diff --git a/state/storage.go b/state/storage.go index bf83823..0d0faa7 100644 --- a/state/storage.go +++ b/state/storage.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "time" "golang.org/x/exp/slices" @@ -69,6 +70,8 @@ type Storage struct { ReceiptTable *ReceiptTable DB *sqlx.DB MaxTimelineLimit int + shutdownCh chan struct{} + shutdown bool } func NewStorage(postgresURI string) *Storage { @@ -104,6 +107,7 @@ func NewStorageWithDB(db *sqlx.DB, addPrometheusMetrics bool) *Storage { ReceiptTable: NewReceiptTable(db), DB: db, MaxTimelineLimit: 50, + shutdownCh: make(chan struct{}), } } @@ -758,6 +762,50 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, return result, err } +// Remove state snapshots which cannot be accessed by clients. The latest MaxTimelineEvents +// snapshots must be kept, +1 for the current state. This handles the worst case where all +// MaxTimelineEvents are state events and hence each event makes a new snapshot. We can safely +// delete all snapshots older than this, as it's not possible to reach this snapshot as the proxy +// does not handle historical state (deferring to the homeserver for that). +func (s *Storage) RemoveInaccessibleStateSnapshots() error { + numToKeep := s.MaxTimelineLimit + 1 + // Create a CTE which ranks each snapshot so we can figure out which snapshots to delete + // then execute the delete using the CTE. + // + // A per-room version of this query: + // WITH ranked_snapshots AS ( + // SELECT + // snapshot_id, + // room_id, + // ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num + // FROM syncv3_snapshots + // ) + // DELETE FROM syncv3_snapshots WHERE snapshot_id IN( + // SELECT snapshot_id FROM ranked_snapshots WHERE row_num > 51 AND room_id='!....' + // ); + awfulQuery := fmt.Sprintf(`WITH ranked_snapshots AS ( + SELECT + snapshot_id, + room_id, + ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num + FROM + syncv3_snapshots + ) + DELETE FROM syncv3_snapshots USING ranked_snapshots + WHERE syncv3_snapshots.snapshot_id = ranked_snapshots.snapshot_id + AND ranked_snapshots.row_num > %d;`, numToKeep) + + result, err := s.DB.Exec(awfulQuery) + if err != nil { + return fmt.Errorf("failed to RemoveInaccessibleStateSnapshots: Exec %s", err) + } + rowsAffected, err := result.RowsAffected() + if err == nil { + logger.Info().Int64("rows_affected", rowsAffected).Msg("RemoveInaccessibleStateSnapshots: deleted rows") + } + return nil +} + func (s *Storage) GetClosestPrevBatch(roomID string, eventNID int64) (prevBatch string) { var err error sqlutil.WithTransaction(s.DB, func(txn *sqlx.Tx) error { @@ -1024,6 +1072,34 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe return joinedMembers, metadata, nil } +func (s *Storage) Cleaner(n time.Duration) { +Loop: + for { + select { + case <-time.After(n): + now := time.Now() + boundaryTime := now.Add(-1 * n) + if n < time.Hour { + boundaryTime = now.Add(-1 * time.Hour) + } + logger.Info().Time("boundaryTime", boundaryTime).Msg("Cleaner running") + err := s.TransactionsTable.Clean(boundaryTime) + if err != nil { + logger.Warn().Err(err).Msg("failed to clean txn ID table") + sentry.CaptureException(err) + } + // we also want to clean up stale state snapshots which are inaccessible, to + // keep the size of the syncv3_snapshots table low. + if err = s.RemoveInaccessibleStateSnapshots(); err != nil { + logger.Warn().Err(err).Msg("failed to remove inaccessible state snapshots") + sentry.CaptureException(err) + } + case <-s.shutdownCh: + break Loop + } + } +} + func (s *Storage) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) { roomToNID = make(map[string]int64) err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error { @@ -1113,6 +1189,11 @@ func (s *Storage) determineJoinedRoomsFromMemberships(membershipEvents []Event) } func (s *Storage) Teardown() { + if !s.shutdown { + s.shutdown = true + close(s.shutdownCh) + } + err := s.Accumulator.db.Close() if err != nil { panic("Storage.Teardown: " + err.Error()) diff --git a/state/storage_test.go b/state/storage_test.go index f11d549..4cef726 100644 --- a/state/storage_test.go +++ b/state/storage_test.go @@ -5,11 +5,13 @@ import ( "context" "encoding/json" "fmt" + "math/rand" "reflect" "sort" "testing" "time" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/sliding-sync/sync2" "github.com/jmoiron/sqlx" @@ -913,6 +915,174 @@ func TestStorage_FetchMemberships(t *testing.T) { assertValue(t, "joins", leaves, []string{"@chris:test", "@david:test", "@glory:test", "@helen:test"}) } +type persistOpts struct { + withInitialEvents bool + numTimelineEvents int + ofWhichNumState int +} + +func mustPersistEvents(t *testing.T, roomID string, store *Storage, opts persistOpts) { + t.Helper() + var events []json.RawMessage + if opts.withInitialEvents { + events = createInitialEvents(t, userID) + } + numAddedStateEvents := 0 + for i := 0; i < opts.numTimelineEvents; i++ { + var ev json.RawMessage + if numAddedStateEvents < opts.ofWhichNumState { + numAddedStateEvents++ + ev = testutils.NewStateEvent(t, "some_kind_of_state", fmt.Sprintf("%d", rand.Int63()), userID, map[string]interface{}{ + "num": numAddedStateEvents, + }) + } else { + ev = testutils.NewEvent(t, "some_kind_of_message", userID, map[string]interface{}{ + "msg": "yep", + }) + } + events = append(events, ev) + } + mustAccumulate(t, store, roomID, events) +} + +func mustAccumulate(t *testing.T, store *Storage, roomID string, events []json.RawMessage) { + t.Helper() + _, err := store.Accumulate(userID, roomID, sync2.TimelineResponse{ + Events: events, + }) + if err != nil { + t.Fatalf("Failed to accumulate: %s", err) + } +} + +func mustHaveNumSnapshots(t *testing.T, db *sqlx.DB, roomID string, numSnapshots int) { + t.Helper() + var val int + err := db.QueryRow(`SELECT count(*) FROM syncv3_snapshots WHERE room_id=$1`, roomID).Scan(&val) + if err != nil { + t.Fatalf("mustHaveNumSnapshots: %s", err) + } + if val != numSnapshots { + t.Fatalf("mustHaveNumSnapshots: got %d want %d snapshots", val, numSnapshots) + } +} + +func mustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("err: %s", err) +} + +func TestRemoveInaccessibleStateSnapshots(t *testing.T) { + store := NewStorage(postgresConnectionString) + store.MaxTimelineLimit = 50 // we nuke if we have >50+1 snapshots + + roomOnlyMessages := "!TestRemoveInaccessibleStateSnapshots_roomOnlyMessages:localhost" + mustPersistEvents(t, roomOnlyMessages, store, persistOpts{ + withInitialEvents: true, + numTimelineEvents: 100, + ofWhichNumState: 0, + }) + roomOnlyState := "!TestRemoveInaccessibleStateSnapshots_roomOnlyState:localhost" + mustPersistEvents(t, roomOnlyState, store, persistOpts{ + withInitialEvents: true, + numTimelineEvents: 100, + ofWhichNumState: 100, + }) + roomPartialStateAndMessages := "!TestRemoveInaccessibleStateSnapshots_roomPartialStateAndMessages:localhost" + mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{ + withInitialEvents: true, + numTimelineEvents: 100, + ofWhichNumState: 30, + }) + roomOverwriteState := "TestRemoveInaccessibleStateSnapshots_roomOverwriteState:localhost" + mustPersistEvents(t, roomOverwriteState, store, persistOpts{ + withInitialEvents: true, + }) + mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1})}) + mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 2})}) + mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // initial state only, one for each state event + mustHaveNumSnapshots(t, store.DB, roomOnlyState, 104) // initial state + 100 state events + mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // initial state + 30 state events + mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // initial state + 2 overwrite state events + mustNotError(t, store.RemoveInaccessibleStateSnapshots()) + mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // it should not be touched as 4 < 51 + mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // it should be capped at 51 + mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // it should not be touched as 34 < 51 + mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51 + // calling it again does nothing + mustNotError(t, store.RemoveInaccessibleStateSnapshots()) + mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) + mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) + mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) + mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51 + // adding one extra state snapshot to each room and repeating RemoveInaccessibleStateSnapshots + mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1}) + mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1}) + mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1}) + mustNotError(t, store.RemoveInaccessibleStateSnapshots()) + mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5) + mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // still capped + mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35) + // adding 51 timeline events and repeating RemoveInaccessibleStateSnapshots does nothing + mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 51}) + mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 51}) + mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 51}) + mustNotError(t, store.RemoveInaccessibleStateSnapshots()) + mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5) + mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) + mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35) + + // overwrite 52 times and check the current state is 52 (shows we are keeping the right snapshots) + for i := 0; i < 52; i++ { + mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1 + i})}) + } + mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 58) + mustNotError(t, store.RemoveInaccessibleStateSnapshots()) + mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 51) + roomsTable := NewRoomsTable(store.DB) + mustNotError(t, sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) error { + snapID, err := roomsTable.CurrentAfterSnapshotID(txn, roomOverwriteState) + if err != nil { + return err + } + state, err := store.StateSnapshot(snapID) + if err != nil { + return err + } + // find the 'overwrite' event and make sure the val is 52 + for _, ev := range state { + evv := gjson.ParseBytes(ev) + if evv.Get("type").Str != "overwrite" { + continue + } + if evv.Get("content.val").Int() != 52 { + return fmt.Errorf("val for overwrite state event was not 52: %v", evv.Raw) + } + } + return nil + })) +} + +func createInitialEvents(t *testing.T, creator string) []json.RawMessage { + t.Helper() + baseTimestamp := time.Now() + var pl gomatrixserverlib.PowerLevelContent + pl.Defaults() + pl.Users = map[string]int64{ + creator: 100, + } + // all with the same timestamp as they get made atomically + return []json.RawMessage{ + testutils.NewStateEvent(t, "m.room.create", "", creator, map[string]interface{}{"creator": creator}, testutils.WithTimestamp(baseTimestamp)), + testutils.NewJoinEvent(t, creator, testutils.WithTimestamp(baseTimestamp)), + testutils.NewStateEvent(t, "m.room.power_levels", "", creator, pl, testutils.WithTimestamp(baseTimestamp)), + testutils.NewStateEvent(t, "m.room.join_rules", "", creator, map[string]interface{}{"join_rule": "public"}, testutils.WithTimestamp(baseTimestamp)), + } +} + func cleanDB(t *testing.T) error { // make a fresh DB which is unpolluted from other tests db, close := connectToDB(t)