mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Make tests work for others, add timeline calculations
This commit is contained in:
parent
5909d1a6b0
commit
0dcd3fac09
@ -11,8 +11,6 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TODO: event table needs room_id for querying timeline
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
@ -149,7 +147,8 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
events := make([]Event, len(state))
|
||||
for i := range events {
|
||||
events[i] = Event{
|
||||
JSON: state[i],
|
||||
JSON: state[i],
|
||||
RoomID: roomID,
|
||||
}
|
||||
}
|
||||
numNew, err := a.eventsTable.Insert(txn, events)
|
||||
@ -217,7 +216,8 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
|
||||
events := make([]Event, len(timeline))
|
||||
for i := range events {
|
||||
events[i] = Event{
|
||||
JSON: timeline[i],
|
||||
JSON: timeline[i],
|
||||
RoomID: roomID,
|
||||
}
|
||||
}
|
||||
numNew, err := a.eventsTable.Insert(txn, events)
|
||||
@ -286,6 +286,25 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
|
||||
})
|
||||
}
|
||||
|
||||
// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`.
|
||||
// Returns the latest NID of the last event (most recent)
|
||||
func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) {
|
||||
txn, err := a.db.Beginx()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer txn.Commit()
|
||||
events, err := a.eventsTable.SelectEventsBetween(txn, roomID, lastEventNID, EventsEnd, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
eventsJSON = make([]json.RawMessage, len(events))
|
||||
for i := range events {
|
||||
eventsJSON[i] = events[i].JSON
|
||||
}
|
||||
return eventsJSON, int64(events[len(events)-1].NID), nil
|
||||
}
|
||||
|
||||
// WithTransaction runs a block of code passing in an SQL transaction
|
||||
// If the code returns an error or panics then the transactions is rolled back
|
||||
// Otherwise the transaction is committed.
|
||||
|
@ -15,7 +15,7 @@ func TestAccumulatorInitialise(t *testing.T) {
|
||||
[]byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
}
|
||||
roomEventIDs := []string{"A", "B", "C"}
|
||||
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
|
||||
accumulator := NewAccumulator(postgresConnectionString)
|
||||
err := accumulator.Initialise(roomID, roomEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("falied to Initialise accumulator: %s", err)
|
||||
@ -82,7 +82,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
|
||||
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
}
|
||||
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
|
||||
accumulator := NewAccumulator(postgresConnectionString)
|
||||
err := accumulator.Initialise(roomID, roomEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Initialise accumulator: %s", err)
|
||||
@ -151,3 +151,31 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestAccumulatorDelta(t *testing.T) {
|
||||
roomID := "!TestAccumulatorAccumulate:localhost"
|
||||
roomEvents := []json.RawMessage{
|
||||
[]byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
|
||||
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
|
||||
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
}
|
||||
accumulator := NewAccumulator(postgresConnectionString)
|
||||
err := accumulator.Initialise(roomID, roomEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Initialise accumulator: %s", err)
|
||||
}
|
||||
|
||||
// accumulate new state makes a new snapshot and removes the old snapshot
|
||||
newEvents := []json.RawMessage{
|
||||
// non-state event does nothing
|
||||
[]byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
|
||||
// join_rules should clobber the one from initialise
|
||||
[]byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
// new state event should be added to the snapshot
|
||||
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
|
||||
}
|
||||
if err = accumulator.Accumulate(roomID, newEvents); err != nil {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
} */
|
||||
|
@ -2,15 +2,22 @@ package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
EventsStart = -1
|
||||
EventsEnd = math.MaxInt64 - 1
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
NID int `db:"event_nid"`
|
||||
ID string `db:"event_id"`
|
||||
JSON []byte `db:"event"`
|
||||
NID int `db:"event_nid"`
|
||||
ID string `db:"event_id"`
|
||||
RoomID string `db:"room_id"`
|
||||
JSON []byte `db:"event"`
|
||||
}
|
||||
|
||||
type StrippedEvent struct {
|
||||
@ -39,6 +46,7 @@ func NewEventTable(db *sqlx.DB) *EventTable {
|
||||
CREATE TABLE IF NOT EXISTS syncv3_events (
|
||||
event_nid BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_event_nids_seq'),
|
||||
event_id TEXT NOT NULL UNIQUE,
|
||||
room_id TEXT NOT NULL,
|
||||
event JSONB NOT NULL
|
||||
);
|
||||
`)
|
||||
@ -51,18 +59,24 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
|
||||
// ensure event_id is set
|
||||
for i := range events {
|
||||
ev := events[i]
|
||||
if ev.ID != "" {
|
||||
continue
|
||||
if ev.RoomID == "" {
|
||||
roomIDResult := gjson.GetBytes(ev.JSON, "room_id")
|
||||
if !roomIDResult.Exists() || roomIDResult.Str == "" {
|
||||
return 0, fmt.Errorf("event missing room_id key")
|
||||
}
|
||||
ev.RoomID = roomIDResult.Str
|
||||
}
|
||||
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
|
||||
if !eventIDResult.Exists() || eventIDResult.Str == "" {
|
||||
return 0, fmt.Errorf("event JSON missing event_id key")
|
||||
if ev.ID == "" {
|
||||
eventIDResult := gjson.GetBytes(ev.JSON, "event_id")
|
||||
if !eventIDResult.Exists() || eventIDResult.Str == "" {
|
||||
return 0, fmt.Errorf("event JSON missing event_id key")
|
||||
}
|
||||
ev.ID = eventIDResult.Str
|
||||
}
|
||||
ev.ID = eventIDResult.Str
|
||||
events[i] = ev
|
||||
}
|
||||
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, event)
|
||||
VALUES (:event_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
|
||||
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
|
||||
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -121,3 +135,11 @@ func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, ids []string) (Stri
|
||||
err = txn.Select(&events, query, args...)
|
||||
return events, err
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
|
||||
var events []Event
|
||||
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 LIMIT $4`,
|
||||
lowerExclusive, upperInclusive, roomID, limit,
|
||||
)
|
||||
return events, err
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestEventTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
@ -20,15 +20,15 @@ func TestEventTable(t *testing.T) {
|
||||
events := []Event{
|
||||
{
|
||||
ID: "100",
|
||||
JSON: []byte(`{"event_id":"100", "foo":"bar", "type": "T1", "state_key":"S1"}`),
|
||||
JSON: []byte(`{"event_id":"100", "foo":"bar", "type": "T1", "state_key":"S1", "room_id":"!0:localhost"}`),
|
||||
},
|
||||
{
|
||||
ID: "101",
|
||||
JSON: []byte(`{"event_id":"101", "foo":"bar", "type": "T2", "state_key":"S2"}`),
|
||||
JSON: []byte(`{"event_id":"101", "foo":"bar", "type": "T2", "state_key":"S2", "room_id":"!0:localhost"}`),
|
||||
},
|
||||
{
|
||||
// ID is optional, it will pull event_id out if it's missing
|
||||
JSON: []byte(`{"event_id":"102", "foo":"bar", "type": "T3", "state_key":""}`),
|
||||
JSON: []byte(`{"event_id":"102", "foo":"bar", "type": "T3", "state_key":"", "room_id":"!0:localhost"}`),
|
||||
},
|
||||
}
|
||||
numNew, err := table.Insert(txn, events)
|
||||
@ -130,3 +130,122 @@ func TestEventTable(t *testing.T) {
|
||||
}
|
||||
verifyStripped(strippedEvents)
|
||||
}
|
||||
|
||||
func TestEventTableSelectEventsBetween(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
table := NewEventTable(db)
|
||||
searchRoomID := "!0TestEventTableSelectEventsBetween:localhost"
|
||||
eventIDs := []string{
|
||||
"100TestEventTableSelectEventsBetween",
|
||||
"101TestEventTableSelectEventsBetween",
|
||||
"102TestEventTableSelectEventsBetween",
|
||||
"103TestEventTableSelectEventsBetween",
|
||||
"104TestEventTableSelectEventsBetween",
|
||||
}
|
||||
events := []Event{
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[0] + `","type": "T1", "state_key":"S1", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[1] + `","type": "T2", "state_key":"S2", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[2] + `","type": "T3", "state_key":"", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
// different room
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[3] + `","type": "T4", "state_key":"", "room_id":"!1TestEventTableSelectEventsBetween:localhost"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[4] + `","type": "T5", "state_key":"", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
}
|
||||
numNew, err := table.Insert(txn, events)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %s", err)
|
||||
}
|
||||
if numNew != len(events) {
|
||||
t.Fatalf("failed to insert events: got %d want %d", numNew, len(events))
|
||||
}
|
||||
txn.Commit()
|
||||
|
||||
t.Run("selecting multiple events known lower bound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn2, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn2.Rollback()
|
||||
events, err := table.SelectByIDs(txn2, []string{eventIDs[0]})
|
||||
if err != nil || len(events) == 0 {
|
||||
t.Fatalf("failed to extract event for lower bound: %s", err)
|
||||
}
|
||||
events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// 3 as 1 is from a different room
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("wanted 3 events, got %d", len(events))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn3, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn3.Rollback()
|
||||
events, err := table.SelectByIDs(txn3, []string{eventIDs[0], eventIDs[2]})
|
||||
if err != nil || len(events) == 0 {
|
||||
t.Fatalf("failed to extract event for lower/upper bound: %s", err)
|
||||
}
|
||||
events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// eventIDs[1] and eventIDs[2]
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("wanted 2 events, got %d", len(events))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn4, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn4.Rollback()
|
||||
gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// one less as one event is for a different room
|
||||
if len(gotEvents) != (len(events) - 1) {
|
||||
t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events hitting the limit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn5, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn5.Rollback()
|
||||
limit := 2
|
||||
gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
if len(gotEvents) != limit {
|
||||
t.Fatalf("wanted %d events, got %d", limit, len(gotEvents))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
40
state/main_test.go
Normal file
40
state/main_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
||||
dbName := "syncv3_test"
|
||||
// cleanup if we died mid-way last time, hence don't check err output in case it was already deleted
|
||||
exec.Command("dropdb", dbName).Run()
|
||||
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
fmt.Println("cannot get current user: ", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
postgresConnectionString = fmt.Sprintf(
|
||||
"user=%s dbname=%s sslmode=disable",
|
||||
user.Username, dbName,
|
||||
)
|
||||
|
||||
if err := exec.Command("createdb", dbName).Run(); err != nil {
|
||||
fmt.Println("createdb failed: ", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
// cleanup
|
||||
fmt.Println("cleaning up database")
|
||||
exec.Command("dropdb", dbName).Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestRoomsTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSnapshotRefCountTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
@ -26,7 +26,7 @@ func TestSnapshotRefCountTable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSnapshotRefCountTableConcurrent(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSnapshotTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", "user=kegan dbname=syncv3 sslmode=disable")
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
@ -17,6 +17,7 @@ func TestSnapshotTable(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
table := NewSnapshotsTable(db)
|
||||
|
||||
// Insert a snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user