e262134148
删除内容: - KlingAI Provider、MiniMax Provider - Kling 视频/图片/TTS/语音克隆/形象克隆 Service 和 Scheduler Handler - 已废弃的 TTSService、VoiceCloneService - config 中 KLINGAI_*/MINIMAX_* 配置项 - ai_models.yaml 中 klingai 平台和模型配置 - docker-compose 中相关环境变量 - .env.example 中相关配置示例 - deploy-test.sh 中相关检查 - Makefile 中 klingai 语义检查排除规则 - KlingTaskStatus 枚举 修改内容: - model_router.py 移除 KlingAI 平台分支 - voice.py 重写,修复批量合成/文件保存中 service 未定义的 Bug - vidu_service.py 移除 MiniMax 相关注释 - script_handler.py 更新注释
436 lines
13 KiB
Python
436 lines
13 KiB
Python
"""
|
||
Token 管理器 - 通用 API 认证 Token 缓存与自动刷新
|
||
|
||
支持:
|
||
- JWT Token
|
||
- OAuth2 Access Token
|
||
- 自定义 Token 类型
|
||
|
||
特性:
|
||
- 线程/协程安全的 token 缓存
|
||
- 自动刷新(带安全边界)
|
||
- 后台预热机制
|
||
- 支持多 Provider 实例隔离
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Protocol
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class TokenInfo:
|
||
"""Token 信息容器"""
|
||
|
||
token: str
|
||
expires_at: float # 过期时间戳(秒)
|
||
token_type: str = "Bearer"
|
||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||
|
||
@property
|
||
def is_expired(self) -> bool:
|
||
"""是否已过期"""
|
||
return time.time() >= self.expires_at
|
||
|
||
@property
|
||
def expires_in(self) -> float:
|
||
"""剩余有效时间(秒)"""
|
||
return max(0, self.expires_at - time.time())
|
||
|
||
def is_near_expiry(self, safety_margin: float = 300) -> bool:
|
||
"""
|
||
是否接近过期(需要刷新)
|
||
|
||
Args:
|
||
safety_margin: 安全边界(秒),默认5分钟
|
||
"""
|
||
return time.time() >= (self.expires_at - safety_margin)
|
||
|
||
|
||
class TokenGenerator(Protocol):
|
||
"""Token 生成函数协议"""
|
||
|
||
async def __call__(self) -> TokenInfo:
|
||
"""生成/获取新的 token"""
|
||
...
|
||
|
||
|
||
class BaseTokenStrategy(ABC):
|
||
"""Token 生成策略基类"""
|
||
|
||
@abstractmethod
|
||
async def generate(self) -> TokenInfo:
|
||
"""生成新的 token"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_cache_key(self) -> str:
|
||
"""获取缓存标识(用于多实例隔离)"""
|
||
pass
|
||
|
||
|
||
class JWTTokenStrategy(BaseTokenStrategy):
|
||
"""JWT Token 生成策略"""
|
||
|
||
def __init__(
|
||
self,
|
||
access_key: str,
|
||
secret_key: str,
|
||
expires_in: int = 1800,
|
||
algorithm: str = "HS256",
|
||
token_type: str = "JWT",
|
||
):
|
||
self.access_key = access_key
|
||
self.secret_key = secret_key
|
||
self.expires_in = expires_in # 默认30分钟
|
||
self.algorithm = algorithm
|
||
self.token_type = token_type
|
||
|
||
async def generate(self) -> TokenInfo:
|
||
"""生成 JWT Token"""
|
||
from jose import jwt
|
||
|
||
headers = {"alg": self.algorithm, "typ": self.token_type}
|
||
current_time = int(time.time())
|
||
payload = {
|
||
"iss": self.access_key,
|
||
"exp": current_time + self.expires_in,
|
||
"nbf": current_time - 5,
|
||
}
|
||
|
||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm, headers=headers)
|
||
|
||
return TokenInfo(
|
||
token=token,
|
||
expires_at=current_time + self.expires_in,
|
||
token_type="Bearer",
|
||
)
|
||
|
||
def get_cache_key(self) -> str:
|
||
"""缓存标识:access_key 的 hash"""
|
||
return f"jwt:{self.access_key[:8]}"
|
||
|
||
|
||
class OAuth2TokenStrategy(BaseTokenStrategy):
|
||
"""OAuth2 Token 生成策略"""
|
||
|
||
def __init__(
|
||
self,
|
||
client_id: str,
|
||
client_secret: str,
|
||
token_url: str,
|
||
scope: str | None = None,
|
||
extra_params: dict[str, Any] | None = None,
|
||
):
|
||
self.client_id = client_id
|
||
self.client_secret = client_secret
|
||
self.token_url = token_url
|
||
self.scope = scope
|
||
self.extra_params = extra_params or {}
|
||
|
||
async def generate(self) -> TokenInfo:
|
||
"""从 OAuth2 服务器获取 token"""
|
||
import httpx
|
||
|
||
data = {
|
||
"grant_type": "client_credentials",
|
||
"client_id": self.client_id,
|
||
"client_secret": self.client_secret,
|
||
**self.extra_params,
|
||
}
|
||
if self.scope:
|
||
data["scope"] = self.scope
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(self.token_url, data=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
access_token = result["access_token"]
|
||
expires_in = result.get("expires_in", 3600)
|
||
token_type = result.get("token_type", "Bearer")
|
||
|
||
return TokenInfo(
|
||
token=access_token,
|
||
expires_at=time.time() + expires_in,
|
||
token_type=token_type,
|
||
extra_data={
|
||
k: v
|
||
for k, v in result.items()
|
||
if k not in ["access_token", "expires_in", "token_type"]
|
||
},
|
||
)
|
||
|
||
def get_cache_key(self) -> str:
|
||
"""缓存标识:client_id + token_url 的 hash"""
|
||
return f"oauth2:{self.client_id[:8]}:{hash(self.token_url) % 10000}"
|
||
|
||
|
||
class TokenManager:
|
||
"""
|
||
Token 管理器 - 单例模式,全局统一管理所有 token
|
||
|
||
使用示例:
|
||
# JWT 方式
|
||
strategy = JWTTokenStrategy(access_key="xxx", secret_key="yyy")
|
||
token = await TokenManager.get_instance().get_token(strategy)
|
||
|
||
# OAuth2 方式
|
||
strategy = OAuth2TokenStrategy(
|
||
client_id="xxx",
|
||
client_secret="yyy",
|
||
token_url="https://api.example.com/oauth2/token"
|
||
)
|
||
token = await TokenManager.get_instance().get_token(strategy)
|
||
"""
|
||
|
||
_instance: TokenManager | None = None
|
||
_lock: asyncio.Lock | None = None
|
||
|
||
def __new__(cls) -> TokenManager:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
@classmethod
|
||
def get_instance(cls) -> TokenManager:
|
||
"""获取单例实例"""
|
||
if cls._instance is None:
|
||
cls._instance = cls()
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# token 缓存: {cache_key: TokenInfo}
|
||
self._tokens: dict[str, TokenInfo] = {}
|
||
|
||
# 刷新锁: {cache_key: asyncio.Lock}
|
||
self._refresh_locks: dict[str, asyncio.Lock] = {}
|
||
|
||
# 全局锁,用于创建新的 refresh_lock
|
||
self._global_lock = asyncio.Lock()
|
||
|
||
# 后台刷新任务
|
||
self._background_tasks: set[asyncio.Task] = set()
|
||
|
||
# 预热配置
|
||
self._safety_margin = 300 # 提前5分钟刷新
|
||
self._preemptive_refresh = True # 启用预热机制
|
||
|
||
self._initialized = True
|
||
|
||
async def get_token(
|
||
self,
|
||
strategy: BaseTokenStrategy,
|
||
force_refresh: bool = False,
|
||
) -> TokenInfo:
|
||
"""
|
||
获取有效的 token
|
||
|
||
Args:
|
||
strategy: Token 生成策略
|
||
force_refresh: 强制刷新(忽略缓存)
|
||
|
||
Returns:
|
||
TokenInfo: 有效的 token 信息
|
||
"""
|
||
cache_key = strategy.get_cache_key()
|
||
|
||
# 检查缓存
|
||
if not force_refresh and cache_key in self._tokens:
|
||
token_info = self._tokens[cache_key]
|
||
if not token_info.is_near_expiry(self._safety_margin):
|
||
logger.debug(f"Token cache hit for {cache_key}")
|
||
return token_info
|
||
|
||
# 需要刷新 token
|
||
return await self._refresh_token(strategy)
|
||
|
||
async def get_token_string(
|
||
self,
|
||
strategy: BaseTokenStrategy,
|
||
force_refresh: bool = False,
|
||
) -> str:
|
||
"""
|
||
获取 token 字符串(快捷方法)
|
||
|
||
Returns:
|
||
str: token 字符串(带 Bearer 前缀)
|
||
"""
|
||
token_info = await self.get_token(strategy, force_refresh)
|
||
return f"{token_info.token_type} {token_info.token}"
|
||
|
||
async def _refresh_token(self, strategy: BaseTokenStrategy) -> TokenInfo:
|
||
"""
|
||
刷新 token(带并发控制)
|
||
|
||
使用双重检查锁定模式,确保并发请求只触发一次刷新
|
||
"""
|
||
cache_key = strategy.get_cache_key()
|
||
|
||
# 获取或创建该 cache_key 专用的刷新锁
|
||
async with self._global_lock:
|
||
if cache_key not in self._refresh_locks:
|
||
self._refresh_locks[cache_key] = asyncio.Lock()
|
||
|
||
refresh_lock = self._refresh_locks[cache_key]
|
||
|
||
async with refresh_lock:
|
||
# 双重检查:等待锁之后,可能其他协程已经刷新过了
|
||
if cache_key in self._tokens:
|
||
token_info = self._tokens[cache_key]
|
||
if not token_info.is_near_expiry(self._safety_margin):
|
||
logger.debug(f"Token refreshed by another task for {cache_key}")
|
||
return token_info
|
||
|
||
# 执行刷新
|
||
logger.info(f"Refreshing token for {cache_key}")
|
||
try:
|
||
new_token = await strategy.generate()
|
||
self._tokens[cache_key] = new_token
|
||
|
||
# 启动后台预热任务
|
||
if self._preemptive_refresh:
|
||
self._schedule_preemptive_refresh(strategy, new_token)
|
||
|
||
logger.info(
|
||
f"Token refreshed successfully for {cache_key}, expires in {new_token.expires_in:.0f}s"
|
||
)
|
||
return new_token
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to refresh token for {cache_key}: {e}")
|
||
# 如果刷新失败但缓存的 token 还能用,返回缓存的
|
||
if cache_key in self._tokens:
|
||
cached = self._tokens[cache_key]
|
||
if not cached.is_expired:
|
||
logger.warning(
|
||
f"Using expired cache for {cache_key} due to refresh failure"
|
||
)
|
||
return cached
|
||
raise
|
||
|
||
def _schedule_preemptive_refresh(self, strategy: BaseTokenStrategy, token_info: TokenInfo):
|
||
"""
|
||
调度后台预热刷新任务
|
||
|
||
在 token 即将过期前自动刷新,避免请求时等待
|
||
"""
|
||
cache_key = strategy.get_cache_key()
|
||
|
||
# 计算预热时间(token 过期前 safety_margin * 2)
|
||
refresh_at = token_info.expires_at - self._safety_margin * 2
|
||
delay = max(0, refresh_at - time.time())
|
||
|
||
async def _refresh_task():
|
||
await asyncio.sleep(delay)
|
||
try:
|
||
logger.info(f"Preemptive token refresh for {cache_key}")
|
||
await self._refresh_token(strategy)
|
||
except Exception as e:
|
||
logger.error(f"Preemptive refresh failed for {cache_key}: {e}")
|
||
|
||
# 创建后台任务
|
||
task = asyncio.create_task(_refresh_task())
|
||
self._background_tasks.add(task)
|
||
task.add_done_callback(self._background_tasks.discard)
|
||
|
||
logger.debug(f"Scheduled preemptive refresh for {cache_key} in {delay:.0f}s")
|
||
|
||
async def invalidate(self, strategy: BaseTokenStrategy) -> bool:
|
||
"""
|
||
使缓存失效
|
||
|
||
Returns:
|
||
bool: 是否成功删除
|
||
"""
|
||
cache_key = strategy.get_cache_key()
|
||
if cache_key in self._tokens:
|
||
del self._tokens[cache_key]
|
||
logger.info(f"Token cache invalidated for {cache_key}")
|
||
return True
|
||
return False
|
||
|
||
def clear(self):
|
||
"""清除所有 token 缓存"""
|
||
self._tokens.clear()
|
||
logger.info("All token caches cleared")
|
||
|
||
def get_stats(self) -> dict[str, Any]:
|
||
"""获取缓存统计信息"""
|
||
stats = {
|
||
"total_cached": len(self._tokens),
|
||
"active_tasks": len(self._background_tasks),
|
||
"tokens": {},
|
||
}
|
||
|
||
for key, token_info in self._tokens.items():
|
||
stats["tokens"][key] = {
|
||
"expires_in": token_info.expires_in,
|
||
"is_expired": token_info.is_expired,
|
||
"is_near_expiry": token_info.is_near_expiry(self._safety_margin),
|
||
}
|
||
|
||
return stats
|
||
|
||
|
||
# 便捷函数
|
||
|
||
|
||
async def get_jwt_token(
|
||
access_key: str,
|
||
secret_key: str,
|
||
expires_in: int = 1800,
|
||
algorithm: str = "HS256",
|
||
) -> TokenInfo:
|
||
"""
|
||
获取 JWT Token(使用全局 TokenManager)
|
||
|
||
示例:
|
||
token_info = await get_jwt_token("access_key", "secret_key")
|
||
headers = {"Authorization": f"Bearer {token_info.token}"}
|
||
"""
|
||
strategy = JWTTokenStrategy(
|
||
access_key=access_key,
|
||
secret_key=secret_key,
|
||
expires_in=expires_in,
|
||
algorithm=algorithm,
|
||
)
|
||
return await TokenManager.get_instance().get_token(strategy)
|
||
|
||
|
||
async def get_oauth2_token(
|
||
client_id: str,
|
||
client_secret: str,
|
||
token_url: str,
|
||
scope: str | None = None,
|
||
) -> TokenInfo:
|
||
"""
|
||
获取 OAuth2 Token(使用全局 TokenManager)
|
||
|
||
示例:
|
||
token_info = await get_oauth2_token(
|
||
client_id="xxx",
|
||
client_secret="yyy",
|
||
token_url="https://api.example.com/oauth2/token"
|
||
)
|
||
headers = {"Authorization": f"Bearer {token_info.token}"}
|
||
"""
|
||
strategy = OAuth2TokenStrategy(
|
||
client_id=client_id,
|
||
client_secret=client_secret,
|
||
token_url=token_url,
|
||
scope=scope,
|
||
)
|
||
return await TokenManager.get_instance().get_token(strategy)
|