code
好问题。先说清楚为什么,再给完整代码。
---
## 为什么 `synchronized` 对虚拟线程不友好
JDK 21 的虚拟线程在遇到 `synchronized` 时会发生 **pinning**(钉住)——虚拟线程被固定到其载体线程(platform thread)上,无法被调度走,退化成阻塞行为,丧失虚拟线程的核心优势。
`ReentrantLock` 则不会 pinning,虚拟线程在等锁时会正确地挂起并释放载体线程。
> JDK 24 已修复大部分 `synchronized` 的 pinning 问题,但在 JDK 21 LTS 上使用 `ReentrantLock` 仍是更稳妥的选择。
---
## 完整重构代码
```kotlin
package io.github.cymoo.colleen.ws
import io.github.cymoo.colleen.Colleen
import io.github.cymoo.colleen.util.http.Headers
import java.io.IOException
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.reflect.KClass
// ============================================================================
// WsMessage
// ============================================================================
/**
* Represents an incoming WebSocket message.
*/
sealed class WsMessage {
data class Text(val data: String) : WsMessage()
data class Binary(val data: ByteArray) : WsMessage() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is Binary) return false
return data.contentEquals(other.data)
}
override fun hashCode(): Int = data.contentHashCode()
}
}
// ============================================================================
// WsCloseReason
// ============================================================================
/**
* Describes why a WebSocket connection was closed.
*/
sealed class WsCloseReason {
object Normal : WsCloseReason() {
override fun toString() = "Normal"
}
object ClientDisconnected : WsCloseReason() {
override fun toString() = "ClientDisconnected"
}
data class Error(val cause: Throwable) : WsCloseReason()
data class Protocol(val code: Int, val reason: String) : WsCloseReason()
}
// ============================================================================
// WsChannel
// ============================================================================
/**
* Low-level WebSocket channel abstraction.
*
* Implementations bridge between the server adapter (e.g. Undertow)
* and the framework-level [WsConnection].
*
* Implementations are NOT required to be thread-safe.
* All concurrency control is handled by [WsConnection].
*/
interface WsChannel : AutoCloseable {
@Throws(IOException::class)
fun sendText(text: String)
@Throws(IOException::class)
fun sendBinary(data: ByteBuffer)
fun close(code: Int, reason: String)
}
// ============================================================================
// WsConnection
// ============================================================================
/**
* WebSocket connection.
*
* This is the primary API surface exposed to user handlers.
*
* ## Threading model
* - [send] methods are serialized internally and may be called from multiple threads.
* - [close] may be called at any time from any thread.
* - Event callbacks ([onMessage], [onClose], [onError]) are dispatched to a worker
* thread by the server adapter to avoid blocking IO threads.
* Messages for the same connection are processed sequentially in order.
*
* ## Lifecycle
* 1. Connection is established by the framework after successful WebSocket handshake.
* 2. User handler receives the connection and registers callbacks.
* 3. Messages arrive via [onMessage] callbacks.
* 4. Connection closes via [close] or when the remote peer disconnects.
* 5. [onClose] callbacks are invoked exactly once.
*
* ## Virtual-thread compatibility
* All internal locks use [ReentrantLock] instead of `synchronized` to avoid
* carrier-thread pinning on JDK 21 virtual threads.
*/
class WsConnection internal constructor(
private val channel: WsChannel,
val pathParams: Map<String, String>,
val queryParams: Map<String, List<String>> = emptyMap(),
private val app: Colleen? = null,
private val states: MutableMap<String, Any?> = mutableMapOf(),
private val requestHeaders: Headers = Headers(),
) : AutoCloseable {
// ========================================================================
// State
// ========================================================================
private val closed = AtomicBoolean(false)
// Written once inside closeLock, then read freely — @Volatile avoids
// a redundant AtomicReference while still guaranteeing visibility.
@Volatile private var closeReason: WsCloseReason = WsCloseReason.Normal
// True once close() has drained and cleared the closeCallbacks list.
// Guarded by closeLock. When true, newly registered onClose callbacks
// are invoked immediately rather than queued.
private var closeCallbacksDrained = false
// ========================================================================
// Locks
// ========================================================================
private val sendLock = ReentrantLock()
private val statesLock = ReentrantLock()
private val messageCallbacksLock = ReentrantLock()
private val closeCallbacksLock = ReentrantLock()
private val errorCallbacksLock = ReentrantLock()
// ========================================================================
// Callback lists (each guarded by its own lock above)
// ========================================================================
private val messageCallbacks = ArrayList<(WsMessage) -> Unit>()
private val closeCallbacks = ArrayList<(WsCloseReason) -> Unit>()
private val errorCallbacks = ArrayList<(Throwable) -> Unit>()
// ========================================================================
// Public state
// ========================================================================
val isClosed: Boolean get() = closed.get()
// ========================================================================
// Path parameters
// ========================================================================
fun pathParam(key: String): String? = pathParams[key]
// ========================================================================
// Query parameters
// ========================================================================
fun query(key: String): String? = queryParams[key]?.firstOrNull()
fun queryList(key: String): List<String> = queryParams[key] ?: emptyList()
// ========================================================================
// Request headers (from the WebSocket upgrade / handshake request)
// ========================================================================
/**
* Returns the first value of the specified HTTP header, or null if absent.
* Header names are case-insensitive.
*/
fun header(key: String): String? = requestHeaders[key]
/**
* Returns all values of the specified HTTP header.
* Returns an empty list if the header is absent.
* Header names are case-insensitive.
*/
fun headerValues(key: String): List<String> = requestHeaders.getAll(key)
// ========================================================================
// Connection-scoped state
// ========================================================================
/**
* Returns true if the state key exists, regardless of whether its value is null.
*/
fun hasState(key: String): Boolean = statesLock.withLock { states.containsKey(key) }
/**
* Returns the state value for the given key.
*
* @throws NoSuchElementException if the key does not exist.
* @throws NullPointerException if the value is null.
*/
@Suppress("UNCHECKED_CAST")
fun <T : Any> getState(key: String): T = statesLock.withLock {
if (!states.containsKey(key)) throw NoSuchElementException("State '$key' not found")
states[key] as T
}
/**
* Returns the state value for the given key, or null if the key does not exist.
*/
@Suppress("UNCHECKED_CAST")
fun <T> getStateOrNull(key: String): T? = statesLock.withLock {
if (!states.containsKey(key)) return null
states[key] as T?
}
/**
* Sets a state value. The value may be null.
*/
fun setState(key: String, value: Any?) = statesLock.withLock {
states[key] = value
}
// ========================================================================
// Service injection
// ========================================================================
/**
* Retrieves a required service instance.
* Resolution walks up the app parent chain (for mounted sub-apps).
*
* @throws IllegalStateException if the service is not registered.
*/
inline fun <reified T : Any> getService(qualifier: Any? = null): T =
resolveService(T::class, qualifier)
?: error("Service ${T::class.simpleName}(qualifier=$qualifier) not registered")
/**
* Retrieves an optional service instance, or null if not registered.
*/
inline fun <reified T : Any> getServiceOrNull(qualifier: Any? = null): T? =
resolveService(T::class, qualifier)
/**
* Retrieves all registered instances of type [T], regardless of qualifier.
*/
inline fun <reified T : Any> getServices(): List<T> =
resolveAllServices(T::class)
@PublishedApi
internal fun <T : Any> resolveAllServices(kClass: KClass<T>): List<T> =
app?.serviceContainer?.getAll(kClass) ?: emptyList()
@PublishedApi
internal fun <T : Any> resolveService(kClass: KClass<T>, qualifier: Any? = null): T? =
resolveServiceFromApp(app, kClass, qualifier)
private tailrec fun <T : Any> resolveServiceFromApp(
current: Colleen?,
kClass: KClass<T>,
qualifier: Any?,
): T? {
if (current == null) return null
return current.serviceContainer.getOrNull(kClass, qualifier)
?: resolveServiceFromApp(current.parent, kClass, qualifier)
}
// ========================================================================
// Java-compatible service injection
// ========================================================================
@JvmOverloads
fun <T : Any> getService(clazz: Class<T>, qualifier: Any? = null): T =
resolveService(clazz.kotlin, qualifier)
?: error("Service ${clazz.simpleName}(qualifier=$qualifier) not registered")
@JvmOverloads
fun <T : Any> getServiceOrNull(clazz: Class<T>, qualifier: Any? = null): T? =
resolveService(clazz.kotlin, qualifier)
// ========================================================================
// Send
// ========================================================================
/**
* Sends a text message. Thread-safe; blocks until the message is written.
*
* @throws IOException if the connection is closed or the write fails.
*/
@Throws(IOException::class)
fun send(text: String): Unit = sendLock.withLock {
ensureOpen()
try {
channel.sendText(text)
} catch (e: IOException) {
close(WsCloseReason.ClientDisconnected)
throw e
}
}
/**
* Sends a binary message. Thread-safe; blocks until the message is written.
*
* @throws IOException if the connection is closed or the write fails.
*/
@Throws(IOException::class)
fun send(data: ByteArray): Unit = sendLock.withLock {
ensureOpen()
try {
channel.sendBinary(ByteBuffer.wrap(data))
} catch (e: IOException) {
close(WsCloseReason.ClientDisconnected)
throw e
}
}
// ========================================================================
// Callback registration
// ========================================================================
/**
* Registers a callback for incoming messages.
* Multiple callbacks may be registered; they are invoked in registration order.
*/
fun onMessage(callback: (WsMessage) -> Unit) {
messageCallbacksLock.withLock { messageCallbacks.add(callback) }
}
/**
* Registers a callback invoked when the connection closes.
*
* If the connection is already closed and callbacks have been drained,
* the callback is invoked immediately with the final [WsCloseReason].
* Otherwise it is queued and invoked by [close].
*
* Multiple callbacks may be registered; each is invoked exactly once.
*/
fun onClose(callback: (WsCloseReason) -> Unit) {
closeCallbacksLock.withLock {
if (closeCallbacksDrained) {
// close() has already drained the list — invoke immediately.
runCatching { callback(closeReason) }
} else {
closeCallbacks.add(callback)
}
}
}
/**
* Registers a callback for transport or message-processing errors.
*/
fun onError(callback: (Throwable) -> Unit) {
errorCallbacksLock.withLock { errorCallbacks.add(callback) }
}
// ========================================================================
// Internal dispatch — called by the server adapter
// ========================================================================
internal fun dispatchMessage(message: WsMessage) {
if (isClosed) return
val snapshot = messageCallbacksLock.withLock { ArrayList(messageCallbacks) }
snapshot.forEach { cb -> runCatching { cb(message) }.onFailure { dispatchError(it) } }
}
internal fun dispatchError(error: Throwable) {
if (isClosed) return
val snapshot = errorCallbacksLock.withLock { ArrayList(errorCallbacks) }
snapshot.forEach { cb -> runCatching { cb(error) } }
}
// ========================================================================
// Lifecycle
// ========================================================================
override fun close() = close(WsCloseReason.Normal)
/**
* Closes the connection with the given reason.
*
* Idempotent — only the first call takes effect.
* Close callbacks are invoked exactly once, in registration order.
*/
fun close(reason: WsCloseReason) {
if (!closed.compareAndSet(false, true)) return
// Send the close frame outside the lock to avoid holding it
// during potentially blocking network I/O.
runCatching { channel.close(closeCode(reason), closeMessage(reason)) }
// Atomically record the reason, drain the callback list, and raise
// the drained flag so that any concurrent onClose() call either:
// (a) sees drained == false and adds to the list (we will invoke it), or
// (b) sees drained == true and invokes immediately with the correct reason.
val callbacks = closeCallbacksLock.withLock {
closeReason = reason
val snapshot = ArrayList(closeCallbacks)
closeCallbacks.clear()
closeCallbacksDrained = true
snapshot
}
// Clear the remaining callback lists — no more dispatches after close.
messageCallbacksLock.withLock { messageCallbacks.clear() }
errorCallbacksLock.withLock { errorCallbacks.clear() }
callbacks.forEach { cb -> runCatching { cb(reason) } }
}
// ========================================================================
// Helpers
// ========================================================================
private fun ensureOpen() {
if (isClosed) throw IOException("WebSocket connection closed")
}
private fun closeCode(reason: WsCloseReason): Int = when (reason) {
is WsCloseReason.Normal -> 1000
is WsCloseReason.Protocol -> reason.code
else -> 1001
}
private fun closeMessage(reason: WsCloseReason): String = when (reason) {
is WsCloseReason.Normal -> ""
is WsCloseReason.Protocol -> reason.reason
is WsCloseReason.ClientDisconnected -> "Client disconnected"
is WsCloseReason.Error -> reason.cause.message ?: "Error"
}
}
```