mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Factor out AccumulateResult struct
This commit is contained in:
parent
be78e6f6e4
commit
777cb357fe
@ -317,6 +317,12 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
return res, err
|
||||
}
|
||||
|
||||
type AccumulateResult struct {
|
||||
// TODO: is this redundant---identical to len(TimelineNIDs)?
|
||||
NumNew int
|
||||
TimelineNIDs []int64
|
||||
}
|
||||
|
||||
// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
|
||||
// received from the server. Returns the number of new events in the timeline, the new timeline event NIDs
|
||||
// or an error.
|
||||
@ -328,7 +334,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
// to exist in the database, and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
|
||||
// - It adds entries to the membership log for membership events.
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (AccumulateResult, error) {
|
||||
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
|
||||
// we expect:
|
||||
// - there to be no duplicate events
|
||||
@ -337,10 +343,10 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("filterTimelineEvents: %w", err)
|
||||
return
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
if len(dedupedEvents) == 0 {
|
||||
return 0, nil, err // nothing to do
|
||||
return AccumulateResult{}, nil // nothing to do
|
||||
}
|
||||
|
||||
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
|
||||
@ -354,7 +360,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
|
||||
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
|
||||
// The only situation where no prior snapshot should exist is if this timeline is
|
||||
@ -389,19 +395,22 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
})
|
||||
sentry.CaptureMessage(msg)
|
||||
})
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
if len(eventIDToNID) == 0 {
|
||||
// nothing to do, we already know about these events
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
|
||||
result := AccumulateResult{
|
||||
NumNew: len(eventIDToNID),
|
||||
}
|
||||
numNew = len(eventIDToNID)
|
||||
|
||||
var latestNID int64
|
||||
newEvents := make([]Event, 0, len(eventIDToNID))
|
||||
@ -433,7 +442,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
}
|
||||
}
|
||||
newEvents = append(newEvents, ev)
|
||||
timelineNIDs = append(timelineNIDs, ev.NID)
|
||||
result.TimelineNIDs = append(result.TimelineNIDs, ev.NID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -443,7 +452,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
if len(redactTheseEventIDs) > 0 {
|
||||
createEventJSON, err := a.eventsTable.SelectCreateEvent(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("SelectCreateEvent: %w", err)
|
||||
return AccumulateResult{}, fmt.Errorf("SelectCreateEvent: %w", err)
|
||||
}
|
||||
roomVersion = gjson.GetBytes(createEventJSON, "content.room_version").Str
|
||||
if roomVersion == "" {
|
||||
@ -454,7 +463,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
)
|
||||
}
|
||||
if err = a.eventsTable.Redact(txn, roomVersion, redactTheseEventIDs); err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -470,12 +479,12 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
if snapID != 0 {
|
||||
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
|
||||
}
|
||||
}
|
||||
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
|
||||
}
|
||||
replacesNID = replacedNID
|
||||
memNIDs, otherNIDs := newStripped.NIDs()
|
||||
@ -485,7 +494,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
OtherEvents: otherNIDs,
|
||||
}
|
||||
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to insert new snapshot: %w", err)
|
||||
}
|
||||
if a.snapshotMemberCountVec != nil {
|
||||
logger.Trace().Str("room_id", roomID).Int("members", len(memNIDs)).Msg("Inserted new snapshot")
|
||||
@ -494,20 +503,20 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
snapID = newSnapshot.SnapshotID
|
||||
}
|
||||
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
|
||||
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
|
||||
return AccumulateResult{}, fmt.Errorf("HandleSpaceUpdates: %s", err)
|
||||
}
|
||||
|
||||
// the last fetched snapshot ID is the current one
|
||||
info := a.roomInfoDelta(roomID, newEvents)
|
||||
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
|
||||
}
|
||||
return numNew, timelineNIDs, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:
|
||||
|
@ -135,25 +135,24 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
// new state event should be added to the snapshot
|
||||
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
|
||||
}
|
||||
var numNew int
|
||||
var latestNIDs []int64
|
||||
var result AccumulateResult
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
result, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
if numNew != len(newEvents) {
|
||||
t.Fatalf("got %d new events, want %d", numNew, len(newEvents))
|
||||
if result.NumNew != len(newEvents) {
|
||||
t.Fatalf("got %d new events, want %d", result.NumNew, len(newEvents))
|
||||
}
|
||||
// latest nid shoould match
|
||||
wantLatestNID, err := accumulator.eventsTable.SelectHighestNID()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to check latest NID from Accumulate: %s", err)
|
||||
}
|
||||
if latestNIDs[len(latestNIDs)-1] != wantLatestNID {
|
||||
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", latestNIDs[len(latestNIDs)-1], wantLatestNID)
|
||||
if result.TimelineNIDs[len(result.TimelineNIDs)-1] != wantLatestNID {
|
||||
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", result.TimelineNIDs[len(result.TimelineNIDs)-1], wantLatestNID)
|
||||
}
|
||||
|
||||
// Begin assertions
|
||||
@ -212,7 +211,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
|
||||
// subsequent calls do nothing and are not an error
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -248,7 +247,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
|
||||
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -384,7 +383,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
|
||||
}
|
||||
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -584,8 +583,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
|
||||
defer wg.Done()
|
||||
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
|
||||
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
|
||||
totalNumNew += numNew
|
||||
result, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
|
||||
totalNumNew += result.NumNew
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -336,15 +336,15 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (result AccumulateResult, err error) {
|
||||
if len(timeline) == 0 {
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
|
||||
result, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
|
||||
return err
|
||||
})
|
||||
return
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
|
||||
|
@ -31,11 +31,11 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
|
||||
}
|
||||
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
|
||||
accResult, err := store.Accumulate(userID, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate returned error: %s", err)
|
||||
}
|
||||
latest := latestNIDs[len(latestNIDs)-1]
|
||||
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -158,14 +158,13 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var latestPos int64
|
||||
var latestNIDs []int64
|
||||
var err error
|
||||
for roomID, eventMap := range roomIDToEventMap {
|
||||
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
|
||||
accResult, err := store.Accumulate(userID, roomID, "", eventMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
|
||||
}
|
||||
latestPos = latestNIDs[len(latestNIDs)-1]
|
||||
latestPos = accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
}
|
||||
aliceJoinTimingsByRoomID, err := store.JoinedRoomsAfterPosition(alice, latestPos)
|
||||
if err != nil {
|
||||
@ -351,11 +350,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
t.Logf("%s added %d new events", tl.RoomID, numNew)
|
||||
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
|
||||
}
|
||||
latestPos, err := store.LatestEventNID()
|
||||
if err != nil {
|
||||
@ -454,11 +453,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
t.Fatalf("LatestEventNID: %s", err)
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
t.Logf("%s added %d new events", tl.RoomID, numNew)
|
||||
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
|
||||
}
|
||||
latestPos, err = store.LatestEventNID()
|
||||
if err != nil {
|
||||
@ -534,7 +533,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
|
||||
}
|
||||
eventIDs := []string{}
|
||||
for _, timeline := range timelines {
|
||||
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
|
||||
_, err := store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to accumulate: %s", err)
|
||||
}
|
||||
@ -776,7 +775,7 @@ func TestAllJoinedMembers(t *testing.T) {
|
||||
}, serialise(tc.InitMemberships)...))
|
||||
assertNoError(t, err)
|
||||
|
||||
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
|
||||
_, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
|
||||
assertNoError(t, err)
|
||||
testCases[i].RoomID = roomID // remember this for later
|
||||
}
|
||||
|
@ -294,7 +294,7 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
|
||||
}
|
||||
|
||||
// Insert new events
|
||||
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
|
||||
accResult, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
|
||||
if err != nil {
|
||||
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
@ -302,11 +302,11 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
|
||||
}
|
||||
|
||||
// We've updated the database. Now tell any pubsub listeners what we learned.
|
||||
if numNew != 0 {
|
||||
if accResult.NumNew != 0 {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
|
||||
RoomID: roomID,
|
||||
PrevBatch: prevBatch,
|
||||
EventNIDs: latestNIDs,
|
||||
EventNIDs: accResult.TimelineNIDs,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -38,16 +38,16 @@ func TestGlobalCacheLoadState(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
|
||||
}
|
||||
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
|
||||
_, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
|
||||
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
|
||||
accResult, err := store.Accumulate(alice, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
latest := latestNIDs[len(latestNIDs)-1]
|
||||
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
globalCache := caches.NewGlobalCache(store)
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
Loading…
x
Reference in New Issue
Block a user