init ingestion
This commit is contained in:
193
ingestion/ingest/internal/middleware/middleware.go
Normal file
193
ingestion/ingest/internal/middleware/middleware.go
Normal file
@@ -0,0 +1,193 @@
|
||||
// Package middleware provides chi-compatible HTTP middleware: auth, logging,
|
||||
// payload-limit, request-id, panic recovery, CORS.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/dbiz/cdp/ingestion/ingest/internal/apperr"
|
||||
"github.com/dbiz/cdp/ingestion/ingest/internal/model"
|
||||
"github.com/dbiz/cdp/ingestion/ingest/internal/service"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
ctxKeyRequestID ctxKey = "request_id"
|
||||
ctxKeyWriteKey ctxKey = "write_key"
|
||||
)
|
||||
|
||||
// RequestID assigns a uuid v4 to each request and stores it in context.
|
||||
func RequestID(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.Header.Get("X-Request-Id")
|
||||
if id == "" {
|
||||
id = uuid.NewString()
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxKeyRequestID, id)
|
||||
w.Header().Set("X-Request-Id", id)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func RequestIDFromCtx(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyRequestID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// Recover handles panics so a buggy handler can't take down the server.
|
||||
func Recover(log *zap.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Error("panic in handler",
|
||||
zap.Any("panic", rec),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.ByteString("stack", debug.Stack()))
|
||||
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PayloadLimit caps the request body size to limitKB kilobytes.
|
||||
func PayloadLimit(limitKB int) func(http.Handler) http.Handler {
|
||||
max := int64(limitKB) * 1024
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, max)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Logger logs one structured line per request.
|
||||
func Logger(log *zap.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rw := &statusRecorder{ResponseWriter: w, status: 200}
|
||||
next.ServeHTTP(rw, r)
|
||||
log.Info("http",
|
||||
zap.String("method", r.Method),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.Int("status", rw.status),
|
||||
zap.Int64("duration_ms", time.Since(start).Milliseconds()),
|
||||
zap.String("request_id", RequestIDFromCtx(r.Context())),
|
||||
zap.String("ip", clientIP(r)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CORS returns a permissive CORS handler. Browser SDKs (web tracker) require it.
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Request-Id")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Auth resolves the write key from the request and stores it in context.
|
||||
// Accepts both `Authorization: Basic <base64(key:)>` (Segment-style) and
|
||||
// `Authorization: Bearer <key>`.
|
||||
func Auth(s *service.AuthService) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key, err := extractWriteKey(r)
|
||||
if err != nil {
|
||||
writeAuthError(w, err)
|
||||
return
|
||||
}
|
||||
wk, err := s.Resolve(r.Context(), key)
|
||||
if err != nil {
|
||||
writeAuthError(w, err)
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxKeyWriteKey, wk)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WriteKeyFromCtx returns the resolved key set by Auth middleware.
|
||||
func WriteKeyFromCtx(ctx context.Context) *model.WriteKey {
|
||||
v, _ := ctx.Value(ctxKeyWriteKey).(*model.WriteKey)
|
||||
return v
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func extractWriteKey(r *http.Request) (string, error) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return "", apperr.Unauthorized("missing Authorization header")
|
||||
}
|
||||
if strings.HasPrefix(h, "Bearer ") {
|
||||
return strings.TrimPrefix(h, "Bearer "), nil
|
||||
}
|
||||
if strings.HasPrefix(h, "Basic ") {
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(h, "Basic "))
|
||||
if err != nil {
|
||||
return "", apperr.Unauthorized("invalid basic auth")
|
||||
}
|
||||
// Segment uses `key:` (no password). Take everything before the first colon.
|
||||
s := string(raw)
|
||||
if i := strings.Index(s, ":"); i >= 0 {
|
||||
return s[:i], nil
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
return "", apperr.Unauthorized("unsupported auth scheme")
|
||||
}
|
||||
|
||||
func writeAuthError(w http.ResponseWriter, err error) {
|
||||
if ae, ok := apperr.As(err); ok {
|
||||
http.Error(w, `{"error":"`+ae.Message+`"}`, ae.Code)
|
||||
return
|
||||
}
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if i := strings.Index(xff, ","); i >= 0 {
|
||||
return strings.TrimSpace(xff[:i])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
if rip := r.Header.Get("X-Real-Ip"); rip != "" {
|
||||
return rip
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (s *statusRecorder) WriteHeader(code int) {
|
||||
s.status = code
|
||||
s.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user