2023-11-24 13:18:24 +00:00
|
|
|
package sync3
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
2023-11-24 15:19:08 +00:00
|
|
|
"reflect"
|
|
|
|
"sort"
|
2024-03-11 10:22:18 +00:00
|
|
|
"sync/atomic"
|
2023-11-24 13:18:24 +00:00
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/matrix-org/sliding-sync/sync3/caches"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
alice = "@alice:localhost"
|
|
|
|
bob = "@bob:localhost"
|
|
|
|
)
|
|
|
|
|
2023-11-24 15:25:53 +00:00
|
|
|
// mustEqual ensures that got==want else logs an error.
|
|
|
|
// The 'msg' is displayed with the error to provide extra context.
|
|
|
|
func mustEqual[V comparable](t *testing.T, got, want V, msg string) {
|
|
|
|
t.Helper()
|
|
|
|
if got != want {
|
|
|
|
t.Errorf("Equal %s: got '%v' want '%v'", msg, got, want)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-24 13:18:24 +00:00
|
|
|
func TestConnMap(t *testing.T) {
|
|
|
|
cm := NewConnMap(false, time.Minute)
|
|
|
|
cid := ConnID{UserID: alice, DeviceID: "A", CID: "room-list"}
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
2023-11-24 15:25:53 +00:00
|
|
|
mustEqual(t, conn.ConnID, cid, "cid mismatch")
|
2023-11-24 13:18:24 +00:00
|
|
|
|
|
|
|
// lookups work
|
2023-11-24 15:25:53 +00:00
|
|
|
mustEqual(t, cm.Conn(cid), conn, "*Conn wasn't the same when fetched via Conn(ConnID)")
|
2023-11-24 13:18:24 +00:00
|
|
|
conns := cm.Conns(cid.UserID, cid.DeviceID)
|
2023-11-24 15:25:53 +00:00
|
|
|
mustEqual(t, len(conns), 1, "Conns length mismatch")
|
|
|
|
mustEqual(t, conns[0], conn, "*Conn wasn't the same when fetched via Conns()[0]")
|
2023-11-24 13:18:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestConnMap_CloseConnsForDevice(t *testing.T) {
|
|
|
|
cm := NewConnMap(false, time.Minute)
|
|
|
|
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
|
|
|
|
cidToConn := map[ConnID]*Conn{
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
|
|
|
|
otherCID: nil,
|
|
|
|
}
|
|
|
|
for cid := range cidToConn {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
|
|
|
|
closedDevice := "A"
|
|
|
|
cm.CloseConnsForDevice(alice, closedDevice)
|
|
|
|
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines
|
|
|
|
|
|
|
|
// Destroy should have been called for all alice|A connections
|
|
|
|
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
|
|
|
return cid.UserID == alice && cid.DeviceID == "A"
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestConnMap_CloseConnsForUser(t *testing.T) {
|
|
|
|
cm := NewConnMap(false, time.Minute)
|
|
|
|
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
|
|
|
|
cidToConn := map[ConnID]*Conn{
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
|
|
|
|
otherCID: nil,
|
|
|
|
}
|
|
|
|
for cid := range cidToConn {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
|
|
|
|
num := cm.CloseConnsForUsers([]string{alice})
|
|
|
|
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines
|
2023-11-24 15:25:53 +00:00
|
|
|
mustEqual(t, num, 6, "unexpected number of closed conns")
|
2023-11-24 13:18:24 +00:00
|
|
|
|
2023-11-24 15:23:35 +00:00
|
|
|
// Destroy should have been called for all alice connections
|
2023-11-24 13:18:24 +00:00
|
|
|
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
|
|
|
return cid.UserID == alice
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestConnMap_TTLExpiry(t *testing.T) {
|
|
|
|
cm := NewConnMap(false, time.Second) // 1s expiry
|
|
|
|
expiredCIDs := []ConnID{
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "room-list"},
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "encryption"},
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "notifications"},
|
|
|
|
}
|
|
|
|
cidToConn := map[ConnID]*Conn{}
|
|
|
|
for _, cid := range expiredCIDs {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
time.Sleep(time.Millisecond * 500)
|
|
|
|
|
|
|
|
unexpiredCIDs := []ConnID{
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "room-list"},
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "encryption"},
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "notifications"},
|
|
|
|
}
|
|
|
|
for _, cid := range unexpiredCIDs {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
|
|
|
|
time.Sleep(510 * time.Millisecond) // all 'A' device conns must have expired
|
|
|
|
|
|
|
|
// Destroy should have been called for all alice|A connections
|
|
|
|
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
|
|
|
return cid.DeviceID == "A"
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2023-11-24 14:30:08 +00:00
|
|
|
func TestConnMap_TTLExpiryStaggeredDevices(t *testing.T) {
|
|
|
|
cm := NewConnMap(false, time.Second) // 1s expiry
|
|
|
|
expiredCIDs := []ConnID{
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "room-list"},
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "encryption"},
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "notifications"},
|
|
|
|
}
|
|
|
|
cidToConn := map[ConnID]*Conn{}
|
|
|
|
for _, cid := range expiredCIDs {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
time.Sleep(time.Millisecond * 500)
|
|
|
|
|
|
|
|
unexpiredCIDs := []ConnID{
|
|
|
|
{UserID: alice, DeviceID: "B", CID: "room-list"},
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "encryption"},
|
|
|
|
{UserID: alice, DeviceID: "A", CID: "notifications"},
|
|
|
|
}
|
|
|
|
for _, cid := range unexpiredCIDs {
|
|
|
|
_, cancel := context.WithCancel(context.Background())
|
|
|
|
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
|
|
|
|
return &mockConnHandler{}
|
|
|
|
})
|
|
|
|
cidToConn[cid] = conn
|
|
|
|
}
|
|
|
|
|
2023-11-24 15:19:08 +00:00
|
|
|
// all expiredCIDs should have expired, none from unexpiredCIDs
|
|
|
|
time.Sleep(510 * time.Millisecond)
|
2023-11-24 14:30:08 +00:00
|
|
|
|
2023-11-24 15:19:08 +00:00
|
|
|
// Destroy should have been called for all expiredCIDs connections
|
2023-11-24 14:30:08 +00:00
|
|
|
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
|
|
|
|
for _, expCID := range expiredCIDs {
|
|
|
|
if expCID.String() == cid.String() {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
})
|
|
|
|
|
|
|
|
// double check this by querying connmap
|
|
|
|
conns := cm.Conns(alice, "A")
|
2023-11-24 15:19:08 +00:00
|
|
|
var gotIDs []string
|
2023-11-24 14:30:08 +00:00
|
|
|
for _, c := range conns {
|
|
|
|
t.Logf(c.String())
|
2023-11-24 15:19:08 +00:00
|
|
|
gotIDs = append(gotIDs, c.CID)
|
2023-11-24 14:30:08 +00:00
|
|
|
}
|
2023-11-24 15:19:08 +00:00
|
|
|
sort.Strings(gotIDs)
|
|
|
|
wantIDs := []string{"encryption", "notifications"}
|
2023-11-24 15:25:53 +00:00
|
|
|
mustEqual(t, len(conns), 2, "unexpected number of Conns for device")
|
2023-11-24 15:19:08 +00:00
|
|
|
if !reflect.DeepEqual(gotIDs, wantIDs) {
|
|
|
|
t.Fatalf("unexpected active conns: got %v want %v", gotIDs, wantIDs)
|
|
|
|
}
|
2023-11-24 14:30:08 +00:00
|
|
|
}
|
|
|
|
|
2023-11-24 13:18:24 +00:00
|
|
|
func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedFn func(cid ConnID) bool) {
|
|
|
|
t.Helper()
|
|
|
|
for cid, conn := range cidToConn {
|
|
|
|
if isDestroyedFn(cid) {
|
2024-03-11 10:22:18 +00:00
|
|
|
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), true, fmt.Sprintf("conn %+v was not destroyed", cid))
|
2023-11-24 13:18:24 +00:00
|
|
|
} else {
|
2024-03-11 10:22:18 +00:00
|
|
|
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), false, fmt.Sprintf("conn %+v was destroyed", cid))
|
2023-11-24 13:18:24 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type mockConnHandler struct {
|
2024-03-11 10:22:18 +00:00
|
|
|
isDestroyed atomic.Bool
|
2023-11-24 13:18:24 +00:00
|
|
|
cancel context.CancelFunc
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error) {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {}
|
|
|
|
func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {}
|
|
|
|
func (c *mockConnHandler) Destroy() {
|
2024-03-11 10:22:18 +00:00
|
|
|
c.isDestroyed.Store(true)
|
2023-11-24 13:18:24 +00:00
|
|
|
}
|
|
|
|
func (c *mockConnHandler) Alive() bool {
|
|
|
|
return true // buffer never fills up
|
|
|
|
}
|
|
|
|
func (c *mockConnHandler) SetCancelCallback(cancel context.CancelFunc) {
|
|
|
|
c.cancel = cancel
|
|
|
|
}
|