mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #422 from matrix-org/kegan/db-size
Clean the syncv3_snapshots table periodically
This commit is contained in:
commit
abf5aeff82
@ -220,6 +220,7 @@ func main() {
|
||||
})
|
||||
|
||||
go h2.StartV2Pollers()
|
||||
go h2.Store.Cleaner(time.Hour)
|
||||
if args[EnvOTLP] != "" {
|
||||
h3 = otelhttp.NewHandler(h3, "Sync")
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@ -69,6 +70,8 @@ type Storage struct {
|
||||
ReceiptTable *ReceiptTable
|
||||
DB *sqlx.DB
|
||||
MaxTimelineLimit int
|
||||
shutdownCh chan struct{}
|
||||
shutdown bool
|
||||
}
|
||||
|
||||
func NewStorage(postgresURI string) *Storage {
|
||||
@ -104,6 +107,7 @@ func NewStorageWithDB(db *sqlx.DB, addPrometheusMetrics bool) *Storage {
|
||||
ReceiptTable: NewReceiptTable(db),
|
||||
DB: db,
|
||||
MaxTimelineLimit: 50,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -758,6 +762,50 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Remove state snapshots which cannot be accessed by clients. The latest MaxTimelineEvents
|
||||
// snapshots must be kept, +1 for the current state. This handles the worst case where all
|
||||
// MaxTimelineEvents are state events and hence each event makes a new snapshot. We can safely
|
||||
// delete all snapshots older than this, as it's not possible to reach this snapshot as the proxy
|
||||
// does not handle historical state (deferring to the homeserver for that).
|
||||
func (s *Storage) RemoveInaccessibleStateSnapshots() error {
|
||||
numToKeep := s.MaxTimelineLimit + 1
|
||||
// Create a CTE which ranks each snapshot so we can figure out which snapshots to delete
|
||||
// then execute the delete using the CTE.
|
||||
//
|
||||
// A per-room version of this query:
|
||||
// WITH ranked_snapshots AS (
|
||||
// SELECT
|
||||
// snapshot_id,
|
||||
// room_id,
|
||||
// ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
|
||||
// FROM syncv3_snapshots
|
||||
// )
|
||||
// DELETE FROM syncv3_snapshots WHERE snapshot_id IN(
|
||||
// SELECT snapshot_id FROM ranked_snapshots WHERE row_num > 51 AND room_id='!....'
|
||||
// );
|
||||
awfulQuery := fmt.Sprintf(`WITH ranked_snapshots AS (
|
||||
SELECT
|
||||
snapshot_id,
|
||||
room_id,
|
||||
ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
|
||||
FROM
|
||||
syncv3_snapshots
|
||||
)
|
||||
DELETE FROM syncv3_snapshots USING ranked_snapshots
|
||||
WHERE syncv3_snapshots.snapshot_id = ranked_snapshots.snapshot_id
|
||||
AND ranked_snapshots.row_num > %d;`, numToKeep)
|
||||
|
||||
result, err := s.DB.Exec(awfulQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to RemoveInaccessibleStateSnapshots: Exec %s", err)
|
||||
}
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err == nil {
|
||||
logger.Info().Int64("rows_affected", rowsAffected).Msg("RemoveInaccessibleStateSnapshots: deleted rows")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Storage) GetClosestPrevBatch(roomID string, eventNID int64) (prevBatch string) {
|
||||
var err error
|
||||
sqlutil.WithTransaction(s.DB, func(txn *sqlx.Tx) error {
|
||||
@ -1024,6 +1072,34 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
|
||||
return joinedMembers, metadata, nil
|
||||
}
|
||||
|
||||
func (s *Storage) Cleaner(n time.Duration) {
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-time.After(n):
|
||||
now := time.Now()
|
||||
boundaryTime := now.Add(-1 * n)
|
||||
if n < time.Hour {
|
||||
boundaryTime = now.Add(-1 * time.Hour)
|
||||
}
|
||||
logger.Info().Time("boundaryTime", boundaryTime).Msg("Cleaner running")
|
||||
err := s.TransactionsTable.Clean(boundaryTime)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to clean txn ID table")
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
// we also want to clean up stale state snapshots which are inaccessible, to
|
||||
// keep the size of the syncv3_snapshots table low.
|
||||
if err = s.RemoveInaccessibleStateSnapshots(); err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to remove inaccessible state snapshots")
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
case <-s.shutdownCh:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Storage) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
|
||||
roomToNID = make(map[string]int64)
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
@ -1113,6 +1189,11 @@ func (s *Storage) determineJoinedRoomsFromMemberships(membershipEvents []Event)
|
||||
}
|
||||
|
||||
func (s *Storage) Teardown() {
|
||||
if !s.shutdown {
|
||||
s.shutdown = true
|
||||
close(s.shutdownCh)
|
||||
}
|
||||
|
||||
err := s.Accumulator.db.Close()
|
||||
if err != nil {
|
||||
panic("Storage.Teardown: " + err.Error())
|
||||
|
@ -5,11 +5,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
@ -913,6 +915,174 @@ func TestStorage_FetchMemberships(t *testing.T) {
|
||||
assertValue(t, "joins", leaves, []string{"@chris:test", "@david:test", "@glory:test", "@helen:test"})
|
||||
}
|
||||
|
||||
type persistOpts struct {
|
||||
withInitialEvents bool
|
||||
numTimelineEvents int
|
||||
ofWhichNumState int
|
||||
}
|
||||
|
||||
func mustPersistEvents(t *testing.T, roomID string, store *Storage, opts persistOpts) {
|
||||
t.Helper()
|
||||
var events []json.RawMessage
|
||||
if opts.withInitialEvents {
|
||||
events = createInitialEvents(t, userID)
|
||||
}
|
||||
numAddedStateEvents := 0
|
||||
for i := 0; i < opts.numTimelineEvents; i++ {
|
||||
var ev json.RawMessage
|
||||
if numAddedStateEvents < opts.ofWhichNumState {
|
||||
numAddedStateEvents++
|
||||
ev = testutils.NewStateEvent(t, "some_kind_of_state", fmt.Sprintf("%d", rand.Int63()), userID, map[string]interface{}{
|
||||
"num": numAddedStateEvents,
|
||||
})
|
||||
} else {
|
||||
ev = testutils.NewEvent(t, "some_kind_of_message", userID, map[string]interface{}{
|
||||
"msg": "yep",
|
||||
})
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
mustAccumulate(t, store, roomID, events)
|
||||
}
|
||||
|
||||
func mustAccumulate(t *testing.T, store *Storage, roomID string, events []json.RawMessage) {
|
||||
t.Helper()
|
||||
_, err := store.Accumulate(userID, roomID, sync2.TimelineResponse{
|
||||
Events: events,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to accumulate: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustHaveNumSnapshots(t *testing.T, db *sqlx.DB, roomID string, numSnapshots int) {
|
||||
t.Helper()
|
||||
var val int
|
||||
err := db.QueryRow(`SELECT count(*) FROM syncv3_snapshots WHERE room_id=$1`, roomID).Scan(&val)
|
||||
if err != nil {
|
||||
t.Fatalf("mustHaveNumSnapshots: %s", err)
|
||||
}
|
||||
if val != numSnapshots {
|
||||
t.Fatalf("mustHaveNumSnapshots: got %d want %d snapshots", val, numSnapshots)
|
||||
}
|
||||
}
|
||||
|
||||
func mustNotError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
func TestRemoveInaccessibleStateSnapshots(t *testing.T) {
|
||||
store := NewStorage(postgresConnectionString)
|
||||
store.MaxTimelineLimit = 50 // we nuke if we have >50+1 snapshots
|
||||
|
||||
roomOnlyMessages := "!TestRemoveInaccessibleStateSnapshots_roomOnlyMessages:localhost"
|
||||
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{
|
||||
withInitialEvents: true,
|
||||
numTimelineEvents: 100,
|
||||
ofWhichNumState: 0,
|
||||
})
|
||||
roomOnlyState := "!TestRemoveInaccessibleStateSnapshots_roomOnlyState:localhost"
|
||||
mustPersistEvents(t, roomOnlyState, store, persistOpts{
|
||||
withInitialEvents: true,
|
||||
numTimelineEvents: 100,
|
||||
ofWhichNumState: 100,
|
||||
})
|
||||
roomPartialStateAndMessages := "!TestRemoveInaccessibleStateSnapshots_roomPartialStateAndMessages:localhost"
|
||||
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{
|
||||
withInitialEvents: true,
|
||||
numTimelineEvents: 100,
|
||||
ofWhichNumState: 30,
|
||||
})
|
||||
roomOverwriteState := "TestRemoveInaccessibleStateSnapshots_roomOverwriteState:localhost"
|
||||
mustPersistEvents(t, roomOverwriteState, store, persistOpts{
|
||||
withInitialEvents: true,
|
||||
})
|
||||
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1})})
|
||||
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 2})})
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // initial state only, one for each state event
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 104) // initial state + 100 state events
|
||||
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // initial state + 30 state events
|
||||
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // initial state + 2 overwrite state events
|
||||
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // it should not be touched as 4 < 51
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // it should be capped at 51
|
||||
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // it should not be touched as 34 < 51
|
||||
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51
|
||||
// calling it again does nothing
|
||||
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4)
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51)
|
||||
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34)
|
||||
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51
|
||||
// adding one extra state snapshot to each room and repeating RemoveInaccessibleStateSnapshots
|
||||
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
|
||||
mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
|
||||
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
|
||||
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5)
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // still capped
|
||||
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35)
|
||||
// adding 51 timeline events and repeating RemoveInaccessibleStateSnapshots does nothing
|
||||
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 51})
|
||||
mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 51})
|
||||
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 51})
|
||||
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5)
|
||||
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51)
|
||||
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35)
|
||||
|
||||
// overwrite 52 times and check the current state is 52 (shows we are keeping the right snapshots)
|
||||
for i := 0; i < 52; i++ {
|
||||
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1 + i})})
|
||||
}
|
||||
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 58)
|
||||
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
|
||||
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 51)
|
||||
roomsTable := NewRoomsTable(store.DB)
|
||||
mustNotError(t, sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) error {
|
||||
snapID, err := roomsTable.CurrentAfterSnapshotID(txn, roomOverwriteState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := store.StateSnapshot(snapID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// find the 'overwrite' event and make sure the val is 52
|
||||
for _, ev := range state {
|
||||
evv := gjson.ParseBytes(ev)
|
||||
if evv.Get("type").Str != "overwrite" {
|
||||
continue
|
||||
}
|
||||
if evv.Get("content.val").Int() != 52 {
|
||||
return fmt.Errorf("val for overwrite state event was not 52: %v", evv.Raw)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func createInitialEvents(t *testing.T, creator string) []json.RawMessage {
|
||||
t.Helper()
|
||||
baseTimestamp := time.Now()
|
||||
var pl gomatrixserverlib.PowerLevelContent
|
||||
pl.Defaults()
|
||||
pl.Users = map[string]int64{
|
||||
creator: 100,
|
||||
}
|
||||
// all with the same timestamp as they get made atomically
|
||||
return []json.RawMessage{
|
||||
testutils.NewStateEvent(t, "m.room.create", "", creator, map[string]interface{}{"creator": creator}, testutils.WithTimestamp(baseTimestamp)),
|
||||
testutils.NewJoinEvent(t, creator, testutils.WithTimestamp(baseTimestamp)),
|
||||
testutils.NewStateEvent(t, "m.room.power_levels", "", creator, pl, testutils.WithTimestamp(baseTimestamp)),
|
||||
testutils.NewStateEvent(t, "m.room.join_rules", "", creator, map[string]interface{}{"join_rule": "public"}, testutils.WithTimestamp(baseTimestamp)),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanDB(t *testing.T) error {
|
||||
// make a fresh DB which is unpolluted from other tests
|
||||
db, close := connectToDB(t)
|
||||
|
Loading…
x
Reference in New Issue
Block a user