""" 全局并发槽位管理器 ================== 基于 Redis SET + Lua 脚本实现严格原子的槽位申请与释放。 """ import logging from collections.abc import Awaitable from contextlib import asynccontextmanager from typing import cast from redis.asyncio import Redis logger = logging.getLogger(__name__) # Lua 脚本:原子执行 SADD -> SCARD -> 条件 SREM _ACQUIRE_LUA = """ local key = KEYS[1] local slot_id = ARGV[1] local max_slots = tonumber(ARGV[2]) redis.call('sadd', key, slot_id) local count = redis.call('scard', key) if count > max_slots then redis.call('srem', key, slot_id) return 0 end redis.call('expire', key, 1200) return 1 """ class SlotManager: """全局并发槽位管理器""" def __init__(self, redis: Redis): self.redis = redis async def acquire(self, slot_key: str, slot_id: str, max_slots: int) -> bool: """申请一个槽位。返回 True 表示成功,False 表示槽位已满。""" try: result = await cast( Awaitable[str], self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots)) ) acquired = result == 1 if acquired: logger.debug(f"Slot acquired: {slot_key}/{slot_id} (max={max_slots})") else: logger.debug(f"Slot full: {slot_key}/{slot_id} (max={max_slots})") return acquired except Exception as e: logger.warning(f"Slot acquire error: {slot_key}/{slot_id}: {e}") return False async def release(self, slot_key: str, slot_id: str) -> None: """释放一个槽位。""" try: await cast(Awaitable[int], self.redis.srem(slot_key, slot_id)) logger.debug(f"Slot released: {slot_key}/{slot_id}") except Exception as e: logger.warning(f"Slot release error: {slot_key}/{slot_id}: {e}") @asynccontextmanager async def acquire_ctx(self, slot_key: str, slot_id: str, max_slots: int): """槽位申请上下文管理器。 用法: async with slots.acquire_ctx(SLOT_KEY, task.task_id, max_slots) as acquired: if not acquired: continue # 业务逻辑... 退出时自动 release """ acquired = await self.acquire(slot_key, slot_id, max_slots) if acquired: try: yield True finally: await self.release(slot_key, slot_id) else: yield False async def count(self, slot_key: str) -> int: """获取当前已占用的槽位数量。""" try: return await cast(Awaitable[int], self.redis.scard(slot_key)) except Exception: return 0