194 lines
5.6 KiB
Go
194 lines
5.6 KiB
Go
// Package middleware provides chi-compatible HTTP middleware: auth, logging,
|
|
// payload-limit, request-id, panic recovery, CORS.
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/dbiz/cdp/ingestion/ingest/internal/apperr"
|
|
"github.com/dbiz/cdp/ingestion/ingest/internal/model"
|
|
"github.com/dbiz/cdp/ingestion/ingest/internal/service"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const (
|
|
ctxKeyRequestID ctxKey = "request_id"
|
|
ctxKeyWriteKey ctxKey = "write_key"
|
|
)
|
|
|
|
// RequestID assigns a uuid v4 to each request and stores it in context.
|
|
func RequestID(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
id := r.Header.Get("X-Request-Id")
|
|
if id == "" {
|
|
id = uuid.NewString()
|
|
}
|
|
ctx := context.WithValue(r.Context(), ctxKeyRequestID, id)
|
|
w.Header().Set("X-Request-Id", id)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func RequestIDFromCtx(ctx context.Context) string {
|
|
v, _ := ctx.Value(ctxKeyRequestID).(string)
|
|
return v
|
|
}
|
|
|
|
// Recover handles panics so a buggy handler can't take down the server.
|
|
func Recover(log *zap.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
log.Error("panic in handler",
|
|
zap.Any("panic", rec),
|
|
zap.String("path", r.URL.Path),
|
|
zap.ByteString("stack", debug.Stack()))
|
|
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// PayloadLimit caps the request body size to limitKB kilobytes.
|
|
func PayloadLimit(limitKB int) func(http.Handler) http.Handler {
|
|
max := int64(limitKB) * 1024
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, max)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Logger logs one structured line per request.
|
|
func Logger(log *zap.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
rw := &statusRecorder{ResponseWriter: w, status: 200}
|
|
next.ServeHTTP(rw, r)
|
|
log.Info("http",
|
|
zap.String("method", r.Method),
|
|
zap.String("path", r.URL.Path),
|
|
zap.Int("status", rw.status),
|
|
zap.Int64("duration_ms", time.Since(start).Milliseconds()),
|
|
zap.String("request_id", RequestIDFromCtx(r.Context())),
|
|
zap.String("ip", clientIP(r)))
|
|
})
|
|
}
|
|
}
|
|
|
|
// CORS returns a permissive CORS handler. Browser SDKs (web tracker) require it.
|
|
func CORS(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Request-Id")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Auth resolves the write key from the request and stores it in context.
|
|
// Accepts both `Authorization: Basic <base64(key:)>` (Segment-style) and
|
|
// `Authorization: Bearer <key>`.
|
|
func Auth(s *service.AuthService) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key, err := extractWriteKey(r)
|
|
if err != nil {
|
|
writeAuthError(w, err)
|
|
return
|
|
}
|
|
wk, err := s.Resolve(r.Context(), key)
|
|
if err != nil {
|
|
writeAuthError(w, err)
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), ctxKeyWriteKey, wk)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// WriteKeyFromCtx returns the resolved key set by Auth middleware.
|
|
func WriteKeyFromCtx(ctx context.Context) *model.WriteKey {
|
|
v, _ := ctx.Value(ctxKeyWriteKey).(*model.WriteKey)
|
|
return v
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func extractWriteKey(r *http.Request) (string, error) {
|
|
h := r.Header.Get("Authorization")
|
|
if h == "" {
|
|
return "", apperr.Unauthorized("missing Authorization header")
|
|
}
|
|
if strings.HasPrefix(h, "Bearer ") {
|
|
return strings.TrimPrefix(h, "Bearer "), nil
|
|
}
|
|
if strings.HasPrefix(h, "Basic ") {
|
|
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(h, "Basic "))
|
|
if err != nil {
|
|
return "", apperr.Unauthorized("invalid basic auth")
|
|
}
|
|
// Segment uses `key:` (no password). Take everything before the first colon.
|
|
s := string(raw)
|
|
if i := strings.Index(s, ":"); i >= 0 {
|
|
return s[:i], nil
|
|
}
|
|
return s, nil
|
|
}
|
|
return "", apperr.Unauthorized("unsupported auth scheme")
|
|
}
|
|
|
|
func writeAuthError(w http.ResponseWriter, err error) {
|
|
if ae, ok := apperr.As(err); ok {
|
|
http.Error(w, `{"error":"`+ae.Message+`"}`, ae.Code)
|
|
return
|
|
}
|
|
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
if i := strings.Index(xff, ","); i >= 0 {
|
|
return strings.TrimSpace(xff[:i])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
if rip := r.Header.Get("X-Real-Ip"); rip != "" {
|
|
return rip
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (s *statusRecorder) WriteHeader(code int) {
|
|
s.status = code
|
|
s.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|