mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Add Accumulator.Initialise with tests
This commit is contained in:
parent
f3e0f96d91
commit
cd20d07d9f
215
state/accumulator.go
Normal file
215
state/accumulator.go
Normal file
@ -0,0 +1,215 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
})
|
||||
|
||||
// Accumulator tracks room state and timelines.
|
||||
//
|
||||
// In order for it to remain simple(ish), the accumulator DOES NOT SUPPORT arbitrary timeline gaps.
|
||||
// There is an Initialise function for new rooms (with some pre-determined state) and then a constant
|
||||
// Accumulate function for timeline events. v2 sync must be called with a large enough timeline.limit
|
||||
// for this to work!
|
||||
type Accumulator struct {
|
||||
mu *sync.Mutex // lock for locks
|
||||
// TODO: unbounded on number of rooms
|
||||
locks map[string]*sync.Mutex // room_id -> mutex
|
||||
|
||||
db *sqlx.DB
|
||||
roomsTable *RoomsTable
|
||||
eventsTable *EventTable
|
||||
snapshotTable *SnapshotTable
|
||||
snapshotRefCountTable *SnapshotRefCountsTable
|
||||
}
|
||||
|
||||
func NewAccumulator(postgresURI string) *Accumulator {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
return &Accumulator{
|
||||
db: db,
|
||||
mu: &sync.Mutex{},
|
||||
locks: make(map[string]*sync.Mutex),
|
||||
roomsTable: NewRoomsTable(db),
|
||||
eventsTable: NewEventTable(db),
|
||||
snapshotTable: NewSnapshotsTable(db),
|
||||
snapshotRefCountTable: NewSnapshotRefCountsTable(db),
|
||||
}
|
||||
}
|
||||
|
||||
// obtain a per-room lock
|
||||
func (a *Accumulator) mutex(roomID string) *sync.Mutex {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
lock, ok := a.locks[roomID]
|
||||
if !ok {
|
||||
lock = &sync.Mutex{}
|
||||
a.locks[roomID] = lock
|
||||
}
|
||||
return lock
|
||||
}
|
||||
|
||||
// clearSnapshots deletes all snapshots with 0 refs to it
|
||||
func (a *Accumulator) clearSnapshots(txn *sqlx.Tx) {
|
||||
snapshotIDs, err := a.snapshotRefCountTable.DeleteEmptyRefs(txn)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("failed to DeleteEmptyRefs")
|
||||
return
|
||||
}
|
||||
err = a.snapshotTable.Delete(txn, snapshotIDs)
|
||||
if err != nil {
|
||||
log.Err(err).Ints("snapshots", snapshotIDs).Msg("failed to delete snapshot IDs")
|
||||
}
|
||||
}
|
||||
|
||||
// moveSnapshotRef decrements the from snapshot ID and increments the to snapshot ID
|
||||
// Clears the from snapshot if there are 0 refs to it
|
||||
func (a *Accumulator) moveSnapshotRef(txn *sqlx.Tx, from, to int) error {
|
||||
if from != 0 {
|
||||
count, err := a.snapshotRefCountTable.Decrement(txn, from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
a.clearSnapshots(txn)
|
||||
}
|
||||
}
|
||||
_, err := a.snapshotRefCountTable.Increment(txn, to)
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialise starts a new sync accumulator for the given room using the given state as a baseline.
|
||||
// This will only take effect if this is the first time the v3 server has seen this room, and it wasn't
|
||||
// possible to get all events up to the create event (e.g Matrix HQ).
|
||||
//
|
||||
// This function:
|
||||
// - Stores these events
|
||||
// - Sets up the current snapshot based on the state list given.
|
||||
func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
if len(state) == 0 {
|
||||
return nil
|
||||
}
|
||||
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
// Attempt to short-circuit. This has to be done inside a transaction to make sure
|
||||
// we don't race with multiple calls to Initialise with the same room ID.
|
||||
snapshotID, err := a.roomsTable.CurrentSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err)
|
||||
}
|
||||
if snapshotID > 0 {
|
||||
// we only initialise rooms once
|
||||
log.Info().Str("room_id", roomID).Int("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert the events
|
||||
events := make([]Event, len(state))
|
||||
for i := range events {
|
||||
events[i] = Event{
|
||||
JSON: state[i],
|
||||
}
|
||||
}
|
||||
numNew, err := a.eventsTable.Insert(txn, events)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if numNew == 0 {
|
||||
// we don't have a current snapshot for this room but yet no events are new,
|
||||
// no idea how this should be handled.
|
||||
log.Error().Str("room_id", roomID).Msg(
|
||||
"Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug.",
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// pull out the event NIDs we just inserted
|
||||
eventIDs := make([]string, len(events))
|
||||
for i := range eventIDs {
|
||||
eventIDs[i] = events[i].ID
|
||||
}
|
||||
nids, err := a.eventsTable.SelectNIDsByIDs(txn, eventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make a current snapshot
|
||||
snapshot := &SnapshotRow{
|
||||
RoomID: roomID,
|
||||
Events: pq.Int64Array(nids),
|
||||
}
|
||||
err = a.snapshotTable.Insert(txn, snapshot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment the ref counter
|
||||
err = a.moveSnapshotRef(txn, 0, snapshot.SnapshotID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the snapshot ID as the current state
|
||||
return a.roomsTable.UpdateCurrentSnapshotID(txn, roomID, snapshot.SnapshotID)
|
||||
})
|
||||
}
|
||||
|
||||
// Accumulate internal state from a user's sync response. Locks per-room.
|
||||
//
|
||||
// This function does several things:
|
||||
// - It ensures all events are persisted in the database. This is shared amongst users.
|
||||
// - If the last event in the timeline has been stored before, then it short circuits and returns.
|
||||
// This is because we must have already processed this in order for the event to exist in the database,
|
||||
// and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot (as this now represents the current state)
|
||||
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
|
||||
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
|
||||
func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) error {
|
||||
if len(timeline) == 0 {
|
||||
return nil
|
||||
}
|
||||
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithTransaction runs a block of code passing in an SQL transaction
|
||||
// If the code returns an error or panics then the transactions is rolled back
|
||||
// Otherwise the transaction is committed.
|
||||
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTransaction.Begin: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
if err == nil && panicErr != nil {
|
||||
err = fmt.Errorf("panic: %v", panicErr)
|
||||
}
|
||||
var txnErr error
|
||||
if err != nil {
|
||||
txnErr = txn.Rollback()
|
||||
} else {
|
||||
txnErr = txn.Commit()
|
||||
}
|
||||
if txnErr != nil && err == nil {
|
||||
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
|
||||
}
|
||||
}()
|
||||
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
67
state/accumulator_test.go
Normal file
67
state/accumulator_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAccumulator(t *testing.T) {
|
||||
roomID := "!1:localhost"
|
||||
roomEvents := []json.RawMessage{
|
||||
[]byte(`{"event_id":"A", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
|
||||
[]byte(`{"event_id":"B", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
|
||||
[]byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
}
|
||||
roomEventIDs := []string{"A", "B", "C"}
|
||||
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
|
||||
err := accumulator.Initialise(roomID, roomEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("falied to Initialise accumulator: %s", err)
|
||||
}
|
||||
|
||||
txn, err := accumulator.db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start assert txn: %s", err)
|
||||
}
|
||||
|
||||
// There should be one snapshot on the current state
|
||||
snapID, err := accumulator.roomsTable.CurrentSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to select current snapshot: %s", err)
|
||||
}
|
||||
if snapID == 0 {
|
||||
t.Fatalf("Initialise did not store a current snapshot")
|
||||
}
|
||||
|
||||
// this snapshot should have 3 events in it
|
||||
row, err := accumulator.snapshotTable.Select(txn, snapID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to select snapshot %d: %s", snapID, err)
|
||||
}
|
||||
if len(row.Events) != len(roomEvents) {
|
||||
t.Fatalf("got %d events, want %d in current state snapshot", len(row.Events), len(roomEvents))
|
||||
}
|
||||
|
||||
// these 3 events should map to the three events we initialised with
|
||||
events, err := accumulator.eventsTable.SelectByNIDs(txn, row.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to extract events in snapshot: %s", err)
|
||||
}
|
||||
if len(events) != len(roomEvents) {
|
||||
t.Fatalf("failed to extract %d events, got %d", len(roomEvents), len(events))
|
||||
}
|
||||
for i := range events {
|
||||
if events[i].ID != roomEventIDs[i] {
|
||||
t.Errorf("event %d was not stored correctly: got ID %s want %s", i, events[i].ID, roomEventIDs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// the ref counter should be 1 for this snapshot
|
||||
emptyRefs, err := accumulator.snapshotRefCountTable.DeleteEmptyRefs(txn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete empty refs: %s", err)
|
||||
}
|
||||
if len(emptyRefs) > 0 {
|
||||
t.Fatalf("got %d empty refs, want none", len(emptyRefs))
|
||||
}
|
||||
}
|
@ -13,15 +13,12 @@ type Event struct {
|
||||
JSON []byte `db:"event"`
|
||||
}
|
||||
|
||||
// EventTable stores events. A unique numeric ID is associated with each event.
|
||||
type EventTable struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewEventTable(postgresURI string) *EventTable {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
// NewEventTable makes a new EventTable
|
||||
func NewEventTable(db *sqlx.DB) *EventTable {
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_event_nids_seq;
|
||||
@ -31,12 +28,12 @@ func NewEventTable(postgresURI string) *EventTable {
|
||||
event JSONB NOT NULL
|
||||
);
|
||||
`)
|
||||
return &EventTable{
|
||||
db: db,
|
||||
}
|
||||
return &EventTable{}
|
||||
}
|
||||
|
||||
func (t *EventTable) Insert(events []Event) error {
|
||||
// Insert events into the event table. Returns the number of rows added. If the number of rows is >0,
|
||||
// and the list of events is in sync stream order, it can be inferred that the last element(s) are new.
|
||||
func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
|
||||
// ensure event_id is set
|
||||
for i := range events {
|
||||
ev := events[i]
|
||||
@ -45,32 +42,46 @@ func (t *EventTable) Insert(events []Event) error {
|
||||
}
|
||||
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
|
||||
if !eventIDResult.Exists() || eventIDResult.Str == "" {
|
||||
return fmt.Errorf("event JSON missing event_id key")
|
||||
return 0, fmt.Errorf("event JSON missing event_id key")
|
||||
}
|
||||
ev.ID = eventIDResult.Str
|
||||
events[i] = ev
|
||||
}
|
||||
_, err := t.db.NamedExec(`INSERT INTO syncv3_events (event_id, event)
|
||||
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, event)
|
||||
VALUES (:event_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
|
||||
return err
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ra, err := result.RowsAffected()
|
||||
return int(ra), err
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectByNIDs(nids []int) (events []Event, err error) {
|
||||
query, args, err := sqlx.In("SELECT * FROM syncv3_events WHERE event_nid IN (?);", nids)
|
||||
query = t.db.Rebind(query)
|
||||
func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, nids []int64) (events []Event, err error) {
|
||||
query, args, err := sqlx.In("SELECT event_nid, event_id, event FROM syncv3_events WHERE event_nid IN (?) ORDER BY event_nid ASC;", nids)
|
||||
query = txn.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = t.db.Select(&events, query, args...)
|
||||
err = txn.Select(&events, query, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectByIDs(ids []string) (events []Event, err error) {
|
||||
query, args, err := sqlx.In("SELECT * FROM syncv3_events WHERE event_id IN (?);", ids)
|
||||
query = t.db.Rebind(query)
|
||||
func (t *EventTable) SelectByIDs(txn *sqlx.Tx, ids []string) (events []Event, err error) {
|
||||
query, args, err := sqlx.In("SELECT event_nid, event_id, event FROM syncv3_events WHERE event_id IN (?);", ids)
|
||||
query = txn.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = t.db.Select(&events, query, args...)
|
||||
err = txn.Select(&events, query, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectNIDsByIDs(txn *sqlx.Tx, ids []string) (nids []int64, err error) {
|
||||
query, args, err := sqlx.In("SELECT event_nid FROM syncv3_events WHERE event_id IN (?);", ids)
|
||||
query = txn.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = txn.Select(&nids, query, args...)
|
||||
return
|
||||
}
|
||||
|
@ -1,9 +1,22 @@
|
||||
package state
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func TestEventTable(t *testing.T) {
|
||||
table := NewEventTable("user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
table := NewEventTable(db)
|
||||
events := []Event{
|
||||
{
|
||||
ID: "100",
|
||||
@ -18,16 +31,24 @@ func TestEventTable(t *testing.T) {
|
||||
JSON: []byte(`{"event_id":"102", "foo":"bar"}`),
|
||||
},
|
||||
}
|
||||
if err := table.Insert(events); err != nil {
|
||||
numNew, err := table.Insert(txn, events)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %s", err)
|
||||
}
|
||||
if numNew != len(events) {
|
||||
t.Fatalf("wanted %d new events, got %d", len(events), numNew)
|
||||
}
|
||||
// duplicate insert is ok
|
||||
if err := table.Insert(events); err != nil {
|
||||
numNew, err = table.Insert(txn, events)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %s", err)
|
||||
}
|
||||
if numNew != 0 {
|
||||
t.Fatalf("wanted 0 new events, got %d", numNew)
|
||||
}
|
||||
|
||||
// pulling non-existent ids returns no error but a zero slice
|
||||
events, err := table.SelectByIDs([]string{"101010101010"})
|
||||
events, err = table.SelectByIDs(txn, []string{"101010101010"})
|
||||
if err != nil {
|
||||
t.Fatalf("SelectByIDs failed: %s", err)
|
||||
}
|
||||
@ -36,7 +57,7 @@ func TestEventTable(t *testing.T) {
|
||||
}
|
||||
|
||||
// pulling events by event_id is ok
|
||||
events, err = table.SelectByIDs([]string{"100", "101", "102"})
|
||||
events, err = table.SelectByIDs(txn, []string{"100", "101", "102"})
|
||||
if err != nil {
|
||||
t.Fatalf("SelectByIDs failed: %s", err)
|
||||
}
|
||||
@ -44,12 +65,21 @@ func TestEventTable(t *testing.T) {
|
||||
t.Fatalf("SelectByIDs returned %d events, want 3", len(events))
|
||||
}
|
||||
|
||||
var nids []int
|
||||
// pulling nids by event_id is ok
|
||||
gotnids, err := table.SelectNIDsByIDs(txn, []string{"100", "101", "102"})
|
||||
if err != nil {
|
||||
t.Fatalf("SelectNIDsByIDs failed: %s", err)
|
||||
}
|
||||
if len(gotnids) != 3 {
|
||||
t.Fatalf("SelectNIDsByIDs returned %d events, want 3", len(gotnids))
|
||||
}
|
||||
|
||||
var nids []int64
|
||||
for _, ev := range events {
|
||||
nids = append(nids, ev.NID)
|
||||
nids = append(nids, int64(ev.NID))
|
||||
}
|
||||
// pulling events by event nid is ok
|
||||
events, err = table.SelectByNIDs(nids)
|
||||
events, err = table.SelectByNIDs(txn, nids)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectByNIDs failed: %s", err)
|
||||
}
|
||||
@ -58,7 +88,7 @@ func TestEventTable(t *testing.T) {
|
||||
}
|
||||
|
||||
// pulling non-existent nids returns no error but a zero slice
|
||||
events, err = table.SelectByNIDs([]int{9999999})
|
||||
events, err = table.SelectByNIDs(txn, []int64{9999999})
|
||||
if err != nil {
|
||||
t.Fatalf("SelectByNIDs failed: %s", err)
|
||||
}
|
||||
|
37
state/rooms_table.go
Normal file
37
state/rooms_table.go
Normal file
@ -0,0 +1,37 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// RoomsTable stores the current snapshot for a room.
|
||||
type RoomsTable struct {
|
||||
}
|
||||
|
||||
func NewRoomsTable(db *sqlx.DB) *RoomsTable {
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_rooms (
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
current_snapshot_id BIGINT NOT NULL
|
||||
);
|
||||
`)
|
||||
return &RoomsTable{}
|
||||
}
|
||||
|
||||
func (t *RoomsTable) UpdateCurrentSnapshotID(txn *sqlx.Tx, roomID string, snapshotID int) (err error) {
|
||||
_, err = txn.Exec(`
|
||||
INSERT INTO syncv3_rooms(room_id, current_snapshot_id) VALUES($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET current_snapshot_id = $2`, roomID, snapshotID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *RoomsTable) CurrentSnapshotID(txn *sqlx.Tx, roomID string) (snapshotID int, err error) {
|
||||
err = txn.QueryRow(`SELECT current_snapshot_id FROM syncv3_rooms WHERE room_id=$1`, roomID).Scan(&snapshotID)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
52
state/rooms_table_test.go
Normal file
52
state/rooms_table_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func TestRoomsTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
table := NewRoomsTable(db)
|
||||
|
||||
// Insert 100
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start txn: %s", err)
|
||||
}
|
||||
roomID := "!1:localhost"
|
||||
if err = table.UpdateCurrentSnapshotID(txn, roomID, 100); err != nil {
|
||||
t.Fatalf("Failed to update current snapshot ID: %s", err)
|
||||
}
|
||||
|
||||
// Select 100
|
||||
id, err := table.CurrentSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select current snapshot ID: %s", err)
|
||||
}
|
||||
if id != 100 {
|
||||
t.Fatalf("current snapshot id mismatch, got %d want %d", id, 100)
|
||||
}
|
||||
|
||||
// Update to 101
|
||||
if table.UpdateCurrentSnapshotID(txn, roomID, 101); err != nil {
|
||||
t.Fatalf("Failed to update current snapshot ID: %s", err)
|
||||
}
|
||||
|
||||
// Select 101
|
||||
id, err = table.CurrentSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select current snapshot ID: %s", err)
|
||||
}
|
||||
if id != 101 {
|
||||
t.Fatalf("current snapshot id mismatch, got %d want %d", id, 101)
|
||||
}
|
||||
}
|
@ -1,20 +1,19 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// SnapshotRefCountsTable maintains a counter per room snapshot which represents
|
||||
// how many clients have this snapshot as their 'latest' state. These snapshots
|
||||
// are paginatable by clients.
|
||||
//
|
||||
// The current state snapshot for a room always has a ref count > 0 (the reference is held by the server, not any clients)
|
||||
// to prevent the current state snapshot from being garbage collected.
|
||||
type SnapshotRefCountsTable struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewSnapshotRefCountsTable(postgresURI string) *SnapshotRefCountsTable {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
func NewSnapshotRefCountsTable(db *sqlx.DB) *SnapshotRefCountsTable {
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_snapshot_ref_counts (
|
||||
@ -22,20 +21,28 @@ func NewSnapshotRefCountsTable(postgresURI string) *SnapshotRefCountsTable {
|
||||
ref_count BIGINT NOT NULL
|
||||
);
|
||||
`)
|
||||
return &SnapshotRefCountsTable{
|
||||
db: db,
|
||||
return &SnapshotRefCountsTable{}
|
||||
}
|
||||
|
||||
// DeleteEmptyRefs removes snapshot ref counts which are 0 and returns the snapshot IDs of the affected rows.
|
||||
func (s *SnapshotRefCountsTable) DeleteEmptyRefs(txn *sqlx.Tx) (emptyRefs []int, err error) {
|
||||
err = txn.Select(&emptyRefs, `SELECT snapshot_id FROM syncv3_snapshot_ref_counts WHERE ref_count = 0`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = txn.Exec(`DELETE FROM syncv3_snapshot_ref_counts WHERE ref_count = 0`)
|
||||
return
|
||||
}
|
||||
|
||||
// Select a row based on its snapshot ID.
|
||||
func (s *SnapshotRefCountsTable) Decrement(tx *sql.Tx, snapshotID int) (count int, err error) {
|
||||
err = tx.QueryRow(`UPDATE syncv3_snapshot_ref_counts SET ref_count = syncv3_snapshot_ref_counts.ref_count - 1 WHERE snapshot_id=$1 RETURNING ref_count`, snapshotID).Scan(&count)
|
||||
func (s *SnapshotRefCountsTable) Decrement(txn *sqlx.Tx, snapshotID int) (count int, err error) {
|
||||
err = txn.QueryRow(`UPDATE syncv3_snapshot_ref_counts SET ref_count = syncv3_snapshot_ref_counts.ref_count - 1 WHERE snapshot_id=$1 RETURNING ref_count`, snapshotID).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
// Insert the row. Modifies SnapshotID to be the inserted primary key.
|
||||
func (s *SnapshotRefCountsTable) Increment(tx *sql.Tx, snapshotID int) (count int, err error) {
|
||||
err = tx.QueryRow(`INSERT INTO syncv3_snapshot_ref_counts(snapshot_id, ref_count) VALUES ($1, $2)
|
||||
func (s *SnapshotRefCountsTable) Increment(txn *sqlx.Tx, snapshotID int) (count int, err error) {
|
||||
err = txn.QueryRow(`INSERT INTO syncv3_snapshot_ref_counts(snapshot_id, ref_count) VALUES ($1, $2)
|
||||
ON CONFLICT (snapshot_id) DO UPDATE SET ref_count=syncv3_snapshot_ref_counts.ref_count+1 RETURNING ref_count`, snapshotID, 1).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
@ -1,32 +1,41 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func TestSnapshotRefCountTable(t *testing.T) {
|
||||
table := NewSnapshotRefCountsTable("user=kegan dbname=syncv3 sslmode=disable")
|
||||
tx, err := table.db.Begin()
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
table := NewSnapshotRefCountsTable(db)
|
||||
tx, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
snapshotID := 100
|
||||
assertIncrement(t, tx, table, snapshotID, 1)
|
||||
assertIncrement(t, tx, table, snapshotID, 2)
|
||||
assertDecrement(t, tx, table, snapshotID, 1)
|
||||
assertDecrement(t, tx, table, snapshotID, 0)
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func TestSnapshotRefCountTableConcurrent(t *testing.T) {
|
||||
table := NewSnapshotRefCountsTable("user=kegan dbname=syncv3 sslmode=disable")
|
||||
tx1, err := table.db.Begin()
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
table := NewSnapshotRefCountsTable(db)
|
||||
tx1, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn1: %s", err)
|
||||
}
|
||||
tx2, err := table.db.Begin()
|
||||
tx2, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn2: %s", err)
|
||||
}
|
||||
@ -36,7 +45,7 @@ func TestSnapshotRefCountTableConcurrent(t *testing.T) {
|
||||
// so we should get 4 as the value
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
incrBy2 := func(tx *sql.Tx) {
|
||||
incrBy2 := func(tx *sqlx.Tx) {
|
||||
defer wg.Done()
|
||||
_, err = table.Increment(tx, snapshotID)
|
||||
if err != nil {
|
||||
@ -56,16 +65,19 @@ func TestSnapshotRefCountTableConcurrent(t *testing.T) {
|
||||
go incrBy2(tx2)
|
||||
wg.Wait()
|
||||
|
||||
tx3, err := table.db.Begin()
|
||||
tx3, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn3: %s", err)
|
||||
}
|
||||
// 4-1=3
|
||||
assertDecrement(t, tx3, table, snapshotID, 3)
|
||||
assertDecrement(t, tx3, table, snapshotID, 2)
|
||||
assertDecrement(t, tx3, table, snapshotID, 1)
|
||||
assertDecrement(t, tx3, table, snapshotID, 0)
|
||||
tx3.Commit()
|
||||
}
|
||||
|
||||
func assertIncrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
|
||||
func assertIncrement(t *testing.T, tx *sqlx.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
|
||||
t.Helper()
|
||||
count, err := table.Increment(tx, snapshotID)
|
||||
if err != nil {
|
||||
@ -76,7 +88,7 @@ func assertIncrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, sn
|
||||
}
|
||||
}
|
||||
|
||||
func assertDecrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
|
||||
func assertDecrement(t *testing.T, tx *sqlx.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
|
||||
t.Helper()
|
||||
count, err := table.Decrement(tx, snapshotID)
|
||||
if err != nil {
|
||||
|
@ -11,15 +11,12 @@ type SnapshotRow struct {
|
||||
Events pq.Int64Array `db:"events"`
|
||||
}
|
||||
|
||||
// SnapshotTable stores room state snapshots. Each snapshot has a unique numeric ID.
|
||||
// Not every event will be associated with a snapshot.
|
||||
type SnapshotTable struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewSnapshotsTable(postgresURI string) *SnapshotTable {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
func NewSnapshotsTable(db *sqlx.DB) *SnapshotTable {
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_snapshots_seq;
|
||||
@ -29,21 +26,30 @@ func NewSnapshotsTable(postgresURI string) *SnapshotTable {
|
||||
events BIGINT[] NOT NULL
|
||||
);
|
||||
`)
|
||||
return &SnapshotTable{
|
||||
db: db,
|
||||
}
|
||||
return &SnapshotTable{}
|
||||
}
|
||||
|
||||
// Select a row based on its snapshot ID.
|
||||
func (s *SnapshotTable) Select(snapshotID int) (row SnapshotRow, err error) {
|
||||
err = s.db.Get(&row, `SELECT * FROM syncv3_snapshots WHERE snapshot_id = $1`, snapshotID)
|
||||
func (s *SnapshotTable) Select(txn *sqlx.Tx, snapshotID int) (row SnapshotRow, err error) {
|
||||
err = txn.Get(&row, `SELECT * FROM syncv3_snapshots WHERE snapshot_id = $1`, snapshotID)
|
||||
return
|
||||
}
|
||||
|
||||
// Insert the row. Modifies SnapshotID to be the inserted primary key.
|
||||
func (s *SnapshotTable) Insert(row *SnapshotRow) error {
|
||||
func (s *SnapshotTable) Insert(txn *sqlx.Tx, row *SnapshotRow) error {
|
||||
var id int
|
||||
err := s.db.QueryRow(`INSERT INTO syncv3_snapshots(room_id, events) VALUES($1, $2) RETURNING snapshot_id`, row.RoomID, row.Events).Scan(&id)
|
||||
err := txn.QueryRow(`INSERT INTO syncv3_snapshots(room_id, events) VALUES($1, $2) RETURNING snapshot_id`, row.RoomID, row.Events).Scan(&id)
|
||||
row.SnapshotID = id
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the snapshot IDs given
|
||||
func (s *SnapshotTable) Delete(txn *sqlx.Tx, snapshotIDs []int) error {
|
||||
query, args, err := sqlx.In(`DELETE FROM syncv3_snapshots WHERE snapshot_id IN (?)`, snapshotIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query = txn.Rebind(query)
|
||||
_, err = txn.Exec(query, args...)
|
||||
return err
|
||||
}
|
||||
|
@ -4,23 +4,36 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestSnapshotTable(t *testing.T) {
|
||||
table := NewSnapshotsTable("user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
table := NewSnapshotsTable(db)
|
||||
|
||||
// Insert a snapshot
|
||||
want := &SnapshotRow{
|
||||
RoomID: "A",
|
||||
Events: pq.Int64Array{1, 2, 3, 4, 5, 6, 7},
|
||||
}
|
||||
err := table.Insert(want)
|
||||
err = table.Insert(txn, want)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert: %s", err)
|
||||
}
|
||||
if want.SnapshotID == 0 {
|
||||
t.Fatalf("Snapshot ID not set")
|
||||
}
|
||||
got, err := table.Select(want.SnapshotID)
|
||||
|
||||
// Select the same snapshot and assert
|
||||
got, err := table.Select(txn, want.SnapshotID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select: %s", err)
|
||||
}
|
||||
@ -33,4 +46,10 @@ func TestSnapshotTable(t *testing.T) {
|
||||
if !reflect.DeepEqual(got.Events, want.Events) {
|
||||
t.Errorf("mismatched events, got: %+v want: %+v", got.Events, want.Events)
|
||||
}
|
||||
|
||||
// Delete the snapshot
|
||||
err = table.Delete(txn, []int{want.SnapshotID})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete snapshot: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
})
|
Loading…
x
Reference in New Issue
Block a user