improvement: completely refactor device data updates

- `Conn`s now expose a direct `OnUpdate(caches.Update)` function
  for updates which concern a specific device ID.
- Add a bitset in `DeviceData` to indicate if the OTK or fallback keys were changed.
- Pass through the affected `DeviceID` in `pubsub.V2DeviceData` updates.
- Remove `DeviceDataTable.SelectFrom` as it was unused.
- Refactor how the poller invokes `OnE2EEData`: it now only does this if
  there are changes to OTK counts and/or fallback key types and/or device lists,
  and _only_ sends those fields, setting the rest to the zero value.
- Remove noisy logging.
- Add `caches.DeviceDataUpdate` which has no data but serves to wake-up the long poller.
- Only send OTK counts / fallback key types when they have changed, not constantly. This
  matches the behaviour described in MSC3884

The entire flow now looks like:
- Poller notices a diff against in-memory version of otk count and invokes `OnE2EEData`
- Handler updates device data table, bumps the changed bit for otk count.
- Other handler gets the pubsub update, directly finds the `Conn` based on the `DeviceID`.
  Invokes `OnUpdate(caches.DeviceDataUpdate)`
- This update is handled by the E2EE extension which then pulls the data out from the database
  and returns it.
- On initial connections, all OTK / fallback data is returned.
This commit is contained in:
Kegan Dougal 2022-12-22 15:08:42 +00:00
parent f80dc00eaf
commit 6c4f7d3722
19 changed files with 297 additions and 182 deletions

View File

@ -4,6 +4,20 @@ import (
"sync"
)
const (
bitOTKCount int = iota
bitFallbackKeyTypes
)
func setBit(n int, bit int) int {
n |= (1 << bit)
return n
}
func isBitSet(n int, bit int) bool {
val := n & (1 << bit)
return val > 0
}
// DeviceData contains useful data for this user's device. This list can be expanded without prompting
// schema changes. These values are upserted into the database and persisted forever.
type DeviceData struct {
@ -16,10 +30,29 @@ type DeviceData struct {
DeviceLists DeviceLists `json:"dl"`
// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`
UserID string
DeviceID string
}
func (dd *DeviceData) SetOTKCountChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) SetFallbackKeysChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
}
func (dd *DeviceData) OTKCountChanged() bool {
return isBitSet(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) FallbackKeysChanged() bool {
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
}
type UserDeviceKey struct {
UserID string
DeviceID string

View File

@ -0,0 +1,57 @@
package internal
import "testing"
func TestDeviceDataBitset(t *testing.T) {
testCases := []struct {
get func() DeviceData
fallbackSet bool
otkSet bool
}{
{
get: func() DeviceData {
var dd DeviceData
dd.SetFallbackKeysChanged()
return dd
},
fallbackSet: true,
otkSet: false,
},
{
get: func() DeviceData {
var dd DeviceData
dd.SetOTKCountChanged()
return dd
},
fallbackSet: false,
otkSet: true,
},
{
get: func() DeviceData {
var dd DeviceData
dd.SetFallbackKeysChanged()
dd.SetOTKCountChanged()
return dd
},
fallbackSet: true,
otkSet: true,
},
{
get: func() DeviceData {
var dd DeviceData
return dd
},
fallbackSet: false,
otkSet: false,
},
}
for _, tc := range testCases {
dd := tc.get()
if dd.FallbackKeysChanged() != tc.fallbackSet {
t.Errorf("%v : wrong fallback value, want %v", dd, tc.fallbackSet)
}
if dd.OTKCountChanged() != tc.otkSet {
t.Errorf("%v : wrong OTK value, want %v", dd, tc.otkSet)
}
}
}

View File

@ -76,7 +76,8 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct {
Pos int64
DeviceID string
Pos int64
}
func (*V2DeviceData) Type() string { return "V2DeviceData" }

View File

@ -4,20 +4,13 @@ import (
"database/sql"
"encoding/json"
"fmt"
"os"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
// Accumulator tracks room state and timelines.
//
// In order for it to remain simple(ish), the accumulator DOES NOT SUPPORT arbitrary timeline gaps.
@ -77,7 +70,7 @@ func (a *Accumulator) calculateNewSnapshot(old StrippedEvents, new Event) (Strip
// ruh roh. This should be impossible, but it can happen if the v2 response sends the same
// event in both state and timeline. We need to alert the operator and whine badly as it means
// we have lost an event by now.
log.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg(
logger.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg(
"Detected different event IDs with the same NID when rolling forward state. This has resulted in data loss in this room (1 event). " +
"This can happen when the v2 /sync response sends the same event in both state and timeline sections. " +
"The event in this log line has been dropped!",
@ -160,7 +153,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool,
}
if snapshotID > 0 {
// we only initialise rooms once
log.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
logger.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
return nil
}
@ -183,7 +176,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool,
if len(eventIDToNID) == 0 {
// we don't have a current snapshot for this room but yet no events are new,
// no idea how this should be handled.
log.Error().Str("room_id", roomID).Msg(
logger.Error().Str("room_id", roomID).Msg(
"Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug.",
)
return nil
@ -270,7 +263,7 @@ func (a *Accumulator) Accumulate(roomID string, prevBatch string, timeline []jso
return fmt.Errorf("event malformed: %s", err)
}
if _, ok := seenEvents[e.ID]; ok {
log.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
)
continue

View File

@ -66,6 +66,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern
n := tempDD.DeviceLists.New
tempDD.DeviceLists.Sent = n
tempDD.DeviceLists.New = make(map[string]int)
// reset changed bits
changedBits := tempDD.ChangedBits
tempDD.ChangedBits = 0
// re-marshal and write
data, err := json.Marshal(tempDD)
@ -74,32 +77,12 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern
}
_, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
dd = &tempDD
dd.ChangedBits = changedBits
return err
})
return
}
func (t *DeviceDataTable) SelectFrom(pos int64) (results []internal.DeviceData, nextPos int64, err error) {
nextPos = pos
var rows []DeviceDataRow
err = t.db.Select(&rows, `SELECT id, user_id, device_id, data FROM syncv3_device_data WHERE id > $1 ORDER BY id ASC`, pos)
if err != nil {
return
}
results = make([]internal.DeviceData, len(rows))
for i := range rows {
var dd internal.DeviceData
if err = json.Unmarshal(rows[i].Data, &dd); err != nil {
return
}
dd.UserID = rows[i].UserID
dd.DeviceID = rows[i].DeviceID
results[i] = dd
nextPos = rows[i].ID
}
return
}
// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
@ -118,9 +101,11 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
}
if dd.FallbackKeyTypes != nil {
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
tempDD.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
tempDD.OTKCounts = dd.OTKCounts
tempDD.SetOTKCountChanged()
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)

View File

@ -9,92 +9,19 @@ import (
)
func assertVal(t *testing.T, msg string, got, want interface{}) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("%s: got %v want %v", msg, got, want)
}
}
func assertDeviceDatas(t *testing.T, got, want []internal.DeviceData) {
func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("got %d devices, want %d : %+v", len(got), len(want), got)
}
for i := range want {
g := got[i]
w := want[i]
assertVal(t, "device id", g.DeviceID, w.DeviceID)
assertVal(t, "user id", g.UserID, w.UserID)
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
}
}
func TestDeviceDataTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewDeviceDataTable(db)
userID := "@alice"
deviceID := "ALICE"
dd := &internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 100,
},
FallbackKeyTypes: []string{"foo", "bar"},
}
// test basic insert -> select
pos, err := table.Upsert(dd)
assertNoError(t, err)
results, nextPos, err := table.SelectFrom(-1)
assertNoError(t, err)
if pos != nextPos {
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
}
assertDeviceDatas(t, results, []internal.DeviceData{*dd})
// at latest -> no results
results, nextPos, err = table.SelectFrom(nextPos)
assertNoError(t, err)
if pos != nextPos {
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
}
assertDeviceDatas(t, results, nil)
// multiple insert -> replace on user|device
dd2 := *dd
dd2.OTKCounts = map[string]int{"foo": 99}
_, err = table.Upsert(&dd2)
assertNoError(t, err)
dd3 := *dd
dd3.OTKCounts = map[string]int{"foo": 98}
pos, err = table.Upsert(&dd3)
assertNoError(t, err)
results, nextPos, err = table.SelectFrom(nextPos)
assertNoError(t, err)
if pos != nextPos {
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
}
assertDeviceDatas(t, results, []internal.DeviceData{dd3})
// multiple insert -> different user, same device + same user, different device
dd4 := *dd
dd4.UserID = "@bob"
_, err = table.Upsert(&dd4)
assertNoError(t, err)
dd5 := *dd
dd5.DeviceID = "ANOTHER"
pos, err = table.Upsert(&dd5)
assertNoError(t, err)
results, nextPos, err = table.SelectFrom(nextPos)
assertNoError(t, err)
if pos != nextPos {
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
}
assertDeviceDatas(t, results, []internal.DeviceData{dd4, dd5})
assertVal(t, "device id", g.DeviceID, w.DeviceID)
assertVal(t, "user id", g.UserID, w.UserID)
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
}
func TestDeviceDataTableSwaps(t *testing.T) {
@ -155,11 +82,13 @@ func TestDeviceDataTableSwaps(t *testing.T) {
New: internal.ToDeviceListChangesMap([]string{"alice", "bob"}, nil),
},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
// check we can read-only select
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want})
assertDeviceData(t, *got, want)
}
// now swap-er-roo
got, err := table.Select(userID, deviceID, true)
@ -169,12 +98,16 @@ func TestDeviceDataTableSwaps(t *testing.T) {
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
New: nil,
}
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2})
assertDeviceData(t, *got, want2)
// changed bits were reset when we swapped
want2.ChangedBits = 0
want.ChangedBits = 0
// this is permanent, read-only views show this too
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2})
assertDeviceData(t, *got, want2)
// another swap causes sent to be cleared out
got, err = table.Select(userID, deviceID, true)
@ -184,16 +117,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
Sent: nil,
New: nil,
}
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want3})
assertDeviceData(t, *got, want3)
// get back the original state
for _, dd := range deltas {
_, err = table.Upsert(&dd)
assertNoError(t, err)
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want})
assertDeviceData(t, *got, want)
// swap once then add once so both sent and new are populated
_, err = table.Select(userID, deviceID, true)
@ -207,6 +142,8 @@ func TestDeviceDataTableSwaps(t *testing.T) {
})
assertNoError(t, err)
want.ChangedBits = 0
want4 := want
want4.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
@ -214,7 +151,7 @@ func TestDeviceDataTableSwaps(t *testing.T) {
}
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want4})
assertDeviceData(t, *got, want4)
// another append then consume
_, err = table.Upsert(&internal.DeviceData{
@ -232,6 +169,68 @@ func TestDeviceDataTableSwaps(t *testing.T) {
Sent: internal.ToDeviceListChangesMap([]string{"bob", "dave"}, []string{"charlie", "dave"}),
New: nil,
}
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want5})
assertDeviceData(t, *got, want5)
}
func TestDeviceDataTableBitset(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewDeviceDataTable(db)
userID := "@bobTestDeviceDataTableBitset"
deviceID := "BOBTestDeviceDataTableBitset"
otkUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 100,
"bar": 92,
},
}
fallbakKeyUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"foo", "bar"},
}
bothUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"both"},
OTKCounts: map[string]int{
"both": 100,
},
}
_, err = table.Upsert(&otkUpdate)
assertNoError(t, err)
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
otkUpdate.SetOTKCountChanged()
assertDeviceData(t, *got, otkUpdate)
// second time swapping causes no OTKs as there have been no changes
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
otkUpdate.ChangedBits = 0
assertDeviceData(t, *got, otkUpdate)
// now same for fallback keys, but we won't swap them so it should return those diffs
_, err = table.Upsert(&fallbakKeyUpdate)
assertNoError(t, err)
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
fallbakKeyUpdate.SetFallbackKeysChanged()
assertDeviceData(t, *got, fallbakKeyUpdate)
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
fallbakKeyUpdate.SetFallbackKeysChanged()
assertDeviceData(t, *got, fallbakKeyUpdate)
// updating both works
_, err = table.Upsert(&bothUpdate)
assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
bothUpdate.SetFallbackKeysChanged()
bothUpdate.SetOTKCountChanged()
assertDeviceData(t, *got, bothUpdate)
}

View File

@ -47,7 +47,7 @@ type Storage struct {
func NewStorage(postgresURI string) *Storage {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
acc := &Accumulator{
db: db,

View File

@ -92,7 +92,7 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso
m := gjson.ParseBytes(msgs[i])
msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" {
log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
}
}
upTo = rows[len(rows)-1].Position
@ -125,7 +125,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
}
msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" {
log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
}
switch rows[i].Type {
case "m.room_key_request":

View File

@ -178,7 +178,8 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
Pos: nextPos,
DeviceID: deviceID,
Pos: nextPos,
})
}

View File

@ -459,31 +459,29 @@ func (p *poller) parseToDeviceMessages(res *SyncResponse) {
}
func (p *poller) parseE2EEData(res *SyncResponse) {
hasE2EEChanges := false
var changedOTKCounts map[string]int
if res.DeviceListsOTKCount != nil && len(res.DeviceListsOTKCount) > 0 {
if len(p.otkCounts) != len(res.DeviceListsOTKCount) {
hasE2EEChanges = true
}
if !hasE2EEChanges && p.otkCounts != nil {
changedOTKCounts = res.DeviceListsOTKCount
} else if p.otkCounts != nil {
for k := range res.DeviceListsOTKCount {
if res.DeviceListsOTKCount[k] != p.otkCounts[k] {
hasE2EEChanges = true
changedOTKCounts = res.DeviceListsOTKCount
break
}
}
}
p.otkCounts = res.DeviceListsOTKCount
}
var changedFallbackTypes []string
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
if !hasE2EEChanges {
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
hasE2EEChanges = true
} else {
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
hasE2EEChanges = true
break
}
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
} else {
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
break
}
}
}
@ -491,12 +489,9 @@ func (p *poller) parseE2EEData(res *SyncResponse) {
}
deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
if deviceListChanges != nil {
hasE2EEChanges = true
}
if hasE2EEChanges {
p.receiver.OnE2EEData(p.userID, p.deviceID, p.otkCounts, p.fallbackKeyTypes, deviceListChanges)
if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil {
p.receiver.OnE2EEData(p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges)
}
}

View File

@ -7,7 +7,6 @@ import (
"sort"
"sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/state"
"github.com/rs/zerolog"
@ -187,15 +186,6 @@ func (c *GlobalCache) Startup(roomIDToMetadata map[string]internal.RoomMetadata)
sort.Strings(roomIDs)
for _, roomID := range roomIDs {
metadata := roomIDToMetadata[roomID]
var upgradedRoomID string
if metadata.UpgradedRoomID != nil {
upgradedRoomID = *metadata.UpgradedRoomID
}
logger.Debug().Str("room", roomID).Interface(
"recent", gomatrixserverlib.Timestamp(metadata.LastMessageTimestamp).Time(),
).Bool("encrypted", metadata.Encrypted).Str("tombstone", upgradedRoomID).Int("joins", metadata.JoinCount).Msg(
"",
)
internal.Assert("room ID is set", metadata.RoomID != "")
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1)
c.roomIDToMetadata[roomID] = &metadata

View File

@ -54,6 +54,7 @@ type RoomAccountDataUpdate struct {
AccountData []state.AccountData
}
// Alerts result in changes to ops, subs or ext modifications
// Alerts can update internal conn state
// Dispatcher thread ultimately fires alerts OR poller thread e.g OnUnreadCounts
type DeviceDataUpdate struct {
// no data; just wakes up the connection
// data comes via sidechannels e.g the database
}

View File

@ -7,6 +7,7 @@ import (
"sync"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
type ConnID struct {
@ -22,6 +23,7 @@ type ConnHandler interface {
// to send back or an error. Errors of type *internal.HandlerError are inspected for the correct
// status code to send back.
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
OnUpdate(update caches.Update)
UserID() string
Destroy()
Alive() bool
@ -69,6 +71,10 @@ func (c *Conn) Alive() bool {
return c.handler.Alive()
}
func (c *Conn) OnUpdate(update caches.Update) {
c.handler.OnUpdate(update)
}
func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) {
defer func() {
panicErr := recover()

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
type connHandlerMock struct {
@ -21,8 +22,9 @@ func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req
func (c *connHandlerMock) UserID() string {
return "dummy"
}
func (c *connHandlerMock) Destroy() {}
func (c *connHandlerMock) Alive() bool { return true }
func (c *connHandlerMock) Destroy() {}
func (c *connHandlerMock) Alive() bool { return true }
func (c *connHandlerMock) OnUpdate(update caches.Update) {}
// Test that Conn can send and receive requests based on positions
func TestConn(t *testing.T) {

View File

@ -3,6 +3,7 @@ package extensions
import (
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
// Client created request params
@ -31,8 +32,15 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
if isInitial {
return true // ensure we send OTK counts immediately
}
// OTK counts aren't enough to make /sync return early as we send them liberally, not just on change
return r.DeviceLists != nil
return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0
}
func ProcessLiveE2EE(up caches.Update, fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) {
_, ok := up.(caches.DeviceDataUpdate)
if !ok {
return nil
}
return ProcessE2EE(fetcher, userID, deviceID, req, false)
}
func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest, isInitial bool) (res *E2EEResponse) {
@ -42,10 +50,10 @@ func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERe
if dd == nil {
return res // unknown device?
}
if dd.FallbackKeyTypes != nil {
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || isInitial) {
res.FallbackKeyTypes = dd.FallbackKeyTypes
}
if dd.OTKCounts != nil {
if dd.OTKCounts != nil && (dd.OTKCountChanged() || isInitial) {
res.OTKCounts = dd.OTKCounts
}
changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent)

View File

@ -78,6 +78,14 @@ func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Respo
if req.Receipts != nil && req.Receipts.Enabled {
res.Receipts = ProcessLiveReceipts(update, updateWillReturnResponse, req.UserID, req.Receipts)
}
// only process 'live' e2ee when we aren't going to return data as we need to ensure that we don't calculate this twice
// e.g once on incoming request then again due to wakeup
if req.E2EE != nil && req.E2EE.Enabled {
if res.E2EE != nil && res.E2EE.HasData(false) {
return
}
res.E2EE = ProcessLiveE2EE(update, h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE)
}
}
func (h *Handler) Handle(req Request, roomIDToTimeline map[string][]string, isInitial bool) (res Response) {

View File

@ -522,10 +522,15 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
userCache.(*caches.UserCache).OnUnreadCounts(p.RoomID, p.HighlightCount, p.NotificationCount)
}
// TODO: We don't eagerly push device data updates on waiting conns (otk counts, device list changes)
// Do we need to?
// push device data updates on waiting conns (otk counts, device list changes)
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
// Do nothing for now
conn := h.ConnMap.Conn(sync3.ConnID{
DeviceID: p.DeviceID,
})
if conn == nil {
return
}
conn.OnUpdate(caches.DeviceDataUpdate{})
}
func (h *SyncLiveHandler) OnInvite(p *pubsub.V2InviteRoom) {

View File

@ -2,7 +2,6 @@ package syncv3
import (
"encoding/json"
"fmt"
"testing"
"time"
@ -50,7 +49,16 @@ func TestExtensionE2EE(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
// check that OTK counts / fallback key types remain constant when they aren't included in the v2 response.
// Dummy request as we will see the same otk/fallback keys twice initially
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
})
// check that OTK counts / fallback key types aren't present afterwards as they haven't changed.
// Do this by feeding in a new joined room
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
@ -61,6 +69,7 @@ func TestExtensionE2EE(t *testing.T) {
}),
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
@ -69,10 +78,9 @@ func TestExtensionE2EE(t *testing.T) {
}},
// skip enabled: true as it should be sticky
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
m.MatchResponse(t, res, m.MatchNoE2EEExtension()) // No E2EE changes = no extension
// check that OTK counts update when they are included in the v2 response
// check fallback key types persist when not included
otkCounts = map[string]int{
"curve25519": 99,
"signed_curve25519": 999,
@ -87,14 +95,8 @@ func TestExtensionE2EE(t *testing.T) {
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Enabled: true,
},
},
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(nil))
// check that changed|left get passed to v3
wantChanged := []string{"bob"}
@ -152,6 +154,7 @@ func TestExtensionE2EE(t *testing.T) {
}),
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
@ -165,12 +168,31 @@ func TestExtensionE2EE(t *testing.T) {
},
},
})
m.MatchResponse(t, res, func(res *sync3.Response) error {
if res.Extensions.E2EE.DeviceLists != nil {
return fmt.Errorf("e2ee device lists present when it shouldn't be")
}
return nil
m.MatchResponse(t, res, m.MatchNoE2EEExtension())
// Check that OTK counts are immediately sent to the client
otkCounts = map[string]int{
"curve25519": 42,
"signed_curve25519": 420,
}
v2.queueResponse(alice, sync2.SyncResponse{
DeviceListsOTKCount: otkCounts,
})
v2.waitUntilEmpty(t, alice)
req := sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
}
req.SetTimeoutMSecs(500)
start := time.Now()
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, req)
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
if time.Since(start) >= (500 * time.Millisecond) {
t.Fatalf("sync request did not return immediately with OTK counts")
}
}
// Checks that to-device messages are passed from v2 to v3

View File

@ -227,6 +227,15 @@ func MatchRoomSubscriptions(wantSubs map[string][]RoomMatcher) RespMatcher {
}
}
func MatchNoE2EEExtension() RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.E2EE != nil {
return fmt.Errorf("MatchNoE2EEExtension: got E2EE extension: %+v", res.Extensions.E2EE)
}
return nil
}
}
func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.E2EE == nil {