mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #382 from matrix-org/kegan/conn-map-tests
bugfix: when connections expire, only delete the affected connection
This commit is contained in:
commit
c31eb3e661
@ -5,7 +5,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ReneKroon/ttlcache/v2"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
@ -25,14 +28,14 @@ type ConnMap struct {
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnMap(enablePrometheus bool) *ConnMap {
|
||||
func NewConnMap(enablePrometheus bool, ttl time.Duration) *ConnMap {
|
||||
cm := &ConnMap{
|
||||
userIDToConn: make(map[string][]*Conn),
|
||||
connIDToConn: make(map[string]*Conn),
|
||||
cache: ttlcache.NewCache(),
|
||||
mu: &sync.Mutex{},
|
||||
}
|
||||
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
|
||||
cm.cache.SetTTL(ttl)
|
||||
cm.cache.SetExpirationCallback(cm.closeConnExpires)
|
||||
|
||||
if enablePrometheus {
|
||||
@ -132,7 +135,7 @@ func (m *ConnMap) getConn(cid ConnID) *Conn {
|
||||
}
|
||||
|
||||
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
|
||||
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) (*Conn, bool) {
|
||||
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) *Conn {
|
||||
// atomically check if a conn exists already and nuke it if it exists
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@ -156,7 +159,7 @@ func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandl
|
||||
m.connIDToConn[cid.String()] = conn
|
||||
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
|
||||
m.updateMetrics(len(m.connIDToConn))
|
||||
return conn, true
|
||||
return conn
|
||||
}
|
||||
|
||||
func (m *ConnMap) CloseConnsForDevice(userID, deviceID string) {
|
||||
@ -164,7 +167,11 @@ func (m *ConnMap) CloseConnsForDevice(userID, deviceID string) {
|
||||
// gather open connections for this user|device
|
||||
connIDs := m.connIDsForDevice(userID, deviceID)
|
||||
for _, cid := range connIDs {
|
||||
m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
|
||||
err := m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
|
||||
if err != nil {
|
||||
logger.Err(err).Str("cid", cid.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
|
||||
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,7 +198,11 @@ func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) {
|
||||
logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()")
|
||||
|
||||
for _, conn := range conns {
|
||||
m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
|
||||
err := m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
|
||||
if err != nil {
|
||||
logger.Err(err).Str("cid", conn.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
|
||||
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
|
||||
}
|
||||
}
|
||||
closed += len(conns)
|
||||
}
|
||||
@ -222,10 +233,11 @@ func (m *ConnMap) closeConn(conn *Conn) {
|
||||
h := conn.handler
|
||||
conns := m.userIDToConn[conn.UserID]
|
||||
for i := 0; i < len(conns); i++ {
|
||||
if conns[i].DeviceID == conn.DeviceID {
|
||||
if conns[i].DeviceID == conn.DeviceID && conns[i].CID == conn.CID {
|
||||
// delete without preserving order
|
||||
conns[i] = conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
conns[i] = nil // allow GC
|
||||
conns = slices.Delete(conns, i, i+1)
|
||||
i--
|
||||
}
|
||||
}
|
||||
m.userIDToConn[conn.UserID] = conns
|
||||
|
229
sync3/connmap_test.go
Normal file
229
sync3/connmap_test.go
Normal file
@ -0,0 +1,229 @@
|
||||
package sync3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
)
|
||||
|
||||
const (
|
||||
alice = "@alice:localhost"
|
||||
bob = "@bob:localhost"
|
||||
)
|
||||
|
||||
// mustEqual ensures that got==want else logs an error.
|
||||
// The 'msg' is displayed with the error to provide extra context.
|
||||
func mustEqual[V comparable](t *testing.T, got, want V, msg string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("Equal %s: got '%v' want '%v'", msg, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnMap(t *testing.T) {
|
||||
cm := NewConnMap(false, time.Minute)
|
||||
cid := ConnID{UserID: alice, DeviceID: "A", CID: "room-list"}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
mustEqual(t, conn.ConnID, cid, "cid mismatch")
|
||||
|
||||
// lookups work
|
||||
mustEqual(t, cm.Conn(cid), conn, "*Conn wasn't the same when fetched via Conn(ConnID)")
|
||||
conns := cm.Conns(cid.UserID, cid.DeviceID)
|
||||
mustEqual(t, len(conns), 1, "Conns length mismatch")
|
||||
mustEqual(t, conns[0], conn, "*Conn wasn't the same when fetched via Conns()[0]")
|
||||
}
|
||||
|
||||
func TestConnMap_CloseConnsForDevice(t *testing.T) {
|
||||
cm := NewConnMap(false, time.Minute)
|
||||
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
|
||||
cidToConn := map[ConnID]*Conn{
|
||||
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
|
||||
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
|
||||
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
|
||||
otherCID: nil,
|
||||
}
|
||||
for cid := range cidToConn {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
|
||||
closedDevice := "A"
|
||||
cm.CloseConnsForDevice(alice, closedDevice)
|
||||
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines
|
||||
|
||||
// Destroy should have been called for all alice|A connections
|
||||
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
||||
return cid.UserID == alice && cid.DeviceID == "A"
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnMap_CloseConnsForUser(t *testing.T) {
|
||||
cm := NewConnMap(false, time.Minute)
|
||||
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
|
||||
cidToConn := map[ConnID]*Conn{
|
||||
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
|
||||
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
|
||||
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
|
||||
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
|
||||
otherCID: nil,
|
||||
}
|
||||
for cid := range cidToConn {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
|
||||
num := cm.CloseConnsForUsers([]string{alice})
|
||||
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines
|
||||
mustEqual(t, num, 6, "unexpected number of closed conns")
|
||||
|
||||
// Destroy should have been called for all alice connections
|
||||
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
||||
return cid.UserID == alice
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnMap_TTLExpiry(t *testing.T) {
|
||||
cm := NewConnMap(false, time.Second) // 1s expiry
|
||||
expiredCIDs := []ConnID{
|
||||
{UserID: alice, DeviceID: "A", CID: "room-list"},
|
||||
{UserID: alice, DeviceID: "A", CID: "encryption"},
|
||||
{UserID: alice, DeviceID: "A", CID: "notifications"},
|
||||
}
|
||||
cidToConn := map[ConnID]*Conn{}
|
||||
for _, cid := range expiredCIDs {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
unexpiredCIDs := []ConnID{
|
||||
{UserID: alice, DeviceID: "B", CID: "room-list"},
|
||||
{UserID: alice, DeviceID: "B", CID: "encryption"},
|
||||
{UserID: alice, DeviceID: "B", CID: "notifications"},
|
||||
}
|
||||
for _, cid := range unexpiredCIDs {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
|
||||
time.Sleep(510 * time.Millisecond) // all 'A' device conns must have expired
|
||||
|
||||
// Destroy should have been called for all alice|A connections
|
||||
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
||||
return cid.DeviceID == "A"
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnMap_TTLExpiryStaggeredDevices(t *testing.T) {
|
||||
cm := NewConnMap(false, time.Second) // 1s expiry
|
||||
expiredCIDs := []ConnID{
|
||||
{UserID: alice, DeviceID: "A", CID: "room-list"},
|
||||
{UserID: alice, DeviceID: "B", CID: "encryption"},
|
||||
{UserID: alice, DeviceID: "B", CID: "notifications"},
|
||||
}
|
||||
cidToConn := map[ConnID]*Conn{}
|
||||
for _, cid := range expiredCIDs {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
unexpiredCIDs := []ConnID{
|
||||
{UserID: alice, DeviceID: "B", CID: "room-list"},
|
||||
{UserID: alice, DeviceID: "A", CID: "encryption"},
|
||||
{UserID: alice, DeviceID: "A", CID: "notifications"},
|
||||
}
|
||||
for _, cid := range unexpiredCIDs {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
||||
return &mockConnHandler{}
|
||||
})
|
||||
cidToConn[cid] = conn
|
||||
}
|
||||
|
||||
// all expiredCIDs should have expired, none from unexpiredCIDs
|
||||
time.Sleep(510 * time.Millisecond)
|
||||
|
||||
// Destroy should have been called for all expiredCIDs connections
|
||||
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
||||
for _, expCID := range expiredCIDs {
|
||||
if expCID.String() == cid.String() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
// double check this by querying connmap
|
||||
conns := cm.Conns(alice, "A")
|
||||
var gotIDs []string
|
||||
for _, c := range conns {
|
||||
t.Logf(c.String())
|
||||
gotIDs = append(gotIDs, c.CID)
|
||||
}
|
||||
sort.Strings(gotIDs)
|
||||
wantIDs := []string{"encryption", "notifications"}
|
||||
mustEqual(t, len(conns), 2, "unexpected number of Conns for device")
|
||||
if !reflect.DeepEqual(gotIDs, wantIDs) {
|
||||
t.Fatalf("unexpected active conns: got %v want %v", gotIDs, wantIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedFn func(cid ConnID) bool) {
|
||||
t.Helper()
|
||||
for cid, conn := range cidToConn {
|
||||
if isDestroyedFn(cid) {
|
||||
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid))
|
||||
} else {
|
||||
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockConnHandler struct {
|
||||
isDestroyed bool
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {}
|
||||
func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {}
|
||||
func (c *mockConnHandler) Destroy() {
|
||||
c.isDestroyed = true
|
||||
}
|
||||
func (c *mockConnHandler) Alive() bool {
|
||||
return true // buffer never fills up
|
||||
}
|
||||
func (c *mockConnHandler) SetCancelCallback(cancel context.CancelFunc) {
|
||||
c.cancel = cancel
|
||||
}
|
@ -80,7 +80,7 @@ func NewSync3Handler(
|
||||
V2: v2Client,
|
||||
Storage: store,
|
||||
V2Store: storev2,
|
||||
ConnMap: sync3.NewConnMap(enablePrometheus),
|
||||
ConnMap: sync3.NewConnMap(enablePrometheus, 30*time.Minute),
|
||||
userCaches: &sync.Map{},
|
||||
Dispatcher: sync3.NewDispatcher(),
|
||||
GlobalCache: caches.NewGlobalCache(store),
|
||||
@ -453,14 +453,10 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.Canc
|
||||
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
||||
// to check for an existing connection though, as it's possible for the client to call /sync
|
||||
// twice for a new connection.
|
||||
conn, created := h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
|
||||
conn = h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
|
||||
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
|
||||
})
|
||||
if created {
|
||||
log.Info().Msg("created new connection")
|
||||
} else {
|
||||
log.Info().Msg("using existing connection")
|
||||
}
|
||||
log.Info().Msg("created new connection")
|
||||
return req, conn, nil
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user