From 6c4f7d37225ed9d2b3a18268ac2fe8fedecdc7a4 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Thu, 22 Dec 2022 15:08:42 +0000 Subject: [PATCH] improvement: completely refactor device data updates - `Conn`s now expose a direct `OnUpdate(caches.Update)` function for updates which concern a specific device ID. - Add a bitset in `DeviceData` to indicate if the OTK or fallback keys were changed. - Pass through the affected `DeviceID` in `pubsub.V2DeviceData` updates. - Remove `DeviceDataTable.SelectFrom` as it was unused. - Refactor how the poller invokes `OnE2EEData`: it now only does this if there are changes to OTK counts and/or fallback key types and/or device lists, and _only_ sends those fields, setting the rest to the zero value. - Remove noisy logging. - Add `caches.DeviceDataUpdate` which has no data but serves to wake-up the long poller. - Only send OTK counts / fallback key types when they have changed, not constantly. This matches the behaviour described in MSC3884 The entire flow now looks like: - Poller notices a diff against in-memory version of otk count and invokes `OnE2EEData` - Handler updates device data table, bumps the changed bit for otk count. - Other handler gets the pubsub update, directly finds the `Conn` based on the `DeviceID`. Invokes `OnUpdate(caches.DeviceDataUpdate)` - This update is handled by the E2EE extension which then pulls the data out from the database and returns it. - On initial connections, all OTK / fallback data is returned. --- internal/device_data.go | 33 +++++ internal/device_data_test.go | 57 +++++++++ pubsub/v2.go | 3 +- state/accumulator.go | 15 +-- state/device_data_table.go | 27 +---- state/device_data_table_test.go | 175 +++++++++++++-------------- state/storage.go | 2 +- state/to_device_table.go | 4 +- sync2/handler2/handler.go | 3 +- sync2/poller.go | 33 +++-- sync3/caches/global.go | 10 -- sync3/caches/update.go | 7 +- sync3/conn.go | 6 + sync3/conn_test.go | 6 +- sync3/extensions/e2ee.go | 16 ++- sync3/extensions/extensions.go | 8 ++ sync3/handler/handler.go | 11 +- tests-integration/extensions_test.go | 54 ++++++--- testutils/m/match.go | 9 ++ 19 files changed, 297 insertions(+), 182 deletions(-) create mode 100644 internal/device_data_test.go diff --git a/internal/device_data.go b/internal/device_data.go index 3d8612d..09cee9f 100644 --- a/internal/device_data.go +++ b/internal/device_data.go @@ -4,6 +4,20 @@ import ( "sync" ) +const ( + bitOTKCount int = iota + bitFallbackKeyTypes +) + +func setBit(n int, bit int) int { + n |= (1 << bit) + return n +} +func isBitSet(n int, bit int) bool { + val := n & (1 << bit) + return val > 0 +} + // DeviceData contains useful data for this user's device. This list can be expanded without prompting // schema changes. These values are upserted into the database and persisted forever. type DeviceData struct { @@ -16,10 +30,29 @@ type DeviceData struct { DeviceLists DeviceLists `json:"dl"` + // bitset for which device data changes are present. They accumulate until they get swapped over + // when they get reset + ChangedBits int `json:"c"` + UserID string DeviceID string } +func (dd *DeviceData) SetOTKCountChanged() { + dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount) +} + +func (dd *DeviceData) SetFallbackKeysChanged() { + dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes) +} + +func (dd *DeviceData) OTKCountChanged() bool { + return isBitSet(dd.ChangedBits, bitOTKCount) +} +func (dd *DeviceData) FallbackKeysChanged() bool { + return isBitSet(dd.ChangedBits, bitFallbackKeyTypes) +} + type UserDeviceKey struct { UserID string DeviceID string diff --git a/internal/device_data_test.go b/internal/device_data_test.go new file mode 100644 index 0000000..17b7381 --- /dev/null +++ b/internal/device_data_test.go @@ -0,0 +1,57 @@ +package internal + +import "testing" + +func TestDeviceDataBitset(t *testing.T) { + testCases := []struct { + get func() DeviceData + fallbackSet bool + otkSet bool + }{ + { + get: func() DeviceData { + var dd DeviceData + dd.SetFallbackKeysChanged() + return dd + }, + fallbackSet: true, + otkSet: false, + }, + { + get: func() DeviceData { + var dd DeviceData + dd.SetOTKCountChanged() + return dd + }, + fallbackSet: false, + otkSet: true, + }, + { + get: func() DeviceData { + var dd DeviceData + dd.SetFallbackKeysChanged() + dd.SetOTKCountChanged() + return dd + }, + fallbackSet: true, + otkSet: true, + }, + { + get: func() DeviceData { + var dd DeviceData + return dd + }, + fallbackSet: false, + otkSet: false, + }, + } + for _, tc := range testCases { + dd := tc.get() + if dd.FallbackKeysChanged() != tc.fallbackSet { + t.Errorf("%v : wrong fallback value, want %v", dd, tc.fallbackSet) + } + if dd.OTKCountChanged() != tc.otkSet { + t.Errorf("%v : wrong OTK value, want %v", dd, tc.otkSet) + } + } +} diff --git a/pubsub/v2.go b/pubsub/v2.go index cdc5afb..1abc704 100644 --- a/pubsub/v2.go +++ b/pubsub/v2.go @@ -76,7 +76,8 @@ type V2InitialSyncComplete struct { func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" } type V2DeviceData struct { - Pos int64 + DeviceID string + Pos int64 } func (*V2DeviceData) Type() string { return "V2DeviceData" } diff --git a/state/accumulator.go b/state/accumulator.go index 81472e1..67e7dc8 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -4,20 +4,13 @@ import ( "database/sql" "encoding/json" "fmt" - "os" "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/matrix-org/sliding-sync/sqlutil" - "github.com/rs/zerolog" "github.com/tidwall/gjson" ) -var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: "15:04:05", -}) - // Accumulator tracks room state and timelines. // // In order for it to remain simple(ish), the accumulator DOES NOT SUPPORT arbitrary timeline gaps. @@ -77,7 +70,7 @@ func (a *Accumulator) calculateNewSnapshot(old StrippedEvents, new Event) (Strip // ruh roh. This should be impossible, but it can happen if the v2 response sends the same // event in both state and timeline. We need to alert the operator and whine badly as it means // we have lost an event by now. - log.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg( + logger.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg( "Detected different event IDs with the same NID when rolling forward state. This has resulted in data loss in this room (1 event). " + "This can happen when the v2 /sync response sends the same event in both state and timeline sections. " + "The event in this log line has been dropped!", @@ -160,7 +153,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool, } if snapshotID > 0 { // we only initialise rooms once - log.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early") + logger.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early") return nil } @@ -183,7 +176,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool, if len(eventIDToNID) == 0 { // we don't have a current snapshot for this room but yet no events are new, // no idea how this should be handled. - log.Error().Str("room_id", roomID).Msg( + logger.Error().Str("room_id", roomID).Msg( "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug.", ) return nil @@ -270,7 +263,7 @@ func (a *Accumulator) Accumulate(roomID string, prevBatch string, timeline []jso return fmt.Errorf("event malformed: %s", err) } if _, ok := seenEvents[e.ID]; ok { - log.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( + logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( "Accumulator.Accumulate: seen the same event ID twice, ignoring", ) continue diff --git a/state/device_data_table.go b/state/device_data_table.go index 4c26ee6..892940b 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -66,6 +66,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern n := tempDD.DeviceLists.New tempDD.DeviceLists.Sent = n tempDD.DeviceLists.New = make(map[string]int) + // reset changed bits + changedBits := tempDD.ChangedBits + tempDD.ChangedBits = 0 // re-marshal and write data, err := json.Marshal(tempDD) @@ -74,32 +77,12 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern } _, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) dd = &tempDD + dd.ChangedBits = changedBits return err }) return } -func (t *DeviceDataTable) SelectFrom(pos int64) (results []internal.DeviceData, nextPos int64, err error) { - nextPos = pos - var rows []DeviceDataRow - err = t.db.Select(&rows, `SELECT id, user_id, device_id, data FROM syncv3_device_data WHERE id > $1 ORDER BY id ASC`, pos) - if err != nil { - return - } - results = make([]internal.DeviceData, len(rows)) - for i := range rows { - var dd internal.DeviceData - if err = json.Unmarshal(rows[i].Data, &dd); err != nil { - return - } - dd.UserID = rows[i].UserID - dd.DeviceID = rows[i].DeviceID - results[i] = dd - nextPos = rows[i].ID - } - return -} - // Upsert combines what is in the database for this user|device with the partial entry `dd` func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { @@ -118,9 +101,11 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) } if dd.FallbackKeyTypes != nil { tempDD.FallbackKeyTypes = dd.FallbackKeyTypes + tempDD.SetFallbackKeysChanged() } if dd.OTKCounts != nil { tempDD.OTKCounts = dd.OTKCounts + tempDD.SetOTKCountChanged() } tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists) diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index 35e8579..60c3435 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -9,92 +9,19 @@ import ( ) func assertVal(t *testing.T, msg string, got, want interface{}) { + t.Helper() if !reflect.DeepEqual(got, want) { t.Errorf("%s: got %v want %v", msg, got, want) } } -func assertDeviceDatas(t *testing.T, got, want []internal.DeviceData) { +func assertDeviceData(t *testing.T, g, w internal.DeviceData) { t.Helper() - if len(got) != len(want) { - t.Fatalf("got %d devices, want %d : %+v", len(got), len(want), got) - } - for i := range want { - g := got[i] - w := want[i] - assertVal(t, "device id", g.DeviceID, w.DeviceID) - assertVal(t, "user id", g.UserID, w.UserID) - assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes) - assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts) - } -} - -func TestDeviceDataTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } - table := NewDeviceDataTable(db) - userID := "@alice" - deviceID := "ALICE" - dd := &internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - OTKCounts: map[string]int{ - "foo": 100, - }, - FallbackKeyTypes: []string{"foo", "bar"}, - } - - // test basic insert -> select - pos, err := table.Upsert(dd) - assertNoError(t, err) - results, nextPos, err := table.SelectFrom(-1) - assertNoError(t, err) - if pos != nextPos { - t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos) - } - assertDeviceDatas(t, results, []internal.DeviceData{*dd}) - - // at latest -> no results - results, nextPos, err = table.SelectFrom(nextPos) - assertNoError(t, err) - if pos != nextPos { - t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos) - } - assertDeviceDatas(t, results, nil) - - // multiple insert -> replace on user|device - dd2 := *dd - dd2.OTKCounts = map[string]int{"foo": 99} - _, err = table.Upsert(&dd2) - assertNoError(t, err) - dd3 := *dd - dd3.OTKCounts = map[string]int{"foo": 98} - pos, err = table.Upsert(&dd3) - assertNoError(t, err) - results, nextPos, err = table.SelectFrom(nextPos) - assertNoError(t, err) - if pos != nextPos { - t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos) - } - assertDeviceDatas(t, results, []internal.DeviceData{dd3}) - - // multiple insert -> different user, same device + same user, different device - dd4 := *dd - dd4.UserID = "@bob" - _, err = table.Upsert(&dd4) - assertNoError(t, err) - dd5 := *dd - dd5.DeviceID = "ANOTHER" - pos, err = table.Upsert(&dd5) - assertNoError(t, err) - results, nextPos, err = table.SelectFrom(nextPos) - assertNoError(t, err) - if pos != nextPos { - t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos) - } - assertDeviceDatas(t, results, []internal.DeviceData{dd4, dd5}) + assertVal(t, "device id", g.DeviceID, w.DeviceID) + assertVal(t, "user id", g.UserID, w.UserID) + assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes) + assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts) + assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits) } func TestDeviceDataTableSwaps(t *testing.T) { @@ -155,11 +82,13 @@ func TestDeviceDataTableSwaps(t *testing.T) { New: internal.ToDeviceListChangesMap([]string{"alice", "bob"}, nil), }, } + want.SetFallbackKeysChanged() + want.SetOTKCountChanged() // check we can read-only select for i := 0; i < 3; i++ { got, err := table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want}) + assertDeviceData(t, *got, want) } // now swap-er-roo got, err := table.Select(userID, deviceID, true) @@ -169,12 +98,16 @@ func TestDeviceDataTableSwaps(t *testing.T) { Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil), New: nil, } - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2}) + assertDeviceData(t, *got, want2) + + // changed bits were reset when we swapped + want2.ChangedBits = 0 + want.ChangedBits = 0 // this is permanent, read-only views show this too got, err = table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2}) + assertDeviceData(t, *got, want2) // another swap causes sent to be cleared out got, err = table.Select(userID, deviceID, true) @@ -184,16 +117,18 @@ func TestDeviceDataTableSwaps(t *testing.T) { Sent: nil, New: nil, } - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want3}) + assertDeviceData(t, *got, want3) // get back the original state for _, dd := range deltas { _, err = table.Upsert(&dd) assertNoError(t, err) } + want.SetFallbackKeysChanged() + want.SetOTKCountChanged() got, err = table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want}) + assertDeviceData(t, *got, want) // swap once then add once so both sent and new are populated _, err = table.Select(userID, deviceID, true) @@ -207,6 +142,8 @@ func TestDeviceDataTableSwaps(t *testing.T) { }) assertNoError(t, err) + want.ChangedBits = 0 + want4 := want want4.DeviceLists = internal.DeviceLists{ Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil), @@ -214,7 +151,7 @@ func TestDeviceDataTableSwaps(t *testing.T) { } got, err = table.Select(userID, deviceID, false) assertNoError(t, err) - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want4}) + assertDeviceData(t, *got, want4) // another append then consume _, err = table.Upsert(&internal.DeviceData{ @@ -232,6 +169,68 @@ func TestDeviceDataTableSwaps(t *testing.T) { Sent: internal.ToDeviceListChangesMap([]string{"bob", "dave"}, []string{"charlie", "dave"}), New: nil, } - assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want5}) - + assertDeviceData(t, *got, want5) +} + +func TestDeviceDataTableBitset(t *testing.T) { + db, err := sqlx.Open("postgres", postgresConnectionString) + if err != nil { + t.Fatalf("failed to open SQL db: %s", err) + } + table := NewDeviceDataTable(db) + userID := "@bobTestDeviceDataTableBitset" + deviceID := "BOBTestDeviceDataTableBitset" + otkUpdate := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + OTKCounts: map[string]int{ + "foo": 100, + "bar": 92, + }, + } + fallbakKeyUpdate := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + FallbackKeyTypes: []string{"foo", "bar"}, + } + bothUpdate := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + FallbackKeyTypes: []string{"both"}, + OTKCounts: map[string]int{ + "both": 100, + }, + } + + _, err = table.Upsert(&otkUpdate) + assertNoError(t, err) + got, err := table.Select(userID, deviceID, true) + assertNoError(t, err) + otkUpdate.SetOTKCountChanged() + assertDeviceData(t, *got, otkUpdate) + // second time swapping causes no OTKs as there have been no changes + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + otkUpdate.ChangedBits = 0 + assertDeviceData(t, *got, otkUpdate) + // now same for fallback keys, but we won't swap them so it should return those diffs + _, err = table.Upsert(&fallbakKeyUpdate) + assertNoError(t, err) + fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts + got, err = table.Select(userID, deviceID, false) + assertNoError(t, err) + fallbakKeyUpdate.SetFallbackKeysChanged() + assertDeviceData(t, *got, fallbakKeyUpdate) + got, err = table.Select(userID, deviceID, false) + assertNoError(t, err) + fallbakKeyUpdate.SetFallbackKeysChanged() + assertDeviceData(t, *got, fallbakKeyUpdate) + // updating both works + _, err = table.Upsert(&bothUpdate) + assertNoError(t, err) + got, err = table.Select(userID, deviceID, true) + assertNoError(t, err) + bothUpdate.SetFallbackKeysChanged() + bothUpdate.SetOTKCountChanged() + assertDeviceData(t, *got, bothUpdate) } diff --git a/state/storage.go b/state/storage.go index baa1c03..98bad63 100644 --- a/state/storage.go +++ b/state/storage.go @@ -47,7 +47,7 @@ type Storage struct { func NewStorage(postgresURI string) *Storage { db, err := sqlx.Open("postgres", postgresURI) if err != nil { - log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB") + logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB") } acc := &Accumulator{ db: db, diff --git a/state/to_device_table.go b/state/to_device_table.go index 2645d49..2ab2245 100644 --- a/state/to_device_table.go +++ b/state/to_device_table.go @@ -92,7 +92,7 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso m := gjson.ParseBytes(msgs[i]) msgId := m.Get(`content.org\.matrix\.msgid`).Str if msgId != "" { - log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages") + logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages") } } upTo = rows[len(rows)-1].Position @@ -125,7 +125,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) } msgId := m.Get(`content.org\.matrix\.msgid`).Str if msgId != "" { - log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages") + logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages") } switch rows[i].Type { case "m.room_key_request": diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 0632c3f..fe495f2 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -178,7 +178,8 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int, return } h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{ - Pos: nextPos, + DeviceID: deviceID, + Pos: nextPos, }) } diff --git a/sync2/poller.go b/sync2/poller.go index 136b112..db7813b 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -459,31 +459,29 @@ func (p *poller) parseToDeviceMessages(res *SyncResponse) { } func (p *poller) parseE2EEData(res *SyncResponse) { - hasE2EEChanges := false + var changedOTKCounts map[string]int if res.DeviceListsOTKCount != nil && len(res.DeviceListsOTKCount) > 0 { if len(p.otkCounts) != len(res.DeviceListsOTKCount) { - hasE2EEChanges = true - } - if !hasE2EEChanges && p.otkCounts != nil { + changedOTKCounts = res.DeviceListsOTKCount + } else if p.otkCounts != nil { for k := range res.DeviceListsOTKCount { if res.DeviceListsOTKCount[k] != p.otkCounts[k] { - hasE2EEChanges = true + changedOTKCounts = res.DeviceListsOTKCount break } } } p.otkCounts = res.DeviceListsOTKCount } + var changedFallbackTypes []string if len(res.DeviceUnusedFallbackKeyTypes) > 0 { - if !hasE2EEChanges { - if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) { - hasE2EEChanges = true - } else { - for i := range res.DeviceUnusedFallbackKeyTypes { - if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] { - hasE2EEChanges = true - break - } + if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) { + changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes + } else { + for i := range res.DeviceUnusedFallbackKeyTypes { + if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] { + changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes + break } } } @@ -491,12 +489,9 @@ func (p *poller) parseE2EEData(res *SyncResponse) { } deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left) - if deviceListChanges != nil { - hasE2EEChanges = true - } - if hasE2EEChanges { - p.receiver.OnE2EEData(p.userID, p.deviceID, p.otkCounts, p.fallbackKeyTypes, deviceListChanges) + if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil { + p.receiver.OnE2EEData(p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges) } } diff --git a/sync3/caches/global.go b/sync3/caches/global.go index 6e7bbcc..0557a7c 100644 --- a/sync3/caches/global.go +++ b/sync3/caches/global.go @@ -7,7 +7,6 @@ import ( "sort" "sync" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/state" "github.com/rs/zerolog" @@ -187,15 +186,6 @@ func (c *GlobalCache) Startup(roomIDToMetadata map[string]internal.RoomMetadata) sort.Strings(roomIDs) for _, roomID := range roomIDs { metadata := roomIDToMetadata[roomID] - var upgradedRoomID string - if metadata.UpgradedRoomID != nil { - upgradedRoomID = *metadata.UpgradedRoomID - } - logger.Debug().Str("room", roomID).Interface( - "recent", gomatrixserverlib.Timestamp(metadata.LastMessageTimestamp).Time(), - ).Bool("encrypted", metadata.Encrypted).Str("tombstone", upgradedRoomID).Int("joins", metadata.JoinCount).Msg( - "", - ) internal.Assert("room ID is set", metadata.RoomID != "") internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1) c.roomIDToMetadata[roomID] = &metadata diff --git a/sync3/caches/update.go b/sync3/caches/update.go index bac3620..aa5a40b 100644 --- a/sync3/caches/update.go +++ b/sync3/caches/update.go @@ -54,6 +54,7 @@ type RoomAccountDataUpdate struct { AccountData []state.AccountData } -// Alerts result in changes to ops, subs or ext modifications -// Alerts can update internal conn state -// Dispatcher thread ultimately fires alerts OR poller thread e.g OnUnreadCounts +type DeviceDataUpdate struct { + // no data; just wakes up the connection + // data comes via sidechannels e.g the database +} diff --git a/sync3/conn.go b/sync3/conn.go index 4fde6d1..b390674 100644 --- a/sync3/conn.go +++ b/sync3/conn.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/sync3/caches" ) type ConnID struct { @@ -22,6 +23,7 @@ type ConnHandler interface { // to send back or an error. Errors of type *internal.HandlerError are inspected for the correct // status code to send back. OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error) + OnUpdate(update caches.Update) UserID() string Destroy() Alive() bool @@ -69,6 +71,10 @@ func (c *Conn) Alive() bool { return c.handler.Alive() } +func (c *Conn) OnUpdate(update caches.Update) { + c.handler.OnUpdate(update) +} + func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) { defer func() { panicErr := recover() diff --git a/sync3/conn_test.go b/sync3/conn_test.go index b987e5f..d250590 100644 --- a/sync3/conn_test.go +++ b/sync3/conn_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/matrix-org/sliding-sync/internal" + "github.com/matrix-org/sliding-sync/sync3/caches" ) type connHandlerMock struct { @@ -21,8 +22,9 @@ func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req func (c *connHandlerMock) UserID() string { return "dummy" } -func (c *connHandlerMock) Destroy() {} -func (c *connHandlerMock) Alive() bool { return true } +func (c *connHandlerMock) Destroy() {} +func (c *connHandlerMock) Alive() bool { return true } +func (c *connHandlerMock) OnUpdate(update caches.Update) {} // Test that Conn can send and receive requests based on positions func TestConn(t *testing.T) { diff --git a/sync3/extensions/e2ee.go b/sync3/extensions/e2ee.go index cec182d..5ef8b68 100644 --- a/sync3/extensions/e2ee.go +++ b/sync3/extensions/e2ee.go @@ -3,6 +3,7 @@ package extensions import ( "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/sync2" + "github.com/matrix-org/sliding-sync/sync3/caches" ) // Client created request params @@ -31,8 +32,15 @@ func (r *E2EEResponse) HasData(isInitial bool) bool { if isInitial { return true // ensure we send OTK counts immediately } - // OTK counts aren't enough to make /sync return early as we send them liberally, not just on change - return r.DeviceLists != nil + return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0 +} + +func ProcessLiveE2EE(up caches.Update, fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) { + _, ok := up.(caches.DeviceDataUpdate) + if !ok { + return nil + } + return ProcessE2EE(fetcher, userID, deviceID, req, false) } func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest, isInitial bool) (res *E2EEResponse) { @@ -42,10 +50,10 @@ func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERe if dd == nil { return res // unknown device? } - if dd.FallbackKeyTypes != nil { + if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || isInitial) { res.FallbackKeyTypes = dd.FallbackKeyTypes } - if dd.OTKCounts != nil { + if dd.OTKCounts != nil && (dd.OTKCountChanged() || isInitial) { res.OTKCounts = dd.OTKCounts } changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent) diff --git a/sync3/extensions/extensions.go b/sync3/extensions/extensions.go index 4033ff2..fc209c5 100644 --- a/sync3/extensions/extensions.go +++ b/sync3/extensions/extensions.go @@ -78,6 +78,14 @@ func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Respo if req.Receipts != nil && req.Receipts.Enabled { res.Receipts = ProcessLiveReceipts(update, updateWillReturnResponse, req.UserID, req.Receipts) } + // only process 'live' e2ee when we aren't going to return data as we need to ensure that we don't calculate this twice + // e.g once on incoming request then again due to wakeup + if req.E2EE != nil && req.E2EE.Enabled { + if res.E2EE != nil && res.E2EE.HasData(false) { + return + } + res.E2EE = ProcessLiveE2EE(update, h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE) + } } func (h *Handler) Handle(req Request, roomIDToTimeline map[string][]string, isInitial bool) (res Response) { diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index b4ecd57..9bdc2ea 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -522,10 +522,15 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) { userCache.(*caches.UserCache).OnUnreadCounts(p.RoomID, p.HighlightCount, p.NotificationCount) } -// TODO: We don't eagerly push device data updates on waiting conns (otk counts, device list changes) -// Do we need to? +// push device data updates on waiting conns (otk counts, device list changes) func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) { - // Do nothing for now + conn := h.ConnMap.Conn(sync3.ConnID{ + DeviceID: p.DeviceID, + }) + if conn == nil { + return + } + conn.OnUpdate(caches.DeviceDataUpdate{}) } func (h *SyncLiveHandler) OnInvite(p *pubsub.V2InviteRoom) { diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index 5b5cab4..be2da25 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -2,7 +2,6 @@ package syncv3 import ( "encoding/json" - "fmt" "testing" "time" @@ -50,7 +49,16 @@ func TestExtensionE2EE(t *testing.T) { }) m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes)) - // check that OTK counts / fallback key types remain constant when they aren't included in the v2 response. + // Dummy request as we will see the same otk/fallback keys twice initially + res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{ + Lists: []sync3.RequestList{{ + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, // doesn't matter + }, + }}, + }) + + // check that OTK counts / fallback key types aren't present afterwards as they haven't changed. // Do this by feeding in a new joined room v2.queueResponse(alice, sync2.SyncResponse{ Rooms: sync2.SyncRoomsResponse{ @@ -61,6 +69,7 @@ func TestExtensionE2EE(t *testing.T) { }), }, }) + v2.waitUntilEmpty(t, alice) res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{ Lists: []sync3.RequestList{{ Ranges: sync3.SliceRanges{ @@ -69,10 +78,9 @@ func TestExtensionE2EE(t *testing.T) { }}, // skip enabled: true as it should be sticky }) - m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes)) + m.MatchResponse(t, res, m.MatchNoE2EEExtension()) // No E2EE changes = no extension // check that OTK counts update when they are included in the v2 response - // check fallback key types persist when not included otkCounts = map[string]int{ "curve25519": 99, "signed_curve25519": 999, @@ -87,14 +95,8 @@ func TestExtensionE2EE(t *testing.T) { [2]int64{0, 10}, // doesn't matter }, }}, - // enable the E2EE extension - Extensions: extensions.Request{ - E2EE: &extensions.E2EERequest{ - Enabled: true, - }, - }, }) - m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes)) + m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(nil)) // check that changed|left get passed to v3 wantChanged := []string{"bob"} @@ -152,6 +154,7 @@ func TestExtensionE2EE(t *testing.T) { }), }, }) + v2.waitUntilEmpty(t, alice) res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{ Lists: []sync3.RequestList{{ Ranges: sync3.SliceRanges{ @@ -165,12 +168,31 @@ func TestExtensionE2EE(t *testing.T) { }, }, }) - m.MatchResponse(t, res, func(res *sync3.Response) error { - if res.Extensions.E2EE.DeviceLists != nil { - return fmt.Errorf("e2ee device lists present when it shouldn't be") - } - return nil + m.MatchResponse(t, res, m.MatchNoE2EEExtension()) + + // Check that OTK counts are immediately sent to the client + otkCounts = map[string]int{ + "curve25519": 42, + "signed_curve25519": 420, + } + v2.queueResponse(alice, sync2.SyncResponse{ + DeviceListsOTKCount: otkCounts, }) + v2.waitUntilEmpty(t, alice) + req := sync3.Request{ + Lists: []sync3.RequestList{{ + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, // doesn't matter + }, + }}, + } + req.SetTimeoutMSecs(500) + start := time.Now() + res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, req) + m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts)) + if time.Since(start) >= (500 * time.Millisecond) { + t.Fatalf("sync request did not return immediately with OTK counts") + } } // Checks that to-device messages are passed from v2 to v3 diff --git a/testutils/m/match.go b/testutils/m/match.go index 78d8f15..3f562e0 100644 --- a/testutils/m/match.go +++ b/testutils/m/match.go @@ -227,6 +227,15 @@ func MatchRoomSubscriptions(wantSubs map[string][]RoomMatcher) RespMatcher { } } +func MatchNoE2EEExtension() RespMatcher { + return func(res *sync3.Response) error { + if res.Extensions.E2EE != nil { + return fmt.Errorf("MatchNoE2EEExtension: got E2EE extension: %+v", res.Extensions.E2EE) + } + return nil + } +} + func MatchOTKCounts(otkCounts map[string]int) RespMatcher { return func(res *sync3.Response) error { if res.Extensions.E2EE == nil {