Files
2026-05-24 22:59:24 +07:00

103 lines
2.6 KiB
Go

// Package ratelimit implements a Redis-backed sliding-window limiter.
//
// We use a sorted-set per workspace where the score is the unix-nano
// timestamp. On each request we:
// 1. ZREMRANGEBYSCORE -- evict entries older than window
// 2. ZCARD -- count current
// 3. if count < limit : ZADD + EXPIRE, allow
// 4. else : compute retry-after from oldest entry, deny
//
// Steps 1-3/4 are wrapped in a Lua script for atomicity.
package ratelimit
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/rueidis"
)
type Decision struct {
Allowed bool
Remaining int
RetryAfterMS int
}
type Limiter interface {
Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error)
}
type redisLimiter struct {
client rueidis.Client
}
func New(client rueidis.Client) Limiter {
return &redisLimiter{client: client}
}
// Lua script: KEYS[1]=zset key, ARGV[1]=now_ms, ARGV[2]=window_ms,
// ARGV[3]=limit, ARGV[4]=member (unique per request).
//
// Returns: {allowed (1/0), remaining, retry_after_ms}
const slidingWindowLua = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local member = ARGV[4]
local cutoff = now - window
redis.call('ZREMRANGEBYSCORE', key, 0, cutoff)
local count = tonumber(redis.call('ZCARD', key))
if count < limit then
redis.call('ZADD', key, now, member)
redis.call('PEXPIRE', key, window)
return {1, limit - count - 1, 0}
end
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry = window
if oldest and oldest[2] then
retry = (tonumber(oldest[2]) + window) - now
if retry < 0 then retry = 0 end
end
return {0, 0, retry}
`
func (l *redisLimiter) Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error) {
key := "rate:" + workspaceID
now := time.Now().UnixMilli()
member := strconv.FormatInt(now, 10) + ":" + workspaceID
cmd := l.client.B().Eval().Script(slidingWindowLua).
Numkeys(1).
Key(key).
Arg(strconv.FormatInt(now, 10),
strconv.FormatInt(window.Milliseconds(), 10),
strconv.Itoa(limit),
member).
Build()
res := l.client.Do(ctx, cmd)
if err := res.Error(); err != nil {
return Decision{}, fmt.Errorf("ratelimit eval: %w", err)
}
arr, err := res.ToArray()
if err != nil || len(arr) != 3 {
return Decision{}, fmt.Errorf("ratelimit bad reply: %w", err)
}
allowed, _ := arr[0].AsInt64()
remaining, _ := arr[1].AsInt64()
retry, _ := arr[2].AsInt64()
return Decision{
Allowed: allowed == 1,
Remaining: int(remaining),
RetryAfterMS: int(retry),
}, nil
}