code

好问题。先说清楚为什么,再给完整代码。

---

## 为什么 `synchronized` 对虚拟线程不友好

JDK 21 的虚拟线程在遇到 `synchronized` 时会发生 **pinning**(钉住)——虚拟线程被固定到其载体线程(platform thread)上,无法被调度走,退化成阻塞行为,丧失虚拟线程的核心优势。

`ReentrantLock` 则不会 pinning,虚拟线程在等锁时会正确地挂起并释放载体线程。

> JDK 24 已修复大部分 `synchronized` 的 pinning 问题,但在 JDK 21 LTS 上使用 `ReentrantLock` 仍是更稳妥的选择。

---

## 完整重构代码

```kotlin
package io.github.cymoo.colleen.ws

import io.github.cymoo.colleen.Colleen
import io.github.cymoo.colleen.util.http.Headers
import java.io.IOException
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.reflect.KClass

// ============================================================================
// WsMessage
// ============================================================================

/**
 * Represents an incoming WebSocket message.
 */
sealed class WsMessage {
    data class Text(val data: String) : WsMessage()

    data class Binary(val data: ByteArray) : WsMessage() {
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (other !is Binary) return false
            return data.contentEquals(other.data)
        }

        override fun hashCode(): Int = data.contentHashCode()
    }
}

// ============================================================================
// WsCloseReason
// ============================================================================

/**
 * Describes why a WebSocket connection was closed.
 */
sealed class WsCloseReason {
    object Normal : WsCloseReason() {
        override fun toString() = "Normal"
    }

    object ClientDisconnected : WsCloseReason() {
        override fun toString() = "ClientDisconnected"
    }

    data class Error(val cause: Throwable) : WsCloseReason()

    data class Protocol(val code: Int, val reason: String) : WsCloseReason()
}

// ============================================================================
// WsChannel
// ============================================================================

/**
 * Low-level WebSocket channel abstraction.
 *
 * Implementations bridge between the server adapter (e.g. Undertow)
 * and the framework-level [WsConnection].
 *
 * Implementations are NOT required to be thread-safe.
 * All concurrency control is handled by [WsConnection].
 */
interface WsChannel : AutoCloseable {
    @Throws(IOException::class)
    fun sendText(text: String)

    @Throws(IOException::class)
    fun sendBinary(data: ByteBuffer)

    fun close(code: Int, reason: String)
}

// ============================================================================
// WsConnection
// ============================================================================

/**
 * WebSocket connection.
 *
 * This is the primary API surface exposed to user handlers.
 *
 * ## Threading model
 * - [send] methods are serialized internally and may be called from multiple threads.
 * - [close] may be called at any time from any thread.
 * - Event callbacks ([onMessage], [onClose], [onError]) are dispatched to a worker
 *   thread by the server adapter to avoid blocking IO threads.
 *   Messages for the same connection are processed sequentially in order.
 *
 * ## Lifecycle
 * 1. Connection is established by the framework after successful WebSocket handshake.
 * 2. User handler receives the connection and registers callbacks.
 * 3. Messages arrive via [onMessage] callbacks.
 * 4. Connection closes via [close] or when the remote peer disconnects.
 * 5. [onClose] callbacks are invoked exactly once.
 *
 * ## Virtual-thread compatibility
 * All internal locks use [ReentrantLock] instead of `synchronized` to avoid
 * carrier-thread pinning on JDK 21 virtual threads.
 */
class WsConnection internal constructor(
    private val channel: WsChannel,
    val pathParams: Map<String, String>,
    val queryParams: Map<String, List<String>> = emptyMap(),
    private val app: Colleen? = null,
    private val states: MutableMap<String, Any?> = mutableMapOf(),
    private val requestHeaders: Headers = Headers(),
) : AutoCloseable {

    // ========================================================================
    // State
    // ========================================================================

    private val closed = AtomicBoolean(false)

    // Written once inside closeLock, then read freely — @Volatile avoids
    // a redundant AtomicReference while still guaranteeing visibility.
    @Volatile private var closeReason: WsCloseReason = WsCloseReason.Normal

    // True once close() has drained and cleared the closeCallbacks list.
    // Guarded by closeLock. When true, newly registered onClose callbacks
    // are invoked immediately rather than queued.
    private var closeCallbacksDrained = false

    // ========================================================================
    // Locks
    // ========================================================================

    private val sendLock = ReentrantLock()
    private val statesLock = ReentrantLock()
    private val messageCallbacksLock = ReentrantLock()
    private val closeCallbacksLock = ReentrantLock()
    private val errorCallbacksLock = ReentrantLock()

    // ========================================================================
    // Callback lists  (each guarded by its own lock above)
    // ========================================================================

    private val messageCallbacks = ArrayList<(WsMessage) -> Unit>()
    private val closeCallbacks = ArrayList<(WsCloseReason) -> Unit>()
    private val errorCallbacks = ArrayList<(Throwable) -> Unit>()

    // ========================================================================
    // Public state
    // ========================================================================

    val isClosed: Boolean get() = closed.get()

    // ========================================================================
    // Path parameters
    // ========================================================================

    fun pathParam(key: String): String? = pathParams[key]

    // ========================================================================
    // Query parameters
    // ========================================================================

    fun query(key: String): String? = queryParams[key]?.firstOrNull()

    fun queryList(key: String): List<String> = queryParams[key] ?: emptyList()

    // ========================================================================
    // Request headers (from the WebSocket upgrade / handshake request)
    // ========================================================================

    /**
     * Returns the first value of the specified HTTP header, or null if absent.
     * Header names are case-insensitive.
     */
    fun header(key: String): String? = requestHeaders[key]

    /**
     * Returns all values of the specified HTTP header.
     * Returns an empty list if the header is absent.
     * Header names are case-insensitive.
     */
    fun headerValues(key: String): List<String> = requestHeaders.getAll(key)

    // ========================================================================
    // Connection-scoped state
    // ========================================================================

    /**
     * Returns true if the state key exists, regardless of whether its value is null.
     */
    fun hasState(key: String): Boolean = statesLock.withLock { states.containsKey(key) }

    /**
     * Returns the state value for the given key.
     *
     * @throws NoSuchElementException if the key does not exist.
     * @throws NullPointerException if the value is null.
     */
    @Suppress("UNCHECKED_CAST")
    fun <T : Any> getState(key: String): T = statesLock.withLock {
        if (!states.containsKey(key)) throw NoSuchElementException("State '$key' not found")
        states[key] as T
    }

    /**
     * Returns the state value for the given key, or null if the key does not exist.
     */
    @Suppress("UNCHECKED_CAST")
    fun <T> getStateOrNull(key: String): T? = statesLock.withLock {
        if (!states.containsKey(key)) return null
        states[key] as T?
    }

    /**
     * Sets a state value. The value may be null.
     */
    fun setState(key: String, value: Any?) = statesLock.withLock {
        states[key] = value
    }

    // ========================================================================
    // Service injection
    // ========================================================================

    /**
     * Retrieves a required service instance.
     * Resolution walks up the app parent chain (for mounted sub-apps).
     *
     * @throws IllegalStateException if the service is not registered.
     */
    inline fun <reified T : Any> getService(qualifier: Any? = null): T =
        resolveService(T::class, qualifier)
            ?: error("Service ${T::class.simpleName}(qualifier=$qualifier) not registered")

    /**
     * Retrieves an optional service instance, or null if not registered.
     */
    inline fun <reified T : Any> getServiceOrNull(qualifier: Any? = null): T? =
        resolveService(T::class, qualifier)

    /**
     * Retrieves all registered instances of type [T], regardless of qualifier.
     */
    inline fun <reified T : Any> getServices(): List<T> =
        resolveAllServices(T::class)

    @PublishedApi
    internal fun <T : Any> resolveAllServices(kClass: KClass<T>): List<T> =
        app?.serviceContainer?.getAll(kClass) ?: emptyList()

    @PublishedApi
    internal fun <T : Any> resolveService(kClass: KClass<T>, qualifier: Any? = null): T? =
        resolveServiceFromApp(app, kClass, qualifier)

    private tailrec fun <T : Any> resolveServiceFromApp(
        current: Colleen?,
        kClass: KClass<T>,
        qualifier: Any?,
    ): T? {
        if (current == null) return null
        return current.serviceContainer.getOrNull(kClass, qualifier)
            ?: resolveServiceFromApp(current.parent, kClass, qualifier)
    }

    // ========================================================================
    // Java-compatible service injection
    // ========================================================================

    @JvmOverloads
    fun <T : Any> getService(clazz: Class<T>, qualifier: Any? = null): T =
        resolveService(clazz.kotlin, qualifier)
            ?: error("Service ${clazz.simpleName}(qualifier=$qualifier) not registered")

    @JvmOverloads
    fun <T : Any> getServiceOrNull(clazz: Class<T>, qualifier: Any? = null): T? =
        resolveService(clazz.kotlin, qualifier)

    // ========================================================================
    // Send
    // ========================================================================

    /**
     * Sends a text message. Thread-safe; blocks until the message is written.
     *
     * @throws IOException if the connection is closed or the write fails.
     */
    @Throws(IOException::class)
    fun send(text: String): Unit = sendLock.withLock {
        ensureOpen()
        try {
            channel.sendText(text)
        } catch (e: IOException) {
            close(WsCloseReason.ClientDisconnected)
            throw e
        }
    }

    /**
     * Sends a binary message. Thread-safe; blocks until the message is written.
     *
     * @throws IOException if the connection is closed or the write fails.
     */
    @Throws(IOException::class)
    fun send(data: ByteArray): Unit = sendLock.withLock {
        ensureOpen()
        try {
            channel.sendBinary(ByteBuffer.wrap(data))
        } catch (e: IOException) {
            close(WsCloseReason.ClientDisconnected)
            throw e
        }
    }

    // ========================================================================
    // Callback registration
    // ========================================================================

    /**
     * Registers a callback for incoming messages.
     * Multiple callbacks may be registered; they are invoked in registration order.
     */
    fun onMessage(callback: (WsMessage) -> Unit) {
        messageCallbacksLock.withLock { messageCallbacks.add(callback) }
    }

    /**
     * Registers a callback invoked when the connection closes.
     *
     * If the connection is already closed and callbacks have been drained,
     * the callback is invoked immediately with the final [WsCloseReason].
     * Otherwise it is queued and invoked by [close].
     *
     * Multiple callbacks may be registered; each is invoked exactly once.
     */
    fun onClose(callback: (WsCloseReason) -> Unit) {
        closeCallbacksLock.withLock {
            if (closeCallbacksDrained) {
                // close() has already drained the list — invoke immediately.
                runCatching { callback(closeReason) }
            } else {
                closeCallbacks.add(callback)
            }
        }
    }

    /**
     * Registers a callback for transport or message-processing errors.
     */
    fun onError(callback: (Throwable) -> Unit) {
        errorCallbacksLock.withLock { errorCallbacks.add(callback) }
    }

    // ========================================================================
    // Internal dispatch — called by the server adapter
    // ========================================================================

    internal fun dispatchMessage(message: WsMessage) {
        if (isClosed) return
        val snapshot = messageCallbacksLock.withLock { ArrayList(messageCallbacks) }
        snapshot.forEach { cb -> runCatching { cb(message) }.onFailure { dispatchError(it) } }
    }

    internal fun dispatchError(error: Throwable) {
        if (isClosed) return
        val snapshot = errorCallbacksLock.withLock { ArrayList(errorCallbacks) }
        snapshot.forEach { cb -> runCatching { cb(error) } }
    }

    // ========================================================================
    // Lifecycle
    // ========================================================================

    override fun close() = close(WsCloseReason.Normal)

    /**
     * Closes the connection with the given reason.
     *
     * Idempotent — only the first call takes effect.
     * Close callbacks are invoked exactly once, in registration order.
     */
    fun close(reason: WsCloseReason) {
        if (!closed.compareAndSet(false, true)) return

        // Send the close frame outside the lock to avoid holding it
        // during potentially blocking network I/O.
        runCatching { channel.close(closeCode(reason), closeMessage(reason)) }

        // Atomically record the reason, drain the callback list, and raise
        // the drained flag so that any concurrent onClose() call either:
        //   (a) sees drained == false and adds to the list (we will invoke it), or
        //   (b) sees drained == true and invokes immediately with the correct reason.
        val callbacks = closeCallbacksLock.withLock {
            closeReason = reason
            val snapshot = ArrayList(closeCallbacks)
            closeCallbacks.clear()
            closeCallbacksDrained = true
            snapshot
        }

        // Clear the remaining callback lists — no more dispatches after close.
        messageCallbacksLock.withLock { messageCallbacks.clear() }
        errorCallbacksLock.withLock { errorCallbacks.clear() }

        callbacks.forEach { cb -> runCatching { cb(reason) } }
    }

    // ========================================================================
    // Helpers
    // ========================================================================

    private fun ensureOpen() {
        if (isClosed) throw IOException("WebSocket connection closed")
    }

    private fun closeCode(reason: WsCloseReason): Int = when (reason) {
        is WsCloseReason.Normal -> 1000
        is WsCloseReason.Protocol -> reason.code
        else -> 1001
    }

    private fun closeMessage(reason: WsCloseReason): String = when (reason) {
        is WsCloseReason.Normal -> ""
        is WsCloseReason.Protocol -> reason.reason
        is WsCloseReason.ClientDisconnected -> "Client disconnected"
        is WsCloseReason.Error -> reason.cause.message ?: "Error"
    }
}
```