Fix race condition in test

This commit is contained in:
Kegan Dougal 2023-10-11 12:36:21 +01:00
parent 4d8cbb2709
commit d65c4ebdcf

View File

@ -10,12 +10,31 @@ import (
"github.com/matrix-org/sliding-sync/pubsub" "github.com/matrix-org/sliding-sync/pubsub"
) )
type syncSlice[T any] struct {
slice []T
mu sync.Mutex
}
func (s *syncSlice[T]) append(item T) {
s.mu.Lock()
defer s.mu.Unlock()
s.slice = append(s.slice, item)
}
func (s *syncSlice[T]) clone() []T {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]T, len(s.slice))
copy(result, s.slice)
return result
}
func TestDeviceTickerBasic(t *testing.T) { func TestDeviceTickerBasic(t *testing.T) {
duration := time.Millisecond duration := time.Millisecond
ticker := NewDeviceDataTicker(duration) ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) { ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload) payloads.append(payload)
}) })
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -31,29 +50,31 @@ func TestDeviceTickerBasic(t *testing.T) {
DeviceID: "b", DeviceID: "b",
}) })
time.Sleep(duration * 2) time.Sleep(duration * 2)
if len(payloads) != 1 { result := payloads.clone()
t.Fatalf("expected 1 callback, got %d", len(payloads)) if len(result) != 1 {
t.Fatalf("expected 1 callback, got %d", len(result))
} }
want := map[string][]string{ want := map[string][]string{
"a": {"b"}, "a": {"b"},
} }
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
// check stopping works // check stopping works
payloads = []*pubsub.V2DeviceData{} payloads = syncSlice[*pubsub.V2DeviceData]{}
ticker.Stop() ticker.Stop()
wg.Wait() wg.Wait()
time.Sleep(duration * 2) time.Sleep(duration * 2)
if len(payloads) != 0 { result = payloads.clone()
t.Fatalf("got extra payloads: %+v", payloads) if len(result) != 0 {
t.Fatalf("got extra payloads: %+v", result)
} }
} }
func TestDeviceTickerBatchesCorrectly(t *testing.T) { func TestDeviceTickerBatchesCorrectly(t *testing.T) {
duration := 100 * time.Millisecond duration := 100 * time.Millisecond
ticker := NewDeviceDataTicker(duration) ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) { ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload) payloads.append(payload)
}) })
go ticker.Run() go ticker.Run()
defer ticker.Stop() defer ticker.Stop()
@ -74,23 +95,23 @@ func TestDeviceTickerBatchesCorrectly(t *testing.T) {
DeviceID: "y", // new device and user DeviceID: "y", // new device and user
}) })
time.Sleep(duration * 2) time.Sleep(duration * 2)
if len(payloads) != 1 { result := payloads.clone()
t.Fatalf("expected 1 callback, got %d", len(payloads)) if len(result) != 1 {
t.Fatalf("expected 1 callback, got %d", len(result))
} }
want := map[string][]string{ want := map[string][]string{
"a": {"b", "bb"}, "a": {"b", "bb"},
"x": {"y"}, "x": {"y"},
} }
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
} }
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) { func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
duration := time.Millisecond duration := time.Millisecond
ticker := NewDeviceDataTicker(duration) ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) { ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload) payloads.append(payload)
}) })
ticker.Remember(PollerID{ ticker.Remember(PollerID{
UserID: "a", UserID: "a",
@ -104,8 +125,9 @@ func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
DeviceID: "b", DeviceID: "b",
}) })
time.Sleep(10 * duration) time.Sleep(10 * duration)
if len(payloads) != 1 { result := payloads.clone()
t.Fatalf("got %d payloads, want 1", len(payloads)) if len(result) != 1 {
t.Fatalf("got %d payloads, want 1", len(result))
} }
} }