Files
meijiaka-zy/python-api/app/core/token_manager.py
T
小鱼开发 e262134148 refactor: 移除 KlingAI 和 MiniMax 相关代码
删除内容:
- KlingAI Provider、MiniMax Provider
- Kling 视频/图片/TTS/语音克隆/形象克隆 Service 和 Scheduler Handler
- 已废弃的 TTSService、VoiceCloneService
- config 中 KLINGAI_*/MINIMAX_* 配置项
- ai_models.yaml 中 klingai 平台和模型配置
- docker-compose 中相关环境变量
- .env.example 中相关配置示例
- deploy-test.sh 中相关检查
- Makefile 中 klingai 语义检查排除规则
- KlingTaskStatus 枚举

修改内容:
- model_router.py 移除 KlingAI 平台分支
- voice.py 重写,修复批量合成/文件保存中 service 未定义的 Bug
- vidu_service.py 移除 MiniMax 相关注释
- script_handler.py 更新注释
2026-05-02 23:16:14 +08:00

436 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Token 管理器 - 通用 API 认证 Token 缓存与自动刷新
支持:
- JWT Token
- OAuth2 Access Token
- 自定义 Token 类型
特性:
- 线程/协程安全的 token 缓存
- 自动刷新(带安全边界)
- 后台预热机制
- 支持多 Provider 实例隔离
"""
from __future__ import annotations
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Protocol
logger = logging.getLogger(__name__)
@dataclass
class TokenInfo:
"""Token 信息容器"""
token: str
expires_at: float # 过期时间戳(秒)
token_type: str = "Bearer"
extra_data: dict[str, Any] = field(default_factory=dict)
@property
def is_expired(self) -> bool:
"""是否已过期"""
return time.time() >= self.expires_at
@property
def expires_in(self) -> float:
"""剩余有效时间(秒)"""
return max(0, self.expires_at - time.time())
def is_near_expiry(self, safety_margin: float = 300) -> bool:
"""
是否接近过期(需要刷新)
Args:
safety_margin: 安全边界(秒),默认5分钟
"""
return time.time() >= (self.expires_at - safety_margin)
class TokenGenerator(Protocol):
"""Token 生成函数协议"""
async def __call__(self) -> TokenInfo:
"""生成/获取新的 token"""
...
class BaseTokenStrategy(ABC):
"""Token 生成策略基类"""
@abstractmethod
async def generate(self) -> TokenInfo:
"""生成新的 token"""
pass
@abstractmethod
def get_cache_key(self) -> str:
"""获取缓存标识(用于多实例隔离)"""
pass
class JWTTokenStrategy(BaseTokenStrategy):
"""JWT Token 生成策略"""
def __init__(
self,
access_key: str,
secret_key: str,
expires_in: int = 1800,
algorithm: str = "HS256",
token_type: str = "JWT",
):
self.access_key = access_key
self.secret_key = secret_key
self.expires_in = expires_in # 默认30分钟
self.algorithm = algorithm
self.token_type = token_type
async def generate(self) -> TokenInfo:
"""生成 JWT Token"""
from jose import jwt
headers = {"alg": self.algorithm, "typ": self.token_type}
current_time = int(time.time())
payload = {
"iss": self.access_key,
"exp": current_time + self.expires_in,
"nbf": current_time - 5,
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm, headers=headers)
return TokenInfo(
token=token,
expires_at=current_time + self.expires_in,
token_type="Bearer",
)
def get_cache_key(self) -> str:
"""缓存标识:access_key 的 hash"""
return f"jwt:{self.access_key[:8]}"
class OAuth2TokenStrategy(BaseTokenStrategy):
"""OAuth2 Token 生成策略"""
def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scope: str | None = None,
extra_params: dict[str, Any] | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.scope = scope
self.extra_params = extra_params or {}
async def generate(self) -> TokenInfo:
"""从 OAuth2 服务器获取 token"""
import httpx
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
**self.extra_params,
}
if self.scope:
data["scope"] = self.scope
async with httpx.AsyncClient() as client:
response = await client.post(self.token_url, data=data)
response.raise_for_status()
result = response.json()
access_token = result["access_token"]
expires_in = result.get("expires_in", 3600)
token_type = result.get("token_type", "Bearer")
return TokenInfo(
token=access_token,
expires_at=time.time() + expires_in,
token_type=token_type,
extra_data={
k: v
for k, v in result.items()
if k not in ["access_token", "expires_in", "token_type"]
},
)
def get_cache_key(self) -> str:
"""缓存标识:client_id + token_url 的 hash"""
return f"oauth2:{self.client_id[:8]}:{hash(self.token_url) % 10000}"
class TokenManager:
"""
Token 管理器 - 单例模式,全局统一管理所有 token
使用示例:
# JWT 方式
strategy = JWTTokenStrategy(access_key="xxx", secret_key="yyy")
token = await TokenManager.get_instance().get_token(strategy)
# OAuth2 方式
strategy = OAuth2TokenStrategy(
client_id="xxx",
client_secret="yyy",
token_url="https://api.example.com/oauth2/token"
)
token = await TokenManager.get_instance().get_token(strategy)
"""
_instance: TokenManager | None = None
_lock: asyncio.Lock | None = None
def __new__(cls) -> TokenManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
@classmethod
def get_instance(cls) -> TokenManager:
"""获取单例实例"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if self._initialized:
return
# token 缓存: {cache_key: TokenInfo}
self._tokens: dict[str, TokenInfo] = {}
# 刷新锁: {cache_key: asyncio.Lock}
self._refresh_locks: dict[str, asyncio.Lock] = {}
# 全局锁,用于创建新的 refresh_lock
self._global_lock = asyncio.Lock()
# 后台刷新任务
self._background_tasks: set[asyncio.Task] = set()
# 预热配置
self._safety_margin = 300 # 提前5分钟刷新
self._preemptive_refresh = True # 启用预热机制
self._initialized = True
async def get_token(
self,
strategy: BaseTokenStrategy,
force_refresh: bool = False,
) -> TokenInfo:
"""
获取有效的 token
Args:
strategy: Token 生成策略
force_refresh: 强制刷新(忽略缓存)
Returns:
TokenInfo: 有效的 token 信息
"""
cache_key = strategy.get_cache_key()
# 检查缓存
if not force_refresh and cache_key in self._tokens:
token_info = self._tokens[cache_key]
if not token_info.is_near_expiry(self._safety_margin):
logger.debug(f"Token cache hit for {cache_key}")
return token_info
# 需要刷新 token
return await self._refresh_token(strategy)
async def get_token_string(
self,
strategy: BaseTokenStrategy,
force_refresh: bool = False,
) -> str:
"""
获取 token 字符串(快捷方法)
Returns:
str: token 字符串(带 Bearer 前缀)
"""
token_info = await self.get_token(strategy, force_refresh)
return f"{token_info.token_type} {token_info.token}"
async def _refresh_token(self, strategy: BaseTokenStrategy) -> TokenInfo:
"""
刷新 token(带并发控制)
使用双重检查锁定模式,确保并发请求只触发一次刷新
"""
cache_key = strategy.get_cache_key()
# 获取或创建该 cache_key 专用的刷新锁
async with self._global_lock:
if cache_key not in self._refresh_locks:
self._refresh_locks[cache_key] = asyncio.Lock()
refresh_lock = self._refresh_locks[cache_key]
async with refresh_lock:
# 双重检查:等待锁之后,可能其他协程已经刷新过了
if cache_key in self._tokens:
token_info = self._tokens[cache_key]
if not token_info.is_near_expiry(self._safety_margin):
logger.debug(f"Token refreshed by another task for {cache_key}")
return token_info
# 执行刷新
logger.info(f"Refreshing token for {cache_key}")
try:
new_token = await strategy.generate()
self._tokens[cache_key] = new_token
# 启动后台预热任务
if self._preemptive_refresh:
self._schedule_preemptive_refresh(strategy, new_token)
logger.info(
f"Token refreshed successfully for {cache_key}, expires in {new_token.expires_in:.0f}s"
)
return new_token
except Exception as e:
logger.error(f"Failed to refresh token for {cache_key}: {e}")
# 如果刷新失败但缓存的 token 还能用,返回缓存的
if cache_key in self._tokens:
cached = self._tokens[cache_key]
if not cached.is_expired:
logger.warning(
f"Using expired cache for {cache_key} due to refresh failure"
)
return cached
raise
def _schedule_preemptive_refresh(self, strategy: BaseTokenStrategy, token_info: TokenInfo):
"""
调度后台预热刷新任务
在 token 即将过期前自动刷新,避免请求时等待
"""
cache_key = strategy.get_cache_key()
# 计算预热时间(token 过期前 safety_margin * 2
refresh_at = token_info.expires_at - self._safety_margin * 2
delay = max(0, refresh_at - time.time())
async def _refresh_task():
await asyncio.sleep(delay)
try:
logger.info(f"Preemptive token refresh for {cache_key}")
await self._refresh_token(strategy)
except Exception as e:
logger.error(f"Preemptive refresh failed for {cache_key}: {e}")
# 创建后台任务
task = asyncio.create_task(_refresh_task())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
logger.debug(f"Scheduled preemptive refresh for {cache_key} in {delay:.0f}s")
async def invalidate(self, strategy: BaseTokenStrategy) -> bool:
"""
使缓存失效
Returns:
bool: 是否成功删除
"""
cache_key = strategy.get_cache_key()
if cache_key in self._tokens:
del self._tokens[cache_key]
logger.info(f"Token cache invalidated for {cache_key}")
return True
return False
def clear(self):
"""清除所有 token 缓存"""
self._tokens.clear()
logger.info("All token caches cleared")
def get_stats(self) -> dict[str, Any]:
"""获取缓存统计信息"""
stats = {
"total_cached": len(self._tokens),
"active_tasks": len(self._background_tasks),
"tokens": {},
}
for key, token_info in self._tokens.items():
stats["tokens"][key] = {
"expires_in": token_info.expires_in,
"is_expired": token_info.is_expired,
"is_near_expiry": token_info.is_near_expiry(self._safety_margin),
}
return stats
# 便捷函数
async def get_jwt_token(
access_key: str,
secret_key: str,
expires_in: int = 1800,
algorithm: str = "HS256",
) -> TokenInfo:
"""
获取 JWT Token(使用全局 TokenManager
示例:
token_info = await get_jwt_token("access_key", "secret_key")
headers = {"Authorization": f"Bearer {token_info.token}"}
"""
strategy = JWTTokenStrategy(
access_key=access_key,
secret_key=secret_key,
expires_in=expires_in,
algorithm=algorithm,
)
return await TokenManager.get_instance().get_token(strategy)
async def get_oauth2_token(
client_id: str,
client_secret: str,
token_url: str,
scope: str | None = None,
) -> TokenInfo:
"""
获取 OAuth2 Token(使用全局 TokenManager
示例:
token_info = await get_oauth2_token(
client_id="xxx",
client_secret="yyy",
token_url="https://api.example.com/oauth2/token"
)
headers = {"Authorization": f"Bearer {token_info.token}"}
"""
strategy = OAuth2TokenStrategy(
client_id=client_id,
client_secret=client_secret,
token_url=token_url,
scope=scope,
)
return await TokenManager.get_instance().get_token(strategy)