添加注释

This commit is contained in:
05412 2024-07-18 09:49:01 +08:00
parent 26e229f05f
commit dd9ebbaa65
36 changed files with 414 additions and 48 deletions

View File

@ -9,11 +9,13 @@ import org.springframework.context.ConfigurableApplicationContext
@SpringBootApplication
@EnableConfigurationProperties(BaseConfig::class)
open class SurlApplication {
// 伴生对象,用于获取上下文
companion object {
lateinit var context: ConfigurableApplicationContext
}
}
fun main(args: Array<String>) {
// 启动并获取上下文
SurlApplication.context = runApplication<SurlApplication>(*args)
}

View File

@ -7,13 +7,39 @@ import java.time.temporal.ChronoUnit
import java.util.Date
import javax.crypto.SecretKey
/**
* 基础配置
*/
@ConfigurationProperties(prefix = "base.configs")
class BaseConfig(
/**
* 主站域名/IP
*/
val site: String = "http://127.0.0.1",
val expire: Long = 3600000, // token expire time
/**
* token过期数值
*/
val expire: Long = 3600000,
/**
* token过期单位
*/
val unit: ChronoUnit = ChronoUnit.MILLIS,
/**
* token头
*/
val tokenHead: String = "Bearer ",
/**
* 免认证白名单
*/
whiteList: List<String> = listOf("/login"),
/**
* JWT密钥
*/
secret: String = numberToKey(Date().time).repeat(5),
) {
val secretKey: SecretKey = Keys.hmacShaKeyFor(secret.toByteArray())

View File

@ -3,8 +3,8 @@ package dev.surl.surl.cfg
import org.slf4j.Logger
import org.slf4j.LoggerFactory
/**
* 获取日志对象的扩展函数
*/
@Suppress("UNUSED")
fun <T: Any> T.logger(): Logger = LoggerFactory.getLogger(this::class.java)
@Suppress("UNUSED")
fun logger(name: String): Logger = LoggerFactory.getLogger(name)

View File

@ -1,10 +0,0 @@
package dev.surl.surl.cfg
import org.springframework.context.annotation.Configuration
@Configuration
@Suppress("UNUSED")
open class PatternConfig {
val usernamePattern = Regex("""\w{6,20}""")
val passwordPattern = Regex("""^((?=\S*?[A-Z])(?=\S*?[a-z])(?=\S*?[0-9])(?=\S*?)).{10,}\S$""")
}

View File

@ -2,14 +2,21 @@ package dev.surl.surl.cfg
import org.springframework.context.annotation.Bean
import org.springframework.data.redis.connection.RedisConnectionFactory
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.data.redis.core.StringRedisTemplate
import org.springframework.stereotype.Component
/**
* Redis配置类
*/
@Component
class RedisConfig {
/**
* 默认RedisTemplate
*/
@Bean
fun baseRedis(factory: RedisConnectionFactory): StringRedisTemplate {
fun baseRedis(factory: RedisConnectionFactory): RedisTemplate<String, String> {
return StringRedisTemplate(factory)
}
}

View File

@ -2,7 +2,6 @@ package dev.surl.surl.cfg.security
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.security.crypto.bcrypt.BCrypt
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
@Configuration
@ -15,9 +14,4 @@ open class EncoderConfig {
open fun passwordEncoder(): BCryptPasswordEncoder {
return BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.`$2B`)
}
@Bean
open fun cryoto(): BCrypt {
return BCrypt()
}
}

View File

@ -9,6 +9,9 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.config.annotation.web.invoke
/**
* 网安配置
*/
@Configuration
@EnableWebSecurity
open class WebSecurityConfig {
@ -22,10 +25,10 @@ open class WebSecurityConfig {
response: HttpServletResponse): SecurityFilterChain {
http {
csrf { disable() } // 关闭csrf
formLogin { disable() }
httpBasic { disable() }
formLogin { disable() } // 关闭表单登录
httpBasic { disable() } // 关闭basic认证
authorizeHttpRequests {
authorize(anyRequest, permitAll)
authorize(anyRequest, permitAll) // 放行所有请求
}
headers {
cacheControl { } // 禁用缓存

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common
/**
* 用户权限枚举
*/
enum class Access {
ADMIN, READ, WRITE
}

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common
/**
* 通用接口返回格式
*/
data class Msg<T>(
val code: Int = 0, val msg: String? = null, val value: T? = null
)

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common.enums
/**
* Redis存储的前缀
*/
enum class RedisStorage {
TOKEN
}

View File

@ -1,3 +1,6 @@
package dev.surl.surl.common.exception
/**
* 自定义权限异常
*/
class UnauthorizedExcecption(message: String? = null, cause: Throwable? = null) : Exception(message, cause)

View File

@ -1,3 +1,6 @@
package dev.surl.surl.common.exception
/**
* 自定义注册异常
*/
class UserRegistException(message: String? = null, cause: Throwable? = null): Exception(message, cause)

View File

@ -12,20 +12,30 @@ import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RestController
import java.net.URI
/**
* 短链接跳转控制器
*/
@RestController
class RedirectController(private val service: SurlService) {
/**
* 短链接跳转
*/
@GetMapping("/{key}")
fun redirect(
@PathVariable
@Valid
@Length(min = 1, max = 11, message = "Key length is not valid")
@Pattern(regexp = "[\\w!*().\\-_~]+", message = "Key format is not valid")
key: String
@PathVariable @Valid @Length(
min = 1,
max = 11,
message = "Key length is not valid"
) @Pattern(regexp = "[\\w!*().\\-_~]+", message = "Key format is not valid") key: String
): ResponseEntity<Any> {
// 根据key获取原始链接
val redirectUrl = service.getUrlByKey(key)
return if(redirectUrl.isBlank()) {
return if (redirectUrl.isBlank()) {
// 未找到,返回异常信息
ResponseEntity(Msg<String>(code = -1, msg = "key `$key` not found"), HttpStatus.NOT_FOUND)
} else {
// 找到发送302跳转
ResponseEntity.status(302).location(URI.create(redirectUrl)).build()
}
}

View File

@ -4,14 +4,40 @@ import dev.surl.surl.cfg.BaseConfig
import dev.surl.surl.common.Msg
import dev.surl.surl.dto.SurlDto
import dev.surl.surl.service.SurlService
import dev.surl.surl.util.JwtTokenUtil
import jakarta.validation.Valid
import org.springframework.http.HttpHeaders
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.RestController
/**
* 短链接新增控制器
*/
@RestController
class SurlAddController(private val service: SurlService, private val cfg: BaseConfig) {
class SurlAddController(
private val service: SurlService, private val cfg: BaseConfig, private val jwtTokenUtil: JwtTokenUtil
) {
/**
* 短链接新增
*/
@PostMapping("/api/surl/add")
fun addSurl(@Valid @RequestBody body: SurlDto) =
Msg(code = 0, value = "${cfg.site}/${service.addSurl(body.url ?: "")}")
fun addSurl(@RequestHeader headers: HttpHeaders, @Valid @RequestBody body: SurlDto): Msg<String> {
// 从认证头获取用户名
val username = jwtTokenUtil.getUsernameFromHeader(headers)
// 获取主站域名/IP
val site = cfg.site
// 添加短链接
val key = service.addSurl(body.url ?: "", username)
// 拼接短链接
val url = "$site/$key"
return Msg(code = 0, value = url)
}
}

View File

@ -8,16 +8,25 @@ import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.RestController
/**
* 获取用户名下短链接列表控制器
*/
@RestController
class SurlGetController(
private val surlService: SurlService,
private val jwtTokenUtil: JwtTokenUtil
private val surlService: SurlService, private val jwtTokenUtil: JwtTokenUtil
) {
/**
* 获取用户名下短链接列表
*/
@GetMapping(path = ["/api/surl/get"])
fun getUrlsByUser(@RequestHeader headers: HttpHeaders): Msg<List<String>> {
val token = jwtTokenUtil.getTokenFromHeader(headers[HttpHeaders.AUTHORIZATION]?.last() ?: "")
val username = jwtTokenUtil.getUsernameFromToken(token)
// 从认证头获取用户名
val username = jwtTokenUtil.getUsernameFromHeader(headers)
// 获取用户名下短链接列表
val urls = surlService.getUrlsByUser(username)
return Msg(value = urls)
}
}

View File

@ -9,10 +9,16 @@ import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestMethod
import org.springframework.web.bind.annotation.RestController
/**
* 用户操作控制器
*/
@RestController
class UserController {
/**
* 用户注册
*/
@RequestMapping(method = [RequestMethod.POST], path = ["/reg"])
fun reg(
@Autowired service: UserService, @Valid @RequestBody(required = true) user: UserDto
) = service.addUser(user.username!!, user.password!!)
) = service.addUser(user.username!!, user.password!!) // 新增用户
}

View File

@ -5,9 +5,24 @@ import org.jetbrains.exposed.dao.Entity
import org.jetbrains.exposed.dao.EntityClass
import org.jetbrains.exposed.dao.id.EntityID
/**
* 短链接实体类
*/
@Suppress("UNUSED")
class Surl(id: EntityID<Long>): Entity<Long>(id) {
/**
* 短链接实体类伴生对象用于crud操作
*/
companion object: EntityClass<Long, Surl>(Surls)
/**
* 短链接url
*/
var url by Surls.url
/**
* 短链接所属用户
*/
var user by User optionalReferencedOn Surls.user
}

View File

@ -5,8 +5,23 @@ import org.jetbrains.exposed.dao.LongEntity
import org.jetbrains.exposed.dao.LongEntityClass
import org.jetbrains.exposed.dao.id.EntityID
/**
* 用户实体
*/
class User(id: EntityID<Long>): LongEntity(id) {
/**
* 用户实体伴生对象用于CRUD操作
*/
companion object EntityClass: LongEntityClass<User>(Users)
/**
* 用户名
*/
var username by Users.username
/**
* 密码
*/
var password by Users.password
}

View File

@ -6,12 +6,27 @@ import org.jetbrains.exposed.dao.LongEntity
import org.jetbrains.exposed.dao.LongEntityClass
import org.jetbrains.exposed.dao.id.EntityID
/**
* 用户权限实体类
*/
class UserAccess(id: EntityID<Long>): LongEntity(id) {
/**
* 伴生对象用于CRUD操作
*/
companion object EntityClass: LongEntityClass<UserAccess>(UserAccesses)
/**
* 权限枚举类型自动转换为数据库存储的整数
*/
var access by UserAccesses.access.transform(toColumn = {
it.ordinal.toShort()
}, toReal = {
Access.entries[it.toInt()]
})
/**
* 用户
*/
var user by User referencedOn UserAccesses.user
}

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable
/**
* 短链接表
*/
object Surls: IdTable<Long>("surl") {
override val id = long("id").entityId()
val url = varchar("url", 2048)

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable
/**
* 用户权限表
*/
object UserAccesses: IdTable<Long>("user_access") {
override val id = long("id").entityId()
val user = reference("user", Users).index()

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable
/**
* 用户权限表
*/
object Users: IdTable<Long>("users") {
override val id = long("id").entityId()
val username = varchar("username", 256).uniqueIndex()

View File

@ -4,6 +4,9 @@ import com.fasterxml.jackson.annotation.JsonProperty
import jakarta.validation.constraints.NotNull
import org.hibernate.validator.constraints.Length
/**
* 短链接新增请求体
*/
data class SurlDto(
@JsonProperty("url")
@get:NotNull(message = "url cannot be empty")

View File

@ -4,6 +4,9 @@ import com.fasterxml.jackson.annotation.JsonProperty
import jakarta.validation.constraints.NotNull
import org.hibernate.validator.constraints.Length
/**
* 用户信息请求体
*/
data class UserDto (
@JsonProperty("username")
@get:Length(max = 16, min = 4, message = "username length must be between 4 and 16")

View File

@ -14,6 +14,9 @@ import org.springframework.http.HttpHeaders
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter
/**
* JWT认证过滤器
*/
@Component
class JwtAuthenticationTokenFilter(
private val jwtTokenUtil: JwtTokenUtil,
@ -26,8 +29,10 @@ class JwtAuthenticationTokenFilter(
response: HttpServletResponse,
filterChain: FilterChain
) {
// 检查请求路径是否在白名单内
if (request.servletPath notMatchedIn cfg.whiteList) {
try {
// 验证token
val exp = UnauthorizedExcecption("unauthorized")
val authHeader = request.getHeader(HttpHeaders.AUTHORIZATION) ?: throw exp
val token = jwtTokenUtil.getTokenFromHeader(authHeader)
@ -38,8 +43,10 @@ class JwtAuthenticationTokenFilter(
throw exp
}
}
// redis缓存内检查不到已存在token拒绝认证抛出异常
if (cachedToken != token) throw exp
} catch (e: UnauthorizedExcecption) {
// 认证失败
response.status = HttpServletResponse.SC_UNAUTHORIZED
val responseBody = om.writeValueAsString(Msg<String>(code = -1, msg = e.message))
response.writer.run {
@ -49,9 +56,13 @@ class JwtAuthenticationTokenFilter(
return
}
}
// 认证成功
filterChain.doFilter(request, response)
}
/**
* 判断字符串是否匹配正则列表
*/
private infix fun String.matchedIn(regexes: List<Regex>): Boolean {
for (regex in regexes) {
if (this.matches(regex)) return true
@ -59,6 +70,9 @@ class JwtAuthenticationTokenFilter(
return false
}
/**
* 判断字符串是否不匹配正则列表
*/
private infix fun String.notMatchedIn(regexes: List<Regex>): Boolean {
return !(this matchedIn regexes)
}

View File

@ -21,6 +21,9 @@ import org.springframework.security.web.authentication.UsernamePasswordAuthentic
import org.springframework.stereotype.Component
import java.nio.charset.StandardCharsets
/**
* 登录过滤器
*/
@Component
class UsernamePasswordAuthenticationCheckFilter(
private val om: ObjectMapper,
@ -31,15 +34,21 @@ class UsernamePasswordAuthenticationCheckFilter(
) : UsernamePasswordAuthenticationFilter() {
init {
// 设置登录地址
setFilterProcessesUrl("/login")
authenticationManager = AuthenticationManager { it }
}
/**
* 尝试登录
*/
override fun attemptAuthentication(request: HttpServletRequest?, response: HttpServletResponse?): Authentication {
request ?: throw IllegalArgumentException("request is null")
val userDto = request.run {
om.readValue(String(inputStream.readAllBytes(), StandardCharsets.UTF_8), UserDto::class.java)
}
// 尝试验证登录信息
try {
validate(userDto, validator)
} catch (e: ConstraintViolationException) {
@ -55,6 +64,9 @@ class UsernamePasswordAuthenticationCheckFilter(
)
}
/**
* 登录成功生成并返回token
*/
override fun successfulAuthentication(
request: HttpServletRequest?, response: HttpServletResponse?, chain: FilterChain?, authResult: Authentication?
) {
@ -71,6 +83,9 @@ class UsernamePasswordAuthenticationCheckFilter(
}
}
/**
* 登录失败 返回错误信息
*/
override fun unsuccessfulAuthentication(
request: HttpServletRequest?, response: HttpServletResponse?, failed: AuthenticationException?
) {

View File

@ -6,6 +6,9 @@ import org.springframework.security.access.AccessDeniedException
import org.springframework.security.web.access.AccessDeniedHandler
import org.springframework.web.bind.annotation.ControllerAdvice
/**
* 访问权限异常处理器
*/
@ControllerAdvice
class AccessHandler: AccessDeniedHandler {
override fun handle(
@ -13,6 +16,7 @@ class AccessHandler: AccessDeniedHandler {
response: HttpServletResponse?,
accessDeniedException: AccessDeniedException?
) {
// 跳转登录页
response?.sendRedirect("/login")
}
}

View File

@ -16,14 +16,24 @@ import org.springframework.web.context.request.WebRequest
import org.springframework.web.method.annotation.HandlerMethodValidationException
import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler
/**
* 自定义异常处理
*/
@ControllerAdvice
class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
/**
* 处理方法参数校验异常
*/
override fun handleMethodValidationException(
ex: MethodValidationException, headers: HttpHeaders, status: HttpStatus, request: WebRequest
): ResponseEntity<Any> {
return ResponseEntity(Msg<String>(code = -1, msg = ex.allValidationResults.joinToString(";")), status)
}
/**
* 处理方法参数校验异常
*/
override fun handleHandlerMethodValidationException(
ex: HandlerMethodValidationException,
headers: HttpHeaders,
@ -38,6 +48,9 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
}), status)
}
/**
* 处理方法参数校验异常
*/
override fun handleMethodArgumentNotValid(
ex: MethodArgumentNotValidException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest
): ResponseEntity<Any> {
@ -48,12 +61,18 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
)
}
/**
* 处理请求体解析异常
*/
override fun handleHttpMessageNotReadable(
ex: HttpMessageNotReadableException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest
): ResponseEntity<Any> {
return ResponseEntity(Msg<String>(code = -1, msg = ex.message ?: "unknown error"), status)
}
/**
* 处理其他异常
*/
@ExceptionHandler(value = [IllegalStateException::class, Exception::class])
fun handleException(
ex: Exception
@ -61,12 +80,18 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown error"), HttpStatus.INTERNAL_SERVER_ERROR)
}
/**
* 处理用户注册异常
*/
@ExceptionHandler(value = [UserRegistException::class])
fun handleUserRegistException(ex: Exception
): ResponseEntity<Msg<String>>{
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown regist error"), HttpStatus.BAD_REQUEST)
}
/**
* 处理校验异常
*/
@ExceptionHandler(value = [ConstraintViolationException::class])
fun handleConstraintViolationException(ex: Exception): ResponseEntity<Msg<String>> {
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown validation error"), HttpStatus.BAD_REQUEST)

View File

@ -8,19 +8,35 @@ import org.jetbrains.exposed.sql.batchInsert
import org.jetbrains.exposed.sql.transactions.transaction
import org.springframework.stereotype.Service
/**
* 短链接服务
*/
@Service
class SurlService {
private val userService: UserService by autowired()
fun addSurl(baseurl: String): String = runBlocking {
/**
* 添加短链接
* @param baseurl 原始链接
* @param username 用户名
*/
fun addSurl(baseurl: String, username: String): String = runBlocking {
// 使用雪花算法生成id
val id = genSnowflakeUID()
transaction {
Surl.new(id) {
url = baseurl
user = userService.getUserByUsername(username)
}
}
// 返回id转换后的生成的key
numberToKey(id)
}
/**
* 批量添加短链接
* @param baseurls 原始链接列表
*/
fun batchAddSurl(baseurls: List<String>) = transaction {
Surls.batchInsert(baseurls, shouldReturnGeneratedValues = false) {
this[Surls.url] = it
@ -28,6 +44,10 @@ class SurlService {
}
}
/**
* 根据key获取原始链接
* @param key 短链接key
*/
fun getUrlByKey(key: String): String {
return transaction {
Surls.select(Surls.url).where {
@ -35,11 +55,16 @@ class SurlService {
}.firstOrNull()?.get(Surls.url) ?: ""
}
}
/**
* 根据用户名获取短链接列表
* @param username 用户名
*/
fun getUrlsByUser(username: String): List<String> {
val user = userService.getUserByUsername(username) ?: return emptyList()
return transaction {
Surl.find {
Surls.id eq user.id
Surls.user eq user.id
}.map {
it.url
}

View File

@ -20,15 +20,26 @@ import org.springframework.stereotype.Service
typealias AUser = org.springframework.security.core.userdetails.User
/**
* 用户服务
*/
@Service
class UserService: UserDetailsService {
private val passwordEncoder: BCryptPasswordEncoder by autowired()
private val validator: Validator by autowired()
/**
* 注册用户
* @param username 用户名
* @param password 密码
* @return 注册成功返回用户id和用户名
*/
fun addUser(username: String, password: String): Msg<Map<String, String>> {
val (id, accessId) = runBlocking {
Pair(genSnowflakeUID(), genSnowflakeUID())
val id = runBlocking {
genSnowflakeUID()
}
// 密码加密
val encryptedPassword = passwordEncoder.encode(password)
transaction {
if (isUserExist(username)) {
@ -38,7 +49,8 @@ class UserService: UserDetailsService {
this.username = username
this.password = encryptedPassword
}
addDefaultAccess(accessId, user)
// 添加默认权限
addDefaultAccess(user)
}
return Msg(value = mapOf(
"id" to numberToKey(id),
@ -46,6 +58,11 @@ class UserService: UserDetailsService {
))
}
/**
* 根据用户名获取用户信息
* @param username 用户名
* @return 用户信息
*/
fun getUserByUsername(username: String): User? {
return transaction {
User.find {
@ -54,23 +71,43 @@ class UserService: UserDetailsService {
}
}
/**
* 判断用户是否存在
* @param username 用户名
* @return 用户是否存在
*/
private fun isUserExist(username: String) = !User.find {
Users.username eq username
}.empty()
private fun addDefaultAccess(id: Long, user: User) {
/**
* 添加默认权限
* @param user 用户
*/
private fun addDefaultAccess(user: User) {
val id = runBlocking { genSnowflakeUID() }
UserAccess.new(id) {
this.access = Access.READ
this.user = user
}
}
/**
* 验证用户密码
* @param userDto 用户信息
* @return 验证结果
*/
fun authUser(userDto: UserDto):Boolean {
validate(userDto, validator)
val user = getUserByUsername(userDto.username!!) ?: throw UsernameNotFoundException("user `${userDto.username}` not found")
return passwordEncoder.matches(userDto.password!!, user.password)
}
/**
* 根据用户名获取用户信息
* @param username 用户名
* @return 用户信息
*/
override fun loadUserByUsername(username: String): UserDetails {
val user = getUserByUsername(username) ?: throw UsernameNotFoundException("user '$username' not found")
return AUser.builder().apply {

View File

@ -4,6 +4,9 @@ import dev.surl.surl.SurlApplication
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
/**
* 注入代理类
*/
class Autowired<T : Any>(private val type: KClass<T>, private val name: String?) {
private val value: T by lazy {
if (name == null) {
@ -14,4 +17,8 @@ class Autowired<T : Any>(private val type: KClass<T>, private val name: String?)
}
operator fun getValue(thisRef: Any?, property: KProperty<*>): T = value
}
/**
* 注入代理器
*/
inline fun <reified T : Any> autowired(name: String? = null) = Autowired(T::class, name)

View File

@ -3,19 +3,29 @@ package dev.surl.surl.util
import dev.surl.surl.cfg.BaseConfig
import io.jsonwebtoken.Claims
import io.jsonwebtoken.Jwts
import org.springframework.http.HttpHeaders
import org.springframework.oxm.ValidationFailureException
import org.springframework.stereotype.Component
import java.time.LocalDateTime
import java.time.ZoneId
import java.util.Date
/**
* JWT token 工具类
*/
@Component
class JwtTokenUtil(private val cfg: BaseConfig) {
fun getToken(identityId: String, authorizes: List<String>): Pair<Date, String> {
/**
* 生成token
* @param username 用户名
* @param authorizes 用户角色
* @return Pair<Date, String> token过期时间token
*/
fun getToken(username: String, authorizes: List<String>): Pair<Date, String> {
val now = LocalDateTime.now()
val expireAt = Date.from(now.plus(cfg.expire, cfg.unit).atZone(ZoneId.systemDefault()).toInstant())
val token = Jwts.builder().run {
subject(identityId)
subject(username)
issuedAt(Date())
expiration(expireAt)
signWith(cfg.secretKey)
@ -25,24 +35,55 @@ class JwtTokenUtil(private val cfg: BaseConfig) {
return Pair(expireAt, token)
}
/**
* 获取token信息
* @param token token
* @return Claims
*/
private fun getTokenClaim(token: String): Claims? {
return Jwts.parser().verifyWith(cfg.secretKey).build().parseSignedClaims(token).payload
}
/**
* 从token中获取用户名
* @param token token
* @return 用户名
*/
fun getUsernameFromToken(token: String): String {
return getClaimFromToken(token) {
it?.subject
} ?: throw ValidationFailureException("invalid token, userinfo not found")
}
/**
* 从token中获取claims
* @param token token
* @param resolver 解析器
* @return T 解析结果
*/
private fun <T> getClaimFromToken(token: String, resolver: (Claims?) -> T): T {
val claims = getTokenClaim(token)
return resolver(claims)
}
/**
* 从header中获取token
* @param header header
* @return token
*/
fun getTokenFromHeader(header: String): String {
return if (header.startsWith(cfg.tokenHead)) {
header.substring(cfg.tokenHead.length)
} else throw ValidationFailureException("invalid token")
}
/**
* 从header中获取用户名
* @param headers headers
* @return 用户名
*/
fun getUsernameFromHeader(headers: HttpHeaders): String {
val token = getTokenFromHeader(headers[HttpHeaders.AUTHORIZATION]?.last() ?: "")
return getUsernameFromToken(token)
}
}

View File

@ -3,9 +3,16 @@ package dev.surl.surl.util
import org.noelware.charted.snowflake.Snowflake
import kotlin.math.pow
// 70进制映射
private val CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!*().-_~".toCharArray()
// 雪花ID生成器
private val snowflake = Snowflake()
/**
* 将数字转换为key70进制
* @param number number
*/
fun numberToKey(number: Long): String {
if(number == 0L) throw Exception("serial number cannot be zero")
var num = number
@ -18,6 +25,10 @@ fun numberToKey(number: Long): String {
return sb.reverse().toString()
}
/**
* 将key转换为10进制数字
* @param key key
*/
fun keyToNumber(key: String): Long {
var sum = 0L
for(i in key.indices) {
@ -28,6 +39,9 @@ fun keyToNumber(key: String): Long {
return sum
}
/**
* 生成雪花ID
*/
suspend fun genSnowflakeUID(): Long {
return snowflake.generate().value
}

View File

@ -3,6 +3,11 @@ package dev.surl.surl.util
import jakarta.validation.ConstraintViolationException
import jakarta.validation.Validator
/**
* 请求体验证
* @param dto 待验证的dto
* @param validator 验证器
*/
fun <T: Any?> validate(dto: T,validator: Validator) {
if(dto == null) throw IllegalArgumentException("dto for validation is null")
val violations = validator.validate(dto)

View File

@ -7,10 +7,20 @@ import org.springframework.stereotype.Component
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit
/**
* Redis工具类
*/
@Suppress("UNUSED")
@Component
class RedisUtil(private val template: StringRedisTemplate, private val cfg: BaseConfig) {
private val ops = template.opsForValue()
/**
* 获取字符串
* @param key
* @param type 存储类型
* @return 字符串
*/
fun getString(key: String, type: RedisStorage? = null): String? {
if (type == null) {
return ops.get(key)
@ -18,6 +28,12 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
return ops.get("${type.name}_$key")
}
/**
* 设置字符串
* @param key
* @param value
* @param type 存储类型
*/
fun setString(key: String, value: String, type: RedisStorage? = null) {
if (type == null) {
ops.set(key, value, cfg.expire, chronoUnitToTimeUnit(cfg.unit))
@ -26,6 +42,11 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
}
}
/**
* 删除键
* @param key
* @param type 存储类型
*/
fun delKey(key: String, type: RedisStorage? = null) {
if (type == null) {
ops.operations.delete(key)
@ -34,12 +55,20 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
ops.operations.delete("${type.name}_$key")
}
/**
* 清空数据库
*/
fun flushdb() {
template.execute {
it.serverCommands().flushDb()
}
}
/**
* 将ChronoUnit转换为TimeUnit
* @param unit ChronoUnit
* @return TimeUnit
*/
private fun chronoUnitToTimeUnit(unit: ChronoUnit): TimeUnit {
return when (unit) {
ChronoUnit.MILLIS -> TimeUnit.MILLISECONDS

View File

@ -24,6 +24,7 @@ spring:
time-zone: Asia/Shanghai
serialization:
indent-output: true
date-format: yyyy-MM-dd HH:mm:ss.SSS
data:
redis:
host: localhost
@ -45,7 +46,8 @@ logging:
base:
configs:
site: http://127.0.0.1:18888
expire: 3600000
expire: 6
unit: hours
secret: Is#45Ddw29apkbHawwaHb4d^&w29apkbHawwaHb4d^&
white-list:
- ^/login$