Kegan Dougal 905f815794 bugfix: ensure sentry username/id values are correct
Previously there were wrong under high concurrency due to
using the global hub instead of a per-request hub.
2024-03-05 11:30:05 +00:00

893 lines
32 KiB
Go

package handler
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"sync"
"syscall"
"time"
"github.com/getsentry/sentry-go"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/pubsub"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/caches"
"github.com/matrix-org/sliding-sync/sync3/extensions"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"github.com/rs/zerolog/log"
)
const DefaultSessionID = "default"
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
// This is a net.http Handler for sync v3. It is responsible for pairing requests to Conns and to
// ensure that the sync v2 poller is running for this client.
type SyncLiveHandler struct {
V2 sync2.Client
Storage *state.Storage
V2Store *sync2.Storage
V2Sub *pubsub.V2Sub
EnsurePoller *EnsurePoller
ConnMap *sync3.ConnMap
Extensions *extensions.Handler
// inserts are done by v2 poll loops, selects are done by v3 request threads
// but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map
// > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys.
userCaches *sync.Map // map[user_id]*UserCache
Dispatcher *sync3.Dispatcher
GlobalCache *caches.GlobalCache
maxPendingEventUpdates int
maxTransactionIDDelay time.Duration
setupHistVec *prometheus.HistogramVec
histVec *prometheus.HistogramVec
slowReqs prometheus.Counter
// destroyedConns is the number of connections that have been destoryed after
// a room invalidation payload.
// TODO: could make this a CounterVec labelled by reason, to track expiry due
// to update buffer filling, expiry due to inactivity, etc.
destroyedConns prometheus.Counter
}
func NewSync3Handler(
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int,
maxTransactionIDDelay time.Duration,
) (*SyncLiveHandler, error) {
logger.Info().Msg("creating handler")
sh := &SyncLiveHandler{
V2: v2Client,
Storage: store,
V2Store: storev2,
ConnMap: sync3.NewConnMap(enablePrometheus, 30*time.Minute),
userCaches: &sync.Map{},
Dispatcher: sync3.NewDispatcher(),
GlobalCache: caches.NewGlobalCache(store),
maxPendingEventUpdates: maxPendingEventUpdates,
maxTransactionIDDelay: maxTransactionIDDelay,
}
sh.Extensions = &extensions.Handler{
Store: store,
E2EEFetcher: sh,
GlobalCache: sh.GlobalCache,
}
if enablePrometheus {
sh.addPrometheusMetrics()
pub = pubsub.NewPromNotifier(pub, "api")
}
// set up pubsub mechanism to start from this point
sh.EnsurePoller = NewEnsurePoller(pub, enablePrometheus)
sh.V2Sub = pubsub.NewV2Sub(sub, sh)
return sh, nil
}
func (h *SyncLiveHandler) Startup(storeSnapshot *state.StartupSnapshot) error {
if err := h.Dispatcher.Startup(storeSnapshot.AllJoinedMembers); err != nil {
return fmt.Errorf("failed to load sync3.Dispatcher: %s", err)
}
h.Dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, h.GlobalCache)
if err := h.GlobalCache.Startup(storeSnapshot.GlobalMetadata); err != nil {
return fmt.Errorf("failed to populate global cache: %s", err)
}
return nil
}
// Listen starts all consumers
func (h *SyncLiveHandler) Listen() {
go func() {
defer internal.ReportPanicsToSentry()
err := h.V2Sub.Listen()
if err != nil {
logger.Err(err).Msg("Failed to listen for v2 messages")
sentry.CaptureException(err)
}
}()
}
// used in tests to close postgres connections
func (h *SyncLiveHandler) Teardown() {
// tear down DB conns
h.Storage.Teardown()
h.V2Sub.Teardown()
h.EnsurePoller.Teardown()
h.ConnMap.Teardown()
if h.setupHistVec != nil {
prometheus.Unregister(h.setupHistVec)
}
if h.histVec != nil {
prometheus.Unregister(h.histVec)
}
if h.slowReqs != nil {
prometheus.Unregister(h.slowReqs)
}
if h.destroyedConns != nil {
prometheus.Unregister(h.destroyedConns)
}
}
func (h *SyncLiveHandler) addPrometheusMetrics() {
h.setupHistVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "setup_duration_secs",
Help: "Time taken in seconds after receiving a request before we start calculating a sliding sync response.",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"initial"})
h.histVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "process_duration_secs",
Help: "Time taken in seconds for the sliding sync response to be calculated, excludes long polling",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"initial"})
h.slowReqs = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "slow_requests",
Help: "Counter of slow (>=50s) requests, initial or otherwise.",
})
h.destroyedConns = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "destroyed_conns",
Help: "Counter of conns that were destroyed.",
})
prometheus.MustRegister(h.setupHistVec)
prometheus.MustRegister(h.histVec)
prometheus.MustRegister(h.slowReqs)
prometheus.MustRegister(h.destroyedConns)
}
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
err := h.serve(w, req)
if err != nil {
herr, ok := err.(*internal.HandlerError)
if !ok {
herr = &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
if herr.ErrCode != "M_UNKNOWN_POS" {
// artificially wait a bit before sending back the error
// this guards against tightlooping when the client hammers the server with invalid requests,
// but not for M_UNKNOWN_POS which we expect to send back after expiring a client's connection.
// We want to recover rapidly in that scenario, hence not sleeping.
time.Sleep(time.Second)
}
w.WriteHeader(herr.StatusCode)
w.Write(herr.JSON())
}
}
// Entry point for sync v3
func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error {
start := time.Now()
defer func() {
dur := time.Since(start)
if dur > 50*time.Second {
if h.slowReqs != nil {
h.slowReqs.Add(1.0)
}
internal.DecorateLogger(req.Context(), log.Warn()).Dur("duration", dur).Msg("slow request")
}
}()
var requestBody sync3.Request
if req.ContentLength != 0 {
defer req.Body.Close()
if err := json.NewDecoder(req.Body).Decode(&requestBody); err != nil {
log.Warn().Err(err).Msg("failed to read/decode request body")
return &internal.HandlerError{
StatusCode: 400,
Err: err,
}
}
if err := requestBody.Validate(); err != nil {
return &internal.HandlerError{
StatusCode: 400,
Err: err,
}
}
}
if requestBody.ConnID != "" {
req = req.WithContext(internal.SetAttributeOnContext(req.Context(), internal.OTLPTagConnID, requestBody.ConnID))
}
if requestBody.TxnID != "" {
req = req.WithContext(internal.SetAttributeOnContext(req.Context(), internal.OTLPTagTxnID, requestBody.TxnID))
}
hlog.FromRequest(req).UpdateContext(func(c zerolog.Context) zerolog.Context {
c.Str("txn_id", requestBody.TxnID)
return c
})
for listKey, l := range requestBody.Lists {
if l.Ranges != nil && !l.Ranges.Valid() {
return &internal.HandlerError{
StatusCode: 400,
Err: fmt.Errorf("list[%v] invalid ranges %v", listKey, l.Ranges),
}
}
}
logErrorOrWarning := func(msg string, herr *internal.HandlerError) {
if herr.StatusCode >= 500 {
hlog.FromRequest(req).Err(herr).Msg(msg)
} else {
hlog.FromRequest(req).Warn().Err(herr).Msg(msg)
}
}
cancelCtx, cancel := context.WithCancel(req.Context())
req = req.WithContext(cancelCtx)
req, conn, herr := h.setupConnection(req, cancel, &requestBody, req.URL.Query().Get("pos") != "")
if herr != nil {
logErrorOrWarning("failed to get or create Conn", herr)
return herr
}
// set pos and timeout if specified
cpos, herr := parseIntFromQuery(req.URL, "pos")
if herr != nil {
return herr
}
requestBody.SetPos(cpos)
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
var timeout int
if req.URL.Query().Get("timeout") == "" {
timeout = sync3.DefaultTimeoutMSecs
} else {
timeout64, herr := parseIntFromQuery(req.URL, "timeout")
if herr != nil {
return herr
}
timeout = int(timeout64)
}
requestBody.SetTimeoutMSecs(timeout)
log.Trace().Int("timeout", timeout).Msg("recv")
resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody, start)
if herr != nil {
logErrorOrWarning("failed to OnIncomingRequest", herr)
return herr
}
// for logging
var numToDeviceEvents int
if resp.Extensions.ToDevice != nil {
numToDeviceEvents = len(resp.Extensions.ToDevice.Events)
}
var numGlobalAccountData int
if resp.Extensions.AccountData != nil {
numGlobalAccountData = len(resp.Extensions.AccountData.Global)
}
var numChangedDevices, numLeftDevices int
if resp.Extensions.E2EE != nil && resp.Extensions.E2EE.DeviceLists != nil {
numChangedDevices = len(resp.Extensions.E2EE.DeviceLists.Changed)
numLeftDevices = len(resp.Extensions.E2EE.DeviceLists.Left)
}
internal.SetRequestContextResponseInfo(
req.Context(), cpos, resp.PosInt(), len(resp.Rooms), requestBody.TxnID, numToDeviceEvents, numGlobalAccountData,
numChangedDevices, numLeftDevices, requestBody.ConnID, len(requestBody.Lists), len(requestBody.RoomSubscriptions), len(requestBody.UnsubscribeRooms),
)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(resp); err != nil {
herr = &internal.HandlerError{
StatusCode: 500,
Err: err,
}
if errors.Is(err, syscall.EPIPE) {
// Client closed the connection. Use a 499 status code internally so that
// we consider this a warning rather than an error. 499 is nonstandard,
// but a) the client has already gone, so this status code will only show
// up in our logs; and b) nginx uses 499 to mean "Client Closed Request",
// see e.g.
// https://www.nginx.com/resources/wiki/extending/api/http/#http-return-codes
herr.StatusCode = 499
}
logErrorOrWarning("failed to JSON-encode result", herr)
return herr
}
return nil
}
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.CancelFunc, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
ctx, task := internal.StartTask(req.Context(), "setupConnection")
req = req.WithContext(ctx)
defer task.End()
var conn *sync3.Conn
// Extract an access token
accessToken, err := internal.ExtractAccessToken(req)
if err != nil || accessToken == "" {
hlog.FromRequest(req).Warn().Err(err).Msg("failed to get access token from request")
return req, nil, &internal.HandlerError{
StatusCode: http.StatusUnauthorized,
Err: err,
}
}
// Try to lookup a record of this token
var token *sync2.Token
token, err = h.V2Store.TokensTable.Token(accessToken)
if err != nil {
if err == sql.ErrNoRows {
hlog.FromRequest(req).Info().Msg("Received connection from unknown access token, querying with homeserver")
newToken, herr := h.identifyUnknownAccessToken(req.Context(), accessToken, hlog.FromRequest(req))
if herr != nil {
return req, nil, herr
}
token = newToken
} else {
hlog.FromRequest(req).Err(err).Msg("Failed to lookup access token")
return req, nil, &internal.HandlerError{
StatusCode: http.StatusInternalServerError,
Err: err,
}
}
}
req = req.WithContext(internal.SetAttributeOnContext(req.Context(), internal.OTLPTagUserID, token.UserID))
req = req.WithContext(internal.SetAttributeOnContext(req.Context(), internal.OTLPTagDeviceID, token.DeviceID))
log := hlog.FromRequest(req).With().
Str("user", token.UserID).
Str("device", token.DeviceID).
Str("conn", syncReq.ConnID).
Logger()
req = req.WithContext(internal.AssociateUserIDWithRequest(req.Context(), token.UserID, token.DeviceID))
internal.Logf(req.Context(), "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID)
// Record the fact that we've recieved a request from this token
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
if err != nil {
// Not fatal---log and continue.
log.Warn().Err(err).Msg("Unable to update last seen timestamp")
}
connID := sync3.ConnID{
UserID: token.UserID,
DeviceID: token.DeviceID,
CID: syncReq.ConnID,
}
// client thinks they have a connection
if containsPos {
// Lookup the connection
conn = h.ConnMap.Conn(connID)
if conn != nil {
conn.SetCancelCallback(cancel)
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
return req, conn, nil
}
// conn doesn't exist, we probably nuked it.
return req, nil, internal.ExpiredSessionError()
}
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
expiredToken := h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
if expiredToken {
log.Error().Msg("EnsurePolling failed, returning 401")
// Assumption: the only way that EnsurePolling fails is if the access token is invalid.
return req, nil, &internal.HandlerError{
StatusCode: http.StatusUnauthorized,
ErrCode: "M_UNKNOWN_TOKEN",
Err: fmt.Errorf("EnsurePolling failed: access token invalid or invalidated"),
}
}
log.Trace().Msg("poller exists and is running")
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
// We'll be quicker next time as the poller will already exist.
if req.Context().Err() != nil {
log.Warn().Msg("client gave up, not creating connection")
return req, nil, &internal.HandlerError{
StatusCode: 400,
Err: req.Context().Err(),
}
}
userCache, err := h.userCache(token.UserID)
if err != nil {
log.Warn().Err(err).Msg("failed to load user cache")
return req, nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
// once we have the conn, make sure our metrics are correct
defer h.ConnMap.UpdateMetrics()
// Now the v2 side of things are running, we can make a v3 live sync conn
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn = h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
})
log.Info().Msg("created new connection")
return req, conn, nil
}
func (h *SyncLiveHandler) identifyUnknownAccessToken(ctx context.Context, accessToken string, logger *zerolog.Logger) (*sync2.Token, *internal.HandlerError) {
// We don't recognise the given accessToken. Ask the homeserver who owns it.
userID, deviceID, err := h.V2.WhoAmI(ctx, accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
ErrCode: "M_UNKNOWN_TOKEN",
}
}
log.Warn().Err(err).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
Err: err,
}
}
var token *sync2.Token
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
// Create a brand-new row for this token.
token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now())
if err != nil {
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
return err
}
// Ensure we have a device row for this token.
err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID)
if err != nil {
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
return err
}
return nil
})
if err != nil {
return nil, &internal.HandlerError{StatusCode: 500, Err: err}
}
return token, nil
}
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
c, ok := h.userCaches.Load(userID)
if ok {
return c.(*caches.UserCache)
}
return nil
}
// userCache fetches an existing caches.UserCache for this user if one exists. If not,
// it
// - creates a blank caches.UserCache struct,
// - fires callbacks on that struct as necessary to populate it with initial state,
// - stores the struct so it will not be recreated in the future, and
// - registers the cache with the Dispatcher.
//
// Some extra initialisation takes place in caches.UserCache.OnRegister.
// TODO: the calls to uc.OnBlahBlah etc can be moved into NewUserCache, now that the
//
// UserCache holds a reference to the storage layer.
func (h *SyncLiveHandler) userCache(userID string) (*caches.UserCache, error) {
// bail if we already have a cache
c, ok := h.userCaches.Load(userID)
if ok {
return c.(*caches.UserCache), nil
}
uc := caches.NewUserCache(userID, h.GlobalCache, h.Storage, h, h.Dispatcher)
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
err := h.Storage.UnreadTable.SelectAllNonZeroCountsForUser(userID, func(roomID string, highlightCount, notificationCount int) {
uc.OnUnreadCounts(context.Background(), roomID, &highlightCount, &notificationCount)
})
if err != nil {
return nil, fmt.Errorf("failed to load unread counts: %s", err)
}
// select the DM account data event and set DM room status
directEvent, err := h.Storage.AccountData(userID, sync2.AccountDataGlobalRoom, []string{"m.direct"})
if err != nil {
return nil, fmt.Errorf("failed to load direct message status for rooms: %s", err)
}
if len(directEvent) == 1 {
uc.OnAccountData(context.Background(), []state.AccountData{directEvent[0]})
}
// select the ignored users account data event and set ignored user list
ignoreEvent, err := h.Storage.AccountData(userID, sync2.AccountDataGlobalRoom, []string{"m.ignored_user_list"})
if err != nil {
return nil, fmt.Errorf("failed to load ignored user list for user %s: %w", userID, err)
}
if len(ignoreEvent) == 1 {
uc.OnAccountData(context.Background(), []state.AccountData{ignoreEvent[0]})
}
// select all room tag account data and set it
tagEvents, err := h.Storage.RoomAccountDatasWithType(userID, "m.tag")
if err != nil {
return nil, fmt.Errorf("failed to load room tags %s", err)
}
if len(tagEvents) > 0 {
uc.OnAccountData(context.Background(), tagEvents)
}
// select outstanding invites
invites, err := h.Storage.InvitesTable.SelectAllInvitesForUser(userID)
if err != nil {
return nil, fmt.Errorf("failed to load outstanding invites for user: %s", err)
}
for roomID, inviteState := range invites {
uc.OnInvite(context.Background(), roomID, inviteState)
}
// use LoadOrStore here else we can race as 2 brand new /sync conns can both get to this point
// at the same time
actualUC, loaded := h.userCaches.LoadOrStore(userID, uc)
uc = actualUC.(*caches.UserCache)
if !loaded { // we actually inserted the cache, so register with the dispatcher.
if err = h.Dispatcher.Register(context.Background(), userID, uc); err != nil {
h.Dispatcher.Unregister(userID)
h.userCaches.Delete(userID)
return nil, fmt.Errorf("failed to register user cache with dispatcher: %s", err)
}
}
return uc, nil
}
// Implements E2EEFetcher
// DeviceData returns the latest device data for this user. isInitial should be set if this is for
// an initial /sync request.
func (h *SyncLiveHandler) DeviceData(ctx context.Context, userID, deviceID string, isInitial bool) *internal.DeviceData {
// We have 2 sources of DeviceData:
// - pubsub updates stored in deviceDataMap
// - the database itself
// Most of the time we would like to pull from deviceDataMap and ignore the database entirely,
// however in most cases we need to do a database hit to atomically swap device lists over. Why?
//
// changed|left are much more important and special because:
//
// - sync v2 only sends deltas, rather than all of them unlike otk counts and fallback key types
// - we MUST guarantee that we send this to the client, as missing a user in `changed` can result in us having the wrong
// device lists for that user resulting in encryption breaking when the client encrypts for known devices.
// - we MUST NOT continually send the same device list changes on each subsequent request i.e we need to delete them
//
// We accumulate device list deltas on the v2 poller side, upserting into the database and sending pubsub notifs for.
// The accumulated deltas are stored in DeviceData.DeviceLists.New
// To guarantee we send this to the client, we need to consider a few failure modes:
// - The response is lost and the request is retried to this proxy -> ConnMap caches will get it.
// - The response is lost and the client doesn't retry until the connection expires. They then retry ->
// ConnMap cache miss, sends HTTP 400 due to invalid ?pos=
// - The response is received and the client sends the next request -> do not send deltas.
// To handle the case where responses are lost, we just need to see if this is an initial request
// and if so, return a "Read-Only" snapshot of the last sent device list changes. This means we may send
// duplicate device list changes if the response did in fact get to the client and the next request hit a
// new proxy, but that's better than losing updates. In this scenario, we do not delete any data.
// To ensure we delete device list updates over time, we now want to swap what was New to Sent and then
// send Sent. That means we forget what was originally in Sent and New is empty. We need to read and swap
// atomically else the v2 poller may insert a new update after the read but before the swap (DELETE on New)
// To ensure atomicity, we need to do this in a txn.
// Atomically move New to Sent so New is now empty and what was originally in Sent is forgotten.
shouldSwap := !isInitial
dd, err := h.Storage.DeviceDataTable.Select(userID, deviceID, shouldSwap)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to SelectAndSwap device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return nil
}
return dd
}
// Implements TransactionIDFetcher
func (h *SyncLiveHandler) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
eventIDToTxnID, err := h.Storage.TransactionsTable.Select(userID, deviceID, eventIDs)
if err != nil {
logger.Warn().Str("err", err.Error()).Str("device", deviceID).Msg("failed to select txn IDs for events")
}
return
}
func (h *SyncLiveHandler) OnInitialSyncComplete(p *pubsub.V2InitialSyncComplete) {
h.EnsurePoller.OnInitialSyncComplete(p)
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) Accumulate(p *pubsub.V2Accumulate) {
ctx, task := internal.StartTask(context.Background(), "Accumulate")
defer task.End()
// note: events is sorted in ascending NID order, event if p.EventNIDs isn't.
events, err := h.Storage.EventNIDs(p.EventNIDs)
if err != nil {
logger.Err(err).Str("room", p.RoomID).Msg("Accumulate: failed to EventNIDs")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
if len(events) == 0 {
return
}
internal.Logf(ctx, "room", fmt.Sprintf("%s: %d events", p.RoomID, len(events)))
// we have new events, notify active connections
for i := range events {
h.Dispatcher.OnNewEvent(ctx, p.RoomID, events[i], p.EventNIDs[i])
}
}
// OnTransactionID is called from the v2 poller, implements V2DataReceiver.
func (h *SyncLiveHandler) OnTransactionID(p *pubsub.V2TransactionID) {
_, task := internal.StartTask(context.Background(), "TransactionID")
defer task.End()
// There is some event E for which we now have a transaction ID, or else now know
// that we will never get a transaction ID. In either case, tell the sender's
// connections to unblock that event in the transaction ID waiter.
h.ConnMap.ClearUpdateQueues(p.UserID, p.RoomID, p.NID)
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) Initialise(p *pubsub.V2Initialise) {
ctx, task := internal.StartTask(context.Background(), "Initialise")
defer task.End()
state, err := h.Storage.StateSnapshot(p.SnapshotNID)
if err != nil {
logger.Err(err).Int64("snap", p.SnapshotNID).Str("room", p.RoomID).Msg("Initialise: failed to get StateSnapshot")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
// we have new state, notify caches
h.Dispatcher.OnNewInitialRoomState(ctx, p.RoomID, state)
}
func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
ctx, task := internal.StartTask(context.Background(), "OnUnreadCounts")
defer task.End()
userCache, ok := h.userCaches.Load(p.UserID)
if !ok {
return
}
userCache.(*caches.UserCache).OnUnreadCounts(ctx, p.RoomID, p.HighlightCount, p.NotificationCount)
}
// push device data updates on waiting conns (otk counts, device list changes)
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
for userID, deviceIDs := range p.UserIDToDeviceIDs {
for _, deviceID := range deviceIDs {
conns := h.ConnMap.Conns(userID, deviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
}
}
}
}
func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
defer task.End()
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceEventsUpdate{})
}
}
func (h *SyncLiveHandler) OnInvite(p *pubsub.V2InviteRoom) {
ctx, task := internal.StartTask(context.Background(), "OnInvite")
defer task.End()
userCache, ok := h.userCaches.Load(p.UserID)
if !ok {
return
}
inviteState, err := h.Storage.InvitesTable.SelectInviteState(p.UserID, p.RoomID)
if err != nil {
logger.Err(err).Str("user", p.UserID).Str("room", p.RoomID).Msg("failed to get invite state")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
userCache.(*caches.UserCache).OnInvite(ctx, p.RoomID, inviteState)
}
func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
ctx, task := internal.StartTask(context.Background(), "OnLeftRoom")
defer task.End()
userCache, ok := h.userCaches.Load(p.UserID)
if !ok {
return
}
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
}
func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {
ctx, task := internal.StartTask(context.Background(), "OnReceipt")
defer task.End()
// split receipts into public / private
userToPrivateReceipts := make(map[string][]internal.Receipt)
publicReceipts := make([]internal.Receipt, 0, len(p.Receipts))
for _, r := range p.Receipts {
if r.IsPrivate {
userToPrivateReceipts[r.UserID] = append(userToPrivateReceipts[r.UserID], r)
} else {
publicReceipts = append(publicReceipts, r)
}
}
// always send private receipts, directly to the connected user cache if one exists
for userID, privateReceipts := range userToPrivateReceipts {
userCache, ok := h.userCaches.Load(userID)
if !ok {
continue
}
for _, pr := range privateReceipts {
userCache.(*caches.UserCache).OnReceipt(ctx, pr)
}
}
if len(publicReceipts) == 0 {
return
}
// inform the dispatcher of global receipts
for _, pr := range publicReceipts {
h.Dispatcher.OnReceipt(ctx, pr)
}
}
func (h *SyncLiveHandler) OnTyping(p *pubsub.V2Typing) {
ctx, task := internal.StartTask(context.Background(), "OnTyping")
defer task.End()
rooms := h.GlobalCache.LoadRooms(ctx, p.RoomID)
if rooms[p.RoomID] != nil {
if reflect.DeepEqual(p.EphemeralEvent, rooms[p.RoomID].TypingEvent) {
return // it's a duplicate, which happens when 2+ users are in the same room
}
}
h.Dispatcher.OnEphemeralEvent(ctx, p.RoomID, p.EphemeralEvent)
}
func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
ctx, task := internal.StartTask(context.Background(), "OnAccountData")
defer task.End()
userCache, ok := h.userCaches.Load(p.UserID)
if !ok {
return
}
data, err := h.Storage.AccountData(p.UserID, p.RoomID, p.Types)
if err != nil {
logger.Err(err).Str("user", p.UserID).Str("room", p.RoomID).Msg("OnAccountData: failed to lookup")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
userCache.(*caches.UserCache).OnAccountData(ctx, data)
}
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.EnsurePoller.OnExpiredToken(p)
h.ConnMap.CloseConnsForDevice(p.UserID, p.DeviceID)
}
func (h *SyncLiveHandler) OnStateRedaction(p *pubsub.V2StateRedaction) {
// We only need to reload the global metadata here: mercifully, there isn't anything
// in the user cache that needs to be reloaded after state gets redacted.
ctx, task := internal.StartTask(context.Background(), "OnStateRedaction")
defer task.End()
h.GlobalCache.OnInvalidateRoom(ctx, p.RoomID)
}
func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) {
ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom")
defer task.End()
// 1. Reload the global cache.
h.GlobalCache.OnInvalidateRoom(ctx, p.RoomID)
// Work out who is affected.
joins, invites, leaves, err := h.Storage.FetchMemberships(p.RoomID)
involvedUsers := make([]string, 0, len(joins)+len(invites)+len(leaves))
involvedUsers = append(involvedUsers, joins...)
involvedUsers = append(involvedUsers, invites...)
involvedUsers = append(involvedUsers, leaves...)
if err != nil {
hub := internal.GetSentryHubFromContextOrDefault(ctx)
hub.WithScope(func(scope *sentry.Scope) {
scope.SetContext(internal.SentryCtxKey, map[string]any{
"room_id": p.RoomID,
})
hub.CaptureException(err)
})
logger.Err(err).
Str("room_id", p.RoomID).
Msg("Failed to fetch members after cache invalidation")
return
}
// 2. Reload the joined-room tracker.
h.Dispatcher.OnInvalidateRoom(p.RoomID, joins, invites)
// 3. Destroy involved users' caches.
// We filter to only those users which had a userCache registered to receive updates.
unregistered := h.Dispatcher.UnregisterBulk(involvedUsers)
for _, userID := range unregistered {
h.userCaches.Delete(userID)
}
// 4. Destroy involved users' connections.
// Since creating a conn creates a user cache, it is safe to loop over
destroyed := h.ConnMap.CloseConnsForUsers(unregistered)
if h.destroyedConns != nil {
h.destroyedConns.Add(float64(destroyed))
}
// invalidations are rare and dangerous if we get it wrong, so log information about it.
logger.Info().
Str("room_id", p.RoomID).Int("joins", len(joins)).Int("invites", len(invites)).Int("leaves", len(leaves)).
Int("del_user_caches", len(unregistered)).Int("conns_destroyed", destroyed).Msg("OnInvalidateRoom")
}
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
queryPos := u.Query().Get(param)
if queryPos != "" {
var err error
result, err = strconv.ParseInt(queryPos, 10, 64)
if err != nil {
return 0, &internal.HandlerError{
StatusCode: 400,
Err: fmt.Errorf("invalid %s: %s", param, queryPos),
}
}
}
return
}