diff --git a/python-api/app/services/point_service.py b/python-api/app/services/point_service.py index 0e5d79e..709f78d 100644 --- a/python-api/app/services/point_service.py +++ b/python-api/app/services/point_service.py @@ -25,7 +25,7 @@ import logging import math from datetime import UTC, datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import yaml from sqlalchemy import select @@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) +from app.core.exceptions import InsufficientPointsException from app.models.point_batch import PointBatch from app.models.point_transaction import PointTransaction from app.models.user_point import UserPoint @@ -46,11 +47,11 @@ if TYPE_CHECKING: _CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml" -def _load_points_config() -> dict: +def _load_points_config() -> dict[str, Any]: """加载积分计费配置。服务启动时读取一次,后续内存中使用。""" if not _CONFIG_PATH.exists(): raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}") - with open(_CONFIG_PATH, "r", encoding="utf-8") as f: + with open(_CONFIG_PATH, encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} # 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...} merged: dict[str, dict] = {} @@ -65,18 +66,22 @@ def _load_points_config() -> dict: return merged -POINTS_CONFIG: dict[str, dict] = _load_points_config() +POINTS_CONFIG: dict[str, Any] = _load_points_config() def get_recharge_options() -> list[dict]: """获取充值档位配置(由后端控制,支持积分赠送)""" - return POINTS_CONFIG.get("_recharge_options", []) + options = POINTS_CONFIG.get("_recharge_options", []) + if isinstance(options, list): + return options + return [] def get_chargeable_source_types() -> list[str]: """获取所有需要扣费的业务类型列表(排除免费业务)""" return [ - key for key, cfg in POINTS_CONFIG.items() + key + for key, cfg in POINTS_CONFIG.items() if not key.startswith("_") and cfg.get("mode") != "free" ] @@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int: # ── 余额查询 ────────────────────────────────────────── + async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict: """获取用户积分余额快照(实时计算,排除已过期批次)。""" - result = await db.execute( - select(UserPoint).where(UserPoint.user_id == user_id) - ) + result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id)) up = result.scalar_one_or_none() if not up: @@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict: from sqlalchemy import func as _func available_result = await db.execute( - select(_func.coalesce(_func.sum(PointBatch.remaining), 0)) - .where( + select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where( PointBatch.user_id == user_id, PointBatch.remaining > 0, PointBatch.expired_at > _now(), @@ -221,6 +224,7 @@ async def check_balance( # ── 充值 ────────────────────────────────────────────── + async def recharge( db: AsyncSession, *, @@ -247,8 +251,7 @@ async def recharge( # 幂等保护:同一笔订单(order_id)只能充值一次 if order_id: existing_result = await db.execute( - select(PointTransaction) - .where( + select(PointTransaction).where( PointTransaction.source_id == str(order_id), PointTransaction.type == "recharge", ) @@ -259,9 +262,7 @@ async def recharge( return existing_tx # 1. 获取或创建用户积分账户 - result = await db.execute( - select(UserPoint).where(UserPoint.user_id == user_id) - ) + result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id)) up = result.scalar_one_or_none() if not up: @@ -353,7 +354,7 @@ async def consume( 直接扣费(后置计费)。 业务执行成功后调用,按实际消耗直接扣除余额。 - 默认不允许欠费(余额不足时抛出 ValueError)。 + 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。 Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。 :param points: 实际消耗积分(正整数) @@ -383,9 +384,7 @@ async def consume( # 2. 获取用户积分账户(加锁) result = await db.execute( - select(UserPoint) - .where(UserPoint.user_id == user_id) - .with_for_update() + select(UserPoint).where(UserPoint.user_id == user_id).with_for_update() ) up = result.scalar_one_or_none() @@ -404,7 +403,7 @@ async def consume( # 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣 available = sum(b.remaining for b in batches) if not allow_negative and available < points: - raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分") + raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分") remaining_to_deduct = points for batch in batches: @@ -440,6 +439,7 @@ async def consume( # ── 过期回收 ────────────────────────────────────────── + async def expire_batches(db: AsyncSession) -> int: """ 回收过期积分批次。返回过期积分总数。 @@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int: # 获取用户积分账户(加锁) up_result = await db.execute( - select(UserPoint) - .where(UserPoint.user_id == batch.user_id) - .with_for_update() + select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update() ) up = up_result.scalar_one_or_none() if not up: diff --git a/python-api/app/services/script_service.py b/python-api/app/services/script_service.py index 49aa4b1..60f58d5 100644 --- a/python-api/app/services/script_service.py +++ b/python-api/app/services/script_service.py @@ -7,9 +7,16 @@ import asyncio import logging import time from pathlib import Path +from typing import Any from app.ai.model_router import get_model_router from app.ai.prompts import load_prompt_file, load_script_user_prompt +from app.core.exceptions import ( + AIEmptyResponseException, + AIParseErrorException, + AITimeoutException, + PromptNotFoundException, +) from app.schemas.script import ScriptShot from app.services.ai_response_utils import ( safe_parse_ai_json_response, @@ -22,12 +29,9 @@ logger = logging.getLogger(__name__) class ScriptService: """脚本生成服务""" - def __init__(self): self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts" - - def _load_prompt(self, name: str) -> str: """加载 Prompt 模板""" prompt_file = self.prompts_dir / f"{name}.txt" @@ -58,7 +62,7 @@ class ScriptService: # 加载 Prompt system_prompt = load_prompt_file(category, filename) if not system_prompt: - raise ValueError(f"未找到提示词: category={category}, filename={filename}") + raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}") # 用户提示词 user_prompt = load_script_user_prompt( @@ -75,24 +79,26 @@ class ScriptService: ) if not result.content or not result.content.strip(): - raise ValueError("AI 返回内容为空,请检查模型配置或重试") + raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试") success, parsed_data, error_msg = safe_parse_ai_json_response(result.content) if not success: - raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON") + raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON") try: shots_data = validate_and_normalize_shots(parsed_data) if not shots_data: - raise ValueError("AI 返回的分镜数据为空或格式不正确") + raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确") shots = [ScriptShot(**shot) for shot in shots_data] return shots + except (AIEmptyResponseException, AIParseErrorException): + raise except Exception as e: - raise ValueError(f"分镜数据处理失败: {str(e)}") + raise AIParseErrorException(f"分镜数据处理失败: {str(e)}") async def polish_content( self, @@ -144,21 +150,23 @@ class ScriptService: ) return result.content.strip() except TimeoutError: - raise ValueError("润色请求超时,请重试") + raise AITimeoutException("润色请求超时,请重试") + except (AIEmptyResponseException, AIParseErrorException, AITimeoutException): + raise except Exception as e: - raise ValueError(f"润色失败: {str(e)}") + raise AIParseErrorException(f"润色失败: {str(e)}") async def check_model_health(self) -> dict: """检查模型健康状态""" model_router = await get_model_router() health_results = await model_router.health_check() - models = [] + models: list[dict[str, Any]] = [] available_count = 0 - recommended = None + recommended: dict[str, Any] | None = None for _provider_id, health in health_results.items(): - model_info = { + model_info: dict[str, Any] = { "id": health.id, "name": health.name, "is_available": health.is_available, @@ -169,9 +177,12 @@ class ScriptService: if health.is_available: available_count += 1 - if recommended is None or health.response_time < recommended.get( - "response_time", float("inf") - ): + current_best = ( + float("inf") + if recommended is None + else float(recommended.get("response_time") or float("inf")) + ) + if health.response_time < current_best: recommended = model_info total = len(models) @@ -188,7 +199,6 @@ class ScriptService: """测试指定模型连接""" model_router = await get_model_router() - start_time = time.time() try: diff --git a/python-api/app/services/vidu_service.py b/python-api/app/services/vidu_service.py index 054823f..61bbdcd 100644 --- a/python-api/app/services/vidu_service.py +++ b/python-api/app/services/vidu_service.py @@ -207,8 +207,9 @@ class ViduService: error_type=PlatformErrorType.BAD_REQUEST, ) - logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}") - return result.data or {} + clone_data = result.data or {} + logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}") + return clone_data async def query_clone_task(self, voice_id: str) -> dict[str, Any]: """Vidu 声音复刻是同步接口,无独立查询。 @@ -270,6 +271,8 @@ class ViduService: result_data = status.result or {} return { "state": ViduAdapter.denormalize_state(status.state), - "creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [], + "creations": ( + [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [] + ), "message": status.error_message, } diff --git a/python-api/app/services/volcengine_caption_service.py b/python-api/app/services/volcengine_caption_service.py index 83f59b4..9a565a5 100644 --- a/python-api/app/services/volcengine_caption_service.py +++ b/python-api/app/services/volcengine_caption_service.py @@ -155,10 +155,7 @@ class VolcengineCaptionService: error_type=PlatformErrorType.BAD_REQUEST, ) - logger.warning( - f"{task_name}超过最大轮询次数: task_id={task_id}, " - f"retries={retries}" - ) + logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}") raise PlatformError( f"{task_name}超时,请稍后重试", platform="volcengine_caption",