Files

148 lines
5.0 KiB
Python

"""
作业注册表 - Redis 运行时状态读写
==================================
所有 running 作业的状态统一存储在 Redis 中,供 Scheduler Tick 读取、更新。
"""
import json
import logging
from collections.abc import Awaitable
from datetime import UTC
from typing import Any, cast
from redis.asyncio import Redis
from app.scheduler.models import TaskRecord
logger = logging.getLogger(__name__)
KEY_RUNNING_SET = "scheduler:running_tasks"
def _task_key(task_id: str) -> str:
return f"task:{task_id}"
class TaskRegistry:
"""基于 Redis 的作业注册表"""
def __init__(self, redis: Redis):
self.redis = redis
async def create(
self,
task_id: str,
task_type: str,
user_id: str,
status: str = "pending",
params: dict[str, Any] | None = None,
ttl: int = 86400,
) -> None:
"""创建新的作业记录"""
from datetime import datetime
data = {
"type": task_type,
"user_id": user_id,
"status": status,
"progress": "0",
"message": "等待执行...",
"completed": "0",
"total": "0",
"created_at": datetime.now(UTC).isoformat(),
}
if params:
data["params"] = json.dumps(params, ensure_ascii=False)
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=data))
await cast(Awaitable[int], self.redis.expire(_task_key(task_id), ttl))
logger.debug(f"Registry created: {task_id}, type={task_type}")
async def update(self, task_id: str, **fields: Any) -> None:
"""更新作业字段"""
mapping: dict[str, str] = {}
for key, value in fields.items():
if isinstance(value, dict | list):
mapping[key] = json.dumps(value, ensure_ascii=False)
elif value is None:
mapping[key] = ""
else:
mapping[key] = str(value)
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=mapping))
logger.debug(f"Registry updated: {task_id}, fields={list(fields.keys())}")
async def get(self, task_id: str) -> TaskRecord | None:
"""读取完整作业记录"""
data = await cast(Awaitable[dict[Any, Any]], self.redis.hgetall(_task_key(task_id)))
if not data:
return None
def _parse(key: str, raw: str) -> Any:
if key in ("result", "params") and raw:
try:
return json.loads(raw)
except json.JSONDecodeError:
return raw
if key in ("progress", "completed", "total"):
try:
return int(raw)
except ValueError:
return 0
if key in ("error_code", "error") and raw == "":
return None
return raw
parsed = {k: _parse(k, v) for k, v in data.items()}
task_type = parsed.get("type", "")
params_raw = parsed.get("params", {})
params = params_raw if isinstance(params_raw, dict) else {}
return TaskRecord(
task_id=task_id,
task_type=task_type,
user_id=parsed.get("user_id", ""),
project_id=str(params.get("project_id", "")),
status=parsed.get("status", "unknown"),
progress=parsed.get("progress", 0),
message=parsed.get("message", ""),
completed=parsed.get("completed", 0),
total=parsed.get("total", 0),
result=parsed.get("result", {}),
error=parsed.get("error"),
error_code=parsed.get("error_code"),
params=params,
created_at=parsed.get("created_at", ""),
)
async def add_running(self, task_id: str) -> None:
"""将作业标记为 running(加入全局 running 集合)"""
await cast(Awaitable[int], self.redis.sadd(KEY_RUNNING_SET, task_id))
async def remove_running(self, task_id: str) -> None:
"""将作业从全局 running 集合移除"""
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))
async def get_running_task_ids(self) -> list[str]:
"""获取所有 running 的作业 ID 列表"""
members = await cast(Awaitable[set[Any]], self.redis.smembers(KEY_RUNNING_SET))
return list(members)
async def list_running_by_user(self, user_id: str) -> list[TaskRecord]:
"""获取指定用户的所有 running 作业"""
task_ids = await self.get_running_task_ids()
if not task_ids:
return []
results: list[TaskRecord] = []
for task_id in task_ids:
task = await self.get(task_id)
if task and task.user_id == user_id:
results.append(task)
return results
async def delete(self, task_id: str) -> None:
"""删除作业记录"""
await cast(Awaitable[int], self.redis.delete(_task_key(task_id)))
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))