Merge pull request #363 from matrix-org/dmr/resnapshot-3

This commit is contained in:
David Robertson 2023-11-10 11:38:12 +00:00 committed by GitHub
commit cbd3c3c5c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1080 additions and 220 deletions

View File

@ -141,69 +141,62 @@ type InitialiseResult struct {
// AddedEvents is true iff this call to Initialise added new state events to the DB.
AddedEvents bool
// SnapshotID is the ID of the snapshot which incorporates all added events.
// It has no meaning if AddedEvents is False.
// It has no meaning if AddedEvents is false.
SnapshotID int64
// PrependTimelineEvents is empty if the room was not initialised prior to this call.
// Otherwise, it is an order-preserving subset of the `state` argument to Initialise
// containing all events that were not persisted prior to the Initialise call. These
// should be prepended to the room timeline by the caller.
PrependTimelineEvents []json.RawMessage
// ReplacedExistingSnapshot is true when we created a new snapshot for the room and
// there a pre-existing room snapshot. It has no meaning if AddedEvents is false.
ReplacedExistingSnapshot bool
}
// Initialise starts a new sync accumulator for the given room using the given state as a baseline.
// Initialise processes the state block of a V2 sync response for a particular room. If
// the state of the room has changed, we persist any new state events and create a new
// "snapshot" of its entire state.
//
// This will only take effect if this is the first time the v3 server has seen this room, and it wasn't
// possible to get all events up to the create event (e.g Matrix HQ).
// This function:
// - Stores these events
// - Sets up the current snapshot based on the state list given.
// Summary of the logic:
//
// If the v3 server has seen this room before, this function
// - queries the DB to determine which state events are known to th server,
// - returns (via InitialiseResult.PrependTimelineEvents) a slice of unknown state events,
// 0. Ensure the state block is not empty.
//
// and otherwise does nothing.
// 1. Capture the current snapshot ID, possibly zero. If it is zero, ensure that the
// state block contains a `create event`.
//
// 2. Insert the events. If there are no newly inserted events, bail. If there are new
// events, then the state block has definitely changed. Note: we ignore cases where
// the state has only changed to a known subset of state events (i.e in the case of
// state resets, slow pollers) as it is impossible to then reconcile that state with
// any new events, as any "catchup" state will be ignored due to the events already
// existing.
//
// 3. Fetch the current state of the room, as a map from (type, state_key) to event.
// If there is no existing state snapshot, this map is the empty map.
// If the state hasn't altered, bail.
//
// 4. Create new snapshot. Update the map from (3) with the events in `state`.
// (There is similar logic for this in Accumulate.)
// Store the snapshot. Mark the room's current state as being this snapshot.
//
// 5. Any other processing of the new state events.
//
// 6. Return an "AddedEvents" bool (if true, emit an Initialise payload) and a
// "ReplacedSnapshot" bool (if true, emit a cache invalidation payload).
func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
var res InitialiseResult
var startingSnapshotID int64
// 0. Ensure the state block is not empty.
if len(state) == 0 {
return res, nil
}
err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) (err error) {
// 1. Capture the current snapshot ID, checking for a create event if this is our first snapshot.
// Attempt to short-circuit. This has to be done inside a transaction to make sure
// we don't race with multiple calls to Initialise with the same room ID.
snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
startingSnapshotID, err = a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err)
return fmt.Errorf("error fetching snapshot id for room %s: %w", roomID, err)
}
if snapshotID > 0 {
// Poller A has received a gappy sync v2 response with a state block, and
// we have seen this room before. If we knew for certain that there is some
// other active poller B in this room then we could safely skip this logic.
// Log at debug for now. If we find an unknown event, we'll return it so
// that the poller can log a warning.
logger.Debug().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called with incremental state but current snapshot already exists.")
eventIDs := make([]string, len(state))
eventIDToRawEvent := make(map[string]json.RawMessage, len(state))
for i := range state {
eventID := gjson.ParseBytes(state[i]).Get("event_id")
if !eventID.Exists() || eventID.Type != gjson.String {
return fmt.Errorf("Event %d lacks an event ID", i)
}
eventIDToRawEvent[eventID.Str] = state[i]
eventIDs[i] = eventID.Str
}
unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, eventIDs)
if err != nil {
return fmt.Errorf("error determing which event IDs are unknown: %s", err)
}
for unknownEventID := range unknownEventIDs {
res.PrependTimelineEvents = append(res.PrependTimelineEvents, eventIDToRawEvent[unknownEventID])
}
return nil
}
// We don't have a snapshot for this room. Parse the events first.
// Start by parsing the events in the state block.
events := make([]Event, len(state))
for i := range events {
events[i] = Event{
@ -214,77 +207,77 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
}
events = filterAndEnsureFieldsSet(events)
if len(events) == 0 {
return fmt.Errorf("failed to insert events, all events were filtered out: %w", err)
return fmt.Errorf("failed to parse state block, all events were filtered out: %w", err)
}
// Before proceeding further, ensure that we have "proper" state and not just a
// single stray event by looking for the create event.
hasCreate := false
for _, e := range events {
if e.Type == "m.room.create" && e.StateKey == "" {
hasCreate = true
break
if startingSnapshotID == 0 {
// Ensure that we have "proper" state and not "stray" events from Synapse.
if err = ensureStateHasCreateEvent(events); err != nil {
return err
}
}
if !hasCreate {
const errMsg = "cannot create first snapshot without a create event"
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetContext(internal.SentryCtxKey, map[string]interface{}{
"room_id": roomID,
"len_state": len(events),
})
sentry.CaptureMessage(errMsg)
})
logger.Warn().
Str("room_id", roomID).
Int("len_state", len(events)).
Msg(errMsg)
// the HS gave us bad data so there's no point retrying => return DataError
return internal.NewDataError(errMsg)
}
// Insert the events.
eventIDToNID, err := a.eventsTable.Insert(txn, events, false)
// 2. Insert the events and determine which ones are new.
newEventIDToNID, err := a.eventsTable.Insert(txn, events, false)
if err != nil {
return fmt.Errorf("failed to insert events: %w", err)
}
if len(eventIDToNID) == 0 {
// we don't have a current snapshot for this room but yet no events are new,
// no idea how this should be handled.
const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug."
logger.Error().Str("room_id", roomID).Msg(errMsg)
sentry.CaptureException(fmt.Errorf(errMsg))
if len(newEventIDToNID) == 0 {
if startingSnapshotID == 0 {
// we don't have a current snapshot for this room but yet no events are new,
// no idea how this should be handled.
const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug."
logger.Error().Str("room_id", roomID).Msg(errMsg)
sentry.CaptureException(fmt.Errorf(errMsg))
}
// Note: we otherwise ignore cases where the state has only changed to a
// known subset of state events (i.e in the case of state resets, slow
// pollers) as it is impossible to then reconcile that state with
// any new events, as any "catchup" state will be ignored due to the events
// already existing.
return nil
}
// pull out the event NIDs we just inserted
membershipEventIDs := make(map[string]struct{}, len(events))
newEvents := make([]Event, 0, len(newEventIDToNID))
for _, event := range events {
if event.Type == "m.room.member" {
membershipEventIDs[event.ID] = struct{}{}
}
}
memberNIDs := make([]int64, 0, len(eventIDToNID))
otherNIDs := make([]int64, 0, len(eventIDToNID))
for evID, nid := range eventIDToNID {
if _, exists := membershipEventIDs[evID]; exists {
memberNIDs = append(memberNIDs, int64(nid))
} else {
otherNIDs = append(otherNIDs, int64(nid))
newNid, isNew := newEventIDToNID[event.ID]
if isNew {
event.NID = newNid
newEvents = append(newEvents, event)
}
}
// Make a current snapshot
// 3. Fetch the current state of the room.
var currentState stateMap
if startingSnapshotID > 0 {
currentState, err = a.stateMapAtSnapshot(txn, startingSnapshotID)
if err != nil {
return fmt.Errorf("failed to load state map: %w", err)
}
} else {
currentState = stateMap{
// Typically expect Other to be small, but Memberships may be large (think: Matrix HQ.)
Memberships: make(map[string]int64, len(events)),
Other: make(map[[2]string]int64),
}
}
// 4. Update the map from (3) with the new events to create a new snapshot.
for _, ev := range newEvents {
currentState.Ingest(ev)
}
memberNIDs, otherNIDs := currentState.NIDs()
snapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: pq.Int64Array(memberNIDs),
OtherEvents: pq.Int64Array(otherNIDs),
MembershipEvents: memberNIDs,
OtherEvents: otherNIDs,
}
err = a.snapshotTable.Insert(txn, snapshot)
if err != nil {
return fmt.Errorf("failed to insert snapshot: %w", err)
}
res.AddedEvents = true
// 5. Any other processing of new state events.
latestNID := int64(0)
for _, nid := range otherNIDs {
if nid > latestNID {
@ -313,8 +306,16 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// will have an associated state snapshot ID on the event.
// Set the snapshot ID as the current state
err = a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID)
if err != nil {
return err
}
// 6. Tell the caller what happened, so they know what payloads to emit.
res.SnapshotID = snapshot.SnapshotID
return a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID)
res.AddedEvents = true
res.ReplacedExistingSnapshot = startingSnapshotID > 0
return nil
})
return res, err
}
@ -652,3 +653,82 @@ func (a *Accumulator) filterToNewTimelineEvents(txn *sqlx.Tx, dedupedEvents []Ev
// A is seen event s[A,B,C] => s[0+1:] => [B,C]
return dedupedEvents[seenIndex+1:], nil
}
func ensureStateHasCreateEvent(events []Event) error {
hasCreate := false
for _, e := range events {
if e.Type == "m.room.create" && e.StateKey == "" {
hasCreate = true
break
}
}
if !hasCreate {
const errMsg = "cannot create first snapshot without a create event"
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetContext(internal.SentryCtxKey, map[string]interface{}{
"room_id": events[0].RoomID,
"len_state": len(events),
})
sentry.CaptureMessage(errMsg)
})
logger.Warn().
Str("room_id", events[0].RoomID).
Int("len_state", len(events)).
Msg(errMsg)
// the HS gave us bad data so there's no point retrying => return DataError
return internal.NewDataError(errMsg)
}
return nil
}
type stateMap struct {
// state_key (user id) -> NID
Memberships map[string]int64
// type, state_key -> NID
Other map[[2]string]int64
}
func (s *stateMap) Ingest(e Event) (replacedNID int64) {
if e.Type == "m.room.member" {
replacedNID = s.Memberships[e.StateKey]
s.Memberships[e.StateKey] = e.NID
} else {
key := [2]string{e.Type, e.StateKey}
replacedNID = s.Other[key]
s.Other[key] = e.NID
}
return
}
func (s *stateMap) NIDs() (membershipNIDs, otherNIDs []int64) {
membershipNIDs = make([]int64, 0, len(s.Memberships))
otherNIDs = make([]int64, 0, len(s.Other))
for _, nid := range s.Memberships {
membershipNIDs = append(membershipNIDs, nid)
}
for _, nid := range s.Other {
otherNIDs = append(otherNIDs, nid)
}
return
}
func (a *Accumulator) stateMapAtSnapshot(txn *sqlx.Tx, snapID int64) (stateMap, error) {
snapshot, err := a.snapshotTable.Select(txn, snapID)
if err != nil {
return stateMap{}, err
}
// pull stripped events as this may be huge (think Matrix HQ)
events, err := a.eventsTable.SelectStrippedEventsByNIDs(txn, true, append(snapshot.MembershipEvents, snapshot.OtherEvents...))
if err != nil {
return stateMap{}, err
}
state := stateMap{
Memberships: make(map[string]int64, len(snapshot.MembershipEvents)),
Other: make(map[[2]string]int64, len(snapshot.OtherEvents)),
}
for _, e := range events {
state.Ingest(e)
}
return state, nil
}

View File

@ -35,9 +35,8 @@ func TestAccumulatorInitialise(t *testing.T) {
if err != nil {
t.Fatalf("falied to Initialise accumulator: %s", err)
}
if !res.AddedEvents {
t.Fatalf("didn't add events, wanted it to")
}
assertValue(t, "res.AddedEvents", res.AddedEvents, true)
assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, false)
txn, err := accumulator.db.Beginx()
if err != nil {
@ -46,21 +45,21 @@ func TestAccumulatorInitialise(t *testing.T) {
defer txn.Rollback()
// There should be one snapshot on the current state
snapID, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID)
snapID1, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
t.Fatalf("failed to select current snapshot: %s", err)
}
if snapID == 0 {
if snapID1 == 0 {
t.Fatalf("Initialise did not store a current snapshot")
}
if snapID != res.SnapshotID {
t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID)
if snapID1 != res.SnapshotID {
t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID1)
}
// this snapshot should have 1 member event and 2 other events in it
row, err := accumulator.snapshotTable.Select(txn, snapID)
row, err := accumulator.snapshotTable.Select(txn, snapID1)
if err != nil {
t.Fatalf("failed to select snapshot %d: %s", snapID, err)
t.Fatalf("failed to select snapshot %d: %s", snapID1, err)
}
if len(row.MembershipEvents) != 1 {
t.Fatalf("got %d membership events, want %d in current state snapshot", len(row.MembershipEvents), 1)
@ -87,7 +86,7 @@ func TestAccumulatorInitialise(t *testing.T) {
}
}
// Subsequent calls do nothing and are not an error
// Subsequent calls with the same set of the events do nothing and are not an error.
res, err = accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("falied to Initialise accumulator: %s", err)
@ -95,6 +94,37 @@ func TestAccumulatorInitialise(t *testing.T) {
if res.AddedEvents {
t.Fatalf("added events when it shouldn't have")
}
// Subsequent calls with a subset of events do nothing and are not an error
res, err = accumulator.Initialise(roomID, roomEvents[:2])
if err != nil {
t.Fatalf("falied to Initialise accumulator: %s", err)
}
if res.AddedEvents {
t.Fatalf("added events when it shouldn't have")
}
// Subsequent calls with at least one new event expand or replace existing state.
// C, D, E
roomEvents2 := append(roomEvents[2:3],
[]byte(`{"event_id":"D", "type":"m.room.topic", "state_key":"", "content":{"topic":"Dr Rick Dagless MD"}}`),
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join", "displayname": "Garth""}}`),
)
res, err = accumulator.Initialise(roomID, roomEvents2)
assertNoError(t, err)
assertValue(t, "res.AddedEvents", res.AddedEvents, true)
assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, true)
snapID2, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID)
assertNoError(t, err)
if snapID2 == snapID1 || snapID2 == 0 {
t.Errorf("Expected snapID2 (%d) to be neither snapID1 (%d) nor 0", snapID2, snapID1)
}
row, err = accumulator.snapshotTable.Select(txn, snapID2)
assertNoError(t, err)
assertValue(t, "len(row.MembershipEvents)", len(row.MembershipEvents), 1)
assertValue(t, "len(row.OtherEvents)", len(row.OtherEvents), 3)
}
// Test that an unknown room shouldn't initialise if given state without a create event.
@ -115,9 +145,9 @@ func TestAccumulatorInitialiseBadInputs(t *testing.T) {
func TestAccumulatorAccumulate(t *testing.T) {
roomID := "!TestAccumulatorAccumulate:localhost"
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"G", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"H", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"I", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
db, close := connectToDB(t)
defer close()
@ -130,11 +160,11 @@ func TestAccumulatorAccumulate(t *testing.T) {
// accumulate new state makes a new snapshot and removes the old snapshot
newEvents := []json.RawMessage{
// non-state event does nothing
[]byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
[]byte(`{"event_id":"J", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
// join_rules should clobber the one from initialise
[]byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"K", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
[]byte(`{"event_id":"L", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
var result AccumulateResult
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {

View File

@ -367,6 +367,8 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
// For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID.
// These shouldn't be changing during a room's lifetime in normal operation.
// We haven't updated LatestEventsByType because that's not part of the timeline.
return nil
}

View File

@ -370,20 +370,24 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID strin
return nil
}
func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) {
func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error {
res, err := h.Store.Initialise(roomID, state)
if err != nil {
logger.Err(err).Int("state", len(state)).Str("room", roomID).Msg("V2: failed to initialise room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return nil, err
return err
}
if res.AddedEvents {
if res.ReplacedExistingSnapshot {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{
RoomID: roomID,
})
} else if res.AddedEvents {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Initialise{
RoomID: roomID,
SnapshotNID: res.SnapshotID,
})
}
return res.PrependTimelineEvents, nil
return nil
}
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {

View File

@ -42,7 +42,7 @@ type V2DataReceiver interface {
// Initialise the room, if it hasn't been already. This means the state section of the v2 response.
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
// Return an error to stop the since token advancing.
Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) // snapshot ID?
Initialise(ctx context.Context, roomID string, state []json.RawMessage) error // snapshot ID?
// SetTyping indicates which users are typing.
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
// Sent when there is a new receipt
@ -328,11 +328,11 @@ func (h *PollerMap) Accumulate(ctx context.Context, userID, deviceID, roomID str
wg.Wait()
return
}
func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (result []json.RawMessage, err error) {
func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (err error) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
result, err = h.callbacks.Initialise(ctx, roomID, state)
err = h.callbacks.Initialise(ctx, roomID, state)
wg.Done()
}
wg.Wait()
@ -791,7 +791,10 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) erro
for roomID, roomData := range res.Rooms.Join {
if len(roomData.State.Events) > 0 {
stateCalls++
prependStateEvents, err := p.receiver.Initialise(ctx, roomID, roomData.State.Events)
if roomData.Timeline.Limited {
p.trackGappyStateSize(len(roomData.State.Events))
}
err := p.receiver.Initialise(ctx, roomID, roomData.State.Events)
if err != nil {
_, ok := err.(*internal.DataError)
if ok {
@ -836,25 +839,6 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) erro
continue
}
}
if len(prependStateEvents) > 0 {
// The poller has just learned of these state events due to an
// incremental poller sync; we must have missed the opportunity to see
// these down /sync in a timeline. As a workaround, inject these into
// the timeline now so that future events are received under the
// correct room state.
const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll"
logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg)
hub := internal.GetSentryHubFromContextOrDefault(ctx)
hub.WithScope(func(scope *sentry.Scope) {
scope.SetContext(internal.SentryCtxKey, map[string]interface{}{
"room_id": roomID,
"num_prepend_state_events": len(prependStateEvents),
})
hub.CaptureMessage(warnMsg)
})
p.trackGappyStateSize(len(prependStateEvents))
roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...)
}
}
// process typing/receipts before events so we seed the caches correctly for when we return the room
for _, ephEvent := range roomData.Ephemeral.Events {

View File

@ -830,8 +830,8 @@ func TestPollerResendsOnCallbackError(t *testing.T) {
// generate a receiver which errors for the right callback
generateReceiver: func() V2DataReceiver {
return &overrideDataReceiver{
initialise: func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) {
return nil, fmt.Errorf("initialise error")
initialise: func(ctx context.Context, roomID string, state []json.RawMessage) error {
return fmt.Errorf("initialise error")
},
}
},
@ -1273,7 +1273,7 @@ func (a *mockDataReceiver) Accumulate(ctx context.Context, userID, deviceID, roo
a.timelines[roomID] = append(a.timelines[roomID], timeline.Events...)
return nil
}
func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) {
func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error {
a.states[roomID] = state
if a.incomingProcess != nil {
a.incomingProcess <- struct{}{}
@ -1283,7 +1283,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
}
// The return value is a list of unknown state events to be prepended to the room
// timeline. Untested here---return nil for now.
return nil, nil
return nil
}
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
s.mu.Lock()
@ -1296,7 +1296,7 @@ func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, device
type overrideDataReceiver struct {
accumulate func(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) error
initialise func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error)
initialise func(ctx context.Context, roomID string, state []json.RawMessage) error
setTyping func(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
updateDeviceSince func(ctx context.Context, userID, deviceID, since string)
addToDeviceMessages func(ctx context.Context, userID, deviceID string, msgs []json.RawMessage) error
@ -1316,9 +1316,9 @@ func (s *overrideDataReceiver) Accumulate(ctx context.Context, userID, deviceID,
}
return s.accumulate(ctx, userID, deviceID, roomID, timeline.PrevBatch, timeline.Events)
}
func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) {
func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error {
if s.initialise == nil {
return nil, nil
return nil
}
return s.initialise(ctx, roomID, state)
}

View File

@ -771,10 +771,3 @@ func (u *UserCache) ShouldIgnore(userID string) bool {
_, ignored := u.ignoredUsers[userID]
return ignored
}
func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) {
// Nothing for now. In UserRoomData the fields dependant on room state are
// IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces.
// Not clear to me if we need to reload these or if we will inherit any changes from
// the global cache.
}

View File

@ -181,6 +181,23 @@ func (m *ConnMap) connIDsForDevice(userID, deviceID string) []ConnID {
return connIDs
}
// CloseConnsForUsers closes all conns for a given slice of users. Returns the number of
// conns closed.
func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) {
m.mu.Lock()
defer m.mu.Unlock()
for _, userID := range userIDs {
conns := m.userIDToConn[userID]
logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()")
for _, conn := range conns {
m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
}
closed += len(conns)
}
return closed
}
func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()

View File

@ -24,7 +24,6 @@ type Receiver interface {
OnNewEvent(ctx context.Context, event *caches.EventData)
OnReceipt(ctx context.Context, receipt internal.Receipt)
OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage)
OnInvalidateRoom(ctx context.Context, roomID string)
// OnRegistered is called after a successful call to Dispatcher.Register
OnRegistered(ctx context.Context) error
}
@ -62,6 +61,24 @@ func (d *Dispatcher) Unregister(userID string) {
delete(d.userToReceiver, userID)
}
// UnregisterBulk accepts a slice of user IDs to unregister. The given users need not
// already be registered (in which case unregistering them is a no-op). Returns the
// list of users that were unregistered.
func (d *Dispatcher) UnregisterBulk(userIDs []string) []string {
d.userToReceiverMu.Lock()
defer d.userToReceiverMu.Unlock()
unregistered := make([]string, 0)
for _, userID := range userIDs {
_, exists := d.userToReceiver[userID]
if exists {
delete(d.userToReceiver, userID)
unregistered = append(unregistered, userID)
}
}
return unregistered
}
func (d *Dispatcher) Register(ctx context.Context, userID string, r Receiver) error {
d.userToReceiverMu.Lock()
defer d.userToReceiverMu.Unlock()
@ -276,22 +293,7 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData,
}
}
func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) {
// First dispatch to the global cache.
receiver, ok := d.userToReceiver[DispatcherAllUsers]
if !ok {
logger.Error().Msgf("No receiver for global cache")
}
receiver.OnInvalidateRoom(ctx, roomID)
// Then dispatch to any users who are joined to that room.
joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil)
d.userToReceiverMu.RLock()
defer d.userToReceiverMu.RUnlock()
for _, userID := range joinedUsers {
receiver = d.userToReceiver[userID]
if receiver != nil {
receiver.OnInvalidateRoom(ctx, roomID)
}
}
func (d *Dispatcher) OnInvalidateRoom(roomID string, joins, invites []string) {
// Reset the joined room tracker.
d.jrt.ReloadMembershipsForRoom(roomID, joins, invites)
}

View File

@ -63,6 +63,11 @@ type SyncLiveHandler struct {
setupHistVec *prometheus.HistogramVec
histVec *prometheus.HistogramVec
slowReqs prometheus.Counter
// destroyedConns is the number of connections that have been destoryed after
// a room invalidation payload.
// TODO: could make this a CounterVec labelled by reason, to track expiry due
// to update buffer filling, expiry due to inactivity, etc.
destroyedConns prometheus.Counter
}
func NewSync3Handler(
@ -139,6 +144,9 @@ func (h *SyncLiveHandler) Teardown() {
if h.slowReqs != nil {
prometheus.Unregister(h.slowReqs)
}
if h.destroyedConns != nil {
prometheus.Unregister(h.destroyedConns)
}
}
func (h *SyncLiveHandler) addPrometheusMetrics() {
@ -162,9 +170,17 @@ func (h *SyncLiveHandler) addPrometheusMetrics() {
Name: "slow_requests",
Help: "Counter of slow (>=50s) requests, initial or otherwise.",
})
h.destroyedConns = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "destroyed_conns",
Help: "Counter of conns that were destroyed.",
})
prometheus.MustRegister(h.setupHistVec)
prometheus.MustRegister(h.histVec)
prometheus.MustRegister(h.slowReqs)
prometheus.MustRegister(h.destroyedConns)
}
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -818,7 +834,46 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) {
ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom")
defer task.End()
h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID)
// 1. Reload the global cache.
h.GlobalCache.OnInvalidateRoom(ctx, p.RoomID)
// Work out who is affected.
joins, invites, leaves, err := h.Storage.FetchMemberships(p.RoomID)
involvedUsers := make([]string, 0, len(joins)+len(invites)+len(leaves))
involvedUsers = append(involvedUsers, joins...)
involvedUsers = append(involvedUsers, invites...)
involvedUsers = append(involvedUsers, leaves...)
if err != nil {
hub := internal.GetSentryHubFromContextOrDefault(ctx)
hub.WithScope(func(scope *sentry.Scope) {
scope.SetContext(internal.SentryCtxKey, map[string]any{
"room_id": p.RoomID,
})
hub.CaptureException(err)
})
logger.Err(err).
Str("room_id", p.RoomID).
Msg("Failed to fetch members after cache invalidation")
return
}
// 2. Reload the joined-room tracker.
h.Dispatcher.OnInvalidateRoom(p.RoomID, joins, invites)
// 3. Destroy involved users' caches.
// We filter to only those users which had a userCache registered to receive updates.
unregistered := h.Dispatcher.UnregisterBulk(involvedUsers)
for _, userID := range unregistered {
h.userCaches.Delete(userID)
}
// 4. Destroy involved users' connections.
// Since creating a conn creates a user cache, it is safe to loop over
destroyed := h.ConnMap.CloseConnsForUsers(unregistered)
if h.destroyedConns != nil {
h.destroyedConns.Add(float64(destroyed))
}
}
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {

View File

@ -1,6 +1,7 @@
package syncv3_test
import (
"encoding/json"
"fmt"
"testing"
@ -65,10 +66,21 @@ func TestGappyState(t *testing.T) {
Content: nameContent,
})
t.Log("Alice sends lots of message events (more than the poller will request in a timeline.")
var latestMessageID string
for i := 0; i < 51; i++ {
latestMessageID = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
t.Log("Alice sends lots of other state events.")
const numOtherState = 40
for i := 0; i < numOtherState; i++ {
alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
Type: "com.example.dummy",
StateKey: ptr(fmt.Sprintf("%d", i)),
Content: map[string]any{},
})
}
t.Log("Alice sends a batch of message events.")
const numMessages = 20
var lastMsgID string
for i := 0; i < numMessages; i++ {
lastMsgID = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
@ -77,28 +89,50 @@ func TestGappyState(t *testing.T) {
})
}
t.Log("Alice requests an initial sliding sync on device 2.")
t.Logf("The proxy is now %d events behind the HS, which should trigger a limited sync", 1+numOtherState+numMessages)
t.Log("Alice requests an initial sliding sync on device 2, with timeline limit big enough to see her first message at the start of the test.")
syncResp = alice.SlidingSync(t,
sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: [][2]int64{{0, 20}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
TimelineLimit: 100,
},
},
},
},
)
t.Log("She should see her latest message with the room name updated")
// We're testing here that the state events from the gappy poll are NOT injected
// into the timeline. The poll is only going to use timeline limit 1 because it's
// the first poll on a new device. See integration test for a "proper" gappy poll.
t.Log("She should see the updated room name, her most recent message, but NOT the state events in the gap nor messages from before the gap.")
m.MatchResponse(
t,
syncResp,
m.MatchRoomSubscription(
roomID,
m.MatchRoomName("potato"),
MatchRoomTimelineMostRecent(1, []Event{{ID: latestMessageID}}),
MatchRoomTimelineMostRecent(1, []Event{{ID: lastMsgID}}),
func(r sync3.Room) error {
for _, rawEv := range r.Timeline {
var ev Event
err := json.Unmarshal(rawEv, &ev)
if err != nil {
t.Fatal(err)
}
// Shouldn't see the state events, only messages
if ev.Type != "m.room.message" {
return fmt.Errorf("timeline contained event %s of type %s (expected m.room.message)", ev.ID, ev.Type)
}
if ev.ID == firstMessageID {
return fmt.Errorf("timeline contained first message from before the gap")
}
}
return nil
},
),
)
}

View File

@ -119,18 +119,14 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
m.MatchResponse(t, res, m.MatchToDeviceMessages([]json.RawMessage{wantMsg}))
}
// Test that the poller makes a best-effort attempt to integrate state seen in a
// v2 sync state block. Our strategy for doing so is to prepend any unknown state events
// to the start of the v2 sync response's timeline, which should then be visible to
// sync v3 clients as ordinary state events in the room timeline.
func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) {
// FIXME: this should resolve once we update downstream caches
t.Skip("We will never see the name/PL event in the timeline with the new code due to those events being part of the state block.")
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
t.Log("Alice creates a room.")
v2.addAccount(t, alice, aliceToken)
const roomID = "!unimportant"
v2.queueResponse(aliceToken, sync2.SyncResponse{
@ -141,18 +137,21 @@ func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) {
}),
},
})
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
t.Log("Alice sliding syncs, explicitly requesting power levels.")
aliceReq := sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: [][2]int64{{0, 20}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
RequiredState: [][2]string{{"m.room.power_levels", ""}},
},
},
},
})
}
res := v3.mustDoV3Request(t, aliceToken, aliceReq)
t.Log("The poller receives a gappy incremental sync response with a state block. The power levels and room name have changed.")
t.Log("Alice's poller receives a gappy poll with a state block. The power levels and room name have changed.")
nameEvent := testutils.NewStateEvent(
t,
"m.room.name",
@ -187,37 +186,26 @@ func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) {
},
},
})
v2.waitUntilEmpty(t, aliceToken)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{})
m.MatchResponse(
t,
res,
m.MatchRoomSubscription(
roomID,
func(r sync3.Room) error {
// syncv2 doesn't assign any meaning to the order of events in a state
// block, so check for both possibilities
nameFirst := m.MatchRoomTimeline([]json.RawMessage{nameEvent, powerLevelsEvent, messageEvent})
powerLevelsFirst := m.MatchRoomTimeline([]json.RawMessage{powerLevelsEvent, nameEvent, messageEvent})
if nameFirst(r) != nil && powerLevelsFirst(r) != nil {
return fmt.Errorf("did not see state before message")
}
return nil
},
m.MatchRoomName("banana"),
),
)
t.Log("Alice incremental sliding syncs.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{})
t.Log("The server should have closed the long-polling session.")
assertUnknownPos(t, respBytes, statusCode)
t.Log("Alice sliding syncs from scratch.")
res = v3.mustDoV3Request(t, aliceToken, aliceReq)
t.Log("Alice sees the new room name and power levels.")
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID,
m.MatchRoomRequiredState([]json.RawMessage{powerLevelsEvent}),
m.MatchRoomName("banana"),
))
}
// Similar to TestPollerHandlesUnknownStateEventsOnIncrementalSync. Here we are testing
// that if Alice's poller sees Bob leave in a state block, the events seen in that
// timeline are not visible to Bob.
func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) {
// the room state should update to make bob no longer be a member, which should update downstream caches
// DO WE SEND THESE GAPPY STATES TO THE CLIENT? It's NOT part of the timeline, but we need to let the client
// know somehow? I think the best case here would be to invalidate that _room_ (if that were possible in the API)
// to force the client to resync the state.
t.Skip("figure out what the valid thing to do here is")
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
@ -312,15 +300,21 @@ func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) {
},
},
})
v2.waitUntilEmpty(t, aliceToken)
t.Log("Bob makes an incremental sliding sync request.")
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
t.Log("He should see his leave event in the room timeline.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), bobToken, bobRes.Pos, sync3.Request{})
assertUnknownPos(t, respBytes, statusCode)
t.Log("Bob makes a new sliding sync session.")
bobRes = v3.mustDoV3Request(t, bobToken, syncRequest)
t.Log("He shouldn't see any evidence of the room.")
m.MatchResponse(
t,
bobRes,
m.MatchList("a", m.MatchV3Count(1)),
m.MatchRoomSubscription(roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bobLeave})),
m.MatchList("a", m.MatchV3Count(0)),
m.MatchRoomSubscriptionsStrict(nil),
)
}
@ -600,3 +594,668 @@ func TestTimelineStopsLoadingWhenMissingPrevious(t *testing.T) {
m.MatchRoomPrevBatch("dummyPrevBatch"),
))
}
// The "prepend state events" mechanism added in
// https://github.com/matrix-org/sliding-sync/pull/71 ensured that the proxy
// communicated state events in "gappy syncs" to users. But it did so via Accumulate,
// which made one snapshot for each state event. This was not an accurate model of the
// room's history (the state block comes in no particular order) and had awful
// performance for large gappy states.
//
// We now want to handle these in Initialise, making a single snapshot for the state
// block. This test ensures that is the case. The logic is very similar to the e2e test
// TestGappyState.
func TestGappyStateDoesNotAccumulateTheStateBlock(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
defer v2.close()
v3 := runTestServer(t, v2, pqString)
defer v3.close()
v2.addAccount(t, alice, aliceToken)
v2.addAccount(t, bob, bobToken)
t.Log("Alice creates a room, sets its name and sends a message.")
const roomID = "!unimportant"
name1 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{
"name": "wonderland",
})
msg1 := testutils.NewMessageEvent(t, alice, "0118 999 881 999 119 7253")
joinTimeline := v2JoinTimeline(roomEvents{
roomID: roomID,
events: append(
createRoomState(t, alice, time.Now()),
name1,
msg1,
),
})
v2.queueResponse(aliceToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: joinTimeline,
},
})
t.Log("Alice sliding syncs with a huge timeline limit, subscribing to the room she just created.")
aliceReq := sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {TimelineLimit: 100},
},
}
res := v3.mustDoV3Request(t, aliceToken, aliceReq)
t.Log("Alice sees the room with the expected name, with the name event and message at the end of the timeline.")
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID,
m.MatchRoomName("wonderland"),
m.MatchRoomTimelineMostRecent(2, []json.RawMessage{name1, msg1}),
))
t.Log("Alice's poller receives a gappy sync, including a room name change, bob joining, and two messages.")
stateBlock := make([]json.RawMessage, 0)
for i := 0; i < 10; i++ {
statePiece := testutils.NewStateEvent(t, "com.example.custom", fmt.Sprintf("%d", i), alice, map[string]any{})
stateBlock = append(stateBlock, statePiece)
}
name2 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{
"name": "not wonderland",
})
bobJoin := testutils.NewJoinEvent(t, bob)
stateBlock = append(stateBlock, name2, bobJoin)
msg2 := testutils.NewMessageEvent(t, alice, "Good morning!")
msg3 := testutils.NewMessageEvent(t, alice, "That's a nice tnetennba.")
v2.queueResponse(aliceToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomID: {
State: sync2.EventsResponse{
Events: stateBlock,
},
Timeline: sync2.TimelineResponse{
Events: []json.RawMessage{msg2, msg3},
Limited: true,
PrevBatch: "dummyPrevBatch",
},
},
},
},
})
v2.waitUntilEmpty(t, aliceToken)
t.Log("Alice syncs. The server should close her long-polling session.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{})
assertUnknownPos(t, respBytes, statusCode)
t.Log("Alice sliding syncs from scratch. She should see the two most recent message in the timeline only. The room name should have changed too.")
res = v3.mustDoV3Request(t, aliceToken, aliceReq)
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID,
m.MatchRoomName("not wonderland"),
// In particular, we shouldn't see state here because it's not part of the timeline.
// Nor should we see msg1, as that comes before a gap.
m.MatchRoomTimeline([]json.RawMessage{msg2, msg3}),
))
}
// Right, this has turned out to be very involved. This test has three varying
// parameters:
// - Bert's initial membership (in 3 below),
// - his final membership in (5), and
// - whether his sync in (6) is initial or long-polling ("live").
//
// The test:
// 1. Registers two users Ana and Bert.
// 2. Has Ana create a public room.
// 3. Sets an initial membership for Bert in that room.
// 4. Sliding syncs for Bert, if he will live-sync in (6) below.
// 5. Gives Ana's poller a gappy poll in which Bert's membership changes.
// 6. Has Bert do a sliding sync.
// 7. Ana invites Bert to a DM.
//
// We perform the following assertions:
// - After (3), Ana sees her membership, Bert's initial membership, appropriate
// join and invite counts, and an appropriate timeline.
// - If applicable: after (4), Bert sees his initial membership.
// - After (5), Ana's connection is closed. When opening a new one, she sees her
// membership, Bert's new membership, and the post-gap timeline.
// - After (6), Bert's connection is closed if he was expecting a live update.
// - After (6), Bert sees his new membership (if there is anything to see).
// - After (7), Bert sees the DM invite.
//
// Remarks:
// - Use a per-test Ana and Bert here so we don't clash with the global constants
// alice and bob.
// - We're feeding all this information in via Ana's poller to check that stuff
// propagates from her poller to Bert's client. However, when Bob's membership is
// "invite" we need to directly send the invite to his poller.
// - Step (7) serves as a sentinel to prove that the proxy has processed (5) in the
// case where there is nothing for Bert to see in (6), e.g. a preemptive ban or
// an unban during the gap.
// - Testing all the membership transitions is likely overkill. But it was useful
// for finding edge cases in the proxy's assumptions at first, before we decided to
// nuke conns and userCaches and start from scratch.
func TestClientsSeeMembershipTransitionsInGappyPolls(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
// TODO remove this? Otherwise running tests is sloooooow
v2.timeToWaitForV2Response /= 20
defer v2.close()
v3 := runTestServer(t, v2, pqString)
defer v3.close()
type testcase struct {
// Inputs
beforeMembership string
afterMembership string
viaLiveUpdate bool
// Scratch space
id string
ana string
anaToken string
bert string
bertToken string
publicRoomID string // room that will receive gappy state
dmRoomID string // DM between ana and bert, used to send a sentinel message
}
var tcs []testcase
transitions := map[string][]string{
// before: {possible after}
// https://spec.matrix.org/v1.8/client-server-api/#room-membership for the list of allowed transitions
"none": {"ban", "invite", "join", "leave"},
"invite": {"ban", "join", "leave"},
// Note: can also join->join here e.g. for displayname change, but will do that in a separate test
"join": {"ban", "leave"},
"leave": {"ban", "invite", "join"},
"ban": {"leave"},
}
for before, afterOptions := range transitions {
for _, after := range afterOptions {
for _, live := range []bool{true, false} {
idStr := fmt.Sprintf("%s-%s", before, after)
if live {
idStr += "-live"
}
tc := testcase{
beforeMembership: before,
afterMembership: after,
viaLiveUpdate: live,
id: idStr,
publicRoomID: fmt.Sprintf("!%s-public", idStr),
dmRoomID: fmt.Sprintf("!%s-dm", idStr),
// Using ana and bert to stop myself from pulling in package-level constants alice and bob
ana: fmt.Sprintf("@ana-%s:localhost", idStr),
bert: fmt.Sprintf("@bert-%s:localhost", idStr),
}
tc.anaToken = tc.ana + "_token"
tc.bertToken = tc.bert + "_token"
tcs = append(tcs, tc)
}
}
}
ssRequest := sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 10}},
RoomSubscription: sync3.RoomSubscription{
RequiredState: [][2]string{{"m.room.member", "*"}},
TimelineLimit: 20,
},
},
},
}
setup := func(t *testing.T, tc testcase) (publicEvents []json.RawMessage, anaMembership json.RawMessage, anaRes *sync3.Response) {
// 1. Register two users Ana and Bert.
v2.addAccount(t, tc.ana, tc.anaToken)
v2.addAccount(t, tc.bert, tc.bertToken)
// 2. Have Ana create a public room.
t.Log("Ana creates a public room.")
publicEvents = createRoomState(t, tc.ana, time.Now())
for _, ev := range publicEvents {
parsed := gjson.ParseBytes(ev)
if parsed.Get("type").Str == "m.room.member" && parsed.Get("state_key").Str == tc.ana {
anaMembership = ev
break
}
}
// 3. Set an initial membership for Bert.
var wantJoinCount int
var wantInviteCount int
var bertMembership json.RawMessage
switch tc.beforeMembership {
case "none":
t.Log("Bert has no membership in the room.")
wantJoinCount = 1
wantInviteCount = 0
case "invite":
t.Log("Bert is invited.")
bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"})
wantJoinCount = 1
wantInviteCount = 1
case "join":
t.Log("Bert joins the room.")
bertMembership = testutils.NewJoinEvent(t, tc.bert)
wantJoinCount = 2
wantInviteCount = 0
case "leave":
t.Log("Bert is pre-emptively kicked.")
bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "leave"})
wantJoinCount = 1
wantInviteCount = 0
case "ban":
t.Log("Bert is banned.")
bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "ban"})
wantJoinCount = 1
wantInviteCount = 0
default:
panic(fmt.Errorf("unknown beforeMembership %s", tc.beforeMembership))
}
if len(bertMembership) > 0 {
publicEvents = append(publicEvents, bertMembership)
}
t.Log("Ana's poller sees the public room for the first time.")
v2.queueResponse(tc.anaToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
tc.publicRoomID: {
Timeline: sync2.TimelineResponse{
Events: publicEvents,
PrevBatch: "anaPublicPrevBatch1",
},
},
},
},
NextBatch: "anaSync1",
})
t.Log("Ana sliding syncs, requesting all room members.")
anaRes = v3.mustDoV3Request(t, tc.anaToken, ssRequest)
t.Log("She sees herself joined to both rooms, with appropriate timelines and counts.")
// Note: we only expect timeline[1:] here, not the create event. See
// https://github.com/matrix-org/sliding-sync/issues/343
expectedMembers := []json.RawMessage{anaMembership}
if len(bertMembership) > 0 {
expectedMembers = append(expectedMembers, bertMembership)
}
m.MatchResponse(t, anaRes,
m.MatchRoomSubscription(tc.publicRoomID,
m.MatchRoomTimeline(publicEvents[1:]),
m.MatchRoomRequiredState(expectedMembers),
m.MatchJoinCount(wantJoinCount),
m.MatchInviteCount(wantInviteCount),
),
)
return
}
gappyPoll := func(t *testing.T, tc testcase, anaMembership json.RawMessage, anaRes *sync3.Response) (newMembership json.RawMessage, publicTimeline []json.RawMessage) {
t.Logf("Ana's poller gets a gappy sync response for the public room. Bert's membership is now %s, and Ana has sent 10 messages.", tc.afterMembership)
publicTimeline = make([]json.RawMessage, 10)
for i := range publicTimeline {
publicTimeline[i] = testutils.NewMessageEvent(t, tc.ana, fmt.Sprintf("hello %d", i))
}
switch tc.afterMembership {
case "invite":
newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"})
case "join":
newMembership = testutils.NewJoinEvent(t, tc.bert)
case "leave":
newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "leave"})
case "ban":
newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "ban"})
default:
panic(fmt.Errorf("unknown afterMembership %s", tc.afterMembership))
}
v2.queueResponse(tc.anaToken, sync2.SyncResponse{
NextBatch: "ana2",
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
tc.publicRoomID: {
State: sync2.EventsResponse{
Events: []json.RawMessage{newMembership},
},
Timeline: sync2.TimelineResponse{
Events: publicTimeline,
Limited: true,
PrevBatch: "anaPublicPrevBatch2",
},
},
},
},
})
v2.waitUntilEmpty(t, tc.anaToken)
if tc.afterMembership == "invite" {
t.Log("Bert's poller sees his invite.")
v2.queueResponse(tc.bertToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Invite: map[string]sync2.SyncV2InviteResponse{
tc.publicRoomID: {
InviteState: sync2.EventsResponse{
// TODO: this really ought to be stripped state events
Events: []json.RawMessage{anaMembership, newMembership},
},
},
}},
NextBatch: tc.bert + "_invite",
})
}
t.Log("Ana syncs.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), tc.anaToken, anaRes.Pos, sync3.Request{})
t.Log("Her long-polling session has been closed by the server.")
assertUnknownPos(t, respBytes, statusCode)
t.Log("Ana syncs again from scratch.")
anaRes = v3.mustDoV3Request(t, tc.anaToken, ssRequest)
t.Log("She sees both her and Bob's membership, and the timeline from the gappy poll.")
// Note: we don't expect to see the pre-gap timeline, here because we stop at
// the first gap we see in the timeline.
m.MatchResponse(t, anaRes, m.MatchRoomSubscription(tc.publicRoomID,
m.MatchRoomRequiredState([]json.RawMessage{anaMembership, newMembership}),
m.MatchRoomTimeline(publicTimeline),
))
return
}
for _, tc := range tcs {
t.Run(tc.id, func(t *testing.T) {
// 1--3: Register users, create public room, set Bert's membership.
publicEvents, anaMembership, anaRes := setup(t, tc)
defer func() {
// Cleanup these users once we're done with them. This helps stop log spam when debugging.
v2.invalidateTokenImmediately(tc.anaToken)
v2.invalidateTokenImmediately(tc.bertToken)
}()
// Ensure the proxy considers Bert to already be polling. In particular, if
// Bert is initially invited, make sure his poller sees the invite.
if tc.beforeMembership == "invite" {
t.Log("Bert's poller sees his invite.")
v2.queueResponse(tc.bertToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Invite: map[string]sync2.SyncV2InviteResponse{
tc.publicRoomID: {
InviteState: sync2.EventsResponse{
// TODO: this really ought to be stripped state events
Events: publicEvents,
},
},
}},
NextBatch: tc.bert + "_invite",
})
} else {
t.Log("Queue up an empty poller response for Bert.")
v2.queueResponse(tc.bertToken, sync2.SyncResponse{
NextBatch: tc.bert + "_empty_sync",
})
}
t.Log("Bert makes a dummy request with a different connection ID, to ensure his poller has started.")
v3.mustDoV3Request(t, tc.bertToken, sync3.Request{
ConnID: "bert-dummy-conn",
})
var bertRes *sync3.Response
// 4: sliding sync for Bert, if he will live-sync in (6) below.
if tc.viaLiveUpdate {
t.Log("Bert sliding syncs.")
bertRes = v3.mustDoV3Request(t, tc.bertToken, ssRequest)
// Bert will see the entire history of these rooms, so there shouldn't be any prev batch tokens.
expectedSubscriptions := map[string][]m.RoomMatcher{}
switch tc.beforeMembership {
case "invite":
t.Log("Bert sees his invite.")
expectedSubscriptions[tc.publicRoomID] = []m.RoomMatcher{
m.MatchRoomHasInviteState(),
m.MatchInviteCount(1),
m.MatchJoinCount(1),
m.MatchRoomPrevBatch(""),
}
case "join":
t.Log("Bert sees his join.")
expectedSubscriptions[tc.publicRoomID] = []m.RoomMatcher{
m.MatchRoomLacksInviteState(),
m.MatchInviteCount(0),
m.MatchJoinCount(2),
m.MatchRoomPrevBatch(""),
}
case "none":
fallthrough
case "leave":
fallthrough
case "ban":
t.Log("Bert does not see the room.")
default:
panic(fmt.Errorf("unknown beforeMembership %s", tc.beforeMembership))
}
m.MatchResponse(t, bertRes, m.MatchRoomSubscriptionsStrict(expectedSubscriptions))
}
// 5: Ana receives a gappy poll, plus a sentinel in her DM with Bert.
newMembership, publicTimeline := gappyPoll(t, tc, anaMembership, anaRes)
// 6: Bert sliding syncs.
if tc.viaLiveUpdate {
wasInvolvedInRoom := tc.beforeMembership == "join" || tc.beforeMembership == "invite"
if wasInvolvedInRoom {
t.Log("Bert makes an incremental sliding sync.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), tc.bertToken, bertRes.Pos, ssRequest)
assertUnknownPos(t, respBytes, statusCode)
}
} else {
t.Log("Queue up an empty poller response for Bert. so the proxy will consider him to be polling.")
v2.queueResponse(tc.bertToken, sync2.SyncResponse{
NextBatch: tc.bert + "_empty_sync",
})
}
t.Log("Bert makes new sliding sync connection.")
bertRes = v3.mustDoV3Request(t, tc.bertToken, ssRequest)
// Work out what Bert should see.
respMatchers := []m.RespMatcher{}
switch tc.afterMembership {
case "invite":
t.Log("Bert should see his invite.")
respMatchers = append(respMatchers,
m.MatchList("a", m.MatchV3Count(1)),
m.MatchRoomSubscription(tc.publicRoomID,
m.MatchRoomHasInviteState(),
m.MatchInviteCount(1),
m.MatchJoinCount(1),
))
case "join":
t.Log("Bert should see himself joined to the room, and Alice's messages.")
respMatchers = append(respMatchers,
m.MatchList("a", m.MatchV3Count(1)),
m.MatchRoomSubscription(tc.publicRoomID,
m.MatchRoomLacksInviteState(),
m.MatchRoomRequiredState([]json.RawMessage{anaMembership, newMembership}),
m.MatchInviteCount(0),
m.MatchJoinCount(2),
m.MatchRoomTimelineMostRecent(len(publicTimeline), publicTimeline),
m.MatchRoomPrevBatch("anaPublicPrevBatch2"),
))
case "leave":
fallthrough
case "ban":
respMatchers = append(respMatchers, m.MatchList("a", m.MatchV3Count(0)))
// Any prior connection has been closed by the server, so Bert won't see
// a transition here.
t.Logf("Bob shouldn't see his %s (membership was: %s)", tc.afterMembership, tc.beforeMembership)
respMatchers = append(respMatchers, m.MatchRoomSubscriptionsStrict(nil))
default:
panic(fmt.Errorf("unknown afterMembership %s", tc.afterMembership))
}
m.MatchResponse(t, bertRes, respMatchers...)
// 7: Ana invites Bert to a DM. He accepts.
// This is a sentinel which proves the proxy has processed the gappy poll
// properly in the situations where there's nothing for Bert to see in his
// second sync, e.g. ban -> leave (an unban).
t.Log("Ana invites Bert to a DM. He accepts.")
bertDMJoin := testutils.NewJoinEvent(t, tc.bert)
dmTimeline := append(
createRoomState(t, tc.ana, time.Now()),
testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"}),
bertDMJoin,
)
v2.queueResponse(tc.anaToken, sync2.SyncResponse{
NextBatch: "ana3",
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
tc.dmRoomID: {
Timeline: sync2.TimelineResponse{
Events: dmTimeline,
PrevBatch: "anaDM",
},
},
},
},
})
v2.waitUntilEmpty(t, tc.anaToken)
t.Log("Bert sliding syncs")
bertRes = v3.mustDoV3RequestWithPos(t, tc.bertToken, bertRes.Pos, ssRequest)
t.Log("Bert sees his join to the DM.")
m.MatchResponse(t, bertRes, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
tc.dmRoomID: {m.MatchRoomLacksInviteState(), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bertDMJoin})},
}))
})
}
}
// This is a minimal version of the test above, which is helpful for debugging (because
// the above test is a monstrosity---apologies to the reader.)
func TestTimelineAfterRequestingStateAfterGappyPoll(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
defer v2.close()
v3 := runTestServer(t, v2, pqString)
defer v3.close()
alice := "alice"
aliceToken := "alicetoken"
bob := "bob"
roomID := "!unimportant"
v2.addAccount(t, alice, aliceToken)
t.Log("alice creates a public room.")
timeline1 := createRoomState(t, alice, time.Now())
var aliceMembership json.RawMessage
for _, ev := range timeline1 {
parsed := gjson.ParseBytes(ev)
if parsed.Get("type").Str == "m.room.member" && parsed.Get("state_key").Str == alice {
aliceMembership = ev
break
}
}
if len(aliceMembership) == 0 {
t.Fatal("Initial timeline did not have a membership for Alice")
}
v2.queueResponse(aliceToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomID: {
Timeline: sync2.TimelineResponse{
Events: timeline1,
PrevBatch: "alicePublicPrevBatch1",
},
},
},
},
NextBatch: "aliceSync1",
})
t.Log("alice sliding syncs, requesting all memberships in state.")
aliceReq := sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 20,
RequiredState: [][2]string{{"m.room.member", "*"}},
},
},
}
aliceRes := v3.mustDoV3Request(t, aliceToken, aliceReq)
t.Log("She sees herself joined to her room, with an appropriate timeline.")
// Note: we only expect timeline1[1:] here, excluding the create event. See
// https://github.com/matrix-org/sliding-sync/issues/343
m.MatchResponse(t, aliceRes,
m.LogResponse(t),
m.MatchRoomSubscription(roomID,
m.MatchRoomRequiredState([]json.RawMessage{aliceMembership}),
m.MatchRoomTimeline(timeline1[1:])),
)
t.Logf("Alice's poller gets a gappy sync response for the public room. bob's membership is now join, and alice has sent 10 messages.")
timeline2 := make([]json.RawMessage, 10)
for i := range timeline2 {
timeline2[i] = testutils.NewMessageEvent(t, alice, fmt.Sprintf("hello %d", i))
}
bobMembership := testutils.NewJoinEvent(t, bob)
v2.queueResponse(aliceToken, sync2.SyncResponse{
NextBatch: "alice2",
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomID: {
State: sync2.EventsResponse{
Events: []json.RawMessage{bobMembership},
},
Timeline: sync2.TimelineResponse{
Events: timeline2,
Limited: true,
PrevBatch: "alicePublicPrevBatch2",
},
},
},
},
})
v2.waitUntilEmpty(t, aliceToken)
t.Log("Alice does an incremental sliding sync.")
_, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, aliceRes.Pos, sync3.Request{})
t.Log("Her long-polling session has been closed by the server.")
assertUnknownPos(t, respBytes, statusCode)
t.Log("Alice syncs again from scratch.")
aliceRes = v3.mustDoV3Request(t, aliceToken, aliceReq)
t.Log("She sees both her and Bob's membership, and the timeline from the gappy poll.")
// Note: we don't expect to see timeline1 here because we stop at the first gap we
// see in the timeline.
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID,
m.MatchRoomRequiredState([]json.RawMessage{aliceMembership, bobMembership}),
m.MatchRoomTimeline(timeline2),
))
}
func assertUnknownPos(t *testing.T, respBytes []byte, statusCode int) {
if statusCode != http.StatusBadRequest {
t.Errorf("Got status %d, expected %d", statusCode, http.StatusBadRequest)
}
if errcode := gjson.GetBytes(respBytes, "errcode").Str; errcode != "M_UNKNOWN_POS" {
t.Errorf("Got errcode %s, expected %s", errcode, "M_UNKNOWN_POS")
}
}