Remember since tokens for each device

Tie it to the device so a since-less v3 sync still pulls in stored pollers
This commit is contained in:
Kegan Dougal 2021-06-09 18:13:55 +01:00
parent 29b17b89e5
commit b0146bb031
7 changed files with 233 additions and 140 deletions

View File

@ -1,61 +0,0 @@
package syncv3
import (
"database/sql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
)
type Session struct {
ID string `db:"session_id"`
DeviceID string `db:"device_id"`
LastToDeviceACK string `db:"last_to_device_ack"`
}
type Sessions struct {
db *sqlx.DB
upsertSessionStmt *sql.Stmt
}
func NewSessions(postgresURI string) *Sessions {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
// make sure tables are made
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_session_id_seq;
CREATE TABLE IF NOT EXISTS syncv3_sessions (
session_id BIGINT PRIMARY KEY DEFAULT nextval('syncv3_session_id_seq'),
-- user_id TEXT NOT NULL, (we don't know the user ID from the access token alone, at least in dendrite!)
device_id TEXT NOT NULL,
last_to_device_ack TEXT NOT NULL,
CONSTRAINT syncv3_sessions_unique UNIQUE (device_id, session_id)
);
`)
return &Sessions{
db: db,
}
}
func (s *Sessions) NewSession(deviceID string) (*Session, error) {
var id string
err := s.db.QueryRow("INSERT INTO syncv3_sessions(device_id, last_to_device_ack) VALUES($1,'') RETURNING session_id", deviceID).Scan(&id)
if err != nil {
return nil, err
}
return &Session{
ID: id,
DeviceID: deviceID,
}, nil
}
func (s *Sessions) Session(sessionID, deviceID string) (*Session, error) {
var result Session
err := s.db.Get(&result, "SELECT * FROM syncv3_sessions WHERE session_id=$1 AND device_id=$2", sessionID, deviceID)
if err == sql.ErrNoRows {
return nil, nil
}
return &result, nil
}

36
sqlutil/sql.go Normal file
View File

@ -0,0 +1,36 @@
package sqlutil
import (
"fmt"
"github.com/jmoiron/sqlx"
)
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolled back
// Otherwise the transaction is committed.
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
txn, err := db.Beginx()
if err != nil {
return fmt.Errorf("WithTransaction.Begin: %w", err)
}
defer func() {
panicErr := recover()
if err == nil && panicErr != nil {
err = fmt.Errorf("panic: %v", panicErr)
}
var txnErr error
if err != nil {
txnErr = txn.Rollback()
} else {
txnErr = txn.Commit()
}
if txnErr != nil && err == nil {
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
}
}()
err = fn(txn)
return
}

View File

@ -7,6 +7,7 @@ import (
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/matrix-org/sync-v3/sqlutil"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
)
@ -130,7 +131,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
if len(state) == 0 {
return nil
}
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
return sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// Attempt to short-circuit. This has to be done inside a transaction to make sure
// we don't race with multiple calls to Initialise with the same room ID.
snapshotID, err := a.roomsTable.CurrentSnapshotID(txn, roomID)
@ -211,7 +212,7 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
if len(timeline) == 0 {
return nil
}
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
return sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// Insert the events
events := make([]Event, len(timeline))
for i := range events {
@ -321,32 +322,3 @@ func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (event
}
return eventsJSON, int64(events[len(events)-1].NID), nil
}
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolled back
// Otherwise the transaction is committed.
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
txn, err := db.Beginx()
if err != nil {
return fmt.Errorf("WithTransaction.Begin: %w", err)
}
defer func() {
panicErr := recover()
if err == nil && panicErr != nil {
err = fmt.Errorf("panic: %v", panicErr)
}
var txnErr error
if err != nil {
txnErr = txn.Rollback()
} else {
txnErr = txn.Commit()
}
if txnErr != nil && err == nil {
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
}
}()
err = fn(txn)
return
}

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/matrix-org/sync-v3/state"
"github.com/matrix-org/sync-v3/sync3"
"github.com/rs/zerolog"
)
@ -15,17 +16,19 @@ type Poller struct {
DeviceID string
Client *Client
Accumulator *state.Accumulator
Sessions *sync3.Sessions
// flag set to true when poll() returns due to expired access tokens
Terminated bool
logger zerolog.Logger
}
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator) *Poller {
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator, sessions *sync3.Sessions) *Poller {
return &Poller{
AuthorizationHeader: authHeader,
DeviceID: deviceID,
Client: client,
Accumulator: accumulator,
Sessions: sessions,
Terminated: false,
logger: zerolog.New(os.Stdout).With().Timestamp().Logger().With().Str("device", deviceID).Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
@ -37,7 +40,7 @@ func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.A
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
// Returns if the access token gets invalidated. Invokes the callback on first success.
func (p *Poller) Poll(since string, callback func()) {
p.logger.Info().Msg("v2 poll loop started")
p.logger.Info().Str("since", since).Msg("v2 poll loop started")
failCount := 0
firstTime := true
for {
@ -63,6 +66,13 @@ func (p *Poller) Poll(since string, callback func()) {
failCount = 0
p.accumulate(resp)
since = resp.NextBatch
// persist the since token (TODO: this could get slow if we hammer the DB too much)
err = p.Sessions.UpdateDeviceSince(p.DeviceID, since)
if err != nil {
// non-fatal
p.logger.Warn().Str("since", since).Err(err).Msg("failed to persist new since value")
}
if firstTime {
firstTime = false
callback()

123
sync3/sessions.go Normal file
View File

@ -0,0 +1,123 @@
package sync3
import (
"database/sql"
"os"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/matrix-org/sync-v3/sqlutil"
"github.com/rs/zerolog"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
// A Session represents a single device's sync stream. One device can have many sessions open at
// once. Sessions are created when devices sync without a since token. Sessions are destroyed
// after a configurable period of inactivity.
type Session struct {
ID string `db:"session_id"`
DeviceID string `db:"device_id"`
LastToDeviceACK string `db:"last_to_device_ack"`
Since string `db:"since"`
}
type Sessions struct {
db *sqlx.DB
upsertSessionStmt *sql.Stmt
}
func NewSessions(postgresURI string) *Sessions {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
// make sure tables are made
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_session_id_seq;
CREATE TABLE IF NOT EXISTS syncv3_sessions (
session_id BIGINT PRIMARY KEY DEFAULT nextval('syncv3_session_id_seq'),
device_id TEXT NOT NULL,
last_to_device_ack TEXT NOT NULL,
CONSTRAINT syncv3_sessions_unique UNIQUE (device_id, session_id)
);
CREATE TABLE IF NOT EXISTS syncv3_sessions_v2devices (
-- user_id TEXT NOT NULL, (we don't know the user ID from the access token alone, at least in dendrite!)
device_id TEXT PRIMARY KEY,
since TEXT NOT NULL
);
`)
return &Sessions{
db: db,
}
}
func (s *Sessions) NewSession(deviceID string) (*Session, error) {
var session *Session
err := sqlutil.WithTransaction(s.db, func(txn *sqlx.Tx) error {
// make a new session
var id string
err := txn.QueryRow("INSERT INTO syncv3_sessions(device_id, last_to_device_ack) VALUES($1,'') RETURNING session_id", deviceID).Scan(&id)
if err != nil {
return err
}
// make sure there is a device entry for this device ID. If one already exists, don't clobber
// the since value else we'll forget our position!
result, err := txn.Exec(`
INSERT INTO syncv3_sessions_v2devices(device_id, since) VALUES($1,$2)
ON CONFLICT (device_id) DO NOTHING`,
deviceID, "",
)
if err != nil {
return err
}
// if we inserted a row that means it's a brand new device ergo there is no since token
if ra, err := result.RowsAffected(); err == nil && ra == 1 {
// we inserted a new row, no need to query the since value
session = &Session{
ID: id,
DeviceID: deviceID,
}
return nil
}
// Return the since value as we may start a new poller with this session.
var since string
err = txn.QueryRow("SELECT since FROM syncv3_sessions_v2devices WHERE device_id = $1", deviceID).Scan(&since)
if err != nil {
return err
}
session = &Session{
ID: id,
DeviceID: deviceID,
Since: since,
}
return nil
})
return session, err
}
func (s *Sessions) Session(sessionID, deviceID string) (*Session, error) {
var result Session
err := s.db.Get(&result,
`SELECT session_id, device_id, last_to_device_ack, since FROM syncv3_sessions
LEFT JOIN syncv3_sessions_v2devices
ON syncv3_sessions.device_id = syncv3_sessions_v2devices.device_id
WHERE session_id=$1 AND syncv3_sessions.device_id=$2`,
sessionID, deviceID,
)
if err == sql.ErrNoRows {
return nil, nil
}
return &result, nil
}
func (s *Sessions) UpdateDeviceSince(deviceID, since string) error {
_, err := s.db.Exec(`UPDATE syncv3_sessions_v2devices SET since = $1 WHERE device_id = $2`, since, deviceID)
return err
}

View File

@ -5,31 +5,29 @@ import (
"strings"
)
// V3_S1_F2-3-4-5-6_V2_s111_222_333_444
// "V3_" $SESSION "_" $FILTERS "_V2_" $V2TOKEN
// V3_S1_F2-3-4-5-6
// "V3_" $SESSION "_" $FILTERS
type Token struct {
SessionID string
FilterIDs []string
V2token string
}
func (t Token) String() string {
filters := strings.Join(t.FilterIDs, "-")
return fmt.Sprintf("V3_S%s_F%s_V2_%s", t.SessionID, filters, t.V2token)
return fmt.Sprintf("V3_S%s_F%s", t.SessionID, filters)
}
func NewSyncToken(since string) (*Token, error) {
segments := strings.SplitN(since, "_", 5)
if len(segments) != 5 {
segments := strings.SplitN(since, "_", 3)
if len(segments) != 3 {
return nil, fmt.Errorf("not a sync v3 token")
}
if segments[0] != "V3" || segments[3] != "V2" {
if segments[0] != "V3" {
return nil, fmt.Errorf("not a sync v3 token: %s", since)
}
filters := strings.TrimPrefix(segments[2], "F")
return &Token{
SessionID: strings.TrimPrefix(segments[1], "S"),
FilterIDs: strings.Split(filters, "-"),
V2token: segments[4],
}, nil
}

91
v3.go
View File

@ -49,7 +49,7 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
},
DestinationServer: destinationServer,
},
Sessions: NewSessions(postgresDBURI),
Sessions: sync3.NewSessions(postgresDBURI),
Accumulator: state.NewAccumulator(postgresDBURI),
Pollers: make(map[string]*sync2.Poller),
pollerMu: &sync.Mutex{},
@ -67,9 +67,18 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
}
}
type handlerError struct {
StatusCode int
err error
}
func (e *handlerError) Error() string {
return fmt.Sprintf("HTTP %d : %s", e.StatusCode, e.err.Error())
}
type SyncV3Handler struct {
V2 *sync2.Client
Sessions *Sessions
Sessions *sync3.Sessions
Accumulator *state.Accumulator
pollerMu *sync.Mutex
@ -77,16 +86,44 @@ type SyncV3Handler struct {
}
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
err := h.serve(w, req)
if err != nil {
w.WriteHeader(err.StatusCode)
w.Write(asJSONError(err))
}
}
func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handlerError {
// Get or create a Session
session, _, err := h.getOrCreateSession(req)
if err != nil {
return err
}
log.Info().Str("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
// make sure we have a poller for this device
h.ensurePolling(req.Header.Get("Authorization"), session)
// return data based on filters
w.WriteHeader(200)
w.Write([]byte(sync3.Token{
SessionID: session.ID,
FilterIDs: []string{},
}.String()))
return nil
}
// getOrCreateSession retrieves an existing session if ?since= is set, else makes a new session.
// Returns a session or an error. Returns a token if and only if there is an existing session.
func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *sync3.Token, *handlerError) {
var session *sync3.Session
var tokv3 *sync3.Token
deviceID, err := deviceIDFromRequest(req)
if err != nil {
log.Warn().Err(err).Msg("failed to get device ID from request")
w.WriteHeader(400)
w.Write(asJSONError(err))
return
return nil, nil, &handlerError{400, err}
}
// Get or create a Session
var session *Session
var tokv3 *sync3.Token
sincev3 := req.URL.Query().Get("since")
if sincev3 == "" {
session, err = h.Sessions.NewSession(deviceID)
@ -94,58 +131,36 @@ func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
tokv3, err = sync3.NewSyncToken(sincev3)
if err != nil {
log.Warn().Err(err).Msg("failed to parse sync v3 token")
w.WriteHeader(400)
w.Write(asJSONError(err))
return
return nil, nil, &handlerError{400, err}
}
session, err = h.Sessions.Session(tokv3.SessionID, deviceID)
}
if err != nil {
log.Warn().Err(err).Str("device", deviceID).Msg("failed to ensure Session existed for device")
w.WriteHeader(500)
w.Write(asJSONError(err))
return
return nil, nil, &handlerError{500, err}
}
log.Info().Str("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
// map sync v3 token to sync v2 token
var sincev2 string
if tokv3 != nil {
sincev2 = tokv3.V2token
}
// make sure we have a poller for this device
h.ensurePolling(req.Header.Get("Authorization"), session.DeviceID, sincev2)
// return data based on filters
w.WriteHeader(200)
w.Write([]byte(sync3.Token{
V2token: "v2tokengoeshere",
SessionID: session.ID,
FilterIDs: []string{},
}.String()))
return session, tokv3, nil
}
// ensurePolling makes sure there is a poller for this device, making one if need be.
// Blocks until at least 1 sync is done if and only if the poller was just created.
// This ensures that calls to the database will return data.
func (h *SyncV3Handler) ensurePolling(authHeader, deviceID, since string) {
func (h *SyncV3Handler) ensurePolling(authHeader string, session *sync3.Session) {
h.pollerMu.Lock()
poller, ok := h.Pollers[deviceID]
poller, ok := h.Pollers[session.DeviceID]
// either no poller exists or it did but it died
if ok && !poller.Terminated {
h.pollerMu.Unlock()
return
}
// replace the poller
poller = sync2.NewPoller(authHeader, deviceID, h.V2, h.Accumulator)
poller = sync2.NewPoller(authHeader, session.DeviceID, h.V2, h.Accumulator, h.Sessions)
var wg sync.WaitGroup
wg.Add(1)
go poller.Poll(since, func() {
go poller.Poll(session.Since, func() {
wg.Done()
})
h.Pollers[deviceID] = poller
h.Pollers[session.DeviceID] = poller
h.pollerMu.Unlock()
wg.Wait()
}