mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Automatically start v2 pollers on startup
We can do this now because we store the access token for each device. Throttled at 16 concurrent sync requests to avoid causing thundering herds on startup.
This commit is contained in:
parent
ed9e9ed48c
commit
47b74a6be6
@ -64,5 +64,6 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
h.StartV2Pollers()
|
||||
syncv3.RunSyncV3Server(h, flagBindAddr, flagDestinationServer)
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
const AccountDataGlobalRoom = ""
|
||||
|
||||
type Client interface {
|
||||
WhoAmI(authHeader string) (string, error)
|
||||
DoSyncV2(authHeader, since string, isFirst bool) (*SyncResponse, int, error)
|
||||
WhoAmI(accessToken string) (string, error)
|
||||
DoSyncV2(accessToken, since string, isFirst bool) (*SyncResponse, int, error)
|
||||
}
|
||||
|
||||
// HTTPClient represents a Sync v2 Client.
|
||||
@ -24,13 +24,13 @@ type HTTPClient struct {
|
||||
DestinationServer string
|
||||
}
|
||||
|
||||
func (v *HTTPClient) WhoAmI(authHeader string) (string, error) {
|
||||
func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
|
||||
req, err := http.NewRequest("GET", v.DestinationServer+"/_matrix/client/r0/account/whoami", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sync-v3-proxy")
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
res, err := v.Client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -48,7 +48,7 @@ func (v *HTTPClient) WhoAmI(authHeader string) (string, error) {
|
||||
|
||||
// DoSyncV2 performs a sync v2 request. Returns the sync response and the response status code
|
||||
// or an error. Set isFirst=true on the first sync to force a timeout=0 sync to ensure snapiness.
|
||||
func (v *HTTPClient) DoSyncV2(authHeader, since string, isFirst bool) (*SyncResponse, int, error) {
|
||||
func (v *HTTPClient) DoSyncV2(accessToken, since string, isFirst bool) (*SyncResponse, int, error) {
|
||||
qps := "?"
|
||||
if isFirst {
|
||||
qps += "timeout=0"
|
||||
@ -62,7 +62,7 @@ func (v *HTTPClient) DoSyncV2(authHeader, since string, isFirst bool) (*SyncResp
|
||||
"GET", v.DestinationServer+"/_matrix/client/r0/sync"+qps, nil,
|
||||
)
|
||||
req.Header.Set("User-Agent", "sync-v3-proxy")
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, c
|
||||
// Note that we will immediately return if there is a poller for the same user but a different device.
|
||||
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
|
||||
// to-device msgs to decrypt E2EE roms.
|
||||
func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string, logger zerolog.Logger) {
|
||||
func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string, logger zerolog.Logger) {
|
||||
h.pollerMu.Lock()
|
||||
if !h.executorRunning {
|
||||
h.executorRunning = true
|
||||
@ -125,7 +125,7 @@ func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string,
|
||||
return
|
||||
}
|
||||
// replace the poller
|
||||
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h, h.txnCache, logger)
|
||||
poller = NewPoller(userID, accessToken, deviceID, h.v2Client, h, h.txnCache, logger)
|
||||
go poller.Poll(v2since)
|
||||
h.Pollers[deviceID] = poller
|
||||
|
||||
@ -233,12 +233,12 @@ func (h *PollerMap) OnAccountData(userID, roomID string, events []json.RawMessag
|
||||
|
||||
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in storage
|
||||
type Poller struct {
|
||||
userID string
|
||||
authorizationHeader string
|
||||
deviceID string
|
||||
client Client
|
||||
receiver V2DataReceiver
|
||||
logger zerolog.Logger
|
||||
userID string
|
||||
accessToken string
|
||||
deviceID string
|
||||
client Client
|
||||
receiver V2DataReceiver
|
||||
logger zerolog.Logger
|
||||
|
||||
// remember txn ids
|
||||
txnCache *TransactionIDCache
|
||||
@ -253,21 +253,21 @@ type Poller struct {
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewPoller(userID, authHeader, deviceID string, client Client, receiver V2DataReceiver, txnCache *TransactionIDCache, logger zerolog.Logger) *Poller {
|
||||
func NewPoller(userID, accessToken, deviceID string, client Client, receiver V2DataReceiver, txnCache *TransactionIDCache, logger zerolog.Logger) *Poller {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
return &Poller{
|
||||
authorizationHeader: authHeader,
|
||||
userID: userID,
|
||||
deviceID: deviceID,
|
||||
client: client,
|
||||
receiver: receiver,
|
||||
Terminated: false,
|
||||
logger: logger,
|
||||
e2eeMu: &sync.Mutex{},
|
||||
deviceListChanges: make(map[string]string),
|
||||
wg: &wg,
|
||||
txnCache: txnCache,
|
||||
accessToken: accessToken,
|
||||
userID: userID,
|
||||
deviceID: deviceID,
|
||||
client: client,
|
||||
receiver: receiver,
|
||||
Terminated: false,
|
||||
logger: logger,
|
||||
e2eeMu: &sync.Mutex{},
|
||||
deviceListChanges: make(map[string]string),
|
||||
wg: &wg,
|
||||
txnCache: txnCache,
|
||||
}
|
||||
}
|
||||
|
||||
@ -292,7 +292,7 @@ func (p *Poller) Poll(since string) {
|
||||
p.logger.Warn().Str("duration", waitTime.String()).Int("fail-count", failCount).Msg("Poller: waiting before next poll")
|
||||
timeSleep(waitTime)
|
||||
}
|
||||
resp, statusCode, err := p.client.DoSyncV2(p.authorizationHeader, since, firstTime)
|
||||
resp, statusCode, err := p.client.DoSyncV2(p.accessToken, since, firstTime)
|
||||
if err != nil {
|
||||
// check if temporary
|
||||
if statusCode != 401 {
|
||||
|
@ -116,6 +116,20 @@ func (s *Storage) Device(deviceID string) (*Device, error) {
|
||||
return &d, err
|
||||
}
|
||||
|
||||
func (s *Storage) AllDevices() (devices []Device, err error) {
|
||||
err = s.db.Select(&devices, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for i := range devices {
|
||||
devices[i].AccessToken, err = s.decrypt(devices[i].AccessTokenEncrypted)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
|
||||
var device Device
|
||||
device.AccessToken = accessToken
|
||||
|
@ -2,6 +2,7 @@ package sync2
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/sync-v3/testutils"
|
||||
@ -16,7 +17,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
deviceID := "TEST_DEVICE_ID"
|
||||
deviceID := "ALICE"
|
||||
accessToken := "my_access_token"
|
||||
store := NewStore(postgresConnectionString, "my_secret")
|
||||
device, err := store.InsertDevice(deviceID, accessToken)
|
||||
@ -51,6 +52,33 @@ func TestStorage(t *testing.T) {
|
||||
assertEqual(t, s2.UserID, "@alice:localhost", "Device.UserID mismatch")
|
||||
assertEqual(t, s2.DeviceID, deviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, s2.AccessToken, accessToken, "Device.AccessToken mismatch")
|
||||
|
||||
// check all devices works
|
||||
deviceID2 := "BOB"
|
||||
accessToken2 := "BOB_ACCESS_TOKEN"
|
||||
bobDevice, err := store.InsertDevice(deviceID2, accessToken2)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
devices, err := store.AllDevices()
|
||||
if err != nil {
|
||||
t.Fatalf("AllDevices: %s", err)
|
||||
}
|
||||
sort.Slice(devices, func(i, j int) bool {
|
||||
return devices[i].DeviceID < devices[j].DeviceID
|
||||
})
|
||||
wantDevices := []*Device{
|
||||
device, bobDevice,
|
||||
}
|
||||
if len(devices) != len(wantDevices) {
|
||||
t.Fatalf("AllDevices: got %d devices, want %d", len(devices), len(wantDevices))
|
||||
}
|
||||
for i := range devices {
|
||||
assertEqual(t, devices[i].Since, wantDevices[i].Since, "Device.Since mismatch")
|
||||
assertEqual(t, devices[i].UserID, wantDevices[i].UserID, "Device.UserID mismatch")
|
||||
assertEqual(t, devices[i].DeviceID, wantDevices[i].DeviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, devices[i].AccessToken, wantDevices[i].AccessToken, "Device.AccessToken mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, got, want, msg string) {
|
||||
|
@ -94,6 +94,39 @@ func (h *SyncLiveHandler) Teardown() {
|
||||
h.Storage.Teardown()
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) StartV2Pollers() {
|
||||
devices, err := h.V2Store.AllDevices()
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("StartV2Pollers: failed to query devices")
|
||||
return
|
||||
}
|
||||
logger.Info().Int("num_devices", len(devices)).Msg("StartV2Pollers")
|
||||
// how many concurrent pollers to make at startup.
|
||||
// Too high and this will flood the upstream server with sync requests at startup.
|
||||
// Too low and this will take ages for the v2 pollers to startup.
|
||||
numWorkers := 16
|
||||
ch := make(chan sync2.Device, len(devices))
|
||||
for _, d := range devices {
|
||||
ch <- d
|
||||
}
|
||||
close(ch)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for d := range ch {
|
||||
h.PollerMap.EnsurePolling(
|
||||
d.AccessToken, d.UserID, d.DeviceID, d.Since,
|
||||
logger.With().Str("user_id", d.UserID).Logger(),
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
logger.Info().Msg("StartV2Pollers finished")
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
@ -222,7 +255,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
}
|
||||
}
|
||||
if v2device.UserID == "" {
|
||||
v2device.UserID, err = h.V2.WhoAmI(req.Header.Get("Authorization"))
|
||||
v2device.UserID, err = h.V2.WhoAmI(accessToken)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
|
||||
return nil, &internal.HandlerError{
|
||||
@ -238,7 +271,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
|
||||
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
|
||||
h.PollerMap.EnsurePolling(
|
||||
req.Header.Get("Authorization"), v2device.UserID, v2device.DeviceID, v2device.Since,
|
||||
accessToken, v2device.UserID, v2device.DeviceID, v2device.Since,
|
||||
hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(),
|
||||
)
|
||||
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")
|
||||
|
3
v3.go
3
v3.go
@ -68,6 +68,9 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server string) {
|
||||
chain: []func(next http.Handler) http.Handler{
|
||||
hlog.NewHandler(logger),
|
||||
hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
if r.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
hlog.FromRequest(r).Info().
|
||||
Str("method", r.Method).
|
||||
Int("status", status).
|
||||
|
Loading…
x
Reference in New Issue
Block a user