claude
package com.taskmanager
import com.cronutils.model.CronType
import com.cronutils.model.definition.CronDefinitionBuilder
import com.cronutils.model.time.ExecutionTime
import com.cronutils.parser.CronParser
import java.time.Duration
import java.time.Instant
import java.time.ZonedDateTime
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
/**
* A lightweight task scheduler supporting cron expressions, fixed-rate scheduling,
* and manual task execution.
*
* Example usage:
* ```kotlin
* val tasks = TaskManager {
* concurrency = 4
* context["db"] = database
* }
*
* tasks.task("heartbeat", every = 5.seconds) {
* println("heartbeat")
* }
*
* tasks.start()
* ```
*/
class TaskManager private constructor(
private val config: TaskManagerConfig
) {
companion object {
@JvmStatic
fun builder(): Builder = Builder()
operator fun invoke(config: TaskManagerConfig.() -> Unit = {}): TaskManager {
val cfg = TaskManagerConfig().apply(config)
return TaskManager(cfg)
}
}
// Thread pool for task execution
private val executor: ScheduledThreadPoolExecutor = ScheduledThreadPoolExecutor(
config.concurrency,
ThreadFactory { runnable ->
Thread(runnable, "${config.threadNamePrefix}-${threadCounter.incrementAndGet()}").apply {
isDaemon = true
}
}
).apply {
removeOnCancelPolicy = true
}
// Scheduler thread for checking and dispatching tasks
private val schedulerThread: Thread
private val schedulerRunning = AtomicBoolean(false)
private val started = AtomicBoolean(false)
// Task registry
private val tasks = ConcurrentHashMap<String, TaskEntry>()
// Priority queue for next execution times
private val taskQueue = PriorityBlockingQueue<TaskQueueEntry>()
init {
// Initialize global context
config.context.forEach { (key, value) ->
globalContext[key] = value
}
// Create scheduler thread
schedulerThread = Thread({
while (schedulerRunning.get()) {
try {
scheduleNextTask()
} catch (e: InterruptedException) {
break
} catch (e: Exception) {
// Log error but continue
System.err.println("Scheduler error: ${e.message}")
}
}
}, "${config.threadNamePrefix}-scheduler")
schedulerThread.isDaemon = true
if (config.autoStart) {
start()
}
}
// ========================================================================
// Task Registration
// ========================================================================
/**
* Registers a task with simple scheduling parameters.
*
* @param name unique task identifier
* @param cron cron expression (e.g., "0 * * * * ?" for every minute)
* @param every fixed-rate interval
* @param allowConcurrent whether to allow concurrent executions of the same task
* @param block the task execution block
*/
fun task(
name: String,
cron: String? = null,
every: Duration? = null,
allowConcurrent: Boolean = false,
block: TaskContext.() -> Any?
) {
require(name.isNotBlank()) { "Task name cannot be blank" }
require(tasks[name] == null) { "Task '$name' already registered" }
require(cron == null || every == null) { "Cannot specify both cron and every" }
val schedule = when {
cron != null -> Schedule.Cron(cron)
every != null -> Schedule.FixedRate(every)
else -> null
}
registerTask(name, schedule, allowConcurrent, block)
}
/**
* Registers a task with advanced scheduling.
*
* @param name unique task identifier
* @param schedule the scheduling strategy
* @param allowConcurrent whether to allow concurrent executions of the same task
* @param block the task execution block
*/
fun task(
name: String,
schedule: Schedule,
allowConcurrent: Boolean = false,
block: TaskContext.() -> Any?
) {
require(name.isNotBlank()) { "Task name cannot be blank" }
require(tasks[name] == null) { "Task '$name' already registered" }
registerTask(name, schedule, allowConcurrent, block)
}
private fun registerTask(
name: String,
schedule: Schedule?,
allowConcurrent: Boolean,
block: TaskContext.() -> Any?
) {
val trigger = schedule?.toTrigger()
val entry = TaskEntry(
name = name,
schedule = schedule,
trigger = trigger,
allowConcurrent = allowConcurrent,
enabled = AtomicBoolean(true),
executing = AtomicBoolean(false),
executionCount = AtomicLong(0),
block = block
)
tasks[name] = entry
// If already started and has trigger, schedule it
if (started.get() && trigger != null) {
scheduleTask(entry)
}
}
// ========================================================================
// Lifecycle
// ========================================================================
/**
* Starts the task scheduler.
* Tasks will begin executing according to their schedules.
*/
fun start() {
if (started.compareAndSet(false, true)) {
schedulerRunning.set(true)
schedulerThread.start()
// Schedule all tasks with triggers
tasks.values.filter { it.trigger != null && it.enabled.get() }.forEach { entry ->
scheduleTask(entry)
}
}
}
/**
* Shuts down the task scheduler.
*
* @param awaitTermination if true, waits for running tasks to complete
*/
fun shutdown(awaitTermination: Boolean = true) {
if (!started.get()) return
schedulerRunning.set(false)
schedulerThread.interrupt()
executor.shutdown()
if (awaitTermination) {
try {
executor.awaitTermination(30, TimeUnit.SECONDS)
} catch (e: InterruptedException) {
executor.shutdownNow()
}
} else {
executor.shutdownNow()
}
}
// ========================================================================
// Manual Execution
// ========================================================================
/**
* Manually runs a task with optional context values.
*
* @param name the task name
* @param contextValues additional context values for this execution
* @return Future that completes when the task finishes
*/
fun run(name: String, contextValues: Map<String, Any> = emptyMap()): Future<Any?> {
val entry = tasks[name] ?: throw NoSuchElementException("Task '$name' not found")
return executor.submit(Callable { executeTask(entry, contextValues) })
}
/**
* Manually runs a task and blocks until completion.
*
* @param name the task name
* @param contextValues additional context values for this execution
* @return the task result
*/
fun runBlocking(name: String, contextValues: Map<String, Any> = emptyMap()): Any? {
return run(name, contextValues).get()
}
// ========================================================================
// Task Control
// ========================================================================
/**
* Enables a task's scheduled execution.
*/
fun enable(name: String) {
val entry = tasks[name] ?: throw NoSuchElementException("Task '$name' not found")
if (entry.enabled.compareAndSet(false, true)) {
if (started.get() && entry.trigger != null) {
scheduleTask(entry)
}
}
}
/**
* Disables a task's scheduled execution.
* Running executions are not affected.
*/
fun disable(name: String) {
val entry = tasks[name] ?: throw NoSuchElementException("Task '$name' not found")
entry.enabled.set(false)
}
/**
* Removes a task from the scheduler.
* Running executions are not affected.
*/
fun remove(name: String) {
tasks.remove(name)
}
// ========================================================================
// Query API
// ========================================================================
/**
* Checks if a task exists.
*/
fun exists(name: String): Boolean = tasks.containsKey(name)
/**
* Gets the definition of a task.
* @return TaskDefinition or null if not found
*/
fun getTaskDefinition(name: String): TaskDefinition? {
val entry = tasks[name] ?: return null
return TaskDefinition(
name = entry.name,
scheduleDescription = entry.schedule?.describe(),
enabled = entry.enabled.get(),
allowConcurrent = entry.allowConcurrent
)
}
/**
* Lists all registered task names.
*/
fun listTaskNames(): List<String> = tasks.keys.toList()
// ========================================================================
// Internal Scheduling Logic
// ========================================================================
private fun scheduleTask(entry: TaskEntry) {
val nextTime = entry.trigger?.nextExecutionTime(null) ?: return
taskQueue.offer(TaskQueueEntry(entry, nextTime))
}
private fun scheduleNextTask() {
// Wait for next task
val queueEntry = taskQueue.take()
val now = System.currentTimeMillis()
val delay = queueEntry.nextExecutionTime - now
if (delay > 0) {
Thread.sleep(delay)
}
// Check if task still enabled
val entry = queueEntry.entry
if (!entry.enabled.get() || !tasks.containsKey(entry.name)) {
return
}
// Execute task
executor.submit {
executeScheduledTask(entry)
}
// Reschedule if trigger has next execution
val nextTime = entry.trigger?.nextExecutionTime(System.currentTimeMillis())
if (nextTime != null) {
taskQueue.offer(TaskQueueEntry(entry, nextTime))
}
}
private fun executeScheduledTask(entry: TaskEntry) {
// Check concurrent execution
if (!entry.allowConcurrent && !entry.executing.compareAndSet(false, true)) {
// Skip this execution
return
}
try {
executeTask(entry, emptyMap())
} finally {
if (!entry.allowConcurrent) {
entry.executing.set(false)
}
}
}
private fun executeTask(entry: TaskEntry, additionalContext: Map<String, Any>): Any? {
val executionCount = entry.executionCount.incrementAndGet()
val startTime = System.currentTimeMillis()
// Build context
val contextMap = ConcurrentHashMap<String, Any?>(globalContext)
contextMap.putAll(additionalContext)
val context = TaskContextImpl(
taskName = entry.name,
executionCount = executionCount,
states = contextMap
)
// Fire onTaskStart
config.onTaskStart?.invoke(
TaskStartEvent(
taskName = entry.name,
scheduledTime = startTime,
actualTime = startTime,
executionCount = executionCount,
context = contextMap
)
)
var result: Any? = null
var error: Throwable? = null
try {
result = entry.block.invoke(context)
} catch (e: Throwable) {
error = e
}
val endTime = System.currentTimeMillis()
// Fire onTaskComplete
config.onTaskComplete?.invoke(
TaskExecution(
taskName = entry.name,
startTime = startTime,
endTime = endTime,
executionCount = executionCount,
result = result,
error = error
)
)
// Re-throw if error
if (error != null) {
throw error
}
return result
}
// ========================================================================
// Builder for Java
// ========================================================================
class Builder {
private val config = TaskManagerConfig()
fun concurrency(value: Int) = apply { config.concurrency = value }
fun threadNamePrefix(value: String) = apply { config.threadNamePrefix = value }
fun autoStart(value: Boolean) = apply { config.autoStart = value }
fun onTaskStart(listener: TaskStartListener) = apply {
config.onTaskStart = { listener.onStart(it) }
}
fun onTaskComplete(listener: TaskCompleteListener) = apply {
config.onTaskComplete = { listener.onComplete(it) }
}
fun putContext(key: String, value: Any) = apply {
config.context[key] = value
}
fun build(): TaskManager = TaskManager(config)
}
// ========================================================================
// Internal Data Structures
// ========================================================================
private data class TaskEntry(
val name: String,
val schedule: Schedule?,
val trigger: Trigger?,
val allowConcurrent: Boolean,
val enabled: AtomicBoolean,
val executing: AtomicBoolean,
val executionCount: AtomicLong,
val block: TaskContext.() -> Any?
)
private data class TaskQueueEntry(
val entry: TaskEntry,
val nextExecutionTime: Long
) : Comparable<TaskQueueEntry> {
override fun compareTo(other: TaskQueueEntry): Int {
return nextExecutionTime.compareTo(other.nextExecutionTime)
}
}
// ========================================================================
// Static
// ========================================================================
companion object {
private val threadCounter = AtomicLong(0)
private val globalContext = ConcurrentHashMap<String, Any?>()
}
}
// ============================================================================
// Configuration
// ============================================================================
class TaskManagerConfig {
var concurrency: Int = Runtime.getRuntime().availableProcessors()
var threadNamePrefix: String = "task-manager"
var autoStart: Boolean = false
val context: MutableMap<String, Any> = ConcurrentHashMap()
var onTaskStart: ((TaskStartEvent) -> Unit)? = null
var onTaskComplete: ((TaskExecution) -> Unit)? = null
}
// ============================================================================
// Events
// ============================================================================
data class TaskStartEvent(
val taskName: String,
val scheduledTime: Long,
val actualTime: Long,
val executionCount: Long,
val context: MutableMap<String, Any>
)
data class TaskExecution(
val taskName: String,
val startTime: Long,
val endTime: Long,
val executionCount: Long,
val result: Any?,
val error: Throwable?
) {
val duration: Long get() = endTime - startTime
val isSuccess: Boolean get() = error == null
}
// ============================================================================
// Context
// ============================================================================
interface TaskContext {
operator fun <T : Any> get(key: String): T
fun <T> getOrNull(key: String): T?
fun <T> getOrDefault(key: String, default: T): T
operator fun set(key: String, value: Any?)
fun remove(key: String)
val taskName: String
val executionCount: Long
}
internal class TaskContextImpl(
override val taskName: String,
override val executionCount: Long,
private val states: MutableMap<String, Any?> = ConcurrentHashMap()
) : TaskContext {
@Suppress("UNCHECKED_CAST")
override fun <T : Any> get(key: String): T {
return states[key] as? T
?: throw NoSuchElementException("Context key '$key' not found")
}
@Suppress("UNCHECKED_CAST")
override fun <T> getOrNull(key: String): T? = states[key] as? T
@Suppress("UNCHECKED_CAST")
override fun <T> getOrDefault(key: String, default: T): T {
return states[key] as? T ?: default
}
override fun set(key: String, value: Any?) {
states[key] = value
}
override fun remove(key: String) {
states.remove(key)
}
}
// ============================================================================
// Query Results
// ============================================================================
data class TaskDefinition(
val name: String,
val scheduleDescription: String?,
val enabled: Boolean,
val allowConcurrent: Boolean
)
// ============================================================================
// Schedule
// ============================================================================
sealed class Schedule {
data class Cron(val expression: String) : Schedule()
data class FixedRate(val interval: Duration) : Schedule()
data class FixedDelay(val delay: Duration) : Schedule()
data class Once(val at: Instant) : Schedule()
data class WithInitialDelay(val delay: Duration, val schedule: Schedule) : Schedule()
internal fun toTrigger(): Trigger? {
return when (this) {
is Cron -> CronTrigger(expression)
is FixedRate -> FixedRateTrigger(interval)
is FixedDelay -> FixedDelayTrigger(delay)
is Once -> OnceTrigger(at)
is WithInitialDelay -> InitialDelayTrigger(delay, schedule.toTrigger())
}
}
internal fun describe(): String {
return when (this) {
is Cron -> "cron: $expression"
is FixedRate -> "every: ${interval.toSeconds()}s"
is FixedDelay -> "fixed-delay: ${delay.toSeconds()}s"
is Once -> "once at: $at"
is WithInitialDelay -> "initial-delay: ${delay.toSeconds()}s, then ${schedule.describe()}"
}
}
}
// ============================================================================
// Triggers (Internal)
// ============================================================================
internal interface Trigger {
fun nextExecutionTime(lastExecution: Long?): Long?
}
internal class CronTrigger(expression: String) : Trigger {
private val executionTime: ExecutionTime
init {
try {
val cronDefinition = CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ)
val parser = CronParser(cronDefinition)
val cron = parser.parse(expression)
cron.validate()
executionTime = ExecutionTime.forCron(cron)
} catch (e: Exception) {
throw IllegalArgumentException("Invalid cron expression: $expression", e)
}
}
override fun nextExecutionTime(lastExecution: Long?): Long? {
val now = ZonedDateTime.now()
val next = executionTime.nextExecution(now)
return next.map { it.toInstant().toEpochMilli() }.orElse(null)
}
}
internal class FixedRateTrigger(private val interval: Duration) : Trigger {
override fun nextExecutionTime(lastExecution: Long?): Long {
val base = lastExecution ?: System.currentTimeMillis()
return base + interval.toMillis()
}
}
internal class FixedDelayTrigger(private val delay: Duration) : Trigger {
override fun nextExecutionTime(lastExecution: Long?): Long {
// Fixed delay calculates from last completion time
// For simplicity, we use the same logic as fixed rate here
// In a real implementation, you'd need to track completion time separately
val base = lastExecution ?: System.currentTimeMillis()
return base + delay.toMillis()
}
}
internal class OnceTrigger(private val at: Instant) : Trigger {
private var executed = false
override fun nextExecutionTime(lastExecution: Long?): Long? {
if (executed) return null
executed = true
return at.toEpochMilli()
}
}
internal class InitialDelayTrigger(
private val delay: Duration,
private val innerTrigger: Trigger?
) : Trigger {
private var firstCall = true
override fun nextExecutionTime(lastExecution: Long?): Long? {
if (firstCall) {
firstCall = false
return System.currentTimeMillis() + delay.toMillis()
}
return innerTrigger?.nextExecutionTime(lastExecution)
}
}
// ============================================================================
// Java Friendly Interfaces
// ============================================================================
interface TaskStartListener {
fun onStart(event: TaskStartEvent)
}
interface TaskCompleteListener {
fun onComplete(execution: TaskExecution)
}