Attempt to prevent data races in PubSub

This commit is contained in:
David Robertson 2023-08-09 16:02:10 +01:00
parent 8cd00cc731
commit 32a7a05af0
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD

View File

@ -49,8 +49,14 @@ type Notifier interface {
Close() error
}
type channel struct {
data chan Payload
closing chan struct{}
}
type PubSub struct {
chans map[string]chan Payload
chans map[string]channel
// mu guards the chans map.
// It does not guard any access to the channels themselves.
mu *sync.Mutex
closed bool
bufferSize int
@ -58,33 +64,49 @@ type PubSub struct {
func NewPubSub(bufferSize int) *PubSub {
return &PubSub{
chans: make(map[string]chan Payload),
chans: make(map[string]channel),
mu: &sync.Mutex{},
bufferSize: bufferSize,
}
}
func (ps *PubSub) getChan(chanName string) chan Payload {
func (ps *PubSub) getChan(chanName string) (chan Payload, chan struct{}) {
ps.mu.Lock()
defer ps.mu.Unlock()
ch := ps.chans[chanName]
if ch == nil {
ch = make(chan Payload, ps.bufferSize)
if ch.data == nil {
ch = channel{
data: make(chan Payload, ps.bufferSize),
closing: make(chan struct{}, 1),
}
ps.chans[chanName] = ch
}
return ch
return ch.data, ch.closing
}
func (ps *PubSub) Notify(chanName string, p Payload) error {
ch := ps.getChan(chanName)
data, closing := ps.getChan(chanName)
select {
case ch <- p:
case <-closing:
// Do not send if we've been asked to shut down.
// This avoids races between closing and sending.
close(data)
return nil
case data <- p:
break
case <-time.After(5 * time.Second):
return fmt.Errorf("notify with payload %v timed out", p.Type())
}
if ps.bufferSize == 0 {
ch <- &emptyPayload{}
select {
case <-closing:
close(data)
return nil
case data <- &emptyPayload{}:
break
case <-time.After(5 * time.Second):
return fmt.Errorf("notify with empty payload timed out")
}
}
return nil
}
@ -97,13 +119,14 @@ func (ps *PubSub) Close() error {
ps.mu.Lock()
defer ps.mu.Unlock()
for _, ch := range ps.chans {
close(ch)
ch.closing <- struct{}{}
close(ch.closing)
}
return nil
}
func (ps *PubSub) Listen(chanName string, fn func(p Payload)) error {
ch := ps.getChan(chanName)
ch, _ := ps.getChan(chanName)
for payload := range ch {
if payload.Type() == emptyPayloadType {
continue