332 lines
13 KiB
Diff
332 lines
13 KiB
Diff
diff --git a/python-api/app/services/point_service.py b/python-api/app/services/point_service.py
|
|
index 0e5d79e..709f78d 100644
|
|
--- a/python-api/app/services/point_service.py
|
|
+++ b/python-api/app/services/point_service.py
|
|
@@ -25,7 +25,7 @@ import logging
|
|
import math
|
|
from datetime import UTC, datetime, timedelta
|
|
from pathlib import Path
|
|
-from typing import TYPE_CHECKING
|
|
+from typing import TYPE_CHECKING, Any
|
|
|
|
import yaml
|
|
from sqlalchemy import select
|
|
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
+from app.core.exceptions import InsufficientPointsException
|
|
from app.models.point_batch import PointBatch
|
|
from app.models.point_transaction import PointTransaction
|
|
from app.models.user_point import UserPoint
|
|
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
|
|
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
|
|
|
|
|
|
-def _load_points_config() -> dict:
|
|
+def _load_points_config() -> dict[str, Any]:
|
|
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
|
|
if not _CONFIG_PATH.exists():
|
|
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
|
|
- with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
|
+ with open(_CONFIG_PATH, encoding="utf-8") as f:
|
|
cfg = yaml.safe_load(f) or {}
|
|
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
|
|
merged: dict[str, dict] = {}
|
|
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
|
|
return merged
|
|
|
|
|
|
-POINTS_CONFIG: dict[str, dict] = _load_points_config()
|
|
+POINTS_CONFIG: dict[str, Any] = _load_points_config()
|
|
|
|
|
|
def get_recharge_options() -> list[dict]:
|
|
"""获取充值档位配置(由后端控制,支持积分赠送)"""
|
|
- return POINTS_CONFIG.get("_recharge_options", [])
|
|
+ options = POINTS_CONFIG.get("_recharge_options", [])
|
|
+ if isinstance(options, list):
|
|
+ return options
|
|
+ return []
|
|
|
|
|
|
def get_chargeable_source_types() -> list[str]:
|
|
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
|
|
return [
|
|
- key for key, cfg in POINTS_CONFIG.items()
|
|
+ key
|
|
+ for key, cfg in POINTS_CONFIG.items()
|
|
if not key.startswith("_") and cfg.get("mode") != "free"
|
|
]
|
|
|
|
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
|
|
|
|
# ── 余额查询 ──────────────────────────────────────────
|
|
|
|
+
|
|
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
|
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
|
|
- result = await db.execute(
|
|
- select(UserPoint).where(UserPoint.user_id == user_id)
|
|
- )
|
|
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
|
up = result.scalar_one_or_none()
|
|
|
|
if not up:
|
|
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
|
from sqlalchemy import func as _func
|
|
|
|
available_result = await db.execute(
|
|
- select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
|
|
- .where(
|
|
+ select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
|
|
PointBatch.user_id == user_id,
|
|
PointBatch.remaining > 0,
|
|
PointBatch.expired_at > _now(),
|
|
@@ -221,6 +224,7 @@ async def check_balance(
|
|
|
|
# ── 充值 ──────────────────────────────────────────────
|
|
|
|
+
|
|
async def recharge(
|
|
db: AsyncSession,
|
|
*,
|
|
@@ -247,8 +251,7 @@ async def recharge(
|
|
# 幂等保护:同一笔订单(order_id)只能充值一次
|
|
if order_id:
|
|
existing_result = await db.execute(
|
|
- select(PointTransaction)
|
|
- .where(
|
|
+ select(PointTransaction).where(
|
|
PointTransaction.source_id == str(order_id),
|
|
PointTransaction.type == "recharge",
|
|
)
|
|
@@ -259,9 +262,7 @@ async def recharge(
|
|
return existing_tx
|
|
|
|
# 1. 获取或创建用户积分账户
|
|
- result = await db.execute(
|
|
- select(UserPoint).where(UserPoint.user_id == user_id)
|
|
- )
|
|
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
|
up = result.scalar_one_or_none()
|
|
|
|
if not up:
|
|
@@ -353,7 +354,7 @@ async def consume(
|
|
直接扣费(后置计费)。
|
|
|
|
业务执行成功后调用,按实际消耗直接扣除余额。
|
|
- 默认不允许欠费(余额不足时抛出 ValueError)。
|
|
+ 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。
|
|
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
|
|
|
|
:param points: 实际消耗积分(正整数)
|
|
@@ -383,9 +384,7 @@ async def consume(
|
|
|
|
# 2. 获取用户积分账户(加锁)
|
|
result = await db.execute(
|
|
- select(UserPoint)
|
|
- .where(UserPoint.user_id == user_id)
|
|
- .with_for_update()
|
|
+ select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
|
|
)
|
|
up = result.scalar_one_or_none()
|
|
|
|
@@ -404,7 +403,7 @@ async def consume(
|
|
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
|
|
available = sum(b.remaining for b in batches)
|
|
if not allow_negative and available < points:
|
|
- raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
|
+ raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
|
|
|
remaining_to_deduct = points
|
|
for batch in batches:
|
|
@@ -440,6 +439,7 @@ async def consume(
|
|
|
|
# ── 过期回收 ──────────────────────────────────────────
|
|
|
|
+
|
|
async def expire_batches(db: AsyncSession) -> int:
|
|
"""
|
|
回收过期积分批次。返回过期积分总数。
|
|
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
|
|
|
|
# 获取用户积分账户(加锁)
|
|
up_result = await db.execute(
|
|
- select(UserPoint)
|
|
- .where(UserPoint.user_id == batch.user_id)
|
|
- .with_for_update()
|
|
+ select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
|
|
)
|
|
up = up_result.scalar_one_or_none()
|
|
if not up:
|
|
diff --git a/python-api/app/services/script_service.py b/python-api/app/services/script_service.py
|
|
index 49aa4b1..60f58d5 100644
|
|
--- a/python-api/app/services/script_service.py
|
|
+++ b/python-api/app/services/script_service.py
|
|
@@ -7,9 +7,16 @@ import asyncio
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
+from typing import Any
|
|
|
|
from app.ai.model_router import get_model_router
|
|
from app.ai.prompts import load_prompt_file, load_script_user_prompt
|
|
+from app.core.exceptions import (
|
|
+ AIEmptyResponseException,
|
|
+ AIParseErrorException,
|
|
+ AITimeoutException,
|
|
+ PromptNotFoundException,
|
|
+)
|
|
from app.schemas.script import ScriptShot
|
|
from app.services.ai_response_utils import (
|
|
safe_parse_ai_json_response,
|
|
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
|
|
class ScriptService:
|
|
"""脚本生成服务"""
|
|
|
|
-
|
|
def __init__(self):
|
|
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
|
|
|
-
|
|
-
|
|
def _load_prompt(self, name: str) -> str:
|
|
"""加载 Prompt 模板"""
|
|
prompt_file = self.prompts_dir / f"{name}.txt"
|
|
@@ -58,7 +62,7 @@ class ScriptService:
|
|
# 加载 Prompt
|
|
system_prompt = load_prompt_file(category, filename)
|
|
if not system_prompt:
|
|
- raise ValueError(f"未找到提示词: category={category}, filename={filename}")
|
|
+ raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
|
|
|
|
# 用户提示词
|
|
user_prompt = load_script_user_prompt(
|
|
@@ -75,24 +79,26 @@ class ScriptService:
|
|
)
|
|
|
|
if not result.content or not result.content.strip():
|
|
- raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
|
+ raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
|
|
|
|
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
|
|
|
if not success:
|
|
- raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
|
+ raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
|
|
|
try:
|
|
shots_data = validate_and_normalize_shots(parsed_data)
|
|
|
|
if not shots_data:
|
|
- raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
|
+ raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
|
|
|
|
shots = [ScriptShot(**shot) for shot in shots_data]
|
|
return shots
|
|
|
|
+ except (AIEmptyResponseException, AIParseErrorException):
|
|
+ raise
|
|
except Exception as e:
|
|
- raise ValueError(f"分镜数据处理失败: {str(e)}")
|
|
+ raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
|
|
|
|
async def polish_content(
|
|
self,
|
|
@@ -144,21 +150,23 @@ class ScriptService:
|
|
)
|
|
return result.content.strip()
|
|
except TimeoutError:
|
|
- raise ValueError("润色请求超时,请重试")
|
|
+ raise AITimeoutException("润色请求超时,请重试")
|
|
+ except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
|
|
+ raise
|
|
except Exception as e:
|
|
- raise ValueError(f"润色失败: {str(e)}")
|
|
+ raise AIParseErrorException(f"润色失败: {str(e)}")
|
|
|
|
async def check_model_health(self) -> dict:
|
|
"""检查模型健康状态"""
|
|
model_router = await get_model_router()
|
|
health_results = await model_router.health_check()
|
|
|
|
- models = []
|
|
+ models: list[dict[str, Any]] = []
|
|
available_count = 0
|
|
- recommended = None
|
|
+ recommended: dict[str, Any] | None = None
|
|
|
|
for _provider_id, health in health_results.items():
|
|
- model_info = {
|
|
+ model_info: dict[str, Any] = {
|
|
"id": health.id,
|
|
"name": health.name,
|
|
"is_available": health.is_available,
|
|
@@ -169,9 +177,12 @@ class ScriptService:
|
|
|
|
if health.is_available:
|
|
available_count += 1
|
|
- if recommended is None or health.response_time < recommended.get(
|
|
- "response_time", float("inf")
|
|
- ):
|
|
+ current_best = (
|
|
+ float("inf")
|
|
+ if recommended is None
|
|
+ else float(recommended.get("response_time") or float("inf"))
|
|
+ )
|
|
+ if health.response_time < current_best:
|
|
recommended = model_info
|
|
|
|
total = len(models)
|
|
@@ -188,7 +199,6 @@ class ScriptService:
|
|
"""测试指定模型连接"""
|
|
model_router = await get_model_router()
|
|
|
|
-
|
|
start_time = time.time()
|
|
|
|
try:
|
|
diff --git a/python-api/app/services/vidu_service.py b/python-api/app/services/vidu_service.py
|
|
index 054823f..61bbdcd 100644
|
|
--- a/python-api/app/services/vidu_service.py
|
|
+++ b/python-api/app/services/vidu_service.py
|
|
@@ -207,8 +207,9 @@ class ViduService:
|
|
error_type=PlatformErrorType.BAD_REQUEST,
|
|
)
|
|
|
|
- logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
|
|
- return result.data or {}
|
|
+ clone_data = result.data or {}
|
|
+ logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
|
|
+ return clone_data
|
|
|
|
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
|
|
"""Vidu 声音复刻是同步接口,无独立查询。
|
|
@@ -270,6 +271,8 @@ class ViduService:
|
|
result_data = status.result or {}
|
|
return {
|
|
"state": ViduAdapter.denormalize_state(status.state),
|
|
- "creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
|
|
+ "creations": (
|
|
+ [{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
|
|
+ ),
|
|
"message": status.error_message,
|
|
}
|
|
diff --git a/python-api/app/services/volcengine_caption_service.py b/python-api/app/services/volcengine_caption_service.py
|
|
index 83f59b4..9a565a5 100644
|
|
--- a/python-api/app/services/volcengine_caption_service.py
|
|
+++ b/python-api/app/services/volcengine_caption_service.py
|
|
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
|
|
error_type=PlatformErrorType.BAD_REQUEST,
|
|
)
|
|
|
|
- logger.warning(
|
|
- f"{task_name}超过最大轮询次数: task_id={task_id}, "
|
|
- f"retries={retries}"
|
|
- )
|
|
+ logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
|
|
raise PlatformError(
|
|
f"{task_name}超时,请稍后重试",
|
|
platform="volcengine_caption",
|