diff --git a/sync2/device_data_ticker_test.go b/sync2/device_data_ticker_test.go index daa5081..f07a99a 100644 --- a/sync2/device_data_ticker_test.go +++ b/sync2/device_data_ticker_test.go @@ -10,12 +10,31 @@ import ( "github.com/matrix-org/sliding-sync/pubsub" ) +type syncSlice[T any] struct { + slice []T + mu sync.Mutex +} + +func (s *syncSlice[T]) append(item T) { + s.mu.Lock() + defer s.mu.Unlock() + s.slice = append(s.slice, item) +} + +func (s *syncSlice[T]) clone() []T { + s.mu.Lock() + defer s.mu.Unlock() + result := make([]T, len(s.slice)) + copy(result, s.slice) + return result +} + func TestDeviceTickerBasic(t *testing.T) { duration := time.Millisecond ticker := NewDeviceDataTicker(duration) - var payloads []*pubsub.V2DeviceData + var payloads syncSlice[*pubsub.V2DeviceData] ticker.SetCallback(func(payload *pubsub.V2DeviceData) { - payloads = append(payloads, payload) + payloads.append(payload) }) var wg sync.WaitGroup wg.Add(1) @@ -31,29 +50,31 @@ func TestDeviceTickerBasic(t *testing.T) { DeviceID: "b", }) time.Sleep(duration * 2) - if len(payloads) != 1 { - t.Fatalf("expected 1 callback, got %d", len(payloads)) + result := payloads.clone() + if len(result) != 1 { + t.Fatalf("expected 1 callback, got %d", len(result)) } want := map[string][]string{ "a": {"b"}, } - assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) + assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want) // check stopping works - payloads = []*pubsub.V2DeviceData{} + payloads = syncSlice[*pubsub.V2DeviceData]{} ticker.Stop() wg.Wait() time.Sleep(duration * 2) - if len(payloads) != 0 { - t.Fatalf("got extra payloads: %+v", payloads) + result = payloads.clone() + if len(result) != 0 { + t.Fatalf("got extra payloads: %+v", result) } } func TestDeviceTickerBatchesCorrectly(t *testing.T) { duration := 100 * time.Millisecond ticker := NewDeviceDataTicker(duration) - var payloads []*pubsub.V2DeviceData + var payloads syncSlice[*pubsub.V2DeviceData] ticker.SetCallback(func(payload *pubsub.V2DeviceData) { - payloads = append(payloads, payload) + payloads.append(payload) }) go ticker.Run() defer ticker.Stop() @@ -74,23 +95,23 @@ func TestDeviceTickerBatchesCorrectly(t *testing.T) { DeviceID: "y", // new device and user }) time.Sleep(duration * 2) - if len(payloads) != 1 { - t.Fatalf("expected 1 callback, got %d", len(payloads)) + result := payloads.clone() + if len(result) != 1 { + t.Fatalf("expected 1 callback, got %d", len(result)) } want := map[string][]string{ "a": {"b", "bb"}, "x": {"y"}, } - assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) + assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want) } func TestDeviceTickerForgetsAfterEmitting(t *testing.T) { duration := time.Millisecond ticker := NewDeviceDataTicker(duration) - var payloads []*pubsub.V2DeviceData - + var payloads syncSlice[*pubsub.V2DeviceData] ticker.SetCallback(func(payload *pubsub.V2DeviceData) { - payloads = append(payloads, payload) + payloads.append(payload) }) ticker.Remember(PollerID{ UserID: "a", @@ -104,8 +125,9 @@ func TestDeviceTickerForgetsAfterEmitting(t *testing.T) { DeviceID: "b", }) time.Sleep(10 * duration) - if len(payloads) != 1 { - t.Fatalf("got %d payloads, want 1", len(payloads)) + result := payloads.clone() + if len(result) != 1 { + t.Fatalf("got %d payloads, want 1", len(result)) } }