2021-10-05 16:22:02 +01:00
|
|
|
package sync3
|
2021-09-22 14:57:57 +01:00
|
|
|
|
|
|
|
import (
|
2023-10-26 15:58:06 +01:00
|
|
|
"context"
|
2021-09-22 14:57:57 +01:00
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
2023-11-24 14:47:47 +00:00
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
|
2021-09-22 14:57:57 +01:00
|
|
|
"github.com/ReneKroon/ttlcache/v2"
|
2023-11-24 15:23:35 +00:00
|
|
|
"github.com/matrix-org/sliding-sync/internal"
|
2023-07-24 14:17:10 +01:00
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
2021-09-22 14:57:57 +01:00
|
|
|
)
|
|
|
|
|
2022-03-23 14:13:59 +00:00
|
|
|
// ConnMap stores a collection of Conns.
|
2021-09-22 14:57:57 +01:00
|
|
|
type ConnMap struct {
|
|
|
|
cache *ttlcache.Cache
|
|
|
|
|
|
|
|
// map of user_id to active connections. Inspect the ConnID to find the device ID.
|
|
|
|
userIDToConn map[string][]*Conn
|
|
|
|
connIDToConn map[string]*Conn
|
|
|
|
|
2023-07-24 14:17:10 +01:00
|
|
|
numConns prometheus.Gauge
|
|
|
|
// counters for reasons why connections have expired
|
|
|
|
expiryTimedOutCounter prometheus.Counter
|
|
|
|
expiryBufferFullCounter prometheus.Counter
|
|
|
|
|
2021-10-11 17:12:54 +01:00
|
|
|
mu *sync.Mutex
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
|
|
|
|
2023-11-24 13:18:24 +00:00
|
|
|
func NewConnMap(enablePrometheus bool, ttl time.Duration) *ConnMap {
|
2021-09-22 14:57:57 +01:00
|
|
|
cm := &ConnMap{
|
2021-10-11 17:12:54 +01:00
|
|
|
userIDToConn: make(map[string][]*Conn),
|
|
|
|
connIDToConn: make(map[string]*Conn),
|
|
|
|
cache: ttlcache.NewCache(),
|
|
|
|
mu: &sync.Mutex{},
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
2023-11-24 13:18:24 +00:00
|
|
|
cm.cache.SetTTL(ttl)
|
2021-12-01 12:22:26 +00:00
|
|
|
cm.cache.SetExpirationCallback(cm.closeConnExpires)
|
2023-07-24 14:17:10 +01:00
|
|
|
|
|
|
|
if enablePrometheus {
|
|
|
|
cm.expiryTimedOutCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
|
|
|
Namespace: "sliding_sync",
|
|
|
|
Subsystem: "api",
|
|
|
|
Name: "expiry_conn_timed_out",
|
|
|
|
Help: "Counter of expired API connections due to reaching TTL limit",
|
|
|
|
})
|
|
|
|
prometheus.MustRegister(cm.expiryTimedOutCounter)
|
|
|
|
cm.expiryBufferFullCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
|
|
|
Namespace: "sliding_sync",
|
|
|
|
Subsystem: "api",
|
|
|
|
Name: "expiry_conn_buffer_full",
|
|
|
|
Help: "Counter of expired API connections due to reaching buffer update limit",
|
|
|
|
})
|
|
|
|
prometheus.MustRegister(cm.expiryBufferFullCounter)
|
|
|
|
cm.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
|
|
|
|
Namespace: "sliding_sync",
|
|
|
|
Subsystem: "api",
|
|
|
|
Name: "num_active_conns",
|
|
|
|
Help: "Number of active sliding sync connections.",
|
|
|
|
})
|
|
|
|
prometheus.MustRegister(cm.numConns)
|
|
|
|
}
|
2021-09-22 14:57:57 +01:00
|
|
|
return cm
|
|
|
|
}
|
|
|
|
|
2022-12-14 18:53:55 +00:00
|
|
|
func (m *ConnMap) Teardown() {
|
|
|
|
m.cache.Close()
|
2023-07-24 14:17:10 +01:00
|
|
|
|
|
|
|
if m.numConns != nil {
|
|
|
|
prometheus.Unregister(m.numConns)
|
|
|
|
}
|
|
|
|
if m.expiryBufferFullCounter != nil {
|
|
|
|
prometheus.Unregister(m.expiryBufferFullCounter)
|
|
|
|
}
|
|
|
|
if m.expiryTimedOutCounter != nil {
|
|
|
|
prometheus.Unregister(m.expiryTimedOutCounter)
|
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
}
|
|
|
|
|
2023-07-24 14:17:10 +01:00
|
|
|
// UpdateMetrics recalculates the number of active connections. Do this when you think there is a change.
|
|
|
|
func (m *ConnMap) UpdateMetrics() {
|
2022-12-14 18:53:55 +00:00
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
2023-07-24 14:17:10 +01:00
|
|
|
m.updateMetrics(len(m.connIDToConn))
|
|
|
|
}
|
|
|
|
|
2023-07-24 16:22:22 +01:00
|
|
|
// updateMetrics is like UpdateMetrics but doesn't touch connIDToConn and hence doesn't need to lock. We use this internally
|
2023-07-24 14:17:10 +01:00
|
|
|
// when we need to update the metric and already have the lock held, as calling UpdateMetrics would deadlock.
|
|
|
|
func (m *ConnMap) updateMetrics(numConns int) {
|
|
|
|
if m.numConns == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
m.numConns.Set(float64(numConns))
|
2022-12-14 18:53:55 +00:00
|
|
|
}
|
|
|
|
|
2023-05-10 17:31:07 +01:00
|
|
|
// Conns return all connections for this user|device
|
|
|
|
func (m *ConnMap) Conns(userID, deviceID string) []*Conn {
|
|
|
|
connIDs := m.connIDsForDevice(userID, deviceID)
|
|
|
|
var conns []*Conn
|
|
|
|
for _, connID := range connIDs {
|
|
|
|
c := m.Conn(connID)
|
|
|
|
if c != nil {
|
|
|
|
conns = append(conns, c)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return conns
|
|
|
|
}
|
|
|
|
|
2021-09-22 14:57:57 +01:00
|
|
|
// Conn returns a connection with this ConnID. Returns nil if no connection exists.
|
|
|
|
func (m *ConnMap) Conn(cid ConnID) *Conn {
|
2023-07-25 10:16:07 +01:00
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
return m.getConn(cid)
|
|
|
|
}
|
|
|
|
|
|
|
|
// getConn returns a connection with this ConnID. Returns nil if no connection exists. Expires connections if the buffer is full.
|
|
|
|
// Must hold mu.
|
|
|
|
func (m *ConnMap) getConn(cid ConnID) *Conn {
|
2021-09-22 14:57:57 +01:00
|
|
|
cint, _ := m.cache.Get(cid.String())
|
|
|
|
if cint == nil {
|
|
|
|
return nil
|
|
|
|
}
|
2022-08-19 18:12:09 +01:00
|
|
|
conn := cint.(*Conn)
|
|
|
|
if conn.Alive() {
|
|
|
|
return conn
|
|
|
|
}
|
|
|
|
// e.g buffer exceeded, close it and remove it from the cache
|
2023-07-24 14:17:10 +01:00
|
|
|
logger.Info().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
|
2022-08-19 18:12:09 +01:00
|
|
|
m.closeConn(conn)
|
2023-07-24 14:43:31 +01:00
|
|
|
if m.expiryBufferFullCounter != nil {
|
|
|
|
m.expiryBufferFullCounter.Inc()
|
|
|
|
}
|
2022-08-19 18:12:09 +01:00
|
|
|
return nil
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
|
|
|
|
2021-11-05 15:45:04 +00:00
|
|
|
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
|
2023-11-24 13:18:24 +00:00
|
|
|
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) *Conn {
|
2022-02-18 16:49:26 +00:00
|
|
|
// atomically check if a conn exists already and nuke it if it exists
|
2021-09-22 14:57:57 +01:00
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
2023-07-25 10:16:07 +01:00
|
|
|
conn := m.getConn(cid)
|
2021-09-22 14:57:57 +01:00
|
|
|
if conn != nil {
|
2022-02-18 16:49:26 +00:00
|
|
|
// tear down this connection and fallthrough
|
2023-05-22 17:44:04 +01:00
|
|
|
isSpamming := conn.lastPos <= 1
|
|
|
|
if isSpamming {
|
|
|
|
// the existing connection has only just been used for one response, and now they are asking
|
|
|
|
// for a new connection. Apply an artificial delay here to stop buggy clients from spamming
|
|
|
|
// /sync without a `?pos=` value.
|
|
|
|
time.Sleep(SpamProtectionInterval)
|
|
|
|
}
|
|
|
|
logger.Trace().Str("conn", cid.String()).Bool("spamming", isSpamming).Msg("closing connection due to CreateConn called again")
|
2022-02-18 16:49:26 +00:00
|
|
|
m.closeConn(conn)
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
2021-11-05 15:45:04 +00:00
|
|
|
h := newConnHandler()
|
2023-10-26 15:58:06 +01:00
|
|
|
h.SetCancelCallback(cancel)
|
2021-11-05 15:45:04 +00:00
|
|
|
conn = NewConn(cid, h)
|
2021-09-22 14:57:57 +01:00
|
|
|
m.cache.Set(cid.String(), conn)
|
|
|
|
m.connIDToConn[cid.String()] = conn
|
2023-04-28 13:43:45 +01:00
|
|
|
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
|
2023-07-24 14:17:10 +01:00
|
|
|
m.updateMetrics(len(m.connIDToConn))
|
2023-11-24 13:18:24 +00:00
|
|
|
return conn
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
|
|
|
|
2023-05-10 17:31:07 +01:00
|
|
|
func (m *ConnMap) CloseConnsForDevice(userID, deviceID string) {
|
|
|
|
logger.Trace().Str("user", userID).Str("device", deviceID).Msg("closing connections due to CloseConn()")
|
|
|
|
// gather open connections for this user|device
|
|
|
|
connIDs := m.connIDsForDevice(userID, deviceID)
|
|
|
|
for _, cid := range connIDs {
|
2023-11-24 14:16:13 +00:00
|
|
|
err := m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
|
|
|
|
if err != nil {
|
|
|
|
logger.Err(err).Str("cid", cid.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
|
2023-11-24 15:23:35 +00:00
|
|
|
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
|
2023-11-24 14:16:13 +00:00
|
|
|
}
|
2023-05-10 17:31:07 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *ConnMap) connIDsForDevice(userID, deviceID string) []ConnID {
|
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
var connIDs []ConnID
|
|
|
|
conns := m.userIDToConn[userID]
|
|
|
|
for _, c := range conns {
|
|
|
|
if c.DeviceID == deviceID {
|
|
|
|
connIDs = append(connIDs, c.ConnID)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return connIDs
|
2021-12-01 12:22:26 +00:00
|
|
|
}
|
|
|
|
|
2023-11-07 18:33:08 +00:00
|
|
|
// CloseConnsForUsers closes all conns for a given slice of users. Returns the number of
|
|
|
|
// conns closed.
|
|
|
|
func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) {
|
2023-11-03 15:44:59 +00:00
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
2023-11-07 18:33:08 +00:00
|
|
|
for _, userID := range userIDs {
|
|
|
|
conns := m.userIDToConn[userID]
|
|
|
|
logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()")
|
2023-11-03 15:44:59 +00:00
|
|
|
|
2023-11-07 18:33:08 +00:00
|
|
|
for _, conn := range conns {
|
2023-11-24 14:16:13 +00:00
|
|
|
err := m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
|
|
|
|
if err != nil {
|
|
|
|
logger.Err(err).Str("cid", conn.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
|
2023-11-24 15:23:35 +00:00
|
|
|
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
|
2023-11-24 14:16:13 +00:00
|
|
|
}
|
2023-11-07 18:33:08 +00:00
|
|
|
}
|
|
|
|
closed += len(conns)
|
2023-11-03 15:44:59 +00:00
|
|
|
}
|
2023-11-07 18:33:08 +00:00
|
|
|
return closed
|
2023-11-03 15:44:59 +00:00
|
|
|
}
|
|
|
|
|
2021-12-01 12:22:26 +00:00
|
|
|
func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
|
2021-09-23 16:37:41 +01:00
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
2021-09-22 14:57:57 +01:00
|
|
|
conn := value.(*Conn)
|
2023-07-24 14:17:10 +01:00
|
|
|
logger.Info().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
|
2023-07-24 14:43:31 +01:00
|
|
|
if m.expiryTimedOutCounter != nil {
|
|
|
|
m.expiryTimedOutCounter.Inc()
|
|
|
|
}
|
2021-12-01 12:22:26 +00:00
|
|
|
m.closeConn(conn)
|
|
|
|
}
|
|
|
|
|
|
|
|
// must hold mu
|
|
|
|
func (m *ConnMap) closeConn(conn *Conn) {
|
|
|
|
if conn == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-04-28 13:43:45 +01:00
|
|
|
connKey := conn.ConnID.String()
|
|
|
|
logger.Trace().Str("conn", connKey).Msg("closing connection")
|
2021-12-01 12:22:26 +00:00
|
|
|
// remove conn from all the maps
|
2023-04-28 13:43:45 +01:00
|
|
|
delete(m.connIDToConn, connKey)
|
2021-11-05 15:45:04 +00:00
|
|
|
h := conn.handler
|
2023-04-28 13:43:45 +01:00
|
|
|
conns := m.userIDToConn[conn.UserID]
|
2021-11-05 15:45:04 +00:00
|
|
|
for i := 0; i < len(conns); i++ {
|
2023-11-24 14:47:47 +00:00
|
|
|
if conns[i].DeviceID == conn.DeviceID && conns[i].CID == conn.CID {
|
2021-11-05 15:45:04 +00:00
|
|
|
// delete without preserving order
|
2023-11-24 14:47:47 +00:00
|
|
|
conns[i] = nil // allow GC
|
|
|
|
conns = slices.Delete(conns, i, i+1)
|
|
|
|
i--
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
|
|
|
}
|
2023-04-28 13:43:45 +01:00
|
|
|
m.userIDToConn[conn.UserID] = conns
|
2021-12-01 12:22:26 +00:00
|
|
|
// remove user cache listeners etc
|
2021-11-05 15:45:04 +00:00
|
|
|
h.Destroy()
|
2023-07-24 14:17:10 +01:00
|
|
|
m.updateMetrics(len(m.connIDToConn))
|
2021-09-22 14:57:57 +01:00
|
|
|
}
|
2023-07-26 13:47:33 +01:00
|
|
|
|
|
|
|
func (m *ConnMap) ClearUpdateQueues(userID, roomID string, nid int64) {
|
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
|
|
for _, conn := range m.userIDToConn[userID] {
|
|
|
|
conn.handler.PublishEventsUpTo(roomID, nid)
|
|
|
|
}
|
|
|
|
}
|