Files
meijiaka-zy/python-api/app/scheduler/registry.py
T
2026-05-04 19:18:22 +08:00

144 lines
4.6 KiB
Python

"""
作业注册表 - Redis 运行时状态读写
==================================
所有 running 作业的状态统一存储在 Redis 中,供 Scheduler Tick 读取、更新。
"""
import json
import logging
from datetime import UTC
from typing import Any
from redis.asyncio import Redis
from app.scheduler.models import TaskRecord
logger = logging.getLogger(__name__)
KEY_RUNNING_SET = "scheduler:running_tasks"
def _task_key(task_id: str) -> str:
return f"task:{task_id}"
class TaskRegistry:
"""基于 Redis 的作业注册表"""
def __init__(self, redis: Redis):
self.redis = redis
async def create(
self,
task_id: str,
task_type: str,
user_id: str,
status: str = "pending",
params: dict[str, Any] | None = None,
ttl: int = 86400,
) -> None:
"""创建新的作业记录"""
from datetime import datetime
data = {
"type": task_type,
"user_id": user_id,
"status": status,
"progress": "0",
"message": "等待执行...",
"completed": "0",
"total": "0",
"created_at": datetime.now(UTC).isoformat(),
}
if params:
data["params"] = json.dumps(params, ensure_ascii=False)
await self.redis.hset(_task_key(task_id), mapping=data)
await self.redis.expire(_task_key(task_id), ttl)
logger.debug(f"Registry created: {task_id}, type={task_type}")
async def update(self, task_id: str, **fields: Any) -> None:
"""更新作业字段"""
mapping: dict[str, str] = {}
for key, value in fields.items():
if isinstance(value, dict | list):
mapping[key] = json.dumps(value, ensure_ascii=False)
elif value is None:
mapping[key] = ""
else:
mapping[key] = str(value)
await self.redis.hset(_task_key(task_id), mapping=mapping)
logger.debug(f"Registry updated: {task_id}, fields={list(fields.keys())}")
async def get(self, task_id: str) -> TaskRecord | None:
"""读取完整作业记录"""
data = await self.redis.hgetall(_task_key(task_id))
if not data:
return None
def _parse(key: str, raw: str) -> Any:
if key in ("result", "params") and raw:
try:
return json.loads(raw)
except json.JSONDecodeError:
return raw
if key in ("progress", "completed", "total"):
try:
return int(raw)
except ValueError:
return 0
return raw
parsed = {k: _parse(k, v) for k, v in data.items()}
task_type = parsed.get("type", "")
params_raw = parsed.get("params", {})
params = params_raw if isinstance(params_raw, dict) else {}
return TaskRecord(
task_id=task_id,
task_type=task_type,
user_id=parsed.get("user_id", ""),
project_id=str(params.get("project_id", "")),
status=parsed.get("status", "unknown"),
progress=parsed.get("progress", 0),
message=parsed.get("message", ""),
completed=parsed.get("completed", 0),
total=parsed.get("total", 0),
result=parsed.get("result", {}),
error=parsed.get("error"),
params=params,
created_at=parsed.get("created_at", ""),
)
async def add_running(self, task_id: str) -> None:
"""将作业标记为 running(加入全局 running 集合)"""
await self.redis.sadd(KEY_RUNNING_SET, task_id)
async def remove_running(self, task_id: str) -> None:
"""将作业从全局 running 集合移除"""
await self.redis.srem(KEY_RUNNING_SET, task_id)
async def get_running_task_ids(self) -> list[str]:
"""获取所有 running 的作业 ID 列表"""
members = await self.redis.smembers(KEY_RUNNING_SET)
return list(members)
async def list_running_by_user(self, user_id: str) -> list[TaskRecord]:
"""获取指定用户的所有 running 作业"""
task_ids = await self.get_running_task_ids()
if not task_ids:
return []
results: list[TaskRecord] = []
for task_id in task_ids:
task = await self.get(task_id)
if task and task.user_id == user_id:
results.append(task)
return results
async def delete(self, task_id: str) -> None:
"""删除作业记录"""
await self.redis.delete(_task_key(task_id))
await self.redis.srem(KEY_RUNNING_SET, task_id)