Merge pull request #90 from matrix-org/dmr/tokens-table

OIDC: track tokens separately to devices
This commit is contained in:
David Robertson 2023-05-02 18:14:19 +01:00 committed by GitHub
commit 0adaf75cfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 883 additions and 454 deletions

View File

@ -4,7 +4,6 @@ on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
permissions:
packages: read

View File

@ -1,24 +1,16 @@
package internal
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
)
func HashedTokenFromRequest(req *http.Request) (hashAccessToken string, accessToken string, err error) {
// return a hash of the access token
func ExtractAccessToken(req *http.Request) (accessToken string, err error) {
ah := req.Header.Get("Authorization")
if ah == "" {
return "", "", fmt.Errorf("missing Authorization header")
return "", fmt.Errorf("missing Authorization header")
}
accessToken = strings.TrimPrefix(ah, "Bearer ")
// important that this is a cryptographically secure hash function to prevent
// preimage attacks where Eve can use a fake token to hash to an existing device ID
// on the server.
hash := sha256.New()
hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil)), accessToken, nil
return accessToken, nil
}

View File

@ -1,24 +0,0 @@
package internal
import (
"net/http"
"testing"
)
func TestDeviceIDFromRequest(t *testing.T) {
req, _ := http.NewRequest("POST", "http://localhost:8008", nil)
req.Header.Set("Authorization", "Bearer A")
deviceIDA, _, err := HashedTokenFromRequest(req)
if err != nil {
t.Fatalf("HashedTokenFromRequest returned %s", err)
}
req.Header.Set("Authorization", "Bearer B")
deviceIDB, _, err := HashedTokenFromRequest(req)
if err != nil {
t.Fatalf("HashedTokenFromRequest returned %s", err)
}
if deviceIDA == deviceIDB {
t.Fatalf("HashedTokenFromRequest: hashed to same device ID: %s", deviceIDA)
}
}

View File

@ -78,6 +78,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct {
UserID string
DeviceID string
Pos int64
}
@ -106,6 +107,7 @@ type V2DeviceMessages struct {
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
type V2ExpiredToken struct {
UserID string
DeviceID string
}

View File

@ -8,8 +8,11 @@ type V3Listener interface {
}
type V3EnsurePolling struct {
UserID string
DeviceID string
// TODO: we only really need to provide the access token hash here.
// Passing through a user means we can log something sensible though.
UserID string
DeviceID string
AccessTokenHash string
}
func (*V3EnsurePolling) Type() string { return "V3EnsurePolling" }

47
sync2/devices_table.go Normal file
View File

@ -0,0 +1,47 @@
package sync2
import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
)
type Device struct {
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
Since string `db:"since"`
}
// DevicesTable remembers syncv2 since positions per-device
type DevicesTable struct {
db *sqlx.DB
}
func NewDevicesTable(db *sqlx.DB) *DevicesTable {
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_sync2_devices (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
PRIMARY KEY (user_id, device_id),
since TEXT NOT NULL
);`)
return &DevicesTable{
db: db,
}
}
// InsertDevice creates a new devices row with a blank since token if no such row
// exists. Otherwise, it does nothing.
func (t *DevicesTable) InsertDevice(userID, deviceID string) error {
_, err := t.db.Exec(
` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO NOTHING`,
userID, deviceID, "",
)
return err
}
func (t *DevicesTable) UpdateDeviceSince(userID, deviceID, since string) error {
_, err := t.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE user_id = $2 AND device_id = $3`, since, userID, deviceID)
return err
}

168
sync2/devices_table_test.go Normal file
View File

@ -0,0 +1,168 @@
package sync2
import (
"github.com/jmoiron/sqlx"
"os"
"sort"
"testing"
"time"
"github.com/matrix-org/sliding-sync/testutils"
)
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
func TestMain(m *testing.M) {
postgresConnectionString = testutils.PrepareDBConnectionString()
exitCode := m.Run()
os.Exit(exitCode)
}
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
return db, func() {
db.Close()
}
}
// Note that we currently only ever read from (devices JOIN tokens), so there is some
// overlap with tokens_table_test.go here.
func TestDevicesTableSinceColumn(t *testing.T) {
db, close := connectToDB(t)
defer close()
tokens := NewTokensTable(db, "my_secret")
devices := NewDevicesTable(db)
alice := "@alice:localhost"
aliceDevice := "alice_phone"
aliceSecret1 := "mysecret1"
aliceSecret2 := "mysecret2"
t.Log("Insert two tokens for Alice.")
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("Add a devices row for Alice")
err = devices.InsertDevice(alice, aliceDevice)
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
if err != nil {
t.Fatalf("Failed to GetTokenAndSince: %s", err)
}
t.Log("The since token should be empty.")
assertEqual(t, accessToken, aliceToken.AccessToken, "Token.AccessToken mismatch")
assertEqual(t, since, "", "Device.Since mismatch")
t.Log("Update the since column.")
sinceValue := "s-1-2-3-4"
err = devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
if err != nil {
t.Fatalf("Failed to update since column: %s", err)
}
t.Log("We should see the new since value when the poller refetches alice's token")
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
if err != nil {
t.Fatalf("Failed to GetTokenAndSince: %s", err)
}
assertEqual(t, since, sinceValue, "Device.Since mismatch")
t.Log("We should also see the new since value when the poller fetches alice's second token")
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken2.AccessTokenHash)
if err != nil {
t.Fatalf("Failed to GetTokenAndSince: %s", err)
}
assertEqual(t, since, sinceValue, "Device.Since mismatch")
}
func TestTokenForEachDevice(t *testing.T) {
db, close := connectToDB(t)
defer close()
// HACK: discard rows inserted by other tests. We don't normally need to do this,
// but this is testing a query that scans the entire devices table.
db.MustExec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
tokens := NewTokensTable(db, "my_secret")
devices := NewDevicesTable(db)
alice := "alice"
aliceDevice := "alice_phone"
bob := "bob"
bobDevice := "bob_laptop"
chris := "chris"
chrisDevice := "chris_desktop"
t.Log("Add a device for Alice, Bob and Chris.")
err := devices.InsertDevice(alice, aliceDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(bob, bobDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(chris, chrisDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
t.Log("Mark Alice's device with a since token.")
sinceValue := "s-1-2-3-4"
devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
aliceLastSeen1 := time.Now()
_, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("Fetch a token for every device")
gotTokens, err := tokens.TokenForEachDevice()
if err != nil {
t.Fatalf("Failed TokenForEachDevice: %s", err)
}
expectAlice := TokenForPoller{Token: aliceToken2, Since: sinceValue}
expectBob := TokenForPoller{Token: bobToken, Since: ""}
wantTokens := []*TokenForPoller{&expectAlice, &expectBob}
if len(gotTokens) != len(wantTokens) {
t.Fatalf("AllDevices: got %d tokens, want %d", len(gotTokens), len(wantTokens))
}
sort.Slice(gotTokens, func(i, j int) bool {
if gotTokens[i].UserID < gotTokens[j].UserID {
return true
}
return gotTokens[i].DeviceID < gotTokens[j].DeviceID
})
for i := range gotTokens {
assertEqual(t, gotTokens[i].Since, wantTokens[i].Since, "Device.Since mismatch")
assertEqual(t, gotTokens[i].UserID, wantTokens[i].UserID, "Token.UserID mismatch")
assertEqual(t, gotTokens[i].DeviceID, wantTokens[i].DeviceID, "Token.DeviceID mismatch")
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
}
}

View File

@ -95,9 +95,9 @@ func (h *Handler) Teardown() {
}
func (h *Handler) StartV2Pollers() {
devices, err := h.v2Store.AllDevices()
tokens, err := h.v2Store.TokensTable.TokenForEachDevice()
if err != nil {
logger.Err(err).Msg("StartV2Pollers: failed to query devices")
logger.Err(err).Msg("StartV2Pollers: failed to query tokens")
sentry.CaptureException(err)
return
}
@ -106,30 +106,34 @@ func (h *Handler) StartV2Pollers() {
// Too low and this will take ages for the v2 pollers to startup.
numWorkers := 16
numFails := 0
ch := make(chan sync2.Device, len(devices))
for _, d := range devices {
ch := make(chan sync2.TokenForPoller, len(tokens))
for _, t := range tokens {
// if we fail to decrypt the access token, skip it.
if d.AccessToken == "" {
if t.AccessToken == "" {
numFails++
continue
}
ch <- d
ch <- t
}
close(ch)
logger.Info().Int("num_devices", len(devices)).Int("num_fail_decrypt", numFails).Msg("StartV2Pollers")
logger.Info().Int("num_devices", len(tokens)).Int("num_fail_decrypt", numFails).Msg("StartV2Pollers")
var wg sync.WaitGroup
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go func() {
defer wg.Done()
for d := range ch {
for t := range ch {
pid := sync2.PollerID{
UserID: t.UserID,
DeviceID: t.DeviceID,
}
h.pMap.EnsurePolling(
d.AccessToken, d.UserID, d.DeviceID, d.Since, true,
logger.With().Str("user_id", d.UserID).Logger(),
pid, t.AccessToken, t.Since, true,
logger.With().Str("user_id", t.UserID).Str("device_id", t.DeviceID).Logger(),
)
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InitialSyncComplete{
UserID: d.UserID,
DeviceID: d.DeviceID,
UserID: t.UserID,
DeviceID: t.DeviceID,
})
}
}()
@ -150,12 +154,11 @@ func (h *Handler) OnTerminated(userID, deviceID string) {
h.updateMetrics()
}
func (h *Handler) OnExpiredToken(userID, deviceID string) {
h.v2Store.RemoveDevice(deviceID)
h.Store.ToDeviceTable.DeleteAllMessagesForDevice(deviceID)
h.Store.DeviceDataTable.DeleteDevice(userID, deviceID)
// also notify v3 side so it can remove the connection from ConnMap
func (h *Handler) OnExpiredToken(accessTokenHash, userID, deviceID string) {
h.v2Store.TokensTable.Delete(accessTokenHash)
// Notify v3 side so it can remove the connection from ConnMap
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
UserID: userID,
DeviceID: deviceID,
})
}
@ -171,10 +174,10 @@ func (h *Handler) addPrometheusMetrics() {
}
// Emits nothing as no downstream components need it.
func (h *Handler) UpdateDeviceSince(deviceID, since string) {
err := h.v2Store.UpdateDeviceSince(deviceID, since)
func (h *Handler) UpdateDeviceSince(userID, deviceID, since string) {
err := h.v2Store.DevicesTable.UpdateDeviceSince(userID, deviceID, since)
if err != nil {
logger.Err(err).Str("device", deviceID).Str("since", since).Msg("V2: failed to persist since token")
logger.Err(err).Str("user", userID).Str("device", deviceID).Str("since", since).Msg("V2: failed to persist since token")
sentry.CaptureException(err)
}
}
@ -197,6 +200,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
UserID: userID,
DeviceID: deviceID,
Pos: nextPos,
})
@ -387,7 +391,7 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
defer func() {
logger.Info().Str("user", p.UserID).Msg("EnsurePolling: request finished")
}()
dev, err := h.v2Store.Device(p.DeviceID)
accessToken, since, err := h.v2Store.TokensTable.GetTokenAndSince(p.UserID, p.DeviceID, p.AccessTokenHash)
if err != nil {
logger.Err(err).Str("user", p.UserID).Str("device", p.DeviceID).Msg("V3Sub: EnsurePolling unknown device")
sentry.CaptureException(err)
@ -396,9 +400,13 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
// don't block us from consuming more pubsub messages just because someone wants to sync
go func() {
// blocks until an initial sync is done
pid := sync2.PollerID{
UserID: p.UserID,
DeviceID: p.DeviceID,
}
h.pMap.EnsurePolling(
dev.AccessToken, dev.UserID, dev.DeviceID, dev.Since, false,
logger.With().Str("user_id", dev.UserID).Logger(),
pid, accessToken, since, false,
logger.With().Str("user_id", p.UserID).Str("device_id", p.DeviceID).Logger(),
)
h.updateMetrics()
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InitialSyncComplete{

View File

@ -14,13 +14,18 @@ import (
"github.com/tidwall/gjson"
)
type PollerID struct {
UserID string
DeviceID string
}
// alias time.Sleep so tests can monkey patch it out
var timeSleep = time.Sleep
// V2DataReceiver is the receiver for all the v2 sync data the poller gets
type V2DataReceiver interface {
// Update the since token for this device. Called AFTER all other data in this sync response has been processed.
UpdateDeviceSince(deviceID, since string)
UpdateDeviceSince(userID, deviceID, since string)
// Accumulate data for this room. This means the timeline section of the v2 response.
Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
// Initialise the room, if it hasn't been already. This means the state section of the v2 response.
@ -45,7 +50,7 @@ type V2DataReceiver interface {
// Sent when the poll loop terminates
OnTerminated(userID, deviceID string)
// Sent when the token gets a 401 response
OnExpiredToken(userID, deviceID string)
OnExpiredToken(accessTokenHash, userID, deviceID string)
}
// PollerMap is a map of device ID to Poller
@ -53,7 +58,7 @@ type PollerMap struct {
v2Client Client
callbacks V2DataReceiver
pollerMu *sync.Mutex
Pollers map[string]*poller // device_id -> poller
Pollers map[PollerID]*poller
executor chan func()
executorRunning bool
processHistogramVec *prometheus.HistogramVec
@ -88,7 +93,7 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap {
pm := &PollerMap{
v2Client: v2Client,
pollerMu: &sync.Mutex{},
Pollers: make(map[string]*poller),
Pollers: make(map[PollerID]*poller),
executor: make(chan func(), 0),
}
if enablePrometheus {
@ -151,15 +156,18 @@ func (h *PollerMap) NumPollers() (count int) {
// Note that we will immediately return if there is a poller for the same user but a different device.
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
// to-device msgs to decrypt E2EE roms.
func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string, isStartup bool, logger zerolog.Logger) {
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
h.pollerMu.Lock()
if !h.executorRunning {
h.executorRunning = true
go h.execute()
}
poller, ok := h.Pollers[deviceID]
poller, ok := h.Pollers[pid]
// a poller exists and hasn't been terminated so we don't need to do anything
if ok && !poller.terminated.Load() {
if poller.accessToken != accessToken {
logger.Warn().Msg("PollerMap.EnsurePolling: poller already running with different access token")
}
h.pollerMu.Unlock()
// this existing poller may not have completed the initial sync yet, so we need to make sure
// it has before we return.
@ -169,28 +177,31 @@ func (h *PollerMap) EnsurePolling(accessToken, userID, deviceID, v2since string,
// check if we need to wait at all: we don't need to if this user is already syncing on a different device
// This is O(n) so we may want to map this if we get a lot of users...
needToWait := true
for pollerDeviceID, poller := range h.Pollers {
if deviceID == pollerDeviceID {
for existingPID, poller := range h.Pollers {
// Ignore different users. Also ignore same-user same-device.
if pid.UserID != existingPID.UserID || pid.DeviceID == existingPID.DeviceID {
continue
}
if poller.userID == userID && !poller.terminated.Load() {
// Now we have same-user different-device.
if !poller.terminated.Load() {
needToWait = false
break
}
}
// replace the poller. If we don't need to wait, then we just want to nab to-device events initially.
// We don't do that on startup though as we cannot be sure that other pollers will not be using expired tokens.
poller = newPoller(userID, accessToken, deviceID, h.v2Client, h, logger, !needToWait && !isStartup)
poller = newPoller(pid, accessToken, h.v2Client, h, logger, !needToWait && !isStartup)
poller.processHistogramVec = h.processHistogramVec
poller.timelineSizeVec = h.timelineSizeHistogramVec
go poller.Poll(v2since)
h.Pollers[deviceID] = poller
h.Pollers[pid] = poller
h.pollerMu.Unlock()
if needToWait {
poller.WaitUntilInitialSync()
} else {
logger.Info().Str("user", userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
logger.Info().Msg("a poller exists for this user; not waiting for this device to do an initial sync")
}
}
@ -200,8 +211,8 @@ func (h *PollerMap) execute() {
}
}
func (h *PollerMap) UpdateDeviceSince(deviceID, since string) {
h.callbacks.UpdateDeviceSince(deviceID, since)
func (h *PollerMap) UpdateDeviceSince(userID, deviceID, since string) {
h.callbacks.UpdateDeviceSince(userID, deviceID, since)
}
func (h *PollerMap) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
var wg sync.WaitGroup
@ -261,8 +272,8 @@ func (h *PollerMap) OnTerminated(userID, deviceID string) {
h.callbacks.OnTerminated(userID, deviceID)
}
func (h *PollerMap) OnExpiredToken(userID, deviceID string) {
h.callbacks.OnExpiredToken(userID, deviceID)
func (h *PollerMap) OnExpiredToken(accessTokenHash, userID, deviceID string) {
h.callbacks.OnExpiredToken(accessTokenHash, userID, deviceID)
}
func (h *PollerMap) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
@ -308,8 +319,8 @@ func (h *PollerMap) OnE2EEData(userID, deviceID string, otkCounts map[string]int
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in storage
type poller struct {
userID string
accessToken string
deviceID string
accessToken string
client Client
receiver V2DataReceiver
logger zerolog.Logger
@ -329,13 +340,13 @@ type poller struct {
timelineSizeVec *prometheus.HistogramVec
}
func newPoller(userID, accessToken, deviceID string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
func newPoller(pid PollerID, accessToken string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
var wg sync.WaitGroup
wg.Add(1)
return &poller{
userID: pid.UserID,
deviceID: pid.DeviceID,
accessToken: accessToken,
userID: userID,
deviceID: deviceID,
client: client,
receiver: receiver,
terminated: &atomic.Bool{},
@ -390,7 +401,7 @@ func (p *poller) Poll(since string) {
continue
} else {
p.logger.Warn().Msg("Poller: access token has been invalidated, terminating loop")
p.receiver.OnExpiredToken(p.userID, p.deviceID)
p.receiver.OnExpiredToken(hashToken(p.accessToken), p.userID, p.deviceID)
p.Terminate()
break
}
@ -411,7 +422,7 @@ func (p *poller) Poll(since string) {
since = resp.NextBatch
// persist the since token (TODO: this could get slow if we hammer the DB too much)
p.receiver.UpdateDeviceSince(p.deviceID, since)
p.receiver.UpdateDeviceSince(p.userID, p.deviceID, since)
if firstTime {
firstTime = false
@ -522,7 +533,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
// the timeline now so that future events are received under the
// correct room state.
const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll"
log.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg)
logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg)
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetContext("sliding-sync", map[string]interface{}{
"room_id": roomID,

View File

@ -51,7 +51,10 @@ func TestPollerMapEnsurePolling(t *testing.T) {
ensurePollingUnblocked := make(chan struct{})
go func() {
pm.EnsurePolling("access_token", "@alice:localhost", "FOOBAR", "", false, zerolog.New(os.Stderr))
pm.EnsurePolling(PollerID{
UserID: "alice:localhost",
DeviceID: "FOOBAR",
}, "access_token", "", false, zerolog.New(os.Stderr))
close(ensurePollingUnblocked)
}()
ensureBlocking := func() {
@ -136,7 +139,7 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
for i := 0; i < n; i++ {
go func() {
t.Logf("EnsurePolling")
pm.EnsurePolling("access_token", "@alice:localhost", "FOOBAR", "", false, zerolog.New(os.Stderr))
pm.EnsurePolling(PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}, "access_token", "", false, zerolog.New(os.Stderr))
wg.Done()
t.Logf("EnsurePolling unblocked")
}()
@ -195,7 +198,7 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
// Check that a call to Poll starts polling and accumulating, and terminates on 401s.
func TestPollerPollFromNothing(t *testing.T) {
nextSince := "next"
deviceID := "FOOBAR"
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
roomID := "!foo:bar"
roomState := []json.RawMessage{
json.RawMessage(`{"event":1}`),
@ -224,7 +227,7 @@ func TestPollerPollFromNothing(t *testing.T) {
})
var wg sync.WaitGroup
wg.Add(1)
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
go func() {
defer wg.Done()
poller.Poll("")
@ -244,14 +247,14 @@ func TestPollerPollFromNothing(t *testing.T) {
if len(accumulator.states[roomID]) != len(roomState) {
t.Errorf("did not accumulate initial state for room, got %d events want %d", len(accumulator.states[roomID]), len(roomState))
}
if accumulator.deviceIDToSince[deviceID] != nextSince {
t.Errorf("did not persist latest since token, got %s want %s", accumulator.deviceIDToSince[deviceID], nextSince)
if accumulator.pollerIDToSince[pid] != nextSince {
t.Errorf("did not persist latest since token, got %s want %s", accumulator.pollerIDToSince[pid], nextSince)
}
}
// Check that a call to Poll starts polling with an existing since token and accumulates timeline entries
func TestPollerPollFromExisting(t *testing.T) {
deviceID := "FOOBAR"
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
roomID := "!foo:bar"
since := "0"
roomTimelineResponses := [][]json.RawMessage{
@ -307,7 +310,7 @@ func TestPollerPollFromExisting(t *testing.T) {
})
var wg sync.WaitGroup
wg.Add(1)
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
go func() {
defer wg.Done()
poller.Poll(since)
@ -328,8 +331,8 @@ func TestPollerPollFromExisting(t *testing.T) {
t.Errorf("did not accumulate timelines for room, got %d events want %d", len(accumulator.timelines[roomID]), 10)
}
wantSince := fmt.Sprintf("%d", len(roomTimelineResponses))
if accumulator.deviceIDToSince[deviceID] != wantSince {
t.Errorf("did not persist latest since token, got %s want %s", accumulator.deviceIDToSince[deviceID], wantSince)
if accumulator.pollerIDToSince[pid] != wantSince {
t.Errorf("did not persist latest since token, got %s want %s", accumulator.pollerIDToSince[pid], wantSince)
}
}
@ -383,7 +386,7 @@ func TestPollerBackoff(t *testing.T) {
}
var wg sync.WaitGroup
wg.Add(1)
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
poller := newPoller(PollerID{UserID: "@alice:localhost", DeviceID: deviceID}, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
go func() {
defer wg.Done()
poller.Poll("some_since_value")
@ -413,7 +416,7 @@ func TestPollerUnblocksIfTerminatedInitially(t *testing.T) {
pollUnblocked := make(chan struct{})
waitUntilInitialSyncUnblocked := make(chan struct{})
poller := newPoller("@alice:localhost", "Authorization: hello world", deviceID, client, accumulator, zerolog.New(os.Stderr), false)
poller := newPoller(PollerID{UserID: "@alice:localhost", DeviceID: deviceID}, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
go func() {
poller.Poll("")
close(pollUnblocked)
@ -452,7 +455,7 @@ func (c *mockClient) WhoAmI(authHeader string) (string, string, error) {
type mockDataReceiver struct {
states map[string][]json.RawMessage
timelines map[string][]json.RawMessage
deviceIDToSince map[string]string
pollerIDToSince map[PollerID]string
incomingProcess chan struct{}
unblockProcess chan struct{}
}
@ -474,8 +477,8 @@ func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) []
}
func (a *mockDataReceiver) SetTyping(roomID string, ephEvent json.RawMessage) {
}
func (s *mockDataReceiver) UpdateDeviceSince(deviceID, since string) {
s.deviceIDToSince[deviceID] = since
func (s *mockDataReceiver) UpdateDeviceSince(userID, deviceID, since string) {
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
}
func (s *mockDataReceiver) AddToDeviceMessages(userID, deviceID string, msgs []json.RawMessage) {
}
@ -488,8 +491,8 @@ func (s *mockDataReceiver) OnInvite(userID, roomID string, inviteState []json.Ra
func (s *mockDataReceiver) OnLeftRoom(userID, roomID string) {}
func (s *mockDataReceiver) OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
func (s *mockDataReceiver) OnExpiredToken(userID, deviceID string) {}
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
func (s *mockDataReceiver) OnExpiredToken(accessTokenHash, userID, deviceID string) {}
func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error)) (*mockDataReceiver, *mockClient) {
client := &mockClient{
@ -498,7 +501,7 @@ func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error
accumulator := &mockDataReceiver{
states: make(map[string][]json.RawMessage),
timelines: make(map[string][]json.RawMessage),
deviceIDToSince: make(map[string]string),
pollerIDToSince: make(map[PollerID]string),
}
return accumulator, client
}

View File

@ -1,45 +1,21 @@
package sync2
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/getsentry/sentry-go"
"io"
"os"
"strings"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/rs/zerolog"
"os"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
type Device struct {
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
Since string `db:"since"`
AccessToken string
AccessTokenEncrypted string `db:"v2_token_encrypted"`
}
// Storage remembers sync v2 tokens per-device
type Storage struct {
db *sqlx.DB
// A separate secret used to en/decrypt access tokens prior to / after retrieval from the database.
// This provides additional security as a simple SQL injection attack would be insufficient to retrieve
// users access tokens due to the encryption key not living inside the database / on that machine at all.
// https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#separation-of-keys-and-data
// We cannot use bcrypt/scrypt as we need the plaintext to do sync requests!
key256 []byte
DevicesTable *DevicesTable
TokensTable *TokensTable
DB *sqlx.DB
}
func NewStore(postgresURI, secret string) *Storage {
@ -47,138 +23,18 @@ func NewStore(postgresURI, secret string) *Storage {
if err != nil {
sentry.CaptureException(err)
// TODO: if we panic(), will sentry have a chance to flush the event?
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_sync2_devices (
device_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL, -- populated from /whoami
v2_token_encrypted TEXT NOT NULL,
since TEXT NOT NULL
);`)
// derive the key from the secret
hash := sha256.New()
hash.Write([]byte(secret))
return &Storage{
db: db,
key256: hash.Sum(nil),
DevicesTable: NewDevicesTable(db),
TokensTable: NewTokensTable(db, secret),
DB: db,
}
}
func (s *Storage) Teardown() {
err := s.db.Close()
err := s.DB.Close()
if err != nil {
panic("V2Storage.Teardown: " + err.Error())
}
}
func (s *Storage) encrypt(token string) string {
block, err := aes.NewCipher(s.key256)
if err != nil {
panic("sync2.Storage encrypt: " + err.Error())
}
gcm, err := cipher.NewGCM(block)
if err != nil {
panic("sync2.Storage encrypt: " + err.Error())
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
panic("sync2.Storage encrypt: " + err.Error())
}
return hex.EncodeToString(nonce) + " " + hex.EncodeToString(gcm.Seal(nil, nonce, []byte(token), nil))
}
func (s *Storage) decrypt(nonceAndEncToken string) (string, error) {
segs := strings.Split(nonceAndEncToken, " ")
nonce := segs[0]
nonceBytes, err := hex.DecodeString(nonce)
if err != nil {
return "", fmt.Errorf("decrypt nonce: failed to decode hex: %s", err)
}
encToken := segs[1]
ciphertext, err := hex.DecodeString(encToken)
if err != nil {
return "", fmt.Errorf("decrypt token: failed to decode hex: %s", err)
}
block, err := aes.NewCipher(s.key256)
if err != nil {
return "", err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
token, err := aesgcm.Open(nil, nonceBytes, ciphertext, nil)
if err != nil {
return "", err
}
return string(token), nil
}
func (s *Storage) Device(deviceID string) (*Device, error) {
var d Device
err := s.db.Get(&d, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices WHERE device_id=$1`, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to lookup device '%s': %s", deviceID, err)
}
d.AccessToken, err = s.decrypt(d.AccessTokenEncrypted)
return &d, err
}
func (s *Storage) AllDevices() (devices []Device, err error) {
err = s.db.Select(&devices, `SELECT device_id, user_id, since, v2_token_encrypted FROM syncv3_sync2_devices`)
if err != nil {
return
}
for i := range devices {
devices[i].AccessToken, _ = s.decrypt(devices[i].AccessTokenEncrypted)
}
return
}
func (s *Storage) RemoveDevice(deviceID string) error {
_, err := s.db.Exec(
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`, deviceID,
)
log.Info().Str("device", deviceID).Msg("Deleting device")
return err
}
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
var device Device
device.AccessToken = accessToken
device.AccessTokenEncrypted = s.encrypt(accessToken)
err := sqlutil.WithTransaction(s.db, func(txn *sqlx.Tx) error {
// make sure there is a device entry for this device ID. If one already exists, don't clobber
// the since value else we'll forget our position!
result, err := txn.Exec(`
INSERT INTO syncv3_sync2_devices(device_id, since, user_id, v2_token_encrypted) VALUES($1,$2,$3,$4)
ON CONFLICT (device_id) DO NOTHING`,
deviceID, "", "", device.AccessTokenEncrypted,
)
if err != nil {
return err
}
device.DeviceID = deviceID
// if we inserted a row that means it's a brand new device ergo there is no since token
if ra, err := result.RowsAffected(); err == nil && ra == 1 {
return nil
}
// Return the since value as we may start a new poller with this session.
return txn.QueryRow("SELECT since, user_id FROM syncv3_sync2_devices WHERE device_id = $1", deviceID).Scan(&device.Since, &device.UserID)
})
return &device, err
}
func (s *Storage) UpdateDeviceSince(deviceID, since string) error {
_, err := s.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE device_id = $2`, since, deviceID)
return err
}
func (s *Storage) UpdateUserIDForDevice(deviceID, userID string) error {
_, err := s.db.Exec(`UPDATE syncv3_sync2_devices SET user_id = $1 WHERE device_id = $2`, userID, deviceID)
return err
}

View File

@ -1,89 +0,0 @@
package sync2
import (
"os"
"sort"
"testing"
"github.com/matrix-org/sliding-sync/testutils"
)
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
func TestMain(m *testing.M) {
postgresConnectionString = testutils.PrepareDBConnectionString()
exitCode := m.Run()
os.Exit(exitCode)
}
func TestStorage(t *testing.T) {
deviceID := "ALICE"
accessToken := "my_access_token"
store := NewStore(postgresConnectionString, "my_secret")
device, err := store.InsertDevice(deviceID, accessToken)
if err != nil {
t.Fatalf("Failed to InsertDevice: %s", err)
}
assertEqual(t, device.DeviceID, deviceID, "Device.DeviceID mismatch")
assertEqual(t, device.AccessToken, accessToken, "Device.AccessToken mismatch")
if err = store.UpdateDeviceSince(deviceID, "s1"); err != nil {
t.Fatalf("UpdateDeviceSince returned error: %s", err)
}
if err = store.UpdateUserIDForDevice(deviceID, "@alice:localhost"); err != nil {
t.Fatalf("UpdateUserIDForDevice returned error: %s", err)
}
// now check that device retrieval has the latest values
device, err = store.Device(deviceID)
if err != nil {
t.Fatalf("Device returned error: %s", err)
}
assertEqual(t, device.DeviceID, deviceID, "Device.DeviceID mismatch")
assertEqual(t, device.Since, "s1", "Device.Since mismatch")
assertEqual(t, device.UserID, "@alice:localhost", "Device.UserID mismatch")
assertEqual(t, device.AccessToken, accessToken, "Device.AccessToken mismatch")
// now check new devices remember the v2 since value and user ID
s2, err := store.InsertDevice(deviceID, accessToken)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
assertEqual(t, s2.Since, "s1", "Device.Since mismatch")
assertEqual(t, s2.UserID, "@alice:localhost", "Device.UserID mismatch")
assertEqual(t, s2.DeviceID, deviceID, "Device.DeviceID mismatch")
assertEqual(t, s2.AccessToken, accessToken, "Device.AccessToken mismatch")
// check all devices works
deviceID2 := "BOB"
accessToken2 := "BOB_ACCESS_TOKEN"
bobDevice, err := store.InsertDevice(deviceID2, accessToken2)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
devices, err := store.AllDevices()
if err != nil {
t.Fatalf("AllDevices: %s", err)
}
sort.Slice(devices, func(i, j int) bool {
return devices[i].DeviceID < devices[j].DeviceID
})
wantDevices := []*Device{
device, bobDevice,
}
if len(devices) != len(wantDevices) {
t.Fatalf("AllDevices: got %d devices, want %d", len(devices), len(wantDevices))
}
for i := range devices {
assertEqual(t, devices[i].Since, wantDevices[i].Since, "Device.Since mismatch")
assertEqual(t, devices[i].UserID, wantDevices[i].UserID, "Device.UserID mismatch")
assertEqual(t, devices[i].DeviceID, wantDevices[i].DeviceID, "Device.DeviceID mismatch")
assertEqual(t, devices[i].AccessToken, wantDevices[i].AccessToken, "Device.AccessToken mismatch")
}
}
func assertEqual(t *testing.T, got, want, msg string) {
t.Helper()
if got != want {
t.Fatalf("%s: got %s want %s", msg, got, want)
}
}

245
sync2/tokens_table.go Normal file
View File

@ -0,0 +1,245 @@
package sync2
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/jmoiron/sqlx"
"io"
"strings"
"time"
)
type Token struct {
AccessToken string
AccessTokenHash string
AccessTokenEncrypted string `db:"token_encrypted"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
LastSeen time.Time `db:"last_seen"`
}
// TokensTable remembers sync v2 tokens
type TokensTable struct {
db *sqlx.DB
// A separate secret used to en/decrypt access tokens prior to / after retrieval from the database.
// This provides additional security as a simple SQL injection attack would be insufficient to retrieve
// users access tokens due to the encryption key not living inside the database / on that machine at all.
// https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#separation-of-keys-and-data
// We cannot use bcrypt/scrypt as we need the plaintext to do sync requests!
key256 []byte
}
// NewTokensTable creates the syncv3_sync2_tokens table if it does not already exist.
func NewTokensTable(db *sqlx.DB, secret string) *TokensTable {
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_sync2_tokens (
token_hash TEXT NOT NULL PRIMARY KEY, -- SHA256(access token)
token_encrypted TEXT NOT NULL,
-- TODO: FK constraints to devices table?
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
last_seen TIMESTAMP WITH TIME ZONE NOT NULL
);`)
// derive the key from the secret
hash := sha256.New()
hash.Write([]byte(secret))
return &TokensTable{
db: db,
key256: hash.Sum(nil),
}
}
func (t *TokensTable) encrypt(token string) string {
block, err := aes.NewCipher(t.key256)
if err != nil {
panic("sync2.DevicesTable encrypt: " + err.Error())
}
gcm, err := cipher.NewGCM(block)
if err != nil {
panic("sync2.DevicesTable encrypt: " + err.Error())
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
panic("sync2.DevicesTable encrypt: " + err.Error())
}
return hex.EncodeToString(nonce) + " " + hex.EncodeToString(gcm.Seal(nil, nonce, []byte(token), nil))
}
func (t *TokensTable) decrypt(nonceAndEncToken string) (string, error) {
segs := strings.Split(nonceAndEncToken, " ")
nonce := segs[0]
nonceBytes, err := hex.DecodeString(nonce)
if err != nil {
return "", fmt.Errorf("decrypt nonce: failed to decode hex: %s", err)
}
encToken := segs[1]
ciphertext, err := hex.DecodeString(encToken)
if err != nil {
return "", fmt.Errorf("decrypt token: failed to decode hex: %s", err)
}
block, err := aes.NewCipher(t.key256)
if err != nil {
return "", err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
token, err := aesgcm.Open(nil, nonceBytes, ciphertext, nil)
if err != nil {
return "", err
}
return string(token), nil
}
func hashToken(accessToken string) string {
// important that this is a cryptographically secure hash function to prevent
// preimage attacks where Eve can use a fake token to hash to an existing device ID
// on the server.
hash := sha256.New()
hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil))
}
// Token retrieves a tokens row from the database if it exists.
// Errors with sql.NoRowsError if the token does not exist.
// Errors with an unspecified error otherwise.
func (t *TokensTable) Token(plaintextToken string) (*Token, error) {
tokenHash := hashToken(plaintextToken)
var token Token
err := t.db.Get(
&token,
`SELECT token_encrypted, user_id, device_id, last_seen FROM syncv3_sync2_tokens WHERE token_hash=$1`,
tokenHash,
)
if err != nil {
return nil, err
}
token.AccessToken = plaintextToken
token.AccessTokenHash = tokenHash
return &token, nil
}
// TokenForPoller represents a row of the tokens table, together with any data
// maintained by pollers for that token's device.
type TokenForPoller struct {
*Token
Since string `db:"since"`
}
func (t *TokensTable) TokenForEachDevice() (tokens []TokenForPoller, err error) {
// Fetches the most recently seen token for each device, see e.g.
// https://www.postgresql.org/docs/11/sql-select.html#SQL-DISTINCT
err = t.db.Select(
&tokens,
`SELECT DISTINCT ON (user_id, device_id) token_encrypted, user_id, device_id, last_seen, since
FROM syncv3_sync2_tokens JOIN syncv3_sync2_devices USING (user_id, device_id)
ORDER BY user_id, device_id, last_seen DESC
`)
if err != nil {
return
}
for _, token := range tokens {
token.AccessToken, err = t.decrypt(token.AccessTokenEncrypted)
if err != nil {
// Ignore decryption failure.
continue
}
token.AccessTokenHash = hashToken(token.AccessToken)
}
return
}
// Insert a new token into the table.
func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
hashedToken := hashToken(plaintextToken)
encToken := t.encrypt(plaintextToken)
_, err := t.db.Exec(
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (token_hash) DO NOTHING;`,
hashedToken, encToken, userID, deviceID, lastSeen,
)
if err != nil {
return nil, err
}
return &Token{
AccessToken: plaintextToken,
AccessTokenHash: hashedToken,
// Note: if this token already exists in the DB, encToken will differ from
// the DB token_encrypted column. (t.encrypt is nondeterministic, see e.g.
// https://en.wikipedia.org/wiki/Probabilistic_encryption).
// The rest of the program should ignore this field; it only lives here so
// we can Scan the DB row into the Tokens struct. Could make it private?
AccessTokenEncrypted: encToken,
UserID: userID,
DeviceID: deviceID,
LastSeen: lastSeen,
}, nil
}
// MaybeUpdateLastSeen actions a request to update a Token struct with its last_seen value
// in the DB. To avoid spamming the DB with a write every time a sync3 request arrives,
// we only update the last seen timestamp or the if it is at least 24 hours old.
// The timestamp is updated on the Token struct if and only if it is updated in the DB.
func (t *TokensTable) MaybeUpdateLastSeen(token *Token, newLastSeen time.Time) error {
sinceLastSeen := newLastSeen.Sub(token.LastSeen)
if sinceLastSeen < (24 * time.Hour) {
return nil
}
_, err := t.db.Exec(
`UPDATE syncv3_sync2_tokens SET last_seen = $1 WHERE token_hash = $2`,
newLastSeen, token.AccessTokenHash,
)
if err != nil {
return err
}
token.LastSeen = newLastSeen
return nil
}
func (t *TokensTable) GetTokenAndSince(userID, deviceID, tokenHash string) (accessToken, since string, err error) {
var encToken, gotUserID, gotDeviceID string
query := `SELECT token_encrypted, since, user_id, device_id
FROM syncv3_sync2_tokens JOIN syncv3_sync2_devices USING (user_id, device_id)
WHERE token_hash = $1;`
err = t.db.QueryRow(query, tokenHash).Scan(&encToken, &since, &gotUserID, &gotDeviceID)
if err != nil {
return
}
if gotUserID != userID || gotDeviceID != deviceID {
err = fmt.Errorf(
"token (hash %s) found with user+device mismatch: got (%s, %s), expected (%s, %s)",
tokenHash, gotUserID, gotDeviceID, userID, deviceID,
)
return
}
accessToken, err = t.decrypt(encToken)
return
}
// Delete looks up a token by its hash and deletes the row. If no token exists with the
// given hash, a warning is logged but no error is returned.
func (t *TokensTable) Delete(accessTokenHash string) error {
result, err := t.db.Exec(
`DELETE FROM syncv3_sync2_tokens WHERE token_hash = $1`,
accessTokenHash,
)
if err != nil {
return err
}
ra, err := result.RowsAffected()
if err != nil {
return err
}
if ra != 1 {
logger.Warn().Msgf("Tokens.Delete: expected to delete one token, but actually deleted %d", ra)
}
return nil
}

150
sync2/tokens_table_test.go Normal file
View File

@ -0,0 +1,150 @@
package sync2
import (
"testing"
"time"
)
// Sanity check that different tokens have different hashes
func TestHash(t *testing.T) {
token1 := "ABCD"
token2 := "EFGH"
hash1 := hashToken(token1)
hash2 := hashToken(token2)
if hash1 == hash2 {
t.Fatalf("HashedTokenFromRequest: %s and %s have the same hash", token1, token2)
}
}
func TestTokensTable(t *testing.T) {
db, close := connectToDB(t)
defer close()
tokens := NewTokensTable(db, "my_secret")
alice := "@alice:localhost"
aliceDevice := "alice_phone"
aliceSecret1 := "mysecret1"
aliceToken1FirstSeen := time.Now()
// Test a single token
t.Log("Insert a new token from Alice.")
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("Reinsert the same token.")
reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("This should yield an equal Token struct.")
assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("Try to mark Alice's token as being used after an hour.")
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
if err != nil {
t.Fatalf("Failed to update last seen: %s", err)
}
t.Log("The token should not be updated in memory, nor in the DB.")
assertEqualTimes(t, aliceToken.LastSeen, aliceToken1FirstSeen, "Token.LastSeen mismatch")
fetchedToken, err := tokens.Token(aliceSecret1)
if err != nil {
t.Fatalf("Failed to fetch token: %s", err)
}
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("Try to mark Alice's token as being used after two days.")
aliceToken1LastSeen := aliceToken1FirstSeen.Add(48 * time.Hour)
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1LastSeen)
if err != nil {
t.Fatalf("Failed to update last seen: %s", err)
}
t.Log("The token should now be updated in-memory and in the DB.")
assertEqualTimes(t, aliceToken.LastSeen, aliceToken1LastSeen, "Token.LastSeen mismatch")
fetchedToken, err = tokens.Token(aliceSecret1)
if err != nil {
t.Fatalf("Failed to fetch token: %s", err)
}
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen)
// Test a second token for Alice
t.Log("Insert a second token for Alice.")
aliceSecret2 := "mysecret2"
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
}
func TestDeletingTokens(t *testing.T) {
db, close := connectToDB(t)
defer close()
tokens := NewTokensTable(db, "my_secret")
t.Log("Insert a new token from Alice.")
accessToken := "mytoken"
token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("We should be able to fetch this token without error.")
_, err = tokens.Token(accessToken)
if err != nil {
t.Fatalf("Failed to fetch token: %s", err)
}
t.Log("Delete the token")
err = tokens.Delete(token.AccessTokenHash)
if err != nil {
t.Fatalf("Failed to delete token: %s", err)
}
t.Log("We should no longer be able to fetch this token.")
token, err = tokens.Token(accessToken)
if token != nil || err == nil {
t.Fatalf("Fetching token after deletion did not fail: got %s, %s", token, err)
}
}
func assertEqualTokens(t *testing.T, table *TokensTable, got *Token, accessToken, userID, deviceID string, lastSeen time.Time) {
t.Helper()
assertEqual(t, got.AccessToken, accessToken, "Token.AccessToken mismatch")
assertEqual(t, got.AccessTokenHash, hashToken(accessToken), "Token.AccessTokenHashed mismatch")
// We don't care what the encrypted token is here. The fact that we store encrypted values is an
// implementation detail; the rest of the program doesn't care.
assertEqual(t, got.UserID, userID, "Token.UserID mismatch")
assertEqual(t, got.DeviceID, deviceID, "Token.DeviceID mismatch")
assertEqualTimes(t, got.LastSeen, lastSeen, "Token.LastSeen mismatch")
}
func assertEqual(t *testing.T, got, want, msg string) {
t.Helper()
if got != want {
t.Fatalf("%s: got %s want %s", msg, got, want)
}
}
func assertEqualTimes(t *testing.T, got, want time.Time, msg string) {
t.Helper()
// Postgres stores timestamps with microsecond resolution, so we might lose some
// precision by storing and fetching a time.Time in/from the DB. Resolution of
// a second will suffice.
if !got.Round(time.Second).Equal(want.Round(time.Second)) {
t.Fatalf("%s: got %v want %v", msg, got, want)
}
}
// see devices_table_test.go for tests which join the tokens and devices tables.

View File

@ -11,11 +11,12 @@ import (
)
type ConnID struct {
UserID string
DeviceID string
}
func (c *ConnID) String() string {
return c.DeviceID
return c.UserID + "|" + c.DeviceID
}
type ConnHandler interface {
@ -24,7 +25,6 @@ type ConnHandler interface {
// status code to send back.
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
OnUpdate(ctx context.Context, update caches.Update)
UserID() string
Destroy()
Alive() bool
}
@ -33,7 +33,7 @@ type ConnHandler interface {
// of the /sync request, including sending cached data in the event of retries. It does not handle
// the contents of the data at all.
type Conn struct {
ConnID ConnID
ConnID
handler ConnHandler
@ -65,10 +65,6 @@ func NewConn(connID ConnID, h ConnHandler) *Conn {
}
}
func (c *Conn) UserID() string {
return c.handler.UserID()
}
func (c *Conn) Alive() bool {
return c.handler.Alive()
}
@ -105,7 +101,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
}
ctx, task := internal.StartTask(ctx, taskType)
defer task.End()
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.handler.UserID(), c.ConnID.DeviceID, req.pos)
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
}
@ -164,7 +160,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
c.serverResponses = c.serverResponses[delIndex+1:] // slice out the first delIndex+1 elements
defer func() {
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.handler.UserID())
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.UserID)
if nextUnACKedResponse != nil {
l.Int64("new_pos", nextUnACKedResponse.PosInt())
}

View File

@ -71,7 +71,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
return conn, true
}
@ -94,20 +94,20 @@ func (m *ConnMap) closeConn(conn *Conn) {
return
}
connID := conn.ConnID.String()
logger.Trace().Str("conn", connID).Msg("closing connection")
connKey := conn.ConnID.String()
logger.Trace().Str("conn", connKey).Msg("closing connection")
// remove conn from all the maps
delete(m.connIDToConn, connID)
delete(m.connIDToConn, connKey)
h := conn.handler
conns := m.userIDToConn[h.UserID()]
conns := m.userIDToConn[conn.UserID]
for i := 0; i < len(conns); i++ {
if conns[i].ConnID.String() == connID {
if conns[i].DeviceID == conn.DeviceID {
// delete without preserving order
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
m.userIDToConn[h.UserID()] = conns
m.userIDToConn[conn.UserID] = conns
// remove user cache listeners etc
h.Destroy()
}

View File

@ -54,6 +54,7 @@ func (s *connStateLive) liveUpdate(
ctx context.Context, req *sync3.Request, ex extensions.Request, isInitial bool,
response *sync3.Response,
) {
log := logger.With().Str("user", s.userID).Str("device", s.deviceID).Logger()
// we need to ensure that we keep consuming from the updates channel, even if they want a response
// immediately. If we have new list data we won't wait, but if we don't then we need to be able to
// catch-up to the current head position, hence giving 100ms grace period for processing.
@ -67,17 +68,17 @@ func (s *connStateLive) liveUpdate(
timeWaited := time.Since(startTime)
timeLeftToWait := timeToWait - timeWaited
if timeLeftToWait < 0 {
logger.Trace().Str("user", s.userID).Str("time_waited", timeWaited.String()).Msg("liveUpdate: timed out")
log.Trace().Str("time_waited", timeWaited.String()).Msg("liveUpdate: timed out")
return
}
logger.Trace().Str("user", s.userID).Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
log.Trace().Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
select {
case <-ctx.Done(): // client has given up
logger.Trace().Str("user", s.userID).Msg("liveUpdate: client gave up")
log.Trace().Msg("liveUpdate: client gave up")
internal.Logf(ctx, "liveUpdate", "context cancelled")
return
case <-time.After(timeLeftToWait): // we've timed out
logger.Trace().Str("user", s.userID).Msg("liveUpdate: timed out")
log.Trace().Msg("liveUpdate: timed out")
internal.Logf(ctx, "liveUpdate", "timed out after %v", timeLeftToWait)
return
case update := <-s.updates:
@ -111,7 +112,7 @@ func (s *connStateLive) liveUpdate(
}
}
}
logger.Trace().Str("user", s.userID).Int("subs", len(response.Rooms)).Msg("liveUpdate: returning")
log.Trace().Msg("liveUpdate: returning")
// TODO: op consolidation
}

View File

@ -1,6 +1,7 @@
package handler
import (
"github.com/matrix-org/sliding-sync/sync2"
"sync"
"github.com/matrix-org/sliding-sync/pubsub"
@ -14,7 +15,7 @@ type pendingInfo struct {
type EnsurePoller struct {
chanName string
mu *sync.Mutex
pendingPolls map[string]pendingInfo
pendingPolls map[sync2.PollerID]pendingInfo
notifier pubsub.Notifier
}
@ -22,23 +23,22 @@ func NewEnsurePoller(notifier pubsub.Notifier) *EnsurePoller {
return &EnsurePoller{
chanName: pubsub.ChanV3,
mu: &sync.Mutex{},
pendingPolls: make(map[string]pendingInfo),
pendingPolls: make(map[sync2.PollerID]pendingInfo),
notifier: notifier,
}
}
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
// the caller's responsibility to call OnInitialSyncComplete when new events arrive.
func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
key := userID + "|" + deviceID
func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
p.mu.Lock()
// do we need to wait?
if p.pendingPolls[key].done {
if p.pendingPolls[pid].done {
p.mu.Unlock()
return
}
// have we called EnsurePolling for this user/device before?
ch := p.pendingPolls[key].ch
ch := p.pendingPolls[pid].ch
if ch != nil {
p.mu.Unlock()
// we already called EnsurePolling on this device, so just listen for the close
@ -50,15 +50,16 @@ func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
}
// Make a channel to wait until we have done an initial sync
ch = make(chan struct{})
p.pendingPolls[key] = pendingInfo{
p.pendingPolls[pid] = pendingInfo{
done: false,
ch: ch,
}
p.mu.Unlock()
// ask the pollers to poll for this device
p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{
UserID: userID,
DeviceID: deviceID,
UserID: pid.UserID,
DeviceID: pid.DeviceID,
AccessTokenHash: tokenHash,
})
// if by some miracle the notify AND sync completes before we receive on ch then this is
// still fine as recv on a closed channel will return immediately.
@ -66,15 +67,15 @@ func (p *EnsurePoller) EnsurePolling(userID, deviceID string) {
}
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
key := payload.UserID + "|" + payload.DeviceID
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
p.mu.Lock()
defer p.mu.Unlock()
pending, ok := p.pendingPolls[key]
pending, ok := p.pendingPolls[pid]
// were we waiting for this initial sync to complete?
if !ok {
// This can happen when the v2 poller spontaneously starts polling even without us asking it to
// e.g from the database
p.pendingPolls[key] = pendingInfo{
p.pendingPolls[pid] = pendingInfo{
done: true,
}
return
@ -88,7 +89,7 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
ch := pending.ch
pending.done = true
pending.ch = nil
p.pendingPolls[key] = pending
p.pendingPolls[pid] = pending
close(ch)
}

View File

@ -3,9 +3,12 @@ package handler
import "C"
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/getsentry/sentry-go"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"net/http"
"net/url"
"os"
@ -194,6 +197,10 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
}
}
hlog.FromRequest(req).UpdateContext(func(c zerolog.Context) zerolog.Context {
c.Str("txn_id", requestBody.TxnID)
return c
})
for listKey, l := range requestBody.Lists {
if l.Ranges != nil && !l.Ranges.Valid() {
return &internal.HandlerError{
@ -215,8 +222,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
return herr
}
requestBody.SetPos(cpos)
internal.SetRequestContextUserID(req.Context(), conn.UserID())
log := hlog.FromRequest(req).With().Str("user", conn.UserID()).Int64("pos", cpos).Logger()
internal.SetRequestContextUserID(req.Context(), conn.UserID)
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
var timeout int
if req.URL.Query().Get("timeout") == "" {
@ -276,25 +283,53 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, error) {
log := hlog.FromRequest(req)
var conn *sync3.Conn
// Identify the device
deviceID, accessToken, err := internal.HashedTokenFromRequest(req)
// Extract an access token
accessToken, err := internal.ExtractAccessToken(req)
if err != nil || accessToken == "" {
log.Warn().Err(err).Msg("failed to get device ID from request")
hlog.FromRequest(req).Warn().Err(err).Msg("failed to get access token from request")
return nil, &internal.HandlerError{
StatusCode: 400,
StatusCode: http.StatusUnauthorized,
Err: err,
}
}
// Try to lookup a record of this token
var token *sync2.Token
token, err = h.V2Store.TokensTable.Token(accessToken)
if err != nil {
if err == sql.ErrNoRows {
hlog.FromRequest(req).Info().Msg("Received connection from unknown access token, querying with homeserver")
newToken, herr := h.identifyUnknownAccessToken(accessToken, hlog.FromRequest(req))
if herr != nil {
return nil, herr
}
token = newToken
} else {
hlog.FromRequest(req).Err(err).Msg("Failed to lookup access token")
return nil, &internal.HandlerError{
StatusCode: http.StatusInternalServerError,
Err: err,
}
}
}
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
// Record the fact that we've recieved a request from this token
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
if err != nil {
// Not fatal---log and continue.
log.Warn().Err(err).Msg("Unable to update last seen timestamp")
}
connID := sync3.ConnID{
UserID: token.UserID,
DeviceID: token.DeviceID,
}
// client thinks they have a connection
if containsPos {
// Lookup the connection
conn = h.ConnMap.Conn(sync3.ConnID{
DeviceID: deviceID,
})
conn = h.ConnMap.Conn(connID)
if conn != nil {
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
return conn, nil
@ -303,55 +338,23 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
return nil, internal.ExpiredSessionError()
}
// We're going to make a new connection
// Ensure we have the v2 side of things hooked up
v2device, err := h.V2Store.InsertDevice(deviceID, accessToken)
if err != nil {
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to insert v2 device")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
if v2device.UserID == "" {
v2device.UserID, _, err = h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
}
}
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
Err: err,
}
}
if err = h.V2Store.UpdateUserIDForDevice(deviceID, v2device.UserID); err != nil {
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to persist user ID -> device ID mapping")
// non-fatal, we can still work without doing this
}
}
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
h.V3Pub.EnsurePolling(v2device.UserID, v2device.DeviceID)
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")
log.Trace().Msg("checking poller exists and is running")
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
h.V3Pub.EnsurePolling(pid, token.AccessTokenHash)
log.Trace().Msg("poller exists and is running")
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
// We'll be quicker next time as the poller will already exist.
if req.Context().Err() != nil {
log.Warn().Str("user_id", v2device.UserID).Msg(
"client gave up, not creating connection",
)
log.Warn().Msg("client gave up, not creating connection")
return nil, &internal.HandlerError{
StatusCode: 400,
Err: req.Context().Err(),
}
}
userCache, err := h.userCache(v2device.UserID)
userCache, err := h.userCache(token.UserID)
if err != nil {
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
log.Warn().Err(err).Msg("failed to load user cache")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
@ -366,19 +369,59 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
DeviceID: deviceID,
}, func() sync3.ConnHandler {
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
})
if created {
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection")
log.Info().Msg("created new connection")
} else {
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("using existing connection")
log.Info().Msg("using existing connection")
}
return conn, nil
}
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger *zerolog.Logger) (*sync2.Token, *internal.HandlerError) {
// We don't recognise the given accessToken. Ask the homeserver who owns it.
userID, deviceID, err := h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
}
}
log.Warn().Err(err).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
Err: err,
}
}
var token *sync2.Token
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
// Create a brand-new row for this token.
token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
if err != nil {
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
return err
}
// Ensure we have a device row for this token.
err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID)
if err != nil {
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
return err
}
return nil
})
if err != nil {
return nil, &internal.HandlerError{StatusCode: 500, Err: err}
}
return token, nil
}
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
c, ok := h.userCaches.Load(userID)
if ok {
@ -551,6 +594,7 @@ func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
@ -563,6 +607,7 @@ func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
@ -659,6 +704,7 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
}

View File

@ -639,17 +639,10 @@ func TestExpiredAccessToken(t *testing.T) {
})
// now expire the token
v2.invalidateToken(aliceToken)
// now do another request, this should 400 as it expires the session
// now do another request, this should 401
req := sync3.Request{}
req.SetTimeoutMSecs(1)
_, body, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, req)
if statusCode != 400 {
t.Fatalf("got %d want 400 : %v", statusCode, string(body))
}
// do a fresh request, this should 401
req = sync3.Request{}
req.SetTimeoutMSecs(1)
_, body, statusCode = v3.doV3Request(t, context.Background(), aliceToken, "", req)
if statusCode != 401 {
t.Fatalf("got %d want 401 : %v", statusCode, string(body))
}

View File

@ -224,7 +224,7 @@ func TestExtensionToDevice(t *testing.T) {
},
})
// 1: check that a fresh sync returns to-device messages
t.Log("1: check that a fresh sync returns to-device messages")
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
@ -239,7 +239,7 @@ func TestExtensionToDevice(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(toDeviceMsgs))
// 2: repeating the fresh sync request returns the same messages (not deleted)
t.Log("2: repeating the fresh sync request returns the same messages (not deleted)")
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
@ -254,7 +254,7 @@ func TestExtensionToDevice(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(toDeviceMsgs))
// 3: update the since token -> no new messages
t.Log("3: update the since token -> no new messages")
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
@ -270,7 +270,7 @@ func TestExtensionToDevice(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
// 4: inject live to-device messages -> receive them only.
t.Log("4: inject live to-device messages -> receive them only.")
sinceBeforeMsgs := res.Extensions.ToDevice.NextBatch
newToDeviceMsgs := []json.RawMessage{
json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"5"}}`),
@ -296,7 +296,7 @@ func TestExtensionToDevice(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(newToDeviceMsgs))
// 5: repeating the previous sync request returns the same live to-device messages (retransmit)
t.Log("5: repeating the previous sync request returns the same live to-device messages (retransmit)")
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
@ -327,7 +327,7 @@ func TestExtensionToDevice(t *testing.T) {
// this response contains nothing
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
// 6: using an old since token does not return to-device messages anymore as they were deleted.
t.Log("6: using an old since token does not return to-device messages anymore as they were deleted.")
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
@ -343,7 +343,7 @@ func TestExtensionToDevice(t *testing.T) {
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages([]json.RawMessage{}))
// live stream and block, then send a to-device msg which should go through immediately
t.Log("7: live stream and block, then send a to-device msg which should go through immediately")
start := time.Now()
go func() {
time.Sleep(500 * time.Millisecond)

View File

@ -25,7 +25,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
defer v2.close()
defer v3.close()
deviceAToken := "DEVICE_A_TOKEN"
v2.addAccount(alice, deviceAToken)
v2.addAccountWithDeviceID(alice, "A", deviceAToken)
v2.queueResponse(deviceAToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
@ -39,7 +39,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
// now sync with device B, and check we send the filter up
deviceBToken := "DEVICE_B_TOKEN"
v2.addAccount(alice, deviceBToken)
v2.addAccountWithDeviceID(alice, "B", deviceBToken)
seenInitialRequest := false
v2.CheckRequest = func(userID, token string, req *http.Request) {
if userID != alice || token != deviceBToken {

View File

@ -49,6 +49,7 @@ type testV2Server struct {
CheckRequest func(userID, token string, req *http.Request)
mu *sync.Mutex
tokenToUser map[string]string
tokenToDevice map[string]string
queues map[string]chan sync2.SyncResponse
waiting map[string]*sync.Cond // broadcasts when the server is about to read a blocking input
srv *httptest.Server
@ -56,10 +57,21 @@ type testV2Server struct {
timeToWaitForV2Response time.Duration
}
// Most tests only use a single device per user. Give them this helper so they don't
// have to care about providing a device name.
func (s *testV2Server) addAccount(userID, token string) {
// To keep our future selves sane while debugging, use a device name that
// includes the mxid localpart.
atLocalPart, _, _ := strings.Cut(userID, ":")
s.addAccountWithDeviceID(userID, atLocalPart[1:]+"_device", token)
}
// Tests that use multiple devices for the same user need to be more explicit.
func (s *testV2Server) addAccountWithDeviceID(userID, deviceID, token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.tokenToUser[token] = userID
s.tokenToDevice[token] = deviceID
s.queues[token] = make(chan sync2.SyncResponse, 100)
s.waiting[token] = &sync.Cond{
L: &sync.Mutex{},
@ -77,6 +89,7 @@ func (s *testV2Server) invalidateToken(token string) {
wg.Done()
}
delete(s.tokenToUser, token)
delete(s.tokenToDevice, token)
s.mu.Unlock()
// kick over the connection so the next request 401s and wait till we get said request
@ -97,6 +110,12 @@ func (s *testV2Server) userID(token string) string {
return s.tokenToUser[token]
}
func (s *testV2Server) deviceID(token string) string {
s.mu.Lock()
defer s.mu.Unlock()
return s.tokenToDevice[token]
}
func (s *testV2Server) queueResponse(userIDOrToken string, resp sync2.SyncResponse) {
s.mu.Lock()
ch := s.queues[userIDOrToken]
@ -185,6 +204,7 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
t.Helper()
server := &testV2Server{
tokenToUser: make(map[string]string),
tokenToDevice: make(map[string]string),
queues: make(map[string]chan sync2.SyncResponse),
waiting: make(map[string]*sync.Cond),
invalidations: make(map[string]func()),
@ -195,7 +215,8 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
r.HandleFunc("/_matrix/client/r0/account/whoami", func(w http.ResponseWriter, req *http.Request) {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
userID := server.userID(token)
if userID == "" {
deviceID := server.deviceID(token)
if userID == "" || deviceID == "" {
w.WriteHeader(401)
server.mu.Lock()
fn := server.invalidations[token]
@ -206,7 +227,7 @@ func runTestV2Server(t testutils.TestBenchInterface) *testV2Server {
return
}
w.WriteHeader(200)
w.Write([]byte(fmt.Sprintf(`{"user_id":"%s"}`, userID)))
w.Write([]byte(fmt.Sprintf(`{"user_id":"%s","device_id":"%s"}`, userID, deviceID)))
})
r.HandleFunc("/_matrix/client/r0/sync", func(w http.ResponseWriter, req *http.Request) {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")