Make tests work for others, add timeline calculations

This commit is contained in:
Kegan Dougal 2021-06-03 14:35:34 +01:00
parent 5909d1a6b0
commit 0dcd3fac09
8 changed files with 254 additions and 25 deletions

View File

@ -11,8 +11,6 @@ import (
"github.com/tidwall/gjson"
)
// TODO: event table needs room_id for querying timeline
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
@ -149,7 +147,8 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
events := make([]Event, len(state))
for i := range events {
events[i] = Event{
JSON: state[i],
JSON: state[i],
RoomID: roomID,
}
}
numNew, err := a.eventsTable.Insert(txn, events)
@ -217,7 +216,8 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
events := make([]Event, len(timeline))
for i := range events {
events[i] = Event{
JSON: timeline[i],
JSON: timeline[i],
RoomID: roomID,
}
}
numNew, err := a.eventsTable.Insert(txn, events)
@ -286,6 +286,25 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
})
}
// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`.
// Returns the latest NID of the last event (most recent)
func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) {
txn, err := a.db.Beginx()
if err != nil {
return nil, 0, err
}
defer txn.Commit()
events, err := a.eventsTable.SelectEventsBetween(txn, roomID, lastEventNID, EventsEnd, limit)
if err != nil {
return nil, 0, err
}
eventsJSON = make([]json.RawMessage, len(events))
for i := range events {
eventsJSON[i] = events[i].JSON
}
return eventsJSON, int64(events[len(events)-1].NID), nil
}
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolled back
// Otherwise the transaction is committed.

View File

@ -15,7 +15,7 @@ func TestAccumulatorInitialise(t *testing.T) {
[]byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
roomEventIDs := []string{"A", "B", "C"}
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
accumulator := NewAccumulator(postgresConnectionString)
err := accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("falied to Initialise accumulator: %s", err)
@ -82,7 +82,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
accumulator := NewAccumulator(postgresConnectionString)
err := accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
@ -151,3 +151,31 @@ func TestAccumulatorAccumulate(t *testing.T) {
t.Fatalf("failed to Accumulate: %s", err)
}
}
/*
func TestAccumulatorDelta(t *testing.T) {
roomID := "!TestAccumulatorAccumulate:localhost"
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
accumulator := NewAccumulator(postgresConnectionString)
err := accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
// accumulate new state makes a new snapshot and removes the old snapshot
newEvents := []json.RawMessage{
// non-state event does nothing
[]byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
// join_rules should clobber the one from initialise
[]byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
if err = accumulator.Accumulate(roomID, newEvents); err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
} */

View File

@ -2,15 +2,22 @@ package state
import (
"fmt"
"math"
"github.com/jmoiron/sqlx"
"github.com/tidwall/gjson"
)
const (
EventsStart = -1
EventsEnd = math.MaxInt64 - 1
)
type Event struct {
NID int `db:"event_nid"`
ID string `db:"event_id"`
JSON []byte `db:"event"`
NID int `db:"event_nid"`
ID string `db:"event_id"`
RoomID string `db:"room_id"`
JSON []byte `db:"event"`
}
type StrippedEvent struct {
@ -39,6 +46,7 @@ func NewEventTable(db *sqlx.DB) *EventTable {
CREATE TABLE IF NOT EXISTS syncv3_events (
event_nid BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_event_nids_seq'),
event_id TEXT NOT NULL UNIQUE,
room_id TEXT NOT NULL,
event JSONB NOT NULL
);
`)
@ -51,18 +59,24 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
// ensure event_id is set
for i := range events {
ev := events[i]
if ev.ID != "" {
continue
if ev.RoomID == "" {
roomIDResult := gjson.GetBytes(ev.JSON, "room_id")
if !roomIDResult.Exists() || roomIDResult.Str == "" {
return 0, fmt.Errorf("event missing room_id key")
}
ev.RoomID = roomIDResult.Str
}
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
if !eventIDResult.Exists() || eventIDResult.Str == "" {
return 0, fmt.Errorf("event JSON missing event_id key")
if ev.ID == "" {
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
if !eventIDResult.Exists() || eventIDResult.Str == "" {
return 0, fmt.Errorf("event JSON missing event_id key")
}
ev.ID = eventIDResult.Str
}
ev.ID = eventIDResult.Str
events[i] = ev
}
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, event)
VALUES (:event_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
if err != nil {
return 0, err
}
@ -121,3 +135,11 @@ func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, ids []string) (Stri
err = txn.Select(&events, query, args...)
return events, err
}
func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
var events []Event
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 LIMIT $4`,
lowerExclusive, upperInclusive, roomID, limit,
)
return events, err
}

View File

@ -7,7 +7,7 @@ import (
)
func TestEventTable(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
@ -20,15 +20,15 @@ func TestEventTable(t *testing.T) {
events := []Event{
{
ID: "100",
JSON: []byte(`{"event_id":"100", "foo":"bar", "type": "T1", "state_key":"S1"}`),
JSON: []byte(`{"event_id":"100", "foo":"bar", "type": "T1", "state_key":"S1", "room_id":"!0:localhost"}`),
},
{
ID: "101",
JSON: []byte(`{"event_id":"101", "foo":"bar", "type": "T2", "state_key":"S2"}`),
JSON: []byte(`{"event_id":"101", "foo":"bar", "type": "T2", "state_key":"S2", "room_id":"!0:localhost"}`),
},
{
// ID is optional, it will pull event_id out if it's missing
JSON: []byte(`{"event_id":"102", "foo":"bar", "type": "T3", "state_key":""}`),
JSON: []byte(`{"event_id":"102", "foo":"bar", "type": "T3", "state_key":"", "room_id":"!0:localhost"}`),
},
}
numNew, err := table.Insert(txn, events)
@ -130,3 +130,122 @@ func TestEventTable(t *testing.T) {
}
verifyStripped(strippedEvents)
}
func TestEventTableSelectEventsBetween(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
table := NewEventTable(db)
searchRoomID := "!0TestEventTableSelectEventsBetween:localhost"
eventIDs := []string{
"100TestEventTableSelectEventsBetween",
"101TestEventTableSelectEventsBetween",
"102TestEventTableSelectEventsBetween",
"103TestEventTableSelectEventsBetween",
"104TestEventTableSelectEventsBetween",
}
events := []Event{
{
JSON: []byte(`{"event_id":"` + eventIDs[0] + `","type": "T1", "state_key":"S1", "room_id":"` + searchRoomID + `"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[1] + `","type": "T2", "state_key":"S2", "room_id":"` + searchRoomID + `"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[2] + `","type": "T3", "state_key":"", "room_id":"` + searchRoomID + `"}`),
},
{
// different room
JSON: []byte(`{"event_id":"` + eventIDs[3] + `","type": "T4", "state_key":"", "room_id":"!1TestEventTableSelectEventsBetween:localhost"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[4] + `","type": "T5", "state_key":"", "room_id":"` + searchRoomID + `"}`),
},
}
numNew, err := table.Insert(txn, events)
if err != nil {
t.Fatalf("Insert failed: %s", err)
}
if numNew != len(events) {
t.Fatalf("failed to insert events: got %d want %d", numNew, len(events))
}
txn.Commit()
t.Run("selecting multiple events known lower bound", func(t *testing.T) {
t.Parallel()
txn2, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn2.Rollback()
events, err := table.SelectByIDs(txn2, []string{eventIDs[0]})
if err != nil || len(events) == 0 {
t.Fatalf("failed to extract event for lower bound: %s", err)
}
events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// 3 as 1 is from a different room
if len(events) != 3 {
t.Fatalf("wanted 3 events, got %d", len(events))
}
})
t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) {
t.Parallel()
txn3, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn3.Rollback()
events, err := table.SelectByIDs(txn3, []string{eventIDs[0], eventIDs[2]})
if err != nil || len(events) == 0 {
t.Fatalf("failed to extract event for lower/upper bound: %s", err)
}
events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// eventIDs[1] and eventIDs[2]
if len(events) != 2 {
t.Fatalf("wanted 2 events, got %d", len(events))
}
})
t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) {
t.Parallel()
txn4, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn4.Rollback()
gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// one less as one event is for a different room
if len(gotEvents) != (len(events) - 1) {
t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents))
}
})
t.Run("selecting multiple events hitting the limit", func(t *testing.T) {
t.Parallel()
txn5, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn5.Rollback()
limit := 2
gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
if len(gotEvents) != limit {
t.Fatalf("wanted %d events, got %d", limit, len(gotEvents))
}
})
}

40
state/main_test.go Normal file
View File

@ -0,0 +1,40 @@
package state
import (
"fmt"
"os"
"os/exec"
"os/user"
"testing"
)
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
func TestMain(m *testing.M) {
fmt.Println("Note: tests require a postgres install accessible to the current user")
dbName := "syncv3_test"
// cleanup if we died mid-way last time, hence don't check err output in case it was already deleted
exec.Command("dropdb", dbName).Run()
user, err := user.Current()
if err != nil {
fmt.Println("cannot get current user: ", err)
os.Exit(2)
}
postgresConnectionString = fmt.Sprintf(
"user=%s dbname=%s sslmode=disable",
user.Username, dbName,
)
if err := exec.Command("createdb", dbName).Run(); err != nil {
fmt.Println("createdb failed: ", err)
os.Exit(2)
}
exitCode := m.Run()
// cleanup
fmt.Println("cleaning up database")
exec.Command("dropdb", dbName).Run()
os.Exit(exitCode)
}

View File

@ -7,7 +7,7 @@ import (
)
func TestRoomsTable(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}

View File

@ -8,7 +8,7 @@ import (
)
func TestSnapshotRefCountTable(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
@ -26,7 +26,7 @@ func TestSnapshotRefCountTable(t *testing.T) {
}
func TestSnapshotRefCountTableConcurrent(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}

View File

@ -9,7 +9,7 @@ import (
)
func TestSnapshotTable(t *testing.T) {
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
@ -17,6 +17,7 @@ func TestSnapshotTable(t *testing.T) {
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn.Rollback()
table := NewSnapshotsTable(db)
// Insert a snapshot