mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
824 lines
25 KiB
Go
824 lines
25 KiB
Go
package m
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"testing"
|
|
|
|
"github.com/matrix-org/sliding-sync/sync3"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type RespMatcher func(res *sync3.Response) error
|
|
type ListMatcher func(list sync3.ResponseList) error
|
|
type OpMatcher func(op sync3.ResponseOp) error
|
|
type RoomMatcher func(r sync3.Room) error
|
|
|
|
// LogRoom builds a matcher that always succeeds. As a side-effect, it pretty-prints
|
|
// the given room to the test log. This is useful when debugging a test.
|
|
func LogRoom(t *testing.T) RoomMatcher {
|
|
return func(room sync3.Room) error {
|
|
dump, _ := json.MarshalIndent(room, "", " ")
|
|
t.Logf("Response was: %s", dump)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomName(name string) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if name == "" {
|
|
return nil
|
|
}
|
|
if r.Name != name {
|
|
return fmt.Errorf("name mismatch, got %s want %s", r.Name, name)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchRoomAvatar builds a RoomMatcher which checks that the given room response has
|
|
// set the room's avatar to the given value.
|
|
func MatchRoomAvatar(wantAvatar string) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if string(r.AvatarChange) != wantAvatar {
|
|
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, wantAvatar)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchRoomUnsetAvatar builds a RoomMatcher which checks that the given room has no
|
|
// avatar, or has had its avatar deleted.
|
|
func MatchRoomUnsetAvatar() RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.AvatarChange != sync3.DeletedAvatar {
|
|
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.DeletedAvatar)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchRoomUnchangedAvatar builds a RoomMatcher which checks that the given room has no
|
|
// change to its avatar, or has had its avatar deleted.
|
|
func MatchRoomUnchangedAvatar() RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.AvatarChange != sync3.UnchangedAvatar {
|
|
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.UnchangedAvatar)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomIsDM(wantDM bool) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.IsDM != wantDM {
|
|
return fmt.Errorf("MatchRoomIsDM: got %t want %t", r.IsDM, wantDM)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchJoinCount(count int) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.JoinedCount != count {
|
|
return fmt.Errorf("MatchJoinCount: got %v want %v", r.JoinedCount, count)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNoInviteCount() RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.InvitedCount != nil {
|
|
return fmt.Errorf("MatchInviteCount: invited_count is present when it should be missing: val=%v", *r.InvitedCount)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchInviteCount(count int) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.InvitedCount == nil {
|
|
return fmt.Errorf("MatchInviteCount: invited_count is missing")
|
|
}
|
|
if *r.InvitedCount != count {
|
|
return fmt.Errorf("MatchInviteCount: got %v want %v", *r.InvitedCount, count)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNumLive(numLive int) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.NumLive != numLive {
|
|
return fmt.Errorf("MatchNumLive: got %v want %v", r.NumLive, numLive)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomRequiredState(events []json.RawMessage) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.RequiredState) != len(events) {
|
|
return fmt.Errorf("required state length mismatch, got %d want %d", len(r.RequiredState), len(events))
|
|
}
|
|
// allow any ordering for required state
|
|
for _, want := range events {
|
|
found := false
|
|
for _, got := range r.RequiredState {
|
|
if bytes.Equal(got, want) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("required state want event %v but it does not exist", string(want))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
func MatchRoomInviteState(events []json.RawMessage) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.InviteState) != len(events) {
|
|
return fmt.Errorf("invite state length mismatch, got %d want %d", len(r.InviteState), len(events))
|
|
}
|
|
// allow any ordering for required state
|
|
for _, want := range events {
|
|
found := false
|
|
for _, got := range r.InviteState {
|
|
if bytes.Equal(got, want) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("required state want event %v but it does not exist", string(want))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomHasInviteState() RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.InviteState) == 0 {
|
|
return fmt.Errorf("missing or empty invite state, expected at least one piece of invite state")
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomLacksInviteState() RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.InviteState) > 0 {
|
|
return fmt.Errorf("invite state present, but expected no invite state")
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Similar to MatchRoomTimeline but takes the last n events of `events` and only checks with the last
|
|
// n events of the timeline.
|
|
func MatchRoomTimelineMostRecent(n int, events []json.RawMessage) RoomMatcher {
|
|
subset := events[len(events)-n:]
|
|
return func(r sync3.Room) error {
|
|
if len(r.Timeline) < len(subset) {
|
|
return fmt.Errorf("MatchRoomTimelineMostRecent: timeline length mismatch: got %d want at least %d", len(r.Timeline), len(subset))
|
|
}
|
|
gotSubset := r.Timeline[len(r.Timeline)-n:]
|
|
for i := range gotSubset {
|
|
if !bytes.Equal(gotSubset[i], subset[i]) {
|
|
return fmt.Errorf("timeline[%d]\ngot %v \nwant %v", i, string(gotSubset[i]), string(subset[i]))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomPrevBatch(prevBatch string) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if prevBatch != r.PrevBatch {
|
|
return fmt.Errorf("MatchRoomPrevBatch: got %v want %v", r.PrevBatch, prevBatch)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Match the timeline with exactly these events in exactly this order
|
|
func MatchRoomTimeline(events []json.RawMessage) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.Timeline) != len(events) {
|
|
return fmt.Errorf("timeline length mismatch: got %d want %d", len(r.Timeline), len(events))
|
|
}
|
|
for i := range r.Timeline {
|
|
if !bytes.Equal(r.Timeline[i], events[i]) {
|
|
return fmt.Errorf("timeline[%d]\ngot %v \nwant %v", i, string(r.Timeline[i]), string(events[i]))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomHighlightCount(count int64) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.HighlightCount != count {
|
|
return fmt.Errorf("highlight count mismatch, got %d want %d", r.HighlightCount, count)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
func MatchRoomNotificationCount(count int64) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.NotificationCount != count {
|
|
return fmt.Errorf("notification count mismatch, got %d want %d", r.NotificationCount, count)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomInitial(initial bool) RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if r.Initial != initial {
|
|
return fmt.Errorf("MatchRoomInitial: got %v want %v", r.Initial, initial)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3Count(wantCount int) ListMatcher {
|
|
return func(res sync3.ResponseList) error {
|
|
if res.Count != wantCount {
|
|
return fmt.Errorf("list got count %d want %d", res.Count, wantCount)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomSubscriptionsStrict(wantSubs map[string][]RoomMatcher) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if len(res.Rooms) != len(wantSubs) {
|
|
return fmt.Errorf("MatchRoomSubscriptionsStrict: strict length on: got %v subs want %v", len(res.Rooms), len(wantSubs))
|
|
}
|
|
for roomID, matchers := range wantSubs {
|
|
room, ok := res.Rooms[roomID]
|
|
if !ok {
|
|
return fmt.Errorf("MatchRoomSubscriptionsStrict: want sub for %s but it was missing", roomID)
|
|
}
|
|
for _, m := range matchers {
|
|
if err := m(room); err != nil {
|
|
return fmt.Errorf("MatchRoomSubscriptionsStrict[%s]: %s", roomID, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomSubscription(roomID string, matchers ...RoomMatcher) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
room, ok := res.Rooms[roomID]
|
|
if !ok {
|
|
return fmt.Errorf("MatchRoomSubscription[%s]: want sub but it was missing", roomID)
|
|
}
|
|
errs := make([]error, 0, len(matchers))
|
|
for _, m := range matchers {
|
|
if err := m(room); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
if len(errs) > 1 {
|
|
return fmt.Errorf("MatchRoomSubscription[%s]: %d errors:\n%w", roomID, len(errs), errors.Join(errs...))
|
|
} else if len(errs) == 1 {
|
|
return fmt.Errorf("MatchRoomSubscription[%s]: %w", roomID, errs[0])
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomSubscriptions(wantSubs map[string][]RoomMatcher) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
for roomID, matchers := range wantSubs {
|
|
room, ok := res.Rooms[roomID]
|
|
if !ok {
|
|
return fmt.Errorf("MatchRoomSubscriptions: want sub for %s but it was missing", roomID)
|
|
}
|
|
for _, m := range matchers {
|
|
if err := m(room); err != nil {
|
|
return fmt.Errorf("MatchRoomSubscriptions[%s]: %s", roomID, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNoE2EEExtension() RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.E2EE != nil {
|
|
return fmt.Errorf("MatchNoE2EEExtension: got E2EE extension: %+v", res.Extensions.E2EE)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNoReceiptsExtension() RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.Receipts != nil {
|
|
return fmt.Errorf("MatchNoReceiptsExtension: got Receipts extension: %+v", res.Extensions.Receipts)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.E2EE == nil {
|
|
return fmt.Errorf("MatchOTKCounts: no E2EE extension present")
|
|
}
|
|
if !reflect.DeepEqual(res.Extensions.E2EE.OTKCounts, otkCounts) {
|
|
return fmt.Errorf("MatchOTKCounts: got %v want %v", res.Extensions.E2EE.OTKCounts, otkCounts)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchFallbackKeyTypes(fallbackKeyTypes []string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.E2EE == nil {
|
|
return fmt.Errorf("MatchFallbackKeyTypes: no E2EE extension present")
|
|
}
|
|
if res.Extensions.E2EE.FallbackKeyTypes == nil { // not supplied
|
|
if fallbackKeyTypes == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("MatchFallbackKeyTypes: FallbackKeyTypes is missing but want %v", fallbackKeyTypes)
|
|
}
|
|
if !reflect.DeepEqual(*res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
|
|
return fmt.Errorf("MatchFallbackKeyTypes: got %v want %v", res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchDeviceLists(changed, left []string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.E2EE == nil {
|
|
return fmt.Errorf("MatchDeviceLists: no E2EE extension present")
|
|
}
|
|
if res.Extensions.E2EE.DeviceLists == nil {
|
|
return fmt.Errorf("MatchDeviceLists: no device lists present")
|
|
}
|
|
if !reflect.DeepEqual(res.Extensions.E2EE.DeviceLists.Changed, changed) {
|
|
return fmt.Errorf("MatchDeviceLists: got changed: %v want %v", res.Extensions.E2EE.DeviceLists.Changed, changed)
|
|
}
|
|
if !reflect.DeepEqual(res.Extensions.E2EE.DeviceLists.Left, left) {
|
|
return fmt.Errorf("MatchDeviceLists: got left: %v want %v", res.Extensions.E2EE.DeviceLists.Left, left)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchToDeviceMessages(wantMsgs []json.RawMessage) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.ToDevice == nil {
|
|
return fmt.Errorf("MatchToDeviceMessages: missing to_device extension")
|
|
}
|
|
if len(res.Extensions.ToDevice.Events) != len(wantMsgs) {
|
|
return fmt.Errorf("MatchToDeviceMessages: got %d events, want %d", len(res.Extensions.ToDevice.Events), len(wantMsgs))
|
|
}
|
|
for i := 0; i < len(wantMsgs); i++ {
|
|
if !reflect.DeepEqual(res.Extensions.ToDevice.Events[i], wantMsgs[i]) {
|
|
return fmt.Errorf("MatchToDeviceMessages[%d]: got %v want %v", i, string(res.Extensions.ToDevice.Events[i]), string(wantMsgs[i]))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3SyncOp(start, end int64, roomIDs []string, anyOrder ...bool) OpMatcher {
|
|
allowAnyOrder := len(anyOrder) > 0 && anyOrder[0]
|
|
return func(op sync3.ResponseOp) error {
|
|
if op.Op() != sync3.OpSync {
|
|
return fmt.Errorf("op: %s != %s", op.Op(), sync3.OpSync)
|
|
}
|
|
oper := op.(*sync3.ResponseOpRange)
|
|
if oper.Range[0] != start {
|
|
return fmt.Errorf("%s: got start %d want %d", sync3.OpSync, oper.Range[0], start)
|
|
}
|
|
if oper.Range[1] != end {
|
|
return fmt.Errorf("%s: got end %d want %d", sync3.OpSync, oper.Range[1], end)
|
|
}
|
|
if allowAnyOrder {
|
|
sort.Strings(oper.RoomIDs)
|
|
sort.Strings(roomIDs)
|
|
}
|
|
if !reflect.DeepEqual(roomIDs, oper.RoomIDs) {
|
|
return fmt.Errorf("%s: got rooms %v want %v", sync3.OpSync, oper.RoomIDs, roomIDs)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3SyncOpFn(fn func(op *sync3.ResponseOpRange) error) OpMatcher {
|
|
return func(op sync3.ResponseOp) error {
|
|
if op.Op() != sync3.OpSync {
|
|
return fmt.Errorf("op: %s != %s", op.Op(), sync3.OpSync)
|
|
}
|
|
oper := op.(*sync3.ResponseOpRange)
|
|
return fn(oper)
|
|
}
|
|
}
|
|
|
|
func MatchV3InsertOp(roomIndex int, roomID string) OpMatcher {
|
|
return func(op sync3.ResponseOp) error {
|
|
if op.Op() != sync3.OpInsert {
|
|
return fmt.Errorf("op: %s != %s", op.Op(), sync3.OpInsert)
|
|
}
|
|
oper := op.(*sync3.ResponseOpSingle)
|
|
if *oper.Index != roomIndex {
|
|
return fmt.Errorf("%s: got index %d want %d", sync3.OpInsert, *oper.Index, roomIndex)
|
|
}
|
|
if oper.RoomID != roomID {
|
|
return fmt.Errorf("%s: got %s want %s", sync3.OpInsert, oper.RoomID, roomID)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3DeleteOp(roomIndex int) OpMatcher {
|
|
return func(op sync3.ResponseOp) error {
|
|
if op.Op() != sync3.OpDelete {
|
|
return fmt.Errorf("op: %s != %s", op.Op(), sync3.OpDelete)
|
|
}
|
|
oper := op.(*sync3.ResponseOpSingle)
|
|
if *oper.Index != roomIndex {
|
|
return fmt.Errorf("%s: got room index %d want %d", sync3.OpDelete, *oper.Index, roomIndex)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3InvalidateOp(start, end int64) OpMatcher {
|
|
return func(op sync3.ResponseOp) error {
|
|
if op.Op() != sync3.OpInvalidate {
|
|
return fmt.Errorf("op: %s != %s", op.Op(), sync3.OpInvalidate)
|
|
}
|
|
oper := op.(*sync3.ResponseOpRange)
|
|
if oper.Range[0] != start {
|
|
return fmt.Errorf("%s: got start %d want %d", sync3.OpInvalidate, oper.Range[0], start)
|
|
}
|
|
if oper.Range[1] != end {
|
|
return fmt.Errorf("%s: got end %d want %d", sync3.OpInvalidate, oper.Range[1], end)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNoV3Ops() RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
for key, l := range res.Lists {
|
|
if len(l.Ops) > 0 {
|
|
return fmt.Errorf("MatchNoV3Ops: list %v got %d ops", key, len(l.Ops))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchV3Ops(matchOps ...OpMatcher) ListMatcher {
|
|
return func(res sync3.ResponseList) error {
|
|
if len(matchOps) != len(res.Ops) {
|
|
return fmt.Errorf("MatchV3Ops: got %d ops want %d", len(res.Ops), len(matchOps))
|
|
}
|
|
for i := range res.Ops {
|
|
op := res.Ops[i]
|
|
if err := matchOps[i](op); err != nil {
|
|
return fmt.Errorf("MatchV3Ops: op[%d](%s) - %s", i, op.Op(), err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchTyping(roomID string, wantUserIDs []string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.Typing == nil {
|
|
return fmt.Errorf("MatchTyping: no typing extension")
|
|
}
|
|
if len(res.Extensions.Typing.Rooms) == 0 || res.Extensions.Typing.Rooms[roomID] == nil {
|
|
serialised, _ := json.Marshal(res.Extensions.Typing)
|
|
return fmt.Errorf("MatchTyping: missing room %s: got %s", roomID, serialised)
|
|
}
|
|
sort.Strings(wantUserIDs)
|
|
ev := res.Extensions.Typing.Rooms[roomID]
|
|
userIDs := gjson.ParseBytes(ev).Get("content.user_ids").Array()
|
|
gotUserIDs := make([]string, len(userIDs))
|
|
for i := range userIDs {
|
|
gotUserIDs[i] = userIDs[i].Str
|
|
}
|
|
sort.Strings(gotUserIDs)
|
|
if !reflect.DeepEqual(gotUserIDs, wantUserIDs) {
|
|
return fmt.Errorf("MatchTyping: mismatched typing users, got %v want %v", gotUserIDs, wantUserIDs)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchNotTyping(roomID string, dontWantUserIDs []string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.Typing == nil {
|
|
return nil
|
|
}
|
|
if len(res.Extensions.Typing.Rooms) == 0 || res.Extensions.Typing.Rooms[roomID] == nil {
|
|
return nil
|
|
}
|
|
ev := res.Extensions.Typing.Rooms[roomID]
|
|
typingIDs := gjson.ParseBytes(ev).Get("content.user_ids").Array()
|
|
|
|
for _, dontWantID := range dontWantUserIDs {
|
|
for _, typingID := range typingIDs {
|
|
// Quick and dirty: report the first mismatch we see.
|
|
if dontWantID == typingID.Str {
|
|
return fmt.Errorf("MatchTyping: user %s should not be typing in %s, but is", dontWantID, roomID)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type Receipt struct {
|
|
EventID string
|
|
UserID string
|
|
Type string
|
|
ThreadID string
|
|
}
|
|
|
|
func sortReceipts(receipts []Receipt) {
|
|
sort.Slice(receipts, func(i, j int) bool {
|
|
keyi := receipts[i].EventID + receipts[i].UserID + receipts[i].Type + receipts[i].ThreadID
|
|
keyj := receipts[j].EventID + receipts[j].UserID + receipts[j].Type + receipts[j].ThreadID
|
|
return keyi < keyj
|
|
})
|
|
}
|
|
|
|
// MatchReceipts builds a matcher which asserts that a sync response has the expected
|
|
// set of read receipts in a given room is the expected set of `wantReceipts`.
|
|
//
|
|
// The match fails if:
|
|
// - there is no receipts extension in the sync response,
|
|
// - the room is missing from the sync response and `wantReceipts` is nonempty,
|
|
// - the room is present in the sync response but has a different set of receipts
|
|
// to `wantReceipts`.
|
|
func MatchReceipts(roomID string, wantReceipts []Receipt) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.Receipts == nil {
|
|
return fmt.Errorf("MatchReceipts: no receipts extension")
|
|
}
|
|
if len(res.Extensions.Receipts.Rooms) == 0 || res.Extensions.Receipts.Rooms[roomID] == nil {
|
|
if len(wantReceipts) == 0 {
|
|
return nil // want nothing
|
|
}
|
|
return fmt.Errorf("MatchReceipts: missing room %s: got %+v", roomID, res.Extensions.Receipts)
|
|
}
|
|
var gotReceipts []Receipt
|
|
ev := res.Extensions.Receipts.Rooms[roomID]
|
|
gjson.ParseBytes(ev).Get("content").ForEach(func(key, value gjson.Result) bool {
|
|
eventID := key.Str
|
|
value.ForEach(func(key, value gjson.Result) bool {
|
|
receiptType := key.Str
|
|
value.ForEach(func(key, value gjson.Result) bool {
|
|
userID := key.Str
|
|
threadID := value.Get("thread_id").Str
|
|
gotReceipts = append(gotReceipts, Receipt{
|
|
EventID: eventID,
|
|
UserID: userID,
|
|
Type: receiptType,
|
|
ThreadID: threadID,
|
|
})
|
|
return true
|
|
})
|
|
return true
|
|
})
|
|
return true
|
|
})
|
|
sortReceipts(gotReceipts)
|
|
sortReceipts(wantReceipts)
|
|
if !reflect.DeepEqual(gotReceipts, wantReceipts) {
|
|
return fmt.Errorf("MatchReceipts: wrong receipts, got %v want %v", gotReceipts, wantReceipts)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchAccountData builds a matcher which asserts that the account data in a sync
|
|
// response /exactly/ matches the given `globals` and `rooms`, up to ordering.
|
|
//
|
|
// - If there is no account data extension in the response, the matche fails.
|
|
// - If globals is non-nil:
|
|
// - if globals is not equal to the global account data, the match fails.
|
|
// Equality is determined using EqualAnyOrder.
|
|
//
|
|
// - If rooms is non-nil:
|
|
// - If the set of given rooms differs from the set of rooms present in the account
|
|
// data response, the match fails.
|
|
// - If a given room's account data events are not equal to its account data events
|
|
// in the sync response, the match fails. Again, equality is determined using
|
|
// EqualAnyOrder.
|
|
func MatchAccountData(globals []json.RawMessage, rooms map[string][]json.RawMessage) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.AccountData == nil {
|
|
return fmt.Errorf("MatchAccountData: no account_data extension")
|
|
}
|
|
if len(globals) > 0 {
|
|
if err := EqualAnyOrder(res.Extensions.AccountData.Global, globals); err != nil {
|
|
return fmt.Errorf("MatchAccountData[global]: %s", err)
|
|
}
|
|
}
|
|
if len(rooms) > 0 {
|
|
if len(rooms) != len(res.Extensions.AccountData.Rooms) {
|
|
return fmt.Errorf("MatchAccountData: got %d rooms with account data, want %d", len(res.Extensions.AccountData.Rooms), len(rooms))
|
|
}
|
|
for roomID := range rooms {
|
|
gots := res.Extensions.AccountData.Rooms[roomID]
|
|
if gots == nil {
|
|
return fmt.Errorf("MatchAccountData: want room account data for %s but it was missing", roomID)
|
|
}
|
|
if err := EqualAnyOrder(gots, rooms[roomID]); err != nil {
|
|
return fmt.Errorf("MatchAccountData[room]: %s", err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchHasGlobalAccountData builds a matcher which asserts that the given event is
|
|
// present in a global account data response.
|
|
func MatchHasGlobalAccountData(want json.RawMessage) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.AccountData == nil {
|
|
return fmt.Errorf("No account data section in sync response")
|
|
}
|
|
for _, msg := range res.Extensions.AccountData.Global {
|
|
if bytes.Equal(msg, want) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("could not find %s in global account data", want)
|
|
}
|
|
}
|
|
|
|
// MatchNoGlobalAccountData builds a matcher which asserts that no global account data
|
|
// is present in a sync response.
|
|
func MatchNoGlobalAccountData() RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.AccountData == nil {
|
|
return nil
|
|
}
|
|
accountDataEvents := res.Extensions.AccountData.Global
|
|
if len(accountDataEvents) > 0 {
|
|
return fmt.Errorf("MatchNoGlobalAccountData: got %d account data events, but expected none", len(accountDataEvents))
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchNoRoomAccountData builds a matcher which asserts that none of the given roomIDs
|
|
// have room account data in a sync response.
|
|
func MatchNoRoomAccountData(roomIDs []string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.AccountData == nil {
|
|
return nil
|
|
}
|
|
for _, roomID := range roomIDs {
|
|
// quick and dirty: complain the first time we see something we shouldn't
|
|
roomData := res.Extensions.AccountData.Rooms[roomID]
|
|
if roomData != nil {
|
|
return fmt.Errorf("MatchNoRoomAccountData: got account data for %s, but expected it to be missing", roomID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// LogResponse builds a matcher that always succeeds. As a side-effect, it pretty-prints
|
|
// the given sync response to the test log. This is useful when debugging a test.
|
|
func LogResponse(t *testing.T) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
t.Helper()
|
|
dump, _ := json.MarshalIndent(res, "", " ")
|
|
t.Logf("Response was: %s", dump)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// LogRooms is like LogResponse, but only logs the rooms section of the response.
|
|
func LogRooms(t *testing.T) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
t.Helper()
|
|
dump, _ := json.MarshalIndent(res.Rooms, "", " ")
|
|
t.Logf("Response rooms were: %s", dump)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func CheckList(listKey string, res sync3.ResponseList, matchers ...ListMatcher) error {
|
|
for _, m := range matchers {
|
|
if err := m(res); err != nil {
|
|
return fmt.Errorf("MatchList[%v]: %v", listKey, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func MatchTxnID(txnID string) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if txnID != res.TxnID {
|
|
return fmt.Errorf("MatchTxnID: got %v want %v", res.TxnID, txnID)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchList(listKey string, matchers ...ListMatcher) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if _, exists := res.Lists[listKey]; !exists {
|
|
return fmt.Errorf("MatchSingleList: key %v does not exist, got %d lists", listKey, len(res.Lists))
|
|
}
|
|
list := res.Lists[listKey]
|
|
return CheckList(listKey, list, matchers...)
|
|
}
|
|
}
|
|
|
|
func MatchLists(matchers map[string][]ListMatcher) RespMatcher {
|
|
return func(res *sync3.Response) error {
|
|
if len(matchers) != len(res.Lists) {
|
|
return fmt.Errorf("MatchLists: got %d matchers for %d lists", len(matchers), len(res.Lists))
|
|
}
|
|
for listKey, matchersForList := range matchers {
|
|
if err := CheckList(listKey, res.Lists[listKey], matchersForList...); err != nil {
|
|
return fmt.Errorf("MatchLists[%v]: %v", listKey, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
const AnsiRedForeground = "\x1b[31m"
|
|
const AnsiResetForeground = "\x1b[39m"
|
|
|
|
func MatchResponse(t *testing.T, res *sync3.Response, matchers ...RespMatcher) {
|
|
t.Helper()
|
|
errs := []error{}
|
|
for _, m := range matchers {
|
|
err := m(res)
|
|
if err != nil {
|
|
errs = append(errs, fmt.Errorf("%v%s%v", AnsiRedForeground, err, AnsiResetForeground))
|
|
}
|
|
}
|
|
|
|
if len(errs) > 0 {
|
|
if len(errs) == 1 {
|
|
t.Errorf("%vMatchResponse: %s", AnsiRedForeground, errs[0])
|
|
} else {
|
|
t.Errorf("%vMatchResponse: there were %d errors\n%s", AnsiRedForeground, len(errs), errors.Join(errs...))
|
|
}
|
|
LogResponse(t)(res)
|
|
}
|
|
}
|
|
|
|
func CheckRoom(r sync3.Room, matchers ...RoomMatcher) error {
|
|
for _, m := range matchers {
|
|
if err := m(r); err != nil {
|
|
return fmt.Errorf("MatchRoom : %s", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func EqualAnyOrder(got, want []json.RawMessage) error {
|
|
if len(got) != len(want) {
|
|
return fmt.Errorf("EqualAnyOrder: got %d, want %d", len(got), len(want))
|
|
}
|
|
sort.Slice(got, func(i, j int) bool {
|
|
return string(got[i]) < string(got[j])
|
|
})
|
|
sort.Slice(want, func(i, j int) bool {
|
|
return string(want[i]) < string(want[j])
|
|
})
|
|
for i := range got {
|
|
if !reflect.DeepEqual(got[i], want[i]) {
|
|
return fmt.Errorf("EqualAnyOrder: [%d] got %v want %v", i, string(got[i]), string(want[i]))
|
|
}
|
|
}
|
|
return nil
|
|
}
|