Files
meijiaka-zy/python-api/app/ai/providers/generic_llm_provider.py
T

315 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
OpenAI Provider 实现
====================
"""
from __future__ import annotations
import time
from collections.abc import AsyncIterator
from openai import AsyncOpenAI
from app.ai.providers.base import (
GenerationResult,
LLMProvider,
ModelHealth,
ProviderError,
)
class GenericLLMProvider(LLMProvider):
"""
OpenAI / OpenAI 兼容 API Provider
支持:
- OpenAI 官方 API
- Azure OpenAI
- 任何 OpenAI 兼容接口(如本地 vLLM
"""
provider_id = "openai"
provider_name = "OpenAI"
# 默认可用模型
DEFAULT_MODELS = [
"gpt-4-turbo-preview",
"gpt-4",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
]
def __init__(self, api_key: str | None = None, base_url: str | None = None, **kwargs):
super().__init__(api_key, base_url, **kwargs)
if not self.api_key:
raise ProviderError("OpenAI API Key 未配置", provider_id=self.provider_id)
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url or "https://api.openai.com/v1",
)
self.default_model = kwargs.get("default_model", "gpt-3.5-turbo")
async def generate(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> GenerationResult:
"""同步生成"""
try:
response = await self.client.chat.completions.create(
model=model or self.default_model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
stream=False,
**kwargs,
)
return GenerationResult(
content=response.choices[0].message.content or "",
usage=response.usage.model_dump() if response.usage else None,
model=response.model,
)
except Exception as e:
raise ProviderError(
f"OpenAI 生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_stream(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> AsyncIterator[str]:
"""流式生成"""
try:
stream = await self.client.chat.completions.create(
model=model or self.default_model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
stream=True,
**kwargs,
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
raise ProviderError(
f"OpenAI 流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def health_check(self, model: str | None = None) -> ModelHealth:
"""健康检查"""
start_time = time.time()
test_model = model or self.default_model
try:
response = await self.client.chat.completions.create(
model=test_model,
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5,
timeout=10,
)
response_time = (time.time() - start_time) * 1000
return ModelHealth(
id=test_model,
name=f"OpenAI {test_model}",
is_available=True,
response_time=response_time,
last_error=None,
)
except Exception as e:
return ModelHealth(
id=test_model,
name=f"OpenAI {test_model}",
is_available=False,
response_time=(time.time() - start_time) * 1000,
last_error=str(e),
)
@property
def available_models(self) -> list[str]:
"""返回可用模型列表"""
return self.config.get("models", self.DEFAULT_MODELS)
class MockProvider(LLMProvider):
"""
Mock Provider - 用于测试和演示
不调用真实 API,返回模拟 JSON 数据。
"""
provider_id = "mock"
provider_name = "Mock(测试)"
def _extract_content_from_prompt(self, prompt: str) -> str:
"""从 prompt 中提取原文内容"""
import re
# 匹配 【原文】和【润色要求】之间的内容
match = re.search(r"【原文】\s*(.+?)\s*【润色要求】", prompt, re.DOTALL)
if match:
return match.group(1).strip()
return "优化后的文案"
async def generate(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> GenerationResult:
"""模拟生成 - 根据 prompt 类型返回不同格式数据"""
import asyncio
import json
await asyncio.sleep(0.5) # 模拟延迟
# 检测是否为润色请求
if "润色" in prompt or "polish" in prompt.lower():
# 返回润色后的文本
original = self._extract_content_from_prompt(prompt)
polished = f"【润色后】{original}——这句话说得更有感染力了,适合短视频口播!"
return GenerationResult(
content=polished,
usage={"prompt_tokens": 50, "completion_tokens": 50, "total_tokens": 100},
model=model or "mock-model",
)
# 否则返回脚本生成的 JSON 数据
mock_shots = [
{
"id": 1,
"type": "segment",
"scene": "镜头从门外缓缓推入,展示客厅整体布局,自然光从落地窗洒入",
"voiceover": "大家好,今天给大家讲讲家装验收最容易被忽略的5个细节",
"duration": "5s",
},
{
"id": 2,
"type": "segment",
"scene": "特写墙面,手指划过检查平整度,展示一处细微裂纹",
"voiceover": "第一,墙面验收。很多人只看颜色,其实平整度和裂纹更重要",
"duration": "8s",
},
{
"id": 3,
"type": "segment",
"scene": "蹲下来拍摄地板接缝处,展示踢脚线与地板的缝隙",
"voiceover": "第二,地板验收。重点看接缝是否均匀,踢脚线是否贴合",
"duration": "8s",
},
{
"id": 4,
"type": "empty_shot",
"scene": "现代简约风格卫生间,白色瓷砖,柔和灯光,镜头缓慢平移",
"voiceover": "",
"duration": "3s",
},
{
"id": 5,
"type": "segment",
"scene": "打开水龙头,检查水流和水压,特写地漏排水速度",
"voiceover": "第三,水电验收。测试所有开关、龙头,检查排水是否顺畅",
"duration": "8s",
},
{
"id": 6,
"type": "segment",
"scene": "开关面板特写,逐一测试灯光开关,展示一处松动的面板",
"voiceover": "第四,电路验收。每个开关都要试,面板安装是否牢固",
"duration": "7s",
},
{
"id": 7,
"type": "segment",
"scene": "主人公安慰地微笑,竖起大拇指,背景是温馨的客厅",
"voiceover": "记住这5点,验收不踩坑!关注我,更多家装干货等你",
"duration": "6s",
},
]
return GenerationResult(
content=json.dumps(mock_shots, ensure_ascii=False),
usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300},
model=model or "mock-model",
)
async def generate_stream(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
**kwargs,
) -> AsyncIterator[str]:
"""模拟流式生成 - 返回脚本 JSON"""
import asyncio
import json
# 检测是否为润色请求
if "润色" in prompt or "polish" in prompt.lower():
response = "【润色后】优化后的文案,更适合短视频口播!"
else:
# 返回脚本生成的 JSON 数据
mock_shots = [
{
"id": 1,
"type": "segment",
"scene": "主播站在毛坯房里,表情严肃",
"voiceover": "装修被坑了8万的业主,昨天来找我哭诉...",
"duration": "5s",
},
{
"id": 2,
"type": "segment",
"scene": "主播指着墙面,手指划过",
"voiceover": "第一坑,水电改造!很多人图便宜找游击队",
"duration": "8s",
},
{
"id": 3,
"type": "empty_shot",
"scene": "现代装修施工现场,水电管线整齐排列,4K画质",
"voiceover": "看,这就是专业施工",
"duration": "3s",
},
]
response = json.dumps(mock_shots, ensure_ascii=False)
# 流式输出
chunk_size = 10 # 每10个字符一个chunk
for i in range(0, len(response), chunk_size):
yield response[i : i + chunk_size]
await asyncio.sleep(0.05) # 模拟打字机效果
async def health_check(self, model: str | None = None) -> ModelHealth:
"""模拟健康检查"""
return ModelHealth(
id=model or "mock-model",
name="Mock Model",
is_available=True,
response_time=50.0,
last_error=None,
)
@property
def available_models(self) -> list[str]:
return ["mock-model", "mock-gpt-3.5", "mock-gpt-4"]