mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
223 lines
7.8 KiB
Go
223 lines
7.8 KiB
Go
package state
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/lib/pq"
|
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
const (
|
|
ActionRequest = 1
|
|
ActionCancel = 2
|
|
)
|
|
|
|
// ToDeviceTable stores to_device messages for devices.
|
|
type ToDeviceTable struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
type ToDeviceRow struct {
|
|
Position int64 `db:"position"`
|
|
UserID string `db:"user_id"`
|
|
DeviceID string `db:"device_id"`
|
|
Message string `db:"message"`
|
|
Type string `db:"event_type"`
|
|
Sender string `db:"sender"`
|
|
UniqueKey *string `db:"unique_key"`
|
|
Action int `db:"action"`
|
|
}
|
|
|
|
type ToDeviceRowChunker []ToDeviceRow
|
|
|
|
func (c ToDeviceRowChunker) Len() int {
|
|
return len(c)
|
|
}
|
|
func (c ToDeviceRowChunker) Subslice(i, j int) sqlutil.Chunker {
|
|
return c[i:j]
|
|
}
|
|
|
|
func NewToDeviceTable(db *sqlx.DB) *ToDeviceTable {
|
|
// make sure tables are made
|
|
db.MustExec(`
|
|
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq;
|
|
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
|
|
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval('syncv3_to_device_messages_seq'),
|
|
user_id TEXT NOT NULL,
|
|
device_id TEXT NOT NULL,
|
|
event_type TEXT NOT NULL,
|
|
sender TEXT NOT NULL,
|
|
message TEXT NOT NULL,
|
|
-- nullable as these fields are not on all to-device events
|
|
unique_key TEXT,
|
|
action SMALLINT DEFAULT 0 -- 0 means unknown
|
|
);
|
|
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
|
|
user_id TEXT NOT NULL,
|
|
device_id TEXT NOT NULL,
|
|
PRIMARY KEY (user_id, device_id),
|
|
unack_pos BIGINT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages(device_id);
|
|
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_ukey_idx ON syncv3_to_device_messages(unique_key, device_id);
|
|
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_pos_device_idx ON syncv3_to_device_messages(position, device_id);
|
|
`)
|
|
return &ToDeviceTable{db}
|
|
}
|
|
|
|
func (t *ToDeviceTable) SetUnackedPosition(userID, deviceID string, pos int64) error {
|
|
_, err := t.db.Exec(`INSERT INTO syncv3_to_device_ack_pos(user_id, device_id, unack_pos) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id)
|
|
DO UPDATE SET unack_pos=excluded.unack_pos`, userID, deviceID, pos)
|
|
return err
|
|
}
|
|
|
|
func (t *ToDeviceTable) DeleteMessagesUpToAndIncluding(userID, deviceID string, toIncl int64) error {
|
|
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3`, userID, deviceID, toIncl)
|
|
return err
|
|
}
|
|
|
|
func (t *ToDeviceTable) DeleteAllMessagesForDevice(userID, deviceID string) error {
|
|
// TODO: should these deletes take place in a transaction?
|
|
_, err := t.db.Exec(`DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = t.db.Exec(`DELETE FROM syncv3_to_device_ack_pos WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
|
|
return err
|
|
}
|
|
|
|
// Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`.
|
|
// Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message
|
|
// fetched.
|
|
func (t *ToDeviceTable) Messages(userID, deviceID string, from, limit int64) (msgs []json.RawMessage, upTo int64, err error) {
|
|
upTo = from
|
|
var rows []ToDeviceRow
|
|
err = t.db.Select(&rows,
|
|
`SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4`,
|
|
userID, deviceID, from, limit,
|
|
)
|
|
if len(rows) == 0 {
|
|
return
|
|
}
|
|
msgs = make([]json.RawMessage, len(rows))
|
|
for i := range rows {
|
|
msgs[i] = json.RawMessage(rows[i].Message)
|
|
m := gjson.ParseBytes(msgs[i])
|
|
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
|
if msgId != "" {
|
|
logger.Info().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.Messages")
|
|
}
|
|
}
|
|
upTo = rows[len(rows)-1].Position
|
|
return
|
|
}
|
|
|
|
func (t *ToDeviceTable) InsertMessages(userID, deviceID string, msgs []json.RawMessage) (pos int64, err error) {
|
|
var lastPos int64
|
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
|
var unackPos int64
|
|
err = txn.QueryRow(`SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2`, userID, deviceID).Scan(&unackPos)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return fmt.Errorf("unable to select unacked pos: %s", err)
|
|
}
|
|
|
|
// Some of these events may be "cancel" actions. If we find events for the unique key of this event, then delete them
|
|
// and ignore the "cancel" action.
|
|
cancels := []string{}
|
|
allRequests := make(map[string]struct{})
|
|
allCancels := make(map[string]struct{})
|
|
|
|
rows := make([]ToDeviceRow, len(msgs))
|
|
for i := range msgs {
|
|
m := gjson.ParseBytes(msgs[i])
|
|
rows[i] = ToDeviceRow{
|
|
UserID: userID,
|
|
DeviceID: deviceID,
|
|
Message: string(msgs[i]),
|
|
Type: m.Get("type").Str,
|
|
Sender: m.Get("sender").Str,
|
|
}
|
|
msgId := m.Get(`content.org\.matrix\.msgid`).Str
|
|
if msgId != "" {
|
|
logger.Debug().Str("msgid", msgId).Str("user", userID).Str("device", deviceID).Msg("ToDeviceTable.InsertMessages")
|
|
}
|
|
switch rows[i].Type {
|
|
case "m.room_key_request":
|
|
action := m.Get("content.action").Str
|
|
if action == "request" {
|
|
rows[i].Action = ActionRequest
|
|
} else if action == "request_cancellation" {
|
|
rows[i].Action = ActionCancel
|
|
}
|
|
// "the same request_id and requesting_device_id fields, sent by the same user."
|
|
key := fmt.Sprintf("%s-%s-%s-%s", rows[i].Type, rows[i].Sender, m.Get("content.requesting_device_id").Str, m.Get("content.request_id").Str)
|
|
rows[i].UniqueKey = &key
|
|
}
|
|
if rows[i].Action == ActionCancel && rows[i].UniqueKey != nil {
|
|
cancels = append(cancels, *rows[i].UniqueKey)
|
|
allCancels[*rows[i].UniqueKey] = struct{}{}
|
|
} else if rows[i].Action == ActionRequest && rows[i].UniqueKey != nil {
|
|
allRequests[*rows[i].UniqueKey] = struct{}{}
|
|
}
|
|
}
|
|
if len(cancels) > 0 {
|
|
var cancelled []string
|
|
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
|
|
err = txn.Select(&cancelled, `DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key`,
|
|
pq.StringArray(cancels), userID, deviceID, unackPos)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete cancelled events: %s", err)
|
|
}
|
|
cancelledInDBSet := make(map[string]struct{}, len(cancelled))
|
|
for _, ukey := range cancelled {
|
|
cancelledInDBSet[ukey] = struct{}{}
|
|
}
|
|
// do not insert the cancelled unique keys
|
|
newRows := make([]ToDeviceRow, 0, len(rows))
|
|
for i := range rows {
|
|
if rows[i].UniqueKey != nil {
|
|
ukey := *rows[i].UniqueKey
|
|
_, exists := cancelledInDBSet[ukey]
|
|
if exists {
|
|
continue // the request was deleted so don't insert the cancel
|
|
}
|
|
// we may be requesting and cancelling in one go, check it and ignore if so
|
|
_, reqExists := allRequests[ukey]
|
|
_, cancelExists := allCancels[ukey]
|
|
if reqExists && cancelExists {
|
|
continue
|
|
}
|
|
}
|
|
newRows = append(newRows, rows[i])
|
|
}
|
|
rows = newRows
|
|
}
|
|
// we may have nothing to do if the entire set of events were cancellations
|
|
if len(rows) == 0 {
|
|
return nil
|
|
}
|
|
|
|
chunks := sqlutil.Chunkify(7, MaxPostgresParameters, ToDeviceRowChunker(rows))
|
|
for _, chunk := range chunks {
|
|
result, err := txn.NamedQuery(`INSERT INTO syncv3_to_device_messages (user_id, device_id, message, event_type, sender, action, unique_key)
|
|
VALUES (:user_id, :device_id, :message, :event_type, :sender, :action, :unique_key) RETURNING position`, chunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for result.Next() {
|
|
if err = result.Scan(&lastPos); err != nil {
|
|
result.Close()
|
|
return err
|
|
}
|
|
}
|
|
result.Close()
|
|
}
|
|
return nil
|
|
})
|
|
return lastPos, err
|
|
}
|