mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
feature: Support account_data extension
Along with a battery of tests to make sure we give account data only for rooms being tracked in the sliding lists, unless it's global in which case we just send it on.
This commit is contained in:
parent
e8131a920f
commit
3ecef2062e
@ -285,3 +285,193 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
// - do we need sessions at all? Can we delete if the since value is incremented?
|
||||
// - check with ios folks if this level of co-ordination between processes is possible.
|
||||
}
|
||||
|
||||
// tests that the account data extension works:
|
||||
// 1- check global account data is sent on first connection
|
||||
// 2- check global account data updates are proxied through
|
||||
// 3- check room account data for the list only is sent
|
||||
// 4- check room account data for subscriptions are sent
|
||||
// 5- when the range changes, make sure room account data is sent
|
||||
// 6- when a room bumps into a range, make sure room account data is sent
|
||||
func TestExtensionAccountData(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
alice := "@alice:localhost"
|
||||
aliceToken := "ALICE_BEARER_TOKEN"
|
||||
roomA := "!a:localhost"
|
||||
roomB := "!b:localhost"
|
||||
roomC := "!c:localhost"
|
||||
globalAccountData := []json.RawMessage{
|
||||
testutils.NewAccountData(t, "im-global", map[string]interface{}{"body": "yep"}),
|
||||
testutils.NewAccountData(t, "im-also-global", map[string]interface{}{"body": "yep"}),
|
||||
}
|
||||
roomAAccountData := []json.RawMessage{
|
||||
testutils.NewAccountData(t, "im-a", map[string]interface{}{"body": "yep a"}),
|
||||
testutils.NewAccountData(t, "im-also-a", map[string]interface{}{"body": "yep A"}),
|
||||
}
|
||||
roomBAccountData := []json.RawMessage{
|
||||
testutils.NewAccountData(t, "im-b", map[string]interface{}{"body": "yep b"}),
|
||||
testutils.NewAccountData(t, "im-also-b", map[string]interface{}{"body": "yep B"}),
|
||||
}
|
||||
roomCAccountData := []json.RawMessage{
|
||||
testutils.NewAccountData(t, "im-c", map[string]interface{}{"body": "yep c"}),
|
||||
testutils.NewAccountData(t, "im-also-c", map[string]interface{}{"body": "yep C"}),
|
||||
}
|
||||
v2.addAccount(alice, aliceToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: globalAccountData,
|
||||
},
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: map[string]sync2.SyncV2JoinResponse{
|
||||
roomA: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: createRoomState(t, alice, time.Now()),
|
||||
},
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: roomAAccountData,
|
||||
},
|
||||
},
|
||||
roomB: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: createRoomState(t, alice, time.Now().Add(-1*time.Minute)),
|
||||
},
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: roomBAccountData,
|
||||
},
|
||||
},
|
||||
roomC: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: createRoomState(t, alice, time.Now().Add(-2*time.Minute)),
|
||||
},
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: roomCAccountData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 1- check global account data is sent on first connection
|
||||
// 3- check room account data for the list only is sent
|
||||
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1}, // first two rooms A,B
|
||||
},
|
||||
Sort: []string{sync3.SortByRecency},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 0,
|
||||
},
|
||||
}},
|
||||
})
|
||||
MatchResponse(t, res, MatchAccountData(
|
||||
globalAccountData,
|
||||
map[string][]json.RawMessage{
|
||||
roomA: roomAAccountData,
|
||||
roomB: roomBAccountData,
|
||||
},
|
||||
))
|
||||
|
||||
// 5- when the range changes, make sure room account data is sent
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 2}, // A,B,C
|
||||
},
|
||||
}},
|
||||
})
|
||||
MatchResponse(t, res, MatchAccountData(
|
||||
nil,
|
||||
map[string][]json.RawMessage{
|
||||
roomC: roomCAccountData,
|
||||
},
|
||||
))
|
||||
|
||||
// 4- check room account data for subscriptions are sent
|
||||
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomB: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
MatchResponse(t, res, MatchAccountData(
|
||||
globalAccountData,
|
||||
map[string][]json.RawMessage{
|
||||
roomB: roomBAccountData,
|
||||
},
|
||||
))
|
||||
|
||||
// 2- check global account data updates are proxied through
|
||||
newGlobalEvent := testutils.NewAccountData(t, "new_fun_event", map[string]interface{}{"much": "excite"})
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{newGlobalEvent},
|
||||
},
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{})
|
||||
MatchResponse(t, res, MatchAccountData(
|
||||
[]json.RawMessage{newGlobalEvent},
|
||||
nil,
|
||||
))
|
||||
|
||||
// 6- when a room bumps into a range, make sure room account data is sent
|
||||
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1}, // first two rooms A,B
|
||||
},
|
||||
Sort: []string{sync3.SortByRecency},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 0,
|
||||
},
|
||||
}},
|
||||
})
|
||||
// bump C to position 0
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomC,
|
||||
events: []json.RawMessage{
|
||||
testutils.NewEvent(t, "m.poke", alice, map[string]interface{}{}, time.Now().Add(time.Millisecond)),
|
||||
},
|
||||
}),
|
||||
},
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
// now we should get room account data for C
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1}, // first two rooms A,B
|
||||
},
|
||||
}},
|
||||
})
|
||||
MatchResponse(t, res, MatchAccountData(
|
||||
nil,
|
||||
map[string][]json.RawMessage{
|
||||
roomC: roomCAccountData,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/sync-v3/sqlutil"
|
||||
)
|
||||
|
||||
@ -72,6 +73,17 @@ func (t *AccountDataTable) Select(txn *sqlx.Tx, userID, eventType, roomID string
|
||||
return &acc, err
|
||||
}
|
||||
|
||||
func (t *AccountDataTable) SelectMany(txn *sqlx.Tx, userID string, roomIDs ...string) (datas []AccountData, err error) {
|
||||
if len(roomIDs) == 0 {
|
||||
err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data
|
||||
WHERE user_id=$1 AND room_id = $2`, userID, AccountDataGlobalRoom)
|
||||
return
|
||||
}
|
||||
err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data
|
||||
WHERE user_id=$1 AND room_id=ANY($2)`, userID, pq.StringArray(roomIDs))
|
||||
return
|
||||
}
|
||||
|
||||
type AccountDataChunker []AccountData
|
||||
|
||||
func (c AccountDataChunker) Len() int {
|
||||
|
@ -2,12 +2,26 @@ package state
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sync-v3/sync2"
|
||||
)
|
||||
|
||||
func accountDatasEqual(gots, wants []AccountData) bool {
|
||||
key := func(a AccountData) string {
|
||||
return a.UserID + a.RoomID + a.Type
|
||||
}
|
||||
sort.Slice(gots, func(i, j int) bool {
|
||||
return key(gots[i]) < key(gots[j])
|
||||
})
|
||||
sort.Slice(wants, func(i, j int) bool {
|
||||
return key(wants[i]) < key(wants[j])
|
||||
})
|
||||
return reflect.DeepEqual(gots, wants)
|
||||
}
|
||||
|
||||
func TestAccountData(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
@ -53,6 +67,12 @@ func TestAccountData(t *testing.T) {
|
||||
Type: eventType,
|
||||
Data: []byte(`{"foo":"bar4"}`),
|
||||
},
|
||||
{
|
||||
UserID: alice,
|
||||
RoomID: sync2.AccountDataGlobalRoom,
|
||||
Type: "dummy",
|
||||
Data: []byte(`{"foo":"bar5"}`),
|
||||
},
|
||||
// this should replace the first element
|
||||
{
|
||||
UserID: alice,
|
||||
@ -83,8 +103,41 @@ func TestAccountData(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Select: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(*gotData, accountData[len(accountData)-2]) {
|
||||
t.Fatalf("Select: expected global event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-2])
|
||||
if !reflect.DeepEqual(*gotData, accountData[len(accountData)-3]) {
|
||||
t.Fatalf("Select: expected global event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-3])
|
||||
}
|
||||
|
||||
// Select all global events for alice
|
||||
wantDatas := []AccountData{
|
||||
accountData[4], accountData[5],
|
||||
}
|
||||
gotDatas, err := table.SelectMany(txn, alice)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectMany: %s", err)
|
||||
}
|
||||
if !accountDatasEqual(gotDatas, wantDatas) {
|
||||
t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas)
|
||||
}
|
||||
|
||||
// Select all room events for alice
|
||||
wantDatas = []AccountData{
|
||||
accountData[6],
|
||||
}
|
||||
gotDatas, err = table.SelectMany(txn, alice, roomA)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectMany: %s", err)
|
||||
}
|
||||
if !accountDatasEqual(gotDatas, wantDatas) {
|
||||
t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas)
|
||||
}
|
||||
|
||||
// Select all room events for unknown user
|
||||
gotDatas, err = table.SelectMany(txn, "@someone-else:localhost", roomA)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectMany: %s", err)
|
||||
}
|
||||
if len(gotDatas) != 0 {
|
||||
t.Fatalf("SelectMany: got %d account data, want 0", len(gotDatas))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -61,6 +61,16 @@ func (s *Storage) AccountData(userID, roomID, eventType string) (data *AccountDa
|
||||
return
|
||||
}
|
||||
|
||||
// Pull out all account data for this user. If roomIDs is empty, global account data is returned.
|
||||
// If roomIDs is non-empty, all account data for these rooms are extracted.
|
||||
func (s *Storage) AccountDatas(userID string, roomIDs ...string) (datas []AccountData, err error) {
|
||||
err = sqlutil.WithTransaction(s.accumulator.db, func(txn *sqlx.Tx) error {
|
||||
datas, err = s.AccountDataTable.SelectMany(txn, userID, roomIDs...)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Storage) InsertAccountData(userID, roomID string, events []json.RawMessage) (data []AccountData, err error) {
|
||||
data = make([]AccountData, len(events))
|
||||
for i := range events {
|
||||
|
@ -25,7 +25,11 @@ type UnreadCountUpdate struct {
|
||||
HasCountDecreased bool
|
||||
}
|
||||
|
||||
type RoomAccountDataAlert struct {
|
||||
type AccountDataUpdate struct {
|
||||
AccountData []state.AccountData
|
||||
}
|
||||
|
||||
type RoomAccountDataUpdate struct {
|
||||
RoomUpdate
|
||||
AccountData []state.AccountData
|
||||
}
|
||||
|
@ -21,9 +21,9 @@ type UserCacheListener interface {
|
||||
// Called when there is an update affecting a room e.g new event, unread count update, room account data.
|
||||
// Type-cast to find out what the update is about.
|
||||
OnRoomUpdate(up RoomUpdate)
|
||||
// Called when there is an update affecting this user e.g global account data, presence.
|
||||
// Called when there is an update affecting this user but not in the room e.g global account data, presence.
|
||||
// Type-cast to find out what the update is about.
|
||||
// OnUserUpdate(up UserUpdate)
|
||||
OnUpdate(up Update)
|
||||
}
|
||||
|
||||
// Tracks data specific to a given user. Specifically, this is the map of room ID to UserRoomData.
|
||||
@ -222,7 +222,11 @@ func (c *UserCache) OnNewEvent(eventData *EventData) {
|
||||
}
|
||||
|
||||
func (c *UserCache) OnAccountData(datas []state.AccountData) {
|
||||
roomUpdates := make(map[string][]state.AccountData)
|
||||
for _, d := range datas {
|
||||
up := roomUpdates[d.RoomID]
|
||||
up = append(up, d)
|
||||
roomUpdates[d.RoomID] = up
|
||||
if d.Type == "m.direct" {
|
||||
dmRoomSet := make(map[string]struct{})
|
||||
// pull out rooms and mark them as DMs
|
||||
@ -249,5 +253,26 @@ func (c *UserCache) OnAccountData(datas []state.AccountData) {
|
||||
}
|
||||
c.roomToDataMu.Unlock()
|
||||
}
|
||||
|
||||
}
|
||||
// bucket account data updates per-room and globally then invoke listeners
|
||||
for roomID, updates := range roomUpdates {
|
||||
if roomID == state.AccountDataGlobalRoom {
|
||||
globalUpdate := &AccountDataUpdate{
|
||||
AccountData: updates,
|
||||
}
|
||||
for _, l := range c.listeners {
|
||||
l.OnUpdate(globalUpdate)
|
||||
}
|
||||
} else {
|
||||
roomUpdate := &RoomAccountDataUpdate{
|
||||
AccountData: updates,
|
||||
RoomUpdate: c.newRoomUpdate(roomID),
|
||||
}
|
||||
for _, l := range c.listeners {
|
||||
l.OnRoomUpdate(roomUpdate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -9,24 +9,11 @@ import (
|
||||
|
||||
// Client created request params
|
||||
type AccountDataRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
GlobalAccountDataTypes []string `json:"global_account_data_types"`
|
||||
RoomAccountDataTypes map[int][]string `json:"room_account_data_types"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (r AccountDataRequest) ApplyDelta(next *AccountDataRequest) *AccountDataRequest {
|
||||
r.Enabled = next.Enabled
|
||||
if next.GlobalAccountDataTypes != nil {
|
||||
r.GlobalAccountDataTypes = next.GlobalAccountDataTypes
|
||||
}
|
||||
if next.RoomAccountDataTypes != nil {
|
||||
if r.RoomAccountDataTypes == nil {
|
||||
r.RoomAccountDataTypes = make(map[int][]string)
|
||||
}
|
||||
for listIndex, types := range next.RoomAccountDataTypes {
|
||||
r.RoomAccountDataTypes[listIndex] = types
|
||||
}
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
@ -43,14 +30,74 @@ func (r *AccountDataResponse) HasData(isInitial bool) bool {
|
||||
return len(r.Rooms) > 0
|
||||
}
|
||||
|
||||
func ProcessAccountData(store *state.Storage, userCache *caches.UserCache, userID string, isInitial bool, req *AccountDataRequest) (res *AccountDataResponse) {
|
||||
if isInitial {
|
||||
// register with the user cache TODO: need destructor call to unregister
|
||||
// pull account data from store and return that
|
||||
func accountEventsAsJSON(events []state.AccountData) []json.RawMessage {
|
||||
j := make([]json.RawMessage, len(events))
|
||||
for i := range events {
|
||||
j[i] = events[i].Data
|
||||
}
|
||||
// on new account data callback, buffer it and then return it assuming isInitial=false
|
||||
return j
|
||||
}
|
||||
|
||||
// TODO: how to handle room account data? Need to know which rooms are being tracked.
|
||||
func ProcessLiveAccountData(up caches.Update, store *state.Storage, updateWillReturnResponse bool, userID string, req *AccountDataRequest) (res *AccountDataResponse) {
|
||||
switch update := up.(type) {
|
||||
case *caches.AccountDataUpdate:
|
||||
return &AccountDataResponse{
|
||||
Global: accountEventsAsJSON(update.AccountData),
|
||||
}
|
||||
case *caches.RoomAccountDataUpdate:
|
||||
return &AccountDataResponse{
|
||||
Rooms: map[string][]json.RawMessage{
|
||||
update.RoomID(): accountEventsAsJSON(update.AccountData),
|
||||
},
|
||||
}
|
||||
case caches.RoomUpdate:
|
||||
// this is a room update which is causing us to return, meaning we are interested in this room.
|
||||
// send account data for this room.
|
||||
if updateWillReturnResponse {
|
||||
roomAccountData, err := store.AccountDatas(userID, update.RoomID())
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Str("room", update.RoomID()).Msg("failed to fetch room account data")
|
||||
} else {
|
||||
return &AccountDataResponse{
|
||||
Rooms: map[string][]json.RawMessage{
|
||||
update.RoomID(): accountEventsAsJSON(roomAccountData),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessAccountData(store *state.Storage, listRoomIDs map[string]struct{}, userID string, isInitial bool, req *AccountDataRequest) (res *AccountDataResponse) {
|
||||
roomIDs := make([]string, len(listRoomIDs))
|
||||
i := 0
|
||||
for roomID := range listRoomIDs {
|
||||
roomIDs[i] = roomID
|
||||
i++
|
||||
}
|
||||
res = &AccountDataResponse{}
|
||||
// room account data needs to be sent every time the user scrolls the list to get new room IDs
|
||||
// TODO: remember which rooms the client has been told about
|
||||
if len(roomIDs) > 0 {
|
||||
roomsAccountData, err := store.AccountDatas(userID, roomIDs...)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Strs("rooms", roomIDs).Msg("failed to fetch room account data")
|
||||
} else {
|
||||
res.Rooms = make(map[string][]json.RawMessage)
|
||||
for _, ad := range roomsAccountData {
|
||||
res.Rooms[ad.RoomID] = append(res.Rooms[ad.RoomID], ad.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
// global account data is only sent on the first connection, then we live stream
|
||||
if isInitial {
|
||||
globalAccountData, err := store.AccountDatas(userID)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Msg("failed to fetch global account data")
|
||||
} else {
|
||||
res.Global = accountEventsAsJSON(globalAccountData)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -29,6 +29,9 @@ func (r Request) ApplyDelta(next *Request) Request {
|
||||
if next.E2EE != nil {
|
||||
r.E2EE = r.E2EE.ApplyDelta(next.E2EE)
|
||||
}
|
||||
if next.AccountData != nil {
|
||||
r.AccountData = r.AccountData.ApplyDelta(next.AccountData)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
@ -45,8 +48,8 @@ func (e Response) HasData(isInitial bool) bool {
|
||||
}
|
||||
|
||||
type HandlerInterface interface {
|
||||
Handle(req Request, userCache *caches.UserCache, isInitial bool) (res Response)
|
||||
HandleLiveData(req Request, res *Response, userCache *caches.UserCache, isInitial bool)
|
||||
Handle(req Request, listRoomIDs map[string]struct{}, isInitial bool) (res Response)
|
||||
HandleLiveUpdate(update caches.Update, req Request, res *Response, updateWillReturnResponse, isInitial bool)
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
@ -54,12 +57,13 @@ type Handler struct {
|
||||
E2EEFetcher sync2.E2EEFetcher
|
||||
}
|
||||
|
||||
func (h *Handler) HandleLiveData(req Request, res *Response, userCache *caches.UserCache, isInitial bool) {
|
||||
// TODO: Define live data event in caches pkg, NOT ConnEvent.
|
||||
// Update `Response` object.
|
||||
func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Response, updateWillReturnResponse, isInitial bool) {
|
||||
if req.AccountData != nil && req.AccountData.Enabled {
|
||||
res.AccountData = ProcessLiveAccountData(update, h.Store, updateWillReturnResponse, req.UserID, req.AccountData)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(req Request, userCache *caches.UserCache, isInitial bool) (res Response) {
|
||||
func (h *Handler) Handle(req Request, listRoomIDs map[string]struct{}, isInitial bool) (res Response) {
|
||||
if req.ToDevice != nil && req.ToDevice.Enabled != nil && *req.ToDevice.Enabled {
|
||||
res.ToDevice = ProcessToDevice(h.Store, req.UserID, req.DeviceID, req.ToDevice)
|
||||
}
|
||||
@ -67,7 +71,7 @@ func (h *Handler) Handle(req Request, userCache *caches.UserCache, isInitial boo
|
||||
res.E2EE = ProcessE2EE(h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE)
|
||||
}
|
||||
if req.AccountData != nil && req.AccountData.Enabled {
|
||||
res.AccountData = ProcessAccountData(h.Store, userCache, req.UserID, isInitial, req.AccountData)
|
||||
res.AccountData = ProcessAccountData(h.Store, listRoomIDs, req.UserID, isInitial, req.AccountData)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -162,9 +162,13 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
responseOperations = append(responseOperations, ops...)
|
||||
}
|
||||
|
||||
includedRoomIDs := sync3.IncludedRoomIDsInOps(responseOperations)
|
||||
for _, roomID := range newSubs { // include room subs in addition to lists
|
||||
includedRoomIDs[roomID] = struct{}{}
|
||||
}
|
||||
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
|
||||
// is being notified about (e.g. for room account data)
|
||||
response.Extensions = s.extensionsHandler.Handle(ex, s.userCache, isInitial)
|
||||
response.Extensions = s.extensionsHandler.Handle(ex, includedRoomIDs, isInitial)
|
||||
|
||||
// do live tracking if we have nothing to tell the client yet
|
||||
responseOperations = s.live.liveUpdate(ctx, req, ex, isInitial, response, responseOperations)
|
||||
@ -418,6 +422,10 @@ func (s *ConnState) moveRoom(
|
||||
|
||||
}
|
||||
|
||||
func (s *ConnState) OnUpdate(up caches.Update) {
|
||||
s.live.onUpdate(up)
|
||||
}
|
||||
|
||||
// Called by the user cache when updates arrive
|
||||
func (s *ConnState) OnRoomUpdate(up caches.RoomUpdate) {
|
||||
switch update := up.(type) {
|
||||
|
@ -62,12 +62,14 @@ func (s *connStateLive) liveUpdate(
|
||||
return responseOperations
|
||||
case update := <-s.updates:
|
||||
responseOperations = s.processLiveUpdate(update, responseOperations, response)
|
||||
updateWillReturnResponse := len(responseOperations) > 0
|
||||
// pass event to extensions AFTER processing
|
||||
s.extensionsHandler.HandleLiveData(ex, &response.Extensions, s.userCache, isInitial)
|
||||
s.extensionsHandler.HandleLiveUpdate(update, ex, &response.Extensions, updateWillReturnResponse, isInitial)
|
||||
// if there's more updates and we don't have lots stacked up already, go ahead and process another
|
||||
for len(s.updates) > 0 && len(responseOperations) < 50 {
|
||||
update = <-s.updates
|
||||
responseOperations = s.processLiveUpdate(update, responseOperations, response)
|
||||
s.extensionsHandler.HandleLiveUpdate(update, ex, &response.Extensions, updateWillReturnResponse, isInitial)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,12 +19,11 @@ import (
|
||||
|
||||
type NopExtensionHandler struct{}
|
||||
|
||||
func (h *NopExtensionHandler) Handle(req extensions.Request, uc *caches.UserCache, isInitial bool) (res extensions.Response) {
|
||||
func (h *NopExtensionHandler) Handle(req extensions.Request, listRoomIDs map[string]struct{}, isInitial bool) (res extensions.Response) {
|
||||
return
|
||||
}
|
||||
|
||||
func (h *NopExtensionHandler) HandleLiveData(req extensions.Request, res *extensions.Response, uc *caches.UserCache, isInitial bool) {
|
||||
return
|
||||
func (h *NopExtensionHandler) HandleLiveUpdate(u caches.Update, req extensions.Request, res *extensions.Response, updateWillReturnResponse, isInitial bool) {
|
||||
}
|
||||
|
||||
type NopJoinTracker struct{}
|
||||
|
@ -69,6 +69,8 @@ type pointInfo struct {
|
||||
isOpen bool
|
||||
}
|
||||
|
||||
// TODO: A,B,C track A,B then B,C incorrectly keeps B?
|
||||
|
||||
// Delta returns the ranges which are unchanged, added and removed.
|
||||
// Intelligently handles overlaps.
|
||||
func (r SliceRanges) Delta(next SliceRanges) (added SliceRanges, removed SliceRanges, same SliceRanges) {
|
||||
|
@ -75,6 +75,20 @@ func (r *Response) UnmarshalJSON(b []byte) error {
|
||||
|
||||
type ResponseOp interface {
|
||||
Op() string
|
||||
// which rooms are we giving data about
|
||||
IncludedRoomIDs() []string
|
||||
}
|
||||
|
||||
// Return which room IDs these set of operations are returning information on. Information means
|
||||
// things like SYNC/INSERT/UPDATE, and not DELETE/INVALIDATE.
|
||||
func IncludedRoomIDsInOps(ops []ResponseOp) map[string]struct{} {
|
||||
set := make(map[string]struct{})
|
||||
for _, o := range ops {
|
||||
for _, roomID := range o.IncludedRoomIDs() {
|
||||
set[roomID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
type ResponseOpRange struct {
|
||||
@ -87,6 +101,16 @@ type ResponseOpRange struct {
|
||||
func (r *ResponseOpRange) Op() string {
|
||||
return r.Operation
|
||||
}
|
||||
func (r *ResponseOpRange) IncludedRoomIDs() []string {
|
||||
if r.Op() == OpInvalidate {
|
||||
return nil // the rooms are being excluded
|
||||
}
|
||||
roomIDs := make([]string, len(r.Rooms))
|
||||
for i := range r.Rooms {
|
||||
roomIDs[i] = r.Rooms[i].RoomID
|
||||
}
|
||||
return roomIDs
|
||||
}
|
||||
|
||||
type ResponseOpSingle struct {
|
||||
Operation string `json:"op"`
|
||||
@ -98,3 +122,12 @@ type ResponseOpSingle struct {
|
||||
func (r *ResponseOpSingle) Op() string {
|
||||
return r.Operation
|
||||
}
|
||||
|
||||
func (r *ResponseOpSingle) IncludedRoomIDs() []string {
|
||||
if r.Op() == OpDelete || r.Room == nil {
|
||||
return nil // the room is being excluded
|
||||
}
|
||||
return []string{
|
||||
r.Room.RoomID,
|
||||
}
|
||||
}
|
||||
|
@ -89,3 +89,18 @@ func NewEvent(t *testing.T, evType, sender string, content interface{}, originSe
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
func NewAccountData(t *testing.T, evType string, content interface{}) json.RawMessage {
|
||||
e := struct {
|
||||
Type string `json:"type"`
|
||||
Content interface{} `json:"content"`
|
||||
}{
|
||||
Type: evType,
|
||||
Content: content,
|
||||
}
|
||||
j, err := json.Marshal(&e)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccountData: failed to make event JSON: %s", err)
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
47
v3_test.go
47
v3_test.go
@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -605,6 +606,34 @@ func MatchV3Ops(matchOps ...opMatcher) respMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func MatchAccountData(globals []json.RawMessage, rooms map[string][]json.RawMessage) respMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.AccountData == nil {
|
||||
return fmt.Errorf("MatchAccountData: no account_data extension")
|
||||
}
|
||||
if len(globals) > 0 {
|
||||
if err := equalAnyOrder(res.Extensions.AccountData.Global, globals); err != nil {
|
||||
return fmt.Errorf("MatchAccountData[global]: %s", err)
|
||||
}
|
||||
}
|
||||
if len(rooms) > 0 {
|
||||
if len(rooms) != len(res.Extensions.AccountData.Rooms) {
|
||||
return fmt.Errorf("MatchAccountData: got %d rooms with account data, want %d", len(res.Extensions.AccountData.Rooms), len(rooms))
|
||||
}
|
||||
for roomID := range rooms {
|
||||
gots := res.Extensions.AccountData.Rooms[roomID]
|
||||
if gots == nil {
|
||||
return fmt.Errorf("MatchAccountData: want room account data for %s but it was missing", roomID)
|
||||
}
|
||||
if err := equalAnyOrder(gots, rooms[roomID]); err != nil {
|
||||
return fmt.Errorf("MatchAccountData[room]: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MatchResponse(t *testing.T, res *sync3.Response, matchers ...respMatcher) {
|
||||
t.Helper()
|
||||
for _, m := range matchers {
|
||||
@ -619,3 +648,21 @@ func MatchResponse(t *testing.T, res *sync3.Response, matchers ...respMatcher) {
|
||||
func ptr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func equalAnyOrder(got, want []json.RawMessage) error {
|
||||
if len(got) != len(want) {
|
||||
return fmt.Errorf("equalAnyOrder: got %d, want %d", len(got), len(want))
|
||||
}
|
||||
sort.Slice(got, func(i, j int) bool {
|
||||
return string(got[i]) < string(got[j])
|
||||
})
|
||||
sort.Slice(want, func(i, j int) bool {
|
||||
return string(want[i]) < string(want[j])
|
||||
})
|
||||
for i := range got {
|
||||
if !reflect.DeepEqual(got[i], want[i]) {
|
||||
return fmt.Errorf("equalAnyOrder: [%d] got %v want %v", i, string(got[i]), string(want[i]))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user