refactor(script): remove sync endpoint, add thread-pool & timeout
- Remove unused POST /script/generate sync endpoint and frontend generate() - Move JSON parsing/validation to asyncio.to_thread() to avoid event-loop blocking - Add 60s asyncio.timeout() around entire script generation pipeline - Migrate volcengine_provider to unified AsyncArk client
This commit is contained in:
@@ -35,21 +35,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入火山方舟 SDK
|
||||
try:
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from volcenginesdkarkruntime import AsyncArk
|
||||
|
||||
VOLCENGINE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOLCENGINE_SDK_AVAILABLE = False
|
||||
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
ASYNC_OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASYNC_OPENAI_AVAILABLE = False
|
||||
logger.warning("OpenAI SDK 未安装,流式生成将不可用")
|
||||
|
||||
|
||||
class VolcengineProvider(LLMProvider):
|
||||
"""
|
||||
@@ -137,16 +129,10 @@ class VolcengineProvider(LLMProvider):
|
||||
self.default_model = "doubao-seed-2-0-lite-260215"
|
||||
|
||||
self.client = self._create_client()
|
||||
import httpx
|
||||
self.async_client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url or self.DEFAULT_BASE_URL,
|
||||
http_client=httpx.AsyncClient(headers={"Accept-Encoding": "identity"}),
|
||||
)
|
||||
|
||||
def _create_client(self) -> Ark:
|
||||
"""创建火山方舟客户端"""
|
||||
return Ark(
|
||||
def _create_client(self) -> AsyncArk:
|
||||
"""创建火山方舟异步客户端"""
|
||||
return AsyncArk(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url or self.DEFAULT_BASE_URL,
|
||||
timeout=self.timeout,
|
||||
@@ -203,7 +189,7 @@ class VolcengineProvider(LLMProvider):
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
|
||||
# 调用 API
|
||||
completion = self.client.chat.completions.create(**request_params)
|
||||
completion = await self.client.chat.completions.create(**request_params)
|
||||
|
||||
# 解析结果
|
||||
content = completion.choices[0].message.content or ""
|
||||
@@ -238,8 +224,6 @@ class VolcengineProvider(LLMProvider):
|
||||
"""
|
||||
流式生成文本
|
||||
|
||||
使用 AsyncOpenAI 客户端避免 Ark SDK 同步流式缓冲问题。
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
model: 模型名称
|
||||
@@ -271,7 +255,7 @@ class VolcengineProvider(LLMProvider):
|
||||
if "reasoning_effort" in kwargs:
|
||||
request_params["extra_body"] = {"reasoning_effort": kwargs["reasoning_effort"]}
|
||||
|
||||
stream = await self.async_client.chat.completions.create(**request_params)
|
||||
stream = await self.client.chat.completions.create(**request_params)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
@@ -338,7 +322,7 @@ class VolcengineProvider(LLMProvider):
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
|
||||
stream = await self.async_client.chat.completions.create(**request_params)
|
||||
stream = await self.client.chat.completions.create(**request_params)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.usage:
|
||||
@@ -387,7 +371,7 @@ class VolcengineProvider(LLMProvider):
|
||||
# 图片生成需要单独的图片模型,不在当前配置中
|
||||
# 如需使用,请在模型广场开通 doubao-seed-1.6 并配置
|
||||
image_model = model or "doubao-seed-1.6-flash-250828"
|
||||
response = self.client.images.generate(
|
||||
response = await self.client.images.generate(
|
||||
model=image_model, prompt=prompt, size=size, **kwargs
|
||||
)
|
||||
|
||||
@@ -424,7 +408,7 @@ class VolcengineProvider(LLMProvider):
|
||||
dict: 包含向量化结果
|
||||
"""
|
||||
try:
|
||||
response = self.client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=model or "doubao-embedding-1.5", input=texts, **kwargs
|
||||
)
|
||||
|
||||
@@ -457,7 +441,7 @@ class VolcengineProvider(LLMProvider):
|
||||
test_model = model or self.default_model
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
model=test_model,
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=5,
|
||||
|
||||
@@ -49,29 +49,6 @@ async def get_categories():
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse[list[ScriptShot]])
|
||||
async def generate_script(request: GenerateScriptRequest):
|
||||
"""
|
||||
同步生成脚本
|
||||
|
||||
直接返回生成的分镜列表,适合快速预览。
|
||||
"""
|
||||
service = get_script_service()
|
||||
|
||||
shots = await service.generate_script(
|
||||
category=request.category,
|
||||
subcategory=request.subcategory,
|
||||
duration=request.duration,
|
||||
script_type=request.script_type,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=shots,
|
||||
message=f"成功生成 {len(shots)} 个分镜",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate/stream")
|
||||
async def generate_script_stream(request: Request, data: GenerateScriptRequest):
|
||||
"""
|
||||
|
||||
@@ -141,89 +141,100 @@ class ScriptService:
|
||||
model_router = await get_model_router()
|
||||
|
||||
try:
|
||||
# 加载 Prompt
|
||||
system_prompt = load_system_prompt(category, subcategory)
|
||||
if not system_prompt:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message=f"未找到提示词: category={category}, subcategory={subcategory}",
|
||||
)
|
||||
return
|
||||
|
||||
user_prompt = load_script_user_prompt(
|
||||
topic=f"{category}/{subcategory}",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
yield ScriptGenerationEvent(
|
||||
type="start",
|
||||
message="准备生成...",
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
has_shown_generating = False
|
||||
|
||||
async for chunk in model_router.generate_stream_with_progress(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
model_id=model,
|
||||
task_type="script",
|
||||
temperature=0.7,
|
||||
response_format="json_object",
|
||||
):
|
||||
if chunk["type"] == "chunk":
|
||||
chunk_content = chunk.get("content", "")
|
||||
if not chunk_content:
|
||||
continue
|
||||
full_content += chunk_content
|
||||
|
||||
if not has_shown_generating:
|
||||
yield ScriptGenerationEvent(
|
||||
type="generating",
|
||||
message="正在创作脚本...",
|
||||
)
|
||||
has_shown_generating = True
|
||||
|
||||
if not full_content or not full_content.strip():
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message="AI 返回内容为空,请检查模型配置或重试",
|
||||
)
|
||||
return
|
||||
|
||||
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
|
||||
|
||||
if not success:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message=f"脚本解析失败: {error_msg or '无法解析 AI 返回的内容'}",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
shots_data = validate_and_normalize_shots(parsed_data)
|
||||
|
||||
if not shots_data:
|
||||
async with asyncio.timeout(60):
|
||||
# 加载 Prompt
|
||||
system_prompt = load_system_prompt(category, subcategory)
|
||||
if not system_prompt:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message="AI 返回的分镜数据为空或格式不正确",
|
||||
message=f"未找到提示词: category={category}, subcategory={subcategory}",
|
||||
)
|
||||
return
|
||||
|
||||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||||
|
||||
yield ScriptGenerationEvent(
|
||||
type="complete",
|
||||
message="脚本生成成功",
|
||||
result=shots,
|
||||
user_prompt = load_script_user_prompt(
|
||||
topic=f"{category}/{subcategory}",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message=f"分镜数据处理失败: {str(e)}",
|
||||
type="start",
|
||||
message="准备生成...",
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
has_shown_generating = False
|
||||
|
||||
async for chunk in model_router.generate_stream_with_progress(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
model_id=model,
|
||||
task_type="script",
|
||||
temperature=0.7,
|
||||
response_format="json_object",
|
||||
):
|
||||
if chunk["type"] == "chunk":
|
||||
chunk_content = chunk.get("content", "")
|
||||
if not chunk_content:
|
||||
continue
|
||||
full_content += chunk_content
|
||||
|
||||
if not has_shown_generating:
|
||||
yield ScriptGenerationEvent(
|
||||
type="generating",
|
||||
message="正在创作脚本...",
|
||||
)
|
||||
has_shown_generating = True
|
||||
|
||||
if not full_content or not full_content.strip():
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message="AI 返回内容为空,请检查模型配置或重试",
|
||||
)
|
||||
return
|
||||
|
||||
success, parsed_data, error_msg = await asyncio.to_thread(
|
||||
safe_parse_ai_json_response, full_content
|
||||
)
|
||||
|
||||
if not success:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message=f"脚本解析失败: {error_msg or '无法解析 AI 返回的内容'}",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
shots_data = await asyncio.to_thread(
|
||||
validate_and_normalize_shots, parsed_data
|
||||
)
|
||||
|
||||
if not shots_data:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message="AI 返回的分镜数据为空或格式不正确",
|
||||
)
|
||||
return
|
||||
|
||||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||||
|
||||
yield ScriptGenerationEvent(
|
||||
type="complete",
|
||||
message="脚本生成成功",
|
||||
result=shots,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message=f"分镜数据处理失败: {str(e)}",
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
message="脚本生成超时,请重试",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ScriptGenerationEvent(
|
||||
type="error",
|
||||
|
||||
@@ -112,19 +112,6 @@ export const scriptApi = {
|
||||
return client.get<CategoryItem[]>('/script/categories');
|
||||
},
|
||||
|
||||
/**
|
||||
* 生成脚本内容(同步)
|
||||
* POST /script/generate
|
||||
*/
|
||||
generate: async (params: GenerateScriptParams): Promise<ScriptShot[]> => {
|
||||
return client.post<ScriptShot[]>('/script/generate', {
|
||||
category: params.category,
|
||||
subcategory: params.subcategory,
|
||||
duration: params.duration,
|
||||
scriptType: params.type,
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* 流式生成脚本(SSE)
|
||||
* POST /script/generate/stream
|
||||
|
||||
Reference in New Issue
Block a user