Factor out AccumulateResult struct

This commit is contained in:
David Robertson 2023-09-07 19:21:44 +01:00
parent be78e6f6e4
commit 777cb357fe
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
6 changed files with 58 additions and 51 deletions

View File

@ -317,6 +317,12 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
return res, err
}
type AccumulateResult struct {
// TODO: is this redundant---identical to len(TimelineNIDs)?
NumNew int
TimelineNIDs []int64
}
// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
// received from the server. Returns the number of new events in the timeline, the new timeline event NIDs
// or an error.
@ -328,7 +334,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (AccumulateResult, error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
@ -337,10 +343,10 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
if err != nil {
err = fmt.Errorf("filterTimelineEvents: %w", err)
return
return AccumulateResult{}, err
}
if len(dedupedEvents) == 0 {
return 0, nil, err // nothing to do
return AccumulateResult{}, nil // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
@ -354,7 +360,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
// The only situation where no prior snapshot should exist is if this timeline is
@ -389,19 +395,22 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
})
sentry.CaptureMessage(msg)
})
return 0, nil, nil
return AccumulateResult{}, nil
}
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return 0, nil, nil
return AccumulateResult{}, nil
}
result := AccumulateResult{
NumNew: len(eventIDToNID),
}
numNew = len(eventIDToNID)
var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
@ -433,7 +442,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
}
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
result.TimelineNIDs = append(result.TimelineNIDs, ev.NID)
}
}
@ -443,7 +452,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if len(redactTheseEventIDs) > 0 {
createEventJSON, err := a.eventsTable.SelectCreateEvent(txn, roomID)
if err != nil {
return 0, nil, fmt.Errorf("SelectCreateEvent: %w", err)
return AccumulateResult{}, fmt.Errorf("SelectCreateEvent: %w", err)
}
roomVersion = gjson.GetBytes(createEventJSON, "content.room_version").Str
if roomVersion == "" {
@ -454,7 +463,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
)
}
if err = a.eventsTable.Redact(txn, roomVersion, redactTheseEventIDs); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
@ -470,12 +479,12 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
if err != nil {
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
return AccumulateResult{}, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
@ -485,7 +494,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
return AccumulateResult{}, fmt.Errorf("failed to insert new snapshot: %w", err)
}
if a.snapshotMemberCountVec != nil {
logger.Trace().Str("room_id", roomID).Int("members", len(memNIDs)).Msg("Inserted new snapshot")
@ -494,20 +503,20 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
snapID = newSnapshot.SnapshotID
}
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
return AccumulateResult{}, fmt.Errorf("HandleSpaceUpdates: %s", err)
}
// the last fetched snapshot ID is the current one
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
return numNew, timelineNIDs, nil
return result, nil
}
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:

View File

@ -135,25 +135,24 @@ func TestAccumulatorAccumulate(t *testing.T) {
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
var numNew int
var latestNIDs []int64
var result AccumulateResult
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
result, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
if numNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", numNew, len(newEvents))
if result.NumNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", result.NumNew, len(newEvents))
}
// latest nid shoould match
wantLatestNID, err := accumulator.eventsTable.SelectHighestNID()
if err != nil {
t.Fatalf("failed to check latest NID from Accumulate: %s", err)
}
if latestNIDs[len(latestNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", latestNIDs[len(latestNIDs)-1], wantLatestNID)
if result.TimelineNIDs[len(result.TimelineNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", result.TimelineNIDs[len(result.TimelineNIDs)-1], wantLatestNID)
}
// Begin assertions
@ -212,7 +211,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -248,7 +247,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -384,7 +383,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
_, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -584,8 +583,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
result, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += result.NumNew
return err
})
if err != nil {

View File

@ -336,15 +336,15 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (result AccumulateResult, err error) {
if len(timeline) == 0 {
return 0, nil, nil
return AccumulateResult{}, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
result, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return
return result, err
}
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {

View File

@ -31,11 +31,11 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
accResult, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
testCases := []struct {
name string
@ -158,14 +158,13 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
},
}
var latestPos int64
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
accResult, err := store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
latestPos = latestNIDs[len(latestNIDs)-1]
latestPos = accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
}
aliceJoinTimingsByRoomID, err := store.JoinedRoomsAfterPosition(alice, latestPos)
if err != nil {
@ -351,11 +350,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err := store.LatestEventNID()
if err != nil {
@ -454,11 +453,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err = store.LatestEventNID()
if err != nil {
@ -534,7 +533,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
_, err := store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -776,7 +775,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
_, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}

View File

@ -294,7 +294,7 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
accResult, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
@ -302,11 +302,11 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// We've updated the database. Now tell any pubsub listeners what we learned.
if numNew != 0 {
if accResult.NumNew != 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: latestNIDs,
EventNIDs: accResult.TimelineNIDs,
})
}

View File

@ -38,16 +38,16 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
_, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
accResult, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
globalCache := caches.NewGlobalCache(store)
testCases := []struct {
name string