feat: init meijiaka-zj project from ai-meijiaka template
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
TokenManager 使用示例
|
||||
|
||||
展示如何在 Provider 中使用 TokenManager 来管理认证 Token。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from app.core.token_manager import (
|
||||
JWTTokenStrategy,
|
||||
OAuth2TokenStrategy,
|
||||
TokenManager,
|
||||
get_jwt_token,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_jwt():
|
||||
"""JWT Token 示例(KlingAI 模式)"""
|
||||
print("=" * 60)
|
||||
print("JWT Token 示例 (KlingAI)")
|
||||
print("=" * 60)
|
||||
|
||||
# 方法1: 使用便捷函数(推荐简单场景)
|
||||
try:
|
||||
token_info = await get_jwt_token(
|
||||
access_key="test_access_key",
|
||||
secret_key="test_secret_key",
|
||||
)
|
||||
print(f"Token: {token_info.token[:50]}...")
|
||||
print(f"Expires in: {token_info.expires_in:.0f} seconds")
|
||||
print(f"Is expired: {token_info.is_expired}")
|
||||
except Exception as e:
|
||||
print(f"JWT generation failed (expected in demo): {e}")
|
||||
|
||||
# 方法2: 使用 TokenManager + Strategy(推荐 Provider 集成)
|
||||
strategy = JWTTokenStrategy(
|
||||
access_key="your_access_key",
|
||||
secret_key="your_secret_key",
|
||||
expires_in=1800, # 30分钟
|
||||
)
|
||||
|
||||
# 第一次获取会生成新 token
|
||||
token1 = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"\nFirst token: {token1.token[:30]}...")
|
||||
|
||||
# 第二次获取会命中缓存(如果未过期)
|
||||
token2 = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"Second token: {token2.token[:30]}...")
|
||||
print(f"Same token: {token1.token == token2.token}")
|
||||
|
||||
# 查看缓存统计
|
||||
stats = TokenManager.get_instance().get_stats()
|
||||
print(f"\nCache stats: {stats}")
|
||||
|
||||
|
||||
async def example_oauth2():
|
||||
"""OAuth2 Token 示例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("OAuth2 Token 示例")
|
||||
print("=" * 60)
|
||||
|
||||
strategy = OAuth2TokenStrategy(
|
||||
client_id="your_client_id",
|
||||
client_secret="your_client_secret",
|
||||
token_url="https://api.example.com/oauth2/token",
|
||||
scope="read write",
|
||||
)
|
||||
|
||||
print("OAuth2 strategy created")
|
||||
print(f"Cache key: {strategy.get_cache_key()}")
|
||||
|
||||
|
||||
async def example_provider_integration():
|
||||
"""Provider 集成示例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Provider 集成示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 这是一个模拟的 Provider 类
|
||||
class ExampleProvider:
|
||||
def __init__(self, access_key: str, secret_key: str):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self._token_strategy = JWTTokenStrategy(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
expires_in=1800,
|
||||
)
|
||||
|
||||
async def _get_headers(self) -> dict[str, str]:
|
||||
"""获取带认证的请求头"""
|
||||
token_info = await TokenManager.get_instance().get_token(self._token_strategy)
|
||||
return {
|
||||
"Authorization": f"Bearer {token_info.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def make_request(self):
|
||||
"""模拟 API 请求"""
|
||||
headers = await self._get_headers()
|
||||
print(f"Request headers: {headers}")
|
||||
# 实际使用时: await session.post(url, headers=headers, ...)
|
||||
|
||||
provider = ExampleProvider("access_key_123", "secret_key_456")
|
||||
await provider.make_request()
|
||||
|
||||
|
||||
async def example_concurrent_requests():
|
||||
"""并发请求示例 - 测试 token 刷新时的并发安全"""
|
||||
print("\n" + "=" * 60)
|
||||
print("并发请求示例")
|
||||
print("=" * 60)
|
||||
|
||||
strategy = JWTTokenStrategy(
|
||||
access_key="concurrent_test_key",
|
||||
secret_key="concurrent_test_secret",
|
||||
expires_in=1800,
|
||||
)
|
||||
|
||||
async def request_task(task_id: int):
|
||||
"""模拟单个请求"""
|
||||
token_info = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"Task {task_id}: got token (expires in {token_info.expires_in:.0f}s)")
|
||||
return token_info
|
||||
|
||||
# 并发10个请求,应该只触发一次 token 生成
|
||||
print("Launching 10 concurrent requests...")
|
||||
results = await asyncio.gather(*[request_task(i) for i in range(10)])
|
||||
|
||||
# 验证所有请求拿到的是同一个 token
|
||||
tokens = [r.token for r in results]
|
||||
unique_tokens = set(tokens)
|
||||
print(f"\nTotal requests: {len(tokens)}")
|
||||
print(f"Unique tokens generated: {len(unique_tokens)}")
|
||||
print(f"Concurrent safety: {'✓ PASS' if len(unique_tokens) == 1 else '✗ FAIL'}")
|
||||
|
||||
|
||||
async def example_stats():
|
||||
"""查看 TokenManager 统计信息"""
|
||||
print("\n" + "=" * 60)
|
||||
print("TokenManager 统计")
|
||||
print("=" * 60)
|
||||
|
||||
manager = TokenManager.get_instance()
|
||||
stats = manager.get_stats()
|
||||
|
||||
print(f"Total cached tokens: {stats['total_cached']}")
|
||||
print(f"Active background tasks: {stats['active_tasks']}")
|
||||
print(f"Token details: {stats['tokens']}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行所有示例"""
|
||||
await example_jwt()
|
||||
await example_oauth2()
|
||||
await example_provider_integration()
|
||||
await example_concurrent_requests()
|
||||
await example_stats()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有示例完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user