2021-08-02 16:45:09 +01:00
package state
import (
2021-12-14 14:38:39 +00:00
"database/sql"
2021-08-02 16:45:09 +01:00
"encoding/json"
2022-12-14 18:53:55 +00:00
"fmt"
2021-08-02 16:45:09 +01:00
"github.com/jmoiron/sqlx"
2022-12-14 18:53:55 +00:00
"github.com/lib/pq"
2022-12-15 11:08:50 +00:00
"github.com/matrix-org/sliding-sync/sqlutil"
2022-02-22 17:04:31 +00:00
"github.com/tidwall/gjson"
2021-08-02 16:45:09 +01:00
)
2022-12-14 18:53:55 +00:00
const (
ActionRequest = 1
ActionCancel = 2
)
2021-08-02 16:45:09 +01:00
// ToDeviceTable stores to_device messages for devices.
2021-08-03 12:06:09 +01:00
type ToDeviceTable struct {
2022-12-14 18:53:55 +00:00
db * sqlx . DB
2021-08-03 12:06:09 +01:00
}
2021-08-02 16:45:09 +01:00
type ToDeviceRow struct {
2022-12-14 18:53:55 +00:00
Position int64 ` db:"position" `
2023-05-02 14:34:22 +01:00
UserID string ` db:"user_id" `
2022-12-14 18:53:55 +00:00
DeviceID string ` db:"device_id" `
Message string ` db:"message" `
Type string ` db:"event_type" `
Sender string ` db:"sender" `
UniqueKey * string ` db:"unique_key" `
Action int ` db:"action" `
2021-08-02 16:45:09 +01:00
}
2021-08-03 09:33:38 +01:00
type ToDeviceRowChunker [ ] ToDeviceRow
func ( c ToDeviceRowChunker ) Len ( ) int {
return len ( c )
}
func ( c ToDeviceRowChunker ) Subslice ( i , j int ) sqlutil . Chunker {
return c [ i : j ]
}
func NewToDeviceTable ( db * sqlx . DB ) * ToDeviceTable {
2021-08-02 16:45:09 +01:00
// make sure tables are made
db . MustExec ( `
2021-08-03 09:33:38 +01:00
CREATE SEQUENCE IF NOT EXISTS syncv3_to_device_messages_seq ;
2021-08-02 16:45:09 +01:00
CREATE TABLE IF NOT EXISTS syncv3_to_device_messages (
2021-08-03 09:33:38 +01:00
position BIGINT NOT NULL PRIMARY KEY DEFAULT nextval ( ' syncv3_to_device_messages_seq ' ) ,
2023-05-02 14:34:22 +01:00
user_id TEXT NOT NULL ,
2021-08-03 09:33:38 +01:00
device_id TEXT NOT NULL ,
2022-02-22 17:04:31 +00:00
event_type TEXT NOT NULL ,
sender TEXT NOT NULL ,
2022-12-14 18:53:55 +00:00
message TEXT NOT NULL ,
-- nullable as these fields are not on all to - device events
unique_key TEXT ,
action SMALLINT DEFAULT 0 -- 0 means unknown
) ;
CREATE TABLE IF NOT EXISTS syncv3_to_device_ack_pos (
2023-05-03 11:25:37 +01:00
user_id TEXT NOT NULL ,
2023-05-02 14:34:22 +01:00
device_id TEXT NOT NULL ,
PRIMARY KEY ( user_id , device_id ) ,
2022-12-14 18:53:55 +00:00
unack_pos BIGINT NOT NULL
2021-08-02 16:45:09 +01:00
) ;
2021-08-03 09:33:38 +01:00
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_device_idx ON syncv3_to_device_messages ( device_id ) ;
2022-12-14 18:53:55 +00:00
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_ukey_idx ON syncv3_to_device_messages ( unique_key , device_id ) ;
2023-01-04 13:59:30 +00:00
CREATE INDEX IF NOT EXISTS syncv3_to_device_messages_pos_device_idx ON syncv3_to_device_messages ( position , device_id ) ;
2021-08-02 16:45:09 +01:00
` )
2022-12-14 18:53:55 +00:00
return & ToDeviceTable { db }
}
2023-05-02 14:34:22 +01:00
func ( t * ToDeviceTable ) SetUnackedPosition ( userID , deviceID string , pos int64 ) error {
_ , err := t . db . Exec ( ` INSERT INTO syncv3_to_device_ack_pos ( user_id , device_id , unack_pos ) VALUES ( $ 1 , $ 2 , $ 3 ) ON CONFLICT ( user_id , device_id )
DO UPDATE SET unack_pos = excluded . unack_pos ` , userID , deviceID , pos )
2022-12-14 18:53:55 +00:00
return err
2021-08-03 09:33:38 +01:00
}
2023-05-02 14:34:22 +01:00
func ( t * ToDeviceTable ) DeleteMessagesUpToAndIncluding ( userID , deviceID string , toIncl int64 ) error {
_ , err := t . db . Exec ( ` DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position <= $3 ` , userID , deviceID , toIncl )
2021-08-05 12:50:03 +01:00
return err
}
2023-05-02 14:34:22 +01:00
func ( t * ToDeviceTable ) DeleteAllMessagesForDevice ( userID , deviceID string ) error {
2023-05-10 11:22:46 +01:00
// TODO: should these deletes take place in a transaction?
2023-05-02 14:34:22 +01:00
_ , err := t . db . Exec ( ` DELETE FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 ` , userID , deviceID )
2023-05-10 11:22:46 +01:00
if err != nil {
return err
}
_ , err = t . db . Exec ( ` DELETE FROM syncv3_to_device_ack_pos WHERE user_id = $1 AND device_id = $2 ` , userID , deviceID )
2023-03-01 16:56:04 +00:00
return err
}
2023-05-02 14:34:22 +01:00
// Messages fetches up to `limit` to-device messages for this device, starting from and excluding `from`.
// Returns the fetches messages ordered by ascending position, as well as the position of the last to-device message
// fetched.
func ( t * ToDeviceTable ) Messages ( userID , deviceID string , from , limit int64 ) ( msgs [ ] json . RawMessage , upTo int64 , err error ) {
2022-12-14 18:53:55 +00:00
upTo = from
2021-08-03 09:33:38 +01:00
var rows [ ] ToDeviceRow
2021-08-03 12:06:09 +01:00
err = t . db . Select ( & rows ,
2023-05-02 14:34:22 +01:00
` SELECT position, message FROM syncv3_to_device_messages WHERE user_id = $1 AND device_id = $2 AND position > $3 ORDER BY position ASC LIMIT $4 ` ,
userID , deviceID , from , limit ,
2021-08-03 12:06:09 +01:00
)
2021-08-03 09:33:38 +01:00
if len ( rows ) == 0 {
return
}
2021-08-03 12:06:09 +01:00
msgs = make ( [ ] json . RawMessage , len ( rows ) )
2021-08-03 09:33:38 +01:00
for i := range rows {
2021-08-03 12:06:09 +01:00
msgs [ i ] = json . RawMessage ( rows [ i ] . Message )
2022-12-21 11:28:43 +00:00
m := gjson . ParseBytes ( msgs [ i ] )
msgId := m . Get ( ` content.org\.matrix\.msgid ` ) . Str
if msgId != "" {
2023-05-02 14:34:22 +01:00
logger . Info ( ) . Str ( "msgid" , msgId ) . Str ( "user" , userID ) . Str ( "device" , deviceID ) . Msg ( "ToDeviceTable.Messages" )
2022-12-21 11:28:43 +00:00
}
2021-08-03 09:33:38 +01:00
}
2021-08-05 10:54:04 +01:00
upTo = rows [ len ( rows ) - 1 ] . Position
2021-08-03 09:33:38 +01:00
return
2021-08-02 16:45:09 +01:00
}
2023-05-02 14:34:22 +01:00
func ( t * ToDeviceTable ) InsertMessages ( userID , deviceID string , msgs [ ] json . RawMessage ) ( pos int64 , err error ) {
2021-08-03 12:06:09 +01:00
var lastPos int64
err = sqlutil . WithTransaction ( t . db , func ( txn * sqlx . Tx ) error {
2022-12-14 18:53:55 +00:00
var unackPos int64
2023-05-02 14:34:22 +01:00
err = txn . QueryRow ( ` SELECT unack_pos FROM syncv3_to_device_ack_pos WHERE user_id=$1 AND device_id=$2 ` , userID , deviceID ) . Scan ( & unackPos )
2022-12-14 18:53:55 +00:00
if err != nil && err != sql . ErrNoRows {
return fmt . Errorf ( "unable to select unacked pos: %s" , err )
}
// Some of these events may be "cancel" actions. If we find events for the unique key of this event, then delete them
// and ignore the "cancel" action.
cancels := [ ] string { }
allRequests := make ( map [ string ] struct { } )
allCancels := make ( map [ string ] struct { } )
2021-08-03 12:06:09 +01:00
rows := make ( [ ] ToDeviceRow , len ( msgs ) )
for i := range msgs {
2022-02-22 17:04:31 +00:00
m := gjson . ParseBytes ( msgs [ i ] )
2021-08-03 12:06:09 +01:00
rows [ i ] = ToDeviceRow {
2023-05-02 14:34:22 +01:00
UserID : userID ,
2021-08-03 12:06:09 +01:00
DeviceID : deviceID ,
2021-12-14 11:51:47 +00:00
Message : string ( msgs [ i ] ) ,
2022-02-22 17:04:31 +00:00
Type : m . Get ( "type" ) . Str ,
Sender : m . Get ( "sender" ) . Str ,
2021-08-03 12:06:09 +01:00
}
2022-12-21 11:28:43 +00:00
msgId := m . Get ( ` content.org\.matrix\.msgid ` ) . Str
if msgId != "" {
2023-05-02 14:34:22 +01:00
logger . Debug ( ) . Str ( "msgid" , msgId ) . Str ( "user" , userID ) . Str ( "device" , deviceID ) . Msg ( "ToDeviceTable.InsertMessages" )
2022-12-21 11:28:43 +00:00
}
2022-12-14 18:53:55 +00:00
switch rows [ i ] . Type {
case "m.room_key_request" :
action := m . Get ( "content.action" ) . Str
if action == "request" {
rows [ i ] . Action = ActionRequest
} else if action == "request_cancellation" {
rows [ i ] . Action = ActionCancel
}
// "the same request_id and requesting_device_id fields, sent by the same user."
key := fmt . Sprintf ( "%s-%s-%s-%s" , rows [ i ] . Type , rows [ i ] . Sender , m . Get ( "content.requesting_device_id" ) . Str , m . Get ( "content.request_id" ) . Str )
rows [ i ] . UniqueKey = & key
}
if rows [ i ] . Action == ActionCancel && rows [ i ] . UniqueKey != nil {
cancels = append ( cancels , * rows [ i ] . UniqueKey )
allCancels [ * rows [ i ] . UniqueKey ] = struct { } { }
} else if rows [ i ] . Action == ActionRequest && rows [ i ] . UniqueKey != nil {
allRequests [ * rows [ i ] . UniqueKey ] = struct { } { }
}
}
if len ( cancels ) > 0 {
var cancelled [ ] string
// delete action: request events which have the same unique key, for this device inbox, only if they are not sent to the client already (unacked)
2023-05-02 14:34:22 +01:00
err = txn . Select ( & cancelled , ` DELETE FROM syncv3_to_device_messages WHERE unique_key = ANY($1) AND user_id = $2 AND device_id = $3 AND position > $4 RETURNING unique_key ` ,
pq . StringArray ( cancels ) , userID , deviceID , unackPos )
2022-12-14 18:53:55 +00:00
if err != nil {
return fmt . Errorf ( "failed to delete cancelled events: %s" , err )
}
cancelledInDBSet := make ( map [ string ] struct { } , len ( cancelled ) )
for _ , ukey := range cancelled {
cancelledInDBSet [ ukey ] = struct { } { }
}
// do not insert the cancelled unique keys
newRows := make ( [ ] ToDeviceRow , 0 , len ( rows ) )
for i := range rows {
if rows [ i ] . UniqueKey != nil {
ukey := * rows [ i ] . UniqueKey
_ , exists := cancelledInDBSet [ ukey ]
if exists {
continue // the request was deleted so don't insert the cancel
}
// we may be requesting and cancelling in one go, check it and ignore if so
_ , reqExists := allRequests [ ukey ]
_ , cancelExists := allCancels [ ukey ]
if reqExists && cancelExists {
continue
}
}
newRows = append ( newRows , rows [ i ] )
}
rows = newRows
}
// we may have nothing to do if the entire set of events were cancellations
if len ( rows ) == 0 {
return nil
2021-08-02 16:45:09 +01:00
}
2021-08-03 12:06:09 +01:00
2023-05-02 14:34:22 +01:00
chunks := sqlutil . Chunkify ( 7 , MaxPostgresParameters , ToDeviceRowChunker ( rows ) )
2021-08-03 12:06:09 +01:00
for _ , chunk := range chunks {
2023-05-02 14:34:22 +01:00
result , err := txn . NamedQuery ( ` INSERT INTO syncv3_to_device_messages ( user_id , device_id , message , event_type , sender , action , unique_key )
VALUES ( : user_id , : device_id , : message , : event_type , : sender , : action , : unique_key ) RETURNING position ` , chunk )
2021-08-03 12:06:09 +01:00
if err != nil {
return err
}
for result . Next ( ) {
if err = result . Scan ( & lastPos ) ; err != nil {
2022-12-14 18:53:55 +00:00
result . Close ( )
2021-08-03 12:06:09 +01:00
return err
}
}
result . Close ( )
2021-08-03 09:33:38 +01:00
}
2021-08-03 12:06:09 +01:00
return nil
} )
return lastPos , err
2021-08-02 16:45:09 +01:00
}