PollerMap: ensure callbacks are always called from a single goroutine

Document a nasty race condition which can happen if >1 user is joined
to the same room. Fixed to ensure that `GlobalCache` will always stay
in-sync with the database without having to hit the database.
This commit is contained in:
Kegan Dougal 2021-10-28 16:15:17 +01:00
parent a9a57ddfda
commit 9f3364d9ed
5 changed files with 114 additions and 64 deletions

View File

@ -323,6 +323,7 @@ const indexInRange = (i) => {
const doSyncLoop = async(accessToken, sessionId) => {
console.log("Starting sync loop. Active: ", activeSessionId, " this:", sessionId);
let currentPos;
let currentError = null;
let currentSub = "";

View File

@ -16,31 +16,55 @@ var timeSleep = time.Sleep
// V2DataReceiver is the receiver for all the v2 sync data the poller gets
type V2DataReceiver interface {
UpdateDeviceSince(deviceID, since string) error
Accumulate(roomID string, timeline []json.RawMessage) error
Initialise(roomID string, state []json.RawMessage) error
SetTyping(roomID string, userIDs []string) (int64, error)
UpdateDeviceSince(deviceID, since string)
Accumulate(roomID string, timeline []json.RawMessage)
Initialise(roomID string, state []json.RawMessage)
SetTyping(roomID string, userIDs []string)
// Add messages for this device. If an error is returned, the poll loop is terminated as continuing
// would implicitly acknowledge these messages.
AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) error
AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent)
UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int)
}
// PollerMap is a map of device ID to Poller
type PollerMap struct {
v2Client Client
callbacks V2DataReceiver
pollerMu *sync.Mutex
Pollers map[string]*Poller // device_id -> poller
v2Client Client
callbacks V2DataReceiver
pollerMu *sync.Mutex
Pollers map[string]*Poller // device_id -> poller
executor chan func()
executorRunning bool
}
// NewPollerMap makes a new PollerMap. Guarantees that the V2DataReceiver will be called on the same
// goroutine for all pollers. This is required to avoid race conditions at the Go level. Whilst we
// use SQL transactions to ensure that the DB doesn't race, we then subsequently feed new events
// from that call into a global cache. This can race which can result in out of order latest NIDs
// which, if we assert NIDs only increment, will result in missed events.
//
// Consider these events in the same room, with 3 different pollers getting the data:
// 1 2 3 4 5 6 7 eventual DB event NID
// A B C D E F G
// ----- poll loop 1 = A,B,C new events = A,B,C latest=3
// --------- poll loop 2 = A,B,C,D,E new events = D,E latest=5
// ------------- poll loop 3 = A,B,C,D,E,F,G new events = F,G latest=7
// The DB layer will correctly assign NIDs and stop duplicates, resulting in a set of new events which
// do not overlap. However, there is a gap between this point and updating the cache, where variable
// delays can be introduced, so F,G latest=7 could be injected first. If we then never walk back to
// earlier NIDs, A,B,C,D,E will be dropped from the cache.
//
// This only affects resources which are shared across multiple DEVICES such as:
// - room resources: events, EDUs
// - user resources: notif counts, account data
// NOT to-device messages,or since tokens.
func NewPollerMap(v2Client Client, callbacks V2DataReceiver) *PollerMap {
return &PollerMap{
v2Client: v2Client,
callbacks: callbacks,
pollerMu: &sync.Mutex{},
Pollers: make(map[string]*Poller),
executor: make(chan func(), 0),
}
}
@ -50,6 +74,10 @@ func NewPollerMap(v2Client Client, callbacks V2DataReceiver) *PollerMap {
// Guarantees only 1 poller will be running per deviceID.
func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string, logger zerolog.Logger) {
h.pollerMu.Lock()
if !h.executorRunning {
h.executorRunning = true
go h.execute()
}
poller, ok := h.Pollers[deviceID]
// a poller exists and hasn't been terminated so we don't need to do anything
if ok && !poller.Terminated {
@ -57,7 +85,7 @@ func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string,
return
}
// replace the poller
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h.callbacks, logger)
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h, logger)
var wg sync.WaitGroup
wg.Add(1)
go poller.Poll(v2since, func() {
@ -68,6 +96,43 @@ func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string,
wg.Wait()
}
func (h *PollerMap) execute() {
for fn := range h.executor {
fn()
}
}
func (h *PollerMap) UpdateDeviceSince(deviceID, since string) {
h.callbacks.UpdateDeviceSince(deviceID, since)
}
func (h *PollerMap) Accumulate(roomID string, timeline []json.RawMessage) {
h.executor <- func() {
h.callbacks.Accumulate(roomID, timeline)
}
}
func (h *PollerMap) Initialise(roomID string, state []json.RawMessage) {
h.executor <- func() {
h.callbacks.Initialise(roomID, state)
}
}
func (h *PollerMap) SetTyping(roomID string, userIDs []string) {
h.executor <- func() {
h.callbacks.SetTyping(roomID, userIDs)
}
}
// Add messages for this device. If an error is returned, the poll loop is terminated as continuing
// would implicitly acknowledge these messages.
func (h *PollerMap) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) {
h.callbacks.AddToDeviceMessages(userID, deviceID, msgs)
}
func (h *PollerMap) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
h.executor <- func() {
h.callbacks.UpdateUnreadCounts(roomID, userID, highlightCount, notifCount)
}
}
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in storage
type Poller struct {
userID string
@ -121,18 +186,11 @@ func (p *Poller) Poll(since string, callback func()) {
}
failCount = 0
p.parseRoomsResponse(resp)
if err = p.parseToDeviceMessages(resp); err != nil {
p.logger.Err(err).Str("since", since).Msg("Poller: V2DataReceiver failed to persist to-device messages. Terminating loop.")
p.Terminated = true
return
}
p.parseToDeviceMessages(resp)
since = resp.NextBatch
// persist the since token (TODO: this could get slow if we hammer the DB too much)
err = p.receiver.UpdateDeviceSince(p.deviceID, since)
if err != nil {
// non-fatal
p.logger.Warn().Str("since", since).Err(err).Msg("Poller: V2DataReceiver failed to persist new since value")
}
p.receiver.UpdateDeviceSince(p.deviceID, since)
if firstTime {
firstTime = false
@ -141,11 +199,11 @@ func (p *Poller) Poll(since string, callback func()) {
}
}
func (p *Poller) parseToDeviceMessages(res *SyncResponse) error {
func (p *Poller) parseToDeviceMessages(res *SyncResponse) {
if len(res.ToDevice.Events) == 0 {
return nil
return
}
return p.receiver.AddToDeviceMessages(p.userID, p.deviceID, res.ToDevice.Events)
p.receiver.AddToDeviceMessages(p.userID, p.deviceID, res.ToDevice.Events)
}
func (p *Poller) parseRoomsResponse(res *SyncResponse) {
@ -155,10 +213,7 @@ func (p *Poller) parseRoomsResponse(res *SyncResponse) {
for roomID, roomData := range res.Rooms.Join {
if len(roomData.State.Events) > 0 {
stateCalls++
err := p.receiver.Initialise(roomID, roomData.State.Events)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Int("num_state_events", len(roomData.State.Events)).Msg("Poller: V2DataReceiver.Initialise failed")
}
p.receiver.Initialise(roomID, roomData.State.Events)
}
// process unread counts before events else we might push the event without including said event in the count
if roomData.UnreadNotifications.HighlightCount != nil || roomData.UnreadNotifications.NotificationCount != nil {
@ -168,10 +223,7 @@ func (p *Poller) parseRoomsResponse(res *SyncResponse) {
}
if len(roomData.Timeline.Events) > 0 {
timelineCalls++
err := p.receiver.Accumulate(roomID, roomData.Timeline.Events)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Poller: V2DataReceiver.Accumulate failed")
}
p.receiver.Accumulate(roomID, roomData.Timeline.Events)
}
for _, ephEvent := range roomData.Ephemeral.Events {
if gjson.GetBytes(ephEvent, "type").Str == "m.typing" {
@ -186,10 +238,7 @@ func (p *Poller) parseRoomsResponse(res *SyncResponse) {
}
}
typingCalls++
_, err := p.receiver.SetTyping(roomID, userIDs)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Strs("user_ids", userIDs).Msg("Poller: V2DataReceiver failed to SetTyping")
}
p.receiver.SetTyping(roomID, userIDs)
}
}
}
@ -197,10 +246,7 @@ func (p *Poller) parseRoomsResponse(res *SyncResponse) {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
err := p.receiver.Accumulate(roomID, roomData.Timeline.Events)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Poller: V2DataReceiver.Accumulate left room failed")
}
p.receiver.Accumulate(roomID, roomData.Timeline.Events)
}
}
p.logger.Info().Ints(

View File

@ -225,23 +225,18 @@ type mockDataReceiver struct {
deviceIDToSince map[string]string
}
func (a *mockDataReceiver) Accumulate(roomID string, timeline []json.RawMessage) error {
func (a *mockDataReceiver) Accumulate(roomID string, timeline []json.RawMessage) {
a.timelines[roomID] = append(a.timelines[roomID], timeline...)
return nil
}
func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) error {
func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) {
a.states[roomID] = state
return nil
}
func (a *mockDataReceiver) SetTyping(roomID string, userIDs []string) (int64, error) {
return 0, nil
func (a *mockDataReceiver) SetTyping(roomID string, userIDs []string) {
}
func (s *mockDataReceiver) UpdateDeviceSince(deviceID, since string) error {
func (s *mockDataReceiver) UpdateDeviceSince(deviceID, since string) {
s.deviceIDToSince[deviceID] = since
return nil
}
func (s *mockDataReceiver) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) error {
return nil
func (s *mockDataReceiver) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) {
}
func (s *mockDataReceiver) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {

View File

@ -182,7 +182,7 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *Request) (*Respo
break blockloop
case <-time.After(10 * time.Second): // TODO configurable
break blockloop
case updateEvent := <-s.updateEvents:
case updateEvent := <-s.updateEvents: // TODO: keep reading until it is empty before responding.
if updateEvent.latestPos > s.loadPosition {
s.loadPosition = updateEvent.latestPos
}

View File

@ -261,53 +261,61 @@ func (h *SyncLiveHandler) userCache(userID string) (*UserCache, error) {
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) error {
return h.V2Store.UpdateDeviceSince(deviceID, since)
func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) {
err := h.V2Store.UpdateDeviceSince(deviceID, since)
if err != nil {
logger.Err(err).Str("device", deviceID).Str("since", since).Msg("V2: failed to persist since token")
}
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) Accumulate(roomID string, timeline []json.RawMessage) error {
func (h *SyncLiveHandler) Accumulate(roomID string, timeline []json.RawMessage) {
numNew, latestPos, err := h.Storage.Accumulate(roomID, timeline)
if err != nil {
return err
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
return
}
if numNew == 0 {
// no new events
return nil
return
}
newEvents := timeline[len(timeline)-numNew:]
// we have new events, notify active connections
h.dispatcher.OnNewEvents(roomID, newEvents, latestPos)
return err
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) Initialise(roomID string, state []json.RawMessage) error {
func (h *SyncLiveHandler) Initialise(roomID string, state []json.RawMessage) {
added, err := h.Storage.Initialise(roomID, state)
if err != nil {
return err
logger.Err(err).Int("state", len(state)).Str("room", roomID).Msg("V2: failed to initialise room")
return
}
if !added {
// no new events
return nil
return
}
// we have new events, notify active connections
h.dispatcher.OnNewEvents(roomID, state, 0)
return err
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) SetTyping(roomID string, userIDs []string) (int64, error) {
return h.Storage.TypingTable.SetTyping(roomID, userIDs)
func (h *SyncLiveHandler) SetTyping(roomID string, userIDs []string) {
_, err := h.Storage.TypingTable.SetTyping(roomID, userIDs)
if err != nil {
logger.Err(err).Strs("users", userIDs).Str("room", roomID).Msg("V2: failed to store typing")
}
}
// Called from the v2 poller, implements V2DataReceiver
// Add messages for this device. If an error is returned, the poll loop is terminated as continuing
// would implicitly acknowledge these messages.
func (h *SyncLiveHandler) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) error {
func (h *SyncLiveHandler) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) {
_, err := h.Storage.ToDeviceTable.InsertMessages(deviceID, msgs)
return err
if err != nil {
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("msgs", len(msgs)).Msg("V2: failed to store to-device messages")
}
}
func (h *SyncLiveHandler) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {