1. go
  2. /web
  3. /middleware

Web Middleware in Go

Middleware in Go allows you to add functionality to your HTTP request handling pipeline. This guide covers how to create and use middleware effectively.

Basic Middleware

Simple Middleware

func loggingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Log request details
        log.Printf("%s %s", r.Method, r.URL.Path)
        
        // Call the next handler
        next.ServeHTTP(w, r)
    })
}

func main() {
    // Create handler
    handler := http.HandlerFunc(finalHandler)
    
    // Wrap with middleware
    http.Handle("/", loggingMiddleware(handler))
    http.ListenAndServe(":8080", nil)
}

Chaining Middleware

func chainMiddleware(mw ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
    return func(final http.Handler) http.Handler {
        for i := len(mw) - 1; i >= 0; i-- {
            final = mw[i](final)
        }
        return final
    }
}

func main() {
    // Create middleware chain
    chain := chainMiddleware(
        loggingMiddleware,
        authMiddleware,
        rateLimitMiddleware,
    )
    
    // Apply chain to handler
    handler := chain(http.HandlerFunc(finalHandler))
    http.Handle("/", handler)
}

Common Middleware Types

Authentication Middleware

func authMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Get token from header
        token := r.Header.Get("Authorization")
        if token == "" {
            http.Error(w, "Unauthorized", http.StatusUnauthorized)
            return
        }
        
        // Validate token
        user, err := validateToken(token)
        if err != nil {
            http.Error(w, "Invalid token", http.StatusUnauthorized)
            return
        }
        
        // Add user to context
        ctx := context.WithValue(r.Context(), "user", user)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

CORS Middleware

func corsMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Set CORS headers
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods",
            "GET, POST, PUT, DELETE, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers",
            "Content-Type, Authorization")
        
        // Handle preflight requests
        if r.Method == http.MethodOptions {
            w.WriteHeader(http.StatusOK)
            return
        }
        
        next.ServeHTTP(w, r)
    })
}

Rate Limiting Middleware

type RateLimiter struct {
    requests map[string][]time.Time
    mu       sync.Mutex
    rate     int
    window   time.Duration
}

func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
    return &RateLimiter{
        requests: make(map[string][]time.Time),
        rate:     rate,
        window:   window,
    }
}

func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Get client IP
        ip := r.RemoteAddr
        
        // Check rate limit
        if !rl.Allow(ip) {
            http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
            return
        }
        
        next.ServeHTTP(w, r)
    })
}

Advanced Middleware

Response Writer Wrapper

type responseWriter struct {
    http.ResponseWriter
    status      int
    wroteHeader bool
}

func (rw *responseWriter) WriteHeader(code int) {
    if !rw.wroteHeader {
        rw.status = code
        rw.ResponseWriter.WriteHeader(code)
        rw.wroteHeader = true
    }
}

func loggingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Create custom response writer
        rw := &responseWriter{
            ResponseWriter: w,
            status:        http.StatusOK,
        }
        
        start := time.Now()
        next.ServeHTTP(rw, r)
        
        // Log request details with status code
        log.Printf(
            "%s %s %d %v",
            r.Method,
            r.URL.Path,
            rw.status,
            time.Since(start),
        )
    })
}

Panic Recovery Middleware

func recoveryMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        defer func() {
            if err := recover(); err != nil {
                // Log the stack trace
                stack := debug.Stack()
                log.Printf("panic: %v\n%s", err, stack)
                
                // Return error to client
                http.Error(w,
                    "Internal server error",
                    http.StatusInternalServerError)
            }
        }()
        
        next.ServeHTTP(w, r)
    })
}

Best Practices

1. Middleware Organization

type Middleware struct {
    handler http.Handler
    logger  *log.Logger
    metrics MetricsClient
}

func NewMiddleware(handler http.Handler, logger *log.Logger, metrics MetricsClient) *Middleware {
    return &Middleware{
        handler: handler,
        logger:  logger,
        metrics: metrics,
    }
}

func (m *Middleware) Logging() func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            m.logger.Printf("Request: %s %s", r.Method, r.URL.Path)
            next.ServeHTTP(w, r)
        })
    }
}

func (m *Middleware) Metrics() func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            start := time.Now()
            next.ServeHTTP(w, r)
            m.metrics.Timing("request.duration", time.Since(start))
        })
    }
}

2. Context Usage

type contextKey string

const (
    userContextKey contextKey = "user"
    traceContextKey contextKey = "trace"
)

func contextMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Add values to context
        ctx := r.Context()
        ctx = context.WithValue(ctx, traceContextKey, uuid.New().String())
        
        // Call next handler with updated context
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func handler(w http.ResponseWriter, r *http.Request) {
    // Get values from context
    traceID := r.Context().Value(traceContextKey).(string)
    log.Printf("Handling request with trace ID: %s", traceID)
}

3. Error Handling

type ErrorResponse struct {
    Error   string `json:"error"`
    Code    int    `json:"code"`
    TraceID string `json:"trace_id,omitempty"`
}

func errorMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Create error response writer
        ew := &errorWriter{ResponseWriter: w}
        
        // Call next handler
        next.ServeHTTP(ew, r)
        
        // Check if error occurred
        if ew.err != nil {
            traceID := r.Context().Value(traceContextKey).(string)
            response := ErrorResponse{
                Error:   ew.err.Error(),
                Code:    ew.status,
                TraceID: traceID,
            }
            
            w.Header().Set("Content-Type", "application/json")
            w.WriteHeader(ew.status)
            json.NewEncoder(w).Encode(response)
        }
    })
}

Common Patterns

1. Middleware Factory

func WithTimeout(timeout time.Duration) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // Create context with timeout
            ctx, cancel := context.WithTimeout(r.Context(), timeout)
            defer cancel()
            
            // Create done channel
            done := make(chan bool)
            
            go func() {
                next.ServeHTTP(w, r.WithContext(ctx))
                done <- true
            }()
            
            select {
            case <-done:
                return
            case <-ctx.Done():
                http.Error(w, "Request timeout", http.StatusGatewayTimeout)
                return
            }
        })
    }
}

2. Conditional Middleware

func IfMiddleware(condition func(*http.Request) bool, middleware func(http.Handler) http.Handler) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            if condition(r) {
                middleware(next).ServeHTTP(w, r)
            } else {
                next.ServeHTTP(w, r)
            }
        })
    }
}

// Usage
handler := IfMiddleware(
    func(r *http.Request) bool {
        return strings.HasPrefix(r.URL.Path, "/api")
    },
    authMiddleware,
)(finalHandler)

Next Steps