bb08d0f586
主要变更: - 修复 /tasks/script 路由 404(去掉重复 prefix) - 开发模式自动认证兜底(无需登录即可测试流程) - Docker 基础设施独立化(共用 db/redis) - 前端 API 端口改为 8081 - 新增 TTS/语音克隆、视频粗剪、音频混音等智剪功能 - 删除智影专属模块(avatar、model_usage、qiniu 上传等)
191 lines
7.4 KiB
Python
191 lines
7.4 KiB
Python
"""
|
|
Image 任务处理器
|
|
===============
|
|
|
|
管理 Kling 图片生成的提交、轮询、下载。
|
|
不占用 Kling Video/Avatar 槽位,使用独立的图片槽位池。
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
|
|
from app.ai.providers.klingai_provider import KlingAIProvider
|
|
from app.config import get_settings
|
|
from app.core.config_loader import get_config_loader
|
|
from app.scheduler.handlers.base import AsyncHandler
|
|
from app.scheduler.models import StateChange
|
|
from app.scheduler.registry import JobRegistry
|
|
from app.scheduler.slot_manager import SlotManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SLOT_KEY = "kling:image_slots"
|
|
MAX_SLOTS = 9
|
|
|
|
|
|
class ImageHandler(AsyncHandler):
|
|
name = "image"
|
|
slot_key = SLOT_KEY
|
|
max_slots = MAX_SLOTS
|
|
|
|
async def _get_provider(self) -> KlingAIProvider:
|
|
settings = get_settings()
|
|
config_loader = get_config_loader()
|
|
platform = config_loader.get_platform("klingai")
|
|
return KlingAIProvider(
|
|
{
|
|
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
|
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
|
"base_url": platform.base_url if platform else "https://api-beijing.klingai.com",
|
|
}
|
|
)
|
|
|
|
async def tick(
|
|
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
|
) -> list[StateChange]:
|
|
changes: list[StateChange] = []
|
|
provider = await self._get_provider()
|
|
|
|
for job in jobs:
|
|
params = job.params or {}
|
|
provider_task_id = params.get("provider_task_id")
|
|
project_id = params.get("project_id", "")
|
|
prompt = params.get("prompt", "")
|
|
image_type = params.get("image_type", "cover")
|
|
|
|
if provider_task_id:
|
|
# 轮询状态
|
|
try:
|
|
result = await provider.get_image_task(provider_task_id)
|
|
status = result.get("task_status", "unknown")
|
|
except Exception as e:
|
|
logger.error(f"[Image {job.job_id}] poll error: {e}")
|
|
continue
|
|
|
|
if status in ("processing", "submitted"):
|
|
continue
|
|
|
|
if status == "failed":
|
|
await slots.release(SLOT_KEY, job.job_id)
|
|
error_msg = result.get("task_status_msg", "图片生成失败")
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="message", value=error_msg)
|
|
)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="error", value=error_msg)
|
|
)
|
|
continue
|
|
|
|
# succeed
|
|
images = result.get("task_result", {}).get("images", [])
|
|
if not images:
|
|
await slots.release(SLOT_KEY, job.job_id)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(
|
|
job_id=job.job_id,
|
|
field_path="message",
|
|
value="图片生成成功但未返回图片",
|
|
)
|
|
)
|
|
continue
|
|
|
|
image_url = images[0].get("url")
|
|
image_dir = (
|
|
Path.home() / "Documents" / "Meijiaka-zj" / "projects" / project_id / "images"
|
|
)
|
|
image_dir.mkdir(parents=True, exist_ok=True)
|
|
ext = ".jpg" if ".jpg" in image_url else ".png"
|
|
local_path = image_dir / f"{image_type}_{job.job_id[:8]}{ext}"
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session, session.get(image_url) as resp:
|
|
resp.raise_for_status()
|
|
local_path.write_bytes(await resp.read())
|
|
except Exception as e:
|
|
await slots.release(SLOT_KEY, job.job_id)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
|
)
|
|
changes.append(
|
|
StateChange(
|
|
job_id=job.job_id, field_path="message", value=f"图片下载失败: {e}"
|
|
)
|
|
)
|
|
continue
|
|
|
|
await slots.release(SLOT_KEY, job.job_id)
|
|
result_data = {
|
|
"project_id": project_id,
|
|
"image_type": image_type,
|
|
"local_path": str(local_path),
|
|
"prompt": prompt,
|
|
}
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="status", value="completed")
|
|
)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="message", value="图片生成完成")
|
|
)
|
|
changes.append(StateChange(job_id=job.job_id, field_path="completed", value=1))
|
|
changes.append(StateChange(job_id=job.job_id, field_path="total", value=1))
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="result", value=result_data)
|
|
)
|
|
continue
|
|
|
|
# 提交新任务
|
|
acquired = await slots.acquire(SLOT_KEY, job.job_id, MAX_SLOTS)
|
|
if not acquired:
|
|
continue
|
|
|
|
try:
|
|
reference_image = params.get("reference_image")
|
|
human_id = params.get("human_id")
|
|
if reference_image:
|
|
result = await provider.generate_image(
|
|
prompt=prompt,
|
|
image_url=reference_image,
|
|
model="kling-v3",
|
|
)
|
|
elif human_id:
|
|
result = await provider.generate_image(
|
|
prompt=prompt,
|
|
model="kling-v3",
|
|
aspect_ratio="9:16",
|
|
)
|
|
else:
|
|
result = await provider.generate_image(
|
|
prompt=prompt,
|
|
model="kling-v3",
|
|
aspect_ratio="9:16",
|
|
)
|
|
|
|
provider_task_id = result.get("task_id")
|
|
if not provider_task_id:
|
|
raise ValueError("未返回任务ID")
|
|
params["provider_task_id"] = provider_task_id
|
|
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="message", value="图片任务已提交")
|
|
)
|
|
except Exception as e:
|
|
await slots.release(SLOT_KEY, job.job_id)
|
|
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="message", value=str(e)[:200])
|
|
)
|
|
changes.append(
|
|
StateChange(job_id=job.job_id, field_path="error", value=str(e)[:500])
|
|
)
|
|
|
|
return changes
|