Do not make snapshots for lone leave events

Specifically this is targetting invite rejections, where the leave
event is inside the leave block of the sync v2 response.

Previously, we would make a snapshot with this leave event. If the
proxy wasn't in this room, it would mean the room state would just
be the leave event, which is wrong. If the proxy was in the room,
then state would correctly be rolled forward.
This commit is contained in:
Kegan Dougal 2023-07-31 17:53:15 +01:00
parent 13673175b5
commit 6623ddb9e3
14 changed files with 108 additions and 68 deletions

View File

@ -70,8 +70,9 @@ type V2AccountData struct {
func (*V2AccountData) Type() string { return "V2AccountData" }
type V2LeaveRoom struct {
UserID string
RoomID string
UserID string
RoomID string
LeaveEvent json.RawMessage
}
func (*V2LeaveRoom) Type() string { return "V2LeaveRoom" }

View File

@ -293,7 +293,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
@ -308,6 +308,33 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
return 0, nil, err // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
// the room state would just be a single event: this leave event, which is wrong.
if len(dedupedEvents) == 1 && dedupedEvents[0].Type == "m.room.member" && dedupedEvents[0].Membership == "leave" &&
dedupedEvents[0].StateKey == userID && snapID == 0 {
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
"Accumulator: skipping processing of leave event, as no snapshot exists",
)
return 0, nil, nil
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
@ -339,19 +366,6 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
}
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,

View File

@ -14,6 +14,10 @@ import (
"github.com/tidwall/gjson"
)
var (
userID = "@me:localhost"
)
func TestAccumulatorInitialise(t *testing.T) {
roomID := "!TestAccumulatorInitialise:localhost"
roomEvents := []json.RawMessage{
@ -118,7 +122,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
var numNew int
var latestNIDs []int64
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -192,7 +196,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -228,7 +232,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -355,7 +359,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -555,7 +559,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, roomID, "", subset)
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
return err
})

View File

@ -307,12 +307,12 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, roomID, prevBatch, timeline)
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return

View File

@ -31,7 +31,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
@ -161,7 +161,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(roomID, "", eventMap)
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
@ -351,7 +351,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -454,7 +454,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -534,7 +534,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(roomID, timeline.prevBatch, timeline.timeline)
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -776,7 +776,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(roomID, "foo", serialise(tc.AccumulateMemberships))
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}

View File

@ -259,7 +259,7 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(roomID, prevBatch, timeline)
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
@ -448,16 +448,18 @@ func (h *Handler) OnInvite(ctx context.Context, userID, roomID string, inviteSta
})
}
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv json.RawMessage) {
// remove any invites for this user if they are rejecting an invite
err := h.Store.InvitesTable.RemoveInvite(userID, roomID)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
UserID: userID,
RoomID: roomID,
LeaveEvent: leaveEv,
})
}

View File

@ -51,7 +51,7 @@ type V2DataReceiver interface {
// Sent when there is a room in the `invite` section of the v2 response.
OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) // invitestate in db
// Sent when there is a room in the `leave` section of the v2 response.
OnLeftRoom(ctx context.Context, userID, roomID string)
OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage)
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
@ -301,11 +301,11 @@ func (h *PollerMap) OnInvite(ctx context.Context, userID, roomID string, inviteS
wg.Wait()
}
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.OnLeftRoom(ctx, userID, roomID)
h.callbacks.OnLeftRoom(ctx, userID, roomID, leaveEvent)
wg.Done()
}
wg.Wait()
@ -716,12 +716,23 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
}
}
for roomID, roomData := range res.Rooms.Leave {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
p.receiver.OnLeftRoom(ctx, p.userID, roomID)
// Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle
// the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event).
var leaveEvent json.RawMessage
for _, ev := range roomData.Timeline.Events {
leaveEv := gjson.ParseBytes(ev)
if leaveEv.Get("content.membership").Str == "leave" && leaveEv.Get("state_key").Str == p.userID {
leaveEvent = ev
break
}
}
if leaveEvent != nil {
p.receiver.OnLeftRoom(ctx, p.userID, roomID, leaveEvent)
}
}
for roomID, roomData := range res.Rooms.Invite {
p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events)

View File

@ -617,7 +617,8 @@ func (s *mockDataReceiver) OnReceipt(ctx context.Context, userID, roomID, ephEve
}
func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) {
}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}

View File

@ -38,12 +38,12 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(roomID2, "", eventsRoom2)
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}

View File

@ -38,15 +38,6 @@ func (u *InviteUpdate) Type() string {
return fmt.Sprintf("InviteUpdate[%s]", u.RoomID())
}
// LeftRoomUpdate corresponds to a key-value pair from a v2 sync's `leave` section.
type LeftRoomUpdate struct {
RoomUpdate
}
func (u *LeftRoomUpdate) Type() string {
return fmt.Sprintf("LeftRoomUpdate[%s]", u.RoomID())
}
// TypingEdu corresponds to a typing EDU in the `ephemeral` section of a joined room's v2 sync resposne.
type TypingUpdate struct {
RoomUpdate

View File

@ -606,7 +606,7 @@ func (c *UserCache) OnInvite(ctx context.Context, roomID string, inviteStateEven
c.emitOnRoomUpdate(ctx, up)
}
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string, leaveEvent json.RawMessage) {
urd := c.LoadRoomData(roomID)
urd.IsInvite = false
urd.HasLeft = true
@ -616,7 +616,10 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
c.roomToData[roomID] = urd
c.roomToDataMu.Unlock()
up := &LeftRoomUpdate{
ev := gjson.ParseBytes(leaveEvent)
stateKey := ev.Get("state_key").Str
up := &RoomEventUpdate{
RoomUpdate: &roomUpdateCache{
roomID: roomID,
// do NOT pull from the global cache as it is a snapshot of the room at the point of
@ -624,6 +627,18 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
globalRoomData: internal.NewRoomMetadata(roomID),
userRoomData: &urd,
},
EventData: &EventData{
Event: leaveEvent,
RoomID: roomID,
EventType: ev.Get("type").Str,
StateKey: &stateKey,
Content: ev.Get("content"),
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
// if this is an invite rejection we need to make sure we tell the client, and not
// skip it because of the lack of a NID (this event may not be in the events table)
AlwaysProcess: true,
},
}
c.emitOnRoomUpdate(ctx, up)
}

View File

@ -181,11 +181,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
if roomEventUpdate != nil && roomEventUpdate.EventData.Event != nil {
r.NumLive++
advancedPastEvent := false
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
if !roomEventUpdate.EventData.AlwaysProcess {
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
// we only append to the timeline if we haven't already got this event. This can happen when:
// - 2 live events for a room mid-connection
// - next request bumps a room from outside to inside the window

View File

@ -317,7 +317,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
defer task.End()
@ -697,7 +696,7 @@ func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
if !ok {
return
}
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID)
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
}
func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {

View File

@ -100,15 +100,16 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
})
// check Alice sees both events
assertEventsEqual(t, []Event{
{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
kickEvent := Event{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
}
assertEventsEqual(t, []Event{
kickEvent,
{
Type: "m.room.name",
StateKey: ptr(""),
@ -120,7 +121,6 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
},
}, timeline)
kickEvent := timeline[0]
// Ensure Eve doesn't see this message in the timeline, name calc or required_state
eveRes = eve.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
@ -140,7 +140,7 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
}, WithPos(eveRes.Pos))
// the room is deleted from eve's point of view and she sees up to and including her kick event
m.MatchResponse(t, eveRes, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(m.MatchV3DeleteOp(0))), m.MatchRoomSubscription(
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{kickEvent}),
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), MatchRoomTimelineMostRecent(1, []Event{kickEvent}),
))
}