mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Fix race condition in test
This commit is contained in:
parent
4d8cbb2709
commit
d65c4ebdcf
@ -10,12 +10,31 @@ import (
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
type syncSlice[T any] struct {
|
||||
slice []T
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *syncSlice[T]) append(item T) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.slice = append(s.slice, item)
|
||||
}
|
||||
|
||||
func (s *syncSlice[T]) clone() []T {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
result := make([]T, len(s.slice))
|
||||
copy(result, s.slice)
|
||||
return result
|
||||
}
|
||||
|
||||
func TestDeviceTickerBasic(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
var payloads syncSlice[*pubsub.V2DeviceData]
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
payloads.append(payload)
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@ -31,29 +50,31 @@ func TestDeviceTickerBasic(t *testing.T) {
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
result := payloads.clone()
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(result))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
|
||||
// check stopping works
|
||||
payloads = []*pubsub.V2DeviceData{}
|
||||
payloads = syncSlice[*pubsub.V2DeviceData]{}
|
||||
ticker.Stop()
|
||||
wg.Wait()
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 0 {
|
||||
t.Fatalf("got extra payloads: %+v", payloads)
|
||||
result = payloads.clone()
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("got extra payloads: %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
|
||||
duration := 100 * time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
var payloads syncSlice[*pubsub.V2DeviceData]
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
payloads.append(payload)
|
||||
})
|
||||
go ticker.Run()
|
||||
defer ticker.Stop()
|
||||
@ -74,23 +95,23 @@ func TestDeviceTickerBatchesCorrectly(t *testing.T) {
|
||||
DeviceID: "y", // new device and user
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
result := payloads.clone()
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(result))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b", "bb"},
|
||||
"x": {"y"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
|
||||
}
|
||||
|
||||
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
|
||||
var payloads syncSlice[*pubsub.V2DeviceData]
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
payloads.append(payload)
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
@ -104,8 +125,9 @@ func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(10 * duration)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("got %d payloads, want 1", len(payloads))
|
||||
result := payloads.clone()
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("got %d payloads, want 1", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user