添加注释

This commit is contained in:
05412 2024-07-18 09:49:01 +08:00
parent 26e229f05f
commit dd9ebbaa65
36 changed files with 414 additions and 48 deletions

View File

@ -9,11 +9,13 @@ import org.springframework.context.ConfigurableApplicationContext
@SpringBootApplication @SpringBootApplication
@EnableConfigurationProperties(BaseConfig::class) @EnableConfigurationProperties(BaseConfig::class)
open class SurlApplication { open class SurlApplication {
// 伴生对象,用于获取上下文
companion object { companion object {
lateinit var context: ConfigurableApplicationContext lateinit var context: ConfigurableApplicationContext
} }
} }
fun main(args: Array<String>) { fun main(args: Array<String>) {
// 启动并获取上下文
SurlApplication.context = runApplication<SurlApplication>(*args) SurlApplication.context = runApplication<SurlApplication>(*args)
} }

View File

@ -7,13 +7,39 @@ import java.time.temporal.ChronoUnit
import java.util.Date import java.util.Date
import javax.crypto.SecretKey import javax.crypto.SecretKey
/**
* 基础配置
*/
@ConfigurationProperties(prefix = "base.configs") @ConfigurationProperties(prefix = "base.configs")
class BaseConfig( class BaseConfig(
/**
* 主站域名/IP
*/
val site: String = "http://127.0.0.1", val site: String = "http://127.0.0.1",
val expire: Long = 3600000, // token expire time
/**
* token过期数值
*/
val expire: Long = 3600000,
/**
* token过期单位
*/
val unit: ChronoUnit = ChronoUnit.MILLIS, val unit: ChronoUnit = ChronoUnit.MILLIS,
/**
* token头
*/
val tokenHead: String = "Bearer ", val tokenHead: String = "Bearer ",
/**
* 免认证白名单
*/
whiteList: List<String> = listOf("/login"), whiteList: List<String> = listOf("/login"),
/**
* JWT密钥
*/
secret: String = numberToKey(Date().time).repeat(5), secret: String = numberToKey(Date().time).repeat(5),
) { ) {
val secretKey: SecretKey = Keys.hmacShaKeyFor(secret.toByteArray()) val secretKey: SecretKey = Keys.hmacShaKeyFor(secret.toByteArray())

View File

@ -3,8 +3,8 @@ package dev.surl.surl.cfg
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
/**
* 获取日志对象的扩展函数
*/
@Suppress("UNUSED") @Suppress("UNUSED")
fun <T: Any> T.logger(): Logger = LoggerFactory.getLogger(this::class.java) fun <T: Any> T.logger(): Logger = LoggerFactory.getLogger(this::class.java)
@Suppress("UNUSED")
fun logger(name: String): Logger = LoggerFactory.getLogger(name)

View File

@ -1,10 +0,0 @@
package dev.surl.surl.cfg
import org.springframework.context.annotation.Configuration
@Configuration
@Suppress("UNUSED")
open class PatternConfig {
val usernamePattern = Regex("""\w{6,20}""")
val passwordPattern = Regex("""^((?=\S*?[A-Z])(?=\S*?[a-z])(?=\S*?[0-9])(?=\S*?)).{10,}\S$""")
}

View File

@ -2,14 +2,21 @@ package dev.surl.surl.cfg
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.data.redis.connection.RedisConnectionFactory import org.springframework.data.redis.connection.RedisConnectionFactory
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.data.redis.core.StringRedisTemplate import org.springframework.data.redis.core.StringRedisTemplate
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
/**
* Redis配置类
*/
@Component @Component
class RedisConfig { class RedisConfig {
/**
* 默认RedisTemplate
*/
@Bean @Bean
fun baseRedis(factory: RedisConnectionFactory): StringRedisTemplate { fun baseRedis(factory: RedisConnectionFactory): RedisTemplate<String, String> {
return StringRedisTemplate(factory) return StringRedisTemplate(factory)
} }
} }

View File

@ -2,7 +2,6 @@ package dev.surl.surl.cfg.security
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
import org.springframework.security.crypto.bcrypt.BCrypt
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
@Configuration @Configuration
@ -15,9 +14,4 @@ open class EncoderConfig {
open fun passwordEncoder(): BCryptPasswordEncoder { open fun passwordEncoder(): BCryptPasswordEncoder {
return BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.`$2B`) return BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.`$2B`)
} }
@Bean
open fun cryoto(): BCrypt {
return BCrypt()
}
} }

View File

@ -9,6 +9,9 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.web.SecurityFilterChain import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.annotation.web.invoke
/**
* 网安配置
*/
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
open class WebSecurityConfig { open class WebSecurityConfig {
@ -22,10 +25,10 @@ open class WebSecurityConfig {
response: HttpServletResponse): SecurityFilterChain { response: HttpServletResponse): SecurityFilterChain {
http { http {
csrf { disable() } // 关闭csrf csrf { disable() } // 关闭csrf
formLogin { disable() } formLogin { disable() } // 关闭表单登录
httpBasic { disable() } httpBasic { disable() } // 关闭basic认证
authorizeHttpRequests { authorizeHttpRequests {
authorize(anyRequest, permitAll) authorize(anyRequest, permitAll) // 放行所有请求
} }
headers { headers {
cacheControl { } // 禁用缓存 cacheControl { } // 禁用缓存

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common package dev.surl.surl.common
/**
* 用户权限枚举
*/
enum class Access { enum class Access {
ADMIN, READ, WRITE ADMIN, READ, WRITE
} }

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common package dev.surl.surl.common
/**
* 通用接口返回格式
*/
data class Msg<T>( data class Msg<T>(
val code: Int = 0, val msg: String? = null, val value: T? = null val code: Int = 0, val msg: String? = null, val value: T? = null
) )

View File

@ -1,5 +1,8 @@
package dev.surl.surl.common.enums package dev.surl.surl.common.enums
/**
* Redis存储的前缀
*/
enum class RedisStorage { enum class RedisStorage {
TOKEN TOKEN
} }

View File

@ -1,3 +1,6 @@
package dev.surl.surl.common.exception package dev.surl.surl.common.exception
/**
* 自定义权限异常
*/
class UnauthorizedExcecption(message: String? = null, cause: Throwable? = null) : Exception(message, cause) class UnauthorizedExcecption(message: String? = null, cause: Throwable? = null) : Exception(message, cause)

View File

@ -1,3 +1,6 @@
package dev.surl.surl.common.exception package dev.surl.surl.common.exception
/**
* 自定义注册异常
*/
class UserRegistException(message: String? = null, cause: Throwable? = null): Exception(message, cause) class UserRegistException(message: String? = null, cause: Throwable? = null): Exception(message, cause)

View File

@ -12,20 +12,30 @@ import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import java.net.URI import java.net.URI
/**
* 短链接跳转控制器
*/
@RestController @RestController
class RedirectController(private val service: SurlService) { class RedirectController(private val service: SurlService) {
/**
* 短链接跳转
*/
@GetMapping("/{key}") @GetMapping("/{key}")
fun redirect( fun redirect(
@PathVariable @PathVariable @Valid @Length(
@Valid min = 1,
@Length(min = 1, max = 11, message = "Key length is not valid") max = 11,
@Pattern(regexp = "[\\w!*().\\-_~]+", message = "Key format is not valid") message = "Key length is not valid"
key: String ) @Pattern(regexp = "[\\w!*().\\-_~]+", message = "Key format is not valid") key: String
): ResponseEntity<Any> { ): ResponseEntity<Any> {
// 根据key获取原始链接
val redirectUrl = service.getUrlByKey(key) val redirectUrl = service.getUrlByKey(key)
return if (redirectUrl.isBlank()) { return if (redirectUrl.isBlank()) {
// 未找到,返回异常信息
ResponseEntity(Msg<String>(code = -1, msg = "key `$key` not found"), HttpStatus.NOT_FOUND) ResponseEntity(Msg<String>(code = -1, msg = "key `$key` not found"), HttpStatus.NOT_FOUND)
} else { } else {
// 找到发送302跳转
ResponseEntity.status(302).location(URI.create(redirectUrl)).build() ResponseEntity.status(302).location(URI.create(redirectUrl)).build()
} }
} }

View File

@ -4,14 +4,40 @@ import dev.surl.surl.cfg.BaseConfig
import dev.surl.surl.common.Msg import dev.surl.surl.common.Msg
import dev.surl.surl.dto.SurlDto import dev.surl.surl.dto.SurlDto
import dev.surl.surl.service.SurlService import dev.surl.surl.service.SurlService
import dev.surl.surl.util.JwtTokenUtil
import jakarta.validation.Valid import jakarta.validation.Valid
import org.springframework.http.HttpHeaders
import org.springframework.web.bind.annotation.PostMapping import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestBody import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
/**
* 短链接新增控制器
*/
@RestController @RestController
class SurlAddController(private val service: SurlService, private val cfg: BaseConfig) { class SurlAddController(
private val service: SurlService, private val cfg: BaseConfig, private val jwtTokenUtil: JwtTokenUtil
) {
/**
* 短链接新增
*/
@PostMapping("/api/surl/add") @PostMapping("/api/surl/add")
fun addSurl(@Valid @RequestBody body: SurlDto) = fun addSurl(@RequestHeader headers: HttpHeaders, @Valid @RequestBody body: SurlDto): Msg<String> {
Msg(code = 0, value = "${cfg.site}/${service.addSurl(body.url ?: "")}")
// 从认证头获取用户名
val username = jwtTokenUtil.getUsernameFromHeader(headers)
// 获取主站域名/IP
val site = cfg.site
// 添加短链接
val key = service.addSurl(body.url ?: "", username)
// 拼接短链接
val url = "$site/$key"
return Msg(code = 0, value = url)
}
} }

View File

@ -8,16 +8,25 @@ import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestHeader import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
/**
* 获取用户名下短链接列表控制器
*/
@RestController @RestController
class SurlGetController( class SurlGetController(
private val surlService: SurlService, private val surlService: SurlService, private val jwtTokenUtil: JwtTokenUtil
private val jwtTokenUtil: JwtTokenUtil
) { ) {
/**
* 获取用户名下短链接列表
*/
@GetMapping(path = ["/api/surl/get"]) @GetMapping(path = ["/api/surl/get"])
fun getUrlsByUser(@RequestHeader headers: HttpHeaders): Msg<List<String>> { fun getUrlsByUser(@RequestHeader headers: HttpHeaders): Msg<List<String>> {
val token = jwtTokenUtil.getTokenFromHeader(headers[HttpHeaders.AUTHORIZATION]?.last() ?: "")
val username = jwtTokenUtil.getUsernameFromToken(token) // 从认证头获取用户名
val username = jwtTokenUtil.getUsernameFromHeader(headers)
// 获取用户名下短链接列表
val urls = surlService.getUrlsByUser(username) val urls = surlService.getUrlsByUser(username)
return Msg(value = urls) return Msg(value = urls)
} }
} }

View File

@ -9,10 +9,16 @@ import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestMethod import org.springframework.web.bind.annotation.RequestMethod
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
/**
* 用户操作控制器
*/
@RestController @RestController
class UserController { class UserController {
/**
* 用户注册
*/
@RequestMapping(method = [RequestMethod.POST], path = ["/reg"]) @RequestMapping(method = [RequestMethod.POST], path = ["/reg"])
fun reg( fun reg(
@Autowired service: UserService, @Valid @RequestBody(required = true) user: UserDto @Autowired service: UserService, @Valid @RequestBody(required = true) user: UserDto
) = service.addUser(user.username!!, user.password!!) ) = service.addUser(user.username!!, user.password!!) // 新增用户
} }

View File

@ -5,9 +5,24 @@ import org.jetbrains.exposed.dao.Entity
import org.jetbrains.exposed.dao.EntityClass import org.jetbrains.exposed.dao.EntityClass
import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.EntityID
/**
* 短链接实体类
*/
@Suppress("UNUSED") @Suppress("UNUSED")
class Surl(id: EntityID<Long>): Entity<Long>(id) { class Surl(id: EntityID<Long>): Entity<Long>(id) {
/**
* 短链接实体类伴生对象用于crud操作
*/
companion object: EntityClass<Long, Surl>(Surls) companion object: EntityClass<Long, Surl>(Surls)
/**
* 短链接url
*/
var url by Surls.url var url by Surls.url
/**
* 短链接所属用户
*/
var user by User optionalReferencedOn Surls.user var user by User optionalReferencedOn Surls.user
} }

View File

@ -5,8 +5,23 @@ import org.jetbrains.exposed.dao.LongEntity
import org.jetbrains.exposed.dao.LongEntityClass import org.jetbrains.exposed.dao.LongEntityClass
import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.EntityID
/**
* 用户实体
*/
class User(id: EntityID<Long>): LongEntity(id) { class User(id: EntityID<Long>): LongEntity(id) {
/**
* 用户实体伴生对象用于CRUD操作
*/
companion object EntityClass: LongEntityClass<User>(Users) companion object EntityClass: LongEntityClass<User>(Users)
/**
* 用户名
*/
var username by Users.username var username by Users.username
/**
* 密码
*/
var password by Users.password var password by Users.password
} }

View File

@ -6,12 +6,27 @@ import org.jetbrains.exposed.dao.LongEntity
import org.jetbrains.exposed.dao.LongEntityClass import org.jetbrains.exposed.dao.LongEntityClass
import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.EntityID
/**
* 用户权限实体类
*/
class UserAccess(id: EntityID<Long>): LongEntity(id) { class UserAccess(id: EntityID<Long>): LongEntity(id) {
/**
* 伴生对象用于CRUD操作
*/
companion object EntityClass: LongEntityClass<UserAccess>(UserAccesses) companion object EntityClass: LongEntityClass<UserAccess>(UserAccesses)
/**
* 权限枚举类型自动转换为数据库存储的整数
*/
var access by UserAccesses.access.transform(toColumn = { var access by UserAccesses.access.transform(toColumn = {
it.ordinal.toShort() it.ordinal.toShort()
}, toReal = { }, toReal = {
Access.entries[it.toInt()] Access.entries[it.toInt()]
}) })
/**
* 用户
*/
var user by User referencedOn UserAccesses.user var user by User referencedOn UserAccesses.user
} }

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable import org.jetbrains.exposed.dao.id.IdTable
/**
* 短链接表
*/
object Surls: IdTable<Long>("surl") { object Surls: IdTable<Long>("surl") {
override val id = long("id").entityId() override val id = long("id").entityId()
val url = varchar("url", 2048) val url = varchar("url", 2048)

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable import org.jetbrains.exposed.dao.id.IdTable
/**
* 用户权限表
*/
object UserAccesses: IdTable<Long>("user_access") { object UserAccesses: IdTable<Long>("user_access") {
override val id = long("id").entityId() override val id = long("id").entityId()
val user = reference("user", Users).index() val user = reference("user", Users).index()

View File

@ -2,6 +2,9 @@ package dev.surl.surl.dsl
import org.jetbrains.exposed.dao.id.IdTable import org.jetbrains.exposed.dao.id.IdTable
/**
* 用户权限表
*/
object Users: IdTable<Long>("users") { object Users: IdTable<Long>("users") {
override val id = long("id").entityId() override val id = long("id").entityId()
val username = varchar("username", 256).uniqueIndex() val username = varchar("username", 256).uniqueIndex()

View File

@ -4,6 +4,9 @@ import com.fasterxml.jackson.annotation.JsonProperty
import jakarta.validation.constraints.NotNull import jakarta.validation.constraints.NotNull
import org.hibernate.validator.constraints.Length import org.hibernate.validator.constraints.Length
/**
* 短链接新增请求体
*/
data class SurlDto( data class SurlDto(
@JsonProperty("url") @JsonProperty("url")
@get:NotNull(message = "url cannot be empty") @get:NotNull(message = "url cannot be empty")

View File

@ -4,6 +4,9 @@ import com.fasterxml.jackson.annotation.JsonProperty
import jakarta.validation.constraints.NotNull import jakarta.validation.constraints.NotNull
import org.hibernate.validator.constraints.Length import org.hibernate.validator.constraints.Length
/**
* 用户信息请求体
*/
data class UserDto ( data class UserDto (
@JsonProperty("username") @JsonProperty("username")
@get:Length(max = 16, min = 4, message = "username length must be between 4 and 16") @get:Length(max = 16, min = 4, message = "username length must be between 4 and 16")

View File

@ -14,6 +14,9 @@ import org.springframework.http.HttpHeaders
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter import org.springframework.web.filter.OncePerRequestFilter
/**
* JWT认证过滤器
*/
@Component @Component
class JwtAuthenticationTokenFilter( class JwtAuthenticationTokenFilter(
private val jwtTokenUtil: JwtTokenUtil, private val jwtTokenUtil: JwtTokenUtil,
@ -26,8 +29,10 @@ class JwtAuthenticationTokenFilter(
response: HttpServletResponse, response: HttpServletResponse,
filterChain: FilterChain filterChain: FilterChain
) { ) {
// 检查请求路径是否在白名单内
if (request.servletPath notMatchedIn cfg.whiteList) { if (request.servletPath notMatchedIn cfg.whiteList) {
try { try {
// 验证token
val exp = UnauthorizedExcecption("unauthorized") val exp = UnauthorizedExcecption("unauthorized")
val authHeader = request.getHeader(HttpHeaders.AUTHORIZATION) ?: throw exp val authHeader = request.getHeader(HttpHeaders.AUTHORIZATION) ?: throw exp
val token = jwtTokenUtil.getTokenFromHeader(authHeader) val token = jwtTokenUtil.getTokenFromHeader(authHeader)
@ -38,8 +43,10 @@ class JwtAuthenticationTokenFilter(
throw exp throw exp
} }
} }
// redis缓存内检查不到已存在token拒绝认证抛出异常
if (cachedToken != token) throw exp if (cachedToken != token) throw exp
} catch (e: UnauthorizedExcecption) { } catch (e: UnauthorizedExcecption) {
// 认证失败
response.status = HttpServletResponse.SC_UNAUTHORIZED response.status = HttpServletResponse.SC_UNAUTHORIZED
val responseBody = om.writeValueAsString(Msg<String>(code = -1, msg = e.message)) val responseBody = om.writeValueAsString(Msg<String>(code = -1, msg = e.message))
response.writer.run { response.writer.run {
@ -49,9 +56,13 @@ class JwtAuthenticationTokenFilter(
return return
} }
} }
// 认证成功
filterChain.doFilter(request, response) filterChain.doFilter(request, response)
} }
/**
* 判断字符串是否匹配正则列表
*/
private infix fun String.matchedIn(regexes: List<Regex>): Boolean { private infix fun String.matchedIn(regexes: List<Regex>): Boolean {
for (regex in regexes) { for (regex in regexes) {
if (this.matches(regex)) return true if (this.matches(regex)) return true
@ -59,6 +70,9 @@ class JwtAuthenticationTokenFilter(
return false return false
} }
/**
* 判断字符串是否不匹配正则列表
*/
private infix fun String.notMatchedIn(regexes: List<Regex>): Boolean { private infix fun String.notMatchedIn(regexes: List<Regex>): Boolean {
return !(this matchedIn regexes) return !(this matchedIn regexes)
} }

View File

@ -21,6 +21,9 @@ import org.springframework.security.web.authentication.UsernamePasswordAuthentic
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
/**
* 登录过滤器
*/
@Component @Component
class UsernamePasswordAuthenticationCheckFilter( class UsernamePasswordAuthenticationCheckFilter(
private val om: ObjectMapper, private val om: ObjectMapper,
@ -31,15 +34,21 @@ class UsernamePasswordAuthenticationCheckFilter(
) : UsernamePasswordAuthenticationFilter() { ) : UsernamePasswordAuthenticationFilter() {
init { init {
// 设置登录地址
setFilterProcessesUrl("/login") setFilterProcessesUrl("/login")
authenticationManager = AuthenticationManager { it } authenticationManager = AuthenticationManager { it }
} }
/**
* 尝试登录
*/
override fun attemptAuthentication(request: HttpServletRequest?, response: HttpServletResponse?): Authentication { override fun attemptAuthentication(request: HttpServletRequest?, response: HttpServletResponse?): Authentication {
request ?: throw IllegalArgumentException("request is null") request ?: throw IllegalArgumentException("request is null")
val userDto = request.run { val userDto = request.run {
om.readValue(String(inputStream.readAllBytes(), StandardCharsets.UTF_8), UserDto::class.java) om.readValue(String(inputStream.readAllBytes(), StandardCharsets.UTF_8), UserDto::class.java)
} }
// 尝试验证登录信息
try { try {
validate(userDto, validator) validate(userDto, validator)
} catch (e: ConstraintViolationException) { } catch (e: ConstraintViolationException) {
@ -55,6 +64,9 @@ class UsernamePasswordAuthenticationCheckFilter(
) )
} }
/**
* 登录成功生成并返回token
*/
override fun successfulAuthentication( override fun successfulAuthentication(
request: HttpServletRequest?, response: HttpServletResponse?, chain: FilterChain?, authResult: Authentication? request: HttpServletRequest?, response: HttpServletResponse?, chain: FilterChain?, authResult: Authentication?
) { ) {
@ -71,6 +83,9 @@ class UsernamePasswordAuthenticationCheckFilter(
} }
} }
/**
* 登录失败 返回错误信息
*/
override fun unsuccessfulAuthentication( override fun unsuccessfulAuthentication(
request: HttpServletRequest?, response: HttpServletResponse?, failed: AuthenticationException? request: HttpServletRequest?, response: HttpServletResponse?, failed: AuthenticationException?
) { ) {

View File

@ -6,6 +6,9 @@ import org.springframework.security.access.AccessDeniedException
import org.springframework.security.web.access.AccessDeniedHandler import org.springframework.security.web.access.AccessDeniedHandler
import org.springframework.web.bind.annotation.ControllerAdvice import org.springframework.web.bind.annotation.ControllerAdvice
/**
* 访问权限异常处理器
*/
@ControllerAdvice @ControllerAdvice
class AccessHandler: AccessDeniedHandler { class AccessHandler: AccessDeniedHandler {
override fun handle( override fun handle(
@ -13,6 +16,7 @@ class AccessHandler: AccessDeniedHandler {
response: HttpServletResponse?, response: HttpServletResponse?,
accessDeniedException: AccessDeniedException? accessDeniedException: AccessDeniedException?
) { ) {
// 跳转登录页
response?.sendRedirect("/login") response?.sendRedirect("/login")
} }
} }

View File

@ -16,14 +16,24 @@ import org.springframework.web.context.request.WebRequest
import org.springframework.web.method.annotation.HandlerMethodValidationException import org.springframework.web.method.annotation.HandlerMethodValidationException
import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler
/**
* 自定义异常处理
*/
@ControllerAdvice @ControllerAdvice
class DefaultExceptionHandler : ResponseEntityExceptionHandler() { class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
/**
* 处理方法参数校验异常
*/
override fun handleMethodValidationException( override fun handleMethodValidationException(
ex: MethodValidationException, headers: HttpHeaders, status: HttpStatus, request: WebRequest ex: MethodValidationException, headers: HttpHeaders, status: HttpStatus, request: WebRequest
): ResponseEntity<Any> { ): ResponseEntity<Any> {
return ResponseEntity(Msg<String>(code = -1, msg = ex.allValidationResults.joinToString(";")), status) return ResponseEntity(Msg<String>(code = -1, msg = ex.allValidationResults.joinToString(";")), status)
} }
/**
* 处理方法参数校验异常
*/
override fun handleHandlerMethodValidationException( override fun handleHandlerMethodValidationException(
ex: HandlerMethodValidationException, ex: HandlerMethodValidationException,
headers: HttpHeaders, headers: HttpHeaders,
@ -38,6 +48,9 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
}), status) }), status)
} }
/**
* 处理方法参数校验异常
*/
override fun handleMethodArgumentNotValid( override fun handleMethodArgumentNotValid(
ex: MethodArgumentNotValidException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest ex: MethodArgumentNotValidException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest
): ResponseEntity<Any> { ): ResponseEntity<Any> {
@ -48,12 +61,18 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
) )
} }
/**
* 处理请求体解析异常
*/
override fun handleHttpMessageNotReadable( override fun handleHttpMessageNotReadable(
ex: HttpMessageNotReadableException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest ex: HttpMessageNotReadableException, headers: HttpHeaders, status: HttpStatusCode, request: WebRequest
): ResponseEntity<Any> { ): ResponseEntity<Any> {
return ResponseEntity(Msg<String>(code = -1, msg = ex.message ?: "unknown error"), status) return ResponseEntity(Msg<String>(code = -1, msg = ex.message ?: "unknown error"), status)
} }
/**
* 处理其他异常
*/
@ExceptionHandler(value = [IllegalStateException::class, Exception::class]) @ExceptionHandler(value = [IllegalStateException::class, Exception::class])
fun handleException( fun handleException(
ex: Exception ex: Exception
@ -61,12 +80,18 @@ class DefaultExceptionHandler : ResponseEntityExceptionHandler() {
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown error"), HttpStatus.INTERNAL_SERVER_ERROR) return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown error"), HttpStatus.INTERNAL_SERVER_ERROR)
} }
/**
* 处理用户注册异常
*/
@ExceptionHandler(value = [UserRegistException::class]) @ExceptionHandler(value = [UserRegistException::class])
fun handleUserRegistException(ex: Exception fun handleUserRegistException(ex: Exception
): ResponseEntity<Msg<String>>{ ): ResponseEntity<Msg<String>>{
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown regist error"), HttpStatus.BAD_REQUEST) return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown regist error"), HttpStatus.BAD_REQUEST)
} }
/**
* 处理校验异常
*/
@ExceptionHandler(value = [ConstraintViolationException::class]) @ExceptionHandler(value = [ConstraintViolationException::class])
fun handleConstraintViolationException(ex: Exception): ResponseEntity<Msg<String>> { fun handleConstraintViolationException(ex: Exception): ResponseEntity<Msg<String>> {
return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown validation error"), HttpStatus.BAD_REQUEST) return ResponseEntity(Msg(code = -1, msg = ex.message ?: "unknown validation error"), HttpStatus.BAD_REQUEST)

View File

@ -8,19 +8,35 @@ import org.jetbrains.exposed.sql.batchInsert
import org.jetbrains.exposed.sql.transactions.transaction import org.jetbrains.exposed.sql.transactions.transaction
import org.springframework.stereotype.Service import org.springframework.stereotype.Service
/**
* 短链接服务
*/
@Service @Service
class SurlService { class SurlService {
private val userService: UserService by autowired() private val userService: UserService by autowired()
fun addSurl(baseurl: String): String = runBlocking {
/**
* 添加短链接
* @param baseurl 原始链接
* @param username 用户名
*/
fun addSurl(baseurl: String, username: String): String = runBlocking {
// 使用雪花算法生成id
val id = genSnowflakeUID() val id = genSnowflakeUID()
transaction { transaction {
Surl.new(id) { Surl.new(id) {
url = baseurl url = baseurl
user = userService.getUserByUsername(username)
} }
} }
// 返回id转换后的生成的key
numberToKey(id) numberToKey(id)
} }
/**
* 批量添加短链接
* @param baseurls 原始链接列表
*/
fun batchAddSurl(baseurls: List<String>) = transaction { fun batchAddSurl(baseurls: List<String>) = transaction {
Surls.batchInsert(baseurls, shouldReturnGeneratedValues = false) { Surls.batchInsert(baseurls, shouldReturnGeneratedValues = false) {
this[Surls.url] = it this[Surls.url] = it
@ -28,6 +44,10 @@ class SurlService {
} }
} }
/**
* 根据key获取原始链接
* @param key 短链接key
*/
fun getUrlByKey(key: String): String { fun getUrlByKey(key: String): String {
return transaction { return transaction {
Surls.select(Surls.url).where { Surls.select(Surls.url).where {
@ -35,11 +55,16 @@ class SurlService {
}.firstOrNull()?.get(Surls.url) ?: "" }.firstOrNull()?.get(Surls.url) ?: ""
} }
} }
/**
* 根据用户名获取短链接列表
* @param username 用户名
*/
fun getUrlsByUser(username: String): List<String> { fun getUrlsByUser(username: String): List<String> {
val user = userService.getUserByUsername(username) ?: return emptyList() val user = userService.getUserByUsername(username) ?: return emptyList()
return transaction { return transaction {
Surl.find { Surl.find {
Surls.id eq user.id Surls.user eq user.id
}.map { }.map {
it.url it.url
} }

View File

@ -20,15 +20,26 @@ import org.springframework.stereotype.Service
typealias AUser = org.springframework.security.core.userdetails.User typealias AUser = org.springframework.security.core.userdetails.User
/**
* 用户服务
*/
@Service @Service
class UserService: UserDetailsService { class UserService: UserDetailsService {
private val passwordEncoder: BCryptPasswordEncoder by autowired() private val passwordEncoder: BCryptPasswordEncoder by autowired()
private val validator: Validator by autowired() private val validator: Validator by autowired()
/**
* 注册用户
* @param username 用户名
* @param password 密码
* @return 注册成功返回用户id和用户名
*/
fun addUser(username: String, password: String): Msg<Map<String, String>> { fun addUser(username: String, password: String): Msg<Map<String, String>> {
val (id, accessId) = runBlocking { val id = runBlocking {
Pair(genSnowflakeUID(), genSnowflakeUID()) genSnowflakeUID()
} }
// 密码加密
val encryptedPassword = passwordEncoder.encode(password) val encryptedPassword = passwordEncoder.encode(password)
transaction { transaction {
if (isUserExist(username)) { if (isUserExist(username)) {
@ -38,7 +49,8 @@ class UserService: UserDetailsService {
this.username = username this.username = username
this.password = encryptedPassword this.password = encryptedPassword
} }
addDefaultAccess(accessId, user) // 添加默认权限
addDefaultAccess(user)
} }
return Msg(value = mapOf( return Msg(value = mapOf(
"id" to numberToKey(id), "id" to numberToKey(id),
@ -46,6 +58,11 @@ class UserService: UserDetailsService {
)) ))
} }
/**
* 根据用户名获取用户信息
* @param username 用户名
* @return 用户信息
*/
fun getUserByUsername(username: String): User? { fun getUserByUsername(username: String): User? {
return transaction { return transaction {
User.find { User.find {
@ -54,23 +71,43 @@ class UserService: UserDetailsService {
} }
} }
/**
* 判断用户是否存在
* @param username 用户名
* @return 用户是否存在
*/
private fun isUserExist(username: String) = !User.find { private fun isUserExist(username: String) = !User.find {
Users.username eq username Users.username eq username
}.empty() }.empty()
private fun addDefaultAccess(id: Long, user: User) { /**
* 添加默认权限
* @param user 用户
*/
private fun addDefaultAccess(user: User) {
val id = runBlocking { genSnowflakeUID() }
UserAccess.new(id) { UserAccess.new(id) {
this.access = Access.READ this.access = Access.READ
this.user = user this.user = user
} }
} }
/**
* 验证用户密码
* @param userDto 用户信息
* @return 验证结果
*/
fun authUser(userDto: UserDto):Boolean { fun authUser(userDto: UserDto):Boolean {
validate(userDto, validator) validate(userDto, validator)
val user = getUserByUsername(userDto.username!!) ?: throw UsernameNotFoundException("user `${userDto.username}` not found") val user = getUserByUsername(userDto.username!!) ?: throw UsernameNotFoundException("user `${userDto.username}` not found")
return passwordEncoder.matches(userDto.password!!, user.password) return passwordEncoder.matches(userDto.password!!, user.password)
} }
/**
* 根据用户名获取用户信息
* @param username 用户名
* @return 用户信息
*/
override fun loadUserByUsername(username: String): UserDetails { override fun loadUserByUsername(username: String): UserDetails {
val user = getUserByUsername(username) ?: throw UsernameNotFoundException("user '$username' not found") val user = getUserByUsername(username) ?: throw UsernameNotFoundException("user '$username' not found")
return AUser.builder().apply { return AUser.builder().apply {

View File

@ -4,6 +4,9 @@ import dev.surl.surl.SurlApplication
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.reflect.KProperty import kotlin.reflect.KProperty
/**
* 注入代理类
*/
class Autowired<T : Any>(private val type: KClass<T>, private val name: String?) { class Autowired<T : Any>(private val type: KClass<T>, private val name: String?) {
private val value: T by lazy { private val value: T by lazy {
if (name == null) { if (name == null) {
@ -14,4 +17,8 @@ class Autowired<T : Any>(private val type: KClass<T>, private val name: String?)
} }
operator fun getValue(thisRef: Any?, property: KProperty<*>): T = value operator fun getValue(thisRef: Any?, property: KProperty<*>): T = value
} }
/**
* 注入代理器
*/
inline fun <reified T : Any> autowired(name: String? = null) = Autowired(T::class, name) inline fun <reified T : Any> autowired(name: String? = null) = Autowired(T::class, name)

View File

@ -3,19 +3,29 @@ package dev.surl.surl.util
import dev.surl.surl.cfg.BaseConfig import dev.surl.surl.cfg.BaseConfig
import io.jsonwebtoken.Claims import io.jsonwebtoken.Claims
import io.jsonwebtoken.Jwts import io.jsonwebtoken.Jwts
import org.springframework.http.HttpHeaders
import org.springframework.oxm.ValidationFailureException import org.springframework.oxm.ValidationFailureException
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
import java.time.LocalDateTime import java.time.LocalDateTime
import java.time.ZoneId import java.time.ZoneId
import java.util.Date import java.util.Date
/**
* JWT token 工具类
*/
@Component @Component
class JwtTokenUtil(private val cfg: BaseConfig) { class JwtTokenUtil(private val cfg: BaseConfig) {
fun getToken(identityId: String, authorizes: List<String>): Pair<Date, String> { /**
* 生成token
* @param username 用户名
* @param authorizes 用户角色
* @return Pair<Date, String> token过期时间token
*/
fun getToken(username: String, authorizes: List<String>): Pair<Date, String> {
val now = LocalDateTime.now() val now = LocalDateTime.now()
val expireAt = Date.from(now.plus(cfg.expire, cfg.unit).atZone(ZoneId.systemDefault()).toInstant()) val expireAt = Date.from(now.plus(cfg.expire, cfg.unit).atZone(ZoneId.systemDefault()).toInstant())
val token = Jwts.builder().run { val token = Jwts.builder().run {
subject(identityId) subject(username)
issuedAt(Date()) issuedAt(Date())
expiration(expireAt) expiration(expireAt)
signWith(cfg.secretKey) signWith(cfg.secretKey)
@ -25,24 +35,55 @@ class JwtTokenUtil(private val cfg: BaseConfig) {
return Pair(expireAt, token) return Pair(expireAt, token)
} }
/**
* 获取token信息
* @param token token
* @return Claims
*/
private fun getTokenClaim(token: String): Claims? { private fun getTokenClaim(token: String): Claims? {
return Jwts.parser().verifyWith(cfg.secretKey).build().parseSignedClaims(token).payload return Jwts.parser().verifyWith(cfg.secretKey).build().parseSignedClaims(token).payload
} }
/**
* 从token中获取用户名
* @param token token
* @return 用户名
*/
fun getUsernameFromToken(token: String): String { fun getUsernameFromToken(token: String): String {
return getClaimFromToken(token) { return getClaimFromToken(token) {
it?.subject it?.subject
} ?: throw ValidationFailureException("invalid token, userinfo not found") } ?: throw ValidationFailureException("invalid token, userinfo not found")
} }
/**
* 从token中获取claims
* @param token token
* @param resolver 解析器
* @return T 解析结果
*/
private fun <T> getClaimFromToken(token: String, resolver: (Claims?) -> T): T { private fun <T> getClaimFromToken(token: String, resolver: (Claims?) -> T): T {
val claims = getTokenClaim(token) val claims = getTokenClaim(token)
return resolver(claims) return resolver(claims)
} }
/**
* 从header中获取token
* @param header header
* @return token
*/
fun getTokenFromHeader(header: String): String { fun getTokenFromHeader(header: String): String {
return if (header.startsWith(cfg.tokenHead)) { return if (header.startsWith(cfg.tokenHead)) {
header.substring(cfg.tokenHead.length) header.substring(cfg.tokenHead.length)
} else throw ValidationFailureException("invalid token") } else throw ValidationFailureException("invalid token")
} }
/**
* 从header中获取用户名
* @param headers headers
* @return 用户名
*/
fun getUsernameFromHeader(headers: HttpHeaders): String {
val token = getTokenFromHeader(headers[HttpHeaders.AUTHORIZATION]?.last() ?: "")
return getUsernameFromToken(token)
}
} }

View File

@ -3,9 +3,16 @@ package dev.surl.surl.util
import org.noelware.charted.snowflake.Snowflake import org.noelware.charted.snowflake.Snowflake
import kotlin.math.pow import kotlin.math.pow
// 70进制映射
private val CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!*().-_~".toCharArray() private val CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!*().-_~".toCharArray()
// 雪花ID生成器
private val snowflake = Snowflake() private val snowflake = Snowflake()
/**
* 将数字转换为key70进制
* @param number number
*/
fun numberToKey(number: Long): String { fun numberToKey(number: Long): String {
if(number == 0L) throw Exception("serial number cannot be zero") if(number == 0L) throw Exception("serial number cannot be zero")
var num = number var num = number
@ -18,6 +25,10 @@ fun numberToKey(number: Long): String {
return sb.reverse().toString() return sb.reverse().toString()
} }
/**
* 将key转换为10进制数字
* @param key key
*/
fun keyToNumber(key: String): Long { fun keyToNumber(key: String): Long {
var sum = 0L var sum = 0L
for(i in key.indices) { for(i in key.indices) {
@ -28,6 +39,9 @@ fun keyToNumber(key: String): Long {
return sum return sum
} }
/**
* 生成雪花ID
*/
suspend fun genSnowflakeUID(): Long { suspend fun genSnowflakeUID(): Long {
return snowflake.generate().value return snowflake.generate().value
} }

View File

@ -3,6 +3,11 @@ package dev.surl.surl.util
import jakarta.validation.ConstraintViolationException import jakarta.validation.ConstraintViolationException
import jakarta.validation.Validator import jakarta.validation.Validator
/**
* 请求体验证
* @param dto 待验证的dto
* @param validator 验证器
*/
fun <T: Any?> validate(dto: T,validator: Validator) { fun <T: Any?> validate(dto: T,validator: Validator) {
if(dto == null) throw IllegalArgumentException("dto for validation is null") if(dto == null) throw IllegalArgumentException("dto for validation is null")
val violations = validator.validate(dto) val violations = validator.validate(dto)

View File

@ -7,10 +7,20 @@ import org.springframework.stereotype.Component
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
/**
* Redis工具类
*/
@Suppress("UNUSED") @Suppress("UNUSED")
@Component @Component
class RedisUtil(private val template: StringRedisTemplate, private val cfg: BaseConfig) { class RedisUtil(private val template: StringRedisTemplate, private val cfg: BaseConfig) {
private val ops = template.opsForValue() private val ops = template.opsForValue()
/**
* 获取字符串
* @param key
* @param type 存储类型
* @return 字符串
*/
fun getString(key: String, type: RedisStorage? = null): String? { fun getString(key: String, type: RedisStorage? = null): String? {
if (type == null) { if (type == null) {
return ops.get(key) return ops.get(key)
@ -18,6 +28,12 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
return ops.get("${type.name}_$key") return ops.get("${type.name}_$key")
} }
/**
* 设置字符串
* @param key
* @param value
* @param type 存储类型
*/
fun setString(key: String, value: String, type: RedisStorage? = null) { fun setString(key: String, value: String, type: RedisStorage? = null) {
if (type == null) { if (type == null) {
ops.set(key, value, cfg.expire, chronoUnitToTimeUnit(cfg.unit)) ops.set(key, value, cfg.expire, chronoUnitToTimeUnit(cfg.unit))
@ -26,6 +42,11 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
} }
} }
/**
* 删除键
* @param key
* @param type 存储类型
*/
fun delKey(key: String, type: RedisStorage? = null) { fun delKey(key: String, type: RedisStorage? = null) {
if (type == null) { if (type == null) {
ops.operations.delete(key) ops.operations.delete(key)
@ -34,12 +55,20 @@ class RedisUtil(private val template: StringRedisTemplate, private val cfg: Base
ops.operations.delete("${type.name}_$key") ops.operations.delete("${type.name}_$key")
} }
/**
* 清空数据库
*/
fun flushdb() { fun flushdb() {
template.execute { template.execute {
it.serverCommands().flushDb() it.serverCommands().flushDb()
} }
} }
/**
* 将ChronoUnit转换为TimeUnit
* @param unit ChronoUnit
* @return TimeUnit
*/
private fun chronoUnitToTimeUnit(unit: ChronoUnit): TimeUnit { private fun chronoUnitToTimeUnit(unit: ChronoUnit): TimeUnit {
return when (unit) { return when (unit) {
ChronoUnit.MILLIS -> TimeUnit.MILLISECONDS ChronoUnit.MILLIS -> TimeUnit.MILLISECONDS

View File

@ -24,6 +24,7 @@ spring:
time-zone: Asia/Shanghai time-zone: Asia/Shanghai
serialization: serialization:
indent-output: true indent-output: true
date-format: yyyy-MM-dd HH:mm:ss.SSS
data: data:
redis: redis:
host: localhost host: localhost
@ -45,7 +46,8 @@ logging:
base: base:
configs: configs:
site: http://127.0.0.1:18888 site: http://127.0.0.1:18888
expire: 3600000 expire: 6
unit: hours
secret: Is#45Ddw29apkbHawwaHb4d^&w29apkbHawwaHb4d^& secret: Is#45Ddw29apkbHawwaHb4d^&w29apkbHawwaHb4d^&
white-list: white-list:
- ^/login$ - ^/login$