sliding-sync/sync3/connmap.go
2023-11-24 15:23:35 +00:00

257 lines
7.7 KiB
Go

package sync3
import (
"context"
"sync"
"time"
"golang.org/x/exp/slices"
"github.com/ReneKroon/ttlcache/v2"
"github.com/matrix-org/sliding-sync/internal"
"github.com/prometheus/client_golang/prometheus"
)
// ConnMap stores a collection of Conns.
type ConnMap struct {
cache *ttlcache.Cache
// map of user_id to active connections. Inspect the ConnID to find the device ID.
userIDToConn map[string][]*Conn
connIDToConn map[string]*Conn
numConns prometheus.Gauge
// counters for reasons why connections have expired
expiryTimedOutCounter prometheus.Counter
expiryBufferFullCounter prometheus.Counter
mu *sync.Mutex
}
func NewConnMap(enablePrometheus bool, ttl time.Duration) *ConnMap {
cm := &ConnMap{
userIDToConn: make(map[string][]*Conn),
connIDToConn: make(map[string]*Conn),
cache: ttlcache.NewCache(),
mu: &sync.Mutex{},
}
cm.cache.SetTTL(ttl)
cm.cache.SetExpirationCallback(cm.closeConnExpires)
if enablePrometheus {
cm.expiryTimedOutCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "expiry_conn_timed_out",
Help: "Counter of expired API connections due to reaching TTL limit",
})
prometheus.MustRegister(cm.expiryTimedOutCounter)
cm.expiryBufferFullCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "expiry_conn_buffer_full",
Help: "Counter of expired API connections due to reaching buffer update limit",
})
prometheus.MustRegister(cm.expiryBufferFullCounter)
cm.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "num_active_conns",
Help: "Number of active sliding sync connections.",
})
prometheus.MustRegister(cm.numConns)
}
return cm
}
func (m *ConnMap) Teardown() {
m.cache.Close()
if m.numConns != nil {
prometheus.Unregister(m.numConns)
}
if m.expiryBufferFullCounter != nil {
prometheus.Unregister(m.expiryBufferFullCounter)
}
if m.expiryTimedOutCounter != nil {
prometheus.Unregister(m.expiryTimedOutCounter)
}
}
// UpdateMetrics recalculates the number of active connections. Do this when you think there is a change.
func (m *ConnMap) UpdateMetrics() {
m.mu.Lock()
defer m.mu.Unlock()
m.updateMetrics(len(m.connIDToConn))
}
// updateMetrics is like UpdateMetrics but doesn't touch connIDToConn and hence doesn't need to lock. We use this internally
// when we need to update the metric and already have the lock held, as calling UpdateMetrics would deadlock.
func (m *ConnMap) updateMetrics(numConns int) {
if m.numConns == nil {
return
}
m.numConns.Set(float64(numConns))
}
// Conns return all connections for this user|device
func (m *ConnMap) Conns(userID, deviceID string) []*Conn {
connIDs := m.connIDsForDevice(userID, deviceID)
var conns []*Conn
for _, connID := range connIDs {
c := m.Conn(connID)
if c != nil {
conns = append(conns, c)
}
}
return conns
}
// Conn returns a connection with this ConnID. Returns nil if no connection exists.
func (m *ConnMap) Conn(cid ConnID) *Conn {
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(cid)
}
// getConn returns a connection with this ConnID. Returns nil if no connection exists. Expires connections if the buffer is full.
// Must hold mu.
func (m *ConnMap) getConn(cid ConnID) *Conn {
cint, _ := m.cache.Get(cid.String())
if cint == nil {
return nil
}
conn := cint.(*Conn)
if conn.Alive() {
return conn
}
// e.g buffer exceeded, close it and remove it from the cache
logger.Info().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
m.closeConn(conn)
if m.expiryBufferFullCounter != nil {
m.expiryBufferFullCounter.Inc()
}
return nil
}
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) *Conn {
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
conn := m.getConn(cid)
if conn != nil {
// tear down this connection and fallthrough
isSpamming := conn.lastPos <= 1
if isSpamming {
// the existing connection has only just been used for one response, and now they are asking
// for a new connection. Apply an artificial delay here to stop buggy clients from spamming
// /sync without a `?pos=` value.
time.Sleep(SpamProtectionInterval)
}
logger.Trace().Str("conn", cid.String()).Bool("spamming", isSpamming).Msg("closing connection due to CreateConn called again")
m.closeConn(conn)
}
h := newConnHandler()
h.SetCancelCallback(cancel)
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
m.updateMetrics(len(m.connIDToConn))
return conn
}
func (m *ConnMap) CloseConnsForDevice(userID, deviceID string) {
logger.Trace().Str("user", userID).Str("device", deviceID).Msg("closing connections due to CloseConn()")
// gather open connections for this user|device
connIDs := m.connIDsForDevice(userID, deviceID)
for _, cid := range connIDs {
err := m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
if err != nil {
logger.Err(err).Str("cid", cid.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
}
}
}
func (m *ConnMap) connIDsForDevice(userID, deviceID string) []ConnID {
m.mu.Lock()
defer m.mu.Unlock()
var connIDs []ConnID
conns := m.userIDToConn[userID]
for _, c := range conns {
if c.DeviceID == deviceID {
connIDs = append(connIDs, c.ConnID)
}
}
return connIDs
}
// CloseConnsForUsers closes all conns for a given slice of users. Returns the number of
// conns closed.
func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) {
m.mu.Lock()
defer m.mu.Unlock()
for _, userID := range userIDs {
conns := m.userIDToConn[userID]
logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()")
for _, conn := range conns {
err := m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
if err != nil {
logger.Err(err).Str("cid", conn.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
}
}
closed += len(conns)
}
return closed
}
func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
conn := value.(*Conn)
logger.Info().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
if m.expiryTimedOutCounter != nil {
m.expiryTimedOutCounter.Inc()
}
m.closeConn(conn)
}
// must hold mu
func (m *ConnMap) closeConn(conn *Conn) {
if conn == nil {
return
}
connKey := conn.ConnID.String()
logger.Trace().Str("conn", connKey).Msg("closing connection")
// remove conn from all the maps
delete(m.connIDToConn, connKey)
h := conn.handler
conns := m.userIDToConn[conn.UserID]
for i := 0; i < len(conns); i++ {
if conns[i].DeviceID == conn.DeviceID && conns[i].CID == conn.CID {
// delete without preserving order
conns[i] = nil // allow GC
conns = slices.Delete(conns, i, i+1)
i--
}
}
m.userIDToConn[conn.UserID] = conns
// remove user cache listeners etc
h.Destroy()
m.updateMetrics(len(m.connIDToConn))
}
func (m *ConnMap) ClearUpdateQueues(userID, roomID string, nid int64) {
m.mu.Lock()
defer m.mu.Unlock()
for _, conn := range m.userIDToConn[userID] {
conn.handler.PublishEventsUpTo(roomID, nid)
}
}