Add regression UT

This commit is contained in:
Kegan Dougal 2023-09-15 10:52:44 +01:00
parent 54a030236f
commit 864d67a5b3

View File

@ -129,6 +129,71 @@ func TestEnsurePollerCachesResponses(t *testing.T) {
n.MustHaveNoSentPayloads(t)
}
// Regression test for when we did cache failures, causing no poller to start for the device
func TestEnsurePollerDoesntCacheFailures(t *testing.T) {
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
ctx := context.Background()
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
ep := NewEnsurePoller(n, false)
finished := make(chan bool) // dummy
var expired atomic.Bool
go func() {
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
expired.Store(exp)
close(finished)
}()
n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
// send back the response, which failed
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: false,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
if !expired.Load() {
t.Fatalf("EnsurePolling returned not expired, wanted expired due to Success=false")
}
// hitting EnsurePolling again should do a new request (i.e not cached the failure)
var expiredAgain atomic.Bool
finished = make(chan bool) // dummy
go func() {
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
expiredAgain.Store(exp)
close(finished)
}()
p := n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
// check it's a V3EnsurePolling payload
pp, ok := p.(*pubsub.V3EnsurePolling)
if !ok {
t.Fatalf("unexpected payload: %+v", p)
}
assertVal(t, pp.UserID, pid.UserID)
assertVal(t, pp.DeviceID, pid.DeviceID)
assertVal(t, pp.AccessTokenHash, "tokenHash")
// send back the response, which succeeded this time
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: true,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
}
func assertVal(t *testing.T, got, want interface{}) {
t.Helper()
if !reflect.DeepEqual(got, want) {