Implement Accumulate

This commit is contained in:
Kegan Dougal 2021-05-28 16:07:28 +01:00
parent 91dd5609f7
commit 5909d1a6b0
2 changed files with 120 additions and 7 deletions

View File

@ -84,8 +84,41 @@ func (a *Accumulator) strippedEventsForSnapshot(txn *sqlx.Tx, snapID int) (Strip
return a.eventsTable.SelectStrippedEventsByNIDs(txn, snapshot.Events)
}
// calculateNewSnapshot works out the new snapshot by combining an old and new snapshot. Events get replaced
// if the tuple of event type/state_key match. A new slice is returning (the inputs are not modified)
func (a *Accumulator) calculateNewSnapshot(old StrippedEvents, new StrippedEvents) (StrippedEvents, error) {
return nil, nil
// TODO: implement dendrite's binary tree diff algorithm
tupleKey := func(e StrippedEvent) string {
// 0x1f = unit separator
return e.Type + "\x1f" + e.StateKey
}
tupleToNew := make(map[string]StrippedEvent)
for _, e := range new {
tupleToNew[tupleKey(e)] = e
}
var result StrippedEvents
for _, e := range old {
newEvent := tupleToNew[tupleKey(e)]
if newEvent.NID > 0 {
result = append(result, StrippedEvent{
NID: newEvent.NID,
Type: e.Type,
StateKey: e.StateKey,
})
delete(tupleToNew, tupleKey(e))
} else {
result = append(result, StrippedEvent{
NID: e.NID,
Type: e.Type,
StateKey: e.StateKey,
})
}
}
// add genuinely new state events from new
for _, newEvent := range tupleToNew {
result = append(result, newEvent)
}
return result, nil
}
// Initialise starts a new sync accumulator for the given room using the given state as a baseline.
@ -168,12 +201,13 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
//
// This function does several things:
// - It ensures all events are persisted in the database. This is shared amongst users.
// - If the last event in the timeline has been stored before, then it short circuits and returns.
// This is because we must have already processed this in order for the event to exist in the database,
// and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot (as this now represents the current state)
// - If all events have been stored before, then it short circuits and returns.
// This is because we must have already processed this part of the timeline in order for the event
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
// The server itself also holds a ref to the current state, which is then moved to the new current state.
func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) error {
if len(timeline) == 0 {
return nil

View File

@ -3,10 +3,12 @@ package state
import (
"encoding/json"
"testing"
"github.com/tidwall/gjson"
)
func TestAccumulator(t *testing.T) {
roomID := "!1:localhost"
func TestAccumulatorInitialise(t *testing.T) {
roomID := "!TestAccumulatorInitialise:localhost"
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"A", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"B", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
@ -72,3 +74,80 @@ func TestAccumulator(t *testing.T) {
t.Fatalf("falied to Initialise accumulator: %s", err)
}
}
func TestAccumulatorAccumulate(t *testing.T) {
roomID := "!TestAccumulatorAccumulate:localhost"
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`),
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
}
accumulator := NewAccumulator("user=kegan dbname=syncv3 sslmode=disable")
err := accumulator.Initialise(roomID, roomEvents)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
// accumulate new state makes a new snapshot and removes the old snapshot
newEvents := []json.RawMessage{
// non-state event does nothing
[]byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
// join_rules should clobber the one from initialise
[]byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
if err = accumulator.Accumulate(roomID, newEvents); err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
// Begin assertions
wantStateEvents := []json.RawMessage{
roomEvents[0], // create event
roomEvents[1], // member event
newEvents[1], // join rules
newEvents[2], // history visibility
}
txn, err := accumulator.db.Beginx()
if err != nil {
t.Fatalf("failed to start assert txn: %s", err)
}
defer txn.Rollback()
// There should be one snapshot in current state
snapID, err := accumulator.roomsTable.CurrentSnapshotID(txn, roomID)
if err != nil {
t.Fatalf("failed to select current snapshot: %s", err)
}
if snapID == 0 {
t.Fatalf("Initialise or Accumulate did not store a current snapshot")
}
// The snapshot should have 4 events
row, err := accumulator.snapshotTable.Select(txn, snapID)
if err != nil {
t.Fatalf("failed to select snapshot %d: %s", snapID, err)
}
if len(row.Events) != len(wantStateEvents) {
t.Fatalf("snapshot: %d got %d events, want %d in current state snapshot", snapID, len(row.Events), len(wantStateEvents))
}
// these 4 events should map to the create/member events from initialise, then the join_rules/history_visibility from accumulate
events, err := accumulator.eventsTable.SelectByNIDs(txn, row.Events)
if err != nil {
t.Fatalf("failed to extract events in snapshot: %s", err)
}
if len(events) != len(wantStateEvents) {
t.Fatalf("failed to extract %d events, got %d", len(wantStateEvents), len(events))
}
for i := range events {
if events[i].ID != gjson.GetBytes(wantStateEvents[i], "event_id").Str {
t.Errorf("event %d was not stored correctly: got ID %s want %s", i, events[i].ID, gjson.GetBytes(wantStateEvents[i], "event_id").Str)
}
}
// subsequent calls do nothing and are not an error
if err = accumulator.Accumulate(roomID, newEvents); err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
}