Flesh out to_device_table with tests, update gjson dep

This commit is contained in:
Kegan Dougal 2021-08-03 09:33:38 +01:00
parent bcfe9b051f
commit 0075b46bc9
5 changed files with 133 additions and 8 deletions

2
go.mod
View File

@ -8,5 +8,5 @@ require (
github.com/lib/pq v1.10.1
github.com/matrix-org/gomatrixserverlib v0.0.0-20210510192107-124228cb9548
github.com/rs/zerolog v1.21.0
github.com/tidwall/gjson v1.6.0
github.com/tidwall/gjson v1.8.1
)

6
go.sum
View File

@ -55,11 +55,17 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc=
github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/gjson v1.8.1 h1:8j5EE9Hrh3l9Od1OIEDAb7IpezNA20UdRngNAj5N0WU=
github.com/tidwall/gjson v1.8.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk=
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE=
github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8=
github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.1.0 h1:K3hMW5epkdAVwibsQEfR/7Zj0Qgt4DxtNumTq/VloO8=
github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8=
github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

View File

@ -8,8 +8,9 @@ import (
)
type Storage struct {
accumulator *Accumulator
typingTable *TypingTable
accumulator *Accumulator
typingTable *TypingTable
toDeviceTable *ToDeviceTable
}
func NewStorage(postgresURI string) *Storage {
@ -27,8 +28,9 @@ func NewStorage(postgresURI string) *Storage {
entityName: "server",
}
return &Storage{
accumulator: acc,
typingTable: NewTypingTable(db),
accumulator: acc,
typingTable: NewTypingTable(db),
toDeviceTable: NewToDeviceTable(db),
}
}

View File

@ -6,25 +6,58 @@ import (
"github.com/jmoiron/sqlx"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/sync-v3/sqlutil"
)
// ToDeviceTable stores to_device messages for devices.
type ToDeviceTable struct{}
type ToDeviceRow struct {
Position int64 `db:"position"`
DeviceID string `db:"device_id"`
Message string `db:"message"`
}
func NewToDeviceTable(db *sqlx.DB) *RoomsTable {
type ToDeviceRowChunker []ToDeviceRow
func (c ToDeviceRowChunker) Len() int {
return len(c)
}
func (c ToDeviceRowChunker) Subslice(i, j int) sqlutil.Chunker {
return c[i:j]
}
func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
// make sure tables are made
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
device_id TEXT NOT NULL PRIMARY KEY,
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
device_id TEXT NOT NULL,
message TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
`)
return &RoomsTable{}
return &ToDeviceTable{}
}
func (t *ToDeviceTable) Messages(txn *sqlx.Tx, deviceID string, from int64) (msgs []gomatrixserverlib.SendToDeviceEvent, to int64, err error) {
var rows []ToDeviceRow
err = txn.Select(&rows, `SELECT position, message FROM syncv3_to_device_messages WHERE device_id = $1 AND position > $2 ORDER BY position ASC`, deviceID, from)
if len(rows) == 0 {
to = from
return
}
to = rows[len(rows)-1].Position
msgs = make([]gomatrixserverlib.SendToDeviceEvent, len(rows))
for i := range rows {
var stdev gomatrixserverlib.SendToDeviceEvent
if err = json.Unmarshal([]byte(rows[i].Message), &stdev); err != nil {
return
}
msgs[i] = stdev
}
return
}
func (t *ToDeviceTable) InsertMessages(txn *sqlx.Tx, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) (err error) {
@ -39,5 +72,13 @@ func (t *ToDeviceTable) InsertMessages(txn *sqlx.Tx, deviceID string, msgs []gom
Message: string(msgJSON),
}
}
chunks := sqlutil.Chunkify(2, 65535, ToDeviceRowChunker(rows))
for _, chunk := range chunks {
_, err := txn.NamedExec(`INSERT INTO syncv3_to_device_messages (device_id, message)
VALUES (:device_id, :message)`, chunk)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,76 @@
package state
import (
"bytes"
"testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/gomatrixserverlib"
)
func TestToDeviceTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewToDeviceTable(db)
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
deviceID := "FOO"
msgs := []gomatrixserverlib.SendToDeviceEvent{
{
Sender: "alice",
Type: "something",
Content: []byte(`{"foo":"bar"}`),
},
{
Sender: "bob",
Type: "something",
Content: []byte(`{"foo":"bar2"}`),
},
}
if err = table.InsertMessages(txn, deviceID, msgs); err != nil {
t.Fatalf("InsertMessages: %s", err)
}
gotMsgs, to, err := table.Messages(txn, deviceID, 0)
if err != nil {
t.Fatalf("Messages: %s", err)
}
if to == 0 {
t.Fatalf("Messages: got to=0")
}
if len(gotMsgs) != len(msgs) {
t.Fatalf("Messages: got %d messages, want %d", len(gotMsgs), len(msgs))
}
for i := range msgs {
if !bytes.Equal(msgs[i].Content, gotMsgs[i].Content) {
t.Fatalf("Messages: got %+v want %+v", gotMsgs[i], msgs[i])
}
}
// same to= token, no messages
gotMsgs, to2, err := table.Messages(txn, deviceID, to)
if err != nil {
t.Fatalf("Messages: %s", err)
}
if to2 != to {
t.Fatalf("Messages: got to=%d want to=%d", to2, to)
}
if len(gotMsgs) > 0 {
t.Fatalf("Messages: got %d messages, want none", len(gotMsgs))
}
// different device ID, no messages
gotMsgs, to, err = table.Messages(txn, "OTHER_DEVICE", 0)
if err != nil {
t.Fatalf("Messages: %s", err)
}
if to != 0 {
t.Fatalf("Messages: got to=%d want 0", to)
}
if len(gotMsgs) > 0 {
t.Fatalf("Messages: got %d messages, want none", len(gotMsgs))
}
}