88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
"""
|
|
全局并发槽位管理器
|
|
==================
|
|
|
|
基于 Redis SET + Lua 脚本实现严格原子的槽位申请与释放。
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Awaitable
|
|
from contextlib import asynccontextmanager
|
|
from typing import cast
|
|
|
|
from redis.asyncio import Redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Lua 脚本:原子执行 SADD -> SCARD -> 条件 SREM
|
|
_ACQUIRE_LUA = """
|
|
local key = KEYS[1]
|
|
local slot_id = ARGV[1]
|
|
local max_slots = tonumber(ARGV[2])
|
|
redis.call('sadd', key, slot_id)
|
|
local count = redis.call('scard', key)
|
|
if count > max_slots then
|
|
redis.call('srem', key, slot_id)
|
|
return 0
|
|
end
|
|
redis.call('expire', key, 1200)
|
|
return 1
|
|
"""
|
|
|
|
|
|
class SlotManager:
|
|
"""全局并发槽位管理器"""
|
|
|
|
def __init__(self, redis: Redis):
|
|
self.redis = redis
|
|
|
|
async def acquire(self, slot_key: str, slot_id: str, max_slots: int) -> bool:
|
|
"""申请一个槽位。返回 True 表示成功,False 表示槽位已满。"""
|
|
try:
|
|
result = await cast(
|
|
Awaitable[str], self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots))
|
|
)
|
|
acquired = result == 1
|
|
if acquired:
|
|
logger.debug(f"Slot acquired: {slot_key}/{slot_id} (max={max_slots})")
|
|
else:
|
|
logger.debug(f"Slot full: {slot_key}/{slot_id} (max={max_slots})")
|
|
return acquired
|
|
except Exception as e:
|
|
logger.warning(f"Slot acquire error: {slot_key}/{slot_id}: {e}")
|
|
return False
|
|
|
|
async def release(self, slot_key: str, slot_id: str) -> None:
|
|
"""释放一个槽位。"""
|
|
try:
|
|
await cast(Awaitable[int], self.redis.srem(slot_key, slot_id))
|
|
logger.debug(f"Slot released: {slot_key}/{slot_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Slot release error: {slot_key}/{slot_id}: {e}")
|
|
|
|
@asynccontextmanager
|
|
async def acquire_ctx(self, slot_key: str, slot_id: str, max_slots: int):
|
|
"""槽位申请上下文管理器。
|
|
|
|
用法:
|
|
async with slots.acquire_ctx(SLOT_KEY, task.task_id, max_slots) as acquired:
|
|
if not acquired:
|
|
continue
|
|
# 业务逻辑... 退出时自动 release
|
|
"""
|
|
acquired = await self.acquire(slot_key, slot_id, max_slots)
|
|
if acquired:
|
|
try:
|
|
yield True
|
|
finally:
|
|
await self.release(slot_key, slot_id)
|
|
else:
|
|
yield False
|
|
|
|
async def count(self, slot_key: str) -> int:
|
|
"""获取当前已占用的槽位数量。"""
|
|
try:
|
|
return await cast(Awaitable[int], self.redis.scard(slot_key))
|
|
except Exception:
|
|
return 0
|