206 lines
7.4 KiB
Python
206 lines
7.4 KiB
Python
"""
|
|
Script 任务处理器
|
|
================
|
|
|
|
管理脚本生成的执行。
|
|
不占用 Volc 槽位,使用独立的 script 槽位池。
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from app.ai.prompts.loader import _load_system_meta
|
|
from app.core.exceptions import InsufficientPointsException
|
|
from app.core.platform_config import get_platform_config_loader
|
|
from app.db.session import AsyncSessionLocal
|
|
from app.scheduler.handlers.base import AsyncHandler
|
|
from app.scheduler.models import StateChange
|
|
from app.scheduler.registry import TaskRegistry
|
|
from app.scheduler.slot_manager import SlotManager
|
|
from app.services import point_service as ps
|
|
from app.services.script_service import ScriptService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_category_name(category: str, filename: str) -> str:
|
|
"""从文件名解析文案作为 title"""
|
|
meta = _load_system_meta()
|
|
cat_name = category
|
|
for cat in meta.get("categories", []):
|
|
if cat.get("code") == category:
|
|
cat_name = cat.get("name", category)
|
|
break
|
|
# 从文件名解析文案(前半部分)
|
|
if filename:
|
|
name = filename.replace(".txt", "")
|
|
if "——" in name:
|
|
label = name.split("——", 1)[0]
|
|
return f"{cat_name} · {label}"
|
|
return cat_name
|
|
|
|
|
|
SLOT_KEY = "script:slots"
|
|
|
|
|
|
def _get_script_max_slots() -> int:
|
|
"""从 platform-config.yaml 读取 rate_limit 配置作为 max_slots"""
|
|
try:
|
|
loader = get_platform_config_loader()
|
|
platform = loader.get_platform("volcengine_ark")
|
|
if platform:
|
|
# LLM 推理是慢请求,max_slots 不应超过服务器承载能力
|
|
return min(int(platform.rate_limit_qps), 10)
|
|
except Exception as e:
|
|
logger.warning(f"读取脚本平台 rate_limit 配置失败: {e}")
|
|
return 10
|
|
|
|
|
|
class ScriptHandler(AsyncHandler):
|
|
name = "script"
|
|
slot_key = SLOT_KEY
|
|
max_slots = _get_script_max_slots()
|
|
|
|
def __init__(self, service: ScriptService | None = None):
|
|
self.service = service
|
|
|
|
def _get_service(self) -> ScriptService:
|
|
if self.service is None:
|
|
raise RuntimeError("ScriptHandler 需要通过构造函数传入 ScriptService 实例")
|
|
return self.service
|
|
|
|
async def tick(
|
|
self, tasks: list[Any], registry: TaskRegistry, slots: SlotManager
|
|
) -> list[StateChange]:
|
|
changes: list[StateChange] = []
|
|
|
|
for task in tasks:
|
|
async with slots.acquire_ctx(SLOT_KEY, task.task_id, self.max_slots) as acquired:
|
|
if not acquired:
|
|
continue
|
|
|
|
try:
|
|
changes.extend(await self._process_task(task, registry, slots))
|
|
except Exception as e:
|
|
logger.exception(f"[Script {task.task_id}] failed")
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="error", value=str(e)[:500])
|
|
)
|
|
|
|
return changes
|
|
|
|
async def _process_task(
|
|
self, task: Any, registry: TaskRegistry, slots: SlotManager
|
|
) -> list[StateChange]:
|
|
changes: list[StateChange] = []
|
|
params = task.params or {}
|
|
category = params.get("category", "")
|
|
filename = params.get("filename", "")
|
|
|
|
await registry.update(
|
|
task.task_id,
|
|
status="running",
|
|
progress=10,
|
|
message="分析需求中...",
|
|
completed=0,
|
|
total=1,
|
|
)
|
|
|
|
try:
|
|
await registry.update(
|
|
task.task_id,
|
|
progress=10,
|
|
message="构思脚本中...",
|
|
)
|
|
|
|
service = self._get_service()
|
|
shots = await service.generate_script(
|
|
category=category,
|
|
filename=filename,
|
|
)
|
|
|
|
# 计算分镜真实总时长
|
|
total_duration = sum(s.duration for s in shots if s.duration)
|
|
result_data = {
|
|
"title": _get_category_name(category, filename),
|
|
"scenes": [s.model_dump() for s in shots],
|
|
"total_duration": total_duration,
|
|
"shot_count": len(shots),
|
|
}
|
|
|
|
# 生成成功后再扣费
|
|
user_id = params.get("user_id")
|
|
required_points = params.get("required_points", 0)
|
|
if user_id and required_points > 0:
|
|
try:
|
|
async with AsyncSessionLocal() as db:
|
|
await ps.consume(
|
|
db,
|
|
user_id=user_id,
|
|
points=required_points,
|
|
source_type="script",
|
|
source_id=task.task_id,
|
|
description="【脚本生成】",
|
|
)
|
|
await db.commit()
|
|
except InsufficientPointsException:
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(
|
|
task_id=task.task_id,
|
|
field_path="message",
|
|
value="积分不足",
|
|
)
|
|
)
|
|
changes.append(
|
|
StateChange(
|
|
task_id=task.task_id,
|
|
field_path="error_code",
|
|
value="insufficient_points",
|
|
)
|
|
)
|
|
return changes
|
|
except Exception as e:
|
|
logger.error(f"[ScriptTask {task.task_id}] 扣费失败: {e}")
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(
|
|
task_id=task.task_id,
|
|
field_path="message",
|
|
value="扣费失败",
|
|
)
|
|
)
|
|
return changes
|
|
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="status", value="completed")
|
|
)
|
|
changes.append(StateChange(task_id=task.task_id, field_path="progress", value=100))
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="message", value="脚本生成完成")
|
|
)
|
|
changes.append(StateChange(task_id=task.task_id, field_path="completed", value=1))
|
|
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="result", value=result_data)
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.exception(f"[ScriptTask {task.task_id}] Failed")
|
|
changes.append(StateChange(task_id=task.task_id, field_path="status", value="failed"))
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="message", value=str(exc)[:200])
|
|
)
|
|
changes.append(
|
|
StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500])
|
|
)
|
|
|
|
return changes
|