mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge branch 'main' into dmr/debug-from-stable
This commit is contained in:
commit
b72ad3bded
12
README.md
12
README.md
@ -2,7 +2,11 @@
|
||||
|
||||
Run a sliding sync proxy. An implementation of [MSC3575](https://github.com/matrix-org/matrix-doc/blob/kegan/sync-v3/proposals/3575-sync.md).
|
||||
|
||||
Proxy version to MSC API specification:
|
||||
## Proxy version to MSC API specification
|
||||
|
||||
This describes which proxy versions implement which version of the API drafted
|
||||
in MSC3575. See https://github.com/matrix-org/sliding-sync/releases for the
|
||||
changes in the proxy itself.
|
||||
|
||||
- Version 0.1.x: [2022/04/01](https://github.com/matrix-org/matrix-spec-proposals/blob/615e8f5a7bfe4da813bc2db661ed0bd00bccac20/proposals/3575-sync.md)
|
||||
- First release
|
||||
@ -21,10 +25,12 @@ Proxy version to MSC API specification:
|
||||
- Support for `errcode` when sessions expire.
|
||||
- Version 0.99.1 [2023/01/20](https://github.com/matrix-org/matrix-spec-proposals/blob/b4b4e7ff306920d2c862c6ff4d245110f6fa5bc7/proposals/3575-sync.md)
|
||||
- Preparing for major v1.x release: lists-as-keys support.
|
||||
- Version 0.99.2 [2024/07/27](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
|
||||
- Version 0.99.2 [2023/03/31](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
|
||||
- Experimental support for `bump_event_types` when ordering rooms by recency.
|
||||
- Support for opting in to extensions on a per-list and per-room basis.
|
||||
- Sentry support.
|
||||
- Version 0.99.3 [2023/05/23](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md)
|
||||
- Support for per-list `bump_event_types`.
|
||||
- Support for [`conn_id`](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md#concurrent-connections) for distinguishing multiple concurrent connections.
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -2,6 +2,14 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
syncv3 "github.com/matrix-org/sliding-sync"
|
||||
@ -10,18 +18,11 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var GitCommit string
|
||||
|
||||
const version = "0.99.2"
|
||||
const version = "0.99.3"
|
||||
|
||||
const (
|
||||
// Required fields
|
||||
@ -163,6 +164,8 @@ func main() {
|
||||
|
||||
h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{
|
||||
AddPrometheusMetrics: args[EnvPrometheus] != "",
|
||||
DBMaxConns: 100,
|
||||
DBConnMaxIdleTime: time.Hour,
|
||||
})
|
||||
|
||||
go h2.StartV2Pollers()
|
||||
|
@ -44,7 +44,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
||||
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
|
||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||
var row DeviceDataRow
|
||||
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
|
||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// if there is no device data for this user, it's not an error.
|
||||
@ -78,7 +78,7 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
||||
// the device_data table.
|
||||
return nil
|
||||
}
|
||||
_, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||
_, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -94,7 +94,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
|
||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||
// select what already exists
|
||||
var row DeviceDataRow
|
||||
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
|
||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
@ -119,7 +119,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = t.db.QueryRow(
|
||||
err = txn.QueryRow(
|
||||
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
||||
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3, id=nextval('syncv3_device_data_seq') RETURNING id`,
|
||||
dd.UserID, dd.DeviceID, data,
|
||||
|
@ -317,6 +317,25 @@ func (t *EventTable) LatestEventInRooms(txn *sqlx.Tx, roomIDs []string, highestN
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
|
||||
// the position (event nid) may be for a random different room, so we need to find the highest nid <= this position for this room
|
||||
var events []Event
|
||||
err = t.db.Select(
|
||||
&events,
|
||||
`SELECT event_nid, room_id FROM syncv3_events
|
||||
WHERE event_nid IN (SELECT max(event_nid) FROM syncv3_events WHERE event_nid <= $1 AND room_id = ANY($2) GROUP BY room_id)`,
|
||||
highestNID, pq.StringArray(roomIDs),
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
roomToNID = make(map[string]int64)
|
||||
for _, ev := range events {
|
||||
roomToNID[ev.RoomID] = ev.NID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
|
||||
var events []Event
|
||||
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`,
|
||||
@ -419,8 +438,8 @@ func (t *EventTable) SelectClosestPrevBatchByID(roomID string, eventID string) (
|
||||
|
||||
// Select the closest prev batch token for the provided event NID. Returns the empty string if there
|
||||
// is no closest.
|
||||
func (t *EventTable) SelectClosestPrevBatch(roomID string, eventNID int64) (prevBatch string, err error) {
|
||||
err = t.db.QueryRow(
|
||||
func (t *EventTable) SelectClosestPrevBatch(txn *sqlx.Tx, roomID string, eventNID int64) (prevBatch string, err error) {
|
||||
err = txn.QueryRow(
|
||||
`SELECT prev_batch FROM syncv3_events WHERE prev_batch IS NOT NULL AND room_id=$1 AND event_nid >= $2 LIMIT 1`, roomID, eventNID,
|
||||
).Scan(&prevBatch)
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
@ -776,10 +778,14 @@ func TestEventTablePrevBatch(t *testing.T) {
|
||||
}
|
||||
|
||||
assertPrevBatch := func(roomID string, index int, wantPrevBatch string) {
|
||||
gotPrevBatch, err := table.SelectClosestPrevBatch(roomID, int64(idToNID[events[index].ID]))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
var gotPrevBatch string
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
gotPrevBatch, err = table.SelectClosestPrevBatch(txn, roomID, int64(idToNID[events[index].ID]))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if wantPrevBatch != "" {
|
||||
if gotPrevBatch == "" || gotPrevBatch != wantPrevBatch {
|
||||
t.Fatalf("SelectClosestPrevBatch: got %v want %v", gotPrevBatch, wantPrevBatch)
|
||||
@ -871,6 +877,93 @@ func TestRemoveUnsignedTXNID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLatestEventNIDInRooms(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
table := NewEventTable(db)
|
||||
|
||||
var result map[string]int64
|
||||
var err error
|
||||
// Insert the following:
|
||||
// - Room FIRST: [N]
|
||||
// - Room SECOND: [N+1, N+2, N+3] (replace)
|
||||
// - Room THIRD: [N+4] (max)
|
||||
first := "!FIRST"
|
||||
second := "!SECOND"
|
||||
third := "!THIRD"
|
||||
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
result, err = table.Insert(txn, []Event{
|
||||
{
|
||||
ID: "$N",
|
||||
Type: "message",
|
||||
RoomID: first,
|
||||
JSON: []byte(`{}`),
|
||||
},
|
||||
{
|
||||
ID: "$N+1",
|
||||
Type: "message",
|
||||
RoomID: second,
|
||||
JSON: []byte(`{}`),
|
||||
},
|
||||
{
|
||||
ID: "$N+2",
|
||||
Type: "message",
|
||||
RoomID: second,
|
||||
JSON: []byte(`{}`),
|
||||
},
|
||||
{
|
||||
ID: "$N+3",
|
||||
Type: "message",
|
||||
RoomID: second,
|
||||
JSON: []byte(`{}`),
|
||||
},
|
||||
{
|
||||
ID: "$N+4",
|
||||
Type: "message",
|
||||
RoomID: third,
|
||||
JSON: []byte(`{}`),
|
||||
},
|
||||
}, false)
|
||||
return err
|
||||
})
|
||||
assertNoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
roomIDs []string
|
||||
highestNID int64
|
||||
wantMap map[string]string
|
||||
}{
|
||||
// We should see FIRST=N, SECOND=N+3, THIRD=N+4 when querying LatestEventNIDInRooms with N+4
|
||||
{
|
||||
roomIDs: []string{first, second, third},
|
||||
highestNID: result["$N+4"],
|
||||
wantMap: map[string]string{
|
||||
first: "$N", second: "$N+3", third: "$N+4",
|
||||
},
|
||||
},
|
||||
// We should see FIRST=N, SECOND=N+2 when querying LatestEventNIDInRooms with N+2
|
||||
{
|
||||
roomIDs: []string{first, second, third},
|
||||
highestNID: result["$N+2"],
|
||||
wantMap: map[string]string{
|
||||
first: "$N", second: "$N+2",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
gotRoomToNID, err := table.LatestEventNIDInRooms(tc.roomIDs, int64(tc.highestNID))
|
||||
assertNoError(t, err)
|
||||
want := make(map[string]int64) // map event IDs to nids
|
||||
for roomID, eventID := range tc.wantMap {
|
||||
want[roomID] = int64(result[eventID])
|
||||
}
|
||||
if !reflect.DeepEqual(gotRoomToNID, want) {
|
||||
t.Errorf("%+v: got %v want %v", tc, gotRoomToNID, want)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEventTableSelectUnknownEventIDs(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
|
@ -31,6 +31,12 @@ type StartupSnapshot struct {
|
||||
AllJoinedMembers map[string][]string // room_id -> [user_id]
|
||||
}
|
||||
|
||||
type LatestEvents struct {
|
||||
Timeline []json.RawMessage
|
||||
PrevBatch string
|
||||
LatestNID int64
|
||||
}
|
||||
|
||||
type Storage struct {
|
||||
Accumulator *Accumulator
|
||||
EventsTable *EventTable
|
||||
@ -535,7 +541,7 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to form sql query: %s", err)
|
||||
}
|
||||
rows, err := s.Accumulator.db.Query(s.Accumulator.db.Rebind(query), args...)
|
||||
rows, err := txn.Query(txn.Rebind(query), args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query: %s", err)
|
||||
}
|
||||
@ -580,16 +586,16 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string][]json.RawMessage, map[string]string, error) {
|
||||
func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string]*LatestEvents, error) {
|
||||
roomIDToRanges, err := s.visibleEventNIDsBetweenForRooms(userID, roomIDs, 0, to)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string][]json.RawMessage, len(roomIDs))
|
||||
prevBatches := make(map[string]string, len(roomIDs))
|
||||
result := make(map[string]*LatestEvents, len(roomIDs))
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
for roomID, ranges := range roomIDToRanges {
|
||||
var earliestEventNID int64
|
||||
var latestEventNID int64
|
||||
var roomEvents []json.RawMessage
|
||||
// start at the most recent range as we want to return the most recent `limit` events
|
||||
for i := len(ranges) - 1; i >= 0; i-- {
|
||||
@ -604,6 +610,9 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
|
||||
}
|
||||
// keep pushing to the front so we end up with A,B,C
|
||||
for _, ev := range events {
|
||||
if latestEventNID == 0 { // set first time and never again
|
||||
latestEventNID = ev.NID
|
||||
}
|
||||
roomEvents = append([]json.RawMessage{ev.JSON}, roomEvents...)
|
||||
earliestEventNID = ev.NID
|
||||
if len(roomEvents) >= limit {
|
||||
@ -611,19 +620,23 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
|
||||
}
|
||||
}
|
||||
}
|
||||
latestEvents := LatestEvents{
|
||||
LatestNID: latestEventNID,
|
||||
Timeline: roomEvents,
|
||||
}
|
||||
if earliestEventNID != 0 {
|
||||
// the oldest event needs a prev batch token, so find one now
|
||||
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(roomID, earliestEventNID)
|
||||
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(txn, roomID, earliestEventNID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err)
|
||||
}
|
||||
prevBatches[roomID] = prevBatch
|
||||
latestEvents.PrevBatch = prevBatch
|
||||
}
|
||||
result[roomID] = roomEvents
|
||||
result[roomID] = &latestEvents
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, prevBatches, err
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []string, from, to int64) (map[string][][2]int64, error) {
|
||||
@ -637,7 +650,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
|
||||
return nil, fmt.Errorf("VisibleEventNIDsBetweenForRooms.SelectEventsWithTypeStateKeyInRooms: %s", err)
|
||||
}
|
||||
}
|
||||
joinTimingsByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents)
|
||||
joinTimingsAtFromByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
|
||||
}
|
||||
@ -648,7 +661,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
|
||||
return nil, fmt.Errorf("failed to load membership events: %s", err)
|
||||
}
|
||||
|
||||
return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to)
|
||||
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
|
||||
}
|
||||
|
||||
// Work out the NID ranges to pull events from for this user. Given a from and to event nid stream position,
|
||||
@ -678,7 +691,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin
|
||||
// - For Room E: from=1, to=15 returns { RoomE: [ [3,3], [13,15] ] } (tests invites)
|
||||
func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[string][][2]int64, error) {
|
||||
// load *ALL* joined rooms for this user at from (inclusive)
|
||||
joinTimingsByRoomID, err := s.JoinedRoomsAfterPosition(userID, from)
|
||||
joinTimingsAtFromByRoomID, err := s.JoinedRoomsAfterPosition(userID, from)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
|
||||
}
|
||||
@ -689,10 +702,10 @@ func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[st
|
||||
return nil, fmt.Errorf("failed to load membership events: %s", err)
|
||||
}
|
||||
|
||||
return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to)
|
||||
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
|
||||
}
|
||||
|
||||
func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) {
|
||||
func (s *Storage) visibleEventNIDsWithData(joinTimingsAtFromByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) {
|
||||
// load membership events in order and bucket based on room ID
|
||||
roomIDToLogs := make(map[string][]membershipEvent)
|
||||
for _, ev := range membershipEvents {
|
||||
@ -754,7 +767,7 @@ func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]intern
|
||||
|
||||
// For each joined room, perform the algorithm and delete the logs afterwards
|
||||
result := make(map[string][][2]int64)
|
||||
for joinedRoomID, _ := range joinTimingsByRoomID {
|
||||
for joinedRoomID, _ := range joinTimingsAtFromByRoomID {
|
||||
roomResult := calculateVisibleEventNIDs(true, from, to, roomIDToLogs[joinedRoomID])
|
||||
result[joinedRoomID] = roomResult
|
||||
delete(roomIDToLogs, joinedRoomID)
|
||||
|
@ -566,10 +566,15 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
|
||||
wantPrevBatch := wantPrevBatches[i]
|
||||
eventNID := idsToNIDs[eventIDs[i]]
|
||||
// closest batch to the last event in the chunk (latest nid) is always the next prev batch token
|
||||
pb, err := store.EventsTable.SelectClosestPrevBatch(roomID, eventNID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
var pb string
|
||||
_ = sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) {
|
||||
pb, err = store.EventsTable.SelectClosestPrevBatch(txn, roomID, eventNID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if pb != wantPrevBatch {
|
||||
t.Fatalf("SelectClosestPrevBatch: got %v want %v", pb, wantPrevBatch)
|
||||
}
|
||||
|
@ -32,8 +32,8 @@ func NewDevicesTable(db *sqlx.DB) *DevicesTable {
|
||||
|
||||
// InsertDevice creates a new devices row with a blank since token if no such row
|
||||
// exists. Otherwise, it does nothing.
|
||||
func (t *DevicesTable) InsertDevice(userID, deviceID string) error {
|
||||
_, err := t.db.Exec(
|
||||
func (t *DevicesTable) InsertDevice(txn *sqlx.Tx, userID, deviceID string) error {
|
||||
_, err := txn.Exec(
|
||||
` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3)
|
||||
ON CONFLICT (user_id, device_id) DO NOTHING`,
|
||||
userID, deviceID, "",
|
||||
|
@ -2,6 +2,7 @@ package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
@ -41,18 +42,25 @@ func TestDevicesTableSinceColumn(t *testing.T) {
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceSecret2 := "mysecret2"
|
||||
|
||||
t.Log("Insert two tokens for Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken, aliceToken2 *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
t.Log("Insert two tokens for Alice.")
|
||||
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Add a devices row for Alice")
|
||||
err = devices.InsertDevice(alice, aliceDevice)
|
||||
t.Log("Add a devices row for Alice")
|
||||
err = devices.InsertDevice(txn, alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert device: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
|
||||
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
||||
@ -104,39 +112,49 @@ func TestTokenForEachDevice(t *testing.T) {
|
||||
chris := "chris"
|
||||
chrisDevice := "chris_desktop"
|
||||
|
||||
t.Log("Add a device for Alice, Bob and Chris.")
|
||||
err := devices.InsertDevice(alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(bob, bobDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(chris, chrisDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
t.Log("Add a device for Alice, Bob and Chris.")
|
||||
err := devices.InsertDevice(txn, alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(txn, bob, bobDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(txn, chris, chrisDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Mark Alice's device with a since token.")
|
||||
sinceValue := "s-1-2-3-4"
|
||||
devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateDeviceSince returned error: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
||||
aliceLastSeen1 := time.Now()
|
||||
_, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
||||
aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken2, bobToken *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
||||
aliceLastSeen1 := time.Now()
|
||||
_, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
||||
aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Fetch a token for every device")
|
||||
gotTokens, err := tokens.TokenForEachDevice(nil)
|
||||
|
@ -12,6 +12,9 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
@ -32,12 +35,11 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C
|
||||
// processing v2 data (as a sync2.V2DataReceiver) and publishing updates (pubsub.Payload to V2Listeners);
|
||||
// and receiving and processing EnsurePolling events.
|
||||
type Handler struct {
|
||||
pMap *sync2.PollerMap
|
||||
pMap sync2.IPollerMap
|
||||
v2Store *sync2.Storage
|
||||
Store *state.Storage
|
||||
v2Pub pubsub.Notifier
|
||||
v3Sub *pubsub.V3Sub
|
||||
client sync2.Client
|
||||
unreadMap map[string]struct {
|
||||
Highlight int
|
||||
Notif int
|
||||
@ -53,13 +55,12 @@ type Handler struct {
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
|
||||
pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
|
||||
) (*Handler, error) {
|
||||
h := &Handler{
|
||||
pMap: pMap,
|
||||
v2Store: v2Store,
|
||||
client: client,
|
||||
Store: store,
|
||||
subSystem: "poller",
|
||||
unreadMap: make(map[string]struct {
|
||||
|
170
sync2/handler2/handler_test.go
Normal file
170
sync2/handler2/handler_test.go
Normal file
@ -0,0 +1,170 @@
|
||||
package handler2_test
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync2/handler2"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var postgresURI string
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
postgresURI = testutils.PrepareDBConnectionString()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
type pollInfo struct {
|
||||
pid sync2.PollerID
|
||||
accessToken string
|
||||
v2since string
|
||||
isStartup bool
|
||||
}
|
||||
|
||||
type mockPollerMap struct {
|
||||
calls []pollInfo
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) NumPollers() int {
|
||||
return 0
|
||||
}
|
||||
func (p *mockPollerMap) Terminate() {}
|
||||
|
||||
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
p.calls = append(p.calls, pollInfo{
|
||||
pid: pid,
|
||||
accessToken: accessToken,
|
||||
v2since: v2since,
|
||||
isStartup: isStartup,
|
||||
})
|
||||
}
|
||||
func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) {
|
||||
for _, c := range p.calls {
|
||||
if reflect.DeepEqual(pi, c) {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("assertCallExists: did not find %+v", pi)
|
||||
}
|
||||
|
||||
type mockPub struct {
|
||||
calls []pubsub.Payload
|
||||
mu *sync.Mutex
|
||||
waiters map[string][]chan struct{}
|
||||
}
|
||||
|
||||
func newMockPub() *mockPub {
|
||||
return &mockPub{
|
||||
mu: &sync.Mutex{},
|
||||
waiters: make(map[string][]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Notify chanName that there is a new payload p. Return an error if we failed to send the notification.
|
||||
func (p *mockPub) Notify(chanName string, payload pubsub.Payload) error {
|
||||
p.calls = append(p.calls, payload)
|
||||
p.mu.Lock()
|
||||
for _, ch := range p.waiters[payload.Type()] {
|
||||
close(ch)
|
||||
}
|
||||
p.waiters[payload.Type()] = nil // don't re-notify for 2nd+ payload
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
p.mu.Lock()
|
||||
p.waiters[t] = append(p.waiters[t], ch)
|
||||
p.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Close is called when we should stop listening.
|
||||
func (p *mockPub) Close() error { return nil }
|
||||
|
||||
type mockSub struct{}
|
||||
|
||||
// Begin listening on this channel with this callback starting from this position. Blocks until Close() is called.
|
||||
func (s *mockSub) Listen(chanName string, fn func(p pubsub.Payload)) error { return nil }
|
||||
|
||||
// Close the listener. No more callbacks should fire.
|
||||
func (s *mockSub) Close() error { return nil }
|
||||
|
||||
func assertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("assertNoError: %v", err)
|
||||
}
|
||||
|
||||
// Test that if you call EnsurePolling you get back V2InitialSyncComplete down pubsub and the poller
|
||||
// map is called correctly
|
||||
func TestHandlerFreshEnsurePolling(t *testing.T) {
|
||||
store := state.NewStorage(postgresURI)
|
||||
v2Store := sync2.NewStore(postgresURI, "secret")
|
||||
pMap := &mockPollerMap{}
|
||||
pub := newMockPub()
|
||||
sub := &mockSub{}
|
||||
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false)
|
||||
assertNoError(t, err)
|
||||
alice := "@alice:localhost"
|
||||
deviceID := "ALICE"
|
||||
token := "aliceToken"
|
||||
|
||||
var tok *sync2.Token
|
||||
sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) error {
|
||||
// the device and token needs to already exist prior to EnsurePolling
|
||||
err = v2Store.DevicesTable.InsertDevice(txn, alice, deviceID)
|
||||
assertNoError(t, err)
|
||||
tok, err = v2Store.TokensTable.Insert(txn, token, alice, deviceID, time.Now())
|
||||
assertNoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
payloadInitialSyncComplete := pubsub.V2InitialSyncComplete{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
ch := pub.WaitForPayloadType(payloadInitialSyncComplete.Type())
|
||||
// ask the handler to start polling
|
||||
h.EnsurePolling(&pubsub.V3EnsurePolling{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
AccessTokenHash: tok.AccessTokenHash,
|
||||
})
|
||||
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch)
|
||||
|
||||
// make sure we polled with the token i.e it did a db hit
|
||||
pMap.assertCallExists(t, pollInfo{
|
||||
pid: sync2.PollerID{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
accessToken: token,
|
||||
v2since: "",
|
||||
isStartup: false,
|
||||
})
|
||||
|
||||
}
|
@ -59,6 +59,12 @@ type V2DataReceiver interface {
|
||||
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
|
||||
}
|
||||
|
||||
type IPollerMap interface {
|
||||
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger)
|
||||
NumPollers() int
|
||||
Terminate()
|
||||
}
|
||||
|
||||
// PollerMap is a map of device ID to Poller
|
||||
type PollerMap struct {
|
||||
v2Client Client
|
||||
@ -508,7 +514,8 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
|
||||
}
|
||||
if err != nil {
|
||||
// check if temporary
|
||||
if statusCode != 401 {
|
||||
isFatal := statusCode == 401 || statusCode == 403
|
||||
if !isFatal {
|
||||
p.logger.Warn().Int("code", statusCode).Err(err).Msg("Poller: sync v2 poll returned temporary error")
|
||||
s.failCount += 1
|
||||
return nil
|
||||
|
@ -1,10 +1,11 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/rs/zerolog"
|
||||
"os"
|
||||
)
|
||||
|
||||
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
|
@ -171,10 +171,10 @@ func (t *TokensTable) TokenForEachDevice(txn *sqlx.Tx) (tokens []TokenForPoller,
|
||||
}
|
||||
|
||||
// Insert a new token into the table.
|
||||
func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
|
||||
func (t *TokensTable) Insert(txn *sqlx.Tx, plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
|
||||
hashedToken := hashToken(plaintextToken)
|
||||
encToken := t.encrypt(plaintextToken)
|
||||
_, err := t.db.Exec(
|
||||
_, err := txn.Exec(
|
||||
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (token_hash) DO NOTHING;`,
|
||||
|
@ -1,6 +1,8 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -26,27 +28,31 @@ func TestTokensTable(t *testing.T) {
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceToken1FirstSeen := time.Now()
|
||||
|
||||
// Test a single token
|
||||
t.Log("Insert a new token from Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken, reinsertedToken *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
// Test a single token
|
||||
t.Log("Insert a new token from Alice.")
|
||||
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Reinsert the same token.")
|
||||
reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
t.Log("Reinsert the same token.")
|
||||
reinsertedToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("This should yield an equal Token struct.")
|
||||
assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Try to mark Alice's token as being used after an hour.")
|
||||
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
|
||||
err := tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update last seen: %s", err)
|
||||
}
|
||||
@ -74,17 +80,20 @@ func TestTokensTable(t *testing.T) {
|
||||
}
|
||||
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen)
|
||||
|
||||
// Test a second token for Alice
|
||||
t.Log("Insert a second token for Alice.")
|
||||
aliceSecret2 := "mysecret2"
|
||||
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
// Test a second token for Alice
|
||||
t.Log("Insert a second token for Alice.")
|
||||
aliceSecret2 := "mysecret2"
|
||||
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
|
||||
aliceToken2, err := tokens.Insert(txn, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeletingTokens(t *testing.T) {
|
||||
@ -94,11 +103,15 @@ func TestDeletingTokens(t *testing.T) {
|
||||
|
||||
t.Log("Insert a new token from Alice.")
|
||||
accessToken := "mytoken"
|
||||
token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
var token *Token
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
token, err = tokens.Insert(txn, accessToken, "@bob:builders.com", "device", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
t.Log("We should be able to fetch this token without error.")
|
||||
_, err = tokens.Token(accessToken)
|
||||
if err != nil {
|
||||
|
@ -58,7 +58,7 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C
|
||||
// Dispatcher for new events.
|
||||
type GlobalCache struct {
|
||||
// LoadJoinedRoomsOverride allows tests to mock out the behaviour of LoadJoinedRooms.
|
||||
LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error)
|
||||
LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, latestNIDs map[string]int64, err error)
|
||||
|
||||
// inserts are done by v2 poll loops, selects are done by v3 request threads
|
||||
// there are lots of overlapping keys as many users (threads) can be joined to the same room (key)
|
||||
@ -135,23 +135,37 @@ func (c *GlobalCache) copyRoom(roomID string) *internal.RoomMetadata {
|
||||
// The two maps returned by this function have exactly the same set of keys. Each is nil
|
||||
// iff a non-nil error is returned.
|
||||
// TODO: remove with LoadRoomState?
|
||||
// FIXME: return args are a mess
|
||||
func (c *GlobalCache) LoadJoinedRooms(ctx context.Context, userID string) (
|
||||
pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata, err error,
|
||||
pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata,
|
||||
latestNIDs map[string]int64, err error,
|
||||
) {
|
||||
if c.LoadJoinedRoomsOverride != nil {
|
||||
return c.LoadJoinedRoomsOverride(userID)
|
||||
}
|
||||
initialLoadPosition, err := c.store.LatestEventNID()
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
joinTimingByRoomID, err = c.store.JoinedRoomsAfterPosition(userID, initialLoadPosition)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
roomIDs := make([]string, len(joinTimingByRoomID))
|
||||
i := 0
|
||||
for roomID := range joinTimingByRoomID {
|
||||
roomIDs[i] = roomID
|
||||
i++
|
||||
}
|
||||
|
||||
latestNIDs, err = c.store.EventsTable.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
|
||||
if err != nil {
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
|
||||
// TODO: no guarantee that this state is the same as latest unless called in a dispatcher loop
|
||||
rooms := c.LoadRoomsFromMap(ctx, joinTimingByRoomID)
|
||||
return initialLoadPosition, rooms, joinTimingByRoomID, nil
|
||||
return initialLoadPosition, rooms, joinTimingByRoomID, latestNIDs, nil
|
||||
}
|
||||
|
||||
func (c *GlobalCache) LoadStateEvent(ctx context.Context, roomID string, loadPosition int64, evType, stateKey string) json.RawMessage {
|
||||
|
@ -42,9 +42,9 @@ type UserRoomData struct {
|
||||
HighlightCount int
|
||||
Invite *InviteData
|
||||
|
||||
// these fields are set by LazyLoadTimelines and are per-function call, and are not persisted in-memory.
|
||||
RequestedPrevBatch string
|
||||
RequestedTimeline []json.RawMessage
|
||||
// this field is set by LazyLoadTimelines and is per-function call, and is not persisted in-memory.
|
||||
// The zero value of this safe to use (0 latest nid, no prev batch, no timeline).
|
||||
RequestedLatestEvents state.LatestEvents
|
||||
|
||||
// TODO: should Canonicalised really be in RoomConMetadata? It's only set in SetRoom AFAICS
|
||||
CanonicalisedName string // stripped leading symbols like #, all in lower case
|
||||
@ -218,7 +218,7 @@ func (c *UserCache) Unsubscribe(id int) {
|
||||
func (c *UserCache) OnRegistered(ctx context.Context) error {
|
||||
// select all spaces the user is a part of to seed the cache correctly. This has to be done in
|
||||
// the OnRegistered callback which has locking guarantees. This is why...
|
||||
_, joinedRooms, joinTimings, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID)
|
||||
_, joinedRooms, joinTimings, _, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load joined rooms: %s", err)
|
||||
}
|
||||
@ -295,7 +295,7 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID
|
||||
return c.LazyRoomDataOverride(loadPos, roomIDs, maxTimelineEvents)
|
||||
}
|
||||
result := make(map[string]UserRoomData)
|
||||
roomIDToEvents, roomIDToPrevBatch, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents)
|
||||
roomIDToLatestEvents, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents)
|
||||
if err != nil {
|
||||
logger.Err(err).Strs("rooms", roomIDs).Msg("failed to get LatestEventsInRooms")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
@ -303,16 +303,14 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID
|
||||
}
|
||||
c.roomToDataMu.Lock()
|
||||
for _, requestedRoomID := range roomIDs {
|
||||
events := roomIDToEvents[requestedRoomID]
|
||||
latestEvents := roomIDToLatestEvents[requestedRoomID]
|
||||
urd, ok := c.roomToData[requestedRoomID]
|
||||
if !ok {
|
||||
urd = NewUserRoomData()
|
||||
}
|
||||
urd.RequestedTimeline = events
|
||||
if len(events) > 0 {
|
||||
urd.RequestedPrevBatch = roomIDToPrevBatch[requestedRoomID]
|
||||
if latestEvents != nil {
|
||||
urd.RequestedLatestEvents = *latestEvents
|
||||
}
|
||||
|
||||
result[requestedRoomID] = urd
|
||||
}
|
||||
c.roomToDataMu.Unlock()
|
||||
|
@ -30,8 +30,15 @@ type ConnState struct {
|
||||
// "is the user joined to this room?" whereas subscriptions in muxedReq are untrusted.
|
||||
roomSubscriptions map[string]sync3.RoomSubscription // room_id -> subscription
|
||||
|
||||
// TODO: remove this as it is unreliable when you have concurrent updates
|
||||
loadPosition int64
|
||||
// This is some event NID which is used to anchor any requests for room data from the database
|
||||
// to their per-room latest NIDs. It does this by selecting the latest NID for each requested room
|
||||
// where the NID is <= this anchor value. Note that there are no ordering guarantees here: it's
|
||||
// possible for the anchor to be higher than room X's latest NID and for this connection to have
|
||||
// not yet seen room X's latest NID (it'll be sitting in the live buffer). This is why it's important
|
||||
// that ConnState DOES NOT ignore events based on this value - it must ignore events based on the real
|
||||
// load position for the room.
|
||||
// If this value is negative or 0, it means that this connection has not been loaded yet.
|
||||
anchorLoadPosition int64
|
||||
// roomID -> latest load pos
|
||||
loadPositions map[string]int64
|
||||
|
||||
@ -59,7 +66,7 @@ func NewConnState(
|
||||
userCache: userCache,
|
||||
userID: userID,
|
||||
deviceID: deviceID,
|
||||
loadPosition: -1,
|
||||
anchorLoadPosition: -1,
|
||||
loadPositions: make(map[string]int64),
|
||||
roomSubscriptions: make(map[string]sync3.RoomSubscription),
|
||||
lists: sync3.NewInternalRequestLists(),
|
||||
@ -73,6 +80,8 @@ func NewConnState(
|
||||
ConnState: cs,
|
||||
updates: make(chan caches.Update, maxPendingEventUpdates),
|
||||
}
|
||||
// subscribe for updates before loading. We risk seeing dupes but that's fine as load positions
|
||||
// will stop us double-processing.
|
||||
cs.userCacheID = cs.userCache.Subsribe(cs)
|
||||
return cs
|
||||
}
|
||||
@ -89,10 +98,13 @@ func NewConnState(
|
||||
// - load() bases its current state based on the latest position, which includes processing of these N events.
|
||||
// - post load() we read N events, processing them a 2nd time.
|
||||
func (s *ConnState) load(ctx context.Context, req *sync3.Request) error {
|
||||
initialLoadPosition, joinedRooms, joinTimings, err := s.globalCache.LoadJoinedRooms(ctx, s.userID)
|
||||
initialLoadPosition, joinedRooms, joinTimings, loadPositions, err := s.globalCache.LoadJoinedRooms(ctx, s.userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for roomID, pos := range loadPositions {
|
||||
s.loadPositions[roomID] = pos
|
||||
}
|
||||
rooms := make([]sync3.RoomConnMetadata, len(joinedRooms))
|
||||
i := 0
|
||||
for _, metadata := range joinedRooms {
|
||||
@ -145,16 +157,21 @@ func (s *ConnState) load(ctx context.Context, req *sync3.Request) error {
|
||||
for _, r := range rooms {
|
||||
s.lists.SetRoom(r)
|
||||
}
|
||||
s.loadPosition = initialLoadPosition
|
||||
s.anchorLoadPosition = initialLoadPosition
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnIncomingRequest is guaranteed to be called sequentially (it's protected by a mutex in conn.go)
|
||||
func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool, start time.Time) (*sync3.Response, error) {
|
||||
if s.loadPosition == -1 {
|
||||
if s.anchorLoadPosition <= 0 {
|
||||
// load() needs no ctx so drop it
|
||||
_, region := internal.StartSpan(ctx, "load")
|
||||
s.load(ctx, req)
|
||||
err := s.load(ctx, req)
|
||||
if err != nil {
|
||||
// in practice this means DB hit failures. If we try again later maybe it'll work, and we will because
|
||||
// anchorLoadPosition is unset.
|
||||
logger.Err(err).Str("conn", cid.String()).Msg("failed to load initial data")
|
||||
}
|
||||
region.End()
|
||||
}
|
||||
setupTime := time.Since(start)
|
||||
@ -165,19 +182,19 @@ func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req
|
||||
// onIncomingRequest is a callback which fires when the client makes a request to the server. Whilst each request may
|
||||
// be on their own goroutine, the requests are linearised for us by Conn so it is safe to modify ConnState without
|
||||
// additional locking mechanisms.
|
||||
func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
|
||||
func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
|
||||
start := time.Now()
|
||||
// ApplyDelta works fine if s.muxedReq is nil
|
||||
var delta *sync3.RequestDelta
|
||||
s.muxedReq, delta = s.muxedReq.ApplyDelta(req)
|
||||
internal.Logf(ctx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
|
||||
internal.Logf(reqCtx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
|
||||
for key, l := range delta.Lists {
|
||||
listData := ""
|
||||
if l.Curr != nil {
|
||||
listDataBytes, _ := json.Marshal(l.Curr)
|
||||
listData = string(listDataBytes)
|
||||
}
|
||||
internal.Logf(ctx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
|
||||
internal.Logf(reqCtx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
|
||||
}
|
||||
for roomID, sub := range s.muxedReq.RoomSubscriptions {
|
||||
internal.Logf(ctx, "connstate", "room sub[%v] %v", roomID, sub)
|
||||
@ -187,20 +204,20 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
// for it to mix together
|
||||
builder := NewRoomsBuilder()
|
||||
// works out which rooms are subscribed to but doesn't pull room data
|
||||
s.buildRoomSubscriptions(ctx, builder, delta.Subs, delta.Unsubs)
|
||||
s.buildRoomSubscriptions(reqCtx, builder, delta.Subs, delta.Unsubs)
|
||||
// works out how rooms get moved about but doesn't pull room data
|
||||
respLists := s.buildListSubscriptions(ctx, builder, delta.Lists)
|
||||
respLists := s.buildListSubscriptions(reqCtx, builder, delta.Lists)
|
||||
|
||||
// pull room data and set changes on the response
|
||||
response := &sync3.Response{
|
||||
Rooms: s.buildRooms(ctx, builder.BuildSubscriptions()), // pull room data
|
||||
Rooms: s.buildRooms(reqCtx, builder.BuildSubscriptions()), // pull room data
|
||||
Lists: respLists,
|
||||
}
|
||||
|
||||
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
|
||||
// is being notified about (e.g. for room account data)
|
||||
ctx, region := internal.StartSpan(ctx, "extensions")
|
||||
response.Extensions = s.extensionsHandler.Handle(ctx, s.muxedReq.Extensions, extensions.Context{
|
||||
extCtx, region := internal.StartSpan(reqCtx, "extensions")
|
||||
response.Extensions = s.extensionsHandler.Handle(extCtx, s.muxedReq.Extensions, extensions.Context{
|
||||
UserID: s.userID,
|
||||
DeviceID: s.deviceID,
|
||||
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
|
||||
@ -218,8 +235,8 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
}
|
||||
|
||||
// do live tracking if we have nothing to tell the client yet
|
||||
ctx, region = internal.StartSpan(ctx, "liveUpdate")
|
||||
s.live.liveUpdate(ctx, req, s.muxedReq.Extensions, isInitial, response)
|
||||
updateCtx, region := internal.StartSpan(reqCtx, "liveUpdate")
|
||||
s.live.liveUpdate(updateCtx, req, s.muxedReq.Extensions, isInitial, response)
|
||||
region.End()
|
||||
|
||||
// counts are AFTER events are applied, hence after liveUpdate
|
||||
@ -232,7 +249,7 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
// Add membership events for users sending typing notifications. We do this after live update
|
||||
// and initial room loading code so we LL room members in all cases.
|
||||
if response.Extensions.Typing != nil && response.Extensions.Typing.HasData(isInitial) {
|
||||
s.lazyLoadTypingMembers(ctx, response)
|
||||
s.lazyLoadTypingMembers(reqCtx, response)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
@ -495,7 +512,7 @@ func (s *ConnState) lazyLoadTypingMembers(ctx context.Context, response *sync3.R
|
||||
continue
|
||||
}
|
||||
// load the state event
|
||||
memberEvent := s.globalCache.LoadStateEvent(ctx, roomID, s.loadPosition, "m.room.member", typingUserID.Str)
|
||||
memberEvent := s.globalCache.LoadStateEvent(ctx, roomID, s.loadPositions[roomID], "m.room.member", typingUserID.Str)
|
||||
if memberEvent != nil {
|
||||
room.RequiredState = append(room.RequiredState, memberEvent)
|
||||
s.lazyCache.AddUser(roomID, typingUserID.Str)
|
||||
@ -512,15 +529,20 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
ctx, span := internal.StartSpan(ctx, "getInitialRoomData")
|
||||
defer span.End()
|
||||
rooms := make(map[string]sync3.Room, len(roomIDs))
|
||||
// We want to grab the user room data and the room metadata for each room ID.
|
||||
roomIDToUserRoomData := s.userCache.LazyLoadTimelines(ctx, s.loadPosition, roomIDs, int(roomSub.TimelineLimit))
|
||||
// We want to grab the user room data and the room metadata for each room ID. We use the globally
|
||||
// highest NID we've seen to act as an anchor for the request. This anchor does not guarantee that
|
||||
// events returned here have already been seen - the position is not globally ordered - so because
|
||||
// room A has a position of 6 and B has 7 (so the highest is 7) does not mean that this connection
|
||||
// has seen 6, as concurrent room updates cause A and B to race. This is why we then go through the
|
||||
// response to this call to assign new load positions for each room.
|
||||
roomIDToUserRoomData := s.userCache.LazyLoadTimelines(ctx, s.anchorLoadPosition, roomIDs, int(roomSub.TimelineLimit))
|
||||
roomMetadatas := s.globalCache.LoadRooms(ctx, roomIDs...)
|
||||
// prepare lazy loading data structures, txn IDs
|
||||
roomToUsersInTimeline := make(map[string][]string, len(roomIDToUserRoomData))
|
||||
roomToTimeline := make(map[string][]json.RawMessage)
|
||||
for roomID, urd := range roomIDToUserRoomData {
|
||||
set := make(map[string]struct{})
|
||||
for _, ev := range urd.RequestedTimeline {
|
||||
for _, ev := range urd.RequestedLatestEvents.Timeline {
|
||||
set[gjson.GetBytes(ev, "sender").Str] = struct{}{}
|
||||
}
|
||||
userIDs := make([]string, len(set))
|
||||
@ -530,11 +552,22 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
i++
|
||||
}
|
||||
roomToUsersInTimeline[roomID] = userIDs
|
||||
roomToTimeline[roomID] = urd.RequestedTimeline
|
||||
roomToTimeline[roomID] = urd.RequestedLatestEvents.Timeline
|
||||
// remember what we just loaded so if we see these events down the live stream we know to ignore them.
|
||||
// This means that requesting a direct room subscription causes the connection to jump ahead to whatever
|
||||
// is in the database at the time of the call, rather than gradually converging by consuming live data.
|
||||
// This is fine, so long as we jump ahead on a per-room basis. We need to make sure (ideally) that the
|
||||
// room state is also pinned to the load position here, else you could see weird things in individual
|
||||
// responses such as an updated room.name without the associated m.room.name event (though this will
|
||||
// come through on the next request -> it converges to the right state so it isn't critical).
|
||||
s.loadPositions[roomID] = urd.RequestedLatestEvents.LatestNID
|
||||
}
|
||||
roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, roomToTimeline)
|
||||
rsm := roomSub.RequiredStateMap(s.userID)
|
||||
roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.loadPosition, rsm, roomToUsersInTimeline)
|
||||
// by reusing the same global load position anchor here, we can be sure that the state returned here
|
||||
// matches the timeline we loaded earlier - the race conditions happen around pubsub updates and not
|
||||
// the events table itself, so whatever position is picked based on this anchor is immutable.
|
||||
roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.anchorLoadPosition, rsm, roomToUsersInTimeline)
|
||||
if roomIDToState == nil { // e.g no required_state
|
||||
roomIDToState = make(map[string][]json.RawMessage)
|
||||
}
|
||||
@ -572,7 +605,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
IsDM: userRoomData.IsDM,
|
||||
JoinedCount: metadata.JoinCount,
|
||||
InvitedCount: &metadata.InviteCount,
|
||||
PrevBatch: userRoomData.RequestedPrevBatch,
|
||||
PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,9 +128,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
roomEventUpdate, _ := up.(*caches.RoomEventUpdate)
|
||||
// if this is a room event update we may not want to process this if the event nid is < loadPos,
|
||||
// as that means we have already taken it into account
|
||||
if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess && roomEventUpdate.EventData.NID < s.loadPosition {
|
||||
internal.Logf(ctx, "liveUpdate", "not process update %v < %v", roomEventUpdate.EventData.NID, s.loadPosition)
|
||||
return false
|
||||
if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess {
|
||||
// check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event
|
||||
// behind what we've processed before?
|
||||
lp := s.loadPositions[roomEventUpdate.RoomID()]
|
||||
if lp > 0 && roomEventUpdate.EventData.NID < lp {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// for initial rooms e.g a room comes into the window or a subscription now exists
|
||||
@ -161,9 +165,6 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
rooms := s.buildRooms(ctx, builder.BuildSubscriptions())
|
||||
for roomID, room := range rooms {
|
||||
response.Rooms[roomID] = room
|
||||
// remember what point we snapshotted this room, incase we see live events which we have
|
||||
// already snapshotted here.
|
||||
s.loadPositions[roomID] = s.loadPosition
|
||||
}
|
||||
|
||||
// TODO: find a better way to determine if the triggering event should be included e.g ask the lists?
|
||||
@ -195,7 +196,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
sender := roomEventUpdate.EventData.Sender
|
||||
if s.lazyCache.IsLazyLoading(roomID) && !s.lazyCache.IsSet(roomID, sender) {
|
||||
// load the state event
|
||||
memberEvent := s.globalCache.LoadStateEvent(context.Background(), roomID, s.loadPosition, "m.room.member", sender)
|
||||
memberEvent := s.globalCache.LoadStateEvent(context.Background(), roomID, s.loadPositions[roomID], "m.room.member", sender)
|
||||
if memberEvent != nil {
|
||||
r.RequiredState = append(r.RequiredState, memberEvent)
|
||||
s.lazyCache.AddUser(roomID, sender)
|
||||
@ -296,12 +297,9 @@ func (s *connStateLive) processGlobalUpdates(ctx context.Context, builder *Rooms
|
||||
})
|
||||
}
|
||||
|
||||
if isRoomEventUpdate {
|
||||
// TODO: we should do this check before lists.SetRoom
|
||||
if roomEventUpdate.EventData.NID <= s.loadPosition {
|
||||
return // if this update is in the past then ignore it
|
||||
}
|
||||
s.loadPosition = roomEventUpdate.EventData.NID
|
||||
// update the anchor for this new event
|
||||
if isRoomEventUpdate && roomEventUpdate.EventData.NID > s.anchorLoadPosition {
|
||||
s.anchorLoadPosition = roomEventUpdate.EventData.NID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func mockLazyRoomOverride(loadPos int64, roomIDs []string, maxTimelineEvents int
|
||||
result := make(map[string]caches.UserRoomData)
|
||||
for _, roomID := range roomIDs {
|
||||
u := caches.NewUserRoomData()
|
||||
u.RequestedTimeline = []json.RawMessage{[]byte(`{}`)}
|
||||
u.RequestedLatestEvents.Timeline = []json.RawMessage{[]byte(`{}`)}
|
||||
result[roomID] = u
|
||||
}
|
||||
return result
|
||||
@ -84,7 +84,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
roomB.RoomID: {userID},
|
||||
roomC.RoomID: {userID},
|
||||
})
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) {
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) {
|
||||
return 1, map[string]*internal.RoomMetadata{
|
||||
roomA.RoomID: &roomA,
|
||||
roomB.RoomID: &roomB,
|
||||
@ -93,7 +93,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
roomA.RoomID: {NID: 123, Timestamp: 123},
|
||||
roomB.RoomID: {NID: 456, Timestamp: 456},
|
||||
roomC.RoomID: {NID: 780, Timestamp: 789},
|
||||
}, nil
|
||||
}, nil, nil
|
||||
}
|
||||
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
|
||||
dispatcher.Register(context.Background(), userCache.UserID, userCache)
|
||||
@ -102,7 +102,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
result := make(map[string]caches.UserRoomData)
|
||||
for _, roomID := range roomIDs {
|
||||
u := caches.NewUserRoomData()
|
||||
u.RequestedTimeline = []json.RawMessage{timeline[roomID]}
|
||||
u.RequestedLatestEvents.Timeline = []json.RawMessage{timeline[roomID]}
|
||||
result[roomID] = u
|
||||
}
|
||||
return result
|
||||
@ -256,7 +256,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
roomID: {userID},
|
||||
})
|
||||
}
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) {
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) {
|
||||
roomMetadata := make(map[string]*internal.RoomMetadata)
|
||||
joinTimings = make(map[string]internal.EventMetadata)
|
||||
for i, r := range rooms {
|
||||
@ -266,7 +266,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
Timestamp: 123456,
|
||||
}
|
||||
}
|
||||
return 1, roomMetadata, joinTimings, nil
|
||||
return 1, roomMetadata, joinTimings, nil, nil
|
||||
}
|
||||
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
|
||||
userCache.LazyRoomDataOverride = mockLazyRoomOverride
|
||||
@ -433,7 +433,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
||||
roomC.RoomID: {userID},
|
||||
roomD.RoomID: {userID},
|
||||
})
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) {
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) {
|
||||
return 1, map[string]*internal.RoomMetadata{
|
||||
roomA.RoomID: &roomA,
|
||||
roomB.RoomID: &roomB,
|
||||
@ -444,7 +444,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
||||
roomB.RoomID: {NID: 2, Timestamp: 2},
|
||||
roomC.RoomID: {NID: 3, Timestamp: 3},
|
||||
roomD.RoomID: {NID: 4, Timestamp: 4},
|
||||
}, nil
|
||||
}, nil, nil
|
||||
|
||||
}
|
||||
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
|
||||
@ -537,7 +537,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
roomC.RoomID: testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{"body": "c"}),
|
||||
roomD.RoomID: testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{"body": "d"}),
|
||||
}
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) {
|
||||
globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) {
|
||||
return 1, map[string]*internal.RoomMetadata{
|
||||
roomA.RoomID: &roomA,
|
||||
roomB.RoomID: &roomB,
|
||||
@ -548,14 +548,14 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
roomB.RoomID: {NID: 2, Timestamp: 2},
|
||||
roomC.RoomID: {NID: 3, Timestamp: 3},
|
||||
roomD.RoomID: {NID: 4, Timestamp: 4},
|
||||
}, nil
|
||||
}, nil, nil
|
||||
}
|
||||
userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{})
|
||||
userCache.LazyRoomDataOverride = func(loadPos int64, roomIDs []string, maxTimelineEvents int) map[string]caches.UserRoomData {
|
||||
result := make(map[string]caches.UserRoomData)
|
||||
for _, roomID := range roomIDs {
|
||||
u := caches.NewUserRoomData()
|
||||
u.RequestedTimeline = []json.RawMessage{timeline[roomID]}
|
||||
u.RequestedLatestEvents.Timeline = []json.RawMessage{timeline[roomID]}
|
||||
result[roomID] = u
|
||||
}
|
||||
return result
|
||||
|
@ -1,6 +1,5 @@
|
||||
package handler
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
@ -67,7 +66,7 @@ type SyncLiveHandler struct {
|
||||
}
|
||||
|
||||
func NewSync3Handler(
|
||||
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string,
|
||||
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int,
|
||||
) (*SyncLiveHandler, error) {
|
||||
logger.Info().Msg("creating handler")
|
||||
@ -225,8 +224,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
if req.ContentLength != 0 {
|
||||
defer req.Body.Close()
|
||||
if err := json.NewDecoder(req.Body).Decode(&requestBody); err != nil {
|
||||
log.Err(err).Msg("failed to read/decode request body")
|
||||
internal.GetSentryHubFromContextOrDefault(req.Context()).CaptureException(err)
|
||||
log.Warn().Err(err).Msg("failed to read/decode request body")
|
||||
return &internal.HandlerError{
|
||||
StatusCode: 400,
|
||||
Err: err,
|
||||
@ -339,6 +337,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
// When this function returns, the connection is alive and active.
|
||||
|
||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
|
||||
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
|
||||
defer task.End()
|
||||
var conn *sync3.Conn
|
||||
// Extract an access token
|
||||
accessToken, err := internal.ExtractAccessToken(req)
|
||||
@ -371,6 +371,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
}
|
||||
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
|
||||
internal.SetRequestContextUserID(req.Context(), token.UserID, token.DeviceID)
|
||||
internal.Logf(taskCtx, "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID)
|
||||
|
||||
// Record the fact that we've recieved a request from this token
|
||||
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
|
||||
@ -396,8 +397,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
return nil, internal.ExpiredSessionError()
|
||||
}
|
||||
|
||||
log.Trace().Msg("checking poller exists and is running")
|
||||
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
|
||||
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
|
||||
h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
|
||||
log.Trace().Msg("poller exists and is running")
|
||||
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
|
||||
@ -458,14 +459,14 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger
|
||||
var token *sync2.Token
|
||||
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
|
||||
// Create a brand-new row for this token.
|
||||
token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
|
||||
token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now())
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure we have a device row for this token.
|
||||
err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID)
|
||||
err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
|
||||
return err
|
||||
|
103
tests-integration/db_test.go
Normal file
103
tests-integration/db_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package syncv3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
syncv3 "github.com/matrix-org/sliding-sync"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
)
|
||||
|
||||
// Test that the proxy works fine with low max conns. Low max conns can be a problem
|
||||
// if a request A needs 2 conns to respond and that blocks forward progress on the server,
|
||||
// and the request can only obtain 1 conn.
|
||||
func TestMaxDBConns(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
opts := syncv3.Opts{
|
||||
DBMaxConns: 1,
|
||||
}
|
||||
v3 := runTestServer(t, v2, pqString, opts)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
testMaxDBConns := func() {
|
||||
// make N users and drip feed some events, make sure they are all seen
|
||||
numUsers := 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
userID := fmt.Sprintf("@maxconns_%d:localhost", n)
|
||||
token := fmt.Sprintf("maxconns_%d", n)
|
||||
roomID := fmt.Sprintf("!maxconns_%d", n)
|
||||
v2.addAccount(t, userID, token)
|
||||
v2.queueResponse(userID, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
state: createRoomState(t, userID, time.Now()),
|
||||
}),
|
||||
},
|
||||
})
|
||||
// initial sync
|
||||
res := v3.mustDoV3Request(t, token, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
}},
|
||||
})
|
||||
t.Logf("user %s has done an initial /sync OK", userID)
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(
|
||||
m.MatchV3SyncOp(0, 0, []string{roomID}),
|
||||
)), m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
m.MatchJoinCount(1),
|
||||
},
|
||||
}))
|
||||
// drip feed and get update
|
||||
dripMsg := testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{
|
||||
"msgtype": "m.text",
|
||||
"body": "drip drip",
|
||||
})
|
||||
v2.queueResponse(userID, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{
|
||||
dripMsg,
|
||||
},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Logf("user %s has queued the drip", userID)
|
||||
v2.waitUntilEmpty(t, userID)
|
||||
t.Logf("user %s poller has received the drip", userID)
|
||||
res = v3.mustDoV3RequestWithPos(t, token, res.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
m.MatchRoomTimelineMostRecent(1, []json.RawMessage{dripMsg}),
|
||||
},
|
||||
}))
|
||||
t.Logf("user %s has received the drip", userID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
testMaxDBConns()
|
||||
v3.restart(t, v2, pqString, opts)
|
||||
testMaxDBConns()
|
||||
}
|
@ -291,11 +291,11 @@ func (s *testV3Server) close() {
|
||||
s.h2.Teardown()
|
||||
}
|
||||
|
||||
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string) {
|
||||
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string, opts ...syncv3.Opts) {
|
||||
t.Helper()
|
||||
log.Printf("restarting server")
|
||||
s.close()
|
||||
ss := runTestServer(t, v2, pq)
|
||||
ss := runTestServer(t, v2, pq, opts...)
|
||||
// replace all the fields which will be close()d to ensure we don't leak
|
||||
s.srv = ss.srv
|
||||
s.h2 = ss.h2
|
||||
@ -366,20 +366,22 @@ func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postg
|
||||
//tests often repeat requests. To ensure tests remain fast, reduce the spam protection limits.
|
||||
sync3.SpamProtectionInterval = time.Millisecond
|
||||
|
||||
metricsEnabled := false
|
||||
maxPendingEventUpdates := 200
|
||||
combinedOpts := syncv3.Opts{
|
||||
TestingSynchronousPubsub: true, // critical to avoid flakey tests
|
||||
AddPrometheusMetrics: false,
|
||||
MaxPendingEventUpdates: 200,
|
||||
}
|
||||
if len(opts) > 0 {
|
||||
metricsEnabled = opts[0].AddPrometheusMetrics
|
||||
if opts[0].MaxPendingEventUpdates > 0 {
|
||||
maxPendingEventUpdates = opts[0].MaxPendingEventUpdates
|
||||
opt := opts[0]
|
||||
combinedOpts.AddPrometheusMetrics = opt.AddPrometheusMetrics
|
||||
combinedOpts.DBConnMaxIdleTime = opt.DBConnMaxIdleTime
|
||||
combinedOpts.DBMaxConns = opt.DBMaxConns
|
||||
if opt.MaxPendingEventUpdates > 0 {
|
||||
combinedOpts.MaxPendingEventUpdates = opt.MaxPendingEventUpdates
|
||||
handler.BufferWaitTime = 5 * time.Millisecond
|
||||
}
|
||||
}
|
||||
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{
|
||||
TestingSynchronousPubsub: true, // critical to avoid flakey tests
|
||||
MaxPendingEventUpdates: maxPendingEventUpdates,
|
||||
AddPrometheusMetrics: metricsEnabled,
|
||||
})
|
||||
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), combinedOpts)
|
||||
// for ease of use we don't start v2 pollers at startup in tests
|
||||
r := mux.NewRouter()
|
||||
r.Use(hlog.NewHandler(logger))
|
||||
|
20
v3.go
20
v3.go
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
@ -36,6 +37,9 @@ type Opts struct {
|
||||
// if true, publishing messages will block until the consumer has consumed it.
|
||||
// Assumes a single producer and a single consumer.
|
||||
TestingSynchronousPubsub bool
|
||||
|
||||
DBMaxConns int
|
||||
DBConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
type server struct {
|
||||
@ -75,6 +79,18 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
}
|
||||
store := state.NewStorage(postgresURI)
|
||||
storev2 := sync2.NewStore(postgresURI, secret)
|
||||
for _, db := range []*sqlx.DB{store.DB, storev2.DB} {
|
||||
if opts.DBMaxConns > 0 {
|
||||
// https://github.com/go-sql-driver/mysql#important-settings
|
||||
// "db.SetMaxIdleConns() is recommended to be set same to db.SetMaxOpenConns(). When it is smaller
|
||||
// than SetMaxOpenConns(), connections can be opened and closed much more frequently than you expect."
|
||||
db.SetMaxOpenConns(opts.DBMaxConns)
|
||||
db.SetMaxIdleConns(opts.DBMaxConns)
|
||||
}
|
||||
if opts.DBConnMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(opts.DBConnMaxIdleTime)
|
||||
}
|
||||
}
|
||||
bufferSize := 50
|
||||
deviceDataUpdateFrequency := time.Second
|
||||
if opts.TestingSynchronousPubsub {
|
||||
@ -88,14 +104,14 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
|
||||
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
|
||||
// create v2 handler
|
||||
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
|
||||
h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pMap.SetCallbacks(h2)
|
||||
|
||||
// create v3 handler
|
||||
h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates)
|
||||
h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user