Update the to-device table

This commit is contained in:
David Robertson 2023-05-02 14:34:22 +01:00
parent 0adaf75cfc
commit e1bc972ff7
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
4 changed files with 66 additions and 57 deletions

View File

@ -23,6 +23,7 @@ type ToDeviceTable struct {
type ToDeviceRow struct {
Position int64 `db:"position"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
Message string `db:"message"`
Type string `db:"event_type"`
@ -46,6 +47,7 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
event_type TEXT NOT NULL,
sender TEXT NOT NULL,
@ -55,7 +57,9 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
action SMALLINT DEFAULT 0 -- 0 means unknown
);
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
device_id TEXT NOT NULL PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
PRIMARY KEY (user_id, device_id),
unack_pos BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
@ -65,29 +69,31 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
return &ToDeviceTable{db}
}
func (t *ToDeviceTable) SetUnackedPosition(deviceID string, pos int64) error {
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(device_id, unack_pos) VALUES($1,$2) ON CONFLICT (device_id)
DO UPDATE SET unack_pos=$2`, deviceID, pos)
func (t *ToDeviceTable) SetUnackedPosition(userID, deviceID string, pos int64) error {
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(user_id, device_id, unack_pos) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id)
DO UPDATE SET unack_pos=excluded.unack_pos`, userID, deviceID, pos)
return err
}
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(deviceID string, toIncl int64) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1 AND position <= $2`, deviceID, toIncl)
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(userID, deviceID string, toIncl int64) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3`, userID, deviceID, toIncl)
return err
}
func (t *ToDeviceTable) DeleteAllMessagesForDevice(deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`, deviceID)
func (t *ToDeviceTable) DeleteAllMessagesForDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
return err
}
// Query to-device messages for this device, exclusive of from and inclusive of to. If a to value is unknown, use -1.
func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
// Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`.
// Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message
// fetched.
func (t *ToDeviceTable) Messages(userID, deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
upTo = from
var rows []ToDeviceRow
err = t.db.Select(&rows,
`SELECT position, message FROM syncv3_to_device_messages WHERE device_id = $1 AND position > $2 ORDER BY position ASC LIMIT $3`,
deviceID, from, limit,
`SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4`,
userID, deviceID, from, limit,
)
if len(rows) == 0 {
return
@ -98,18 +104,18 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso
m := gjson.ParseBytes(msgs[i])
msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" {
logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
logger.Info().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.Messages")
}
}
upTo = rows[len(rows)-1].Position
return
}
func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) (pos int64, err error) {
func (t *ToDeviceTable) InsertMessages(userID, deviceID string, msgs []json.RawMessage) (pos int64, err error) {
var lastPos int64
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var unackPos int64
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE device_id=$1`, deviceID).Scan(&unackPos)
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2`, userID, deviceID).Scan(&unackPos)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("unable to select unacked pos: %s", err)
}
@ -124,6 +130,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
for i := range msgs {
m := gjson.ParseBytes(msgs[i])
rows[i] = ToDeviceRow{
UserID: userID,
DeviceID: deviceID,
Message: string(msgs[i]),
Type: m.Get("type").Str,
@ -131,7 +138,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
}
msgId := m.Get(`content.org\.matrix\.msgid`).Str
if msgId != "" {
logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
logger.Debug().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
}
switch rows[i].Type {
case "m.room_key_request":
@ -155,8 +162,8 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
if len(cancels) > 0 {
var cancelled []string
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND device_id = $2 AND position > $3 RETURNING unique_key`,
pq.StringArray(cancels), deviceID, unackPos)
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key`,
pq.StringArray(cancels), userID, deviceID, unackPos)
if err != nil {
return fmt.Errorf("failed to delete cancelled events: %s", err)
}
@ -189,10 +196,10 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
return nil
}
chunks := sqlutil.Chunkify(6, MaxPostgresParameters, ToDeviceRowChunker(rows))
chunks := sqlutil.Chunkify(7, MaxPostgresParameters, ToDeviceRowChunker(rows))
for _, chunk := range chunks {
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender, action, unique_key)
VALUES (:device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (user_id, device_id, message, event_type, sender, action, unique_key)
VALUES (:user_id, :device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
if err != nil {
return err
}

View File

@ -12,6 +12,7 @@ func TestToDeviceTable(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewToDeviceTable(db)
sender := "@alice:localhost"
deviceID := "FOO"
var limit int64 = 999
msgs := []json.RawMessage{
@ -20,13 +21,13 @@ func TestToDeviceTable(t *testing.T) {
}
var lastPos int64
var err error
if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil {
if lastPos, err = table.InsertMessages(sender, deviceID, msgs); err != nil {
t.Fatalf("InsertMessages: %s", err)
}
if lastPos != 2 {
t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos)
}
gotMsgs, upTo, err := table.Messages(deviceID, 0, limit)
gotMsgs, upTo, err := table.Messages(sender, deviceID, 0, limit)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -43,7 +44,7 @@ func TestToDeviceTable(t *testing.T) {
}
// same to= token, no messages
gotMsgs, upTo, err = table.Messages(deviceID, lastPos, limit)
gotMsgs, upTo, err = table.Messages(sender, deviceID, lastPos, limit)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -55,7 +56,7 @@ func TestToDeviceTable(t *testing.T) {
}
// different device ID, no messages
gotMsgs, upTo, err = table.Messages("OTHER_DEVICE", 0, limit)
gotMsgs, upTo, err = table.Messages(sender, "OTHER_DEVICE", 0, limit)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -67,7 +68,7 @@ func TestToDeviceTable(t *testing.T) {
}
// zero limit, no messages
gotMsgs, upTo, err = table.Messages(deviceID, 0, 0)
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, 0)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -80,7 +81,7 @@ func TestToDeviceTable(t *testing.T) {
// lower limit, cap out
var wantLimit int64 = 1
gotMsgs, upTo, err = table.Messages(deviceID, 0, wantLimit)
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, wantLimit)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -93,10 +94,10 @@ func TestToDeviceTable(t *testing.T) {
}
// delete the first message, requerying only gives 1 message
if err := table.DeleteMessagesUpToAndIncluding(deviceID, lastPos-1); err != nil {
if err := table.DeleteMessagesUpToAndIncluding(sender, deviceID, lastPos-1); err != nil {
t.Fatalf("DeleteMessagesUpTo: %s", err)
}
gotMsgs, upTo, err = table.Messages(deviceID, 0, limit)
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, limit)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -114,9 +115,9 @@ func TestToDeviceTable(t *testing.T) {
t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want))
}
// delete everything and check it works
err = table.DeleteAllMessagesForDevice(deviceID)
err = table.DeleteAllMessagesForDevice(sender, deviceID)
assertNoError(t, err)
msgs, _, err = table.Messages(deviceID, -1, 10)
msgs, _, err = table.Messages(sender, deviceID, -1, 10)
assertNoError(t, err)
assertVal(t, "wanted 0 msgs", len(msgs), 0)
}
@ -132,41 +133,41 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
"foo": "bar",
})
_, err := table.InsertMessages(destination, []json.RawMessage{reqEv1})
_, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv1})
assertNoError(t, err)
gotMsgs, _, err := table.Messages(destination, 0, 10)
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv1)
reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
"foo": "baz",
})
_, err = table.InsertMessages(destination, []json.RawMessage{reqEv2})
_, err = table.InsertMessages(sender, destination, []json.RawMessage{reqEv2})
assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination, 0, 10)
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
bytesEqual(t, gotMsgs[1], reqEv2)
// now delete 1
cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv1})
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv1})
assertNoError(t, err)
// selecting messages now returns only reqEv2
gotMsgs, _, err = table.Messages(destination, 0, 10)
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv2)
// now do lots of close but not quite cancellation requests that should not match reqEv2
_, err = table.InsertMessages(destination, []json.RawMessage{
_, err = table.InsertMessages(sender, destination, []json.RawMessage{
newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action
newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID
newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id
})
assertNoError(t, err)
_, err = table.InsertMessages("wrong_destination", []json.RawMessage{ // wrong destination
_, err = table.InsertMessages(sender, "wrong_destination", []json.RawMessage{ // wrong destination
newRoomKeyEvent(t, "request_cancellation", "2", sender, nil),
})
assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination, 0, 10)
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on
if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest
@ -175,14 +176,14 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
// request + cancel in one go => nothing inserted
destination2 := "DEST2"
_, err = table.InsertMessages(destination2, []json.RawMessage{
_, err = table.InsertMessages(sender, destination2, []json.RawMessage{
newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{
"foo": "baz",
}),
newRoomKeyEvent(t, "request_cancellation", "A", sender, nil),
})
assertNoError(t, err)
gotMsgs, _, err = table.Messages(destination2, 0, 10)
gotMsgs, _, err = table.Messages(sender, destination2, 0, 10)
assertNoError(t, err)
if len(gotMsgs) > 0 {
t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs))
@ -200,18 +201,18 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
"foo": "bar",
})
pos, err := table.InsertMessages(destination, []json.RawMessage{reqEv})
pos, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv})
assertNoError(t, err)
// mark this position as unacked: this means the client MAY know about this request so it isn't
// safe to delete it
err = table.SetUnackedPosition(destination, pos)
err = table.SetUnackedPosition(sender, destination, pos)
assertNoError(t, err)
// now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events
cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv})
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv})
assertNoError(t, err)
// selecting messages returns both events
gotMsgs, _, err := table.Messages(destination, 0, 10)
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
if len(gotMsgs) != 2 {
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
@ -220,14 +221,14 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
bytesEqual(t, gotMsgs[1], cancelEv)
// test that injecting another req/cancel does cause them to be deleted
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
"foo": "bar",
})})
assertNoError(t, err)
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
assertNoError(t, err)
// selecting messages returns the same as before
gotMsgs, _, err = table.Messages(destination, 0, 10)
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
assertNoError(t, err)
if len(gotMsgs) != 2 {
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
@ -240,6 +241,7 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
db, close := connectToDB(t)
defer close()
sender := "@sendymcsendface:localhost"
table := NewToDeviceTable(db)
testCases := []json.RawMessage{
json.RawMessage(`{}`),
@ -250,11 +252,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
}
var pos int64
for _, msg := range testCases {
nextPos, err := table.InsertMessages("A", []json.RawMessage{msg})
nextPos, err := table.InsertMessages(sender, "A", []json.RawMessage{msg})
if err != nil {
t.Fatalf("InsertMessages: %s", err)
}
got, _, err := table.Messages("A", pos, 1)
got, _, err := table.Messages(sender, "A", pos, 1)
if err != nil {
t.Fatalf("Messages: %s", err)
}
@ -262,11 +264,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
pos = nextPos
}
// and all at once
_, err := table.InsertMessages("B", testCases)
_, err := table.InsertMessages(sender, "B", testCases)
if err != nil {
t.Fatalf("InsertMessages: %s", err)
}
got, _, err := table.Messages("B", 0, 100)
got, _, err := table.Messages(sender, "B", 0, 100)
if err != nil {
t.Fatalf("Messages: %s", err)
}

View File

@ -294,7 +294,7 @@ func (h *Handler) OnReceipt(userID, roomID, ephEventType string, ephEvent json.R
}
func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
_, err := h.Store.ToDeviceTable.InsertMessages(deviceID, msgs)
_, err := h.Store.ToDeviceTable.InsertMessages(userID, deviceID, msgs)
if err != nil {
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages")
sentry.CaptureException(err)

View File

@ -75,7 +75,7 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
return
}
// the client is confirming messages up to `from` so delete everything up to and including it.
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.DeviceID, from); err != nil {
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.UserID, extCtx.DeviceID, from); err != nil {
l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value")
// TODO add context to sentry
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
@ -94,14 +94,14 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg))
}
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.DeviceID, from, int64(r.Limit))
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.UserID, extCtx.DeviceID, from, int64(r.Limit))
if err != nil {
l.Err(err).Int64("from", from).Msg("cannot query to-device messages")
// TODO add context to sentry
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.DeviceID, upTo)
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.UserID, extCtx.DeviceID, upTo)
if err != nil {
l.Err(err).Msg("cannot set unacked position")
// TODO add context to sentry