task manager

package io.github.cymoo.cleary

import com.cronutils.model.CronType
import com.cronutils.model.definition.CronDefinitionBuilder
import com.cronutils.model.time.ExecutionTime
import com.cronutils.parser.CronParser
import java.time.Duration
import java.time.Instant
import java.time.ZoneId
import java.time.ZonedDateTime
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong

/**
 * A lightweight, production-grade task scheduler supporting cron expressions,
 * fixed-rate scheduling, and manual task execution.
 *
 * Example usage:
 * ```kotlin
 * val tasks = TaskManager {
 *     concurrency = 4
 *     context["db"] = database
 * }
 *
 * tasks.task("heartbeat", every = 5.seconds) {
 *     println("heartbeat: $taskName")
 * }
 *
 * tasks.start()
 * // ...
 * tasks.shutdown()
 * ```
 *
 * ### Scheduling semantics
 * - **Cron**: next execution is anchored to the previously *scheduled* time, so the
 *   schedule stays on the wall-clock grid even under load.
 * - **FixedRate**: next execution = last *scheduled* time + interval, preventing drift.
 * - **Once**: executes exactly once at the given [Instant].
 * - **WithInitialDelay**: delays the first execution of any inner [Schedule].
 *
 * ### Concurrency guard
 * When `allowConcurrent = false` (default), an execution that is still running when
 * the next slot arrives is simply **skipped** — no queuing, no backpressure.
 * This applies to both scheduled and manual ([run] / [runBlocking]) invocations.
 */
class TaskManager private constructor(
    private val config: TaskManagerConfig
) {
    companion object {
        private val threadCounter = AtomicLong(0)

        @JvmStatic
        fun builder(): Builder = Builder()

        operator fun invoke(block: TaskManagerConfig.() -> Unit = {}): TaskManager =
            TaskManager(TaskManagerConfig().apply(block))
    }

    // ── Execution pool ────────────────────────────────────────────────────────
    private val executor = ScheduledThreadPoolExecutor(config.concurrency) { r ->
        Thread(r, "${config.threadNamePrefix}-worker-${threadCounter.incrementAndGet()}").apply {
            isDaemon = true
        }
    }.apply { removeOnCancelPolicy = true }

    // ── Scheduler ─────────────────────────────────────────────────────────────
    // DelayQueue.take() blocks until the head element is due.
    // Inserting an element earlier than the current head automatically wakes the
    // waiting thread — no manual sleep/interrupt tricks needed.
    private val taskQueue = DelayQueue<ScheduledTask>()
    private val schedulerRunning = AtomicBoolean(false)
    private val started = AtomicBoolean(false)
    private val schedulerThread = Thread(::runScheduler, "${config.threadNamePrefix}-scheduler")
        .apply { isDaemon = true }

    // ── State ─────────────────────────────────────────────────────────────────
    private val tasks = ConcurrentHashMap<String, TaskEntry>()

    // Instance-level context — NOT a companion-object field, which would pollute
    // all TaskManager instances sharing the same classloader.
    private val globalContext = ConcurrentHashMap<String, Any>().apply {
        putAll(config.context)
    }

    // =========================================================================
    // Task Registration
    // =========================================================================

    /**
     * Registers a task with a simple schedule.
     *
     * Exactly one of [cron] or [every] may be set; both null means the task can
     * only be triggered via [run] / [runBlocking].
     *
     * @param name            unique task identifier
     * @param cron            Quartz cron expression, e.g. `"0 0/5 * * * ?"`
     * @param every           fixed-rate interval (must be positive)
     * @param allowConcurrent if false (default) overlapping executions are skipped
     * @param block           work to perform; receiver is a [TaskContext]
     */
    fun task(
        name: String,
        cron: String? = null,
        every: Duration? = null,
        allowConcurrent: Boolean = false,
        block: TaskContext.() -> Any?
    ) {
        require(cron == null || every == null) { "Cannot specify both 'cron' and 'every'" }
        val schedule = when {
            cron  != null -> Schedule.Cron(cron)
            every != null -> Schedule.FixedRate(every)
            else          -> null
        }
        register(name, schedule, allowConcurrent, block)
    }

    /**
     * Registers a task with an explicit [Schedule].
     */
    fun task(
        name: String,
        schedule: Schedule,
        allowConcurrent: Boolean = false,
        block: TaskContext.() -> Any?
    ) = register(name, schedule, allowConcurrent, block)

    private fun register(
        name: String,
        schedule: Schedule?,
        allowConcurrent: Boolean,
        block: TaskContext.() -> Any?
    ) {
        require(name.isNotBlank()) { "Task name cannot be blank" }
        require(!tasks.containsKey(name)) { "Task '$name' is already registered" }

        val entry = TaskEntry(
            name            = name,
            schedule        = schedule,
            trigger         = schedule?.toTrigger(),
            allowConcurrent = allowConcurrent,
            enabled         = AtomicBoolean(true),
            executing       = AtomicBoolean(false),
            block           = block
        )
        tasks[name] = entry

        if (started.get() && entry.trigger != null && entry.enabled.get()) {
            enqueue(name, entry.trigger)
        }
    }

    // =========================================================================
    // Lifecycle
    // =========================================================================

    /** Starts the scheduler. Safe to call only once; subsequent calls are no-ops. */
    fun start() {
        if (!started.compareAndSet(false, true)) return
        schedulerRunning.set(true)
        schedulerThread.start()
        tasks.values
            .filter { it.trigger != null && it.enabled.get() }
            .forEach { enqueue(it.name, it.trigger!!) }
    }

    /**
     * Shuts down the scheduler.
     *
     * @param awaitTermination if true (default), blocks up to 30 s for in-flight
     *   tasks to complete before forcing a shutdown
     */
    fun shutdown(awaitTermination: Boolean = true) {
        if (!started.get()) return
        schedulerRunning.set(false)
        schedulerThread.interrupt()
        // Wait for the scheduler thread to stop before shutting down the executor,
        // so it cannot submit new tasks after executor.shutdown() is called.
        try {
            schedulerThread.join()
        } catch (_: InterruptedException) {
            Thread.currentThread().interrupt()
        }
        executor.shutdown()
        if (awaitTermination) {
            try {
                if (!executor.awaitTermination(30, TimeUnit.SECONDS)) executor.shutdownNow()
            } catch (_: InterruptedException) {
                executor.shutdownNow()
                Thread.currentThread().interrupt()
            }
        } else {
            executor.shutdownNow()
        }
    }

    // =========================================================================
    // Manual Execution
    // =========================================================================

    /**
     * Submits a task for immediate asynchronous execution and returns a [Future].
     *
     * Respects the `allowConcurrent` guard: if another execution of the same task
     * is running and `allowConcurrent = false`, the submission is skipped and the
     * future completes with `null`.
     *
     * @throws IllegalStateException if the scheduler has not been started or has been shut down
     */
    fun run(name: String, contextValues: Map<String, Any> = emptyMap()): Future<Any?> {
        check(started.get() && !executor.isShutdown) { "TaskManager is not running" }
        val entry = tasks[name] ?: throw NoSuchElementException("Task '$name' not found")
        return executor.submit(Callable {
            val skipped = !entry.allowConcurrent && !entry.executing.compareAndSet(false, true)
            if (skipped) return@Callable null
            try {
                executeTask(entry, contextValues, scheduledTime = null)
            } finally {
                if (!entry.allowConcurrent) entry.executing.set(false)
            }
        })
    }

    /**
     * Runs a task and blocks until it completes, re-throwing any exception the task raised.
     *
     * @throws IllegalStateException if the scheduler has not been started or has been shut down
     */
    fun runBlocking(name: String, contextValues: Map<String, Any> = emptyMap()): Any? =
        try {
            run(name, contextValues).get()
        } catch (e: ExecutionException) {
            throw e.cause ?: e
        }

    // =========================================================================
    // Task Control
    // =========================================================================

    /**
     * Re-enables a previously disabled task.
     * The task is re-enqueued and will run at its next scheduled time.
     */
    fun enable(name: String) {
        val entry = tasks[name] ?: throw NoSuchElementException("Task '$name' not found")
        if (entry.enabled.compareAndSet(false, true) && started.get() && entry.trigger != null) {
            enqueue(name, entry.trigger)
        }
    }

    /**
     * Disables scheduled execution of a task.
     * In-flight executions are not interrupted.
     * Stale queue entries are silently discarded when they come due.
     */
    fun disable(name: String) {
        (tasks[name] ?: throw NoSuchElementException("Task '$name' not found")).enabled.set(false)
    }

    /**
     * Removes a task from the registry.
     * In-flight executions are not interrupted.
     * Any pending queue entry is silently discarded when it comes due.
     */
    fun remove(name: String) { tasks.remove(name) }

    // =========================================================================
    // Query
    // =========================================================================

    fun exists(name: String): Boolean = tasks.containsKey(name)
    fun listTaskNames(): List<String> = tasks.keys.toList()

    fun getTaskDefinition(name: String): TaskDefinition? {
        val e = tasks[name] ?: return null
        return TaskDefinition(
            name                = e.name,
            scheduleDescription = e.schedule?.describe(),
            enabled             = e.enabled.get(),
            allowConcurrent     = e.allowConcurrent
        )
    }

    // =========================================================================
    // Internal scheduling
    // =========================================================================

    /**
     * Computes the first scheduled time for [trigger] and adds it to the queue.
     *
     * We subtract 1 ms from "now" before calling [Trigger.nextExecutionTime] because
     * cron implementations treat the base time as exclusive (they return the next time
     * *strictly after* the base). Without this adjustment, a task registered at exactly
     * a cron boundary would silently skip that slot.
     */
    private fun enqueue(taskName: String, trigger: Trigger) {
        val nextTime = trigger.nextExecutionTime(System.currentTimeMillis() - 1) ?: return
        taskQueue.offer(ScheduledTask(taskName, nextTime))
    }

    private fun runScheduler() {
        while (schedulerRunning.get()) {
            try {
                val scheduled = taskQueue.take()
                dispatch(scheduled)
            } catch (_: InterruptedException) {
                break
            } catch (e: Exception) {
                // Only infrastructure errors reach here (e.g. a bug in dispatch).
                // Task-level errors are caught inside executeTask and reported via onTaskComplete.
                System.err.println("[TaskManager] Scheduler error: ${e.message}")
            }
        }
    }

    private fun dispatch(scheduled: ScheduledTask) {
        // Look up the live entry by name rather than holding a direct reference in
        // ScheduledTask, so stale entries (from remove() or disable()) are naturally
        // garbage-collected once the queue drains them.
        val entry = tasks[scheduled.taskName] ?: return
        if (!entry.enabled.get()) return

        // Re-enqueue *before* submitting to the executor.
        // Anchoring to scheduledTime (not wall clock) prevents FixedRate / Cron drift.
        val nextTime = entry.trigger?.nextExecutionTime(scheduled.scheduledTime)
        if (nextTime != null) {
            taskQueue.offer(ScheduledTask(scheduled.taskName, nextTime))
        }

        try {
            executor.submit {
                val skipped = !entry.allowConcurrent && !entry.executing.compareAndSet(false, true)
                if (!skipped) {
                    try {
                        executeTask(entry, emptyMap(), scheduledTime = scheduled.scheduledTime)
                    } finally {
                        if (!entry.allowConcurrent) entry.executing.set(false)
                    }
                }
            }
        } catch (_: RejectedExecutionException) {
            // Executor shut down between the schedulerRunning check and submit().
            // Harmless — the scheduler loop will exit on its next iteration.
        }
    }

    private fun executeTask(
        entry: TaskEntry,
        extra: Map<String, Any>,
        scheduledTime: Long?   // null for manual runs; planned trigger time for scheduled runs
    ): Any? {
        val ctx = ConcurrentHashMap(globalContext).apply { putAll(extra) }
        val actualStartTime = System.currentTimeMillis()

        config.onTaskStart?.invoke(
            TaskStartEvent(
                taskName      = entry.name,
                scheduledTime = scheduledTime ?: actualStartTime,
                actualTime    = actualStartTime,
                context       = ctx   // live map — onTaskStart may inject values visible to the task
            )
        )

        var result: Any? = null
        var error: Throwable? = null
        try {
            result = entry.block(TaskContextImpl(entry.name, ctx))
        } catch (t: Throwable) {
            error = t
        }

        val endTime = System.currentTimeMillis()
        config.onTaskComplete?.invoke(
            TaskCompleteEvent(entry.name, actualStartTime, endTime, result, error)
        )

        if (error != null) throw error
        return result
    }

    // =========================================================================
    // Internal data structures
    // =========================================================================

    private data class TaskEntry(
        val name:            String,
        val schedule:        Schedule?,
        val trigger:         Trigger?,
        val allowConcurrent: Boolean,
        val enabled:         AtomicBoolean,
        val executing:       AtomicBoolean,
        val block:           TaskContext.() -> Any?
    )

    /**
     * An element in the [DelayQueue] that becomes due at [scheduledTime].
     *
     * Stores the task name rather than a direct [TaskEntry] reference so that
     * removed tasks can be garbage-collected once the queue drains their entry,
     * instead of being kept alive by the queue itself.
     *
     * [scheduledTime] is the *planned* trigger time and is passed to
     * [Trigger.nextExecutionTime] on re-queue, keeping FixedRate and Cron schedules
     * anchored to the original time grid rather than the (potentially delayed) wall clock.
     */
    private class ScheduledTask(
        val taskName:      String,
        val scheduledTime: Long
    ) : Delayed {
        override fun getDelay(unit: TimeUnit): Long =
            unit.convert(scheduledTime - System.currentTimeMillis(), TimeUnit.MILLISECONDS)

        // `other` is always a ScheduledTask inside a single DelayQueue instance.
        override fun compareTo(other: Delayed): Int =
            scheduledTime.compareTo((other as ScheduledTask).scheduledTime)
    }

    // =========================================================================
    // Java-friendly Builder
    // =========================================================================

    class Builder {
        private val config = TaskManagerConfig()
        fun concurrency(n: Int)                     = apply { config.concurrency = n }
        fun threadNamePrefix(s: String)             = apply { config.threadNamePrefix = s }
        fun onTaskStart(l: TaskStartListener)       = apply { config.onTaskStart    = l::onStart }
        fun onTaskComplete(l: TaskCompleteListener) = apply { config.onTaskComplete = l::onComplete }
        fun putContext(key: String, value: Any)     = apply { config.context[key] = value }
        fun build(): TaskManager = TaskManager(config)
    }
}

// =============================================================================
// Configuration
// =============================================================================

class TaskManagerConfig {
    var concurrency: Int = Runtime.getRuntime().availableProcessors()
    var threadNamePrefix: String = "task-manager"
    val context: MutableMap<String, Any> = ConcurrentHashMap()
    var onTaskStart:    ((TaskStartEvent)    -> Unit)? = null
    var onTaskComplete: ((TaskCompleteEvent) -> Unit)? = null
}

// =============================================================================
// Events
// =============================================================================

data class TaskStartEvent(
    val taskName:      String,
    /** The time the task was *planned* to run (epoch millis). Equals [actualTime] for manual runs. */
    val scheduledTime: Long,
    /** The time execution *actually* began (epoch millis). */
    val actualTime:    Long,
    /**
     * The live context map for this task execution.
     *
     * Values added or modified here are visible to the task itself, making this the
     * right place to inject per-execution data such as trace IDs or request metadata.
     *
     * The callback is invoked synchronously before the task block runs, so there is
     * no concurrency concern between the callback and the task.
     *
     * Note: modifications do **not** affect [TaskManager]'s global context or other
     * task executions — each execution receives its own copy of the global context.
     */
    val context: MutableMap<String, Any>
)

data class TaskCompleteEvent(
    val taskName:  String,
    val startTime: Long,
    val endTime:   Long,
    val result:    Any?,
    /** Non-null if the task threw an exception. */
    val error:     Throwable?
) {
    val duration:  Long    get() = endTime - startTime
    val isSuccess: Boolean get() = error == null
}

// =============================================================================
// Context
// =============================================================================

interface TaskContext {
    val taskName: String
    /** Returns the value for [key], throwing [NoSuchElementException] if absent or if the cast fails. */
    operator fun <T : Any> get(key: String): T
    fun <T> getOrNull(key: String): T?
    fun <T> getOrDefault(key: String, default: T): T
    operator fun set(key: String, value: Any)
    fun remove(key: String)
}

internal class TaskContextImpl(
    override val taskName: String,
    private val states: MutableMap<String, Any>
) : TaskContext {
    @Suppress("UNCHECKED_CAST")
    override fun <T : Any> get(key: String): T =
        states[key] as? T ?: throw NoSuchElementException("Context key '$key' not found")

    @Suppress("UNCHECKED_CAST")
    override fun <T> getOrNull(key: String): T? = states[key] as? T

    @Suppress("UNCHECKED_CAST")
    override fun <T> getOrDefault(key: String, default: T): T = states[key] as? T ?: default

    override fun set(key: String, value: Any) { states[key] = value }
    override fun remove(key: String)           { states.remove(key) }
}

// =============================================================================
// Query results
// =============================================================================

data class TaskDefinition(
    val name:                String,
    val scheduleDescription: String?,
    val enabled:             Boolean,
    val allowConcurrent:     Boolean
)

// =============================================================================
// Schedule (public API)
// =============================================================================

sealed class Schedule {
    /** Quartz cron expression, e.g. `"0 0/5 * * * ?"` */
    data class Cron(val expression: String, val zone: ZoneId = ZoneId.systemDefault()) : Schedule()

    /** Execute repeatedly at a fixed rate regardless of how long each execution takes. */
    data class FixedRate(val interval: Duration) : Schedule() {
        init {
            require(!interval.isZero && !interval.isNegative) {
                "FixedRate interval must be positive, got: $interval"
            }
        }
    }

    /**
     * Execute exactly once at the given [Instant].
     *
     * If [at] is already in the past when the scheduler processes it, the task will
     * fire immediately (the [DelayQueue] returns a non-positive delay, which it treats
     * as "due now"). No validation is performed at registration time because the gap
     * between construction and scheduling makes an at-construction check unreliable,
     * and a `data class` `init` block with `Instant.now()` would also break `copy()`.
     */
    data class Once(val at: Instant) : Schedule()

    /** Delay the first execution of [schedule] by [delay]. */
    data class WithInitialDelay(val delay: Duration, val schedule: Schedule) : Schedule() {
        init {
            require(!delay.isNegative) { "WithInitialDelay.delay must be non-negative, got: $delay" }
        }
    }

    internal fun toTrigger(): Trigger = when (this) {
        is Cron             -> CronTrigger(expression, zone)
        is FixedRate        -> FixedRateTrigger(interval)
        is Once             -> OnceTrigger(at)
        is WithInitialDelay -> InitialDelayTrigger(delay, schedule.toTrigger())
    }

    internal fun describe(): String = when (this) {
        is Cron             -> "cron[$zone]: $expression"
        is FixedRate        -> "every ${interval.toSeconds()}s"
        is Once             -> "once at $at"
        is WithInitialDelay -> "initial-delay ${delay.toSeconds()}s, then ${schedule.describe()}"
    }
}

// =============================================================================
// Triggers (internal)
// =============================================================================

internal interface Trigger {
    /**
     * Returns the next execution time in epoch-millis, or `null` if there are no
     * further executions.
     *
     * @param lastScheduledTime epoch-millis of the **previously planned** trigger time,
     *   not the actual wall-clock time of the last execution. Anchoring to the planned
     *   time is what keeps FixedRate and Cron schedules from drifting under load.
     */
    fun nextExecutionTime(lastScheduledTime: Long): Long?
}

internal class CronTrigger(expression: String, private val zone: ZoneId) : Trigger {
    private val executionTime: ExecutionTime = try {
        val parser = CronParser(CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ))
        ExecutionTime.forCron(parser.parse(expression).also { it.validate() })
    } catch (e: Exception) {
        throw IllegalArgumentException("Invalid cron expression: '$expression'", e)
    }

    override fun nextExecutionTime(lastScheduledTime: Long): Long? {
        // Anchor to the last *planned* time so the cron stays on the wall-clock grid.
        // The caller subtracts 1 ms on the very first call (see enqueue) so that a task
        // registered at exactly a cron boundary is not silently skipped.
        val base = ZonedDateTime.ofInstant(Instant.ofEpochMilli(lastScheduledTime), zone)
        return executionTime.nextExecution(base).map { it.toInstant().toEpochMilli() }.orElse(null)
    }
}

internal class FixedRateTrigger(private val interval: Duration) : Trigger {
    override fun nextExecutionTime(lastScheduledTime: Long): Long =
        lastScheduledTime + interval.toMillis()
}

internal class OnceTrigger(private val at: Instant) : Trigger {
    private val fired = AtomicBoolean(false)
    override fun nextExecutionTime(lastScheduledTime: Long): Long? =
        if (fired.compareAndSet(false, true)) at.toEpochMilli() else null
}

internal class InitialDelayTrigger(
    private val delay: Duration,
    private val inner: Trigger
) : Trigger {
    private val firstCall = AtomicBoolean(true)

    override fun nextExecutionTime(lastScheduledTime: Long): Long? =
        if (firstCall.compareAndSet(true, false))
        // Initial delay is relative to the current wall clock, not to any previous
        // scheduled time — the whole point is to push the first run into the future
        // by a fixed amount from "now".
            System.currentTimeMillis() + delay.toMillis()
        else
            inner.nextExecutionTime(lastScheduledTime)
}

// =============================================================================
// Java-friendly listener interfaces
// =============================================================================

fun interface TaskStartListener    { fun onStart(event: TaskStartEvent) }
fun interface TaskCompleteListener { fun onComplete(event: TaskCompleteEvent) }