mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
improvement: completely refactor device data updates
- `Conn`s now expose a direct `OnUpdate(caches.Update)` function for updates which concern a specific device ID. - Add a bitset in `DeviceData` to indicate if the OTK or fallback keys were changed. - Pass through the affected `DeviceID` in `pubsub.V2DeviceData` updates. - Remove `DeviceDataTable.SelectFrom` as it was unused. - Refactor how the poller invokes `OnE2EEData`: it now only does this if there are changes to OTK counts and/or fallback key types and/or device lists, and _only_ sends those fields, setting the rest to the zero value. - Remove noisy logging. - Add `caches.DeviceDataUpdate` which has no data but serves to wake-up the long poller. - Only send OTK counts / fallback key types when they have changed, not constantly. This matches the behaviour described in MSC3884 The entire flow now looks like: - Poller notices a diff against in-memory version of otk count and invokes `OnE2EEData` - Handler updates device data table, bumps the changed bit for otk count. - Other handler gets the pubsub update, directly finds the `Conn` based on the `DeviceID`. Invokes `OnUpdate(caches.DeviceDataUpdate)` - This update is handled by the E2EE extension which then pulls the data out from the database and returns it. - On initial connections, all OTK / fallback data is returned.
This commit is contained in:
parent
f80dc00eaf
commit
6c4f7d3722
@ -4,6 +4,20 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
bitOTKCount int = iota
|
||||
bitFallbackKeyTypes
|
||||
)
|
||||
|
||||
func setBit(n int, bit int) int {
|
||||
n |= (1 << bit)
|
||||
return n
|
||||
}
|
||||
func isBitSet(n int, bit int) bool {
|
||||
val := n & (1 << bit)
|
||||
return val > 0
|
||||
}
|
||||
|
||||
// DeviceData contains useful data for this user's device. This list can be expanded without prompting
|
||||
// schema changes. These values are upserted into the database and persisted forever.
|
||||
type DeviceData struct {
|
||||
@ -16,10 +30,29 @@ type DeviceData struct {
|
||||
|
||||
DeviceLists DeviceLists `json:"dl"`
|
||||
|
||||
// bitset for which device data changes are present. They accumulate until they get swapped over
|
||||
// when they get reset
|
||||
ChangedBits int `json:"c"`
|
||||
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
func (dd *DeviceData) SetOTKCountChanged() {
|
||||
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
|
||||
}
|
||||
|
||||
func (dd *DeviceData) SetFallbackKeysChanged() {
|
||||
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
|
||||
}
|
||||
|
||||
func (dd *DeviceData) OTKCountChanged() bool {
|
||||
return isBitSet(dd.ChangedBits, bitOTKCount)
|
||||
}
|
||||
func (dd *DeviceData) FallbackKeysChanged() bool {
|
||||
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
|
||||
}
|
||||
|
||||
type UserDeviceKey struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
|
57
internal/device_data_test.go
Normal file
57
internal/device_data_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package internal
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDeviceDataBitset(t *testing.T) {
|
||||
testCases := []struct {
|
||||
get func() DeviceData
|
||||
fallbackSet bool
|
||||
otkSet bool
|
||||
}{
|
||||
{
|
||||
get: func() DeviceData {
|
||||
var dd DeviceData
|
||||
dd.SetFallbackKeysChanged()
|
||||
return dd
|
||||
},
|
||||
fallbackSet: true,
|
||||
otkSet: false,
|
||||
},
|
||||
{
|
||||
get: func() DeviceData {
|
||||
var dd DeviceData
|
||||
dd.SetOTKCountChanged()
|
||||
return dd
|
||||
},
|
||||
fallbackSet: false,
|
||||
otkSet: true,
|
||||
},
|
||||
{
|
||||
get: func() DeviceData {
|
||||
var dd DeviceData
|
||||
dd.SetFallbackKeysChanged()
|
||||
dd.SetOTKCountChanged()
|
||||
return dd
|
||||
},
|
||||
fallbackSet: true,
|
||||
otkSet: true,
|
||||
},
|
||||
{
|
||||
get: func() DeviceData {
|
||||
var dd DeviceData
|
||||
return dd
|
||||
},
|
||||
fallbackSet: false,
|
||||
otkSet: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
dd := tc.get()
|
||||
if dd.FallbackKeysChanged() != tc.fallbackSet {
|
||||
t.Errorf("%v : wrong fallback value, want %v", dd, tc.fallbackSet)
|
||||
}
|
||||
if dd.OTKCountChanged() != tc.otkSet {
|
||||
t.Errorf("%v : wrong OTK value, want %v", dd, tc.otkSet)
|
||||
}
|
||||
}
|
||||
}
|
@ -76,7 +76,8 @@ type V2InitialSyncComplete struct {
|
||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||
|
||||
type V2DeviceData struct {
|
||||
Pos int64
|
||||
DeviceID string
|
||||
Pos int64
|
||||
}
|
||||
|
||||
func (*V2DeviceData) Type() string { return "V2DeviceData" }
|
||||
|
@ -4,20 +4,13 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
})
|
||||
|
||||
// Accumulator tracks room state and timelines.
|
||||
//
|
||||
// In order for it to remain simple(ish), the accumulator DOES NOT SUPPORT arbitrary timeline gaps.
|
||||
@ -77,7 +70,7 @@ func (a *Accumulator) calculateNewSnapshot(old StrippedEvents, new Event) (Strip
|
||||
// ruh roh. This should be impossible, but it can happen if the v2 response sends the same
|
||||
// event in both state and timeline. We need to alert the operator and whine badly as it means
|
||||
// we have lost an event by now.
|
||||
log.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg(
|
||||
logger.Warn().Str("new_event_id", new.ID).Str("old_event_id", e.ID).Str("room_id", new.RoomID).Str("type", new.Type).Str("state_key", new.StateKey).Msg(
|
||||
"Detected different event IDs with the same NID when rolling forward state. This has resulted in data loss in this room (1 event). " +
|
||||
"This can happen when the v2 /sync response sends the same event in both state and timeline sections. " +
|
||||
"The event in this log line has been dropped!",
|
||||
@ -160,7 +153,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool,
|
||||
}
|
||||
if snapshotID > 0 {
|
||||
// we only initialise rooms once
|
||||
log.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
|
||||
logger.Info().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -183,7 +176,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (bool,
|
||||
if len(eventIDToNID) == 0 {
|
||||
// we don't have a current snapshot for this room but yet no events are new,
|
||||
// no idea how this should be handled.
|
||||
log.Error().Str("room_id", roomID).Msg(
|
||||
logger.Error().Str("room_id", roomID).Msg(
|
||||
"Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug.",
|
||||
)
|
||||
return nil
|
||||
@ -270,7 +263,7 @@ func (a *Accumulator) Accumulate(roomID string, prevBatch string, timeline []jso
|
||||
return fmt.Errorf("event malformed: %s", err)
|
||||
}
|
||||
if _, ok := seenEvents[e.ID]; ok {
|
||||
log.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
|
||||
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
|
||||
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
|
||||
)
|
||||
continue
|
||||
|
@ -66,6 +66,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern
|
||||
n := tempDD.DeviceLists.New
|
||||
tempDD.DeviceLists.Sent = n
|
||||
tempDD.DeviceLists.New = make(map[string]int)
|
||||
// reset changed bits
|
||||
changedBits := tempDD.ChangedBits
|
||||
tempDD.ChangedBits = 0
|
||||
|
||||
// re-marshal and write
|
||||
data, err := json.Marshal(tempDD)
|
||||
@ -74,32 +77,12 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (dd *intern
|
||||
}
|
||||
_, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||
dd = &tempDD
|
||||
dd.ChangedBits = changedBits
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (t *DeviceDataTable) SelectFrom(pos int64) (results []internal.DeviceData, nextPos int64, err error) {
|
||||
nextPos = pos
|
||||
var rows []DeviceDataRow
|
||||
err = t.db.Select(&rows, `SELECT id, user_id, device_id, data FROM syncv3_device_data WHERE id > $1 ORDER BY id ASC`, pos)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
results = make([]internal.DeviceData, len(rows))
|
||||
for i := range rows {
|
||||
var dd internal.DeviceData
|
||||
if err = json.Unmarshal(rows[i].Data, &dd); err != nil {
|
||||
return
|
||||
}
|
||||
dd.UserID = rows[i].UserID
|
||||
dd.DeviceID = rows[i].DeviceID
|
||||
results[i] = dd
|
||||
nextPos = rows[i].ID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
||||
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) {
|
||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||
@ -118,9 +101,11 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
|
||||
}
|
||||
if dd.FallbackKeyTypes != nil {
|
||||
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
|
||||
tempDD.SetFallbackKeysChanged()
|
||||
}
|
||||
if dd.OTKCounts != nil {
|
||||
tempDD.OTKCounts = dd.OTKCounts
|
||||
tempDD.SetOTKCountChanged()
|
||||
}
|
||||
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)
|
||||
|
||||
|
@ -9,92 +9,19 @@ import (
|
||||
)
|
||||
|
||||
func assertVal(t *testing.T, msg string, got, want interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("%s: got %v want %v", msg, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertDeviceDatas(t *testing.T, got, want []internal.DeviceData) {
|
||||
func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %d devices, want %d : %+v", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
g := got[i]
|
||||
w := want[i]
|
||||
assertVal(t, "device id", g.DeviceID, w.DeviceID)
|
||||
assertVal(t, "user id", g.UserID, w.UserID)
|
||||
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
|
||||
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceDataTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
table := NewDeviceDataTable(db)
|
||||
userID := "@alice"
|
||||
deviceID := "ALICE"
|
||||
dd := &internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
OTKCounts: map[string]int{
|
||||
"foo": 100,
|
||||
},
|
||||
FallbackKeyTypes: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
// test basic insert -> select
|
||||
pos, err := table.Upsert(dd)
|
||||
assertNoError(t, err)
|
||||
results, nextPos, err := table.SelectFrom(-1)
|
||||
assertNoError(t, err)
|
||||
if pos != nextPos {
|
||||
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
|
||||
}
|
||||
assertDeviceDatas(t, results, []internal.DeviceData{*dd})
|
||||
|
||||
// at latest -> no results
|
||||
results, nextPos, err = table.SelectFrom(nextPos)
|
||||
assertNoError(t, err)
|
||||
if pos != nextPos {
|
||||
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
|
||||
}
|
||||
assertDeviceDatas(t, results, nil)
|
||||
|
||||
// multiple insert -> replace on user|device
|
||||
dd2 := *dd
|
||||
dd2.OTKCounts = map[string]int{"foo": 99}
|
||||
_, err = table.Upsert(&dd2)
|
||||
assertNoError(t, err)
|
||||
dd3 := *dd
|
||||
dd3.OTKCounts = map[string]int{"foo": 98}
|
||||
pos, err = table.Upsert(&dd3)
|
||||
assertNoError(t, err)
|
||||
results, nextPos, err = table.SelectFrom(nextPos)
|
||||
assertNoError(t, err)
|
||||
if pos != nextPos {
|
||||
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
|
||||
}
|
||||
assertDeviceDatas(t, results, []internal.DeviceData{dd3})
|
||||
|
||||
// multiple insert -> different user, same device + same user, different device
|
||||
dd4 := *dd
|
||||
dd4.UserID = "@bob"
|
||||
_, err = table.Upsert(&dd4)
|
||||
assertNoError(t, err)
|
||||
dd5 := *dd
|
||||
dd5.DeviceID = "ANOTHER"
|
||||
pos, err = table.Upsert(&dd5)
|
||||
assertNoError(t, err)
|
||||
results, nextPos, err = table.SelectFrom(nextPos)
|
||||
assertNoError(t, err)
|
||||
if pos != nextPos {
|
||||
t.Fatalf("Upsert returned pos %v but SelectFrom returned pos %v", pos, nextPos)
|
||||
}
|
||||
assertDeviceDatas(t, results, []internal.DeviceData{dd4, dd5})
|
||||
assertVal(t, "device id", g.DeviceID, w.DeviceID)
|
||||
assertVal(t, "user id", g.UserID, w.UserID)
|
||||
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
|
||||
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
|
||||
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
|
||||
}
|
||||
|
||||
func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
@ -155,11 +82,13 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
New: internal.ToDeviceListChangesMap([]string{"alice", "bob"}, nil),
|
||||
},
|
||||
}
|
||||
want.SetFallbackKeysChanged()
|
||||
want.SetOTKCountChanged()
|
||||
// check we can read-only select
|
||||
for i := 0; i < 3; i++ {
|
||||
got, err := table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want})
|
||||
assertDeviceData(t, *got, want)
|
||||
}
|
||||
// now swap-er-roo
|
||||
got, err := table.Select(userID, deviceID, true)
|
||||
@ -169,12 +98,16 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
|
||||
New: nil,
|
||||
}
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2})
|
||||
assertDeviceData(t, *got, want2)
|
||||
|
||||
// changed bits were reset when we swapped
|
||||
want2.ChangedBits = 0
|
||||
want.ChangedBits = 0
|
||||
|
||||
// this is permanent, read-only views show this too
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want2})
|
||||
assertDeviceData(t, *got, want2)
|
||||
|
||||
// another swap causes sent to be cleared out
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
@ -184,16 +117,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
Sent: nil,
|
||||
New: nil,
|
||||
}
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want3})
|
||||
assertDeviceData(t, *got, want3)
|
||||
|
||||
// get back the original state
|
||||
for _, dd := range deltas {
|
||||
_, err = table.Upsert(&dd)
|
||||
assertNoError(t, err)
|
||||
}
|
||||
want.SetFallbackKeysChanged()
|
||||
want.SetOTKCountChanged()
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want})
|
||||
assertDeviceData(t, *got, want)
|
||||
|
||||
// swap once then add once so both sent and new are populated
|
||||
_, err = table.Select(userID, deviceID, true)
|
||||
@ -207,6 +142,8 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
})
|
||||
assertNoError(t, err)
|
||||
|
||||
want.ChangedBits = 0
|
||||
|
||||
want4 := want
|
||||
want4.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
|
||||
@ -214,7 +151,7 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
}
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want4})
|
||||
assertDeviceData(t, *got, want4)
|
||||
|
||||
// another append then consume
|
||||
_, err = table.Upsert(&internal.DeviceData{
|
||||
@ -232,6 +169,68 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"bob", "dave"}, []string{"charlie", "dave"}),
|
||||
New: nil,
|
||||
}
|
||||
assertDeviceDatas(t, []internal.DeviceData{*got}, []internal.DeviceData{want5})
|
||||
|
||||
assertDeviceData(t, *got, want5)
|
||||
}
|
||||
|
||||
func TestDeviceDataTableBitset(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
table := NewDeviceDataTable(db)
|
||||
userID := "@bobTestDeviceDataTableBitset"
|
||||
deviceID := "BOBTestDeviceDataTableBitset"
|
||||
otkUpdate := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
OTKCounts: map[string]int{
|
||||
"foo": 100,
|
||||
"bar": 92,
|
||||
},
|
||||
}
|
||||
fallbakKeyUpdate := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
FallbackKeyTypes: []string{"foo", "bar"},
|
||||
}
|
||||
bothUpdate := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
FallbackKeyTypes: []string{"both"},
|
||||
OTKCounts: map[string]int{
|
||||
"both": 100,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = table.Upsert(&otkUpdate)
|
||||
assertNoError(t, err)
|
||||
got, err := table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
otkUpdate.SetOTKCountChanged()
|
||||
assertDeviceData(t, *got, otkUpdate)
|
||||
// second time swapping causes no OTKs as there have been no changes
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
otkUpdate.ChangedBits = 0
|
||||
assertDeviceData(t, *got, otkUpdate)
|
||||
// now same for fallback keys, but we won't swap them so it should return those diffs
|
||||
_, err = table.Upsert(&fallbakKeyUpdate)
|
||||
assertNoError(t, err)
|
||||
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
fallbakKeyUpdate.SetFallbackKeysChanged()
|
||||
assertDeviceData(t, *got, fallbakKeyUpdate)
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
fallbakKeyUpdate.SetFallbackKeysChanged()
|
||||
assertDeviceData(t, *got, fallbakKeyUpdate)
|
||||
// updating both works
|
||||
_, err = table.Upsert(&bothUpdate)
|
||||
assertNoError(t, err)
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
bothUpdate.SetFallbackKeysChanged()
|
||||
bothUpdate.SetOTKCountChanged()
|
||||
assertDeviceData(t, *got, bothUpdate)
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ type Storage struct {
|
||||
func NewStorage(postgresURI string) *Storage {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
acc := &Accumulator{
|
||||
db: db,
|
||||
|
@ -92,7 +92,7 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso
|
||||
m := gjson.ParseBytes(msgs[i])
|
||||
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
||||
if msgId != "" {
|
||||
log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
|
||||
logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
|
||||
}
|
||||
}
|
||||
upTo = rows[len(rows)-1].Position
|
||||
@ -125,7 +125,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
||||
}
|
||||
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
||||
if msgId != "" {
|
||||
log.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
|
||||
logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
|
||||
}
|
||||
switch rows[i].Type {
|
||||
case "m.room_key_request":
|
||||
|
@ -178,7 +178,8 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
||||
Pos: nextPos,
|
||||
DeviceID: deviceID,
|
||||
Pos: nextPos,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -459,31 +459,29 @@ func (p *poller) parseToDeviceMessages(res *SyncResponse) {
|
||||
}
|
||||
|
||||
func (p *poller) parseE2EEData(res *SyncResponse) {
|
||||
hasE2EEChanges := false
|
||||
var changedOTKCounts map[string]int
|
||||
if res.DeviceListsOTKCount != nil && len(res.DeviceListsOTKCount) > 0 {
|
||||
if len(p.otkCounts) != len(res.DeviceListsOTKCount) {
|
||||
hasE2EEChanges = true
|
||||
}
|
||||
if !hasE2EEChanges && p.otkCounts != nil {
|
||||
changedOTKCounts = res.DeviceListsOTKCount
|
||||
} else if p.otkCounts != nil {
|
||||
for k := range res.DeviceListsOTKCount {
|
||||
if res.DeviceListsOTKCount[k] != p.otkCounts[k] {
|
||||
hasE2EEChanges = true
|
||||
changedOTKCounts = res.DeviceListsOTKCount
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
p.otkCounts = res.DeviceListsOTKCount
|
||||
}
|
||||
var changedFallbackTypes []string
|
||||
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
|
||||
if !hasE2EEChanges {
|
||||
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
|
||||
hasE2EEChanges = true
|
||||
} else {
|
||||
for i := range res.DeviceUnusedFallbackKeyTypes {
|
||||
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
|
||||
hasE2EEChanges = true
|
||||
break
|
||||
}
|
||||
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
|
||||
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
|
||||
} else {
|
||||
for i := range res.DeviceUnusedFallbackKeyTypes {
|
||||
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
|
||||
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -491,12 +489,9 @@ func (p *poller) parseE2EEData(res *SyncResponse) {
|
||||
}
|
||||
|
||||
deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
|
||||
if deviceListChanges != nil {
|
||||
hasE2EEChanges = true
|
||||
}
|
||||
|
||||
if hasE2EEChanges {
|
||||
p.receiver.OnE2EEData(p.userID, p.deviceID, p.otkCounts, p.fallbackKeyTypes, deviceListChanges)
|
||||
if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil {
|
||||
p.receiver.OnE2EEData(p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"github.com/rs/zerolog"
|
||||
@ -187,15 +186,6 @@ func (c *GlobalCache) Startup(roomIDToMetadata map[string]internal.RoomMetadata)
|
||||
sort.Strings(roomIDs)
|
||||
for _, roomID := range roomIDs {
|
||||
metadata := roomIDToMetadata[roomID]
|
||||
var upgradedRoomID string
|
||||
if metadata.UpgradedRoomID != nil {
|
||||
upgradedRoomID = *metadata.UpgradedRoomID
|
||||
}
|
||||
logger.Debug().Str("room", roomID).Interface(
|
||||
"recent", gomatrixserverlib.Timestamp(metadata.LastMessageTimestamp).Time(),
|
||||
).Bool("encrypted", metadata.Encrypted).Str("tombstone", upgradedRoomID).Int("joins", metadata.JoinCount).Msg(
|
||||
"",
|
||||
)
|
||||
internal.Assert("room ID is set", metadata.RoomID != "")
|
||||
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1)
|
||||
c.roomIDToMetadata[roomID] = &metadata
|
||||
|
@ -54,6 +54,7 @@ type RoomAccountDataUpdate struct {
|
||||
AccountData []state.AccountData
|
||||
}
|
||||
|
||||
// Alerts result in changes to ops, subs or ext modifications
|
||||
// Alerts can update internal conn state
|
||||
// Dispatcher thread ultimately fires alerts OR poller thread e.g OnUnreadCounts
|
||||
type DeviceDataUpdate struct {
|
||||
// no data; just wakes up the connection
|
||||
// data comes via sidechannels e.g the database
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
)
|
||||
|
||||
type ConnID struct {
|
||||
@ -22,6 +23,7 @@ type ConnHandler interface {
|
||||
// to send back or an error. Errors of type *internal.HandlerError are inspected for the correct
|
||||
// status code to send back.
|
||||
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
|
||||
OnUpdate(update caches.Update)
|
||||
UserID() string
|
||||
Destroy()
|
||||
Alive() bool
|
||||
@ -69,6 +71,10 @@ func (c *Conn) Alive() bool {
|
||||
return c.handler.Alive()
|
||||
}
|
||||
|
||||
func (c *Conn) OnUpdate(update caches.Update) {
|
||||
c.handler.OnUpdate(update)
|
||||
}
|
||||
|
||||
func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) {
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
)
|
||||
|
||||
type connHandlerMock struct {
|
||||
@ -21,8 +22,9 @@ func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req
|
||||
func (c *connHandlerMock) UserID() string {
|
||||
return "dummy"
|
||||
}
|
||||
func (c *connHandlerMock) Destroy() {}
|
||||
func (c *connHandlerMock) Alive() bool { return true }
|
||||
func (c *connHandlerMock) Destroy() {}
|
||||
func (c *connHandlerMock) Alive() bool { return true }
|
||||
func (c *connHandlerMock) OnUpdate(update caches.Update) {}
|
||||
|
||||
// Test that Conn can send and receive requests based on positions
|
||||
func TestConn(t *testing.T) {
|
||||
|
@ -3,6 +3,7 @@ package extensions
|
||||
import (
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
)
|
||||
|
||||
// Client created request params
|
||||
@ -31,8 +32,15 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
|
||||
if isInitial {
|
||||
return true // ensure we send OTK counts immediately
|
||||
}
|
||||
// OTK counts aren't enough to make /sync return early as we send them liberally, not just on change
|
||||
return r.DeviceLists != nil
|
||||
return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0
|
||||
}
|
||||
|
||||
func ProcessLiveE2EE(up caches.Update, fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) {
|
||||
_, ok := up.(caches.DeviceDataUpdate)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ProcessE2EE(fetcher, userID, deviceID, req, false)
|
||||
}
|
||||
|
||||
func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest, isInitial bool) (res *E2EEResponse) {
|
||||
@ -42,10 +50,10 @@ func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERe
|
||||
if dd == nil {
|
||||
return res // unknown device?
|
||||
}
|
||||
if dd.FallbackKeyTypes != nil {
|
||||
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || isInitial) {
|
||||
res.FallbackKeyTypes = dd.FallbackKeyTypes
|
||||
}
|
||||
if dd.OTKCounts != nil {
|
||||
if dd.OTKCounts != nil && (dd.OTKCountChanged() || isInitial) {
|
||||
res.OTKCounts = dd.OTKCounts
|
||||
}
|
||||
changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent)
|
||||
|
@ -78,6 +78,14 @@ func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Respo
|
||||
if req.Receipts != nil && req.Receipts.Enabled {
|
||||
res.Receipts = ProcessLiveReceipts(update, updateWillReturnResponse, req.UserID, req.Receipts)
|
||||
}
|
||||
// only process 'live' e2ee when we aren't going to return data as we need to ensure that we don't calculate this twice
|
||||
// e.g once on incoming request then again due to wakeup
|
||||
if req.E2EE != nil && req.E2EE.Enabled {
|
||||
if res.E2EE != nil && res.E2EE.HasData(false) {
|
||||
return
|
||||
}
|
||||
res.E2EE = ProcessLiveE2EE(update, h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(req Request, roomIDToTimeline map[string][]string, isInitial bool) (res Response) {
|
||||
|
@ -522,10 +522,15 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
|
||||
userCache.(*caches.UserCache).OnUnreadCounts(p.RoomID, p.HighlightCount, p.NotificationCount)
|
||||
}
|
||||
|
||||
// TODO: We don't eagerly push device data updates on waiting conns (otk counts, device list changes)
|
||||
// Do we need to?
|
||||
// push device data updates on waiting conns (otk counts, device list changes)
|
||||
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||
// Do nothing for now
|
||||
conn := h.ConnMap.Conn(sync3.ConnID{
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
conn.OnUpdate(caches.DeviceDataUpdate{})
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) OnInvite(p *pubsub.V2InviteRoom) {
|
||||
|
@ -2,7 +2,6 @@ package syncv3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -50,7 +49,16 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
|
||||
// check that OTK counts / fallback key types remain constant when they aren't included in the v2 response.
|
||||
// Dummy request as we will see the same otk/fallback keys twice initially
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10}, // doesn't matter
|
||||
},
|
||||
}},
|
||||
})
|
||||
|
||||
// check that OTK counts / fallback key types aren't present afterwards as they haven't changed.
|
||||
// Do this by feeding in a new joined room
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
@ -61,6 +69,7 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -69,10 +78,9 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
}},
|
||||
// skip enabled: true as it should be sticky
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
m.MatchResponse(t, res, m.MatchNoE2EEExtension()) // No E2EE changes = no extension
|
||||
|
||||
// check that OTK counts update when they are included in the v2 response
|
||||
// check fallback key types persist when not included
|
||||
otkCounts = map[string]int{
|
||||
"curve25519": 99,
|
||||
"signed_curve25519": 999,
|
||||
@ -87,14 +95,8 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
[2]int64{0, 10}, // doesn't matter
|
||||
},
|
||||
}},
|
||||
// enable the E2EE extension
|
||||
Extensions: extensions.Request{
|
||||
E2EE: &extensions.E2EERequest{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(nil))
|
||||
|
||||
// check that changed|left get passed to v3
|
||||
wantChanged := []string{"bob"}
|
||||
@ -152,6 +154,7 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -165,12 +168,31 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, func(res *sync3.Response) error {
|
||||
if res.Extensions.E2EE.DeviceLists != nil {
|
||||
return fmt.Errorf("e2ee device lists present when it shouldn't be")
|
||||
}
|
||||
return nil
|
||||
m.MatchResponse(t, res, m.MatchNoE2EEExtension())
|
||||
|
||||
// Check that OTK counts are immediately sent to the client
|
||||
otkCounts = map[string]int{
|
||||
"curve25519": 42,
|
||||
"signed_curve25519": 420,
|
||||
}
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
DeviceListsOTKCount: otkCounts,
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
req := sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10}, // doesn't matter
|
||||
},
|
||||
}},
|
||||
}
|
||||
req.SetTimeoutMSecs(500)
|
||||
start := time.Now()
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, req)
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
|
||||
if time.Since(start) >= (500 * time.Millisecond) {
|
||||
t.Fatalf("sync request did not return immediately with OTK counts")
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that to-device messages are passed from v2 to v3
|
||||
|
@ -227,6 +227,15 @@ func MatchRoomSubscriptions(wantSubs map[string][]RoomMatcher) RespMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func MatchNoE2EEExtension() RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.E2EE != nil {
|
||||
return fmt.Errorf("MatchNoE2EEExtension: got E2EE extension: %+v", res.Extensions.E2EE)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.E2EE == nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user