extractor
package site.daydream.colleen
import java.io.InputStream
import java.lang.reflect.Parameter
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import kotlin.reflect.KFunction
import kotlin.reflect.jvm.javaMethod
// ===== 异常定义 =====
sealed class ExtractionError(message: String, cause: Throwable? = null) : RuntimeException(message, cause) {
class MissingParameter(name: String, type: String) :
ExtractionError("Required parameter '$name' of type $type not found")
class InvalidType(type: String, reason: String) : ExtractionError("Invalid type $type: $reason")
class ConversionFailed(value: String, from: String, to: String, cause: Throwable? = null) :
ExtractionError("Cannot convert '$value' from $from to $to", cause)
}
class InvocationError(handler: String, method: String, cause: Throwable) :
RuntimeException("Failed to invoke $handler.$method", cause)
// ===== 参数提取器接口 =====
interface ParamExtractor<T> {
val value: T
}
// ===== 参数包装类型 =====
data class Path<T>(override val value: T) : ParamExtractor<T>
data class Header(override val value: String) : ParamExtractor<String>
data class Cookie(override val value: String) : ParamExtractor<String>
data class Query<T>(override val value: T) : ParamExtractor<T>
data class Text<T>(override val value: T) : ParamExtractor<T>
data class Json<T>(override val value: T) : ParamExtractor<T>
data class Form<T>(override val value: T) : ParamExtractor<T>
data class Stream(override val value: InputStream) : ParamExtractor<InputStream>
data class UploadedFile(override val value: Request.UploadedFile) : ParamExtractor<Request.UploadedFile>
// ===== 注解定义 =====
@Target(AnnotationTarget.VALUE_PARAMETER)
@Retention(AnnotationRetention.RUNTIME)
annotation class Param(val value: String = "")
// ===== cx 函数实现 =====
fun cx(fn: KFunction<*>): Handler {
val method = fn.javaMethod
?: throw ExtractionError.InvalidType("KFunction", "Cannot get Java method")
val handlerName = method.declaringClass.simpleName ?: "Handler"
// 预先构建所有提取器(闭包缓存,零运行时反射)
val extractors = method.parameters.map { param ->
buildExtractor(param, handlerName, method.name)
}
return Handler { ctx ->
try {
val args = extractors.map { it(ctx) }
method.invoke(null, *args.toTypedArray())
} catch (e: ExtractionError) {
throw e
} catch (e: InvocationError) {
throw e
} catch (e: Exception) {
throw InvocationError(handlerName, method.name, e.cause ?: e)
}
}
}
// ===== 提取器上下文 =====
private data class ExtractorContext(
val handler: String,
val method: String,
val paramName: String,
val paramType: String
) {
override fun toString() = "$handler.$method($paramType $paramName)"
}
// ===== 参数提取器构建 =====
private fun buildExtractor(param: Parameter, handler: String, method: String): (Context) -> Any? {
val rawType = param.type
// Context 直接传递
if (rawType == Context::class.java) {
return { it }
}
val typeInfo = extractTypeInfo(param.parameterizedType)
val paramName = getParameterName(param)
val ctx = ExtractorContext(handler, method, paramName, rawType.simpleName)
return when (rawType) {
Path::class.java -> buildPathExtractor(ctx, typeInfo)
Header::class.java -> buildHeaderExtractor(ctx)
Cookie::class.java -> buildCookieExtractor(ctx)
Query::class.java -> buildQueryExtractor(ctx, typeInfo)
Text::class.java -> buildTextExtractor()
Json::class.java -> buildJsonExtractor(ctx, typeInfo)
Form::class.java -> buildFormExtractor(ctx, typeInfo)
Stream::class.java -> buildStreamExtractor(ctx)
UploadedFile::class.java -> buildFileExtractor(ctx)
else -> { ctx ->
ctx.getService(rawType)
}
}
}
private fun buildTextExtractor(): (Context) -> Any? {
return { Text(it.text() ?: "") }
}
// ===== 参数名获取 =====
private fun getParameterName(param: Parameter): String {
return param.getAnnotation(Param::class.java)?.value?.takeIf { it.isNotEmpty() }
?: param.name.takeUnless { it.matches(Regex("arg\\d+")) }
?: ""
}
// ===== 类型信息提取 =====
private data class TypeInfo(
val rawType: Class<*>?,
val elementType: Class<*>?
)
private fun extractTypeInfo(type: Type): TypeInfo {
if (type !is ParameterizedType) {
return TypeInfo(null, null)
}
return when (val firstArg = type.actualTypeArguments.firstOrNull()) {
is Class<*> -> TypeInfo(firstArg, null)
is ParameterizedType -> {
val raw = firstArg.rawType as? Class<*>
val element = if (raw?.isList() == true) {
firstArg.actualTypeArguments.firstOrNull() as? Class<*>
} else null
TypeInfo(raw, element)
}
else -> TypeInfo(null, null)
}
}
// ===== 类型判断扩展 =====
private fun Class<*>.isSimple() = this in SIMPLE_TYPES
private fun Class<*>.isMap() = this == Map::class.java || this == java.util.Map::class.java
private fun Class<*>.isList() = this == List::class.java || this == java.util.List::class.java
// ===== 各类型提取器 =====
private fun buildPathExtractor(ctx: ExtractorContext, typeInfo: TypeInfo): (Context) -> Any? {
require(ctx.paramName.isNotEmpty()) { "Path parameter requires @Param at $ctx" }
val valueType = typeInfo.rawType
?: throw ExtractionError.InvalidType(ctx.paramType, "Path requires generic type at $ctx")
return { context ->
val value = context.param(ctx.paramName)
?: throw ExtractionError.MissingParameter(ctx.paramName, ctx.paramType)
Path(value.convert(valueType, ctx))
}
}
private fun buildHeaderExtractor(ctx: ExtractorContext): (Context) -> Any? {
require(ctx.paramName.isNotEmpty()) { "Header requires @Param at $ctx" }
return { Header(it.header(ctx.paramName) ?: "") }
}
private fun buildCookieExtractor(ctx: ExtractorContext): (Context) -> Any? {
require(ctx.paramName.isNotEmpty()) { "Cookie requires @Param at $ctx" }
return { Cookie(it.cookie(ctx.paramName) ?: "") }
}
private fun buildQueryExtractor(ctx: ExtractorContext, typeInfo: TypeInfo): (Context) -> Any? {
val raw = typeInfo.rawType
val element = typeInfo.elementType
return { context ->
val value = when {
raw?.isMap() == true -> context.queries()
raw?.isList() == true && element != null -> {
val values = context.request.queryAll(ctx.paramName)
if (values.isEmpty()) emptyList()
else values.map { it.convert(element, ctx) }
}
raw?.isSimple() == true -> {
val queryValue = context.query(ctx.paramName)
queryValue?.convert(raw, ctx) ?: getDefaultValue(raw)
}
raw != null -> context.queriesAs(raw)
else -> throw ExtractionError.InvalidType(ctx.paramType, "Invalid Query type at $ctx")
}
Query(value)
}
}
private fun buildJsonExtractor(ctx: ExtractorContext, typeInfo: TypeInfo): (Context) -> Any? {
val valueType = typeInfo.rawType
?: throw ExtractionError.InvalidType(ctx.paramType, "Json requires generic type at $ctx")
return { Json(it.jsonAs(valueType)) }
}
private fun buildFormExtractor(ctx: ExtractorContext, typeInfo: TypeInfo): (Context) -> Any? {
val raw = typeInfo.rawType
val element = typeInfo.elementType
return { context ->
val value = when {
raw?.isMap() == true -> context.forms()
raw?.isList() == true && element != null -> {
val values = context.request.formAll(ctx.paramName)
if (values.isEmpty()) emptyList()
else values.map { it.convert(element, ctx) }
}
raw?.isSimple() == true -> {
val formValue = context.form(ctx.paramName)
formValue?.convert(raw, ctx) ?: getDefaultValue(raw)
}
raw != null -> context.formsAs(raw)
else -> throw ExtractionError.InvalidType(ctx.paramType, "Invalid Form type at $ctx")
}
Form(value)
}
}
private fun buildStreamExtractor(ctx: ExtractorContext): (Context) -> Any? {
return {
Stream(it.request.stream ?: throw ExtractionError.MissingParameter("stream", ctx.paramType))
}
}
private fun buildFileExtractor(ctx: ExtractorContext): (Context) -> Any? {
return {
val file = it.file(ctx.paramName) ?: throw ExtractionError.MissingParameter(ctx.paramName, ctx.paramType)
UploadedFile(file)
}
}
// ===== 类型转换 =====
private fun String.convert(targetType: Class<*>, ctx: ExtractorContext): Any {
return try {
when (targetType) {
String::class.java -> this
Int::class.java, Integer::class.java ->
toIntOrNull() ?: throw ExtractionError.ConversionFailed(this, "String", "Int")
Long::class.java, java.lang.Long::class.java ->
toLongOrNull() ?: throw ExtractionError.ConversionFailed(this, "String", "Long")
Double::class.java, java.lang.Double::class.java ->
toDoubleOrNull() ?: throw ExtractionError.ConversionFailed(this, "String", "Double")
Float::class.java, java.lang.Float::class.java ->
toFloatOrNull() ?: throw ExtractionError.ConversionFailed(this, "String", "Float")
Boolean::class.java, java.lang.Boolean::class.java ->
toBooleanStrictOrNull() ?: throw ExtractionError.ConversionFailed(this, "String", "Boolean")
else -> throw ExtractionError.InvalidType(targetType.simpleName, "Unsupported type at $ctx")
}
} catch (e: ExtractionError) {
throw e
} catch (e: Exception) {
throw ExtractionError.ConversionFailed(this, "String", targetType.simpleName, e)
}
}
private fun getDefaultValue(clazz: Class<*>): Any = when (clazz) {
Int::class.java, Integer::class.java -> 0
Long::class.java, java.lang.Long::class.java -> 0L
Double::class.java, java.lang.Double::class.java -> 0.0
Float::class.java, java.lang.Float::class.java -> 0.0f
Boolean::class.java, java.lang.Boolean::class.java -> false
String::class.java -> ""
else -> throw ExtractionError.InvalidType(clazz.simpleName, "No default value")
}
// ===== 常量 =====
private val SIMPLE_TYPES = setOf(
String::class.java,
Int::class.java, Integer::class.java,
Long::class.java, java.lang.Long::class.java,
Double::class.java, java.lang.Double::class.java,
Float::class.java, java.lang.Float::class.java,
Boolean::class.java, java.lang.Boolean::class.java
)