Merge pull request #235 from matrix-org/kegan/leave-event-shouldnt-snapshot

Do not make snapshots for lone leave events
This commit is contained in:
kegsay 2023-08-02 04:53:40 -07:00 committed by GitHub
commit a61a3fdde2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 111 additions and 68 deletions

View File

@ -73,8 +73,9 @@ type V2AccountData struct {
func (*V2AccountData) Type() string { return "V2AccountData" }
type V2LeaveRoom struct {
UserID string
RoomID string
UserID string
RoomID string
LeaveEvent json.RawMessage
}
func (*V2LeaveRoom) Type() string { return "V2LeaveRoom" }

View File

@ -293,7 +293,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
@ -308,6 +308,36 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
return 0, nil, err // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
// the room state would just be a single event: this leave event, which is wrong.
if len(dedupedEvents) == 1 &&
dedupedEvents[0].Type == "m.room.member" &&
(dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") &&
dedupedEvents[0].StateKey == userID &&
snapID == 0 {
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
"Accumulator: skipping processing of leave event, as no snapshot exists",
)
return 0, nil, nil
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
@ -339,19 +369,6 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
}
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,

View File

@ -14,6 +14,10 @@ import (
"github.com/tidwall/gjson"
)
var (
userID = "@me:localhost"
)
func TestAccumulatorInitialise(t *testing.T) {
roomID := "!TestAccumulatorInitialise:localhost"
roomEvents := []json.RawMessage{
@ -118,7 +122,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
var numNew int
var latestNIDs []int64
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -192,7 +196,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -228,7 +232,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -355,7 +359,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -555,7 +559,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, roomID, "", subset)
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
return err
})

View File

@ -307,12 +307,12 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, roomID, prevBatch, timeline)
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return

View File

@ -31,7 +31,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
@ -161,7 +161,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(roomID, "", eventMap)
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
@ -351,7 +351,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -454,7 +454,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -534,7 +534,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(roomID, timeline.prevBatch, timeline.timeline)
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -776,7 +776,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(roomID, "foo", serialise(tc.AccumulateMemberships))
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}

View File

@ -269,7 +269,7 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(roomID, prevBatch, timeline)
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
@ -465,16 +465,18 @@ func (h *Handler) OnInvite(ctx context.Context, userID, roomID string, inviteSta
})
}
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv json.RawMessage) {
// remove any invites for this user if they are rejecting an invite
err := h.Store.InvitesTable.RemoveInvite(userID, roomID)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
UserID: userID,
RoomID: roomID,
LeaveEvent: leaveEv,
})
}

View File

@ -51,7 +51,7 @@ type V2DataReceiver interface {
// Sent when there is a room in the `invite` section of the v2 response.
OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) // invitestate in db
// Sent when there is a room in the `leave` section of the v2 response.
OnLeftRoom(ctx context.Context, userID, roomID string)
OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage)
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
@ -316,11 +316,11 @@ func (h *PollerMap) OnInvite(ctx context.Context, userID, roomID string, inviteS
wg.Wait()
}
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.OnLeftRoom(ctx, userID, roomID)
h.callbacks.OnLeftRoom(ctx, userID, roomID, leaveEvent)
wg.Done()
}
wg.Wait()
@ -731,12 +731,23 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
}
}
for roomID, roomData := range res.Rooms.Leave {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
p.receiver.OnLeftRoom(ctx, p.userID, roomID)
// Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle
// the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event).
var leaveEvent json.RawMessage
for _, ev := range roomData.Timeline.Events {
leaveEv := gjson.ParseBytes(ev)
if leaveEv.Get("content.membership").Str == "leave" && leaveEv.Get("state_key").Str == p.userID {
leaveEvent = ev
break
}
}
if leaveEvent != nil {
p.receiver.OnLeftRoom(ctx, p.userID, roomID, leaveEvent)
}
}
for roomID, roomData := range res.Rooms.Invite {
p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events)

View File

@ -617,7 +617,8 @@ func (s *mockDataReceiver) OnReceipt(ctx context.Context, userID, roomID, ephEve
}
func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) {
}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}

View File

@ -38,12 +38,12 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(roomID2, "", eventsRoom2)
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}

View File

@ -38,15 +38,6 @@ func (u *InviteUpdate) Type() string {
return fmt.Sprintf("InviteUpdate[%s]", u.RoomID())
}
// LeftRoomUpdate corresponds to a key-value pair from a v2 sync's `leave` section.
type LeftRoomUpdate struct {
RoomUpdate
}
func (u *LeftRoomUpdate) Type() string {
return fmt.Sprintf("LeftRoomUpdate[%s]", u.RoomID())
}
// TypingEdu corresponds to a typing EDU in the `ephemeral` section of a joined room's v2 sync resposne.
type TypingUpdate struct {
RoomUpdate

View File

@ -606,7 +606,7 @@ func (c *UserCache) OnInvite(ctx context.Context, roomID string, inviteStateEven
c.emitOnRoomUpdate(ctx, up)
}
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string, leaveEvent json.RawMessage) {
urd := c.LoadRoomData(roomID)
urd.IsInvite = false
urd.HasLeft = true
@ -616,7 +616,10 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
c.roomToData[roomID] = urd
c.roomToDataMu.Unlock()
up := &LeftRoomUpdate{
ev := gjson.ParseBytes(leaveEvent)
stateKey := ev.Get("state_key").Str
up := &RoomEventUpdate{
RoomUpdate: &roomUpdateCache{
roomID: roomID,
// do NOT pull from the global cache as it is a snapshot of the room at the point of
@ -624,6 +627,18 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
globalRoomData: internal.NewRoomMetadata(roomID),
userRoomData: &urd,
},
EventData: &EventData{
Event: leaveEvent,
RoomID: roomID,
EventType: ev.Get("type").Str,
StateKey: &stateKey,
Content: ev.Get("content"),
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
// if this is an invite rejection we need to make sure we tell the client, and not
// skip it because of the lack of a NID (this event may not be in the events table)
AlwaysProcess: true,
},
}
c.emitOnRoomUpdate(ctx, up)
}

View File

@ -181,11 +181,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
if roomEventUpdate != nil && roomEventUpdate.EventData.Event != nil {
r.NumLive++
advancedPastEvent := false
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
if !roomEventUpdate.EventData.AlwaysProcess {
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
// we only append to the timeline if we haven't already got this event. This can happen when:
// - 2 live events for a room mid-connection
// - next request bumps a room from outside to inside the window

View File

@ -320,7 +320,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
defer task.End()
@ -704,7 +703,7 @@ func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
if !ok {
return
}
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID)
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
}
func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {

View File

@ -100,15 +100,16 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
})
// check Alice sees both events
assertEventsEqual(t, []Event{
{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
kickEvent := Event{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
}
assertEventsEqual(t, []Event{
kickEvent,
{
Type: "m.room.name",
StateKey: ptr(""),
@ -120,7 +121,6 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
},
}, timeline)
kickEvent := timeline[0]
// Ensure Eve doesn't see this message in the timeline, name calc or required_state
eveRes = eve.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
@ -140,7 +140,7 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
}, WithPos(eveRes.Pos))
// the room is deleted from eve's point of view and she sees up to and including her kick event
m.MatchResponse(t, eveRes, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(m.MatchV3DeleteOp(0))), m.MatchRoomSubscription(
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{kickEvent}),
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), MatchRoomTimelineMostRecent(1, []Event{kickEvent}),
))
}