""" Token 管理器 - 通用 API 认证 Token 缓存与自动刷新 支持: - JWT Token(如 KlingAI) - OAuth2 Access Token - 自定义 Token 类型 特性: - 线程/协程安全的 token 缓存 - 自动刷新(带安全边界) - 后台预热机制 - 支持多 Provider 实例隔离 """ from __future__ import annotations import asyncio import logging import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Protocol logger = logging.getLogger(__name__) @dataclass class TokenInfo: """Token 信息容器""" token: str expires_at: float # 过期时间戳(秒) token_type: str = "Bearer" extra_data: dict[str, Any] = field(default_factory=dict) @property def is_expired(self) -> bool: """是否已过期""" return time.time() >= self.expires_at @property def expires_in(self) -> float: """剩余有效时间(秒)""" return max(0, self.expires_at - time.time()) def is_near_expiry(self, safety_margin: float = 300) -> bool: """ 是否接近过期(需要刷新) Args: safety_margin: 安全边界(秒),默认5分钟 """ return time.time() >= (self.expires_at - safety_margin) class TokenGenerator(Protocol): """Token 生成函数协议""" async def __call__(self) -> TokenInfo: """生成/获取新的 token""" ... class BaseTokenStrategy(ABC): """Token 生成策略基类""" @abstractmethod async def generate(self) -> TokenInfo: """生成新的 token""" pass @abstractmethod def get_cache_key(self) -> str: """获取缓存标识(用于多实例隔离)""" pass class JWTTokenStrategy(BaseTokenStrategy): """JWT Token 生成策略(用于 KlingAI 等)""" def __init__( self, access_key: str, secret_key: str, expires_in: int = 1800, algorithm: str = "HS256", token_type: str = "JWT", ): self.access_key = access_key self.secret_key = secret_key self.expires_in = expires_in # 默认30分钟 self.algorithm = algorithm self.token_type = token_type async def generate(self) -> TokenInfo: """生成 JWT Token""" from jose import jwt headers = {"alg": self.algorithm, "typ": self.token_type} current_time = int(time.time()) payload = { "iss": self.access_key, "exp": current_time + self.expires_in, "nbf": current_time - 5, } token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm, headers=headers) return TokenInfo( token=token, expires_at=current_time + self.expires_in, token_type="Bearer", ) def get_cache_key(self) -> str: """缓存标识:access_key 的 hash""" return f"jwt:{self.access_key[:8]}" class OAuth2TokenStrategy(BaseTokenStrategy): """OAuth2 Token 生成策略""" def __init__( self, client_id: str, client_secret: str, token_url: str, scope: str | None = None, extra_params: dict[str, Any] | None = None, ): self.client_id = client_id self.client_secret = client_secret self.token_url = token_url self.scope = scope self.extra_params = extra_params or {} async def generate(self) -> TokenInfo: """从 OAuth2 服务器获取 token""" import httpx data = { "grant_type": "client_credentials", "client_id": self.client_id, "client_secret": self.client_secret, **self.extra_params, } if self.scope: data["scope"] = self.scope async with httpx.AsyncClient() as client: response = await client.post(self.token_url, data=data) response.raise_for_status() result = response.json() access_token = result["access_token"] expires_in = result.get("expires_in", 3600) token_type = result.get("token_type", "Bearer") return TokenInfo( token=access_token, expires_at=time.time() + expires_in, token_type=token_type, extra_data={ k: v for k, v in result.items() if k not in ["access_token", "expires_in", "token_type"] }, ) def get_cache_key(self) -> str: """缓存标识:client_id + token_url 的 hash""" return f"oauth2:{self.client_id[:8]}:{hash(self.token_url) % 10000}" class TokenManager: """ Token 管理器 - 单例模式,全局统一管理所有 token 使用示例: # JWT 方式(KlingAI) strategy = JWTTokenStrategy(access_key="xxx", secret_key="yyy") token = await TokenManager.get_instance().get_token(strategy) # OAuth2 方式 strategy = OAuth2TokenStrategy( client_id="xxx", client_secret="yyy", token_url="https://api.example.com/oauth2/token" ) token = await TokenManager.get_instance().get_token(strategy) """ _instance: TokenManager | None = None _lock: asyncio.Lock | None = None def __new__(cls) -> TokenManager: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance @classmethod def get_instance(cls) -> TokenManager: """获取单例实例""" if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): if self._initialized: return # token 缓存: {cache_key: TokenInfo} self._tokens: dict[str, TokenInfo] = {} # 刷新锁: {cache_key: asyncio.Lock} self._refresh_locks: dict[str, asyncio.Lock] = {} # 全局锁,用于创建新的 refresh_lock self._global_lock = asyncio.Lock() # 后台刷新任务 self._background_tasks: set[asyncio.Task] = set() # 预热配置 self._safety_margin = 300 # 提前5分钟刷新 self._preemptive_refresh = True # 启用预热机制 self._initialized = True async def get_token( self, strategy: BaseTokenStrategy, force_refresh: bool = False, ) -> TokenInfo: """ 获取有效的 token Args: strategy: Token 生成策略 force_refresh: 强制刷新(忽略缓存) Returns: TokenInfo: 有效的 token 信息 """ cache_key = strategy.get_cache_key() # 检查缓存 if not force_refresh and cache_key in self._tokens: token_info = self._tokens[cache_key] if not token_info.is_near_expiry(self._safety_margin): logger.debug(f"Token cache hit for {cache_key}") return token_info # 需要刷新 token return await self._refresh_token(strategy) async def get_token_string( self, strategy: BaseTokenStrategy, force_refresh: bool = False, ) -> str: """ 获取 token 字符串(快捷方法) Returns: str: token 字符串(带 Bearer 前缀) """ token_info = await self.get_token(strategy, force_refresh) return f"{token_info.token_type} {token_info.token}" async def _refresh_token(self, strategy: BaseTokenStrategy) -> TokenInfo: """ 刷新 token(带并发控制) 使用双重检查锁定模式,确保并发请求只触发一次刷新 """ cache_key = strategy.get_cache_key() # 获取或创建该 cache_key 专用的刷新锁 async with self._global_lock: if cache_key not in self._refresh_locks: self._refresh_locks[cache_key] = asyncio.Lock() refresh_lock = self._refresh_locks[cache_key] async with refresh_lock: # 双重检查:等待锁之后,可能其他协程已经刷新过了 if cache_key in self._tokens: token_info = self._tokens[cache_key] if not token_info.is_near_expiry(self._safety_margin): logger.debug(f"Token refreshed by another task for {cache_key}") return token_info # 执行刷新 logger.info(f"Refreshing token for {cache_key}") try: new_token = await strategy.generate() self._tokens[cache_key] = new_token # 启动后台预热任务 if self._preemptive_refresh: self._schedule_preemptive_refresh(strategy, new_token) logger.info( f"Token refreshed successfully for {cache_key}, expires in {new_token.expires_in:.0f}s" ) return new_token except Exception as e: logger.error(f"Failed to refresh token for {cache_key}: {e}") # 如果刷新失败但缓存的 token 还能用,返回缓存的 if cache_key in self._tokens: cached = self._tokens[cache_key] if not cached.is_expired: logger.warning( f"Using expired cache for {cache_key} due to refresh failure" ) return cached raise def _schedule_preemptive_refresh(self, strategy: BaseTokenStrategy, token_info: TokenInfo): """ 调度后台预热刷新任务 在 token 即将过期前自动刷新,避免请求时等待 """ cache_key = strategy.get_cache_key() # 计算预热时间(token 过期前 safety_margin * 2) refresh_at = token_info.expires_at - self._safety_margin * 2 delay = max(0, refresh_at - time.time()) async def _refresh_task(): await asyncio.sleep(delay) try: logger.info(f"Preemptive token refresh for {cache_key}") await self._refresh_token(strategy) except Exception as e: logger.error(f"Preemptive refresh failed for {cache_key}: {e}") # 创建后台任务 task = asyncio.create_task(_refresh_task()) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) logger.debug(f"Scheduled preemptive refresh for {cache_key} in {delay:.0f}s") async def invalidate(self, strategy: BaseTokenStrategy) -> bool: """ 使缓存失效 Returns: bool: 是否成功删除 """ cache_key = strategy.get_cache_key() if cache_key in self._tokens: del self._tokens[cache_key] logger.info(f"Token cache invalidated for {cache_key}") return True return False def clear(self): """清除所有 token 缓存""" self._tokens.clear() logger.info("All token caches cleared") def get_stats(self) -> dict[str, Any]: """获取缓存统计信息""" now = time.time() stats = { "total_cached": len(self._tokens), "active_tasks": len(self._background_tasks), "tokens": {}, } for key, token_info in self._tokens.items(): stats["tokens"][key] = { "expires_in": token_info.expires_in, "is_expired": token_info.is_expired, "is_near_expiry": token_info.is_near_expiry(self._safety_margin), } return stats # 便捷函数 async def get_jwt_token( access_key: str, secret_key: str, expires_in: int = 1800, algorithm: str = "HS256", ) -> TokenInfo: """ 获取 JWT Token(使用全局 TokenManager) 示例: token_info = await get_jwt_token("access_key", "secret_key") headers = {"Authorization": f"Bearer {token_info.token}"} """ strategy = JWTTokenStrategy( access_key=access_key, secret_key=secret_key, expires_in=expires_in, algorithm=algorithm, ) return await TokenManager.get_instance().get_token(strategy) async def get_oauth2_token( client_id: str, client_secret: str, token_url: str, scope: str | None = None, ) -> TokenInfo: """ 获取 OAuth2 Token(使用全局 TokenManager) 示例: token_info = await get_oauth2_token( client_id="xxx", client_secret="yyy", token_url="https://api.example.com/oauth2/token" ) headers = {"Authorization": f"Bearer {token_info.token}"} """ strategy = OAuth2TokenStrategy( client_id=client_id, client_secret=client_secret, token_url=token_url, scope=scope, ) return await TokenManager.get_instance().get_token(strategy)