mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
234 lines
7.0 KiB
Go
234 lines
7.0 KiB
Go
package syncv3_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/matrix-org/complement/client"
|
|
"github.com/matrix-org/sliding-sync/sync3"
|
|
"github.com/matrix-org/sliding-sync/testutils/m"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
var (
|
|
proxyBaseURL = os.Getenv("SYNCV3_ADDR")
|
|
homeserverBaseURL = os.Getenv("SYNCV3_SERVER")
|
|
userCounter uint64
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
if proxyBaseURL == "" {
|
|
fmt.Println("SYNCV3_ADDR must be set e.g 'http://localhost:8008'")
|
|
os.Exit(1)
|
|
}
|
|
fmt.Println("proxy located at", proxyBaseURL)
|
|
exitCode := m.Run()
|
|
os.Exit(exitCode)
|
|
}
|
|
|
|
func assertEventsEqual(t *testing.T, wantList []Event, gotList []json.RawMessage) {
|
|
t.Helper()
|
|
err := eventsEqual(wantList, gotList)
|
|
if err != nil {
|
|
t.Errorf(err.Error())
|
|
}
|
|
}
|
|
|
|
func eventsEqual(wantList []Event, gotList []json.RawMessage) error {
|
|
if len(wantList) != len(gotList) {
|
|
return fmt.Errorf("got %d events, want %d", len(gotList), len(wantList))
|
|
}
|
|
for i := 0; i < len(wantList); i++ {
|
|
want := wantList[i]
|
|
var got Event
|
|
if err := json.Unmarshal(gotList[i], &got); err != nil {
|
|
return fmt.Errorf("failed to unmarshal event %d: %s", i, err)
|
|
}
|
|
if want.ID != "" && got.ID != want.ID {
|
|
return fmt.Errorf("event %d ID mismatch: got %v want %v", i, got.ID, want.ID)
|
|
}
|
|
if want.Content != nil && !reflect.DeepEqual(got.Content, want.Content) {
|
|
return fmt.Errorf("event %d content mismatch: got %+v want %+v", i, got.Content, want.Content)
|
|
}
|
|
if want.Type != "" && want.Type != got.Type {
|
|
return fmt.Errorf("event %d Type mismatch: got %v want %v", i, got.Type, want.Type)
|
|
}
|
|
if want.StateKey != nil {
|
|
if got.StateKey == nil {
|
|
return fmt.Errorf("event %d StateKey mismatch: want %v got <nil>", i, *want.StateKey)
|
|
} else if *want.StateKey != *got.StateKey {
|
|
return fmt.Errorf("event %d StateKey mismatch: got %v want %v", i, *got.StateKey, *want.StateKey)
|
|
}
|
|
}
|
|
if want.Sender != "" && want.Sender != got.Sender {
|
|
return fmt.Errorf("event %d Sender mismatch: got %v want %v", i, got.Sender, want.Sender)
|
|
}
|
|
// loop each key on unsigned as unsigned also includes "age" which is unpredictable so cannot DeepEqual
|
|
if want.Unsigned != nil {
|
|
for k, v := range want.Unsigned {
|
|
got := got.Unsigned[k]
|
|
if !reflect.DeepEqual(got, v) {
|
|
return fmt.Errorf("event %d Unsigned.%s mismatch: got %v want %v", i, k, got, v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MatchRoomTimelineMostRecent builds a matcher which checks that the last `n` elements
|
|
// of `events` are the same as the last n elements of the room timeline. If either list
|
|
// contains fewer than `n` events, the match fails.
|
|
// Events are tested for equality using `eventsEqual`.
|
|
func MatchRoomTimelineMostRecent(n int, events []Event) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(events) < n {
|
|
return fmt.Errorf("list of wanted events has %d events, expected at least %d", len(events), n)
|
|
}
|
|
wantList := events[len(events)-n:]
|
|
if len(r.Timeline) < n {
|
|
return fmt.Errorf("timeline has %d events, expected at least %d", len(r.Timeline), n)
|
|
}
|
|
|
|
gotList := r.Timeline[len(r.Timeline)-n:]
|
|
return eventsEqual(wantList, gotList)
|
|
}
|
|
}
|
|
|
|
func MatchRoomTimeline(events []Event) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
return eventsEqual(events, r.Timeline)
|
|
}
|
|
}
|
|
|
|
func MatchRoomTimelineContains(event Event) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
var err error
|
|
for _, got := range r.Timeline {
|
|
if err = eventsEqual([]Event{event}, []json.RawMessage{got}); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
func MatchRoomRequiredState(events []Event) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
// allow any ordering for required state
|
|
for _, want := range events {
|
|
found := false
|
|
for _, got := range r.RequiredState {
|
|
if err := eventsEqual([]Event{want}, []json.RawMessage{got}); err == nil {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("required state want event %+v but did not find exact match", want)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func MatchRoomRequiredStateStrict(events []Event) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if len(r.RequiredState) != len(events) {
|
|
return fmt.Errorf("required state length mismatch, got %d want %d", len(r.RequiredState), len(events))
|
|
}
|
|
return MatchRoomRequiredState(events)(r)
|
|
}
|
|
}
|
|
|
|
func MatchRoomInviteState(events []Event, partial bool) m.RoomMatcher {
|
|
return func(r sync3.Room) error {
|
|
if !partial && len(r.InviteState) != len(events) {
|
|
return fmt.Errorf("MatchRoomInviteState: length mismatch, got %d want %d", len(r.InviteState), len(events))
|
|
}
|
|
// allow any ordering for state
|
|
for _, want := range events {
|
|
found := false
|
|
for _, got := range r.InviteState {
|
|
if err := eventsEqual([]Event{want}, []json.RawMessage{got}); err == nil {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist or failed to pass equality checks", want)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MatchGlobalAccountData builds a matcher which asserts that the account data in a sync
|
|
// response matches the given `globals`, with any ordering.
|
|
// If there is no account data extension in the response, the match fails.
|
|
func MatchGlobalAccountData(globals []Event) m.RespMatcher {
|
|
// sort want list by type
|
|
sort.Slice(globals, func(i, j int) bool {
|
|
return globals[i].Type < globals[j].Type
|
|
})
|
|
|
|
return func(res *sync3.Response) error {
|
|
if res.Extensions.AccountData == nil {
|
|
return fmt.Errorf("MatchGlobalAccountData: no account_data extension")
|
|
}
|
|
if len(globals) != len(res.Extensions.AccountData.Global) {
|
|
return fmt.Errorf("MatchGlobalAccountData: got %v global account data, want %v", len(res.Extensions.AccountData.Global), len(globals))
|
|
}
|
|
// sort the got list by type
|
|
got := res.Extensions.AccountData.Global
|
|
sort.Slice(got, func(i, j int) bool {
|
|
return gjson.GetBytes(got[i], "type").Str < gjson.GetBytes(got[j], "type").Str
|
|
})
|
|
if err := eventsEqual(globals, got); err != nil {
|
|
return fmt.Errorf("MatchGlobalAccountData: %s", err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func registerNewUser(t *testing.T) *CSAPI {
|
|
return registerNamedUser(t, "user")
|
|
}
|
|
|
|
func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI {
|
|
// create user
|
|
localpart := fmt.Sprintf("%s-%d-%d", localpartPrefix, time.Now().Unix(), atomic.AddUint64(&userCounter, 1))
|
|
httpClient := client.NewLoggedClient(t, "localhost", nil)
|
|
client := &CSAPI{
|
|
CSAPI: &client.CSAPI{
|
|
Client: httpClient,
|
|
BaseURL: homeserverBaseURL,
|
|
SyncUntilTimeout: 3 * time.Second,
|
|
},
|
|
}
|
|
|
|
client.UserID, client.AccessToken, client.DeviceID = client.RegisterUser(t, localpart, "password")
|
|
parts := strings.Split(client.UserID, ":")
|
|
client.Localpart = parts[0][1:]
|
|
client.Domain = strings.Split(client.UserID, ":")[1]
|
|
return client
|
|
}
|
|
|
|
func ptr(s string) *string {
|
|
return &s
|
|
}
|
|
|
|
func assertEqual(t *testing.T, msg string, got, want interface{}) {
|
|
t.Helper()
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("%s: got %v want %v", msg, got, want)
|
|
}
|
|
}
|