mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Update the to-device table
This commit is contained in:
parent
0adaf75cfc
commit
e1bc972ff7
@ -23,6 +23,7 @@ type ToDeviceTable struct {
|
|||||||
|
|
||||||
type ToDeviceRow struct {
|
type ToDeviceRow struct {
|
||||||
Position int64 `db:"position"`
|
Position int64 `db:"position"`
|
||||||
|
UserID string `db:"user_id"`
|
||||||
DeviceID string `db:"device_id"`
|
DeviceID string `db:"device_id"`
|
||||||
Message string `db:"message"`
|
Message string `db:"message"`
|
||||||
Type string `db:"event_type"`
|
Type string `db:"event_type"`
|
||||||
@ -46,6 +47,7 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
|
|||||||
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
|
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
|
||||||
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
|
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
|
||||||
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
|
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
event_type TEXT NOT NULL,
|
event_type TEXT NOT NULL,
|
||||||
sender TEXT NOT NULL,
|
sender TEXT NOT NULL,
|
||||||
@ -55,7 +57,9 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
|
|||||||
action SMALLINT DEFAULT 0 -- 0 means unknown
|
action SMALLINT DEFAULT 0 -- 0 means unknown
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
|
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
|
||||||
device_id TEXT NOT NULL PRIMARY KEY,
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, device_id),
|
||||||
unack_pos BIGINT NOT NULL
|
unack_pos BIGINT NOT NULL
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
|
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
|
||||||
@ -65,29 +69,31 @@ func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
|
|||||||
return &ToDeviceTable{db}
|
return &ToDeviceTable{db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToDeviceTable) SetUnackedPosition(deviceID string, pos int64) error {
|
func (t *ToDeviceTable) SetUnackedPosition(userID, deviceID string, pos int64) error {
|
||||||
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(device_id, unack_pos) VALUES($1,$2) ON CONFLICT (device_id)
|
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(user_id, device_id, unack_pos) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id)
|
||||||
DO UPDATE SET unack_pos=$2`, deviceID, pos)
|
DO UPDATE SET unack_pos=excluded.unack_pos`, userID, deviceID, pos)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(deviceID string, toIncl int64) error {
|
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(userID, deviceID string, toIncl int64) error {
|
||||||
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1 AND position <= $2`, deviceID, toIncl)
|
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3`, userID, deviceID, toIncl)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToDeviceTable) DeleteAllMessagesForDevice(deviceID string) error {
|
func (t *ToDeviceTable) DeleteAllMessagesForDevice(userID, deviceID string) error {
|
||||||
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`, deviceID)
|
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query to-device messages for this device, exclusive of from and inclusive of to. If a to value is unknown, use -1.
|
// Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`.
|
||||||
func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
|
// Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message
|
||||||
|
// fetched.
|
||||||
|
func (t *ToDeviceTable) Messages(userID, deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
|
||||||
upTo = from
|
upTo = from
|
||||||
var rows []ToDeviceRow
|
var rows []ToDeviceRow
|
||||||
err = t.db.Select(&rows,
|
err = t.db.Select(&rows,
|
||||||
`SELECT position, message FROM syncv3_to_device_messages WHERE device_id = $1 AND position > $2 ORDER BY position ASC LIMIT $3`,
|
`SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4`,
|
||||||
deviceID, from, limit,
|
userID, deviceID, from, limit,
|
||||||
)
|
)
|
||||||
if len(rows) == 0 {
|
if len(rows) == 0 {
|
||||||
return
|
return
|
||||||
@ -98,18 +104,18 @@ func (t *ToDeviceTable) Messages(deviceID string, from, limit int64) (msgs []jso
|
|||||||
m := gjson.ParseBytes(msgs[i])
|
m := gjson.ParseBytes(msgs[i])
|
||||||
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
||||||
if msgId != "" {
|
if msgId != "" {
|
||||||
logger.Info().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.Messages")
|
logger.Info().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.Messages")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
upTo = rows[len(rows)-1].Position
|
upTo = rows[len(rows)-1].Position
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage) (pos int64, err error) {
|
func (t *ToDeviceTable) InsertMessages(userID, deviceID string, msgs []json.RawMessage) (pos int64, err error) {
|
||||||
var lastPos int64
|
var lastPos int64
|
||||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
var unackPos int64
|
var unackPos int64
|
||||||
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE device_id=$1`, deviceID).Scan(&unackPos)
|
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2`, userID, deviceID).Scan(&unackPos)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return fmt.Errorf("unable to select unacked pos: %s", err)
|
return fmt.Errorf("unable to select unacked pos: %s", err)
|
||||||
}
|
}
|
||||||
@ -124,6 +130,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
|||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
m := gjson.ParseBytes(msgs[i])
|
m := gjson.ParseBytes(msgs[i])
|
||||||
rows[i] = ToDeviceRow{
|
rows[i] = ToDeviceRow{
|
||||||
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
Message: string(msgs[i]),
|
Message: string(msgs[i]),
|
||||||
Type: m.Get("type").Str,
|
Type: m.Get("type").Str,
|
||||||
@ -131,7 +138,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
|||||||
}
|
}
|
||||||
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
||||||
if msgId != "" {
|
if msgId != "" {
|
||||||
logger.Debug().Str("msgid", msgId).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
|
logger.Debug().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
|
||||||
}
|
}
|
||||||
switch rows[i].Type {
|
switch rows[i].Type {
|
||||||
case "m.room_key_request":
|
case "m.room_key_request":
|
||||||
@ -155,8 +162,8 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
|||||||
if len(cancels) > 0 {
|
if len(cancels) > 0 {
|
||||||
var cancelled []string
|
var cancelled []string
|
||||||
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
|
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
|
||||||
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND device_id = $2 AND position > $3 RETURNING unique_key`,
|
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key`,
|
||||||
pq.StringArray(cancels), deviceID, unackPos)
|
pq.StringArray(cancels), userID, deviceID, unackPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete cancelled events: %s", err)
|
return fmt.Errorf("failed to delete cancelled events: %s", err)
|
||||||
}
|
}
|
||||||
@ -189,10 +196,10 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := sqlutil.Chunkify(6, MaxPostgresParameters, ToDeviceRowChunker(rows))
|
chunks := sqlutil.Chunkify(7, MaxPostgresParameters, ToDeviceRowChunker(rows))
|
||||||
for _, chunk := range chunks {
|
for _, chunk := range chunks {
|
||||||
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender, action, unique_key)
|
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (user_id, device_id, message, event_type, sender, action, unique_key)
|
||||||
VALUES (:device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
|
VALUES (:user_id, :device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
db, close := connectToDB(t)
|
db, close := connectToDB(t)
|
||||||
defer close()
|
defer close()
|
||||||
table := NewToDeviceTable(db)
|
table := NewToDeviceTable(db)
|
||||||
|
sender := "@alice:localhost"
|
||||||
deviceID := "FOO"
|
deviceID := "FOO"
|
||||||
var limit int64 = 999
|
var limit int64 = 999
|
||||||
msgs := []json.RawMessage{
|
msgs := []json.RawMessage{
|
||||||
@ -20,13 +21,13 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
var lastPos int64
|
var lastPos int64
|
||||||
var err error
|
var err error
|
||||||
if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil {
|
if lastPos, err = table.InsertMessages(sender, deviceID, msgs); err != nil {
|
||||||
t.Fatalf("InsertMessages: %s", err)
|
t.Fatalf("InsertMessages: %s", err)
|
||||||
}
|
}
|
||||||
if lastPos != 2 {
|
if lastPos != 2 {
|
||||||
t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos)
|
t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos)
|
||||||
}
|
}
|
||||||
gotMsgs, upTo, err := table.Messages(deviceID, 0, limit)
|
gotMsgs, upTo, err := table.Messages(sender, deviceID, 0, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -43,7 +44,7 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// same to= token, no messages
|
// same to= token, no messages
|
||||||
gotMsgs, upTo, err = table.Messages(deviceID, lastPos, limit)
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, lastPos, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -55,7 +56,7 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// different device ID, no messages
|
// different device ID, no messages
|
||||||
gotMsgs, upTo, err = table.Messages("OTHER_DEVICE", 0, limit)
|
gotMsgs, upTo, err = table.Messages(sender, "OTHER_DEVICE", 0, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -67,7 +68,7 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// zero limit, no messages
|
// zero limit, no messages
|
||||||
gotMsgs, upTo, err = table.Messages(deviceID, 0, 0)
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -80,7 +81,7 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
|
|
||||||
// lower limit, cap out
|
// lower limit, cap out
|
||||||
var wantLimit int64 = 1
|
var wantLimit int64 = 1
|
||||||
gotMsgs, upTo, err = table.Messages(deviceID, 0, wantLimit)
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, wantLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -93,10 +94,10 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// delete the first message, requerying only gives 1 message
|
// delete the first message, requerying only gives 1 message
|
||||||
if err := table.DeleteMessagesUpToAndIncluding(deviceID, lastPos-1); err != nil {
|
if err := table.DeleteMessagesUpToAndIncluding(sender, deviceID, lastPos-1); err != nil {
|
||||||
t.Fatalf("DeleteMessagesUpTo: %s", err)
|
t.Fatalf("DeleteMessagesUpTo: %s", err)
|
||||||
}
|
}
|
||||||
gotMsgs, upTo, err = table.Messages(deviceID, 0, limit)
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -114,9 +115,9 @@ func TestToDeviceTable(t *testing.T) {
|
|||||||
t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want))
|
t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want))
|
||||||
}
|
}
|
||||||
// delete everything and check it works
|
// delete everything and check it works
|
||||||
err = table.DeleteAllMessagesForDevice(deviceID)
|
err = table.DeleteAllMessagesForDevice(sender, deviceID)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
msgs, _, err = table.Messages(deviceID, -1, 10)
|
msgs, _, err = table.Messages(sender, deviceID, -1, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
assertVal(t, "wanted 0 msgs", len(msgs), 0)
|
assertVal(t, "wanted 0 msgs", len(msgs), 0)
|
||||||
}
|
}
|
||||||
@ -132,41 +133,41 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
|
|||||||
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
})
|
})
|
||||||
_, err := table.InsertMessages(destination, []json.RawMessage{reqEv1})
|
_, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv1})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
gotMsgs, _, err := table.Messages(destination, 0, 10)
|
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
bytesEqual(t, gotMsgs[0], reqEv1)
|
bytesEqual(t, gotMsgs[0], reqEv1)
|
||||||
reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
||||||
"foo": "baz",
|
"foo": "baz",
|
||||||
})
|
})
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{reqEv2})
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{reqEv2})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
gotMsgs, _, err = table.Messages(destination, 0, 10)
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
bytesEqual(t, gotMsgs[1], reqEv2)
|
bytesEqual(t, gotMsgs[1], reqEv2)
|
||||||
|
|
||||||
// now delete 1
|
// now delete 1
|
||||||
cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv1})
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv1})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
// selecting messages now returns only reqEv2
|
// selecting messages now returns only reqEv2
|
||||||
gotMsgs, _, err = table.Messages(destination, 0, 10)
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
bytesEqual(t, gotMsgs[0], reqEv2)
|
bytesEqual(t, gotMsgs[0], reqEv2)
|
||||||
|
|
||||||
// now do lots of close but not quite cancellation requests that should not match reqEv2
|
// now do lots of close but not quite cancellation requests that should not match reqEv2
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{
|
||||||
newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action
|
newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action
|
||||||
newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID
|
newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID
|
||||||
newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id
|
newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id
|
||||||
})
|
})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
_, err = table.InsertMessages("wrong_destination", []json.RawMessage{ // wrong destination
|
_, err = table.InsertMessages(sender, "wrong_destination", []json.RawMessage{ // wrong destination
|
||||||
newRoomKeyEvent(t, "request_cancellation", "2", sender, nil),
|
newRoomKeyEvent(t, "request_cancellation", "2", sender, nil),
|
||||||
})
|
})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
gotMsgs, _, err = table.Messages(destination, 0, 10)
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on
|
bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on
|
||||||
if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest
|
if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest
|
||||||
@ -175,14 +176,14 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
|
|||||||
|
|
||||||
// request + cancel in one go => nothing inserted
|
// request + cancel in one go => nothing inserted
|
||||||
destination2 := "DEST2"
|
destination2 := "DEST2"
|
||||||
_, err = table.InsertMessages(destination2, []json.RawMessage{
|
_, err = table.InsertMessages(sender, destination2, []json.RawMessage{
|
||||||
newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{
|
newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{
|
||||||
"foo": "baz",
|
"foo": "baz",
|
||||||
}),
|
}),
|
||||||
newRoomKeyEvent(t, "request_cancellation", "A", sender, nil),
|
newRoomKeyEvent(t, "request_cancellation", "A", sender, nil),
|
||||||
})
|
})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
gotMsgs, _, err = table.Messages(destination2, 0, 10)
|
gotMsgs, _, err = table.Messages(sender, destination2, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
if len(gotMsgs) > 0 {
|
if len(gotMsgs) > 0 {
|
||||||
t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs))
|
t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs))
|
||||||
@ -200,18 +201,18 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
|
|||||||
reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
})
|
})
|
||||||
pos, err := table.InsertMessages(destination, []json.RawMessage{reqEv})
|
pos, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
// mark this position as unacked: this means the client MAY know about this request so it isn't
|
// mark this position as unacked: this means the client MAY know about this request so it isn't
|
||||||
// safe to delete it
|
// safe to delete it
|
||||||
err = table.SetUnackedPosition(destination, pos)
|
err = table.SetUnackedPosition(sender, destination, pos)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
// now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events
|
// now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events
|
||||||
cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{cancelEv})
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
// selecting messages returns both events
|
// selecting messages returns both events
|
||||||
gotMsgs, _, err := table.Messages(destination, 0, 10)
|
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
if len(gotMsgs) != 2 {
|
if len(gotMsgs) != 2 {
|
||||||
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
||||||
@ -220,14 +221,14 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
|
|||||||
bytesEqual(t, gotMsgs[1], cancelEv)
|
bytesEqual(t, gotMsgs[1], cancelEv)
|
||||||
|
|
||||||
// test that injecting another req/cancel does cause them to be deleted
|
// test that injecting another req/cancel does cause them to be deleted
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
||||||
"foo": "bar",
|
"foo": "bar",
|
||||||
})})
|
})})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
_, err = table.InsertMessages(destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
// selecting messages returns the same as before
|
// selecting messages returns the same as before
|
||||||
gotMsgs, _, err = table.Messages(destination, 0, 10)
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
if len(gotMsgs) != 2 {
|
if len(gotMsgs) != 2 {
|
||||||
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
||||||
@ -240,6 +241,7 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
|
|||||||
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
|
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
|
||||||
db, close := connectToDB(t)
|
db, close := connectToDB(t)
|
||||||
defer close()
|
defer close()
|
||||||
|
sender := "@sendymcsendface:localhost"
|
||||||
table := NewToDeviceTable(db)
|
table := NewToDeviceTable(db)
|
||||||
testCases := []json.RawMessage{
|
testCases := []json.RawMessage{
|
||||||
json.RawMessage(`{}`),
|
json.RawMessage(`{}`),
|
||||||
@ -250,11 +252,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
|
|||||||
}
|
}
|
||||||
var pos int64
|
var pos int64
|
||||||
for _, msg := range testCases {
|
for _, msg := range testCases {
|
||||||
nextPos, err := table.InsertMessages("A", []json.RawMessage{msg})
|
nextPos, err := table.InsertMessages(sender, "A", []json.RawMessage{msg})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("InsertMessages: %s", err)
|
t.Fatalf("InsertMessages: %s", err)
|
||||||
}
|
}
|
||||||
got, _, err := table.Messages("A", pos, 1)
|
got, _, err := table.Messages(sender, "A", pos, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
@ -262,11 +264,11 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
|
|||||||
pos = nextPos
|
pos = nextPos
|
||||||
}
|
}
|
||||||
// and all at once
|
// and all at once
|
||||||
_, err := table.InsertMessages("B", testCases)
|
_, err := table.InsertMessages(sender, "B", testCases)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("InsertMessages: %s", err)
|
t.Fatalf("InsertMessages: %s", err)
|
||||||
}
|
}
|
||||||
got, _, err := table.Messages("B", 0, 100)
|
got, _, err := table.Messages(sender, "B", 0, 100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Messages: %s", err)
|
t.Fatalf("Messages: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -294,7 +294,7 @@ func (h *Handler) OnReceipt(userID, roomID, ephEventType string, ephEvent json.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
|
func (h *Handler) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
|
||||||
_, err := h.Store.ToDeviceTable.InsertMessages(deviceID, msgs)
|
_, err := h.Store.ToDeviceTable.InsertMessages(userID, deviceID, msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages")
|
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages")
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
|
@ -75,7 +75,7 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// the client is confirming messages up to `from` so delete everything up to and including it.
|
// the client is confirming messages up to `from` so delete everything up to and including it.
|
||||||
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.DeviceID, from); err != nil {
|
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.UserID, extCtx.DeviceID, from); err != nil {
|
||||||
l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value")
|
l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value")
|
||||||
// TODO add context to sentry
|
// TODO add context to sentry
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
@ -94,14 +94,14 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext
|
|||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg))
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf(errMsg))
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.DeviceID, from, int64(r.Limit))
|
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.UserID, extCtx.DeviceID, from, int64(r.Limit))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Err(err).Int64("from", from).Msg("cannot query to-device messages")
|
l.Err(err).Int64("from", from).Msg("cannot query to-device messages")
|
||||||
// TODO add context to sentry
|
// TODO add context to sentry
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.DeviceID, upTo)
|
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.UserID, extCtx.DeviceID, upTo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Err(err).Msg("cannot set unacked position")
|
l.Err(err).Msg("cannot set unacked position")
|
||||||
// TODO add context to sentry
|
// TODO add context to sentry
|
||||||
|
Loading…
x
Reference in New Issue
Block a user