315 lines
10 KiB
Python
315 lines
10 KiB
Python
"""
|
||
OpenAI Provider 实现
|
||
====================
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import time
|
||
from collections.abc import AsyncIterator
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
from app.ai.providers.base import (
|
||
GenerationResult,
|
||
LLMProvider,
|
||
ModelHealth,
|
||
ProviderError,
|
||
)
|
||
|
||
|
||
class GenericLLMProvider(LLMProvider):
|
||
"""
|
||
OpenAI / OpenAI 兼容 API Provider
|
||
|
||
支持:
|
||
- OpenAI 官方 API
|
||
- Azure OpenAI
|
||
- 任何 OpenAI 兼容接口(如本地 vLLM)
|
||
"""
|
||
|
||
provider_id = "openai"
|
||
provider_name = "OpenAI"
|
||
|
||
# 默认可用模型
|
||
DEFAULT_MODELS = [
|
||
"gpt-4-turbo-preview",
|
||
"gpt-4",
|
||
"gpt-3.5-turbo",
|
||
"gpt-3.5-turbo-16k",
|
||
]
|
||
|
||
def __init__(self, api_key: str | None = None, base_url: str | None = None, **kwargs):
|
||
super().__init__(api_key, base_url, **kwargs)
|
||
|
||
if not self.api_key:
|
||
raise ProviderError("OpenAI API Key 未配置", provider_id=self.provider_id)
|
||
|
||
self.client = AsyncOpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url or "https://api.openai.com/v1",
|
||
)
|
||
self.default_model = kwargs.get("default_model", "gpt-3.5-turbo")
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""同步生成"""
|
||
try:
|
||
response = await self.client.chat.completions.create(
|
||
model=model or self.default_model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=False,
|
||
**kwargs,
|
||
)
|
||
|
||
return GenerationResult(
|
||
content=response.choices[0].message.content or "",
|
||
usage=response.usage.model_dump() if response.usage else None,
|
||
model=response.model,
|
||
)
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"OpenAI 生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def generate_stream(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> AsyncIterator[str]:
|
||
"""流式生成"""
|
||
try:
|
||
stream = await self.client.chat.completions.create(
|
||
model=model or self.default_model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=True,
|
||
**kwargs,
|
||
)
|
||
|
||
async for chunk in stream:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"OpenAI 流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||
"""健康检查"""
|
||
start_time = time.time()
|
||
test_model = model or self.default_model
|
||
|
||
try:
|
||
response = await self.client.chat.completions.create(
|
||
model=test_model,
|
||
messages=[{"role": "user", "content": "Hi"}],
|
||
max_tokens=5,
|
||
timeout=10,
|
||
)
|
||
|
||
response_time = (time.time() - start_time) * 1000
|
||
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"OpenAI {test_model}",
|
||
is_available=True,
|
||
response_time=response_time,
|
||
last_error=None,
|
||
)
|
||
|
||
except Exception as e:
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"OpenAI {test_model}",
|
||
is_available=False,
|
||
response_time=(time.time() - start_time) * 1000,
|
||
last_error=str(e),
|
||
)
|
||
|
||
@property
|
||
def available_models(self) -> list[str]:
|
||
"""返回可用模型列表"""
|
||
return self.config.get("models", self.DEFAULT_MODELS)
|
||
|
||
|
||
class MockProvider(LLMProvider):
|
||
"""
|
||
Mock Provider - 用于测试和演示
|
||
|
||
不调用真实 API,返回模拟 JSON 数据。
|
||
"""
|
||
|
||
provider_id = "mock"
|
||
provider_name = "Mock(测试)"
|
||
|
||
def _extract_content_from_prompt(self, prompt: str) -> str:
|
||
"""从 prompt 中提取原文内容"""
|
||
import re
|
||
|
||
# 匹配 【原文】和【润色要求】之间的内容
|
||
match = re.search(r"【原文】\s*(.+?)\s*【润色要求】", prompt, re.DOTALL)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return "优化后的文案"
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""模拟生成 - 根据 prompt 类型返回不同格式数据"""
|
||
import asyncio
|
||
import json
|
||
|
||
await asyncio.sleep(0.5) # 模拟延迟
|
||
|
||
# 检测是否为润色请求
|
||
if "润色" in prompt or "polish" in prompt.lower():
|
||
# 返回润色后的文本
|
||
original = self._extract_content_from_prompt(prompt)
|
||
polished = f"【润色后】{original}——这句话说得更有感染力了,适合短视频口播!"
|
||
return GenerationResult(
|
||
content=polished,
|
||
usage={"prompt_tokens": 50, "completion_tokens": 50, "total_tokens": 100},
|
||
model=model or "mock-model",
|
||
)
|
||
|
||
# 否则返回脚本生成的 JSON 数据
|
||
mock_shots = [
|
||
{
|
||
"id": 1,
|
||
"type": "segment",
|
||
"scene": "镜头从门外缓缓推入,展示客厅整体布局,自然光从落地窗洒入",
|
||
"voiceover": "大家好,今天给大家讲讲家装验收最容易被忽略的5个细节",
|
||
"duration": "5s",
|
||
},
|
||
{
|
||
"id": 2,
|
||
"type": "segment",
|
||
"scene": "特写墙面,手指划过检查平整度,展示一处细微裂纹",
|
||
"voiceover": "第一,墙面验收。很多人只看颜色,其实平整度和裂纹更重要",
|
||
"duration": "8s",
|
||
},
|
||
{
|
||
"id": 3,
|
||
"type": "segment",
|
||
"scene": "蹲下来拍摄地板接缝处,展示踢脚线与地板的缝隙",
|
||
"voiceover": "第二,地板验收。重点看接缝是否均匀,踢脚线是否贴合",
|
||
"duration": "8s",
|
||
},
|
||
{
|
||
"id": 4,
|
||
"type": "empty_shot",
|
||
"scene": "现代简约风格卫生间,白色瓷砖,柔和灯光,镜头缓慢平移",
|
||
"voiceover": "",
|
||
"duration": "3s",
|
||
},
|
||
{
|
||
"id": 5,
|
||
"type": "segment",
|
||
"scene": "打开水龙头,检查水流和水压,特写地漏排水速度",
|
||
"voiceover": "第三,水电验收。测试所有开关、龙头,检查排水是否顺畅",
|
||
"duration": "8s",
|
||
},
|
||
{
|
||
"id": 6,
|
||
"type": "segment",
|
||
"scene": "开关面板特写,逐一测试灯光开关,展示一处松动的面板",
|
||
"voiceover": "第四,电路验收。每个开关都要试,面板安装是否牢固",
|
||
"duration": "7s",
|
||
},
|
||
{
|
||
"id": 7,
|
||
"type": "segment",
|
||
"scene": "主人公安慰地微笑,竖起大拇指,背景是温馨的客厅",
|
||
"voiceover": "记住这5点,验收不踩坑!关注我,更多家装干货等你",
|
||
"duration": "6s",
|
||
},
|
||
]
|
||
|
||
return GenerationResult(
|
||
content=json.dumps(mock_shots, ensure_ascii=False),
|
||
usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300},
|
||
model=model or "mock-model",
|
||
)
|
||
|
||
async def generate_stream(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> AsyncIterator[str]:
|
||
"""模拟流式生成 - 返回脚本 JSON"""
|
||
import asyncio
|
||
import json
|
||
|
||
# 检测是否为润色请求
|
||
if "润色" in prompt or "polish" in prompt.lower():
|
||
response = "【润色后】优化后的文案,更适合短视频口播!"
|
||
else:
|
||
# 返回脚本生成的 JSON 数据
|
||
mock_shots = [
|
||
{
|
||
"id": 1,
|
||
"type": "segment",
|
||
"scene": "主播站在毛坯房里,表情严肃",
|
||
"voiceover": "装修被坑了8万的业主,昨天来找我哭诉...",
|
||
"duration": "5s",
|
||
},
|
||
{
|
||
"id": 2,
|
||
"type": "segment",
|
||
"scene": "主播指着墙面,手指划过",
|
||
"voiceover": "第一坑,水电改造!很多人图便宜找游击队",
|
||
"duration": "8s",
|
||
},
|
||
{
|
||
"id": 3,
|
||
"type": "empty_shot",
|
||
"scene": "现代装修施工现场,水电管线整齐排列,4K画质",
|
||
"voiceover": "看,这就是专业施工",
|
||
"duration": "3s",
|
||
},
|
||
]
|
||
response = json.dumps(mock_shots, ensure_ascii=False)
|
||
|
||
# 流式输出
|
||
chunk_size = 10 # 每10个字符一个chunk
|
||
for i in range(0, len(response), chunk_size):
|
||
yield response[i : i + chunk_size]
|
||
await asyncio.sleep(0.05) # 模拟打字机效果
|
||
|
||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||
"""模拟健康检查"""
|
||
return ModelHealth(
|
||
id=model or "mock-model",
|
||
name="Mock Model",
|
||
is_available=True,
|
||
response_time=50.0,
|
||
last_error=None,
|
||
)
|
||
|
||
@property
|
||
def available_models(self) -> list[str]:
|
||
return ["mock-model", "mock-gpt-3.5", "mock-gpt-4"]
|