Implement MSC4102

This commit is contained in:
Kegan Dougal 2024-02-15 18:02:40 +00:00
parent bbb886efd9
commit 782703cd48
2 changed files with 200 additions and 16 deletions

View File

@ -11,11 +11,13 @@ import (
) )
type receiptEDU struct { type receiptEDU struct {
Type string `json:"type"` Type string `json:"type"`
Content map[string]struct { Content map[string]receiptContent `json:"content"`
Read map[string]receiptInfo `json:"m.read,omitempty"` }
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
} `json:"content"` type receiptContent struct {
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
} }
type receiptInfo struct { type receiptInfo struct {
@ -164,29 +166,35 @@ func (t *ReceiptTable) bulkInsert(tableName string, txn *sqlx.Tx, receipts []int
// client connections. // client connections.
func PackReceiptsIntoEDU(receipts []internal.Receipt) (json.RawMessage, error) { func PackReceiptsIntoEDU(receipts []internal.Receipt) (json.RawMessage, error) {
newReceiptEDU := receiptEDU{ newReceiptEDU := receiptEDU{
Type: "m.receipt", Type: "m.receipt",
Content: make(map[string]struct { Content: make(map[string]receiptContent),
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
}),
} }
for _, r := range receipts { for _, r := range receipts {
thisReceiptIsUnthreaded := r.ThreadID == ""
receiptsForEvent := newReceiptEDU.Content[r.EventID] receiptsForEvent := newReceiptEDU.Content[r.EventID]
if r.IsPrivate { if r.IsPrivate {
if receiptsForEvent.ReadPrivate == nil { if receiptsForEvent.ReadPrivate == nil {
receiptsForEvent.ReadPrivate = make(map[string]receiptInfo) receiptsForEvent.ReadPrivate = make(map[string]receiptInfo)
} }
receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{ // MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
TS: r.TS, _, receiptAlreadyExists := receiptsForEvent.ReadPrivate[r.UserID]
ThreadID: r.ThreadID, if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
} }
} else { } else {
if receiptsForEvent.Read == nil { if receiptsForEvent.Read == nil {
receiptsForEvent.Read = make(map[string]receiptInfo) receiptsForEvent.Read = make(map[string]receiptInfo)
} }
receiptsForEvent.Read[r.UserID] = receiptInfo{ // MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
TS: r.TS, _, receiptAlreadyExists := receiptsForEvent.Read[r.UserID]
ThreadID: r.ThreadID, if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.Read[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
} }
} }
newReceiptEDU.Content[r.EventID] = receiptsForEvent newReceiptEDU.Content[r.EventID] = receiptsForEvent

View File

@ -31,6 +31,182 @@ func parsedReceiptsEqual(t *testing.T, got, want []internal.Receipt) {
} }
} }
func TestReceiptPacking(t *testing.T) {
testCases := []struct {
receipts []internal.Receipt
wantEDU receiptEDU
name string
}{
{
name: "single receipt",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 42,
},
},
},
},
},
},
{
name: "two distinct receipt",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
},
{
RoomID: "!foo2",
EventID: "$bar2",
UserID: "@baz2",
TS: 422,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 42,
},
},
},
"$bar2": {
Read: map[string]receiptInfo{
"@baz2": {
TS: 422,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins when threaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins when unthreaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins in private receipts when unthreaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
IsPrivate: true,
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
IsPrivate: true,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
ReadPrivate: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
}
for _, tc := range testCases {
edu, err := PackReceiptsIntoEDU(tc.receipts)
if err != nil {
t.Fatalf("%s: PackReceiptsIntoEDU: %s", tc.name, err)
}
gotEDU := receiptEDU{
Type: "m.receipt",
Content: make(map[string]receiptContent),
}
if err := json.Unmarshal(edu, &gotEDU); err != nil {
t.Fatalf("%s: json.Unmarshal: %s", tc.name, err)
}
if !reflect.DeepEqual(gotEDU, tc.wantEDU) {
t.Errorf("%s: EDU mismatch, got %+v\n want %+v", tc.name, gotEDU, tc.wantEDU)
}
}
}
func TestReceiptTable(t *testing.T) { func TestReceiptTable(t *testing.T) {
db, close := connectToDB(t) db, close := connectToDB(t)
defer close() defer close()