Update the to-device table

This commit is contained in:
David Robertson 2023-05-02 14:34:22 +01:00
parent 0adaf75cfc
commit e1bc972ff7
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
4 changed files with 66 additions and 57 deletions

View File

@ -23,6 +23,7 @@ type ToDeviceTable struct {
type ToDeviceRow struct { type ToDeviceRow struct {
Position int64 `db:"position"` Position int64 `db:"position"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"` DeviceID string `db:"device_id"`
Message string `db:"message"` Message string `db:"message"`
Type string `db:"event_type"` Type string `db:"event_type"`
@ -46,6 +47,7 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq; CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages ( CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'), position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
user_id TEXT NOT NULL,
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
event_type TEXT NOT NULL, event_type TEXT NOT NULL,
sender TEXT NOT NULL, sender TEXT NOT NULL,
@ -55,7 +57,9 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
action SMALLINT DEFAULT 0 -- 0 means unknown action SMALLINT DEFAULT 0 -- 0 means unknown
); );
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos ( CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
device_id TEXT NOT NULL PRIMARY KEY, user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
PRIMARY KEY (user_id, device_id),
unack_pos BIGINT NOT NULL unack_pos BIGINT NOT NULL
); );
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id); CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
@ -65,29 +69,31 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
return &ToDeviceTable{db} return &ToDeviceTable{db}
} }
func (t *ToDeviceTable) SetUnackedPosition(deviceID string, pos int64) error { func (t *ToDeviceTable) SetUnackedPosition(userID, deviceID string, pos int64) error {
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(device_id, unack_pos) VALUES($1,$2) ON CONFLICT (device_id) _, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(user_id, device_id, unack_pos) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id)
DO UPDATE SET unack_pos=$2`, deviceID, pos) DO UPDATE SET unack_pos=excluded.unack_pos`, userID, deviceID, pos)
return err return err
} }
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(deviceID string, toIncl int64) error { func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(userID, deviceID string, toIncl int64) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1 AND position <= $2`, deviceID, toIncl) _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3`, userID, deviceID, toIncl)
return err return err
} }
func (t *ToDeviceTable) DeleteAllMessagesForDevice(deviceID string) error { func (t *ToDeviceTable) DeleteAllMessagesForDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`, deviceID) _, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
return err return err
} }
// Query to-device messages for this device, exclusive of from and inclusive of to. If a to value is unknown, use -1. // Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`.
func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) { // Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message
// fetched.
func (t *ToDeviceTable) Messages(userID, deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
upTo = from upTo = from
var rows []ToDeviceRow var rows []ToDeviceRow
err = t.db.Select(&rows, err = t.db.Select(&rows,
`SELECT position, message FROM syncv3_to_device_messages WHERE device_id = $1 AND position > $2 ORDER BY position ASC LIMIT $3`, `SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4`,
deviceID, from, limit, userID, deviceID, from, limit,
) )
if len(rows) == 0 { if len(rows) == 0 {
return return
@ -98,18 +104,18 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso
m := gjson.ParseBytes(msgs[i]) m := gjson.ParseBytes(msgs[i])
msgId := m.Get(`content.org\.matrix\.msgid`).Str msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" { if msgId != "" {
logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages") logger.Info().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.Messages")
} }
} }
upTo = rows[len(rows)-1].Position upTo = rows[len(rows)-1].Position
return return
} }
func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) (pos int64, err error) { func (t *ToDeviceTable) InsertMessages(userID, deviceID string, msgs []json.RawMessage) (pos int64, err error) {
var lastPos int64 var lastPos int64
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var unackPos int64 var unackPos int64
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE device_id=$1`, deviceID).Scan(&unackPos) err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2`, userID, deviceID).Scan(&unackPos)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("unable to select unacked pos: %s", err) return fmt.Errorf("unable to select unacked pos: %s", err)
} }
@ -124,6 +130,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
for i := range msgs { for i := range msgs {
m := gjson.ParseBytes(msgs[i]) m := gjson.ParseBytes(msgs[i])
rows[i] = ToDeviceRow{ rows[i] = ToDeviceRow{
UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
Message: string(msgs[i]), Message: string(msgs[i]),
Type: m.Get("type").Str, Type: m.Get("type").Str,
@ -131,7 +138,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
} }
msgId := m.Get(`content.org\.matrix\.msgid`).Str msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" { if msgId != "" {
logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages") logger.Debug().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
} }
switch rows[i].Type { switch rows[i].Type {
case "m.room_key_request": case "m.room_key_request":
@ -155,8 +162,8 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
if len(cancels) > 0 { if len(cancels) > 0 {
var cancelled []string var cancelled []string
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked) // delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND device_id = $2 AND position > $3 RETURNING unique_key`, err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key`,
pq.StringArray(cancels), deviceID, unackPos) pq.StringArray(cancels), userID, deviceID, unackPos)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete cancelled events: %s", err) return fmt.Errorf("failed to delete cancelled events: %s", err)
} }
@ -189,10 +196,10 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
return nil return nil
} }
chunks := sqlutil.Chunkify(6, MaxPostgresParameters, ToDeviceRowChunker(rows)) chunks := sqlutil.Chunkify(7, MaxPostgresParameters, ToDeviceRowChunker(rows))
for _, chunk := range chunks { for _, chunk := range chunks {
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender, action, unique_key) result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (user_id, device_id, message, event_type, sender, action, unique_key)
VALUES (:device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk) VALUES (:user_id, :device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,6 +12,7 @@ func TestToDeviceTable(t *testing.T) {
db, close := connectToDB(t) db, close := connectToDB(t)
defer close() defer close()
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
sender := "@alice:localhost"
deviceID := "FOO" deviceID := "FOO"
var limit int64 = 999 var limit int64 = 999
msgs := []json.RawMessage{ msgs := []json.RawMessage{
@ -20,13 +21,13 @@ func TestToDeviceTable(t *testing.T) {
} }
var lastPos int64 var lastPos int64
var err error var err error
if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil { if lastPos, err = table.InsertMessages(sender, deviceID, msgs); err != nil {
t.Fatalf("InsertMessages: %s", err) t.Fatalf("InsertMessages: %s", err)
} }
if lastPos != 2 { if lastPos != 2 {
t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos) t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos)
} }
gotMsgs, upTo, err := table.Messages(deviceID, 0, limit) gotMsgs, upTo, err := table.Messages(sender, deviceID, 0, limit)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -43,7 +44,7 @@ func TestToDeviceTable(t *testing.T) {
} }
// same to= token, no messages // same to= token, no messages
gotMsgs, upTo, err = table.Messages(deviceID, lastPos, limit) gotMsgs, upTo, err = table.Messages(sender, deviceID, lastPos, limit)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -55,7 +56,7 @@ func TestToDeviceTable(t *testing.T) {
} }
// different device ID, no messages // different device ID, no messages
gotMsgs, upTo, err = table.Messages("OTHER_DEVICE", 0, limit) gotMsgs, upTo, err = table.Messages(sender, "OTHER_DEVICE", 0, limit)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -67,7 +68,7 @@ func TestToDeviceTable(t *testing.T) {
} }
// zero limit, no messages // zero limit, no messages
gotMsgs, upTo, err = table.Messages(deviceID, 0, 0) gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, 0)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -80,7 +81,7 @@ func TestToDeviceTable(t *testing.T) {
// lower limit, cap out // lower limit, cap out
var wantLimit int64 = 1 var wantLimit int64 = 1
gotMsgs, upTo, err = table.Messages(deviceID, 0, wantLimit) gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, wantLimit)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -93,10 +94,10 @@ func TestToDeviceTable(t *testing.T) {
} }
// delete the first message, requerying only gives 1 message // delete the first message, requerying only gives 1 message
if err := table.DeleteMessagesUpToAndIncluding(deviceID, lastPos-1); err != nil { if err := table.DeleteMessagesUpToAndIncluding(sender, deviceID, lastPos-1); err != nil {
t.Fatalf("DeleteMessagesUpTo: %s", err) t.Fatalf("DeleteMessagesUpTo: %s", err)
} }
gotMsgs, upTo, err = table.Messages(deviceID, 0, limit) gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, limit)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -114,9 +115,9 @@ func TestToDeviceTable(t *testing.T) {
t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want)) t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want))
} }
// delete everything and check it works // delete everything and check it works
err = table.DeleteAllMessagesForDevice(deviceID) err = table.DeleteAllMessagesForDevice(sender, deviceID)
assertNoError(t, err) assertNoError(t, err)
msgs, _, err = table.Messages(deviceID, -1, 10) msgs, _, err = table.Messages(sender, deviceID, -1, 10)
assertNoError(t, err) assertNoError(t, err)
assertVal(t, "wanted 0 msgs", len(msgs), 0) assertVal(t, "wanted 0 msgs", len(msgs), 0)
} }
@ -132,41 +133,41 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
"foo": "bar", "foo": "bar",
}) })
_, err := table.InsertMessages(destination, []json.RawMessage{reqEv1}) _, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv1})
assertNoError(t, err) assertNoError(t, err)
gotMsgs, _, err := table.Messages(destination, 0, 10) gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv1) bytesEqual(t, gotMsgs[0], reqEv1)
reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{ reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
"foo": "baz", "foo": "baz",
}) })
_, err = table.InsertMessages(destination, []json.RawMessage{reqEv2}) _, err = table.InsertMessages(sender, destination, []json.RawMessage{reqEv2})
assertNoError(t, err) assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination, 0, 10) gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
bytesEqual(t, gotMsgs[1], reqEv2) bytesEqual(t, gotMsgs[1], reqEv2)
// now delete 1 // now delete 1
cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil) cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv1}) _, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv1})
assertNoError(t, err) assertNoError(t, err)
// selecting messages now returns only reqEv2 // selecting messages now returns only reqEv2
gotMsgs, _, err = table.Messages(destination, 0, 10) gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv2) bytesEqual(t, gotMsgs[0], reqEv2)
// now do lots of close but not quite cancellation requests that should not match reqEv2 // now do lots of close but not quite cancellation requests that should not match reqEv2
_, err = table.InsertMessages(destination, []json.RawMessage{ _, err = table.InsertMessages(sender, destination, []json.RawMessage{
newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action
newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID
newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id
}) })
assertNoError(t, err) assertNoError(t, err)
_, err = table.InsertMessages("wrong_destination", []json.RawMessage{ // wrong destination _, err = table.InsertMessages(sender, "wrong_destination", []json.RawMessage{ // wrong destination
newRoomKeyEvent(t, "request_cancellation", "2", sender, nil), newRoomKeyEvent(t, "request_cancellation", "2", sender, nil),
}) })
assertNoError(t, err) assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination, 0, 10) gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on
if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest
@ -175,14 +176,14 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
// request + cancel in one go => nothing inserted // request + cancel in one go => nothing inserted
destination2 := "DEST2" destination2 := "DEST2"
_, err = table.InsertMessages(destination2, []json.RawMessage{ _, err = table.InsertMessages(sender, destination2, []json.RawMessage{
newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{ newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{
"foo": "baz", "foo": "baz",
}), }),
newRoomKeyEvent(t, "request_cancellation", "A", sender, nil), newRoomKeyEvent(t, "request_cancellation", "A", sender, nil),
}) })
assertNoError(t, err) assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination2, 0, 10) gotMsgs, _, err = table.Messages(sender, destination2, 0, 10)
assertNoError(t, err) assertNoError(t, err)
if len(gotMsgs) > 0 { if len(gotMsgs) > 0 {
t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs)) t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs))
@ -200,18 +201,18 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
"foo": "bar", "foo": "bar",
}) })
pos, err := table.InsertMessages(destination, []json.RawMessage{reqEv}) pos, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv})
assertNoError(t, err) assertNoError(t, err)
// mark this position as unacked: this means the client MAY know about this request so it isn't // mark this position as unacked: this means the client MAY know about this request so it isn't
// safe to delete it // safe to delete it
err = table.SetUnackedPosition(destination, pos) err = table.SetUnackedPosition(sender, destination, pos)
assertNoError(t, err) assertNoError(t, err)
// now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events // now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events
cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil) cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv}) _, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv})
assertNoError(t, err) assertNoError(t, err)
// selecting messages returns both events // selecting messages returns both events
gotMsgs, _, err := table.Messages(destination, 0, 10) gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
if len(gotMsgs) != 2 { if len(gotMsgs) != 2 {
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs)) t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
@ -220,14 +221,14 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
bytesEqual(t, gotMsgs[1], cancelEv) bytesEqual(t, gotMsgs[1], cancelEv)
// test that injecting another req/cancel does cause them to be deleted // test that injecting another req/cancel does cause them to be deleted
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{ _, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
"foo": "bar", "foo": "bar",
})}) })})
assertNoError(t, err) assertNoError(t, err)
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)}) _, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
assertNoError(t, err) assertNoError(t, err)
// selecting messages returns the same as before // selecting messages returns the same as before
gotMsgs, _, err = table.Messages(destination, 0, 10) gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
if len(gotMsgs) != 2 { if len(gotMsgs) != 2 {
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs)) t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
@ -240,6 +241,7 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
db, close := connectToDB(t) db, close := connectToDB(t)
defer close() defer close()
sender := "@sendymcsendface:localhost"
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
testCases := []json.RawMessage{ testCases := []json.RawMessage{
json.RawMessage(`{}`), json.RawMessage(`{}`),
@ -250,11 +252,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
} }
var pos int64 var pos int64
for _, msg := range testCases { for _, msg := range testCases {
nextPos, err := table.InsertMessages("A", []json.RawMessage{msg}) nextPos, err := table.InsertMessages(sender, "A", []json.RawMessage{msg})
if err != nil { if err != nil {
t.Fatalf("InsertMessages: %s", err) t.Fatalf("InsertMessages: %s", err)
} }
got, _, err := table.Messages("A", pos, 1) got, _, err := table.Messages(sender, "A", pos, 1)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }
@ -262,11 +264,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
pos = nextPos pos = nextPos
} }
// and all at once // and all at once
_, err := table.InsertMessages("B", testCases) _, err := table.InsertMessages(sender, "B", testCases)
if err != nil { if err != nil {
t.Fatalf("InsertMessages: %s", err) t.Fatalf("InsertMessages: %s", err)
} }
got, _, err := table.Messages("B", 0, 100) got, _, err := table.Messages(sender, "B", 0, 100)
if err != nil { if err != nil {
t.Fatalf("Messages: %s", err) t.Fatalf("Messages: %s", err)
} }

View File

@ -294,7 +294,7 @@ func (h *Handler) OnReceipt(userID, roomID, ephEventType string, ephEvent json.R
} }
func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) { func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
_, err := h.Store.ToDeviceTable.InsertMessages(deviceID, msgs) _, err := h.Store.ToDeviceTable.InsertMessages(userID, deviceID, msgs)
if err != nil { if err != nil {
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages") logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages")
sentry.CaptureException(err) sentry.CaptureException(err)

View File

@ -75,7 +75,7 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
return return
} }
// the client is confirming messages up to `from` so delete everything up to and including it. // the client is confirming messages up to `from` so delete everything up to and including it.
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.DeviceID, from); err != nil { if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.UserID, extCtx.DeviceID, from); err != nil {
l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value") l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value")
// TODO add context to sentry // TODO add context to sentry
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
@ -94,14 +94,14 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg)) internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg))
} }
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.DeviceID, from, int64(r.Limit)) msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.UserID, extCtx.DeviceID, from, int64(r.Limit))
if err != nil { if err != nil {
l.Err(err).Int64("from", from).Msg("cannot query to-device messages") l.Err(err).Int64("from", from).Msg("cannot query to-device messages")
// TODO add context to sentry // TODO add context to sentry
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return return
} }
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.DeviceID, upTo) err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.UserID, extCtx.DeviceID, upTo)
if err != nil { if err != nil {
l.Err(err).Msg("cannot set unacked position") l.Err(err).Msg("cannot set unacked position")
// TODO add context to sentry // TODO add context to sentry