// Package middleware provides chi-compatible HTTP middleware: auth, logging, // payload-limit, request-id, panic recovery, CORS. package middleware import ( "context" "encoding/base64" "net/http" "runtime/debug" "strings" "time" "github.com/google/uuid" "go.uber.org/zap" "github.com/dbiz/cdp/ingestion/ingest/internal/apperr" "github.com/dbiz/cdp/ingestion/ingest/internal/model" "github.com/dbiz/cdp/ingestion/ingest/internal/service" ) type ctxKey string const ( ctxKeyRequestID ctxKey = "request_id" ctxKeyWriteKey ctxKey = "write_key" ) // RequestID assigns a uuid v4 to each request and stores it in context. func RequestID(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := r.Header.Get("X-Request-Id") if id == "" { id = uuid.NewString() } ctx := context.WithValue(r.Context(), ctxKeyRequestID, id) w.Header().Set("X-Request-Id", id) next.ServeHTTP(w, r.WithContext(ctx)) }) } func RequestIDFromCtx(ctx context.Context) string { v, _ := ctx.Value(ctxKeyRequestID).(string) return v } // Recover handles panics so a buggy handler can't take down the server. func Recover(log *zap.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { log.Error("panic in handler", zap.Any("panic", rec), zap.String("path", r.URL.Path), zap.ByteString("stack", debug.Stack())) http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } } // PayloadLimit caps the request body size to limitKB kilobytes. func PayloadLimit(limitKB int) func(http.Handler) http.Handler { max := int64(limitKB) * 1024 return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, max) next.ServeHTTP(w, r) }) } } // Logger logs one structured line per request. func Logger(log *zap.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rw := &statusRecorder{ResponseWriter: w, status: 200} next.ServeHTTP(rw, r) log.Info("http", zap.String("method", r.Method), zap.String("path", r.URL.Path), zap.Int("status", rw.status), zap.Int64("duration_ms", time.Since(start).Milliseconds()), zap.String("request_id", RequestIDFromCtx(r.Context())), zap.String("ip", clientIP(r))) }) } } // CORS returns a permissive CORS handler. Browser SDKs (web tracker) require it. func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Request-Id") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // Auth resolves the write key from the request and stores it in context. // Accepts both `Authorization: Basic ` (Segment-style) and // `Authorization: Bearer `. func Auth(s *service.AuthService) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key, err := extractWriteKey(r) if err != nil { writeAuthError(w, err) return } wk, err := s.Resolve(r.Context(), key) if err != nil { writeAuthError(w, err) return } ctx := context.WithValue(r.Context(), ctxKeyWriteKey, wk) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // WriteKeyFromCtx returns the resolved key set by Auth middleware. func WriteKeyFromCtx(ctx context.Context) *model.WriteKey { v, _ := ctx.Value(ctxKeyWriteKey).(*model.WriteKey) return v } // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- func extractWriteKey(r *http.Request) (string, error) { h := r.Header.Get("Authorization") if h == "" { return "", apperr.Unauthorized("missing Authorization header") } if strings.HasPrefix(h, "Bearer ") { return strings.TrimPrefix(h, "Bearer "), nil } if strings.HasPrefix(h, "Basic ") { raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(h, "Basic ")) if err != nil { return "", apperr.Unauthorized("invalid basic auth") } // Segment uses `key:` (no password). Take everything before the first colon. s := string(raw) if i := strings.Index(s, ":"); i >= 0 { return s[:i], nil } return s, nil } return "", apperr.Unauthorized("unsupported auth scheme") } func writeAuthError(w http.ResponseWriter, err error) { if ae, ok := apperr.As(err); ok { http.Error(w, `{"error":"`+ae.Message+`"}`, ae.Code) return } http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) } func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if i := strings.Index(xff, ","); i >= 0 { return strings.TrimSpace(xff[:i]) } return strings.TrimSpace(xff) } if rip := r.Header.Get("X-Real-Ip"); rip != "" { return rip } return r.RemoteAddr } type statusRecorder struct { http.ResponseWriter status int } func (s *statusRecorder) WriteHeader(code int) { s.status = code s.ResponseWriter.WriteHeader(code) } // Flush delegates so SSE handlers can still call w.(http.Flusher).Flush() // after the Logger middleware wraps the original ResponseWriter. func (s *statusRecorder) Flush() { if f, ok := s.ResponseWriter.(http.Flusher); ok { f.Flush() } }