// Package ratelimit implements a Redis-backed sliding-window limiter. // // We use a sorted-set per workspace where the score is the unix-nano // timestamp. On each request we: // 1. ZREMRANGEBYSCORE -- evict entries older than window // 2. ZCARD -- count current // 3. if count < limit : ZADD + EXPIRE, allow // 4. else : compute retry-after from oldest entry, deny // // Steps 1-3/4 are wrapped in a Lua script for atomicity. package ratelimit import ( "context" "fmt" "strconv" "time" "github.com/redis/rueidis" ) type Decision struct { Allowed bool Remaining int RetryAfterMS int } type Limiter interface { Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error) } type redisLimiter struct { client rueidis.Client } func New(client rueidis.Client) Limiter { return &redisLimiter{client: client} } // Lua script: KEYS[1]=zset key, ARGV[1]=now_ms, ARGV[2]=window_ms, // ARGV[3]=limit, ARGV[4]=member (unique per request). // // Returns: {allowed (1/0), remaining, retry_after_ms} const slidingWindowLua = ` local key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) local member = ARGV[4] local cutoff = now - window redis.call('ZREMRANGEBYSCORE', key, 0, cutoff) local count = tonumber(redis.call('ZCARD', key)) if count < limit then redis.call('ZADD', key, now, member) redis.call('PEXPIRE', key, window) return {1, limit - count - 1, 0} end local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local retry = window if oldest and oldest[2] then retry = (tonumber(oldest[2]) + window) - now if retry < 0 then retry = 0 end end return {0, 0, retry} ` func (l *redisLimiter) Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error) { key := "rate:" + workspaceID now := time.Now().UnixMilli() member := strconv.FormatInt(now, 10) + ":" + workspaceID cmd := l.client.B().Eval().Script(slidingWindowLua). Numkeys(1). Key(key). Arg(strconv.FormatInt(now, 10), strconv.FormatInt(window.Milliseconds(), 10), strconv.Itoa(limit), member). Build() res := l.client.Do(ctx, cmd) if err := res.Error(); err != nil { return Decision{}, fmt.Errorf("ratelimit eval: %w", err) } arr, err := res.ToArray() if err != nil || len(arr) != 3 { return Decision{}, fmt.Errorf("ratelimit bad reply: %w", err) } allowed, _ := arr[0].AsInt64() remaining, _ := arr[1].AsInt64() retry, _ := arr[2].AsInt64() return Decision{ Allowed: allowed == 1, Remaining: int(remaining), RetryAfterMS: int(retry), }, nil }