package syncv3_test import ( "encoding/json" "fmt" "os" "reflect" "sort" "strings" "sync/atomic" "testing" "time" "github.com/matrix-org/complement/client" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils/m" "github.com/tidwall/gjson" ) var ( proxyBaseURL = os.Getenv("SYNCV3_ADDR") homeserverBaseURL = os.Getenv("SYNCV3_SERVER") userCounter uint64 ) func TestMain(m *testing.M) { if proxyBaseURL == "" { fmt.Println("SYNCV3_ADDR must be set e.g 'http://localhost:8008'") os.Exit(1) } fmt.Println("proxy located at", proxyBaseURL) exitCode := m.Run() os.Exit(exitCode) } func assertEventsEqual(t *testing.T, wantList []Event, gotList []json.RawMessage) { t.Helper() err := eventsEqual(wantList, gotList) if err != nil { t.Errorf(err.Error()) } } func eventsEqual(wantList []Event, gotList []json.RawMessage) error { if len(wantList) != len(gotList) { return fmt.Errorf("got %d events, want %d", len(gotList), len(wantList)) } for i := 0; i < len(wantList); i++ { want := wantList[i] var got Event if err := json.Unmarshal(gotList[i], &got); err != nil { return fmt.Errorf("failed to unmarshal event %d: %s", i, err) } if want.ID != "" && got.ID != want.ID { return fmt.Errorf("event %d ID mismatch: got %v want %v", i, got.ID, want.ID) } if want.Content != nil && !reflect.DeepEqual(got.Content, want.Content) { return fmt.Errorf("event %d content mismatch: got %+v want %+v", i, got.Content, want.Content) } if want.Type != "" && want.Type != got.Type { return fmt.Errorf("event %d Type mismatch: got %v want %v", i, got.Type, want.Type) } if want.StateKey != nil { if got.StateKey == nil { return fmt.Errorf("event %d StateKey mismatch: want %v got ", i, *want.StateKey) } else if *want.StateKey != *got.StateKey { return fmt.Errorf("event %d StateKey mismatch: got %v want %v", i, *got.StateKey, *want.StateKey) } } if want.Sender != "" && want.Sender != got.Sender { return fmt.Errorf("event %d Sender mismatch: got %v want %v", i, got.Sender, want.Sender) } // loop each key on unsigned as unsigned also includes "age" which is unpredictable so cannot DeepEqual if want.Unsigned != nil { for k, v := range want.Unsigned { got := got.Unsigned[k] if !reflect.DeepEqual(got, v) { return fmt.Errorf("event %d Unsigned.%s mismatch: got %v want %v", i, k, got, v) } } } } return nil } // MatchRoomTimelineMostRecent builds a matcher which checks that the last `n` elements // of `events` are the same as the last n elements of the room timeline. If either list // contains fewer than `n` events, the match fails. // Events are tested for equality using `eventsEqual`. func MatchRoomTimelineMostRecent(n int, events []Event) m.RoomMatcher { return func(r sync3.Room) error { if len(events) < n { return fmt.Errorf("list of wanted events has %d events, expected at least %d", len(events), n) } wantList := events[len(events)-n:] if len(r.Timeline) < n { return fmt.Errorf("timeline has %d events, expected at least %d", len(r.Timeline), n) } gotList := r.Timeline[len(r.Timeline)-n:] return eventsEqual(wantList, gotList) } } func MatchRoomTimeline(events []Event) m.RoomMatcher { return func(r sync3.Room) error { return eventsEqual(events, r.Timeline) } } func MatchRoomTimelineContains(event Event) m.RoomMatcher { return func(r sync3.Room) error { var err error for _, got := range r.Timeline { if err = eventsEqual([]Event{event}, []json.RawMessage{got}); err == nil { return nil } } return err } } func MatchRoomRequiredState(events []Event) m.RoomMatcher { return func(r sync3.Room) error { // allow any ordering for required state for _, want := range events { found := false for _, got := range r.RequiredState { if err := eventsEqual([]Event{want}, []json.RawMessage{got}); err == nil { found = true break } } if !found { return fmt.Errorf("required state want event %+v but did not find exact match", want) } } return nil } } func MatchRoomRequiredStateStrict(events []Event) m.RoomMatcher { return func(r sync3.Room) error { if len(r.RequiredState) != len(events) { return fmt.Errorf("required state length mismatch, got %d want %d", len(r.RequiredState), len(events)) } return MatchRoomRequiredState(events)(r) } } func MatchRoomInviteState(events []Event, partial bool) m.RoomMatcher { return func(r sync3.Room) error { if !partial && len(r.InviteState) != len(events) { return fmt.Errorf("MatchRoomInviteState: length mismatch, got %d want %d", len(r.InviteState), len(events)) } // allow any ordering for state for _, want := range events { found := false for _, got := range r.InviteState { if err := eventsEqual([]Event{want}, []json.RawMessage{got}); err == nil { found = true break } } if !found { return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist or failed to pass equality checks", want) } } return nil } } // MatchGlobalAccountData builds a matcher which asserts that the account data in a sync // response matches the given `globals`, with any ordering. // If there is no account data extension in the response, the match fails. func MatchGlobalAccountData(globals []Event) m.RespMatcher { // sort want list by type sort.Slice(globals, func(i, j int) bool { return globals[i].Type < globals[j].Type }) return func(res *sync3.Response) error { if res.Extensions.AccountData == nil { return fmt.Errorf("MatchGlobalAccountData: no account_data extension") } if len(globals) != len(res.Extensions.AccountData.Global) { return fmt.Errorf("MatchGlobalAccountData: got %v global account data, want %v", len(res.Extensions.AccountData.Global), len(globals)) } // sort the got list by type got := res.Extensions.AccountData.Global sort.Slice(got, func(i, j int) bool { return gjson.GetBytes(got[i], "type").Str < gjson.GetBytes(got[j], "type").Str }) if err := eventsEqual(globals, got); err != nil { return fmt.Errorf("MatchGlobalAccountData: %s", err) } return nil } } func registerNewUser(t *testing.T) *CSAPI { return registerNamedUser(t, "user") } func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI { // create user localpart := fmt.Sprintf("%s-%d-%d", localpartPrefix, time.Now().Unix(), atomic.AddUint64(&userCounter, 1)) httpClient := client.NewLoggedClient(t, "localhost", nil) client := &CSAPI{ CSAPI: &client.CSAPI{ Client: httpClient, BaseURL: homeserverBaseURL, SyncUntilTimeout: 3 * time.Second, }, } client.UserID, client.AccessToken, client.DeviceID = client.RegisterUser(t, localpart, "password") parts := strings.Split(client.UserID, ":") client.Localpart = parts[0][1:] client.Domain = strings.Split(client.UserID, ":")[1] return client } func ptr(s string) *string { return &s } func assertEqual(t *testing.T, msg string, got, want interface{}) { t.Helper() if !reflect.DeepEqual(got, want) { t.Fatalf("%s: got %v want %v", msg, got, want) } }