1. go
  2. /web
  3. /websockets

Implementing WebSockets in Go

WebSockets enable real-time, bidirectional communication between clients and servers. This guide covers implementing WebSocket servers and handling connections in Go.

Basic Setup

WebSocket Server

package main

import (
    "log"
    "net/http"
    
    "github.com/gorilla/websocket"
)

var upgrader = websocket.Upgrader{
    ReadBufferSize:  1024,
    WriteBufferSize: 1024,
    CheckOrigin: func(r *http.Request) bool {
        return true // Adjust for production
    },
}

func handleWebSocket(w http.ResponseWriter, r *http.Request) {
    // Upgrade HTTP connection to WebSocket
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket upgrade error: %v", err)
        return
    }
    defer conn.Close()
    
    // Handle connection
    for {
        messageType, message, err := conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket read error: %v", err)
            }
            break
        }
        
        // Echo message back
        if err := conn.WriteMessage(messageType, message); err != nil {
            log.Printf("WebSocket write error: %v", err)
            break
        }
    }
}

func main() {
    http.HandleFunc("/ws", handleWebSocket)
    log.Fatal(http.ListenAndServe(":8080", nil))
}

Client Connection

// JavaScript client example
const ws = new WebSocket('ws://localhost:8080/ws');

ws.onopen = () => {
    console.log('Connected to WebSocket');
    ws.send('Hello, Server!');
};

ws.onmessage = (event) => {
    console.log('Received:', event.data);
};

ws.onclose = () => {
    console.log('Disconnected from WebSocket');
};

ws.onerror = (error) => {
    console.error('WebSocket error:', error);
};

Connection Management

Connection Hub

type Client struct {
    hub  *Hub
    conn *websocket.Conn
    send chan []byte
}

type Hub struct {
    clients    map[*Client]bool
    broadcast  chan []byte
    register   chan *Client
    unregister chan *Client
}

func newHub() *Hub {
    return &Hub{
        clients:    make(map[*Client]bool),
        broadcast:  make(chan []byte),
        register:   make(chan *Client),
        unregister: make(chan *Client),
    }
}

func (h *Hub) run() {
    for {
        select {
        case client := <-h.register:
            h.clients[client] = true
        case client := <-h.unregister:
            if _, ok := h.clients[client]; ok {
                delete(h.clients, client)
                close(client.send)
            }
        case message := <-h.broadcast:
            for client := range h.clients {
                select {
                case client.send <- message:
                default:
                    close(client.send)
                    delete(h.clients, client)
                }
            }
        }
    }
}

Client Handling

func (c *Client) readPump() {
    defer func() {
        c.hub.unregister <- c
        c.conn.Close()
    }()
    
    c.conn.SetReadLimit(maxMessageSize)
    c.conn.SetReadDeadline(time.Now().Add(pongWait))
    c.conn.SetPongHandler(func(string) error {
        c.conn.SetReadDeadline(time.Now().Add(pongWait))
        return nil
    })
    
    for {
        _, message, err := c.conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("error: %v", err)
            }
            break
        }
        
        c.hub.broadcast <- message
    }
}

func (c *Client) writePump() {
    ticker := time.NewTicker(pingPeriod)
    defer func() {
        ticker.Stop()
        c.conn.Close()
    }()
    
    for {
        select {
        case message, ok := <-c.send:
            c.conn.SetWriteDeadline(time.Now().Add(writeWait))
            if !ok {
                c.conn.WriteMessage(websocket.CloseMessage, []byte{})
                return
            }
            
            w, err := c.conn.NextWriter(websocket.TextMessage)
            if err != nil {
                return
            }
            w.Write(message)
            
            if err := w.Close(); err != nil {
                return
            }
        case <-ticker.C:
            c.conn.SetWriteDeadline(time.Now().Add(writeWait))
            if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
                return
            }
        }
    }
}

Message Handling

Message Types

type Message struct {
    Type    string          `json:"type"`
    Payload json.RawMessage `json:"payload"`
}

type ChatMessage struct {
    UserID  string    `json:"user_id"`
    Content string    `json:"content"`
    Time    time.Time `json:"time"`
}

func (c *Client) handleMessage(message []byte) error {
    var msg Message
    if err := json.Unmarshal(message, &msg); err != nil {
        return err
    }
    
    switch msg.Type {
    case "chat":
        var chatMsg ChatMessage
        if err := json.Unmarshal(msg.Payload, &chatMsg); err != nil {
            return err
        }
        return c.handleChatMessage(chatMsg)
    case "presence":
        // Handle presence updates
        return c.handlePresenceUpdate(msg.Payload)
    default:
        return fmt.Errorf("unknown message type: %s", msg.Type)
    }
}

Message Broadcasting

type Room struct {
    ID      string
    clients map[*Client]bool
    hub     *Hub
}

func (r *Room) broadcast(message []byte, sender *Client) {
    for client := range r.clients {
        if client != sender {
            select {
            case client.send <- message:
            default:
                close(client.send)
                delete(r.clients, client)
            }
        }
    }
}

func (r *Room) join(client *Client) {
    r.clients[client] = true
    // Broadcast join message
    msg := Message{
        Type: "system",
        Payload: json.RawMessage(`{
            "event": "join",
            "user_id": "` + client.userID + `"
        }`),
    }
    msgBytes, _ := json.Marshal(msg)
    r.broadcast(msgBytes, client)
}

Middleware

Authentication

func WebSocketAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        token := r.URL.Query().Get("token")
        if token == "" {
            http.Error(w, "Unauthorized", http.StatusUnauthorized)
            return
        }
        
        // Validate token
        user, err := validateToken(token)
        if err != nil {
            http.Error(w, "Invalid token", http.StatusUnauthorized)
            return
        }
        
        // Add user to context
        ctx := context.WithValue(r.Context(), "user", user)
        next.ServeHTTP(w, r.WithContext(ctx))
    }
}

Rate Limiting

type RateLimiter struct {
    limit  int
    window time.Duration
    counts map[string][]time.Time
    mu     sync.Mutex
}

func (rl *RateLimiter) Allow(clientID string) bool {
    rl.mu.Lock()
    defer rl.mu.Unlock()
    
    now := time.Now()
    windowStart := now.Add(-rl.window)
    
    // Remove old timestamps
    times := rl.counts[clientID]
    valid := 0
    for _, t := range times {
        if t.After(windowStart) {
            times[valid] = t
            valid++
        }
    }
    times = times[:valid]
    
    // Check rate limit
    if len(times) >= rl.limit {
        return false
    }
    
    // Add new timestamp
    rl.counts[clientID] = append(times, now)
    return true
}

Best Practices

1. Connection Parameters

const (
    // Time allowed to write a message to the peer
    writeWait = 10 * time.Second
    
    // Time allowed to read the next pong message from the peer
    pongWait = 60 * time.Second
    
    // Send pings to peer with this period
    pingPeriod = (pongWait * 9) / 10
    
    // Maximum message size allowed from peer
    maxMessageSize = 512
)

func configureConnection(conn *websocket.Conn) {
    conn.SetReadLimit(maxMessageSize)
    conn.SetReadDeadline(time.Now().Add(pongWait))
    conn.SetPongHandler(func(string) error {
        conn.SetReadDeadline(time.Now().Add(pongWait))
        return nil
    })
}

2. Error Handling

type WSError struct {
    Code    int    `json:"code"`
    Message string `json:"message"`
}

func (c *Client) sendError(err *WSError) {
    msg := Message{
        Type:    "error",
        Payload: err,
    }
    
    msgBytes, _ := json.Marshal(msg)
    c.send <- msgBytes
}

func handleWSError(err error, client *Client) {
    switch e := err.(type) {
    case *websocket.CloseError:
        // Handle close errors
        if e.Code != websocket.CloseNormalClosure {
            log.Printf("WebSocket closed abnormally: %v", err)
        }
    default:
        // Handle other errors
        client.sendError(&WSError{
            Code:    5000,
            Message: "Internal server error",
        })
    }
}

3. Graceful Shutdown

type Server struct {
    hub     *Hub
    server  *http.Server
    done    chan bool
}

func (s *Server) Shutdown(ctx context.Context) error {
    // Stop accepting new connections
    if err := s.server.Shutdown(ctx); err != nil {
        return err
    }
    
    // Close all existing connections
    for client := range s.hub.clients {
        client.conn.WriteMessage(
            websocket.CloseMessage,
            websocket.FormatCloseMessage(websocket.CloseGoingAway, ""),
        )
        client.conn.Close()
    }
    
    // Wait for all connections to close
    select {
    case <-s.done:
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

Common Patterns

1. Heartbeat

func (c *Client) startHeartbeat() {
    ticker := time.NewTicker(pingPeriod)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil {
                log.Printf("ping error: %v", err)
                return
            }
        case <-c.done:
            return
        }
    }
}

func (c *Client) setupPongHandler() {
    c.conn.SetPongHandler(func(string) error {
        c.conn.SetReadDeadline(time.Now().Add(pongWait))
        return nil
    })
}

2. Reconnection

// Client-side reconnection logic
class WSClient {
    constructor(url, options = {}) {
        this.url = url;
        this.options = {
            reconnectInterval: 1000,
            maxReconnectAttempts: 5,
            ...options,
        };
        this.reconnectAttempts = 0;
        this.connect();
    }
    
    connect() {
        this.ws = new WebSocket(this.url);
        
        this.ws.onopen = () => {
            this.reconnectAttempts = 0;
            this.onOpen();
        };
        
        this.ws.onclose = () => {
            this.reconnect();
        };
        
        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
        };
    }
    
    reconnect() {
        if (this.reconnectAttempts >= this.options.maxReconnectAttempts) {
            console.error('Max reconnection attempts reached');
            return;
        }
        
        this.reconnectAttempts++;
        setTimeout(() => {
            console.log('Reconnecting...');
            this.connect();
        }, this.options.reconnectInterval);
    }
}

Next Steps