redis 限流
redis 令牌桶限流
@dataclass(frozen=True)
class RateLimitResult:
allowed: bool
remaining_tokens: float
retry_after_seconds: float
now_ms: int
class RedisTokenBucketRateLimiter:
_TOKEN_BUCKET_LUA = r"""
local key = KEYS[1] -- 令牌桶状态存储的 Redis Key(Hash)
local capacity = tonumber(ARGV[1]) -- 桶容量:最多可累积的令牌数
local rate = tonumber(ARGV[2]) -- 补充速率:每秒补充多少令牌(tokens/s)
local requested = tonumber(ARGV[3]) -- 本次请求要消耗的令牌数
local ttl_ms = tonumber(ARGV[4]) -- 令牌桶 Key 的过期时间(毫秒),用于回收冷 Key
local t = redis.call('TIME') -- 读取 Redis 服务器时间:{秒, 微秒}
local now = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000) -- 服务器时间换算为毫秒时间戳
local data = redis.call('HMGET', key, 'tokens', 'ts') -- 读取当前令牌数(tokens)与上次更新时间(ts)
local tokens = tonumber(data[1]) -- 当前令牌数(可能为 nil)
local ts = tonumber(data[2]) -- 上次更新时间戳(毫秒,可能为 nil)
if tokens == nil then tokens = capacity end -- 初始化:若不存在,则认为桶是满的
if ts == nil then ts = now end -- 初始化:若不存在,则将上次更新时间设为当前
if now < ts then ts = now end -- 防御:若时钟回拨导致 now < ts,则强制对齐
local delta_ms = now - ts -- 距离上次更新过去了多少毫秒
local refill = (delta_ms / 1000.0) * rate -- 这段时间应补充的令牌数量(可为小数)
tokens = math.min(capacity, tokens + refill) -- 将令牌补充后并截断到容量上限
local allowed = 0 -- 是否允许:0/1(Lua 没有布尔返回给 Python 的统一类型)
local retry_after_ms = 0 -- 如果不允许,建议等待多少毫秒再试
if tokens >= requested then -- 若令牌充足
allowed = 1 -- 标记允许
tokens = tokens - requested -- 扣除本次消耗
else -- 否则令牌不足
allowed = 0 -- 标记拒绝
if rate > 0 then -- 若补充速率 > 0,可计算需要等待的时间
local missing = requested - tokens -- 还差多少令牌
retry_after_ms = math.ceil((missing / rate) * 1000) -- 差额按速率换算为等待毫秒数(向上取整)
else -- 若 rate=0,则永远补不回令牌
retry_after_ms = -1 -- 用 -1 表示无法通过等待获得令牌
end -- 结束 rate>0 分支
end -- 结束 tokens>=requested 分支
redis.call('HMSET', key, 'tokens', tokens, 'ts', now) -- 写回最新 tokens 与 ts(原子更新)
if ttl_ms ~= nil and ttl_ms > 0 then -- 若配置了过期时间且 >0
redis.call('PEXPIRE', key, ttl_ms) -- 给令牌桶 Key 设置过期(毫秒)
end -- 结束 ttl 分支
return {allowed, tokens, retry_after_ms, now} -- 返回:是否允许、剩余令牌、建议等待(ms)、当前时间(ms)
"""
def __init__(
self,
redis: Optional[AsyncRedis] = None,
*,
capacity: Union[int, float],
refill_rate: Union[int, float],
requested: Union[int, float] = 1,
ttl_ms: Optional[int] = None,
key_prefix: str = "rate_limit:token_bucket:",
) -> None:
capacity_f = float(capacity)
rate_f = float(refill_rate)
req_f = float(requested)
if capacity_f <= 0:
raise ValueError("capacity 必须 > 0")
if rate_f < 0:
raise ValueError("refill_rate 必须 >= 0")
if req_f <= 0:
raise ValueError("requested 必须 > 0")
if req_f > capacity_f:
raise ValueError("requested 不能大于 capacity")
self._redis = redis or async_redis_client
self._key_prefix = key_prefix
self._capacity = capacity_f
self._refill_rate = rate_f
self._requested = req_f
self._ttl_ms = int(ttl_ms) if ttl_ms is not None else None
def _full_key(self, key: str) -> str:
return f"{self._key_prefix}{key}"
@staticmethod
def _default_ttl_ms(
capacity: Union[int, float], refill_rate: Union[int, float]
) -> int:
# 让桶在“完全补满所需时间”的 2 倍后过期,避免大量冷 key 常驻
# 至少 5 秒,防止极小桶/极大 rate 造成频繁抖动
if refill_rate <= 0:
return 60_000
seconds = max(5.0, float(capacity) / float(refill_rate) * 2.0)
return int(seconds * 1000)
async def allow(
self,
key: str,
*,
requested: Optional[Union[int, float]] = None,
ttl_ms: Optional[int] = None,
) -> RateLimitResult:
"""
尝试拿令牌(令牌桶)。
- capacity/refill_rate/requested 默认来自 __init__
- requested/ttl_ms 允许在单次调用中覆盖
"""
capacity_f = self._capacity
rate_f = self._refill_rate
req_f = float(requested) if requested is not None else self._requested
if req_f <= 0:
raise ValueError("requested 必须 > 0")
if req_f > capacity_f:
raise ValueError("requested 不能大于 capacity")
ttl_ms_i = (
int(ttl_ms)
if ttl_ms is not None
else (
self._ttl_ms
if self._ttl_ms is not None
else self._default_ttl_ms(capacity_f, rate_f)
)
)
full_key = self._full_key(key)
res = await self._redis.eval(
self._TOKEN_BUCKET_LUA,
1,
full_key,
capacity_f,
rate_f,
req_f,
ttl_ms_i,
)
allowed = bool(int(res[0]))
remaining = float(res[1])
retry_after_ms = int(res[2])
now_ms = int(res[3])
retry_after_seconds = 0.0
if not allowed:
retry_after_seconds = (
0.0 if retry_after_ms <= 0 else retry_after_ms / 1000.0
)
return RateLimitResult(
allowed=allowed,
remaining_tokens=remaining,
retry_after_seconds=retry_after_seconds,
now_ms=now_ms,
)
async def wait(
self,
key: str,
*,
requested: Optional[Union[int, float]] = None,
ttl_ms: Optional[int] = None,
max_wait_seconds: Optional[float] = None,
poll_min_sleep: float = 0.01,
) -> RateLimitResult:
"""
一直等到拿到令牌为止(适合 Activity 里对外部依赖做限流)。
max_wait_seconds: 超过则抛 TimeoutError
"""
start = time.perf_counter()
while True:
r = await self.allow(
key,
requested=requested,
ttl_ms=ttl_ms,
)
if r.allowed:
return r
sleep_s = max(poll_min_sleep, r.retry_after_seconds)
if max_wait_seconds is not None:
elapsed = time.perf_counter() - start
if elapsed + sleep_s > max_wait_seconds:
raise TimeoutError(
f"rate limit wait timeout: key={key}, waited={elapsed:.3f}s, next_sleep={sleep_s:.3f}s"
)
await asyncio.sleep(sleep_s)
redis 滑动窗口限流
@dataclass(frozen=True)
class SlidingWindowResult:
allowed: bool
current_count: int
retry_after_seconds: float
now_ms: int
class RedisSlidingWindowRateLimiter:
"""
Redis 滑动窗口限流(ZSET)。
语义:任意连续 window_seconds 秒内最多 limit 次通过。
"""
_SLIDING_WINDOW_LUA = r"""
local key = KEYS[1] -- 滑动窗口限流使用的 ZSET Key
local window_ms = tonumber(ARGV[1]) -- 窗口长度(毫秒)
local limit = tonumber(ARGV[2]) -- 窗口内允许的最大通过次数
local ttl_ms = tonumber(ARGV[3]) -- Key 过期时间(毫秒),用于回收冷 Key
-- Redis server time (ms) -- 注释:以下用 Redis TIME 作为统一时钟,避免多机时间不一致
local t = redis.call('TIME') -- 读取 Redis 服务器时间:{秒, 微秒}
local now = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000) -- 转为毫秒时间戳
-- window cleanup -- 注释:删除窗口之外的旧请求记录
local start = now - window_ms -- 窗口起点(毫秒)
redis.call('ZREMRANGEBYSCORE', key, '-inf', start) -- 删除 score <= start 的成员(窗口外)
local cnt = tonumber(redis.call('ZCARD', key)) -- 统计窗口内当前已有多少次请求
if cnt < limit then -- 若窗口内次数仍未达到上限(严格 <,避免多放 1 次)
-- unique member to avoid overwrite under same ms -- 注释:member 唯一化,避免同毫秒覆盖导致少计数
local seq = redis.call('INCR', key .. ':seq') -- 递增序列号(辅助唯一 member)
local member = tostring(now) .. ':' .. tostring(seq) -- member = now:seq
redis.call('ZADD', key, now, member) -- 写入本次请求记录:score=now(member 的时间)
if ttl_ms ~= nil and ttl_ms > 0 then -- 若配置了过期时间且 >0
redis.call('PEXPIRE', key, ttl_ms) -- 给 ZSET Key 设置过期时间
redis.call('PEXPIRE', key .. ':seq', ttl_ms) -- 给序列号 Key 也设置过期时间
end -- 结束 ttl 分支
return {1, cnt + 1, 0, now} -- 返回:允许(1)、窗口内计数(含本次)、建议等待(ms=0)、now(ms)
else -- 否则窗口内次数已满,需要拒绝
-- compute retry_after_ms: when the oldest entry exits window -- 注释:计算何时最早记录滑出窗口
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') -- 取窗口内最早的一条(score 最小)
local oldest_ts = nil -- 最早记录的时间戳(毫秒)
if oldest ~= nil and #oldest >= 2 then -- 若存在成员且带 score
oldest_ts = tonumber(oldest[2]) -- oldest[2] 为 score(ZSET 的时间戳)
end -- 结束 oldest 解析
local retry_after = 0 -- 建议等待的毫秒数
if oldest_ts ~= nil then -- 若能拿到最早时间戳
retry_after = math.max(0, math.ceil((oldest_ts + window_ms) - now)) -- 等到 oldest_ts+window_ms 才会滑出
else -- 极端情况:ZSET 为空但 cnt>=limit(理论不应发生)
retry_after = math.ceil(window_ms) -- 保守等待一个窗口长度
end -- 结束 oldest_ts 分支
if ttl_ms ~= nil and ttl_ms > 0 then -- 若配置了过期时间且 >0
redis.call('PEXPIRE', key, ttl_ms) -- 维持 key 的过期时间(避免热 key 被误删)
redis.call('PEXPIRE', key .. ':seq', ttl_ms) -- 维持 seq key 的过期时间
end -- 结束 ttl 分支
return {0, cnt, retry_after, now} -- 返回:拒绝(0)、当前窗口计数、建议等待(ms)、now(ms)
end -- 结束 if cnt < limit 分支
"""
def __init__(
self,
redis: Optional[AsyncRedis] = None,
*,
limit: int,
window_seconds: float,
ttl_ms: Optional[int] = None,
key_prefix: str = "rate_limit:sliding_window:",
) -> None:
if limit <= 0:
raise ValueError("limit 必须 > 0")
if window_seconds <= 0:
raise ValueError("window_seconds 必须 > 0")
self._redis = redis or async_redis_client
self._key_prefix = key_prefix
self._limit = int(limit)
self._window_seconds = float(window_seconds)
self._ttl_ms = int(ttl_ms) if ttl_ms is not None else None
def _full_key(self, key: str) -> str:
return f"{self._key_prefix}{key}"
@staticmethod
def _default_ttl_ms(window_seconds: float) -> int:
# key 存活时间略大于窗口,方便清理且避免冷 key 常驻
return int(max(5.0, window_seconds * 2.0) * 1000)
async def allow(
self,
key: str,
*,
ttl_ms: Optional[int] = None,
) -> SlidingWindowResult:
window_ms = int(self._window_seconds * 1000)
ttl_ms_i = (
int(ttl_ms)
if ttl_ms is not None
else (
self._ttl_ms
if self._ttl_ms is not None
else self._default_ttl_ms(self._window_seconds)
)
)
res = await self._redis.eval(
self._SLIDING_WINDOW_LUA,
1,
self._full_key(key),
window_ms,
self._limit,
ttl_ms_i,
)
allowed = bool(int(res[0]))
current_count = int(res[1])
retry_after_ms = int(res[2])
now_ms = int(res[3])
return SlidingWindowResult(
allowed=allowed,
current_count=current_count,
retry_after_seconds=retry_after_ms / 1000.0 if retry_after_ms > 0 else 0.0,
now_ms=now_ms,
)
async def wait(
self,
key: str,
*,
ttl_ms: Optional[int] = None,
max_wait_seconds: Optional[float] = None,
poll_min_sleep: float = 0.01,
) -> SlidingWindowResult:
start = asyncio.get_running_loop().time()
while True:
r = await self.allow(
key,
ttl_ms=ttl_ms,
)
if r.allowed:
return r
sleep_s = max(poll_min_sleep, r.retry_after_seconds)
if max_wait_seconds is not None:
elapsed = asyncio.get_running_loop().time() - start
if elapsed + sleep_s > max_wait_seconds:
raise TimeoutError(
f"sliding window wait timeout: key={key}, waited={elapsed:.3f}s, next_sleep={sleep_s:.3f}s"
)
await asyncio.sleep(sleep_s)
Comments | NOTHING