sliding-sync/state/receipt_table.go
2024-02-15 18:02:40 +00:00

264 lines
8.3 KiB
Go

package state
import (
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sqlutil"
)
type receiptEDU struct {
Type string `json:"type"`
Content map[string]receiptContent `json:"content"`
}
type receiptContent struct {
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
}
type receiptInfo struct {
TS int64 `json:"ts"`
ThreadID string `json:"thread_id,omitempty"`
}
type ReceiptTable struct {
db *sqlx.DB
}
func NewReceiptTable(db *sqlx.DB) *ReceiptTable {
// we make 2 tables here to reduce the compound key size to be just room/user/thread and not
// room/user/thread/receipt_type. This should help performance somewhat when querying. Other than
// that, the tables are identical.
tableNames := []string{
"syncv3_receipts", "syncv3_receipts_private",
}
schema := `
CREATE TABLE IF NOT EXISTS %s (
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
thread_id TEXT NOT NULL,
event_id TEXT NOT NULL,
ts BIGINT NOT NULL,
UNIQUE(room_id, user_id, thread_id)
);
-- for querying by events in the timeline, need to search by event id
CREATE INDEX IF NOT EXISTS %s_by_event_idx ON %s(room_id, event_id);
-- for querying all receipts for a user in a room, need to search by user id
CREATE INDEX IF NOT EXISTS %s_by_user_idx ON %s(room_id, user_id);
`
for _, tableName := range tableNames {
db.MustExec(fmt.Sprintf(schema, tableName, tableName, tableName, tableName, tableName))
}
return &ReceiptTable{db}
}
// Insert new receipts based on a receipt EDU
// Returns newly inserted receipts, or nil if there are no new receipts.
// These newly inserted receipts can then be sent to the API processes for live updates.
func (t *ReceiptTable) Insert(roomID string, ephEvent json.RawMessage) (receipts []internal.Receipt, err error) {
readReceipts, privateReceipts, err := UnpackReceiptsFromEDU(roomID, ephEvent)
if err != nil {
return nil, err
}
if len(readReceipts) == 0 && len(privateReceipts) == 0 {
return nil, nil
}
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
readReceipts, err = t.bulkInsert("syncv3_receipts", txn, readReceipts)
if err != nil {
return err
}
privateReceipts, err = t.bulkInsert("syncv3_receipts_private", txn, privateReceipts)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to insert receipts: %s", err)
}
// no new receipts
if len(readReceipts) == 0 && len(privateReceipts) == 0 {
return nil, nil
}
// combine together new receipts
return append(readReceipts, privateReceipts...), nil
}
// Select all non-private receipts for the event IDs given. Events must be in the room ID given.
// The parsed receipts are returned so callers can use information in the receipts in further queries
// e.g to pull out profile information for users read receipts. Call PackReceiptsIntoEDU when sending to clients.
func (t *ReceiptTable) SelectReceiptsForEvents(roomID string, eventIDs []string) (receipts []internal.Receipt, err error) {
err = t.db.Select(&receipts, `SELECT room_id, event_id, user_id, ts, thread_id FROM syncv3_receipts
WHERE room_id=$1 AND event_id = ANY($2)`, roomID, pq.StringArray(eventIDs))
return
}
// Select all (including private) receipts for this user in these rooms.
func (t *ReceiptTable) SelectReceiptsForUser(roomIDs []string, userID string) (receiptsByRoom map[string][]internal.Receipt, err error) {
var receipts []internal.Receipt
err = t.db.Select(&receipts, `SELECT room_id, event_id, user_id, ts, thread_id FROM syncv3_receipts
WHERE room_id=ANY($1) AND user_id = $2`, pq.StringArray(roomIDs), userID)
if err != nil {
return nil, err
}
var privReceipts []internal.Receipt
err = t.db.Select(&privReceipts, `SELECT room_id, event_id, user_id, ts, thread_id FROM syncv3_receipts_private
WHERE room_id=ANY($1) AND user_id = $2`, pq.StringArray(roomIDs), userID)
if err != nil {
return nil, err
}
for i := range privReceipts {
privReceipts[i].IsPrivate = true
}
receipts = append(receipts, privReceipts...)
receiptsByRoom = make(map[string][]internal.Receipt)
// bucket by room
for _, r := range receipts {
receiptsByRoom[r.RoomID] = append(receiptsByRoom[r.RoomID], r)
}
return receiptsByRoom, nil
}
func (t *ReceiptTable) bulkInsert(tableName string, txn *sqlx.Tx, receipts []internal.Receipt) (newReceipts []internal.Receipt, err error) {
if len(receipts) == 0 {
return
}
chunks := sqlutil.Chunkify(5, MaxPostgresParameters, ReceiptChunker(receipts))
var eventID string
var roomID string
var threadID string
var userID string
var ts int64
for _, chunk := range chunks {
rows, err := txn.NamedQuery(`
INSERT INTO `+tableName+` AS old (room_id, event_id, user_id, ts, thread_id)
VALUES (:room_id, :event_id, :user_id, :ts, :thread_id) ON CONFLICT (room_id, user_id, thread_id) DO UPDATE SET event_id=excluded.event_id, ts=excluded.ts WHERE old.event_id <> excluded.event_id
RETURNING room_id, user_id, thread_id, event_id, ts`, chunk)
if err != nil {
return nil, err
}
for rows.Next() {
if err := rows.Scan(&roomID, &userID, &threadID, &eventID, &ts); err != nil {
rows.Close()
return nil, err
}
newReceipts = append(newReceipts, internal.Receipt{
RoomID: roomID,
EventID: eventID,
UserID: userID,
TS: ts,
ThreadID: threadID,
IsPrivate: tableName == "syncv3_receipts_private",
})
}
rows.Close()
}
return
}
// PackReceiptsIntoEDU bundles all the receipts into a single m.receipt EDU, suitable for sending down
// client connections.
func PackReceiptsIntoEDU(receipts []internal.Receipt) (json.RawMessage, error) {
newReceiptEDU := receiptEDU{
Type: "m.receipt",
Content: make(map[string]receiptContent),
}
for _, r := range receipts {
thisReceiptIsUnthreaded := r.ThreadID == ""
receiptsForEvent := newReceiptEDU.Content[r.EventID]
if r.IsPrivate {
if receiptsForEvent.ReadPrivate == nil {
receiptsForEvent.ReadPrivate = make(map[string]receiptInfo)
}
// MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
_, receiptAlreadyExists := receiptsForEvent.ReadPrivate[r.UserID]
if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
}
} else {
if receiptsForEvent.Read == nil {
receiptsForEvent.Read = make(map[string]receiptInfo)
}
// MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
_, receiptAlreadyExists := receiptsForEvent.Read[r.UserID]
if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.Read[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
}
}
newReceiptEDU.Content[r.EventID] = receiptsForEvent
}
return json.Marshal(newReceiptEDU)
}
func UnpackReceiptsFromEDU(roomID string, ephEvent json.RawMessage) (readReceipts, privateReceipts []internal.Receipt, err error) {
// unpack the receipts, of the form:
// {
// "content": {
// "$1435641916114394fHBLK:matrix.org": {
// "m.read": {
// "@rikj:jki.re": {
// "ts": 1436451550453,
// "thread_id": "$aaabbbccc"
// }
// },
// "m.read.private": {
// "@self:example.org": {
// "ts": 1661384801651
// }
// }
// }
// },
// "type": "m.receipt"
// }
var edu receiptEDU
if err := json.Unmarshal(ephEvent, &edu); err != nil {
return nil, nil, err
}
if edu.Type != "m.receipt" {
return
}
for eventID, content := range edu.Content {
for userID, val := range content.Read {
readReceipts = append(readReceipts, internal.Receipt{
UserID: userID,
RoomID: roomID,
EventID: eventID,
TS: val.TS,
ThreadID: val.ThreadID,
})
}
for userID, val := range content.ReadPrivate {
privateReceipts = append(privateReceipts, internal.Receipt{
UserID: userID,
RoomID: roomID,
EventID: eventID,
TS: val.TS,
ThreadID: val.ThreadID,
IsPrivate: true,
})
}
}
return readReceipts, privateReceipts, nil
}
type ReceiptChunker []internal.Receipt
func (c ReceiptChunker) Len() int {
return len(c)
}
func (c ReceiptChunker) Subslice(i, j int) sqlutil.Chunker {
return c[i:j]
}