173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
"""
|
|
Async Engine 核心调度器
|
|
=======================
|
|
|
|
驱动所有 Handler 的 Tick 循环,批量查询、批量更新。
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from app.core.redis_client import get_redis_client
|
|
from app.scheduler.handlers.base import AsyncHandler
|
|
from app.scheduler.models import StateChange
|
|
from app.scheduler.registry import TaskRegistry
|
|
from app.scheduler.slot_manager import SlotManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 各任务类型最大执行时间(秒),超过后自动标记为 failed
|
|
TASK_TIMEOUT_SECONDS = {
|
|
"script": 5 * 60,
|
|
"subtitle": 10 * 60,
|
|
"video": 30 * 60,
|
|
}
|
|
|
|
|
|
class AsyncEngine:
|
|
"""统一异步作业调度引擎"""
|
|
|
|
def __init__(self, handlers: list[AsyncHandler] | None = None):
|
|
self.redis = get_redis_client()
|
|
self.registry = TaskRegistry(self.redis)
|
|
self.slots = SlotManager(self.redis)
|
|
self.handlers: dict[str, AsyncHandler] = {}
|
|
if handlers:
|
|
for h in handlers:
|
|
self.handlers[h.name] = h
|
|
|
|
def register(self, handler: AsyncHandler) -> None:
|
|
"""注册一个 Handler"""
|
|
self.handlers[handler.name] = handler
|
|
logger.info(f"Registered handler: {handler.name}")
|
|
|
|
async def tick(self) -> None:
|
|
"""执行一次完整的调度 Tick"""
|
|
tick_start = asyncio.get_event_loop().time()
|
|
|
|
try:
|
|
# 1. 加载所有 running 的作业 ID
|
|
running_ids = await self.registry.get_running_task_ids()
|
|
if not running_ids:
|
|
logger.debug("Tick: no running tasks")
|
|
return
|
|
|
|
# 2. 按 task_type 分组,并处理超时任务
|
|
tasks_by_type: dict[str, list[Any]] = {}
|
|
timeout_changes: list[StateChange] = []
|
|
now = datetime.now(UTC)
|
|
|
|
for task_id in running_ids:
|
|
record = await self.registry.get(task_id)
|
|
if not record:
|
|
await self.registry.remove_running(task_id)
|
|
continue
|
|
|
|
max_duration = TASK_TIMEOUT_SECONDS.get(record.task_type)
|
|
is_timeout = (
|
|
max_duration
|
|
and record.status == "running"
|
|
and record.created_at
|
|
and (now - datetime.fromisoformat(record.created_at)).total_seconds() > max_duration
|
|
)
|
|
|
|
if is_timeout:
|
|
logger.warning(
|
|
f"Task timeout: {task_id}, type={record.task_type}, "
|
|
f"created_at={record.created_at}"
|
|
)
|
|
timeout_changes.append(
|
|
StateChange(task_id=task_id, field_path="status", value="failed")
|
|
)
|
|
timeout_changes.append(
|
|
StateChange(
|
|
task_id=task_id,
|
|
field_path="message",
|
|
value="任务执行超时,请稍后重试",
|
|
)
|
|
)
|
|
timeout_changes.append(
|
|
StateChange(
|
|
task_id=task_id,
|
|
field_path="error",
|
|
value=f"任务执行超过 {max_duration} 秒",
|
|
)
|
|
)
|
|
await self.registry.remove_running(task_id)
|
|
continue
|
|
|
|
tasks_by_type.setdefault(record.task_type, []).append(record)
|
|
|
|
# 3. 并行执行各 Handler 的 tick
|
|
results = await asyncio.gather(
|
|
*[
|
|
self._safe_tick(handler_name, handler, tasks_by_type.get(handler_name, []))
|
|
for handler_name, handler in self.handlers.items()
|
|
]
|
|
)
|
|
|
|
# 4. 收集并应用状态变更(包含超时任务)
|
|
all_changes: list[StateChange] = []
|
|
for changes in results:
|
|
if changes:
|
|
all_changes.extend(changes)
|
|
all_changes.extend(timeout_changes)
|
|
if all_changes:
|
|
await self._apply_changes(all_changes)
|
|
|
|
# 5. 清理已结束的作业
|
|
await self._cleanup_finished()
|
|
|
|
except Exception:
|
|
logger.exception("Scheduler tick failed")
|
|
finally:
|
|
elapsed = asyncio.get_event_loop().time() - tick_start
|
|
logger.debug(f"Tick completed in {elapsed:.2f}s")
|
|
# 写入心跳,供 healthcheck 检查
|
|
await self.redis.set("scheduler:heartbeat", str(asyncio.get_event_loop().time()), ex=60)
|
|
|
|
async def _safe_tick(
|
|
self, name: str, handler: AsyncHandler, tasks: list[Any]
|
|
) -> list[StateChange]:
|
|
"""安全执行 Handler tick,捕获异常"""
|
|
try:
|
|
return await handler.tick(tasks, self.registry, self.slots)
|
|
except Exception:
|
|
logger.exception(f"Handler tick failed: {name}")
|
|
return []
|
|
|
|
async def _apply_changes(self, changes: list[StateChange]) -> None:
|
|
"""批量应用状态变更到 Redis"""
|
|
pipe = self.redis.pipeline()
|
|
executed = False
|
|
for change in changes:
|
|
key, field, value = change.to_redis_command()
|
|
pipe.hset(key, field, value)
|
|
executed = True
|
|
if executed:
|
|
await pipe.execute()
|
|
|
|
async def _cleanup_finished(self) -> None:
|
|
"""清理已完成的作业"""
|
|
running_ids = await self.registry.get_running_task_ids()
|
|
for task_id in running_ids:
|
|
record = await self.registry.get(task_id)
|
|
if not record:
|
|
await self.registry.remove_running(task_id)
|
|
continue
|
|
if record.status in ("completed", "failed"):
|
|
await self.registry.remove_running(task_id)
|
|
logger.info(f"Task moved to finished: {task_id} ({record.status})")
|
|
|
|
async def run_forever(self, interval: float = 10.0, min_interval: float = 2.0) -> None:
|
|
"""启动无限 Tick 循环"""
|
|
logger.info("Async Engine started")
|
|
while True:
|
|
tick_start = asyncio.get_event_loop().time()
|
|
await self.tick()
|
|
elapsed = asyncio.get_event_loop().time() - tick_start
|
|
sleep_time = max(interval - elapsed, min_interval)
|
|
await asyncio.sleep(sleep_time)
|