Files

332 lines
13 KiB
Diff

diff --git a/python-api/app/services/point_service.py b/python-api/app/services/point_service.py
index 0e5d79e..709f78d 100644
--- a/python-api/app/services/point_service.py
+++ b/python-api/app/services/point_service.py
@@ -25,7 +25,7 @@ import logging
import math
from datetime import UTC, datetime, timedelta
from pathlib import Path
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import yaml
from sqlalchemy import select
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
+from app.core.exceptions import InsufficientPointsException
from app.models.point_batch import PointBatch
from app.models.point_transaction import PointTransaction
from app.models.user_point import UserPoint
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
-def _load_points_config() -> dict:
+def _load_points_config() -> dict[str, Any]:
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
if not _CONFIG_PATH.exists():
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
- with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
+ with open(_CONFIG_PATH, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
merged: dict[str, dict] = {}
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
return merged
-POINTS_CONFIG: dict[str, dict] = _load_points_config()
+POINTS_CONFIG: dict[str, Any] = _load_points_config()
def get_recharge_options() -> list[dict]:
"""获取充值档位配置(由后端控制,支持积分赠送)"""
- return POINTS_CONFIG.get("_recharge_options", [])
+ options = POINTS_CONFIG.get("_recharge_options", [])
+ if isinstance(options, list):
+ return options
+ return []
def get_chargeable_source_types() -> list[str]:
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
return [
- key for key, cfg in POINTS_CONFIG.items()
+ key
+ for key, cfg in POINTS_CONFIG.items()
if not key.startswith("_") and cfg.get("mode") != "free"
]
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
# ── 余额查询 ──────────────────────────────────────────
+
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
- result = await db.execute(
- select(UserPoint).where(UserPoint.user_id == user_id)
- )
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
from sqlalchemy import func as _func
available_result = await db.execute(
- select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
- .where(
+ select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
PointBatch.user_id == user_id,
PointBatch.remaining > 0,
PointBatch.expired_at > _now(),
@@ -221,6 +224,7 @@ async def check_balance(
# ── 充值 ──────────────────────────────────────────────
+
async def recharge(
db: AsyncSession,
*,
@@ -247,8 +251,7 @@ async def recharge(
# 幂等保护:同一笔订单(order_id)只能充值一次
if order_id:
existing_result = await db.execute(
- select(PointTransaction)
- .where(
+ select(PointTransaction).where(
PointTransaction.source_id == str(order_id),
PointTransaction.type == "recharge",
)
@@ -259,9 +262,7 @@ async def recharge(
return existing_tx
# 1. 获取或创建用户积分账户
- result = await db.execute(
- select(UserPoint).where(UserPoint.user_id == user_id)
- )
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -353,7 +354,7 @@ async def consume(
直接扣费(后置计费)。
业务执行成功后调用,按实际消耗直接扣除余额。
- 默认不允许欠费(余额不足时抛出 ValueError)。
+ 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
:param points: 实际消耗积分(正整数)
@@ -383,9 +384,7 @@ async def consume(
# 2. 获取用户积分账户(加锁)
result = await db.execute(
- select(UserPoint)
- .where(UserPoint.user_id == user_id)
- .with_for_update()
+ select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
)
up = result.scalar_one_or_none()
@@ -404,7 +403,7 @@ async def consume(
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
available = sum(b.remaining for b in batches)
if not allow_negative and available < points:
- raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
+ raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
remaining_to_deduct = points
for batch in batches:
@@ -440,6 +439,7 @@ async def consume(
# ── 过期回收 ──────────────────────────────────────────
+
async def expire_batches(db: AsyncSession) -> int:
"""
回收过期积分批次。返回过期积分总数。
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
# 获取用户积分账户(加锁)
up_result = await db.execute(
- select(UserPoint)
- .where(UserPoint.user_id == batch.user_id)
- .with_for_update()
+ select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
)
up = up_result.scalar_one_or_none()
if not up:
diff --git a/python-api/app/services/script_service.py b/python-api/app/services/script_service.py
index 49aa4b1..60f58d5 100644
--- a/python-api/app/services/script_service.py
+++ b/python-api/app/services/script_service.py
@@ -7,9 +7,16 @@ import asyncio
import logging
import time
from pathlib import Path
+from typing import Any
from app.ai.model_router import get_model_router
from app.ai.prompts import load_prompt_file, load_script_user_prompt
+from app.core.exceptions import (
+ AIEmptyResponseException,
+ AIParseErrorException,
+ AITimeoutException,
+ PromptNotFoundException,
+)
from app.schemas.script import ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
-
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
-
-
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
@@ -58,7 +62,7 @@ class ScriptService:
# 加载 Prompt
system_prompt = load_prompt_file(category, filename)
if not system_prompt:
- raise ValueError(f"未找到提示词: category={category}, filename={filename}")
+ raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
# 用户提示词
user_prompt = load_script_user_prompt(
@@ -75,24 +79,26 @@ class ScriptService:
)
if not result.content or not result.content.strip():
- raise ValueError("AI 返回内容为空,请检查模型配置或重试")
+ raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
- raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
+ raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
- raise ValueError("AI 返回的分镜数据为空或格式不正确")
+ raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
shots = [ScriptShot(**shot) for shot in shots_data]
return shots
+ except (AIEmptyResponseException, AIParseErrorException):
+ raise
except Exception as e:
- raise ValueError(f"分镜数据处理失败: {str(e)}")
+ raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
async def polish_content(
self,
@@ -144,21 +150,23 @@ class ScriptService:
)
return result.content.strip()
except TimeoutError:
- raise ValueError("润色请求超时,请重试")
+ raise AITimeoutException("润色请求超时,请重试")
+ except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
+ raise
except Exception as e:
- raise ValueError(f"润色失败: {str(e)}")
+ raise AIParseErrorException(f"润色失败: {str(e)}")
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
- models = []
+ models: list[dict[str, Any]] = []
available_count = 0
- recommended = None
+ recommended: dict[str, Any] | None = None
for _provider_id, health in health_results.items():
- model_info = {
+ model_info: dict[str, Any] = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
@@ -169,9 +177,12 @@ class ScriptService:
if health.is_available:
available_count += 1
- if recommended is None or health.response_time < recommended.get(
- "response_time", float("inf")
- ):
+ current_best = (
+ float("inf")
+ if recommended is None
+ else float(recommended.get("response_time") or float("inf"))
+ )
+ if health.response_time < current_best:
recommended = model_info
total = len(models)
@@ -188,7 +199,6 @@ class ScriptService:
"""测试指定模型连接"""
model_router = await get_model_router()
-
start_time = time.time()
try:
diff --git a/python-api/app/services/vidu_service.py b/python-api/app/services/vidu_service.py
index 054823f..61bbdcd 100644
--- a/python-api/app/services/vidu_service.py
+++ b/python-api/app/services/vidu_service.py
@@ -207,8 +207,9 @@ class ViduService:
error_type=PlatformErrorType.BAD_REQUEST,
)
- logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
- return result.data or {}
+ clone_data = result.data or {}
+ logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
+ return clone_data
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
"""Vidu 声音复刻是同步接口,无独立查询。
@@ -270,6 +271,8 @@ class ViduService:
result_data = status.result or {}
return {
"state": ViduAdapter.denormalize_state(status.state),
- "creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
+ "creations": (
+ [{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
+ ),
"message": status.error_message,
}
diff --git a/python-api/app/services/volcengine_caption_service.py b/python-api/app/services/volcengine_caption_service.py
index 83f59b4..9a565a5 100644
--- a/python-api/app/services/volcengine_caption_service.py
+++ b/python-api/app/services/volcengine_caption_service.py
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
error_type=PlatformErrorType.BAD_REQUEST,
)
- logger.warning(
- f"{task_name}超过最大轮询次数: task_id={task_id}, "
- f"retries={retries}"
- )
+ logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
raise PlatformError(
f"{task_name}超时,请稍后重试",
platform="volcengine_caption",