Add Accumulator.Initialise with tests

This commit is contained in:
Kegan Dougal 2021-05-27 19:20:36 +01:00
parent f3e0f96d91
commit cd20d07d9f
11 changed files with 528 additions and 84 deletions

215
state/accumulator.go Normal file
View File

@ -0,0 +1,215 @@
package state
import (
"encoding/json"
"fmt"
"os"
"sync"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/rs/zerolog"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
// Accumulator tracks room state and timelines.
//
// In order for it to remain simple(ish), the accumulator DOES NOT SUPPORT arbitrary timeline gaps.
// There is an Initialise function for new rooms (with some pre-determined state) and then a constant
// Accumulate function for timeline events. v2 sync must be called with a large enough timeline.limit
// for this to work!
type Accumulator struct {
mu *sync.Mutex // lock for locks
// TODO: unbounded on number of rooms
locks map[string]*sync.Mutex // room_id -> mutex
db *sqlx.DB
roomsTable *RoomsTable
eventsTable *EventTable
snapshotTable *SnapshotTable
snapshotRefCountTable *SnapshotRefCountsTable
}
func NewAccumulator(postgresURI string) *Accumulator {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
return &Accumulator{
db: db,
mu: &sync.Mutex{},
locks: make(map[string]*sync.Mutex),
roomsTable: NewRoomsTable(db),
eventsTable: NewEventTable(db),
snapshotTable: NewSnapshotsTable(db),
snapshotRefCountTable: NewSnapshotRefCountsTable(db),
}
}
// obtain a per-room lock
func (a *Accumulator) mutex(roomID string) *sync.Mutex {
a.mu.Lock()
defer a.mu.Unlock()
lock, ok := a.locks[roomID]
if !ok {
lock = &sync.Mutex{}
a.locks[roomID] = lock
}
return lock
}
// clearSnapshots deletes all snapshots with 0 refs to it
func (a *Accumulator) clearSnapshots(txn *sqlx.Tx) {
snapshotIDs, err := a.snapshotRefCountTable.DeleteEmptyRefs(txn)
if err != nil {
log.Err(err).Msg("failed to DeleteEmptyRefs")
return
}
err = a.snapshotTable.Delete(txn, snapshotIDs)
if err != nil {
log.Err(err).Ints("snapshots", snapshotIDs).Msg("failed to delete snapshot IDs")
}
}
// moveSnapshotRef decrements the from snapshot ID and increments the to snapshot ID
// Clears the from snapshot if there are 0 refs to it
func (a *Accumulator) moveSnapshotRef(txn *sqlx.Tx, from, to int) error {
if from != 0 {
count, err := a.snapshotRefCountTable.Decrement(txn, from)
if err != nil {
return err
}
if count == 0 {
a.clearSnapshots(txn)
}
}
_, err := a.snapshotRefCountTable.Increment(txn, to)
return err
}
// Initialise starts a new sync accumulator for the given room using the given state as a baseline.
// This will only take effect if this is the first time the v3 server has seen this room, and it wasn't
// possible to get all events up to the create event (e.g Matrix HQ).
//
// This function:
// - Stores these events
// - Sets up the current snapshot based on the state list given.
func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
if len(state) == 0 {
return nil
}
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
// Attempt to short-circuit. This has to be done inside a transaction to make sure
// we don't race with multiple calls to Initialise with the same room ID.
snapshotID, err := a.roomsTable.CurrentSnapshotID(txn, roomID)
if err != nil {
return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err)
}
if snapshotID > 0 {
// we only initialise rooms once
log.Info().Str("room_id", roomID).Int("snapshot_id", snapshotID).Msg("Accumulator.Initialise called but current snapshot already exists, bailing early")
return nil
}
// Insert the events
events := make([]Event, len(state))
for i := range events {
events[i] = Event{
JSON: state[i],
}
}
numNew, err := a.eventsTable.Insert(txn, events)
if err != nil {
return err
}
if numNew == 0 {
// we don't have a current snapshot for this room but yet no events are new,
// no idea how this should be handled.
log.Error().Str("room_id", roomID).Msg(
"Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug.",
)
return nil
}
// pull out the event NIDs we just inserted
eventIDs := make([]string, len(events))
for i := range eventIDs {
eventIDs[i] = events[i].ID
}
nids, err := a.eventsTable.SelectNIDsByIDs(txn, eventIDs)
if err != nil {
return err
}
// Make a current snapshot
snapshot := &SnapshotRow{
RoomID: roomID,
Events: pq.Int64Array(nids),
}
err = a.snapshotTable.Insert(txn, snapshot)
if err != nil {
return err
}
// Increment the ref counter
err = a.moveSnapshotRef(txn, 0, snapshot.SnapshotID)
if err != nil {
return err
}
// Set the snapshot ID as the current state
return a.roomsTable.UpdateCurrentSnapshotID(txn, roomID, snapshot.SnapshotID)
})
}
// Accumulate internal state from a user's sync response. Locks per-room.
//
// This function does several things:
// - It ensures all events are persisted in the database. This is shared amongst users.
// - If the last event in the timeline has been stored before, then it short circuits and returns.
// This is because we must have already processed this in order for the event to exist in the database,
// and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot (as this now represents the current state)
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) error {
if len(timeline) == 0 {
return nil
}
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
return nil
})
}
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolled back
// Otherwise the transaction is committed.
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
txn, err := db.Beginx()
if err != nil {
return fmt.Errorf("WithTransaction.Begin: %w", err)
}
defer func() {
panicErr := recover()
if err == nil && panicErr != nil {
err = fmt.Errorf("panic: %v", panicErr)
}
var txnErr error
if err != nil {
txnErr = txn.Rollback()
} else {
txnErr = txn.Commit()
}
if txnErr != nil && err == nil {
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
}
}()
err = fn(txn)
return
}

67
state/accumulator_test.go Normal file
View File

@ -0,0 +1,67 @@
package state
import (
"encoding/json"
"testing"
)
func TestAccumulator(t *testing.T) {
roomID := "!1:localhost"
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"A", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"B", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
roomEventIDs := []string{"A", "B", "C"}
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
err := accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("falied to Initialise accumulator: %s", err)
}
txn, err := accumulator.db.Beginx()
if err != nil {
t.Fatalf("failed to start assert txn: %s", err)
}
// There should be one snapshot on the current state
snapID, err := accumulator.roomsTable.CurrentSnapshotID(txn, roomID)
if err != nil {
t.Fatalf("failed to select current snapshot: %s", err)
}
if snapID == 0 {
t.Fatalf("Initialise did not store a current snapshot")
}
// this snapshot should have 3 events in it
row, err := accumulator.snapshotTable.Select(txn, snapID)
if err != nil {
t.Fatalf("failed to select snapshot %d: %s", snapID, err)
}
if len(row.Events) != len(roomEvents) {
t.Fatalf("got %d events, want %d in current state snapshot", len(row.Events), len(roomEvents))
}
// these 3 events should map to the three events we initialised with
events, err := accumulator.eventsTable.SelectByNIDs(txn, row.Events)
if err != nil {
t.Fatalf("failed to extract events in snapshot: %s", err)
}
if len(events) != len(roomEvents) {
t.Fatalf("failed to extract %d events, got %d", len(roomEvents), len(events))
}
for i := range events {
if events[i].ID != roomEventIDs[i] {
t.Errorf("event %d was not stored correctly: got ID %s want %s", i, events[i].ID, roomEventIDs[i])
}
}
// the ref counter should be 1 for this snapshot
emptyRefs, err := accumulator.snapshotRefCountTable.DeleteEmptyRefs(txn)
if err != nil {
t.Fatalf("failed to delete empty refs: %s", err)
}
if len(emptyRefs) > 0 {
t.Fatalf("got %d empty refs, want none", len(emptyRefs))
}
}

View File

@ -13,15 +13,12 @@ type Event struct {
JSON []byte `db:"event"`
}
// EventTable stores events. A unique numeric ID is associated with each event.
type EventTable struct {
db *sqlx.DB
}
func NewEventTable(postgresURI string) *EventTable {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
// NewEventTable makes a new EventTable
func NewEventTable(db *sqlx.DB) *EventTable {
// make sure tables are made
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_event_nids_seq;
@ -31,12 +28,12 @@ func NewEventTable(postgresURI string) *EventTable {
event JSONB NOT NULL
);
`)
return &EventTable{
db: db,
}
return &EventTable{}
}
func (t *EventTable) Insert(events []Event) error {
// Insert events into the event table. Returns the number of rows added. If the number of rows is >0,
// and the list of events is in sync stream order, it can be inferred that the last element(s) are new.
func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
// ensure event_id is set
for i := range events {
ev := events[i]
@ -45,32 +42,46 @@ func (t *EventTable) Insert(events []Event) error {
}
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
if !eventIDResult.Exists() || eventIDResult.Str == "" {
return fmt.Errorf("event JSON missing event_id key")
return 0, fmt.Errorf("event JSON missing event_id key")
}
ev.ID = eventIDResult.Str
events[i] = ev
}
_, err := t.db.NamedExec(`INSERT INTO syncv3_events (event_id, event)
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, event)
VALUES (:event_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
return err
if err != nil {
return 0, err
}
ra, err := result.RowsAffected()
return int(ra), err
}
func (t *EventTable) SelectByNIDs(nids []int) (events []Event, err error) {
query, args, err := sqlx.In("SELECT * FROM syncv3_events WHERE event_nid IN (?);", nids)
query = t.db.Rebind(query)
func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, nids []int64) (events []Event, err error) {
query, args, err := sqlx.In("SELECT event_nid, event_id, event FROM syncv3_events WHERE event_nid IN (?) ORDER BY event_nid ASC;", nids)
query = txn.Rebind(query)
if err != nil {
return nil, err
}
err = t.db.Select(&events, query, args...)
err = txn.Select(&events, query, args...)
return
}
func (t *EventTable) SelectByIDs(ids []string) (events []Event, err error) {
query, args, err := sqlx.In("SELECT * FROM syncv3_events WHERE event_id IN (?);", ids)
query = t.db.Rebind(query)
func (t *EventTable) SelectByIDs(txn *sqlx.Tx, ids []string) (events []Event, err error) {
query, args, err := sqlx.In("SELECT event_nid, event_id, event FROM syncv3_events WHERE event_id IN (?);", ids)
query = txn.Rebind(query)
if err != nil {
return nil, err
}
err = t.db.Select(&events, query, args...)
err = txn.Select(&events, query, args...)
return
}
func (t *EventTable) SelectNIDsByIDs(txn *sqlx.Tx, ids []string) (nids []int64, err error) {
query, args, err := sqlx.In("SELECT event_nid FROM syncv3_events WHERE event_id IN (?);", ids)
query = txn.Rebind(query)
if err != nil {
return nil, err
}
err = txn.Select(&nids, query, args...)
return
}

View File

@ -1,9 +1,22 @@
package state
import "testing"
import (
"testing"
"github.com/jmoiron/sqlx"
)
func TestEventTable(t *testing.T) {
table := NewEventTable("user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn.Rollback()
table := NewEventTable(db)
events := []Event{
{
ID: "100",
@ -18,16 +31,24 @@ func TestEventTable(t *testing.T) {
JSON: []byte(`{"event_id":"102", "foo":"bar"}`),
},
}
if err := table.Insert(events); err != nil {
numNew, err := table.Insert(txn, events)
if err != nil {
t.Fatalf("Insert failed: %s", err)
}
if numNew != len(events) {
t.Fatalf("wanted %d new events, got %d", len(events), numNew)
}
// duplicate insert is ok
if err := table.Insert(events); err != nil {
numNew, err = table.Insert(txn, events)
if err != nil {
t.Fatalf("Insert failed: %s", err)
}
if numNew != 0 {
t.Fatalf("wanted 0 new events, got %d", numNew)
}
// pulling non-existent ids returns no error but a zero slice
events, err := table.SelectByIDs([]string{"101010101010"})
events, err = table.SelectByIDs(txn, []string{"101010101010"})
if err != nil {
t.Fatalf("SelectByIDs failed: %s", err)
}
@ -36,7 +57,7 @@ func TestEventTable(t *testing.T) {
}
// pulling events by event_id is ok
events, err = table.SelectByIDs([]string{"100", "101", "102"})
events, err = table.SelectByIDs(txn, []string{"100", "101", "102"})
if err != nil {
t.Fatalf("SelectByIDs failed: %s", err)
}
@ -44,12 +65,21 @@ func TestEventTable(t *testing.T) {
t.Fatalf("SelectByIDs returned %d events, want 3", len(events))
}
var nids []int
// pulling nids by event_id is ok
gotnids, err := table.SelectNIDsByIDs(txn, []string{"100", "101", "102"})
if err != nil {
t.Fatalf("SelectNIDsByIDs failed: %s", err)
}
if len(gotnids) != 3 {
t.Fatalf("SelectNIDsByIDs returned %d events, want 3", len(gotnids))
}
var nids []int64
for _, ev := range events {
nids = append(nids, ev.NID)
nids = append(nids, int64(ev.NID))
}
// pulling events by event nid is ok
events, err = table.SelectByNIDs(nids)
events, err = table.SelectByNIDs(txn, nids)
if err != nil {
t.Fatalf("SelectByNIDs failed: %s", err)
}
@ -58,7 +88,7 @@ func TestEventTable(t *testing.T) {
}
// pulling non-existent nids returns no error but a zero slice
events, err = table.SelectByNIDs([]int{9999999})
events, err = table.SelectByNIDs(txn, []int64{9999999})
if err != nil {
t.Fatalf("SelectByNIDs failed: %s", err)
}

37
state/rooms_table.go Normal file
View File

@ -0,0 +1,37 @@
package state
import (
"database/sql"
"github.com/jmoiron/sqlx"
)
// RoomsTable stores the current snapshot for a room.
type RoomsTable struct {
}
func NewRoomsTable(db *sqlx.DB) *RoomsTable {
// make sure tables are made
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_rooms (
room_id TEXT NOT NULL PRIMARY KEY,
current_snapshot_id BIGINT NOT NULL
);
`)
return &RoomsTable{}
}
func (t *RoomsTable) UpdateCurrentSnapshotID(txn *sqlx.Tx, roomID string, snapshotID int) (err error) {
_, err = txn.Exec(`
INSERT INTO syncv3_rooms(room_id, current_snapshot_id) VALUES($1, $2)
ON CONFLICT (room_id) DO UPDATE SET current_snapshot_id = $2`, roomID, snapshotID)
return err
}
func (t *RoomsTable) CurrentSnapshotID(txn *sqlx.Tx, roomID string) (snapshotID int, err error) {
err = txn.QueryRow(`SELECT current_snapshot_id FROM syncv3_rooms WHERE room_id=$1`, roomID).Scan(&snapshotID)
if err == sql.ErrNoRows {
err = nil
}
return
}

52
state/rooms_table_test.go Normal file
View File

@ -0,0 +1,52 @@
package state
import (
"testing"
"github.com/jmoiron/sqlx"
)
func TestRoomsTable(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn.Rollback()
table := NewRoomsTable(db)
// Insert 100
if err != nil {
t.Fatalf("Failed to start txn: %s", err)
}
roomID := "!1:localhost"
if err = table.UpdateCurrentSnapshotID(txn, roomID, 100); err != nil {
t.Fatalf("Failed to update current snapshot ID: %s", err)
}
// Select 100
id, err := table.CurrentSnapshotID(txn, roomID)
if err != nil {
t.Fatalf("Failed to select current snapshot ID: %s", err)
}
if id != 100 {
t.Fatalf("current snapshot id mismatch, got %d want %d", id, 100)
}
// Update to 101
if table.UpdateCurrentSnapshotID(txn, roomID, 101); err != nil {
t.Fatalf("Failed to update current snapshot ID: %s", err)
}
// Select 101
id, err = table.CurrentSnapshotID(txn, roomID)
if err != nil {
t.Fatalf("Failed to select current snapshot ID: %s", err)
}
if id != 101 {
t.Fatalf("current snapshot id mismatch, got %d want %d", id, 101)
}
}

View File

@ -1,20 +1,19 @@
package state
import (
"database/sql"
"github.com/jmoiron/sqlx"
)
// SnapshotRefCountsTable maintains a counter per room snapshot which represents
// how many clients have this snapshot as their 'latest' state. These snapshots
// are paginatable by clients.
//
// The current state snapshot for a room always has a ref count > 0 (the reference is held by the server, not any clients)
// to prevent the current state snapshot from being garbage collected.
type SnapshotRefCountsTable struct {
db *sqlx.DB
}
func NewSnapshotRefCountsTable(postgresURI string) *SnapshotRefCountsTable {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
func NewSnapshotRefCountsTable(db *sqlx.DB) *SnapshotRefCountsTable {
// make sure tables are made
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_snapshot_ref_counts (
@ -22,20 +21,28 @@ func NewSnapshotRefCountsTable(postgresURI string) *SnapshotRefCountsTable {
ref_count BIGINT NOT NULL
);
`)
return &SnapshotRefCountsTable{
db: db,
return &SnapshotRefCountsTable{}
}
// DeleteEmptyRefs removes snapshot ref counts which are 0 and returns the snapshot IDs of the affected rows.
func (s *SnapshotRefCountsTable) DeleteEmptyRefs(txn *sqlx.Tx) (emptyRefs []int, err error) {
err = txn.Select(&emptyRefs, `SELECT snapshot_id FROM syncv3_snapshot_ref_counts WHERE ref_count = 0`)
if err != nil {
return
}
_, err = txn.Exec(`DELETE FROM syncv3_snapshot_ref_counts WHERE ref_count = 0`)
return
}
// Select a row based on its snapshot ID.
func (s *SnapshotRefCountsTable) Decrement(tx *sql.Tx, snapshotID int) (count int, err error) {
err = tx.QueryRow(`UPDATE syncv3_snapshot_ref_counts SET ref_count = syncv3_snapshot_ref_counts.ref_count - 1 WHERE snapshot_id=$1 RETURNING ref_count`, snapshotID).Scan(&count)
func (s *SnapshotRefCountsTable) Decrement(txn *sqlx.Tx, snapshotID int) (count int, err error) {
err = txn.QueryRow(`UPDATE syncv3_snapshot_ref_counts SET ref_count = syncv3_snapshot_ref_counts.ref_count - 1 WHERE snapshot_id=$1 RETURNING ref_count`, snapshotID).Scan(&count)
return
}
// Insert the row. Modifies SnapshotID to be the inserted primary key.
func (s *SnapshotRefCountsTable) Increment(tx *sql.Tx, snapshotID int) (count int, err error) {
err = tx.QueryRow(`INSERT INTO syncv3_snapshot_ref_counts(snapshot_id, ref_count) VALUES ($1, $2)
func (s *SnapshotRefCountsTable) Increment(txn *sqlx.Tx, snapshotID int) (count int, err error) {
err = txn.QueryRow(`INSERT INTO syncv3_snapshot_ref_counts(snapshot_id, ref_count) VALUES ($1, $2)
ON CONFLICT (snapshot_id) DO UPDATE SET ref_count=syncv3_snapshot_ref_counts.ref_count+1 RETURNING ref_count`, snapshotID, 1).Scan(&count)
return
}

View File

@ -1,32 +1,41 @@
package state
import (
"database/sql"
"sync"
"testing"
"github.com/jmoiron/sqlx"
)
func TestSnapshotRefCountTable(t *testing.T) {
table := NewSnapshotRefCountsTable("user=kegan dbname=syncv3 sslmode=disable")
tx, err := table.db.Begin()
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewSnapshotRefCountsTable(db)
tx, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer tx.Rollback()
snapshotID := 100
assertIncrement(t, tx, table, snapshotID, 1)
assertIncrement(t, tx, table, snapshotID, 2)
assertDecrement(t, tx, table, snapshotID, 1)
assertDecrement(t, tx, table, snapshotID, 0)
tx.Commit()
}
func TestSnapshotRefCountTableConcurrent(t *testing.T) {
table := NewSnapshotRefCountsTable("user=kegan dbname=syncv3 sslmode=disable")
tx1, err := table.db.Begin()
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewSnapshotRefCountsTable(db)
tx1, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn1: %s", err)
}
tx2, err := table.db.Begin()
tx2, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn2: %s", err)
}
@ -36,7 +45,7 @@ func TestSnapshotRefCountTableConcurrent(t *testing.T) {
// so we should get 4 as the value
var wg sync.WaitGroup
wg.Add(2)
incrBy2 := func(tx *sql.Tx) {
incrBy2 := func(tx *sqlx.Tx) {
defer wg.Done()
_, err = table.Increment(tx, snapshotID)
if err != nil {
@ -56,16 +65,19 @@ func TestSnapshotRefCountTableConcurrent(t *testing.T) {
go incrBy2(tx2)
wg.Wait()
tx3, err := table.db.Begin()
tx3, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn3: %s", err)
}
// 4-1=3
assertDecrement(t, tx3, table, snapshotID, 3)
assertDecrement(t, tx3, table, snapshotID, 2)
assertDecrement(t, tx3, table, snapshotID, 1)
assertDecrement(t, tx3, table, snapshotID, 0)
tx3.Commit()
}
func assertIncrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
func assertIncrement(t *testing.T, tx *sqlx.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
t.Helper()
count, err := table.Increment(tx, snapshotID)
if err != nil {
@ -76,7 +88,7 @@ func assertIncrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, sn
}
}
func assertDecrement(t *testing.T, tx *sql.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
func assertDecrement(t *testing.T, tx *sqlx.Tx, table *SnapshotRefCountsTable, snapshotID int, wantVal int) {
t.Helper()
count, err := table.Decrement(tx, snapshotID)
if err != nil {

View File

@ -11,15 +11,12 @@ type SnapshotRow struct {
Events pq.Int64Array `db:"events"`
}
// SnapshotTable stores room state snapshots. Each snapshot has a unique numeric ID.
// Not every event will be associated with a snapshot.
type SnapshotTable struct {
db *sqlx.DB
}
func NewSnapshotsTable(postgresURI string) *SnapshotTable {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
func NewSnapshotsTable(db *sqlx.DB) *SnapshotTable {
// make sure tables are made
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_snapshots_seq;
@ -29,21 +26,30 @@ func NewSnapshotsTable(postgresURI string) *SnapshotTable {
events BIGINT[] NOT NULL
);
`)
return &SnapshotTable{
db: db,
}
return &SnapshotTable{}
}
// Select a row based on its snapshot ID.
func (s *SnapshotTable) Select(snapshotID int) (row SnapshotRow, err error) {
err = s.db.Get(&row, `SELECT * FROM syncv3_snapshots WHERE snapshot_id = $1`, snapshotID)
func (s *SnapshotTable) Select(txn *sqlx.Tx, snapshotID int) (row SnapshotRow, err error) {
err = txn.Get(&row, `SELECT * FROM syncv3_snapshots WHERE snapshot_id = $1`, snapshotID)
return
}
// Insert the row. Modifies SnapshotID to be the inserted primary key.
func (s *SnapshotTable) Insert(row *SnapshotRow) error {
func (s *SnapshotTable) Insert(txn *sqlx.Tx, row *SnapshotRow) error {
var id int
err := s.db.QueryRow(`INSERT INTO syncv3_snapshots(room_id, events) VALUES($1, $2) RETURNING snapshot_id`, row.RoomID, row.Events).Scan(&id)
err := txn.QueryRow(`INSERT INTO syncv3_snapshots(room_id, events) VALUES($1, $2) RETURNING snapshot_id`, row.RoomID, row.Events).Scan(&id)
row.SnapshotID = id
return err
}
// Delete the snapshot IDs given
func (s *SnapshotTable) Delete(txn *sqlx.Tx, snapshotIDs []int) error {
query, args, err := sqlx.In(`DELETE FROM syncv3_snapshots WHERE snapshot_id IN (?)`, snapshotIDs)
if err != nil {
return err
}
query = txn.Rebind(query)
_, err = txn.Exec(query, args...)
return err
}

View File

@ -4,23 +4,36 @@ import (
"reflect"
"testing"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
func TestSnapshotTable(t *testing.T) {
table := NewSnapshotsTable("user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
table := NewSnapshotsTable(db)
// Insert a snapshot
want := &SnapshotRow{
RoomID: "A",
Events: pq.Int64Array{1, 2, 3, 4, 5, 6, 7},
}
err := table.Insert(want)
err = table.Insert(txn, want)
if err != nil {
t.Fatalf("Failed to insert: %s", err)
}
if want.SnapshotID == 0 {
t.Fatalf("Snapshot ID not set")
}
got, err := table.Select(want.SnapshotID)
// Select the same snapshot and assert
got, err := table.Select(txn, want.SnapshotID)
if err != nil {
t.Fatalf("Failed to select: %s", err)
}
@ -33,4 +46,10 @@ func TestSnapshotTable(t *testing.T) {
if !reflect.DeepEqual(got.Events, want.Events) {
t.Errorf("mismatched events, got: %+v want: %+v", got.Events, want.Events)
}
// Delete the snapshot
err = table.Delete(txn, []int{want.SnapshotID})
if err != nil {
t.Fatalf("failed to delete snapshot: %s", err)
}
}

View File

@ -1,12 +0,0 @@
package state
import (
"os"
"github.com/rs/zerolog"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})