Automatically start v2 pollers on startup

We can do this now because we store the access token for each device.

Throttled at 16 concurrent sync requests to avoid causing
thundering herds on startup.
This commit is contained in:
Kegan Dougal 2022-07-14 10:48:45 +01:00
parent ed9e9ed48c
commit 47b74a6be6
7 changed files with 109 additions and 30 deletions

View File

@ -64,5 +64,6 @@ func main() {
if err != nil {
panic(err)
}
h.StartV2Pollers()
syncv3.RunSyncV3Server(h, flagBindAddr, flagDestinationServer)
}

View File

@ -13,8 +13,8 @@ import (
const AccountDataGlobalRoom = ""
type Client interface {
WhoAmI(authHeader string) (string, error)
DoSyncV2(authHeader, since string, isFirst bool) (*SyncResponse, int, error)
WhoAmI(accessToken string) (string, error)
DoSyncV2(accessToken, since string, isFirst bool) (*SyncResponse, int, error)
}
// HTTPClient represents a Sync v2 Client.
@ -24,13 +24,13 @@ type HTTPClient struct {
DestinationServer string
}
func (v *HTTPClient) WhoAmI(authHeader string) (string, error) {
func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
req, err := http.NewRequest("GET", v.DestinationServer+"/_matrix/client/r0/account/whoami", nil)
if err != nil {
return "", err
}
req.Header.Set("User-Agent", "sync-v3-proxy")
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+accessToken)
res, err := v.Client.Do(req)
if err != nil {
return "", err
@ -48,7 +48,7 @@ func (v *HTTPClient) WhoAmI(authHeader string) (string, error) {
// DoSyncV2 performs a sync v2 request. Returns the sync response and the response status code
// or an error. Set isFirst=true on the first sync to force a timeout=0 sync to ensure snapiness.
func (v *HTTPClient) DoSyncV2(authHeader, since string, isFirst bool) (*SyncResponse, int, error) {
func (v *HTTPClient) DoSyncV2(accessToken, since string, isFirst bool) (*SyncResponse, int, error) {
qps := "?"
if isFirst {
qps += "timeout=0"
@ -62,7 +62,7 @@ func (v *HTTPClient) DoSyncV2(authHeader, since string, isFirst bool) (*SyncResp
"GET", v.DestinationServer+"/_matrix/client/r0/sync"+qps, nil,
)
req.Header.Set("User-Agent", "sync-v3-proxy")
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+accessToken)
if err != nil {
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
}

View File

@ -109,7 +109,7 @@ func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, c
// Note that we will immediately return if there is a poller for the same user but a different device.
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
// to-device msgs to decrypt E2EE roms.
func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string, logger zerolog.Logger) {
func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string, logger zerolog.Logger) {
h.pollerMu.Lock()
if !h.executorRunning {
h.executorRunning = true
@ -125,7 +125,7 @@ func (h *PollerMap) EnsurePolling(authHeader, userID, deviceID, v2since string,
return
}
// replace the poller
poller = NewPoller(userID, authHeader, deviceID, h.v2Client, h, h.txnCache, logger)
poller = NewPoller(userID, accessToken, deviceID, h.v2Client, h, h.txnCache, logger)
go poller.Poll(v2since)
h.Pollers[deviceID] = poller
@ -233,12 +233,12 @@ func (h *PollerMap) OnAccountData(userID, roomID string, events []json.RawMessag
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in storage
type Poller struct {
userID string
authorizationHeader string
deviceID string
client Client
receiver V2DataReceiver
logger zerolog.Logger
userID string
accessToken string
deviceID string
client Client
receiver V2DataReceiver
logger zerolog.Logger
// remember txn ids
txnCache *TransactionIDCache
@ -253,21 +253,21 @@ type Poller struct {
wg *sync.WaitGroup
}
func NewPoller(userID, authHeader, deviceID string, client Client, receiver V2DataReceiver, txnCache *TransactionIDCache, logger zerolog.Logger) *Poller {
func NewPoller(userID, accessToken, deviceID string, client Client, receiver V2DataReceiver, txnCache *TransactionIDCache, logger zerolog.Logger) *Poller {
var wg sync.WaitGroup
wg.Add(1)
return &Poller{
authorizationHeader: authHeader,
userID: userID,
deviceID: deviceID,
client: client,
receiver: receiver,
Terminated: false,
logger: logger,
e2eeMu: &sync.Mutex{},
deviceListChanges: make(map[string]string),
wg: &wg,
txnCache: txnCache,
accessToken: accessToken,
userID: userID,
deviceID: deviceID,
client: client,
receiver: receiver,
Terminated: false,
logger: logger,
e2eeMu: &sync.Mutex{},
deviceListChanges: make(map[string]string),
wg: &wg,
txnCache: txnCache,
}
}
@ -292,7 +292,7 @@ func (p *Poller) Poll(since string) {
p.logger.Warn().Str("duration", waitTime.String()).Int("fail-count", failCount).Msg("Poller: waiting before next poll")
timeSleep(waitTime)
}
resp, statusCode, err := p.client.DoSyncV2(p.authorizationHeader, since, firstTime)
resp, statusCode, err := p.client.DoSyncV2(p.accessToken, since, firstTime)
if err != nil {
// check if temporary
if statusCode != 401 {

View File

@ -116,6 +116,20 @@ func (s *Storage) Device(deviceID string) (*Device, error) {
return &d, err
}
func (s *Storage) AllDevices() (devices []Device, err error) {
err = s.db.Select(&devices, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices`)
if err != nil {
return
}
for i := range devices {
devices[i].AccessToken, err = s.decrypt(devices[i].AccessTokenEncrypted)
if err != nil {
return
}
}
return
}
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
var device Device
device.AccessToken = accessToken

View File

@ -2,6 +2,7 @@ package sync2
import (
"os"
"sort"
"testing"
"github.com/matrix-org/sync-v3/testutils"
@ -16,7 +17,7 @@ func TestMain(m *testing.M) {
}
func TestStorage(t *testing.T) {
deviceID := "TEST_DEVICE_ID"
deviceID := "ALICE"
accessToken := "my_access_token"
store := NewStore(postgresConnectionString, "my_secret")
device, err := store.InsertDevice(deviceID, accessToken)
@ -51,6 +52,33 @@ func TestStorage(t *testing.T) {
assertEqual(t, s2.UserID, "@alice:localhost", "Device.UserID mismatch")
assertEqual(t, s2.DeviceID, deviceID, "Device.DeviceID mismatch")
assertEqual(t, s2.AccessToken, accessToken, "Device.AccessToken mismatch")
// check all devices works
deviceID2 := "BOB"
accessToken2 := "BOB_ACCESS_TOKEN"
bobDevice, err := store.InsertDevice(deviceID2, accessToken2)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
devices, err := store.AllDevices()
if err != nil {
t.Fatalf("AllDevices: %s", err)
}
sort.Slice(devices, func(i, j int) bool {
return devices[i].DeviceID < devices[j].DeviceID
})
wantDevices := []*Device{
device, bobDevice,
}
if len(devices) != len(wantDevices) {
t.Fatalf("AllDevices: got %d devices, want %d", len(devices), len(wantDevices))
}
for i := range devices {
assertEqual(t, devices[i].Since, wantDevices[i].Since, "Device.Since mismatch")
assertEqual(t, devices[i].UserID, wantDevices[i].UserID, "Device.UserID mismatch")
assertEqual(t, devices[i].DeviceID, wantDevices[i].DeviceID, "Device.DeviceID mismatch")
assertEqual(t, devices[i].AccessToken, wantDevices[i].AccessToken, "Device.AccessToken mismatch")
}
}
func assertEqual(t *testing.T, got, want, msg string) {

View File

@ -94,6 +94,39 @@ func (h *SyncLiveHandler) Teardown() {
h.Storage.Teardown()
}
func (h *SyncLiveHandler) StartV2Pollers() {
devices, err := h.V2Store.AllDevices()
if err != nil {
logger.Err(err).Msg("StartV2Pollers: failed to query devices")
return
}
logger.Info().Int("num_devices", len(devices)).Msg("StartV2Pollers")
// how many concurrent pollers to make at startup.
// Too high and this will flood the upstream server with sync requests at startup.
// Too low and this will take ages for the v2 pollers to startup.
numWorkers := 16
ch := make(chan sync2.Device, len(devices))
for _, d := range devices {
ch <- d
}
close(ch)
var wg sync.WaitGroup
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go func() {
defer wg.Done()
for d := range ch {
h.PollerMap.EnsurePolling(
d.AccessToken, d.UserID, d.DeviceID, d.Since,
logger.With().Str("user_id", d.UserID).Logger(),
)
}
}()
}
wg.Wait()
logger.Info().Msg("StartV2Pollers finished")
}
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
@ -222,7 +255,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
}
}
if v2device.UserID == "" {
v2device.UserID, err = h.V2.WhoAmI(req.Header.Get("Authorization"))
v2device.UserID, err = h.V2.WhoAmI(accessToken)
if err != nil {
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
@ -238,7 +271,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
h.PollerMap.EnsurePolling(
req.Header.Get("Authorization"), v2device.UserID, v2device.DeviceID, v2device.Since,
accessToken, v2device.UserID, v2device.DeviceID, v2device.Since,
hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(),
)
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")

3
v3.go
View File

@ -68,6 +68,9 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server string) {
chain: []func(next http.Handler) http.Handler{
hlog.NewHandler(logger),
hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
if r.Method == "OPTIONS" {
return
}
hlog.FromRequest(r).Info().
Str("method", r.Method).
Int("status", status).