Add account data aggregation; with tests

This commit is contained in:
Kegan Dougal 2023-02-10 12:53:23 +00:00
parent b3d76e90be
commit c063352508
3 changed files with 132 additions and 16 deletions

View File

@ -39,17 +39,13 @@ func accountEventsAsJSON(events []state.AccountData) []json.RawMessage {
}
func (r *AccountDataRequest) AppendLive(ctx context.Context, res *Response, extCtx Context, up caches.Update) {
var globalMsgs []json.RawMessage
roomToMsgs := map[string][]json.RawMessage{}
switch update := up.(type) {
case *caches.AccountDataUpdate:
res.AccountData = &AccountDataResponse{ // TODO: aggregate
Global: accountEventsAsJSON(update.AccountData),
}
globalMsgs = accountEventsAsJSON(update.AccountData)
case *caches.RoomAccountDataUpdate:
res.AccountData = &AccountDataResponse{ // TODO: aggregate
Rooms: map[string][]json.RawMessage{
update.RoomID(): accountEventsAsJSON(update.AccountData),
},
}
roomToMsgs[update.RoomID()] = accountEventsAsJSON(update.AccountData)
case caches.RoomUpdate:
// if this is a room update which is included in the response, send account data for this room
if _, exists := extCtx.RoomIDToTimeline[update.RoomID()]; exists {
@ -57,14 +53,22 @@ func (r *AccountDataRequest) AppendLive(ctx context.Context, res *Response, extC
if err != nil {
logger.Err(err).Str("user", extCtx.UserID).Str("room", update.RoomID()).Msg("failed to fetch room account data")
} else {
res.AccountData = &AccountDataResponse{ // TODO: aggregate
Rooms: map[string][]json.RawMessage{
update.RoomID(): accountEventsAsJSON(roomAccountData),
},
}
roomToMsgs[update.RoomID()] = accountEventsAsJSON(roomAccountData)
}
}
}
if len(globalMsgs) == 0 && len(roomToMsgs) == 0 {
return
}
if res.AccountData == nil {
res.AccountData = &AccountDataResponse{
Rooms: make(map[string][]json.RawMessage),
}
}
res.AccountData.Global = append(res.AccountData.Global, globalMsgs...)
for roomID, roomAccountData := range roomToMsgs {
res.AccountData.Rooms[roomID] = append(res.AccountData.Rooms[roomID], roomAccountData...)
}
}
func (r *AccountDataRequest) ProcessInitial(ctx context.Context, res *Response, extCtx Context) {
@ -74,7 +78,9 @@ func (r *AccountDataRequest) ProcessInitial(ctx context.Context, res *Response,
roomIDs[i] = roomID
i++
}
extRes := &AccountDataResponse{}
extRes := &AccountDataResponse{
Rooms: make(map[string][]json.RawMessage),
}
// room account data needs to be sent every time the user scrolls the list to get new room IDs
// TODO: remember which rooms the client has been told about
if len(roomIDs) > 0 {
@ -97,5 +103,7 @@ func (r *AccountDataRequest) ProcessInitial(ctx context.Context, res *Response,
extRes.Global = accountEventsAsJSON(globalAccountData)
}
}
res.AccountData = extRes
if len(extRes.Rooms) > 0 || len(extRes.Global) > 0 {
res.AccountData = extRes
}
}

View File

@ -0,0 +1,107 @@
package extensions
import (
"encoding/json"
"reflect"
"testing"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
// Test that aggregation works, which is hard to assert in integration tests
func TestLiveAccountDataAggregation(t *testing.T) {
boolTrue := true
ext := &AccountDataRequest{
Enableable: Enableable{
Enabled: &boolTrue,
},
}
var res Response
var extCtx Context
room1 := &caches.RoomAccountDataUpdate{
RoomUpdate: &dummyRoomUpdate{
roomID: roomA,
globalMetadata: &internal.RoomMetadata{
RoomID: roomA,
},
},
AccountData: []state.AccountData{
{
Data: []byte(`{"foo":"bar"}`),
},
{
Data: []byte(`{"foo2":"bar2"}`),
},
},
}
room2 := &caches.RoomAccountDataUpdate{
RoomUpdate: &dummyRoomUpdate{
roomID: roomA,
globalMetadata: &internal.RoomMetadata{
RoomID: roomA,
},
},
AccountData: []state.AccountData{
{
Data: []byte(`{"foo3":"bar3"}`),
},
{
Data: []byte(`{"foo4":"bar4"}`),
},
},
}
global1 := &caches.AccountDataUpdate{
AccountData: []state.AccountData{
{
Data: []byte(`{"global":"bar"}`),
},
{
Data: []byte(`{"global2":"bar2"}`),
},
},
}
global2 := &caches.AccountDataUpdate{
AccountData: []state.AccountData{
{
Data: []byte(`{"global3":"bar3"}`),
},
{
Data: []byte(`{"global4":"bar4"}`),
},
},
}
ext.AppendLive(ctx, &res, extCtx, room1)
wantRoomAccountData := map[string][]json.RawMessage{
roomA: []json.RawMessage{
room1.AccountData[0].Data, room1.AccountData[1].Data,
},
}
if !reflect.DeepEqual(res.AccountData.Rooms, wantRoomAccountData) {
t.Fatalf("got %+v\nwant %+v", res.AccountData.Rooms, wantRoomAccountData)
}
ext.AppendLive(ctx, &res, extCtx, room2)
ext.AppendLive(ctx, &res, extCtx, global1)
ext.AppendLive(ctx, &res, extCtx, global2)
if res.AccountData == nil {
t.Fatalf("account_data response is empty")
}
wantRoomAccountData = map[string][]json.RawMessage{
roomA: []json.RawMessage{
room1.AccountData[0].Data, room1.AccountData[1].Data,
room2.AccountData[0].Data, room2.AccountData[1].Data,
},
}
if !reflect.DeepEqual(res.AccountData.Rooms, wantRoomAccountData) {
t.Fatalf("got %+v\nwant %+v", res.AccountData.Rooms, wantRoomAccountData)
}
wantGlobalAccountData := []json.RawMessage{
global1.AccountData[0].Data, global1.AccountData[1].Data,
global2.AccountData[0].Data, global2.AccountData[1].Data,
}
if !reflect.DeepEqual(res.AccountData.Global, wantGlobalAccountData) {
t.Fatalf("got %+v\nwant %+v", res.AccountData.Global, wantGlobalAccountData)
}
}

View File

@ -81,5 +81,6 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx
if !hasUpdates {
return
}
res.E2EE = extRes // TODO: aggregate
// doesn't need aggregation as we just replace from the db
res.E2EE = extRes
}