103 lines
2.6 KiB
Go
103 lines
2.6 KiB
Go
// Package ratelimit implements a Redis-backed sliding-window limiter.
|
|
//
|
|
// We use a sorted-set per workspace where the score is the unix-nano
|
|
// timestamp. On each request we:
|
|
// 1. ZREMRANGEBYSCORE -- evict entries older than window
|
|
// 2. ZCARD -- count current
|
|
// 3. if count < limit : ZADD + EXPIRE, allow
|
|
// 4. else : compute retry-after from oldest entry, deny
|
|
//
|
|
// Steps 1-3/4 are wrapped in a Lua script for atomicity.
|
|
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/redis/rueidis"
|
|
)
|
|
|
|
type Decision struct {
|
|
Allowed bool
|
|
Remaining int
|
|
RetryAfterMS int
|
|
}
|
|
|
|
type Limiter interface {
|
|
Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error)
|
|
}
|
|
|
|
type redisLimiter struct {
|
|
client rueidis.Client
|
|
}
|
|
|
|
func New(client rueidis.Client) Limiter {
|
|
return &redisLimiter{client: client}
|
|
}
|
|
|
|
// Lua script: KEYS[1]=zset key, ARGV[1]=now_ms, ARGV[2]=window_ms,
|
|
// ARGV[3]=limit, ARGV[4]=member (unique per request).
|
|
//
|
|
// Returns: {allowed (1/0), remaining, retry_after_ms}
|
|
const slidingWindowLua = `
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local limit = tonumber(ARGV[3])
|
|
local member = ARGV[4]
|
|
local cutoff = now - window
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, 0, cutoff)
|
|
local count = tonumber(redis.call('ZCARD', key))
|
|
|
|
if count < limit then
|
|
redis.call('ZADD', key, now, member)
|
|
redis.call('PEXPIRE', key, window)
|
|
return {1, limit - count - 1, 0}
|
|
end
|
|
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
|
local retry = window
|
|
if oldest and oldest[2] then
|
|
retry = (tonumber(oldest[2]) + window) - now
|
|
if retry < 0 then retry = 0 end
|
|
end
|
|
return {0, 0, retry}
|
|
`
|
|
|
|
func (l *redisLimiter) Allow(ctx context.Context, workspaceID string, limit int, window time.Duration) (Decision, error) {
|
|
key := "rate:" + workspaceID
|
|
now := time.Now().UnixMilli()
|
|
member := strconv.FormatInt(now, 10) + ":" + workspaceID
|
|
|
|
cmd := l.client.B().Eval().Script(slidingWindowLua).
|
|
Numkeys(1).
|
|
Key(key).
|
|
Arg(strconv.FormatInt(now, 10),
|
|
strconv.FormatInt(window.Milliseconds(), 10),
|
|
strconv.Itoa(limit),
|
|
member).
|
|
Build()
|
|
|
|
res := l.client.Do(ctx, cmd)
|
|
if err := res.Error(); err != nil {
|
|
return Decision{}, fmt.Errorf("ratelimit eval: %w", err)
|
|
}
|
|
|
|
arr, err := res.ToArray()
|
|
if err != nil || len(arr) != 3 {
|
|
return Decision{}, fmt.Errorf("ratelimit bad reply: %w", err)
|
|
}
|
|
allowed, _ := arr[0].AsInt64()
|
|
remaining, _ := arr[1].AsInt64()
|
|
retry, _ := arr[2].AsInt64()
|
|
|
|
return Decision{
|
|
Allowed: allowed == 1,
|
|
Remaining: int(remaining),
|
|
RetryAfterMS: int(retry),
|
|
}, nil
|
|
}
|