feature: Support account_data extension

Along with a battery of tests to make sure we give account data only
for rooms being tracked in the sliding lists, unless it's global in
which case we just send it on.
This commit is contained in:
Kegan Dougal 2022-03-24 19:38:55 +00:00
parent e8131a920f
commit 3ecef2062e
15 changed files with 488 additions and 37 deletions

View File

@ -285,3 +285,193 @@ func TestExtensionToDevice(t *testing.T) {
// - do we need sessions at all? Can we delete if the since value is incremented?
// - check with ios folks if this level of co-ordination between processes is possible.
}
// tests that the account data extension works:
// 1- check global account data is sent on first connection
// 2- check global account data updates are proxied through
// 3- check room account data for the list only is sent
// 4- check room account data for subscriptions are sent
// 5- when the range changes, make sure room account data is sent
// 6- when a room bumps into a range, make sure room account data is sent
func TestExtensionAccountData(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
alice := "@alice:localhost"
aliceToken := "ALICE_BEARER_TOKEN"
roomA := "!a:localhost"
roomB := "!b:localhost"
roomC := "!c:localhost"
globalAccountData := []json.RawMessage{
testutils.NewAccountData(t, "im-global", map[string]interface{}{"body": "yep"}),
testutils.NewAccountData(t, "im-also-global", map[string]interface{}{"body": "yep"}),
}
roomAAccountData := []json.RawMessage{
testutils.NewAccountData(t, "im-a", map[string]interface{}{"body": "yep a"}),
testutils.NewAccountData(t, "im-also-a", map[string]interface{}{"body": "yep A"}),
}
roomBAccountData := []json.RawMessage{
testutils.NewAccountData(t, "im-b", map[string]interface{}{"body": "yep b"}),
testutils.NewAccountData(t, "im-also-b", map[string]interface{}{"body": "yep B"}),
}
roomCAccountData := []json.RawMessage{
testutils.NewAccountData(t, "im-c", map[string]interface{}{"body": "yep c"}),
testutils.NewAccountData(t, "im-also-c", map[string]interface{}{"body": "yep C"}),
}
v2.addAccount(alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
AccountData: sync2.EventsResponse{
Events: globalAccountData,
},
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: createRoomState(t, alice, time.Now()),
},
AccountData: sync2.EventsResponse{
Events: roomAAccountData,
},
},
roomB: {
State: sync2.EventsResponse{
Events: createRoomState(t, alice, time.Now().Add(-1*time.Minute)),
},
AccountData: sync2.EventsResponse{
Events: roomBAccountData,
},
},
roomC: {
State: sync2.EventsResponse{
Events: createRoomState(t, alice, time.Now().Add(-2*time.Minute)),
},
AccountData: sync2.EventsResponse{
Events: roomCAccountData,
},
},
},
},
})
// 1- check global account data is sent on first connection
// 3- check room account data for the list only is sent
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Enabled: true,
},
},
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 1}, // first two rooms A,B
},
Sort: []string{sync3.SortByRecency},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 0,
},
}},
})
MatchResponse(t, res, MatchAccountData(
globalAccountData,
map[string][]json.RawMessage{
roomA: roomAAccountData,
roomB: roomBAccountData,
},
))
// 5- when the range changes, make sure room account data is sent
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 2}, // A,B,C
},
}},
})
MatchResponse(t, res, MatchAccountData(
nil,
map[string][]json.RawMessage{
roomC: roomCAccountData,
},
))
// 4- check room account data for subscriptions are sent
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Enabled: true,
},
},
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomB: {
TimelineLimit: 1,
},
},
})
MatchResponse(t, res, MatchAccountData(
globalAccountData,
map[string][]json.RawMessage{
roomB: roomBAccountData,
},
))
// 2- check global account data updates are proxied through
newGlobalEvent := testutils.NewAccountData(t, "new_fun_event", map[string]interface{}{"much": "excite"})
v2.queueResponse(alice, sync2.SyncResponse{
AccountData: sync2.EventsResponse{
Events: []json.RawMessage{newGlobalEvent},
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{})
MatchResponse(t, res, MatchAccountData(
[]json.RawMessage{newGlobalEvent},
nil,
))
// 6- when a room bumps into a range, make sure room account data is sent
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Enabled: true,
},
},
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 1}, // first two rooms A,B
},
Sort: []string{sync3.SortByRecency},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 0,
},
}},
})
// bump C to position 0
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomC,
events: []json.RawMessage{
testutils.NewEvent(t, "m.poke", alice, map[string]interface{}{}, time.Now().Add(time.Millisecond)),
},
}),
},
})
v2.waitUntilEmpty(t, alice)
// now we should get room account data for C
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 1}, // first two rooms A,B
},
}},
})
MatchResponse(t, res, MatchAccountData(
nil,
map[string][]json.RawMessage{
roomC: roomCAccountData,
},
))
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/matrix-org/sync-v3/sqlutil"
)
@ -72,6 +73,17 @@ func (t *AccountDataTable) Select(txn *sqlx.Tx, userID, eventType, roomID string
return &acc, err
}
func (t *AccountDataTable) SelectMany(txn *sqlx.Tx, userID string, roomIDs ...string) (datas []AccountData, err error) {
if len(roomIDs) == 0 {
err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data
WHERE user_id=$1 AND room_id = $2`, userID, AccountDataGlobalRoom)
return
}
err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data
WHERE user_id=$1 AND room_id=ANY($2)`, userID, pq.StringArray(roomIDs))
return
}
type AccountDataChunker []AccountData
func (c AccountDataChunker) Len() int {

View File

@ -2,12 +2,26 @@ package state
import (
"reflect"
"sort"
"testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sync-v3/sync2"
)
func accountDatasEqual(gots, wants []AccountData) bool {
key := func(a AccountData) string {
return a.UserID + a.RoomID + a.Type
}
sort.Slice(gots, func(i, j int) bool {
return key(gots[i]) < key(gots[j])
})
sort.Slice(wants, func(i, j int) bool {
return key(wants[i]) < key(wants[j])
})
return reflect.DeepEqual(gots, wants)
}
func TestAccountData(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
@ -53,6 +67,12 @@ func TestAccountData(t *testing.T) {
Type: eventType,
Data: []byte(`{"foo":"bar4"}`),
},
{
UserID: alice,
RoomID: sync2.AccountDataGlobalRoom,
Type: "dummy",
Data: []byte(`{"foo":"bar5"}`),
},
// this should replace the first element
{
UserID: alice,
@ -83,8 +103,41 @@ func TestAccountData(t *testing.T) {
if err != nil {
t.Fatalf("Select: %s", err)
}
if !reflect.DeepEqual(*gotData, accountData[len(accountData)-2]) {
t.Fatalf("Select: expected global event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-2])
if !reflect.DeepEqual(*gotData, accountData[len(accountData)-3]) {
t.Fatalf("Select: expected global event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-3])
}
// Select all global events for alice
wantDatas := []AccountData{
accountData[4], accountData[5],
}
gotDatas, err := table.SelectMany(txn, alice)
if err != nil {
t.Fatalf("SelectMany: %s", err)
}
if !accountDatasEqual(gotDatas, wantDatas) {
t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas)
}
// Select all room events for alice
wantDatas = []AccountData{
accountData[6],
}
gotDatas, err = table.SelectMany(txn, alice, roomA)
if err != nil {
t.Fatalf("SelectMany: %s", err)
}
if !accountDatasEqual(gotDatas, wantDatas) {
t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas)
}
// Select all room events for unknown user
gotDatas, err = table.SelectMany(txn, "@someone-else:localhost", roomA)
if err != nil {
t.Fatalf("SelectMany: %s", err)
}
if len(gotDatas) != 0 {
t.Fatalf("SelectMany: got %d account data, want 0", len(gotDatas))
}
}

View File

@ -61,6 +61,16 @@ func (s *Storage) AccountData(userID, roomID, eventType string) (data *AccountDa
return
}
// Pull out all account data for this user. If roomIDs is empty, global account data is returned.
// If roomIDs is non-empty, all account data for these rooms are extracted.
func (s *Storage) AccountDatas(userID string, roomIDs ...string) (datas []AccountData, err error) {
err = sqlutil.WithTransaction(s.accumulator.db, func(txn *sqlx.Tx) error {
datas, err = s.AccountDataTable.SelectMany(txn, userID, roomIDs...)
return err
})
return
}
func (s *Storage) InsertAccountData(userID, roomID string, events []json.RawMessage) (data []AccountData, err error) {
data = make([]AccountData, len(events))
for i := range events {

View File

@ -25,7 +25,11 @@ type UnreadCountUpdate struct {
HasCountDecreased bool
}
type RoomAccountDataAlert struct {
type AccountDataUpdate struct {
AccountData []state.AccountData
}
type RoomAccountDataUpdate struct {
RoomUpdate
AccountData []state.AccountData
}

View File

@ -21,9 +21,9 @@ type UserCacheListener interface {
// Called when there is an update affecting a room e.g new event, unread count update, room account data.
// Type-cast to find out what the update is about.
OnRoomUpdate(up RoomUpdate)
// Called when there is an update affecting this user e.g global account data, presence.
// Called when there is an update affecting this user but not in the room e.g global account data, presence.
// Type-cast to find out what the update is about.
// OnUserUpdate(up UserUpdate)
OnUpdate(up Update)
}
// Tracks data specific to a given user. Specifically, this is the map of room ID to UserRoomData.
@ -222,7 +222,11 @@ func (c *UserCache) OnNewEvent(eventData *EventData) {
}
func (c *UserCache) OnAccountData(datas []state.AccountData) {
roomUpdates := make(map[string][]state.AccountData)
for _, d := range datas {
up := roomUpdates[d.RoomID]
up = append(up, d)
roomUpdates[d.RoomID] = up
if d.Type == "m.direct" {
dmRoomSet := make(map[string]struct{})
// pull out rooms and mark them as DMs
@ -249,5 +253,26 @@ func (c *UserCache) OnAccountData(datas []state.AccountData) {
}
c.roomToDataMu.Unlock()
}
}
// bucket account data updates per-room and globally then invoke listeners
for roomID, updates := range roomUpdates {
if roomID == state.AccountDataGlobalRoom {
globalUpdate := &AccountDataUpdate{
AccountData: updates,
}
for _, l := range c.listeners {
l.OnUpdate(globalUpdate)
}
} else {
roomUpdate := &RoomAccountDataUpdate{
AccountData: updates,
RoomUpdate: c.newRoomUpdate(roomID),
}
for _, l := range c.listeners {
l.OnRoomUpdate(roomUpdate)
}
}
}
}

View File

@ -9,24 +9,11 @@ import (
// Client created request params
type AccountDataRequest struct {
Enabled bool `json:"enabled"`
GlobalAccountDataTypes []string `json:"global_account_data_types"`
RoomAccountDataTypes map[int][]string `json:"room_account_data_types"`
Enabled bool `json:"enabled"`
}
func (r AccountDataRequest) ApplyDelta(next *AccountDataRequest) *AccountDataRequest {
r.Enabled = next.Enabled
if next.GlobalAccountDataTypes != nil {
r.GlobalAccountDataTypes = next.GlobalAccountDataTypes
}
if next.RoomAccountDataTypes != nil {
if r.RoomAccountDataTypes == nil {
r.RoomAccountDataTypes = make(map[int][]string)
}
for listIndex, types := range next.RoomAccountDataTypes {
r.RoomAccountDataTypes[listIndex] = types
}
}
return &r
}
@ -43,14 +30,74 @@ func (r *AccountDataResponse) HasData(isInitial bool) bool {
return len(r.Rooms) > 0
}
func ProcessAccountData(store *state.Storage, userCache *caches.UserCache, userID string, isInitial bool, req *AccountDataRequest) (res *AccountDataResponse) {
if isInitial {
// register with the user cache TODO: need destructor call to unregister
// pull account data from store and return that
func accountEventsAsJSON(events []state.AccountData) []json.RawMessage {
j := make([]json.RawMessage, len(events))
for i := range events {
j[i] = events[i].Data
}
// on new account data callback, buffer it and then return it assuming isInitial=false
return j
}
// TODO: how to handle room account data? Need to know which rooms are being tracked.
func ProcessLiveAccountData(up caches.Update, store *state.Storage, updateWillReturnResponse bool, userID string, req *AccountDataRequest) (res *AccountDataResponse) {
switch update := up.(type) {
case *caches.AccountDataUpdate:
return &AccountDataResponse{
Global: accountEventsAsJSON(update.AccountData),
}
case *caches.RoomAccountDataUpdate:
return &AccountDataResponse{
Rooms: map[string][]json.RawMessage{
update.RoomID(): accountEventsAsJSON(update.AccountData),
},
}
case caches.RoomUpdate:
// this is a room update which is causing us to return, meaning we are interested in this room.
// send account data for this room.
if updateWillReturnResponse {
roomAccountData, err := store.AccountDatas(userID, update.RoomID())
if err != nil {
logger.Err(err).Str("user", userID).Str("room", update.RoomID()).Msg("failed to fetch room account data")
} else {
return &AccountDataResponse{
Rooms: map[string][]json.RawMessage{
update.RoomID(): accountEventsAsJSON(roomAccountData),
},
}
}
}
}
return nil
}
func ProcessAccountData(store *state.Storage, listRoomIDs map[string]struct{}, userID string, isInitial bool, req *AccountDataRequest) (res *AccountDataResponse) {
roomIDs := make([]string, len(listRoomIDs))
i := 0
for roomID := range listRoomIDs {
roomIDs[i] = roomID
i++
}
res = &AccountDataResponse{}
// room account data needs to be sent every time the user scrolls the list to get new room IDs
// TODO: remember which rooms the client has been told about
if len(roomIDs) > 0 {
roomsAccountData, err := store.AccountDatas(userID, roomIDs...)
if err != nil {
logger.Err(err).Str("user", userID).Strs("rooms", roomIDs).Msg("failed to fetch room account data")
} else {
res.Rooms = make(map[string][]json.RawMessage)
for _, ad := range roomsAccountData {
res.Rooms[ad.RoomID] = append(res.Rooms[ad.RoomID], ad.Data)
}
}
}
// global account data is only sent on the first connection, then we live stream
if isInitial {
globalAccountData, err := store.AccountDatas(userID)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to fetch global account data")
} else {
res.Global = accountEventsAsJSON(globalAccountData)
}
}
return
}

View File

@ -29,6 +29,9 @@ func (r Request) ApplyDelta(next *Request) Request {
if next.E2EE != nil {
r.E2EE = r.E2EE.ApplyDelta(next.E2EE)
}
if next.AccountData != nil {
r.AccountData = r.AccountData.ApplyDelta(next.AccountData)
}
return r
}
@ -45,8 +48,8 @@ func (e Response) HasData(isInitial bool) bool {
}
type HandlerInterface interface {
Handle(req Request, userCache *caches.UserCache, isInitial bool) (res Response)
HandleLiveData(req Request, res *Response, userCache *caches.UserCache, isInitial bool)
Handle(req Request, listRoomIDs map[string]struct{}, isInitial bool) (res Response)
HandleLiveUpdate(update caches.Update, req Request, res *Response, updateWillReturnResponse, isInitial bool)
}
type Handler struct {
@ -54,12 +57,13 @@ type Handler struct {
E2EEFetcher sync2.E2EEFetcher
}
func (h *Handler) HandleLiveData(req Request, res *Response, userCache *caches.UserCache, isInitial bool) {
// TODO: Define live data event in caches pkg, NOT ConnEvent.
// Update `Response` object.
func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Response, updateWillReturnResponse, isInitial bool) {
if req.AccountData != nil && req.AccountData.Enabled {
res.AccountData = ProcessLiveAccountData(update, h.Store, updateWillReturnResponse, req.UserID, req.AccountData)
}
}
func (h *Handler) Handle(req Request, userCache *caches.UserCache, isInitial bool) (res Response) {
func (h *Handler) Handle(req Request, listRoomIDs map[string]struct{}, isInitial bool) (res Response) {
if req.ToDevice != nil && req.ToDevice.Enabled != nil && *req.ToDevice.Enabled {
res.ToDevice = ProcessToDevice(h.Store, req.UserID, req.DeviceID, req.ToDevice)
}
@ -67,7 +71,7 @@ func (h *Handler) Handle(req Request, userCache *caches.UserCache, isInitial boo
res.E2EE = ProcessE2EE(h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE)
}
if req.AccountData != nil && req.AccountData.Enabled {
res.AccountData = ProcessAccountData(h.Store, userCache, req.UserID, isInitial, req.AccountData)
res.AccountData = ProcessAccountData(h.Store, listRoomIDs, req.UserID, isInitial, req.AccountData)
}
return
}

View File

@ -162,9 +162,13 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
responseOperations = append(responseOperations, ops...)
}
includedRoomIDs := sync3.IncludedRoomIDsInOps(responseOperations)
for _, roomID := range newSubs { // include room subs in addition to lists
includedRoomIDs[roomID] = struct{}{}
}
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
// is being notified about (e.g. for room account data)
response.Extensions = s.extensionsHandler.Handle(ex, s.userCache, isInitial)
response.Extensions = s.extensionsHandler.Handle(ex, includedRoomIDs, isInitial)
// do live tracking if we have nothing to tell the client yet
responseOperations = s.live.liveUpdate(ctx, req, ex, isInitial, response, responseOperations)
@ -418,6 +422,10 @@ func (s *ConnState) moveRoom(
}
func (s *ConnState) OnUpdate(up caches.Update) {
s.live.onUpdate(up)
}
// Called by the user cache when updates arrive
func (s *ConnState) OnRoomUpdate(up caches.RoomUpdate) {
switch update := up.(type) {

View File

@ -62,12 +62,14 @@ func (s *connStateLive) liveUpdate(
return responseOperations
case update := <-s.updates:
responseOperations = s.processLiveUpdate(update, responseOperations, response)
updateWillReturnResponse := len(responseOperations) > 0
// pass event to extensions AFTER processing
s.extensionsHandler.HandleLiveData(ex, &response.Extensions, s.userCache, isInitial)
s.extensionsHandler.HandleLiveUpdate(update, ex, &response.Extensions, updateWillReturnResponse, isInitial)
// if there's more updates and we don't have lots stacked up already, go ahead and process another
for len(s.updates) > 0 && len(responseOperations) < 50 {
update = <-s.updates
responseOperations = s.processLiveUpdate(update, responseOperations, response)
s.extensionsHandler.HandleLiveUpdate(update, ex, &response.Extensions, updateWillReturnResponse, isInitial)
}
}
}

View File

@ -19,12 +19,11 @@ import (
type NopExtensionHandler struct{}
func (h *NopExtensionHandler) Handle(req extensions.Request, uc *caches.UserCache, isInitial bool) (res extensions.Response) {
func (h *NopExtensionHandler) Handle(req extensions.Request, listRoomIDs map[string]struct{}, isInitial bool) (res extensions.Response) {
return
}
func (h *NopExtensionHandler) HandleLiveData(req extensions.Request, res *extensions.Response, uc *caches.UserCache, isInitial bool) {
return
func (h *NopExtensionHandler) HandleLiveUpdate(u caches.Update, req extensions.Request, res *extensions.Response, updateWillReturnResponse, isInitial bool) {
}
type NopJoinTracker struct{}

View File

@ -69,6 +69,8 @@ type pointInfo struct {
isOpen bool
}
// TODO: A,B,C track A,B then B,C incorrectly keeps B?
// Delta returns the ranges which are unchanged, added and removed.
// Intelligently handles overlaps.
func (r SliceRanges) Delta(next SliceRanges) (added SliceRanges, removed SliceRanges, same SliceRanges) {

View File

@ -75,6 +75,20 @@ func (r *Response) UnmarshalJSON(b []byte) error {
type ResponseOp interface {
Op() string
// which rooms are we giving data about
IncludedRoomIDs() []string
}
// Return which room IDs these set of operations are returning information on. Information means
// things like SYNC/INSERT/UPDATE, and not DELETE/INVALIDATE.
func IncludedRoomIDsInOps(ops []ResponseOp) map[string]struct{} {
set := make(map[string]struct{})
for _, o := range ops {
for _, roomID := range o.IncludedRoomIDs() {
set[roomID] = struct{}{}
}
}
return set
}
type ResponseOpRange struct {
@ -87,6 +101,16 @@ type ResponseOpRange struct {
func (r *ResponseOpRange) Op() string {
return r.Operation
}
func (r *ResponseOpRange) IncludedRoomIDs() []string {
if r.Op() == OpInvalidate {
return nil // the rooms are being excluded
}
roomIDs := make([]string, len(r.Rooms))
for i := range r.Rooms {
roomIDs[i] = r.Rooms[i].RoomID
}
return roomIDs
}
type ResponseOpSingle struct {
Operation string `json:"op"`
@ -98,3 +122,12 @@ type ResponseOpSingle struct {
func (r *ResponseOpSingle) Op() string {
return r.Operation
}
func (r *ResponseOpSingle) IncludedRoomIDs() []string {
if r.Op() == OpDelete || r.Room == nil {
return nil // the room is being excluded
}
return []string{
r.Room.RoomID,
}
}

View File

@ -89,3 +89,18 @@ func NewEvent(t *testing.T, evType, sender string, content interface{}, originSe
}
return j
}
func NewAccountData(t *testing.T, evType string, content interface{}) json.RawMessage {
e := struct {
Type string `json:"type"`
Content interface{} `json:"content"`
}{
Type: evType,
Content: content,
}
j, err := json.Marshal(&e)
if err != nil {
t.Fatalf("NewAccountData: failed to make event JSON: %s", err)
}
return j
}

View File

@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strings"
"sync"
"testing"
@ -605,6 +606,34 @@ func MatchV3Ops(matchOps ...opMatcher) respMatcher {
}
}
func MatchAccountData(globals []json.RawMessage, rooms map[string][]json.RawMessage) respMatcher {
return func(res *sync3.Response) error {
if res.Extensions.AccountData == nil {
return fmt.Errorf("MatchAccountData: no account_data extension")
}
if len(globals) > 0 {
if err := equalAnyOrder(res.Extensions.AccountData.Global, globals); err != nil {
return fmt.Errorf("MatchAccountData[global]: %s", err)
}
}
if len(rooms) > 0 {
if len(rooms) != len(res.Extensions.AccountData.Rooms) {
return fmt.Errorf("MatchAccountData: got %d rooms with account data, want %d", len(res.Extensions.AccountData.Rooms), len(rooms))
}
for roomID := range rooms {
gots := res.Extensions.AccountData.Rooms[roomID]
if gots == nil {
return fmt.Errorf("MatchAccountData: want room account data for %s but it was missing", roomID)
}
if err := equalAnyOrder(gots, rooms[roomID]); err != nil {
return fmt.Errorf("MatchAccountData[room]: %s", err)
}
}
}
return nil
}
}
func MatchResponse(t *testing.T, res *sync3.Response, matchers ...respMatcher) {
t.Helper()
for _, m := range matchers {
@ -619,3 +648,21 @@ func MatchResponse(t *testing.T, res *sync3.Response, matchers ...respMatcher) {
func ptr(i int) *int {
return &i
}
func equalAnyOrder(got, want []json.RawMessage) error {
if len(got) != len(want) {
return fmt.Errorf("equalAnyOrder: got %d, want %d", len(got), len(want))
}
sort.Slice(got, func(i, j int) bool {
return string(got[i]) < string(got[j])
})
sort.Slice(want, func(i, j int) bool {
return string(want[i]) < string(want[j])
})
for i := range got {
if !reflect.DeepEqual(got[i], want[i]) {
return fmt.Errorf("equalAnyOrder: [%d] got %v want %v", i, string(got[i]), string(want[i]))
}
}
return nil
}