Websocket in Go With Gorilla Websocket
Jun 8, 2019
2 minute read

Learn how to use websockets in Golang to create a real time application with an Gorilla Websocket

hub.go

package hub

import (
	"log"
	"net/http"
	"github.com/faizalpribadi/learn/client"
	"github.com/gorilla/websocket"
)

var (
	upgrader = websocket.Upgrader{
		ReadBufferSize:  1024,
		WriteBufferSize: 1024,
		CheckOrigin: func(r *http.Request) bool {
			return true
		},
	}
)

type Hub interface {
	Serve(chan client.Client) func(http.ResponseWriter, *http.Request)
	HasClient(userId string) bool
	GetClient(userId string) (client.Client, bool)
	Stop()
}

type hub struct {
	connections map[string]client.Client
	upgrader    websocket.Upgrader
	chClose     chan struct{}
	closed      bool
}

func NewHub() Hub {
	return &hub{make(map[string]client.Client), upgrader, make(chan struct{}), false}
}

func (h *hub) startClientListener(chClient chan client.Client) {
	defer close(chClient)
	defer close(h.chClose)
	for {
		select {
		case client := <-chClient:
			h.connections[client.ID()] = client
		case <-h.chClose:
			break
		}
	}
}

func (h *hub) Stop() {
	h.closed = true
	h.chClose <- struct{}{}
}

func (h *hub) Serve(chClient chan client.Client) func(http.ResponseWriter, *http.Request) {
	go h.startClientListener(chClient)
	log.Println("Hub is serving...")
	return func(writer http.ResponseWriter, req *http.Request) {
		userId := req.URL.Query().Get("user_id")
		if h.closed {
			http.Error(writer, "Server was stopped", 500)
			return
		} else if userId == "" {
			http.Error(writer, "Unauthorized", 401)
			return
		}
		log.Printf("Connected user: %s\n", userId)
		conn, err := upgrader.Upgrade(writer, req, nil)
		if err != nil {
			log.Println("Failed to upgrade to websocket", err.Error())
			http.Error(writer, "Failed to upgrade to websocket", 500)
			return
		}
		chClient <- client.NewClient(conn, userId)
	}
}

func (h *hub) HasClient(userId string) bool {
	_, ok := h.connections[userId]
	return ok
}

func (h *hub) GetClient(userId string) (client.Client, bool) {
	if !h.HasClient(userId) {
		return nil, false
	}
	return h.connections[userId], true
}

client.go

package client

import (
	"fmt"
	"time"

	"github.com/gorilla/websocket"
)

var (
	ErrWriteTimeout = fmt.Errorf("WriteTimeout")
)

type Client interface {
	ID() string
	Ping()
	Destroy()
	Send(interface{}) error
}

type client struct {
	conn         *websocket.Conn
	id           string
	writeTimeout time.Duration
	closed       bool
	chClose      chan struct{}
}

func NewClient(conn *websocket.Conn, userId string) Client {
	c := &client{conn, userId, 100 * time.Millisecond, false, make(chan struct{})}
	go c.Ping()
	return c
}

func (c *client) ID() string {
	return c.id
}

func (c *client) Ping() {
	ticker := time.NewTicker(100 * time.Millisecond)
	defer ticker.Stop()
	for {
		select {
		case <-c.chClose:
			return
		case <-ticker.C:
			c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
			c.conn.WriteMessage(websocket.PingMessage, nil)
		}
	}
}

func (c *client) Send(data interface{}) error {
	c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
	err := c.conn.WriteJSON(data)
	return err
}

func (c *client) Destroy() {
	c.chClose <- struct{}{}
}

main.go

var (
    hub = hub.NewHub()
)

func main() {
    r := mux.NewRouter()
    chClient := make(chan client.Client)
    r.HandleFunc("/notification", hub.Serve(chClient))
    http.ListenAndServe(":8080", r)
}