2021-08-03 09:33:38 +01:00
|
|
|
package state
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2021-08-03 12:06:09 +01:00
|
|
|
"encoding/json"
|
2021-08-03 09:33:38 +01:00
|
|
|
"testing"
|
|
|
|
|
2022-12-21 11:28:43 +00:00
|
|
|
"github.com/tidwall/gjson"
|
2021-08-03 09:33:38 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
func TestToDeviceTable(t *testing.T) {
|
2023-01-18 14:54:26 +00:00
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
2021-08-03 09:33:38 +01:00
|
|
|
table := NewToDeviceTable(db)
|
2023-05-02 14:34:22 +01:00
|
|
|
sender := "@alice:localhost"
|
2021-08-03 09:33:38 +01:00
|
|
|
deviceID := "FOO"
|
2021-08-05 10:54:04 +01:00
|
|
|
var limit int64 = 999
|
2021-12-14 11:51:47 +00:00
|
|
|
msgs := []json.RawMessage{
|
|
|
|
json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"bar"}}`),
|
|
|
|
json.RawMessage(`{"sender":"bob","type":"something","content":{"foo":"bar2"}}`),
|
2021-08-03 09:33:38 +01:00
|
|
|
}
|
2021-08-03 12:06:09 +01:00
|
|
|
var lastPos int64
|
2023-01-18 14:54:26 +00:00
|
|
|
var err error
|
2023-05-02 14:34:22 +01:00
|
|
|
if lastPos, err = table.InsertMessages(sender, deviceID, msgs); err != nil {
|
2021-08-03 09:33:38 +01:00
|
|
|
t.Fatalf("InsertMessages: %s", err)
|
|
|
|
}
|
2021-08-03 12:06:09 +01:00
|
|
|
if lastPos != 2 {
|
|
|
|
t.Fatalf("InsertMessages: bad pos returned, got %d want 2", lastPos)
|
|
|
|
}
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err := table.Messages(sender, deviceID, 0, limit)
|
2021-12-14 14:38:39 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
|
|
|
if upTo != lastPos {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, lastPos)
|
|
|
|
}
|
|
|
|
if len(gotMsgs) != len(msgs) {
|
|
|
|
t.Fatalf("Messages: got %d messages, want %d", len(gotMsgs), len(msgs))
|
|
|
|
}
|
|
|
|
for i := range msgs {
|
|
|
|
if !bytes.Equal(msgs[i], gotMsgs[i]) {
|
2021-08-03 09:33:38 +01:00
|
|
|
t.Fatalf("Messages: got %+v want %+v", gotMsgs[i], msgs[i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// same to= token, no messages
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, lastPos, limit)
|
2021-08-03 09:33:38 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
2021-08-05 10:54:04 +01:00
|
|
|
if upTo != lastPos {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, lastPos)
|
|
|
|
}
|
2021-08-03 09:33:38 +01:00
|
|
|
if len(gotMsgs) > 0 {
|
|
|
|
t.Fatalf("Messages: got %d messages, want none", len(gotMsgs))
|
|
|
|
}
|
|
|
|
|
|
|
|
// different device ID, no messages
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err = table.Messages(sender, "OTHER_DEVICE", 0, limit)
|
2021-08-03 09:33:38 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
if upTo != 0 {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, 0)
|
2021-08-05 10:54:04 +01:00
|
|
|
}
|
2021-08-03 09:33:38 +01:00
|
|
|
if len(gotMsgs) > 0 {
|
|
|
|
t.Fatalf("Messages: got %d messages, want none", len(gotMsgs))
|
|
|
|
}
|
2021-08-05 10:54:04 +01:00
|
|
|
|
|
|
|
// zero limit, no messages
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, 0)
|
2021-08-05 10:54:04 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
if upTo != 0 {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, 0)
|
2021-08-05 10:54:04 +01:00
|
|
|
}
|
|
|
|
if len(gotMsgs) > 0 {
|
|
|
|
t.Fatalf("Messages: got %d messages, want none", len(gotMsgs))
|
|
|
|
}
|
|
|
|
|
|
|
|
// lower limit, cap out
|
|
|
|
var wantLimit int64 = 1
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, wantLimit)
|
2021-08-05 10:54:04 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
|
|
|
// we inserted 2 messages, and request a limit of 1 so the position should be one before
|
|
|
|
if upTo != (lastPos - 1) {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, lastPos-1)
|
|
|
|
}
|
|
|
|
if int64(len(gotMsgs)) != wantLimit {
|
|
|
|
t.Fatalf("Messages: got %d messages, want %d", len(gotMsgs), wantLimit)
|
|
|
|
}
|
2021-08-05 12:50:03 +01:00
|
|
|
|
|
|
|
// delete the first message, requerying only gives 1 message
|
2023-05-02 14:34:22 +01:00
|
|
|
if err := table.DeleteMessagesUpToAndIncluding(sender, deviceID, lastPos-1); err != nil {
|
2021-08-05 12:50:03 +01:00
|
|
|
t.Fatalf("DeleteMessagesUpTo: %s", err)
|
|
|
|
}
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, upTo, err = table.Messages(sender, deviceID, 0, limit)
|
2021-08-05 12:50:03 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
|
|
|
if upTo != lastPos {
|
|
|
|
t.Errorf("Message: got up to %d want %d", upTo, lastPos)
|
|
|
|
}
|
|
|
|
if len(gotMsgs) != 1 {
|
|
|
|
t.Fatalf("Messages: got %d messages, want %d", len(gotMsgs), 1)
|
|
|
|
}
|
|
|
|
want, err := json.Marshal(msgs[1])
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("failed to marshal msg: %s", err)
|
|
|
|
}
|
|
|
|
if !bytes.Equal(gotMsgs[0], want) {
|
|
|
|
t.Fatalf("Messages: deleted message but unexpected message left: got %s want %s", string(gotMsgs[0]), string(want))
|
|
|
|
}
|
2023-03-01 16:56:04 +00:00
|
|
|
// delete everything and check it works
|
2023-05-02 14:34:22 +01:00
|
|
|
err = table.DeleteAllMessagesForDevice(sender, deviceID)
|
2023-03-01 16:56:04 +00:00
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
msgs, _, err = table.Messages(sender, deviceID, -1, 10)
|
2023-03-01 16:56:04 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
assertVal(t, "wanted 0 msgs", len(msgs), 0)
|
2021-08-03 09:33:38 +01:00
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
|
|
|
|
// Test that https://github.com/uhoreg/matrix-doc/blob/drop-stale-to-device/proposals/3944-drop-stale-to-device.md works for m.room_key_request
|
|
|
|
func TestToDeviceTableDeleteCancels(t *testing.T) {
|
2023-01-18 14:54:26 +00:00
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
2022-12-14 18:53:55 +00:00
|
|
|
sender := "SENDER"
|
|
|
|
destination := "DEST"
|
|
|
|
table := NewToDeviceTable(db)
|
|
|
|
// insert 2 requests
|
|
|
|
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
|
|
|
"foo": "bar",
|
|
|
|
})
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv1})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
bytesEqual(t, gotMsgs[0], reqEv1)
|
|
|
|
reqEv2 := newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
|
|
|
"foo": "baz",
|
|
|
|
})
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{reqEv2})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
bytesEqual(t, gotMsgs[1], reqEv2)
|
|
|
|
|
|
|
|
// now delete 1
|
|
|
|
cancelEv1 := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv1})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
// selecting messages now returns only reqEv2
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
bytesEqual(t, gotMsgs[0], reqEv2)
|
|
|
|
|
|
|
|
// now do lots of close but not quite cancellation requests that should not match reqEv2
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{
|
2022-12-14 18:53:55 +00:00
|
|
|
newRoomKeyEvent(t, "cancellation", "2", sender, nil), // wrong action
|
|
|
|
newRoomKeyEvent(t, "request_cancellation", "22", sender, nil), // wrong request ID
|
|
|
|
newRoomKeyEvent(t, "request_cancellation", "2", "not_who_you_think", nil), // wrong req device id
|
|
|
|
})
|
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, "wrong_destination", []json.RawMessage{ // wrong destination
|
2022-12-14 18:53:55 +00:00
|
|
|
newRoomKeyEvent(t, "request_cancellation", "2", sender, nil),
|
|
|
|
})
|
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
bytesEqual(t, gotMsgs[0], reqEv2) // the request lives on
|
|
|
|
if len(gotMsgs) != 4 { // the cancellations live on too, but not the one sent to the wrong dest
|
|
|
|
t.Errorf("got %d msgs, want 4", len(gotMsgs))
|
|
|
|
}
|
|
|
|
|
|
|
|
// request + cancel in one go => nothing inserted
|
|
|
|
destination2 := "DEST2"
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination2, []json.RawMessage{
|
2022-12-14 18:53:55 +00:00
|
|
|
newRoomKeyEvent(t, "request", "A", sender, map[string]interface{}{
|
|
|
|
"foo": "baz",
|
|
|
|
}),
|
|
|
|
newRoomKeyEvent(t, "request_cancellation", "A", sender, nil),
|
|
|
|
})
|
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err = table.Messages(sender, destination2, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
if len(gotMsgs) > 0 {
|
|
|
|
t.Errorf("Got %+v want nothing", jsonArrStr(gotMsgs))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Test that unacked events are safe from deletion
|
|
|
|
func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
|
2023-01-18 14:54:26 +00:00
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
2022-12-14 18:53:55 +00:00
|
|
|
sender := "SENDER2"
|
|
|
|
destination := "DEST2"
|
|
|
|
table := NewToDeviceTable(db)
|
|
|
|
// insert request
|
|
|
|
reqEv := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
|
|
|
|
"foo": "bar",
|
|
|
|
})
|
2023-05-02 14:34:22 +01:00
|
|
|
pos, err := table.InsertMessages(sender, destination, []json.RawMessage{reqEv})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
// mark this position as unacked: this means the client MAY know about this request so it isn't
|
|
|
|
// safe to delete it
|
2023-05-02 14:34:22 +01:00
|
|
|
err = table.SetUnackedPosition(sender, destination, pos)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
// now issue a cancellation: this should NOT result in a cancellation due to protection for unacked events
|
|
|
|
cancelEv := newRoomKeyEvent(t, "request_cancellation", "1", sender, nil)
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{cancelEv})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
// selecting messages returns both events
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err := table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
if len(gotMsgs) != 2 {
|
|
|
|
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
|
|
|
}
|
|
|
|
bytesEqual(t, gotMsgs[0], reqEv)
|
|
|
|
bytesEqual(t, gotMsgs[1], cancelEv)
|
|
|
|
|
|
|
|
// test that injecting another req/cancel does cause them to be deleted
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request", "2", sender, map[string]interface{}{
|
2022-12-14 18:53:55 +00:00
|
|
|
"foo": "bar",
|
|
|
|
})})
|
|
|
|
assertNoError(t, err)
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err = table.InsertMessages(sender, destination, []json.RawMessage{newRoomKeyEvent(t, "request_cancellation", "2", sender, nil)})
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
// selecting messages returns the same as before
|
2023-05-02 14:34:22 +01:00
|
|
|
gotMsgs, _, err = table.Messages(sender, destination, 0, 10)
|
2022-12-14 18:53:55 +00:00
|
|
|
assertNoError(t, err)
|
|
|
|
if len(gotMsgs) != 2 {
|
|
|
|
t.Fatalf("got %d msgs, want 2: %v", len(gotMsgs), jsonArrStr(gotMsgs))
|
|
|
|
}
|
|
|
|
bytesEqual(t, gotMsgs[0], reqEv)
|
|
|
|
bytesEqual(t, gotMsgs[1], cancelEv)
|
2022-12-21 10:53:05 +00:00
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
|
2022-12-21 10:53:05 +00:00
|
|
|
// Guard against possible message truncation?
|
|
|
|
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
|
2023-01-18 14:54:26 +00:00
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
2023-05-02 14:34:22 +01:00
|
|
|
sender := "@sendymcsendface:localhost"
|
2022-12-21 10:53:05 +00:00
|
|
|
table := NewToDeviceTable(db)
|
|
|
|
testCases := []json.RawMessage{
|
|
|
|
json.RawMessage(`{}`),
|
|
|
|
json.RawMessage(`{"foo":"bar"}`),
|
|
|
|
json.RawMessage(`{ "foo": "bar" }`),
|
|
|
|
json.RawMessage(`{ not even valid json :D }`),
|
2022-12-21 10:58:57 +00:00
|
|
|
json.RawMessage(`{ "\~./.-$%_!@£?;'\[]= }`),
|
2022-12-21 10:53:05 +00:00
|
|
|
}
|
|
|
|
var pos int64
|
|
|
|
for _, msg := range testCases {
|
2023-05-02 14:34:22 +01:00
|
|
|
nextPos, err := table.InsertMessages(sender, "A", []json.RawMessage{msg})
|
2022-12-21 10:53:05 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("InsertMessages: %s", err)
|
|
|
|
}
|
2023-05-02 14:34:22 +01:00
|
|
|
got, _, err := table.Messages(sender, "A", pos, 1)
|
2022-12-21 10:53:05 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
|
|
|
bytesEqual(t, got[0], msg)
|
|
|
|
pos = nextPos
|
|
|
|
}
|
2022-12-21 10:56:25 +00:00
|
|
|
// and all at once
|
2023-05-02 14:34:22 +01:00
|
|
|
_, err := table.InsertMessages(sender, "B", testCases)
|
2022-12-21 10:56:25 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("InsertMessages: %s", err)
|
|
|
|
}
|
2023-05-02 14:34:22 +01:00
|
|
|
got, _, err := table.Messages(sender, "B", 0, 100)
|
2022-12-21 10:56:25 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Messages: %s", err)
|
|
|
|
}
|
|
|
|
if len(got) != len(testCases) {
|
|
|
|
t.Fatalf("got %d messages, want %d", len(got), len(testCases))
|
|
|
|
}
|
|
|
|
for i := range testCases {
|
|
|
|
bytesEqual(t, got[i], testCases[i])
|
|
|
|
}
|
2022-12-14 18:53:55 +00:00
|
|
|
}
|
|
|
|
|
2022-12-21 11:28:43 +00:00
|
|
|
func TestMsgID(t *testing.T) {
|
|
|
|
data := json.RawMessage(`{
|
|
|
|
"content": {
|
|
|
|
"algorithm": "m.olm.v1.curve25519-aes-sha2",
|
|
|
|
"ciphertext": {
|
|
|
|
"gMObR+/4dqL5T4DisRRRYBJpn+OjzFnkyCFOktP6Eyw": {
|
|
|
|
"body": "AwogrdbTbG8VCW....slqU",
|
|
|
|
"type": 0
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"org.matrix.msgid": "6390a372-fd3c-4f56-b0d5-2f2ce39f2d56",
|
|
|
|
"sender_key": "EWnYTm/yIQ1lStSIqO6fdVYvS69OfU2DzrX+q+1d+w8"
|
|
|
|
},
|
|
|
|
"type": "m.room.encrypted",
|
|
|
|
"sender": "@sample:localhost:8480"
|
|
|
|
}`)
|
|
|
|
m := gjson.ParseBytes(data)
|
|
|
|
got := m.Get(`content.org\.matrix\.msgid`).Str
|
|
|
|
want := "6390a372-fd3c-4f56-b0d5-2f2ce39f2d56"
|
|
|
|
if got != want {
|
|
|
|
t.Fatalf("got %v want %v", got, want)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-14 18:53:55 +00:00
|
|
|
func bytesEqual(t *testing.T, got, want json.RawMessage) {
|
|
|
|
t.Helper()
|
|
|
|
if !bytes.Equal(got, want) {
|
|
|
|
t.Fatalf("bytesEqual: \ngot %s\n want %s", string(got), string(want))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type roomKeyRequest struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Content roomKeyRequestContent `json:"content"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type roomKeyRequestContent struct {
|
|
|
|
Action string `json:"action"`
|
|
|
|
RequestID string `json:"request_id"`
|
|
|
|
RequestingDeviceID string `json:"requesting_device_id"`
|
|
|
|
Body map[string]interface{} `json:"body,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func newRoomKeyEvent(t *testing.T, action, reqID, reqDeviceID string, body map[string]interface{}) json.RawMessage {
|
|
|
|
rkr := roomKeyRequest{
|
|
|
|
Type: "m.room_key_request",
|
|
|
|
Content: roomKeyRequestContent{
|
|
|
|
Action: action,
|
|
|
|
RequestID: reqID,
|
|
|
|
RequestingDeviceID: reqDeviceID,
|
|
|
|
Body: body,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
b, err := json.Marshal(rkr)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("newRoomKeyEvent: %s", err)
|
|
|
|
}
|
|
|
|
return json.RawMessage(b)
|
|
|
|
}
|