From e1bc972ff73aef77c874d4658cc53ed818ee969c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 2 May 2023 14:34:22 +0100 Subject: [PATCH] Update the to-device table --- state/to_device_table.go | 49 +++++++++++++++----------- state/to_device_table_test.go | 66 ++++++++++++++++++----------------- sync2/handler2/handler.go | 2 +- sync3/extensions/todevice.go | 6 ++-- 4 files changed, 66 insertions(+), 57 deletions(-) diff --git a/state/to_device_table.go b/state/to_device_table.go index 0d428ac..6bf48e2 100644 --- a/state/to_device_table.go +++ b/state/to_device_table.go @@ -23,6 +23,7 @@ type ToDeviceTable struct { type ToDeviceRow struct { Position int64 `db:"position"` + UserID string `db:"user_id"` DeviceID string `db:"device_id"` Message string `db:"message"` Type string `db:"event_type"` @@ -46,6 +47,7 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable { CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq; CREATE TABLE IF NOT EXISTS syncv3_to_device_messages ( position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'), + user_id TEXT NOT NULL, device_id TEXT NOT NULL, event_type TEXT NOT NULL, sender TEXT NOT NULL, @@ -55,7 +57,9 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable { action SMALLINT DEFAULT 0 -- 0 means unknown ); CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos ( - device_id TEXT NOT NULL PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + PRIMARY KEY (user_id, device_id), unack_pos BIGINT NOT NULL ); CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id); @@ -65,29 +69,31 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable { return &ToDeviceTable{db} } -func (t *ToDeviceTable) SetUnackedPosition(deviceID string, pos int64) error { - _, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(device_id, unack_pos) VALUES($1,$2) ON CONFLICT (device_id) - DO UPDATE SET unack_pos=$2`, deviceID, pos) +func (t *ToDeviceTable) SetUnackedPosition(userID, deviceID string, pos int64) error { + _, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(user_id, device_id, unack_pos) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) + DO UPDATE SET unack_pos=excluded.unack_pos`, userID, deviceID, pos) return err } -func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(deviceID string, toIncl int64) error { - _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1 AND position <= $2`, deviceID, toIncl) +func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(userID, deviceID string, toIncl int64) error { + _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3`, userID, deviceID, toIncl) return err } -func (t *ToDeviceTable) DeleteAllMessagesForDevice(deviceID string) error { - _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`, deviceID) +func (t *ToDeviceTable) DeleteAllMessagesForDevice(userID, deviceID string) error { + _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2`, userID, deviceID) return err } -// Query to-device messages for this device, exclusive of from and inclusive of to. If a to value is unknown, use -1. -func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) { +// Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`. +// Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message +// fetched. +func (t *ToDeviceTable) Messages(userID, deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) { upTo = from var rows []ToDeviceRow err = t.db.Select(&rows, - `SELECT position, message FROM syncv3_to_device_messages WHERE device_id = $1 AND position > $2 ORDER BY position ASC LIMIT $3`, - deviceID, from, limit, + `SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4`, + userID, deviceID, from, limit, ) if len(rows) == 0 { return @@ -98,18 +104,18 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso m := gjson.ParseBytes(msgs[i]) msgId := m.Get(`content.org\.matrix\.msgid`).Str if msgId != "" { - logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages") + logger.Info().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.Messages") } } upTo = rows[len(rows)-1].Position return } -func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) (pos int64, err error) { +func (t *ToDeviceTable) InsertMessages(userID, deviceID string, msgs []json.RawMessage) (pos int64, err error) { var lastPos int64 err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { var unackPos int64 - err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE device_id=$1`, deviceID).Scan(&unackPos) + err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2`, userID, deviceID).Scan(&unackPos) if err != nil && err != sql.ErrNoRows { return fmt.Errorf("unable to select unacked pos: %s", err) } @@ -124,6 +130,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) for i := range msgs { m := gjson.ParseBytes(msgs[i]) rows[i] = ToDeviceRow{ + UserID: userID, DeviceID: deviceID, Message: string(msgs[i]), Type: m.Get("type").Str, @@ -131,7 +138,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) } msgId := m.Get(`content.org\.matrix\.msgid`).Str if msgId != "" { - logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages") + logger.Debug().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages") } switch rows[i].Type { case "m.room_key_request": @@ -155,8 +162,8 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) if len(cancels) > 0 { var cancelled []string // delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked) - err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND device_id = $2 AND position > $3 RETURNING unique_key`, - pq.StringArray(cancels), deviceID, unackPos) + err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key`, + pq.StringArray(cancels), userID, deviceID, unackPos) if err != nil { return fmt.Errorf("failed to delete cancelled events: %s", err) } @@ -189,10 +196,10 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) return nil } - chunks := sqlutil.Chunkify(6, MaxPostgresParameters, ToDeviceRowChunker(rows)) + chunks := sqlutil.Chunkify(7, MaxPostgresParameters, ToDeviceRowChunker(rows)) for _, chunk := range chunks { - result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender, action, unique_key) - VALUES (:device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk) + result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (user_id, device_id, message, event_type, sender, action, unique_key) + VALUES (:user_id, :device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk) if err != nil { return err } diff --git a/state/to_device_table_test.go b/state/to_device_table_test.go index 8495bad..7f092a5 100644 --- a/state/to_device_table_test.go +++ b/state/to_device_table_test.go @@ -12,6 +12,7 @@ func TestToDeviceTable(t *testing.T) { db, close := connectToDB(t) defer close() table := NewToDeviceTable(db) + sender := "@alice:localhost" deviceID := "FOO" var limit int64 = 999 msgs := []json.RawMessage{ @@ -20,13 +21,13 @@ func TestToDeviceTable(t *testing.T) { } var lastPos int64 var err error - if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil { + if lastPos, err = table.InsertMessages(sender, deviceID, msgs); err != nil { t.Fatalf("InsertMessages: %s", err) } if lastPos != 2 { t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos) } - gotMsgs, upTo, err := table.Messages(deviceID, 0, limit) + gotMsgs, upTo, err := table.Messages(sender, deviceID, 0, limit) if err != nil { t.Fatalf("Messages: %s", err) } @@ -43,7 +44,7 @@ func TestToDeviceTable(t *testing.T) { } // same to= token, no messages - gotMsgs, upTo, err = table.Messages(deviceID, lastPos, limit) + gotMsgs, upTo, err = table.Messages(sender, deviceID, lastPos, limit) if err != nil { t.Fatalf("Messages: %s", err) } @@ -55,7 +56,7 @@ func TestToDeviceTable(t *testing.T) { } // different device ID, no messages - gotMsgs, upTo, err = table.Messages("OTHER_DEVICE", 0, limit) + gotMsgs, upTo, err = table.Messages(sender, "OTHER_DEVICE", 0, limit) if err != nil { t.Fatalf("Messages: %s", err) } @@ -67,7 +68,7 @@ func TestToDeviceTable(t *testing.T) { } // zero limit, no messages - gotMsgs, upTo, err = table.Messages(deviceID, 0, 0) + gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, 0) if err != nil { t.Fatalf("Messages: %s", err) } @@ -80,7 +81,7 @@ func TestToDeviceTable(t *testing.T) { // lower limit, cap out var wantLimit int64 = 1 - gotMsgs, upTo, err = table.Messages(deviceID, 0, wantLimit) + gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, wantLimit) if err != nil { t.Fatalf("Messages: %s", err) } @@ -93,10 +94,10 @@ func TestToDeviceTable(t *testing.T) { } // delete the first message, requerying only gives 1 message - if err := table.DeleteMessagesUpToAndIncluding(deviceID, lastPos-1); err != nil { + if err := table.DeleteMessagesUpToAndIncluding(sender, deviceID, lastPos-1); err != nil { t.Fatalf("DeleteMessagesUpTo: %s", err) } - gotMsgs, upTo, err = table.Messages(deviceID, 0, limit) + gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, limit) if err != nil { t.Fatalf("Messages: %s", err) } @@ -114,9 +115,9 @@ func TestToDeviceTable(t *testing.T) { t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want)) } // delete everything and check it works - err = table.DeleteAllMessagesForDevice(deviceID) + err = table.DeleteAllMessagesForDevice(sender, deviceID) assertNoError(t, err) - msgs, _, err = table.Messages(deviceID, -1, 10) + msgs, _, err = table.Messages(sender, deviceID, -1, 10) assertNoError(t, err) assertVal(t, "wanted 0 msgs", len(msgs), 0) } @@ -132,41 +133,41 @@ func TestToDeviceTableDeleteCancels(t *testing.T) { reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ "foo": "bar", }) - _, err := table.InsertMessages(destination, []json.RawMessage{reqEv1}) + _, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv1}) assertNoError(t, err) - gotMsgs, _, err := table.Messages(destination, 0, 10) + gotMsgs, _, err := table.Messages(sender, destination, 0, 10) assertNoError(t, err) bytesEqual(t, gotMsgs[0], reqEv1) reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{ "foo": "baz", }) - _, err = table.InsertMessages(destination, []json.RawMessage{reqEv2}) + _, err = table.InsertMessages(sender, destination, []json.RawMessage{reqEv2}) assertNoError(t, err) - gotMsgs, _, err = table.Messages(destination, 0, 10) + gotMsgs, _, err = table.Messages(sender, destination, 0, 10) assertNoError(t, err) bytesEqual(t, gotMsgs[1], reqEv2) // now delete 1 cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil) - _, err = table.InsertMessages(destination, []json.RawMessage{cancelEv1}) + _, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv1}) assertNoError(t, err) // selecting messages now returns only reqEv2 - gotMsgs, _, err = table.Messages(destination, 0, 10) + gotMsgs, _, err = table.Messages(sender, destination, 0, 10) assertNoError(t, err) bytesEqual(t, gotMsgs[0], reqEv2) // now do lots of close but not quite cancellation requests that should not match reqEv2 - _, err = table.InsertMessages(destination, []json.RawMessage{ + _, err = table.InsertMessages(sender, destination, []json.RawMessage{ newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id }) assertNoError(t, err) - _, err = table.InsertMessages("wrong_destination", []json.RawMessage{ // wrong destination + _, err = table.InsertMessages(sender, "wrong_destination", []json.RawMessage{ // wrong destination newRoomKeyEvent(t, "request_cancellation", "2", sender, nil), }) assertNoError(t, err) - gotMsgs, _, err = table.Messages(destination, 0, 10) + gotMsgs, _, err = table.Messages(sender, destination, 0, 10) assertNoError(t, err) bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest @@ -175,14 +176,14 @@ func TestToDeviceTableDeleteCancels(t *testing.T) { // request + cancel in one go => nothing inserted destination2 := "DEST2" - _, err = table.InsertMessages(destination2, []json.RawMessage{ + _, err = table.InsertMessages(sender, destination2, []json.RawMessage{ newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{ "foo": "baz", }), newRoomKeyEvent(t, "request_cancellation", "A", sender, nil), }) assertNoError(t, err) - gotMsgs, _, err = table.Messages(destination2, 0, 10) + gotMsgs, _, err = table.Messages(sender, destination2, 0, 10) assertNoError(t, err) if len(gotMsgs) > 0 { t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs)) @@ -200,18 +201,18 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) { reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ "foo": "bar", }) - pos, err := table.InsertMessages(destination, []json.RawMessage{reqEv}) + pos, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv}) assertNoError(t, err) // mark this position as unacked: this means the client MAY know about this request so it isn't // safe to delete it - err = table.SetUnackedPosition(destination, pos) + err = table.SetUnackedPosition(sender, destination, pos) assertNoError(t, err) // now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil) - _, err = table.InsertMessages(destination, []json.RawMessage{cancelEv}) + _, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv}) assertNoError(t, err) // selecting messages returns both events - gotMsgs, _, err := table.Messages(destination, 0, 10) + gotMsgs, _, err := table.Messages(sender, destination, 0, 10) assertNoError(t, err) if len(gotMsgs) != 2 { t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs)) @@ -220,14 +221,14 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) { bytesEqual(t, gotMsgs[1], cancelEv) // test that injecting another req/cancel does cause them to be deleted - _, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{ + _, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{ "foo": "bar", })}) assertNoError(t, err) - _, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)}) + _, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)}) assertNoError(t, err) // selecting messages returns the same as before - gotMsgs, _, err = table.Messages(destination, 0, 10) + gotMsgs, _, err = table.Messages(sender, destination, 0, 10) assertNoError(t, err) if len(gotMsgs) != 2 { t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs)) @@ -240,6 +241,7 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) { func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { db, close := connectToDB(t) defer close() + sender := "@sendymcsendface:localhost" table := NewToDeviceTable(db) testCases := []json.RawMessage{ json.RawMessage(`{}`), @@ -250,11 +252,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { } var pos int64 for _, msg := range testCases { - nextPos, err := table.InsertMessages("A", []json.RawMessage{msg}) + nextPos, err := table.InsertMessages(sender, "A", []json.RawMessage{msg}) if err != nil { t.Fatalf("InsertMessages: %s", err) } - got, _, err := table.Messages("A", pos, 1) + got, _, err := table.Messages(sender, "A", pos, 1) if err != nil { t.Fatalf("Messages: %s", err) } @@ -262,11 +264,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { pos = nextPos } // and all at once - _, err := table.InsertMessages("B", testCases) + _, err := table.InsertMessages(sender, "B", testCases) if err != nil { t.Fatalf("InsertMessages: %s", err) } - got, _, err := table.Messages("B", 0, 100) + got, _, err := table.Messages(sender, "B", 0, 100) if err != nil { t.Fatalf("Messages: %s", err) } diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index ece4885..9b42306 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -294,7 +294,7 @@ func (h *Handler) OnReceipt(userID, roomID, ephEventType string, ephEvent json.R } func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) { - _, err := h.Store.ToDeviceTable.InsertMessages(deviceID, msgs) + _, err := h.Store.ToDeviceTable.InsertMessages(userID, deviceID, msgs) if err != nil { logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages") sentry.CaptureException(err) diff --git a/sync3/extensions/todevice.go b/sync3/extensions/todevice.go index c6a306f..9df95b4 100644 --- a/sync3/extensions/todevice.go +++ b/sync3/extensions/todevice.go @@ -75,7 +75,7 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext return } // the client is confirming messages up to `from` so delete everything up to and including it. - if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.DeviceID, from); err != nil { + if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.UserID, extCtx.DeviceID, from); err != nil { l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value") // TODO add context to sentry internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) @@ -94,14 +94,14 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg)) } - msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.DeviceID, from, int64(r.Limit)) + msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.UserID, extCtx.DeviceID, from, int64(r.Limit)) if err != nil { l.Err(err).Int64("from", from).Msg("cannot query to-device messages") // TODO add context to sentry internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) return } - err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.DeviceID, upTo) + err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.UserID, extCtx.DeviceID, upTo) if err != nil { l.Err(err).Msg("cannot set unacked position") // TODO add context to sentry