mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #90 from matrix-org/dmr/tokens-table
OIDC: track tokens separately to devices
This commit is contained in:
commit
0adaf75cfc
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
|
||||
permissions:
|
||||
packages: read
|
||||
|
@ -1,24 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func HashedTokenFromRequest(req *http.Request) (hashAccessToken string, accessToken string, err error) {
|
||||
// return a hash of the access token
|
||||
func ExtractAccessToken(req *http.Request) (accessToken string, err error) {
|
||||
ah := req.Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return "", "", fmt.Errorf("missing Authorization header")
|
||||
return "", fmt.Errorf("missing Authorization header")
|
||||
}
|
||||
accessToken = strings.TrimPrefix(ah, "Bearer ")
|
||||
// important that this is a cryptographically secure hash function to prevent
|
||||
// preimage attacks where Eve can use a fake token to hash to an existing device ID
|
||||
// on the server.
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(accessToken))
|
||||
return hex.EncodeToString(hash.Sum(nil)), accessToken, nil
|
||||
return accessToken, nil
|
||||
}
|
||||
|
@ -1,24 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDeviceIDFromRequest(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "http://localhost:8008", nil)
|
||||
req.Header.Set("Authorization", "Bearer A")
|
||||
deviceIDA, _, err := HashedTokenFromRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("HashedTokenFromRequest returned %s", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer B")
|
||||
deviceIDB, _, err := HashedTokenFromRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("HashedTokenFromRequest returned %s", err)
|
||||
}
|
||||
if deviceIDA == deviceIDB {
|
||||
t.Fatalf("HashedTokenFromRequest: hashed to same device ID: %s", deviceIDA)
|
||||
}
|
||||
|
||||
}
|
@ -78,6 +78,7 @@ type V2InitialSyncComplete struct {
|
||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||
|
||||
type V2DeviceData struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Pos int64
|
||||
}
|
||||
@ -106,6 +107,7 @@ type V2DeviceMessages struct {
|
||||
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
|
||||
|
||||
type V2ExpiredToken struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,11 @@ type V3Listener interface {
|
||||
}
|
||||
|
||||
type V3EnsurePolling struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
// TODO: we only really need to provide the access token hash here.
|
||||
// Passing through a user means we can log something sensible though.
|
||||
UserID string
|
||||
DeviceID string
|
||||
AccessTokenHash string
|
||||
}
|
||||
|
||||
func (*V3EnsurePolling) Type() string { return "V3EnsurePolling" }
|
||||
|
47
sync2/devices_table.go
Normal file
47
sync2/devices_table.go
Normal file
@ -0,0 +1,47 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
UserID string `db:"user_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
Since string `db:"since"`
|
||||
}
|
||||
|
||||
// DevicesTable remembers syncv2 since positions per-device
|
||||
type DevicesTable struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewDevicesTable(db *sqlx.DB) *DevicesTable {
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sync2_devices (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, device_id),
|
||||
since TEXT NOT NULL
|
||||
);`)
|
||||
|
||||
return &DevicesTable{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// InsertDevice creates a new devices row with a blank since token if no such row
|
||||
// exists. Otherwise, it does nothing.
|
||||
func (t *DevicesTable) InsertDevice(userID, deviceID string) error {
|
||||
_, err := t.db.Exec(
|
||||
` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3)
|
||||
ON CONFLICT (user_id, device_id) DO NOTHING`,
|
||||
userID, deviceID, "",
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *DevicesTable) UpdateDeviceSince(userID, deviceID, since string) error {
|
||||
_, err := t.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE user_id = $2 AND device_id = $3`, since, userID, deviceID)
|
||||
return err
|
||||
}
|
168
sync2/devices_table_test.go
Normal file
168
sync2/devices_table_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
)
|
||||
|
||||
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
postgresConnectionString = testutils.PrepareDBConnectionString()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
return db, func() {
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Note that we currently only ever read from (devices JOIN tokens), so there is some
|
||||
// overlap with tokens_table_test.go here.
|
||||
func TestDevicesTableSinceColumn(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
devices := NewDevicesTable(db)
|
||||
|
||||
alice := "@alice:localhost"
|
||||
aliceDevice := "alice_phone"
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceSecret2 := "mysecret2"
|
||||
|
||||
t.Log("Insert two tokens for Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Add a devices row for Alice")
|
||||
err = devices.InsertDevice(alice, aliceDevice)
|
||||
|
||||
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
|
||||
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The since token should be empty.")
|
||||
assertEqual(t, accessToken, aliceToken.AccessToken, "Token.AccessToken mismatch")
|
||||
assertEqual(t, since, "", "Device.Since mismatch")
|
||||
|
||||
t.Log("Update the since column.")
|
||||
sinceValue := "s-1-2-3-4"
|
||||
err = devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update since column: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We should see the new since value when the poller refetches alice's token")
|
||||
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
||||
}
|
||||
assertEqual(t, since, sinceValue, "Device.Since mismatch")
|
||||
|
||||
t.Log("We should also see the new since value when the poller fetches alice's second token")
|
||||
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken2.AccessTokenHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
||||
}
|
||||
assertEqual(t, since, sinceValue, "Device.Since mismatch")
|
||||
}
|
||||
|
||||
func TestTokenForEachDevice(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
|
||||
// HACK: discard rows inserted by other tests. We don't normally need to do this,
|
||||
// but this is testing a query that scans the entire devices table.
|
||||
db.MustExec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
||||
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
devices := NewDevicesTable(db)
|
||||
|
||||
alice := "alice"
|
||||
aliceDevice := "alice_phone"
|
||||
bob := "bob"
|
||||
bobDevice := "bob_laptop"
|
||||
chris := "chris"
|
||||
chrisDevice := "chris_desktop"
|
||||
|
||||
t.Log("Add a device for Alice, Bob and Chris.")
|
||||
err := devices.InsertDevice(alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(bob, bobDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(chris, chrisDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Mark Alice's device with a since token.")
|
||||
sinceValue := "s-1-2-3-4"
|
||||
devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
|
||||
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
||||
aliceLastSeen1 := time.Now()
|
||||
_, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
||||
aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Fetch a token for every device")
|
||||
gotTokens, err := tokens.TokenForEachDevice()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed TokenForEachDevice: %s", err)
|
||||
}
|
||||
|
||||
expectAlice := TokenForPoller{Token: aliceToken2, Since: sinceValue}
|
||||
expectBob := TokenForPoller{Token: bobToken, Since: ""}
|
||||
wantTokens := []*TokenForPoller{&expectAlice, &expectBob}
|
||||
|
||||
if len(gotTokens) != len(wantTokens) {
|
||||
t.Fatalf("AllDevices: got %d tokens, want %d", len(gotTokens), len(wantTokens))
|
||||
}
|
||||
|
||||
sort.Slice(gotTokens, func(i, j int) bool {
|
||||
if gotTokens[i].UserID < gotTokens[j].UserID {
|
||||
return true
|
||||
}
|
||||
return gotTokens[i].DeviceID < gotTokens[j].DeviceID
|
||||
})
|
||||
|
||||
for i := range gotTokens {
|
||||
assertEqual(t, gotTokens[i].Since, wantTokens[i].Since, "Device.Since mismatch")
|
||||
assertEqual(t, gotTokens[i].UserID, wantTokens[i].UserID, "Token.UserID mismatch")
|
||||
assertEqual(t, gotTokens[i].DeviceID, wantTokens[i].DeviceID, "Token.DeviceID mismatch")
|
||||
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
|
||||
}
|
||||
}
|
@ -95,9 +95,9 @@ func (h *Handler) Teardown() {
|
||||
}
|
||||
|
||||
func (h *Handler) StartV2Pollers() {
|
||||
devices, err := h.v2Store.AllDevices()
|
||||
tokens, err := h.v2Store.TokensTable.TokenForEachDevice()
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("StartV2Pollers: failed to query devices")
|
||||
logger.Err(err).Msg("StartV2Pollers: failed to query tokens")
|
||||
sentry.CaptureException(err)
|
||||
return
|
||||
}
|
||||
@ -106,30 +106,34 @@ func (h *Handler) StartV2Pollers() {
|
||||
// Too low and this will take ages for the v2 pollers to startup.
|
||||
numWorkers := 16
|
||||
numFails := 0
|
||||
ch := make(chan sync2.Device, len(devices))
|
||||
for _, d := range devices {
|
||||
ch := make(chan sync2.TokenForPoller, len(tokens))
|
||||
for _, t := range tokens {
|
||||
// if we fail to decrypt the access token, skip it.
|
||||
if d.AccessToken == "" {
|
||||
if t.AccessToken == "" {
|
||||
numFails++
|
||||
continue
|
||||
}
|
||||
ch <- d
|
||||
ch <- t
|
||||
}
|
||||
close(ch)
|
||||
logger.Info().Int("num_devices", len(devices)).Int("num_fail_decrypt", numFails).Msg("StartV2Pollers")
|
||||
logger.Info().Int("num_devices", len(tokens)).Int("num_fail_decrypt", numFails).Msg("StartV2Pollers")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for d := range ch {
|
||||
for t := range ch {
|
||||
pid := sync2.PollerID{
|
||||
UserID: t.UserID,
|
||||
DeviceID: t.DeviceID,
|
||||
}
|
||||
h.pMap.EnsurePolling(
|
||||
d.AccessToken, d.UserID, d.DeviceID, d.Since, true,
|
||||
logger.With().Str("user_id", d.UserID).Logger(),
|
||||
pid, t.AccessToken, t.Since, true,
|
||||
logger.With().Str("user_id", t.UserID).Str("device_id", t.DeviceID).Logger(),
|
||||
)
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InitialSyncComplete{
|
||||
UserID: d.UserID,
|
||||
DeviceID: d.DeviceID,
|
||||
UserID: t.UserID,
|
||||
DeviceID: t.DeviceID,
|
||||
})
|
||||
}
|
||||
}()
|
||||
@ -150,12 +154,11 @@ func (h *Handler) OnTerminated(userID, deviceID string) {
|
||||
h.updateMetrics()
|
||||
}
|
||||
|
||||
func (h *Handler) OnExpiredToken(userID, deviceID string) {
|
||||
h.v2Store.RemoveDevice(deviceID)
|
||||
h.Store.ToDeviceTable.DeleteAllMessagesForDevice(deviceID)
|
||||
h.Store.DeviceDataTable.DeleteDevice(userID, deviceID)
|
||||
// also notify v3 side so it can remove the connection from ConnMap
|
||||
func (h *Handler) OnExpiredToken(accessTokenHash, userID, deviceID string) {
|
||||
h.v2Store.TokensTable.Delete(accessTokenHash)
|
||||
// Notify v3 side so it can remove the connection from ConnMap
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
})
|
||||
}
|
||||
@ -171,10 +174,10 @@ func (h *Handler) addPrometheusMetrics() {
|
||||
}
|
||||
|
||||
// Emits nothing as no downstream components need it.
|
||||
func (h *Handler) UpdateDeviceSince(deviceID, since string) {
|
||||
err := h.v2Store.UpdateDeviceSince(deviceID, since)
|
||||
func (h *Handler) UpdateDeviceSince(userID, deviceID, since string) {
|
||||
err := h.v2Store.DevicesTable.UpdateDeviceSince(userID, deviceID, since)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("device", deviceID).Str("since", since).Msg("V2: failed to persist since token")
|
||||
logger.Err(err).Str("user", userID).Str("device", deviceID).Str("since", since).Msg("V2: failed to persist since token")
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}
|
||||
@ -197,6 +200,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Pos: nextPos,
|
||||
})
|
||||
@ -387,7 +391,7 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
|
||||
defer func() {
|
||||
logger.Info().Str("user", p.UserID).Msg("EnsurePolling: request finished")
|
||||
}()
|
||||
dev, err := h.v2Store.Device(p.DeviceID)
|
||||
accessToken, since, err := h.v2Store.TokensTable.GetTokenAndSince(p.UserID, p.DeviceID, p.AccessTokenHash)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", p.UserID).Str("device", p.DeviceID).Msg("V3Sub: EnsurePolling unknown device")
|
||||
sentry.CaptureException(err)
|
||||
@ -396,9 +400,13 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
|
||||
// don't block us from consuming more pubsub messages just because someone wants to sync
|
||||
go func() {
|
||||
// blocks until an initial sync is done
|
||||
pid := sync2.PollerID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
}
|
||||
h.pMap.EnsurePolling(
|
||||
dev.AccessToken, dev.UserID, dev.DeviceID, dev.Since, false,
|
||||
logger.With().Str("user_id", dev.UserID).Logger(),
|
||||
pid, accessToken, since, false,
|
||||
logger.With().Str("user_id", p.UserID).Str("device_id", p.DeviceID).Logger(),
|
||||
)
|
||||
h.updateMetrics()
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InitialSyncComplete{
|
||||
|
@ -14,13 +14,18 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type PollerID struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// alias time.Sleep so tests can monkey patch it out
|
||||
var timeSleep = time.Sleep
|
||||
|
||||
// V2DataReceiver is the receiver for all the v2 sync data the poller gets
|
||||
type V2DataReceiver interface {
|
||||
// Update the since token for this device. Called AFTER all other data in this sync response has been processed.
|
||||
UpdateDeviceSince(deviceID, since string)
|
||||
UpdateDeviceSince(userID, deviceID, since string)
|
||||
// Accumulate data for this room. This means the timeline section of the v2 response.
|
||||
Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
|
||||
// Initialise the room, if it hasn't been already. This means the state section of the v2 response.
|
||||
@ -45,7 +50,7 @@ type V2DataReceiver interface {
|
||||
// Sent when the poll loop terminates
|
||||
OnTerminated(userID, deviceID string)
|
||||
// Sent when the token gets a 401 response
|
||||
OnExpiredToken(userID, deviceID string)
|
||||
OnExpiredToken(accessTokenHash, userID, deviceID string)
|
||||
}
|
||||
|
||||
// PollerMap is a map of device ID to Poller
|
||||
@ -53,7 +58,7 @@ type PollerMap struct {
|
||||
v2Client Client
|
||||
callbacks V2DataReceiver
|
||||
pollerMu *sync.Mutex
|
||||
Pollers map[string]*poller // device_id -> poller
|
||||
Pollers map[PollerID]*poller
|
||||
executor chan func()
|
||||
executorRunning bool
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
@ -88,7 +93,7 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap {
|
||||
pm := &PollerMap{
|
||||
v2Client: v2Client,
|
||||
pollerMu: &sync.Mutex{},
|
||||
Pollers: make(map[string]*poller),
|
||||
Pollers: make(map[PollerID]*poller),
|
||||
executor: make(chan func(), 0),
|
||||
}
|
||||
if enablePrometheus {
|
||||
@ -151,15 +156,18 @@ func (h *PollerMap) NumPollers() (count int) {
|
||||
// Note that we will immediately return if there is a poller for the same user but a different device.
|
||||
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
|
||||
// to-device msgs to decrypt E2EE roms.
|
||||
func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
h.pollerMu.Lock()
|
||||
if !h.executorRunning {
|
||||
h.executorRunning = true
|
||||
go h.execute()
|
||||
}
|
||||
poller, ok := h.Pollers[deviceID]
|
||||
poller, ok := h.Pollers[pid]
|
||||
// a poller exists and hasn't been terminated so we don't need to do anything
|
||||
if ok && !poller.terminated.Load() {
|
||||
if poller.accessToken != accessToken {
|
||||
logger.Warn().Msg("PollerMap.EnsurePolling: poller already running with different access token")
|
||||
}
|
||||
h.pollerMu.Unlock()
|
||||
// this existing poller may not have completed the initial sync yet, so we need to make sure
|
||||
// it has before we return.
|
||||
@ -169,28 +177,31 @@ func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string,
|
||||
// check if we need to wait at all: we don't need to if this user is already syncing on a different device
|
||||
// This is O(n) so we may want to map this if we get a lot of users...
|
||||
needToWait := true
|
||||
for pollerDeviceID, poller := range h.Pollers {
|
||||
if deviceID == pollerDeviceID {
|
||||
for existingPID, poller := range h.Pollers {
|
||||
// Ignore different users. Also ignore same-user same-device.
|
||||
if pid.UserID != existingPID.UserID || pid.DeviceID == existingPID.DeviceID {
|
||||
continue
|
||||
}
|
||||
if poller.userID == userID && !poller.terminated.Load() {
|
||||
// Now we have same-user different-device.
|
||||
if !poller.terminated.Load() {
|
||||
needToWait = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// replace the poller. If we don't need to wait, then we just want to nab to-device events initially.
|
||||
// We don't do that on startup though as we cannot be sure that other pollers will not be using expired tokens.
|
||||
poller = newPoller(userID, accessToken, deviceID, h.v2Client, h, logger, !needToWait && !isStartup)
|
||||
poller = newPoller(pid, accessToken, h.v2Client, h, logger, !needToWait && !isStartup)
|
||||
poller.processHistogramVec = h.processHistogramVec
|
||||
poller.timelineSizeVec = h.timelineSizeHistogramVec
|
||||
go poller.Poll(v2since)
|
||||
h.Pollers[deviceID] = poller
|
||||
h.Pollers[pid] = poller
|
||||
|
||||
h.pollerMu.Unlock()
|
||||
if needToWait {
|
||||
poller.WaitUntilInitialSync()
|
||||
} else {
|
||||
logger.Info().Str("user", userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
|
||||
logger.Info().Msg("a poller exists for this user; not waiting for this device to do an initial sync")
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,8 +211,8 @@ func (h *PollerMap) execute() {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PollerMap) UpdateDeviceSince(deviceID, since string) {
|
||||
h.callbacks.UpdateDeviceSince(deviceID, since)
|
||||
func (h *PollerMap) UpdateDeviceSince(userID, deviceID, since string) {
|
||||
h.callbacks.UpdateDeviceSince(userID, deviceID, since)
|
||||
}
|
||||
func (h *PollerMap) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
var wg sync.WaitGroup
|
||||
@ -261,8 +272,8 @@ func (h *PollerMap) OnTerminated(userID, deviceID string) {
|
||||
h.callbacks.OnTerminated(userID, deviceID)
|
||||
}
|
||||
|
||||
func (h *PollerMap) OnExpiredToken(userID, deviceID string) {
|
||||
h.callbacks.OnExpiredToken(userID, deviceID)
|
||||
func (h *PollerMap) OnExpiredToken(accessTokenHash, userID, deviceID string) {
|
||||
h.callbacks.OnExpiredToken(accessTokenHash, userID, deviceID)
|
||||
}
|
||||
|
||||
func (h *PollerMap) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
|
||||
@ -308,8 +319,8 @@ func (h *PollerMap) OnE2EEData(userID, deviceID string, otkCounts map[string]int
|
||||
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in storage
|
||||
type poller struct {
|
||||
userID string
|
||||
accessToken string
|
||||
deviceID string
|
||||
accessToken string
|
||||
client Client
|
||||
receiver V2DataReceiver
|
||||
logger zerolog.Logger
|
||||
@ -329,13 +340,13 @@ type poller struct {
|
||||
timelineSizeVec *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
func newPoller(userID, accessToken, deviceID string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
|
||||
func newPoller(pid PollerID, accessToken string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
return &poller{
|
||||
userID: pid.UserID,
|
||||
deviceID: pid.DeviceID,
|
||||
accessToken: accessToken,
|
||||
userID: userID,
|
||||
deviceID: deviceID,
|
||||
client: client,
|
||||
receiver: receiver,
|
||||
terminated: &atomic.Bool{},
|
||||
@ -390,7 +401,7 @@ func (p *poller) Poll(since string) {
|
||||
continue
|
||||
} else {
|
||||
p.logger.Warn().Msg("Poller: access token has been invalidated, terminating loop")
|
||||
p.receiver.OnExpiredToken(p.userID, p.deviceID)
|
||||
p.receiver.OnExpiredToken(hashToken(p.accessToken), p.userID, p.deviceID)
|
||||
p.Terminate()
|
||||
break
|
||||
}
|
||||
@ -411,7 +422,7 @@ func (p *poller) Poll(since string) {
|
||||
|
||||
since = resp.NextBatch
|
||||
// persist the since token (TODO: this could get slow if we hammer the DB too much)
|
||||
p.receiver.UpdateDeviceSince(p.deviceID, since)
|
||||
p.receiver.UpdateDeviceSince(p.userID, p.deviceID, since)
|
||||
|
||||
if firstTime {
|
||||
firstTime = false
|
||||
@ -522,7 +533,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
|
||||
// the timeline now so that future events are received under the
|
||||
// correct room state.
|
||||
const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll"
|
||||
log.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg)
|
||||
logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg)
|
||||
sentry.WithScope(func(scope *sentry.Scope) {
|
||||
scope.SetContext("sliding-sync", map[string]interface{}{
|
||||
"room_id": roomID,
|
||||
|
@ -51,7 +51,10 @@ func TestPollerMapEnsurePolling(t *testing.T) {
|
||||
|
||||
ensurePollingUnblocked := make(chan struct{})
|
||||
go func() {
|
||||
pm.EnsurePolling("access_token", "@alice:localhost", "FOOBAR", "", false, zerolog.New(os.Stderr))
|
||||
pm.EnsurePolling(PollerID{
|
||||
UserID: "alice:localhost",
|
||||
DeviceID: "FOOBAR",
|
||||
}, "access_token", "", false, zerolog.New(os.Stderr))
|
||||
close(ensurePollingUnblocked)
|
||||
}()
|
||||
ensureBlocking := func() {
|
||||
@ -136,7 +139,7 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
t.Logf("EnsurePolling")
|
||||
pm.EnsurePolling("access_token", "@alice:localhost", "FOOBAR", "", false, zerolog.New(os.Stderr))
|
||||
pm.EnsurePolling(PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}, "access_token", "", false, zerolog.New(os.Stderr))
|
||||
wg.Done()
|
||||
t.Logf("EnsurePolling unblocked")
|
||||
}()
|
||||
@ -195,7 +198,7 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
|
||||
// Check that a call to Poll starts polling and accumulating, and terminates on 401s.
|
||||
func TestPollerPollFromNothing(t *testing.T) {
|
||||
nextSince := "next"
|
||||
deviceID := "FOOBAR"
|
||||
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
|
||||
roomID := "!foo:bar"
|
||||
roomState := []json.RawMessage{
|
||||
json.RawMessage(`{"event":1}`),
|
||||
@ -224,7 +227,7 @@ func TestPollerPollFromNothing(t *testing.T) {
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
|
||||
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
poller.Poll("")
|
||||
@ -244,14 +247,14 @@ func TestPollerPollFromNothing(t *testing.T) {
|
||||
if len(accumulator.states[roomID]) != len(roomState) {
|
||||
t.Errorf("did not accumulate initial state for room, got %d events want %d", len(accumulator.states[roomID]), len(roomState))
|
||||
}
|
||||
if accumulator.deviceIDToSince[deviceID] != nextSince {
|
||||
t.Errorf("did not persist latest since token, got %s want %s", accumulator.deviceIDToSince[deviceID], nextSince)
|
||||
if accumulator.pollerIDToSince[pid] != nextSince {
|
||||
t.Errorf("did not persist latest since token, got %s want %s", accumulator.pollerIDToSince[pid], nextSince)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that a call to Poll starts polling with an existing since token and accumulates timeline entries
|
||||
func TestPollerPollFromExisting(t *testing.T) {
|
||||
deviceID := "FOOBAR"
|
||||
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
|
||||
roomID := "!foo:bar"
|
||||
since := "0"
|
||||
roomTimelineResponses := [][]json.RawMessage{
|
||||
@ -307,7 +310,7 @@ func TestPollerPollFromExisting(t *testing.T) {
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
|
||||
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
poller.Poll(since)
|
||||
@ -328,8 +331,8 @@ func TestPollerPollFromExisting(t *testing.T) {
|
||||
t.Errorf("did not accumulate timelines for room, got %d events want %d", len(accumulator.timelines[roomID]), 10)
|
||||
}
|
||||
wantSince := fmt.Sprintf("%d", len(roomTimelineResponses))
|
||||
if accumulator.deviceIDToSince[deviceID] != wantSince {
|
||||
t.Errorf("did not persist latest since token, got %s want %s", accumulator.deviceIDToSince[deviceID], wantSince)
|
||||
if accumulator.pollerIDToSince[pid] != wantSince {
|
||||
t.Errorf("did not persist latest since token, got %s want %s", accumulator.pollerIDToSince[pid], wantSince)
|
||||
}
|
||||
}
|
||||
|
||||
@ -383,7 +386,7 @@ func TestPollerBackoff(t *testing.T) {
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
|
||||
poller := newPoller(PollerID{UserID: "@alice:localhost", DeviceID: deviceID}, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
poller.Poll("some_since_value")
|
||||
@ -413,7 +416,7 @@ func TestPollerUnblocksIfTerminatedInitially(t *testing.T) {
|
||||
|
||||
pollUnblocked := make(chan struct{})
|
||||
waitUntilInitialSyncUnblocked := make(chan struct{})
|
||||
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
|
||||
poller := newPoller(PollerID{UserID: "@alice:localhost", DeviceID: deviceID}, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
|
||||
go func() {
|
||||
poller.Poll("")
|
||||
close(pollUnblocked)
|
||||
@ -452,7 +455,7 @@ func (c *mockClient) WhoAmI(authHeader string) (string, string, error) {
|
||||
type mockDataReceiver struct {
|
||||
states map[string][]json.RawMessage
|
||||
timelines map[string][]json.RawMessage
|
||||
deviceIDToSince map[string]string
|
||||
pollerIDToSince map[PollerID]string
|
||||
incomingProcess chan struct{}
|
||||
unblockProcess chan struct{}
|
||||
}
|
||||
@ -474,8 +477,8 @@ func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) []
|
||||
}
|
||||
func (a *mockDataReceiver) SetTyping(roomID string, ephEvent json.RawMessage) {
|
||||
}
|
||||
func (s *mockDataReceiver) UpdateDeviceSince(deviceID, since string) {
|
||||
s.deviceIDToSince[deviceID] = since
|
||||
func (s *mockDataReceiver) UpdateDeviceSince(userID, deviceID, since string) {
|
||||
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
|
||||
}
|
||||
func (s *mockDataReceiver) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
|
||||
}
|
||||
@ -488,8 +491,8 @@ func (s *mockDataReceiver) OnInvite(userID, roomID string, inviteState []json.Ra
|
||||
func (s *mockDataReceiver) OnLeftRoom(userID, roomID string) {}
|
||||
func (s *mockDataReceiver) OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
||||
}
|
||||
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
|
||||
func (s *mockDataReceiver) OnExpiredToken(userID, deviceID string) {}
|
||||
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
|
||||
func (s *mockDataReceiver) OnExpiredToken(accessTokenHash, userID, deviceID string) {}
|
||||
|
||||
func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error)) (*mockDataReceiver, *mockClient) {
|
||||
client := &mockClient{
|
||||
@ -498,7 +501,7 @@ func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error
|
||||
accumulator := &mockDataReceiver{
|
||||
states: make(map[string][]json.RawMessage),
|
||||
timelines: make(map[string][]json.RawMessage),
|
||||
deviceIDToSince: make(map[string]string),
|
||||
pollerIDToSince: make(map[PollerID]string),
|
||||
}
|
||||
return accumulator, client
|
||||
}
|
||||
|
164
sync2/storage.go
164
sync2/storage.go
@ -1,45 +1,21 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"github.com/rs/zerolog"
|
||||
"os"
|
||||
)
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
})
|
||||
|
||||
type Device struct {
|
||||
UserID string `db:"user_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
Since string `db:"since"`
|
||||
AccessToken string
|
||||
AccessTokenEncrypted string `db:"v2_token_encrypted"`
|
||||
}
|
||||
|
||||
// Storage remembers sync v2 tokens per-device
|
||||
type Storage struct {
|
||||
db *sqlx.DB
|
||||
// A separate secret used to en/decrypt access tokens prior to / after retrieval from the database.
|
||||
// This provides additional security as a simple SQL injection attack would be insufficient to retrieve
|
||||
// users access tokens due to the encryption key not living inside the database / on that machine at all.
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#separation-of-keys-and-data
|
||||
// We cannot use bcrypt/scrypt as we need the plaintext to do sync requests!
|
||||
key256 []byte
|
||||
DevicesTable *DevicesTable
|
||||
TokensTable *TokensTable
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func NewStore(postgresURI, secret string) *Storage {
|
||||
@ -47,138 +23,18 @@ func NewStore(postgresURI, secret string) *Storage {
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
// TODO: if we panic(), will sentry have a chance to flush the event?
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sync2_devices (
|
||||
device_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL, -- populated from /whoami
|
||||
v2_token_encrypted TEXT NOT NULL,
|
||||
since TEXT NOT NULL
|
||||
);`)
|
||||
|
||||
// derive the key from the secret
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(secret))
|
||||
|
||||
return &Storage{
|
||||
db: db,
|
||||
key256: hash.Sum(nil),
|
||||
DevicesTable: NewDevicesTable(db),
|
||||
TokensTable: NewTokensTable(db, secret),
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Storage) Teardown() {
|
||||
err := s.db.Close()
|
||||
err := s.DB.Close()
|
||||
if err != nil {
|
||||
panic("V2Storage.Teardown: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Storage) encrypt(token string) string {
|
||||
block, err := aes.NewCipher(s.key256)
|
||||
if err != nil {
|
||||
panic("sync2.Storage encrypt: " + err.Error())
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
panic("sync2.Storage encrypt: " + err.Error())
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
panic("sync2.Storage encrypt: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(nonce) + " " + hex.EncodeToString(gcm.Seal(nil, nonce, []byte(token), nil))
|
||||
}
|
||||
func (s *Storage) decrypt(nonceAndEncToken string) (string, error) {
|
||||
segs := strings.Split(nonceAndEncToken, " ")
|
||||
nonce := segs[0]
|
||||
nonceBytes, err := hex.DecodeString(nonce)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt nonce: failed to decode hex: %s", err)
|
||||
}
|
||||
encToken := segs[1]
|
||||
ciphertext, err := hex.DecodeString(encToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt token: failed to decode hex: %s", err)
|
||||
}
|
||||
block, err := aes.NewCipher(s.key256)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token, err := aesgcm.Open(nil, nonceBytes, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(token), nil
|
||||
}
|
||||
|
||||
func (s *Storage) Device(deviceID string) (*Device, error) {
|
||||
var d Device
|
||||
err := s.db.Get(&d, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices WHERE device_id=$1`, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to lookup device '%s': %s", deviceID, err)
|
||||
}
|
||||
d.AccessToken, err = s.decrypt(d.AccessTokenEncrypted)
|
||||
return &d, err
|
||||
}
|
||||
|
||||
func (s *Storage) AllDevices() (devices []Device, err error) {
|
||||
err = s.db.Select(&devices, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for i := range devices {
|
||||
devices[i].AccessToken, _ = s.decrypt(devices[i].AccessTokenEncrypted)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Storage) RemoveDevice(deviceID string) error {
|
||||
_, err := s.db.Exec(
|
||||
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`, deviceID,
|
||||
)
|
||||
log.Info().Str("device", deviceID).Msg("Deleting device")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
|
||||
var device Device
|
||||
device.AccessToken = accessToken
|
||||
device.AccessTokenEncrypted = s.encrypt(accessToken)
|
||||
err := sqlutil.WithTransaction(s.db, func(txn *sqlx.Tx) error {
|
||||
// make sure there is a device entry for this device ID. If one already exists, don't clobber
|
||||
// the since value else we'll forget our position!
|
||||
result, err := txn.Exec(`
|
||||
INSERT INTO syncv3_sync2_devices(device_id, since, user_id, v2_token_encrypted) VALUES($1,$2,$3,$4)
|
||||
ON CONFLICT (device_id) DO NOTHING`,
|
||||
deviceID, "", "", device.AccessTokenEncrypted,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
device.DeviceID = deviceID
|
||||
|
||||
// if we inserted a row that means it's a brand new device ergo there is no since token
|
||||
if ra, err := result.RowsAffected(); err == nil && ra == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the since value as we may start a new poller with this session.
|
||||
return txn.QueryRow("SELECT since, user_id FROM syncv3_sync2_devices WHERE device_id = $1", deviceID).Scan(&device.Since, &device.UserID)
|
||||
})
|
||||
return &device, err
|
||||
}
|
||||
|
||||
func (s *Storage) UpdateDeviceSince(deviceID, since string) error {
|
||||
_, err := s.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE device_id = $2`, since, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Storage) UpdateUserIDForDevice(deviceID, userID string) error {
|
||||
_, err := s.db.Exec(`UPDATE syncv3_sync2_devices SET user_id = $1 WHERE device_id = $2`, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
@ -1,89 +0,0 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
)
|
||||
|
||||
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
postgresConnectionString = testutils.PrepareDBConnectionString()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
deviceID := "ALICE"
|
||||
accessToken := "my_access_token"
|
||||
store := NewStore(postgresConnectionString, "my_secret")
|
||||
device, err := store.InsertDevice(deviceID, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to InsertDevice: %s", err)
|
||||
}
|
||||
assertEqual(t, device.DeviceID, deviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, device.AccessToken, accessToken, "Device.AccessToken mismatch")
|
||||
if err = store.UpdateDeviceSince(deviceID, "s1"); err != nil {
|
||||
t.Fatalf("UpdateDeviceSince returned error: %s", err)
|
||||
}
|
||||
if err = store.UpdateUserIDForDevice(deviceID, "@alice:localhost"); err != nil {
|
||||
t.Fatalf("UpdateUserIDForDevice returned error: %s", err)
|
||||
}
|
||||
|
||||
// now check that device retrieval has the latest values
|
||||
device, err = store.Device(deviceID)
|
||||
if err != nil {
|
||||
t.Fatalf("Device returned error: %s", err)
|
||||
}
|
||||
assertEqual(t, device.DeviceID, deviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, device.Since, "s1", "Device.Since mismatch")
|
||||
assertEqual(t, device.UserID, "@alice:localhost", "Device.UserID mismatch")
|
||||
assertEqual(t, device.AccessToken, accessToken, "Device.AccessToken mismatch")
|
||||
|
||||
// now check new devices remember the v2 since value and user ID
|
||||
s2, err := store.InsertDevice(deviceID, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
assertEqual(t, s2.Since, "s1", "Device.Since mismatch")
|
||||
assertEqual(t, s2.UserID, "@alice:localhost", "Device.UserID mismatch")
|
||||
assertEqual(t, s2.DeviceID, deviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, s2.AccessToken, accessToken, "Device.AccessToken mismatch")
|
||||
|
||||
// check all devices works
|
||||
deviceID2 := "BOB"
|
||||
accessToken2 := "BOB_ACCESS_TOKEN"
|
||||
bobDevice, err := store.InsertDevice(deviceID2, accessToken2)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
devices, err := store.AllDevices()
|
||||
if err != nil {
|
||||
t.Fatalf("AllDevices: %s", err)
|
||||
}
|
||||
sort.Slice(devices, func(i, j int) bool {
|
||||
return devices[i].DeviceID < devices[j].DeviceID
|
||||
})
|
||||
wantDevices := []*Device{
|
||||
device, bobDevice,
|
||||
}
|
||||
if len(devices) != len(wantDevices) {
|
||||
t.Fatalf("AllDevices: got %d devices, want %d", len(devices), len(wantDevices))
|
||||
}
|
||||
for i := range devices {
|
||||
assertEqual(t, devices[i].Since, wantDevices[i].Since, "Device.Since mismatch")
|
||||
assertEqual(t, devices[i].UserID, wantDevices[i].UserID, "Device.UserID mismatch")
|
||||
assertEqual(t, devices[i].DeviceID, wantDevices[i].DeviceID, "Device.DeviceID mismatch")
|
||||
assertEqual(t, devices[i].AccessToken, wantDevices[i].AccessToken, "Device.AccessToken mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, got, want, msg string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("%s: got %s want %s", msg, got, want)
|
||||
}
|
||||
}
|
245
sync2/tokens_table.go
Normal file
245
sync2/tokens_table.go
Normal file
@ -0,0 +1,245 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
AccessToken string
|
||||
AccessTokenHash string
|
||||
AccessTokenEncrypted string `db:"token_encrypted"`
|
||||
UserID string `db:"user_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
LastSeen time.Time `db:"last_seen"`
|
||||
}
|
||||
|
||||
// TokensTable remembers sync v2 tokens
|
||||
type TokensTable struct {
|
||||
db *sqlx.DB
|
||||
// A separate secret used to en/decrypt access tokens prior to / after retrieval from the database.
|
||||
// This provides additional security as a simple SQL injection attack would be insufficient to retrieve
|
||||
// users access tokens due to the encryption key not living inside the database / on that machine at all.
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#separation-of-keys-and-data
|
||||
// We cannot use bcrypt/scrypt as we need the plaintext to do sync requests!
|
||||
key256 []byte
|
||||
}
|
||||
|
||||
// NewTokensTable creates the syncv3_sync2_tokens table if it does not already exist.
|
||||
func NewTokensTable(db *sqlx.DB, secret string) *TokensTable {
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sync2_tokens (
|
||||
token_hash TEXT NOT NULL PRIMARY KEY, -- SHA256(access token)
|
||||
token_encrypted TEXT NOT NULL,
|
||||
-- TODO: FK constraints to devices table?
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
last_seen TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);`)
|
||||
|
||||
// derive the key from the secret
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(secret))
|
||||
|
||||
return &TokensTable{
|
||||
db: db,
|
||||
key256: hash.Sum(nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokensTable) encrypt(token string) string {
|
||||
block, err := aes.NewCipher(t.key256)
|
||||
if err != nil {
|
||||
panic("sync2.DevicesTable encrypt: " + err.Error())
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
panic("sync2.DevicesTable encrypt: " + err.Error())
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
panic("sync2.DevicesTable encrypt: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(nonce) + " " + hex.EncodeToString(gcm.Seal(nil, nonce, []byte(token), nil))
|
||||
}
|
||||
func (t *TokensTable) decrypt(nonceAndEncToken string) (string, error) {
|
||||
segs := strings.Split(nonceAndEncToken, " ")
|
||||
nonce := segs[0]
|
||||
nonceBytes, err := hex.DecodeString(nonce)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt nonce: failed to decode hex: %s", err)
|
||||
}
|
||||
encToken := segs[1]
|
||||
ciphertext, err := hex.DecodeString(encToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt token: failed to decode hex: %s", err)
|
||||
}
|
||||
block, err := aes.NewCipher(t.key256)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token, err := aesgcm.Open(nil, nonceBytes, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(token), nil
|
||||
}
|
||||
|
||||
func hashToken(accessToken string) string {
|
||||
// important that this is a cryptographically secure hash function to prevent
|
||||
// preimage attacks where Eve can use a fake token to hash to an existing device ID
|
||||
// on the server.
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(accessToken))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// Token retrieves a tokens row from the database if it exists.
|
||||
// Errors with sql.NoRowsError if the token does not exist.
|
||||
// Errors with an unspecified error otherwise.
|
||||
func (t *TokensTable) Token(plaintextToken string) (*Token, error) {
|
||||
tokenHash := hashToken(plaintextToken)
|
||||
var token Token
|
||||
err := t.db.Get(
|
||||
&token,
|
||||
`SELECT token_encrypted, user_id, device_id, last_seen FROM syncv3_sync2_tokens WHERE token_hash=$1`,
|
||||
tokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.AccessToken = plaintextToken
|
||||
token.AccessTokenHash = tokenHash
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// TokenForPoller represents a row of the tokens table, together with any data
|
||||
// maintained by pollers for that token's device.
|
||||
type TokenForPoller struct {
|
||||
*Token
|
||||
Since string `db:"since"`
|
||||
}
|
||||
|
||||
func (t *TokensTable) TokenForEachDevice() (tokens []TokenForPoller, err error) {
|
||||
// Fetches the most recently seen token for each device, see e.g.
|
||||
// https://www.postgresql.org/docs/11/sql-select.html#SQL-DISTINCT
|
||||
err = t.db.Select(
|
||||
&tokens,
|
||||
`SELECT DISTINCT ON (user_id, device_id) token_encrypted, user_id, device_id, last_seen, since
|
||||
FROM syncv3_sync2_tokens JOIN syncv3_sync2_devices USING (user_id, device_id)
|
||||
ORDER BY user_id, device_id, last_seen DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, token := range tokens {
|
||||
token.AccessToken, err = t.decrypt(token.AccessTokenEncrypted)
|
||||
if err != nil {
|
||||
// Ignore decryption failure.
|
||||
continue
|
||||
}
|
||||
token.AccessTokenHash = hashToken(token.AccessToken)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Insert a new token into the table.
|
||||
func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
|
||||
hashedToken := hashToken(plaintextToken)
|
||||
encToken := t.encrypt(plaintextToken)
|
||||
_, err := t.db.Exec(
|
||||
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (token_hash) DO NOTHING;`,
|
||||
hashedToken, encToken, userID, deviceID, lastSeen,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
AccessToken: plaintextToken,
|
||||
AccessTokenHash: hashedToken,
|
||||
// Note: if this token already exists in the DB, encToken will differ from
|
||||
// the DB token_encrypted column. (t.encrypt is nondeterministic, see e.g.
|
||||
// https://en.wikipedia.org/wiki/Probabilistic_encryption).
|
||||
// The rest of the program should ignore this field; it only lives here so
|
||||
// we can Scan the DB row into the Tokens struct. Could make it private?
|
||||
AccessTokenEncrypted: encToken,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
LastSeen: lastSeen,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MaybeUpdateLastSeen actions a request to update a Token struct with its last_seen value
|
||||
// in the DB. To avoid spamming the DB with a write every time a sync3 request arrives,
|
||||
// we only update the last seen timestamp or the if it is at least 24 hours old.
|
||||
// The timestamp is updated on the Token struct if and only if it is updated in the DB.
|
||||
func (t *TokensTable) MaybeUpdateLastSeen(token *Token, newLastSeen time.Time) error {
|
||||
sinceLastSeen := newLastSeen.Sub(token.LastSeen)
|
||||
if sinceLastSeen < (24 * time.Hour) {
|
||||
return nil
|
||||
}
|
||||
_, err := t.db.Exec(
|
||||
`UPDATE syncv3_sync2_tokens SET last_seen = $1 WHERE token_hash = $2`,
|
||||
newLastSeen, token.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token.LastSeen = newLastSeen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TokensTable) GetTokenAndSince(userID, deviceID, tokenHash string) (accessToken, since string, err error) {
|
||||
var encToken, gotUserID, gotDeviceID string
|
||||
query := `SELECT token_encrypted, since, user_id, device_id
|
||||
FROM syncv3_sync2_tokens JOIN syncv3_sync2_devices USING (user_id, device_id)
|
||||
WHERE token_hash = $1;`
|
||||
err = t.db.QueryRow(query, tokenHash).Scan(&encToken, &since, &gotUserID, &gotDeviceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if gotUserID != userID || gotDeviceID != deviceID {
|
||||
err = fmt.Errorf(
|
||||
"token (hash %s) found with user+device mismatch: got (%s, %s), expected (%s, %s)",
|
||||
tokenHash, gotUserID, gotDeviceID, userID, deviceID,
|
||||
)
|
||||
return
|
||||
}
|
||||
accessToken, err = t.decrypt(encToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete looks up a token by its hash and deletes the row. If no token exists with the
|
||||
// given hash, a warning is logged but no error is returned.
|
||||
func (t *TokensTable) Delete(accessTokenHash string) error {
|
||||
result, err := t.db.Exec(
|
||||
`DELETE FROM syncv3_sync2_tokens WHERE token_hash = $1`,
|
||||
accessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ra, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ra != 1 {
|
||||
logger.Warn().Msgf("Tokens.Delete: expected to delete one token, but actually deleted %d", ra)
|
||||
}
|
||||
return nil
|
||||
}
|
150
sync2/tokens_table_test.go
Normal file
150
sync2/tokens_table_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Sanity check that different tokens have different hashes
|
||||
func TestHash(t *testing.T) {
|
||||
token1 := "ABCD"
|
||||
token2 := "EFGH"
|
||||
hash1 := hashToken(token1)
|
||||
hash2 := hashToken(token2)
|
||||
if hash1 == hash2 {
|
||||
t.Fatalf("HashedTokenFromRequest: %s and %s have the same hash", token1, token2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokensTable(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
|
||||
alice := "@alice:localhost"
|
||||
aliceDevice := "alice_phone"
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceToken1FirstSeen := time.Now()
|
||||
|
||||
// Test a single token
|
||||
t.Log("Insert a new token from Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Reinsert the same token.")
|
||||
reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("This should yield an equal Token struct.")
|
||||
assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Try to mark Alice's token as being used after an hour.")
|
||||
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update last seen: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The token should not be updated in memory, nor in the DB.")
|
||||
assertEqualTimes(t, aliceToken.LastSeen, aliceToken1FirstSeen, "Token.LastSeen mismatch")
|
||||
fetchedToken, err := tokens.Token(aliceSecret1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch token: %s", err)
|
||||
}
|
||||
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Try to mark Alice's token as being used after two days.")
|
||||
aliceToken1LastSeen := aliceToken1FirstSeen.Add(48 * time.Hour)
|
||||
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1LastSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update last seen: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The token should now be updated in-memory and in the DB.")
|
||||
assertEqualTimes(t, aliceToken.LastSeen, aliceToken1LastSeen, "Token.LastSeen mismatch")
|
||||
fetchedToken, err = tokens.Token(aliceSecret1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch token: %s", err)
|
||||
}
|
||||
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen)
|
||||
|
||||
// Test a second token for Alice
|
||||
t.Log("Insert a second token for Alice.")
|
||||
aliceSecret2 := "mysecret2"
|
||||
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
}
|
||||
|
||||
func TestDeletingTokens(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
|
||||
t.Log("Insert a new token from Alice.")
|
||||
accessToken := "mytoken"
|
||||
token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We should be able to fetch this token without error.")
|
||||
_, err = tokens.Token(accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Delete the token")
|
||||
err = tokens.Delete(token.AccessTokenHash)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We should no longer be able to fetch this token.")
|
||||
token, err = tokens.Token(accessToken)
|
||||
if token != nil || err == nil {
|
||||
t.Fatalf("Fetching token after deletion did not fail: got %s, %s", token, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqualTokens(t *testing.T, table *TokensTable, got *Token, accessToken, userID, deviceID string, lastSeen time.Time) {
|
||||
t.Helper()
|
||||
assertEqual(t, got.AccessToken, accessToken, "Token.AccessToken mismatch")
|
||||
assertEqual(t, got.AccessTokenHash, hashToken(accessToken), "Token.AccessTokenHashed mismatch")
|
||||
// We don't care what the encrypted token is here. The fact that we store encrypted values is an
|
||||
// implementation detail; the rest of the program doesn't care.
|
||||
assertEqual(t, got.UserID, userID, "Token.UserID mismatch")
|
||||
assertEqual(t, got.DeviceID, deviceID, "Token.DeviceID mismatch")
|
||||
assertEqualTimes(t, got.LastSeen, lastSeen, "Token.LastSeen mismatch")
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, got, want, msg string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("%s: got %s want %s", msg, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqualTimes(t *testing.T, got, want time.Time, msg string) {
|
||||
t.Helper()
|
||||
// Postgres stores timestamps with microsecond resolution, so we might lose some
|
||||
// precision by storing and fetching a time.Time in/from the DB. Resolution of
|
||||
// a second will suffice.
|
||||
if !got.Round(time.Second).Equal(want.Round(time.Second)) {
|
||||
t.Fatalf("%s: got %v want %v", msg, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// see devices_table_test.go for tests which join the tokens and devices tables.
|
@ -11,11 +11,12 @@ import (
|
||||
)
|
||||
|
||||
type ConnID struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
func (c *ConnID) String() string {
|
||||
return c.DeviceID
|
||||
return c.UserID + "|" + c.DeviceID
|
||||
}
|
||||
|
||||
type ConnHandler interface {
|
||||
@ -24,7 +25,6 @@ type ConnHandler interface {
|
||||
// status code to send back.
|
||||
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
|
||||
OnUpdate(ctx context.Context, update caches.Update)
|
||||
UserID() string
|
||||
Destroy()
|
||||
Alive() bool
|
||||
}
|
||||
@ -33,7 +33,7 @@ type ConnHandler interface {
|
||||
// of the /sync request, including sending cached data in the event of retries. It does not handle
|
||||
// the contents of the data at all.
|
||||
type Conn struct {
|
||||
ConnID ConnID
|
||||
ConnID
|
||||
|
||||
handler ConnHandler
|
||||
|
||||
@ -65,10 +65,6 @@ func NewConn(connID ConnID, h ConnHandler) *Conn {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) UserID() string {
|
||||
return c.handler.UserID()
|
||||
}
|
||||
|
||||
func (c *Conn) Alive() bool {
|
||||
return c.handler.Alive()
|
||||
}
|
||||
@ -105,7 +101,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
|
||||
}
|
||||
ctx, task := internal.StartTask(ctx, taskType)
|
||||
defer task.End()
|
||||
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.handler.UserID(), c.ConnID.DeviceID, req.pos)
|
||||
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
|
||||
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
|
||||
}
|
||||
|
||||
@ -164,7 +160,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
|
||||
c.serverResponses = c.serverResponses[delIndex+1:] // slice out the first delIndex+1 elements
|
||||
|
||||
defer func() {
|
||||
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.handler.UserID())
|
||||
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.UserID)
|
||||
if nextUnACKedResponse != nil {
|
||||
l.Int64("new_pos", nextUnACKedResponse.PosInt())
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
conn = NewConn(cid, h)
|
||||
m.cache.Set(cid.String(), conn)
|
||||
m.connIDToConn[cid.String()] = conn
|
||||
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
|
||||
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
|
||||
return conn, true
|
||||
}
|
||||
|
||||
@ -94,20 +94,20 @@ func (m *ConnMap) closeConn(conn *Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
connID := conn.ConnID.String()
|
||||
logger.Trace().Str("conn", connID).Msg("closing connection")
|
||||
connKey := conn.ConnID.String()
|
||||
logger.Trace().Str("conn", connKey).Msg("closing connection")
|
||||
// remove conn from all the maps
|
||||
delete(m.connIDToConn, connID)
|
||||
delete(m.connIDToConn, connKey)
|
||||
h := conn.handler
|
||||
conns := m.userIDToConn[h.UserID()]
|
||||
conns := m.userIDToConn[conn.UserID]
|
||||
for i := 0; i < len(conns); i++ {
|
||||
if conns[i].ConnID.String() == connID {
|
||||
if conns[i].DeviceID == conn.DeviceID {
|
||||
// delete without preserving order
|
||||
conns[i] = conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
}
|
||||
}
|
||||
m.userIDToConn[h.UserID()] = conns
|
||||
m.userIDToConn[conn.UserID] = conns
|
||||
// remove user cache listeners etc
|
||||
h.Destroy()
|
||||
}
|
||||
|
@ -54,6 +54,7 @@ func (s *connStateLive) liveUpdate(
|
||||
ctx context.Context, req *sync3.Request, ex extensions.Request, isInitial bool,
|
||||
response *sync3.Response,
|
||||
) {
|
||||
log := logger.With().Str("user", s.userID).Str("device", s.deviceID).Logger()
|
||||
// we need to ensure that we keep consuming from the updates channel, even if they want a response
|
||||
// immediately. If we have new list data we won't wait, but if we don't then we need to be able to
|
||||
// catch-up to the current head position, hence giving 100ms grace period for processing.
|
||||
@ -67,17 +68,17 @@ func (s *connStateLive) liveUpdate(
|
||||
timeWaited := time.Since(startTime)
|
||||
timeLeftToWait := timeToWait - timeWaited
|
||||
if timeLeftToWait < 0 {
|
||||
logger.Trace().Str("user", s.userID).Str("time_waited", timeWaited.String()).Msg("liveUpdate: timed out")
|
||||
log.Trace().Str("time_waited", timeWaited.String()).Msg("liveUpdate: timed out")
|
||||
return
|
||||
}
|
||||
logger.Trace().Str("user", s.userID).Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
|
||||
log.Trace().Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
|
||||
select {
|
||||
case <-ctx.Done(): // client has given up
|
||||
logger.Trace().Str("user", s.userID).Msg("liveUpdate: client gave up")
|
||||
log.Trace().Msg("liveUpdate: client gave up")
|
||||
internal.Logf(ctx, "liveUpdate", "context cancelled")
|
||||
return
|
||||
case <-time.After(timeLeftToWait): // we've timed out
|
||||
logger.Trace().Str("user", s.userID).Msg("liveUpdate: timed out")
|
||||
log.Trace().Msg("liveUpdate: timed out")
|
||||
internal.Logf(ctx, "liveUpdate", "timed out after %v", timeLeftToWait)
|
||||
return
|
||||
case update := <-s.updates:
|
||||
@ -111,7 +112,7 @@ func (s *connStateLive) liveUpdate(
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Trace().Str("user", s.userID).Int("subs", len(response.Rooms)).Msg("liveUpdate: returning")
|
||||
log.Trace().Msg("liveUpdate: returning")
|
||||
// TODO: op consolidation
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
@ -14,7 +15,7 @@ type pendingInfo struct {
|
||||
type EnsurePoller struct {
|
||||
chanName string
|
||||
mu *sync.Mutex
|
||||
pendingPolls map[string]pendingInfo
|
||||
pendingPolls map[sync2.PollerID]pendingInfo
|
||||
notifier pubsub.Notifier
|
||||
}
|
||||
|
||||
@ -22,23 +23,22 @@ func NewEnsurePoller(notifier pubsub.Notifier) *EnsurePoller {
|
||||
return &EnsurePoller{
|
||||
chanName: pubsub.ChanV3,
|
||||
mu: &sync.Mutex{},
|
||||
pendingPolls: make(map[string]pendingInfo),
|
||||
pendingPolls: make(map[sync2.PollerID]pendingInfo),
|
||||
notifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
|
||||
// the caller's responsibility to call OnInitialSyncComplete when new events arrive.
|
||||
func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
|
||||
key := userID + "|" + deviceID
|
||||
func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
|
||||
p.mu.Lock()
|
||||
// do we need to wait?
|
||||
if p.pendingPolls[key].done {
|
||||
if p.pendingPolls[pid].done {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// have we called EnsurePolling for this user/device before?
|
||||
ch := p.pendingPolls[key].ch
|
||||
ch := p.pendingPolls[pid].ch
|
||||
if ch != nil {
|
||||
p.mu.Unlock()
|
||||
// we already called EnsurePolling on this device, so just listen for the close
|
||||
@ -50,15 +50,16 @@ func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
|
||||
}
|
||||
// Make a channel to wait until we have done an initial sync
|
||||
ch = make(chan struct{})
|
||||
p.pendingPolls[key] = pendingInfo{
|
||||
p.pendingPolls[pid] = pendingInfo{
|
||||
done: false,
|
||||
ch: ch,
|
||||
}
|
||||
p.mu.Unlock()
|
||||
// ask the pollers to poll for this device
|
||||
p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
UserID: pid.UserID,
|
||||
DeviceID: pid.DeviceID,
|
||||
AccessTokenHash: tokenHash,
|
||||
})
|
||||
// if by some miracle the notify AND sync completes before we receive on ch then this is
|
||||
// still fine as recv on a closed channel will return immediately.
|
||||
@ -66,15 +67,15 @@ func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
|
||||
}
|
||||
|
||||
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
|
||||
key := payload.UserID + "|" + payload.DeviceID
|
||||
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
pending, ok := p.pendingPolls[key]
|
||||
pending, ok := p.pendingPolls[pid]
|
||||
// were we waiting for this initial sync to complete?
|
||||
if !ok {
|
||||
// This can happen when the v2 poller spontaneously starts polling even without us asking it to
|
||||
// e.g from the database
|
||||
p.pendingPolls[key] = pendingInfo{
|
||||
p.pendingPolls[pid] = pendingInfo{
|
||||
done: true,
|
||||
}
|
||||
return
|
||||
@ -88,7 +89,7 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
|
||||
ch := pending.ch
|
||||
pending.done = true
|
||||
pending.ch = nil
|
||||
p.pendingPolls[key] = pending
|
||||
p.pendingPolls[pid] = pending
|
||||
close(ch)
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,12 @@ package handler
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -194,6 +197,10 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
}
|
||||
}
|
||||
}
|
||||
hlog.FromRequest(req).UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
c.Str("txn_id", requestBody.TxnID)
|
||||
return c
|
||||
})
|
||||
for listKey, l := range requestBody.Lists {
|
||||
if l.Ranges != nil && !l.Ranges.Valid() {
|
||||
return &internal.HandlerError{
|
||||
@ -215,8 +222,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
return herr
|
||||
}
|
||||
requestBody.SetPos(cpos)
|
||||
internal.SetRequestContextUserID(req.Context(), conn.UserID())
|
||||
log := hlog.FromRequest(req).With().Str("user", conn.UserID()).Int64("pos", cpos).Logger()
|
||||
internal.SetRequestContextUserID(req.Context(), conn.UserID)
|
||||
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
|
||||
|
||||
var timeout int
|
||||
if req.URL.Query().Get("timeout") == "" {
|
||||
@ -276,25 +283,53 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
|
||||
// When this function returns, the connection is alive and active.
|
||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, error) {
|
||||
log := hlog.FromRequest(req)
|
||||
var conn *sync3.Conn
|
||||
|
||||
// Identify the device
|
||||
deviceID, accessToken, err := internal.HashedTokenFromRequest(req)
|
||||
// Extract an access token
|
||||
accessToken, err := internal.ExtractAccessToken(req)
|
||||
if err != nil || accessToken == "" {
|
||||
log.Warn().Err(err).Msg("failed to get device ID from request")
|
||||
hlog.FromRequest(req).Warn().Err(err).Msg("failed to get access token from request")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 400,
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Try to lookup a record of this token
|
||||
var token *sync2.Token
|
||||
token, err = h.V2Store.TokensTable.Token(accessToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
hlog.FromRequest(req).Info().Msg("Received connection from unknown access token, querying with homeserver")
|
||||
newToken, herr := h.identifyUnknownAccessToken(accessToken, hlog.FromRequest(req))
|
||||
if herr != nil {
|
||||
return nil, herr
|
||||
}
|
||||
token = newToken
|
||||
} else {
|
||||
hlog.FromRequest(req).Err(err).Msg("Failed to lookup access token")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
|
||||
|
||||
// Record the fact that we've recieved a request from this token
|
||||
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
|
||||
if err != nil {
|
||||
// Not fatal---log and continue.
|
||||
log.Warn().Err(err).Msg("Unable to update last seen timestamp")
|
||||
}
|
||||
|
||||
connID := sync3.ConnID{
|
||||
UserID: token.UserID,
|
||||
DeviceID: token.DeviceID,
|
||||
}
|
||||
// client thinks they have a connection
|
||||
if containsPos {
|
||||
// Lookup the connection
|
||||
conn = h.ConnMap.Conn(sync3.ConnID{
|
||||
DeviceID: deviceID,
|
||||
})
|
||||
conn = h.ConnMap.Conn(connID)
|
||||
if conn != nil {
|
||||
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
|
||||
return conn, nil
|
||||
@ -303,55 +338,23 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
return nil, internal.ExpiredSessionError()
|
||||
}
|
||||
|
||||
// We're going to make a new connection
|
||||
// Ensure we have the v2 side of things hooked up
|
||||
v2device, err := h.V2Store.InsertDevice(deviceID, accessToken)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to insert v2 device")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 500,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
if v2device.UserID == "" {
|
||||
v2device.UserID, _, err = h.V2.WhoAmI(accessToken)
|
||||
if err != nil {
|
||||
if err == sync2.HTTP401 {
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 401,
|
||||
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
||||
}
|
||||
}
|
||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
if err = h.V2Store.UpdateUserIDForDevice(deviceID, v2device.UserID); err != nil {
|
||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to persist user ID -> device ID mapping")
|
||||
// non-fatal, we can still work without doing this
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
|
||||
h.V3Pub.EnsurePolling(v2device.UserID, v2device.DeviceID)
|
||||
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")
|
||||
log.Trace().Msg("checking poller exists and is running")
|
||||
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
|
||||
h.V3Pub.EnsurePolling(pid, token.AccessTokenHash)
|
||||
log.Trace().Msg("poller exists and is running")
|
||||
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
|
||||
// We'll be quicker next time as the poller will already exist.
|
||||
if req.Context().Err() != nil {
|
||||
log.Warn().Str("user_id", v2device.UserID).Msg(
|
||||
"client gave up, not creating connection",
|
||||
)
|
||||
log.Warn().Msg("client gave up, not creating connection")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 400,
|
||||
Err: req.Context().Err(),
|
||||
}
|
||||
}
|
||||
|
||||
userCache, err := h.userCache(v2device.UserID)
|
||||
userCache, err := h.userCache(token.UserID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
|
||||
log.Warn().Err(err).Msg("failed to load user cache")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 500,
|
||||
Err: err,
|
||||
@ -366,19 +369,59 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
||||
// to check for an existing connection though, as it's possible for the client to call /sync
|
||||
// twice for a new connection.
|
||||
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
|
||||
DeviceID: deviceID,
|
||||
}, func() sync3.ConnHandler {
|
||||
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
|
||||
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
|
||||
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
|
||||
})
|
||||
if created {
|
||||
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection")
|
||||
log.Info().Msg("created new connection")
|
||||
} else {
|
||||
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("using existing connection")
|
||||
log.Info().Msg("using existing connection")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger *zerolog.Logger) (*sync2.Token, *internal.HandlerError) {
|
||||
// We don't recognise the given accessToken. Ask the homeserver who owns it.
|
||||
userID, deviceID, err := h.V2.WhoAmI(accessToken)
|
||||
if err != nil {
|
||||
if err == sync2.HTTP401 {
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 401,
|
||||
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
||||
}
|
||||
}
|
||||
log.Warn().Err(err).Msg("failed to get user ID from device ID")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
var token *sync2.Token
|
||||
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
|
||||
// Create a brand-new row for this token.
|
||||
token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure we have a device row for this token.
|
||||
err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, &internal.HandlerError{StatusCode: 500, Err: err}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
|
||||
c, ok := h.userCaches.Load(userID)
|
||||
if ok {
|
||||
@ -551,6 +594,7 @@ func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
||||
defer task.End()
|
||||
conn := h.ConnMap.Conn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
if conn == nil {
|
||||
@ -563,6 +607,7 @@ func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
|
||||
defer task.End()
|
||||
conn := h.ConnMap.Conn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
if conn == nil {
|
||||
@ -659,6 +704,7 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
|
||||
|
||||
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
|
||||
h.ConnMap.CloseConn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
}
|
||||
|
@ -639,17 +639,10 @@ func TestExpiredAccessToken(t *testing.T) {
|
||||
})
|
||||
// now expire the token
|
||||
v2.invalidateToken(aliceToken)
|
||||
// now do another request, this should 400 as it expires the session
|
||||
// now do another request, this should 401
|
||||
req := sync3.Request{}
|
||||
req.SetTimeoutMSecs(1)
|
||||
_, body, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, req)
|
||||
if statusCode != 400 {
|
||||
t.Fatalf("got %d want 400 : %v", statusCode, string(body))
|
||||
}
|
||||
// do a fresh request, this should 401
|
||||
req = sync3.Request{}
|
||||
req.SetTimeoutMSecs(1)
|
||||
_, body, statusCode = v3.doV3Request(t, context.Background(), aliceToken, "", req)
|
||||
if statusCode != 401 {
|
||||
t.Fatalf("got %d want 401 : %v", statusCode, string(body))
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
// 1: check that a fresh sync returns to-device messages
|
||||
t.Log("1: check that a fresh sync returns to-device messages")
|
||||
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -239,7 +239,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(toDeviceMsgs))
|
||||
|
||||
// 2: repeating the fresh sync request returns the same messages (not deleted)
|
||||
t.Log("2: repeating the fresh sync request returns the same messages (not deleted)")
|
||||
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -254,7 +254,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(toDeviceMsgs))
|
||||
|
||||
// 3: update the since token -> no new messages
|
||||
t.Log("3: update the since token -> no new messages")
|
||||
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -270,7 +270,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
|
||||
|
||||
// 4: inject live to-device messages -> receive them only.
|
||||
t.Log("4: inject live to-device messages -> receive them only.")
|
||||
sinceBeforeMsgs := res.Extensions.ToDevice.NextBatch
|
||||
newToDeviceMsgs := []json.RawMessage{
|
||||
json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"5"}}`),
|
||||
@ -296,7 +296,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(newToDeviceMsgs))
|
||||
|
||||
// 5: repeating the previous sync request returns the same live to-device messages (retransmit)
|
||||
t.Log("5: repeating the previous sync request returns the same live to-device messages (retransmit)")
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -327,7 +327,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
// this response contains nothing
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
|
||||
|
||||
// 6: using an old since token does not return to-device messages anymore as they were deleted.
|
||||
t.Log("6: using an old since token does not return to-device messages anymore as they were deleted.")
|
||||
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
@ -343,7 +343,7 @@ func TestExtensionToDevice(t *testing.T) {
|
||||
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
|
||||
|
||||
// live stream and block, then send a to-device msg which should go through immediately
|
||||
t.Log("7: live stream and block, then send a to-device msg which should go through immediately")
|
||||
start := time.Now()
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
@ -25,7 +25,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
deviceAToken := "DEVICE_A_TOKEN"
|
||||
v2.addAccount(alice, deviceAToken)
|
||||
v2.addAccountWithDeviceID(alice, "A", deviceAToken)
|
||||
v2.queueResponse(deviceAToken, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
@ -39,7 +39,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
|
||||
|
||||
// now sync with device B, and check we send the filter up
|
||||
deviceBToken := "DEVICE_B_TOKEN"
|
||||
v2.addAccount(alice, deviceBToken)
|
||||
v2.addAccountWithDeviceID(alice, "B", deviceBToken)
|
||||
seenInitialRequest := false
|
||||
v2.CheckRequest = func(userID, token string, req *http.Request) {
|
||||
if userID != alice || token != deviceBToken {
|
||||
|
@ -49,6 +49,7 @@ type testV2Server struct {
|
||||
CheckRequest func(userID, token string, req *http.Request)
|
||||
mu *sync.Mutex
|
||||
tokenToUser map[string]string
|
||||
tokenToDevice map[string]string
|
||||
queues map[string]chan sync2.SyncResponse
|
||||
waiting map[string]*sync.Cond // broadcasts when the server is about to read a blocking input
|
||||
srv *httptest.Server
|
||||
@ -56,10 +57,21 @@ type testV2Server struct {
|
||||
timeToWaitForV2Response time.Duration
|
||||
}
|
||||
|
||||
// Most tests only use a single device per user. Give them this helper so they don't
|
||||
// have to care about providing a device name.
|
||||
func (s *testV2Server) addAccount(userID, token string) {
|
||||
// To keep our future selves sane while debugging, use a device name that
|
||||
// includes the mxid localpart.
|
||||
atLocalPart, _, _ := strings.Cut(userID, ":")
|
||||
s.addAccountWithDeviceID(userID, atLocalPart[1:]+"_device", token)
|
||||
}
|
||||
|
||||
// Tests that use multiple devices for the same user need to be more explicit.
|
||||
func (s *testV2Server) addAccountWithDeviceID(userID, deviceID, token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokenToUser[token] = userID
|
||||
s.tokenToDevice[token] = deviceID
|
||||
s.queues[token] = make(chan sync2.SyncResponse, 100)
|
||||
s.waiting[token] = &sync.Cond{
|
||||
L: &sync.Mutex{},
|
||||
@ -77,6 +89,7 @@ func (s *testV2Server) invalidateToken(token string) {
|
||||
wg.Done()
|
||||
}
|
||||
delete(s.tokenToUser, token)
|
||||
delete(s.tokenToDevice, token)
|
||||
s.mu.Unlock()
|
||||
|
||||
// kick over the connection so the next request 401s and wait till we get said request
|
||||
@ -97,6 +110,12 @@ func (s *testV2Server) userID(token string) string {
|
||||
return s.tokenToUser[token]
|
||||
}
|
||||
|
||||
func (s *testV2Server) deviceID(token string) string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokenToDevice[token]
|
||||
}
|
||||
|
||||
func (s *testV2Server) queueResponse(userIDOrToken string, resp sync2.SyncResponse) {
|
||||
s.mu.Lock()
|
||||
ch := s.queues[userIDOrToken]
|
||||
@ -185,6 +204,7 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
|
||||
t.Helper()
|
||||
server := &testV2Server{
|
||||
tokenToUser: make(map[string]string),
|
||||
tokenToDevice: make(map[string]string),
|
||||
queues: make(map[string]chan sync2.SyncResponse),
|
||||
waiting: make(map[string]*sync.Cond),
|
||||
invalidations: make(map[string]func()),
|
||||
@ -195,7 +215,8 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
|
||||
r.HandleFunc("/_matrix/client/r0/account/whoami", func(w http.ResponseWriter, req *http.Request) {
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
userID := server.userID(token)
|
||||
if userID == "" {
|
||||
deviceID := server.deviceID(token)
|
||||
if userID == "" || deviceID == "" {
|
||||
w.WriteHeader(401)
|
||||
server.mu.Lock()
|
||||
fn := server.invalidations[token]
|
||||
@ -206,7 +227,7 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
|
||||
return
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(fmt.Sprintf(`{"user_id":"%s"}`, userID)))
|
||||
w.Write([]byte(fmt.Sprintf(`{"user_id":"%s","device_id":"%s"}`, userID, deviceID)))
|
||||
})
|
||||
r.HandleFunc("/_matrix/client/r0/sync", func(w http.ResponseWriter, req *http.Request) {
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
Loading…
x
Reference in New Issue
Block a user