ConnState: add tests and fix bugs

Test basic SYNC,INSERT,DELETE,UPDATE functionality.
This commit is contained in:
Kegan Dougal 2021-09-27 17:47:43 +01:00
parent 6d284ef158
commit 3f9794ef33
7 changed files with 343 additions and 63 deletions

View File

@ -2,10 +2,11 @@
<head>
<title>Sync v3 experiments</title>
<script>
let restart = false;
const doSyncLoop = async(accessToken) => {
let currentPos;
let rooms = [];
while (true) {
while (!restart) {
let resp = await doSyncRequest(accessToken, currentPos, [0,9]);
currentPos = resp.pos;
if (!resp.ops) {
@ -74,6 +75,10 @@
console.log(output);
document.getElementById("list").textContent = output;
}
console.log("restarting");
document.getElementById("list").textContent = "";
restart = false;
doSyncLoop(accessToken);
}
// accessToken = string, pos = int, ranges = [2]int e.g [0,99]
const doSyncRequest = async (accessToken, pos, ranges) => {
@ -96,7 +101,9 @@
document.getElementById("syncButton").onclick = () => {
const accessToken = document.getElementById("accessToken").value;
doSyncLoop(accessToken);
}
document.getElementById("resetButton").onclick = () => {
restart = true;
}
});
</script>
@ -105,6 +112,7 @@
<div>
<input id="accessToken" type="password" placeholder="matrix.org access token" />
<input id="syncButton" type="button" value="Sync" />
<input id="resetButton" type="button" value="Reset" />
<p id="list" style="white-space: pre;"></p>
</div>
</body>

View File

@ -17,7 +17,7 @@ func (c *ConnID) String() string {
return c.SessionID + "-" + c.DeviceID
}
type HandlerIncomingReqFunc func(ctx context.Context, conn *Conn, req *Request) (*Response, error)
type HandlerIncomingReqFunc func(ctx context.Context, cid ConnID, req *Request) (*Response, error)
// Conn is an abstraction of a long-poll connection. It automatically handles the position values
// of the /sync request, including sending cached data in the event of retries. It does not handle
@ -82,7 +82,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
}
c.lastClientRequest = *req
resp, err := c.HandleIncomingRequest(ctx, c, req)
resp, err := c.HandleIncomingRequest(ctx, c.ConnID, req)
if err != nil {
herr, ok := err.(*internal.HandlerError)
if !ok {

View File

@ -18,7 +18,7 @@ func TestConn(t *testing.T) {
SessionID: "s",
}
count := int64(100)
c := NewConn(connID, nil, func(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
c := NewConn(connID, nil, func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
count += 1
return &Response{
Count: count,
@ -62,7 +62,7 @@ func TestConnBlocking(t *testing.T) {
SessionID: "s",
}
ch := make(chan string)
c := NewConn(connID, nil, func(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
c := NewConn(connID, nil, func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
if req.Sort[0] == "hi" {
time.Sleep(10 * time.Millisecond)
}
@ -109,7 +109,7 @@ func TestConnRetries(t *testing.T) {
SessionID: "s",
}
callCount := int64(0)
c := NewConn(connID, nil, func(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
c := NewConn(connID, nil, func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
callCount += 1
return &Response{Count: 20}, nil
})
@ -144,7 +144,7 @@ func TestConnErrors(t *testing.T) {
SessionID: "s",
}
errCh := make(chan error, 1)
c := NewConn(connID, nil, func(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
c := NewConn(connID, nil, func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
return nil, <-errCh
})
@ -172,7 +172,7 @@ func TestConnErrorsNoCache(t *testing.T) {
SessionID: "s",
}
errCh := make(chan error, 1)
c := NewConn(connID, nil, func(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
c := NewConn(connID, nil, func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
select {
case e := <-errCh:
return nil, e

View File

@ -17,6 +17,7 @@ type EventData struct {
eventType string
stateKey *string
content gjson.Result
timestamp int64
// the absolute latest position for this event data. The NID for this event is guaranteed to
// be <= this value.
latestPos int64
@ -77,7 +78,7 @@ func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) {
if conn != nil {
return conn, false
}
state := NewConnState(userID, m.store, m.roomInfo)
state := NewConnState(userID, m)
conn = NewConn(cid, state, state.HandleIncomingRequest)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
@ -122,17 +123,26 @@ func (m *ConnMap) LoadBaseline(roomIDToUserIDs map[string][]string) error {
for _, userID := range userIDs {
m.jrt.UserJoinedRoom(userID, roomID)
}
fmt.Printf("Room: %+v \n", room)
fmt.Printf("Room: %s - %s - %s \n", room.RoomID, room.Name, time.Unix(room.LastMessageTimestamp/1000, 0))
}
return nil
}
func (m *ConnMap) roomInfo(roomID string) *SortableRoom {
func (m *ConnMap) LoadRoom(roomID string) *SortableRoom {
m.mu.Lock()
defer m.mu.Unlock()
return m.globalRoomInfo[roomID]
}
func (m *ConnMap) Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error) {
initialLoadPosition, err = m.store.LatestEventNID()
if err != nil {
return
}
joinedRoomIDs, err = m.store.JoinedRoomsAfterPosition(userID, initialLoadPosition)
return
}
func (m *ConnMap) closeConn(connID string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
@ -210,6 +220,7 @@ func (m *ConnMap) onNewEvent(
stateKey: stateKey,
content: ev.Get("content"),
latestPos: latestPos,
timestamp: ev.Get("origin_server_ts").Int(),
}
// notify all people in this room

View File

@ -3,12 +3,9 @@ package synclive
import (
"context"
"encoding/json"
"fmt"
"reflect"
"sort"
"time"
"github.com/matrix-org/sync-v3/state"
)
var (
@ -18,28 +15,31 @@ var (
MaxPendingEventUpdates = 100
)
type ConnStateStore interface {
LoadRoom(roomID string) *SortableRoom
Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error)
}
// ConnState tracks all high-level connection state for this connection, like the combined request
// and the underlying sorted room list. It doesn't track session IDs or positions of the connection.
type ConnState struct {
store *state.Storage
store ConnStateStore
muxedReq *Request
userID string
sortedJoinedRooms SortableRooms
sortedJoinedRoomsPositions map[string]int // room_id -> index in sortedJoinedRooms
roomSubscriptions map[string]*Room // TODO
initialLoadPosition int64
loadRoom func(roomID string) *SortableRoom
// A channel which v2 poll loops use to send updates to, via the ConnMap.
// Consumed when the conn is read. There is a limit to how many updates we will store before
// saying the client is ded and cleaning up the conn.
updateEvents chan *EventData
}
func NewConnState(userID string, store *state.Storage, loadRoom func(roomID string) *SortableRoom) *ConnState {
func NewConnState(userID string, store ConnStateStore) *ConnState {
return &ConnState{
store: store,
userID: userID,
loadRoom: loadRoom,
roomSubscriptions: make(map[string]*Room),
sortedJoinedRoomsPositions: make(map[string]int),
updateEvents: make(chan *EventData, MaxPendingEventUpdates), // TODO: customisable
@ -58,27 +58,20 @@ func NewConnState(userID string, store *state.Storage, loadRoom func(roomID stri
// - load() bases its current state based on the latest position, which includes processing of these N events.
// - post load() we read N events, processing them a 2nd time.
func (c *ConnState) load(req *Request) error {
// load from store
var err error
c.initialLoadPosition, err = c.store.LatestEventNID()
if err != nil {
return err
}
joinedRoomIDs, err := c.store.JoinedRoomsAfterPosition(c.userID, c.initialLoadPosition)
joinedRoomIDs, initialLoadPosition, err := c.store.Load(c.userID)
if err != nil {
return err
}
c.initialLoadPosition = initialLoadPosition
c.sortedJoinedRooms = make([]SortableRoom, len(joinedRoomIDs))
for i, roomID := range joinedRoomIDs {
// load global room info
sr := c.loadRoom(roomID)
c.sortedJoinedRooms[i] = SortableRoom{
RoomID: sr.RoomID,
Name: sr.Name,
}
sr := c.store.LoadRoom(roomID)
c.sortedJoinedRooms[i] = *sr
c.sortedJoinedRoomsPositions[sr.RoomID] = i
}
c.sort(req.Sort)
return nil
}
@ -93,7 +86,7 @@ func (c *ConnState) sort(sortBy []string) {
//logger.Info().Interface("pos", c.sortedJoinedRoomsPositions).Msg("sorted")
}
func (c *ConnState) HandleIncomingRequest(ctx context.Context, conn *Conn, req *Request) (*Response, error) {
func (c *ConnState) HandleIncomingRequest(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
if c.initialLoadPosition == 0 {
c.load(req)
}
@ -140,7 +133,6 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *Request) (*Respo
// TODO: update room subscriptions
// TODO: calculate the M values for N < M calcs
fmt.Println("range", s.muxedReq.Rooms, "prev_range", prevRange, "sort", prevSort)
var responseOperations []ResponseOp
@ -152,12 +144,14 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *Request) (*Respo
}
if !reflect.DeepEqual(prevSort, s.muxedReq.Sort) {
// the sort operations have changed, invalidate everything, re-sort and re-SYNC
for _, r := range s.muxedReq.Rooms {
responseOperations = append(responseOperations, &ResponseOpRange{
Operation: "INVALIDATE",
Range: r[:],
})
// the sort operations have changed, invalidate everything (if there were previous syncs), re-sort and re-SYNC
if prevSort != nil {
for _, r := range s.muxedReq.Rooms {
responseOperations = append(responseOperations, &ResponseOpRange{
Operation: "INVALIDATE",
Range: r[:],
})
}
}
s.sort(s.muxedReq.Sort)
added = s.muxedReq.Rooms
@ -207,35 +201,28 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *Request) (*Respo
// TODO: Implement sorting by something other than recency. With recency sorting,
// most operations are DELETE/INSERT to bump rooms to the top of the list. We only
// do an UPDATE if the most recent room gets a 2nd event.
var targetRoom SortableRoom
fromIndex, ok := s.sortedJoinedRoomsPositions[updateEvent.roomID]
if !ok {
// the user may have just joined the room hence not have an entry in this list yet.
fromIndex = -1
fromIndex = len(s.sortedJoinedRooms)
newRoom := s.store.LoadRoom(updateEvent.roomID)
newRoom.LastMessageTimestamp = updateEvent.timestamp
s.sortedJoinedRooms = append(s.sortedJoinedRooms, *newRoom)
targetRoom = *newRoom
} else {
targetRoom = s.sortedJoinedRooms[fromIndex]
targetRoom.LastMessageTimestamp = updateEvent.timestamp
s.sortedJoinedRooms[fromIndex] = targetRoom
}
toIndex := 0 // TODO: this won't always be 0 if we sort by something other than recency
logger.Info().Int("from", fromIndex).Int("to", toIndex).Str("room", updateEvent.roomID).Msg(
logger.Info().Int("from", fromIndex).Interface("room", targetRoom).Msg(
"moving room",
)
// move the server's representation
swap := s.sortedJoinedRooms[toIndex]
var room *SortableRoom
if fromIndex == -1 {
logger.Info().Str("room", updateEvent.roomID).Msg("loading brand new room into sorted list")
room = s.loadRoom(updateEvent.roomID)
// TODO: work out which index position this should be sorted into, depending on the sort operations
// for now we always insert it into toIndex+1
s.sortedJoinedRooms = append([]SortableRoom{
s.sortedJoinedRooms[0], *room,
}, s.sortedJoinedRooms[1:]...)
fromIndex = 1
} else {
room = &s.sortedJoinedRooms[fromIndex]
}
s.sortedJoinedRooms[toIndex] = *room
s.sortedJoinedRooms[fromIndex] = swap
s.sortedJoinedRoomsPositions[room.RoomID] = toIndex
s.sortedJoinedRoomsPositions[swap.RoomID] = fromIndex
// re-sort
s.sort(nil)
toIndex := s.sortedJoinedRoomsPositions[updateEvent.roomID]
logger.Info().Int("to", toIndex).Msg("moved!")
responseOperations = append(
responseOperations, s.moveRoom(updateEvent, fromIndex, toIndex, s.muxedReq.Rooms)...,
@ -255,6 +242,7 @@ func (s *ConnState) UserID() string {
return s.userID
}
// Move a room from an absolute index position to another absolute position.
// 1,2,3,4,5
// 3 bumps to top -> 3,1,2,4,5 -> DELETE index=2, INSERT val=3 index=0
// 7 bumps to top -> 7,1,2,3,4 -> DELETE index=4, INSERT val=7 index=0
@ -285,7 +273,7 @@ func (s *ConnState) moveRoom(updateEvent *EventData, fromIndex, toIndex int, ran
// to the highest end-range marker < index
deleteIndex = int(ranges.LowerClamp(int64(fromIndex)))
}
room := s.loadRoom(updateEvent.roomID)
room := s.store.LoadRoom(updateEvent.roomID)
return []ResponseOp{
&ResponseOpSingle{
Operation: "DELETE",

255
synclive/connstate_test.go Normal file
View File

@ -0,0 +1,255 @@
package synclive
import (
"bytes"
"context"
"encoding/json"
"reflect"
"testing"
)
type connStateStoreMock struct {
roomIDToRoom map[string]*SortableRoom
userIDToJoinedRooms map[string][]string
userIDToPosition map[string]int64
}
func (s *connStateStoreMock) LoadRoom(roomID string) *SortableRoom {
return s.roomIDToRoom[roomID]
}
func (s *connStateStoreMock) Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error) {
joinedRoomIDs = s.userIDToJoinedRooms[userID]
initialLoadPosition = s.userIDToPosition[userID]
if initialLoadPosition == 0 {
initialLoadPosition = 1 // so we don't continually load the same rooms
}
return
}
// Sync an account with 3 rooms and check that we can grab all rooms and they are sorted correctly initially. Checks
// that basic UPDATE and DELETE/INSERT works when tracking all rooms.
func TestConnStateInitial(t *testing.T) {
connID := ConnID{
SessionID: "s",
DeviceID: "d",
}
userID := "@alice:localhost"
roomA := "!a:localhost"
roomB := "!b:localhost"
roomC := "!c:localhost"
timestampNow := int64(1632131678061)
// initial sort order B, C, A
cs := NewConnState(userID, &connStateStoreMock{
userIDToJoinedRooms: map[string][]string{
userID: {roomA, roomB, roomC},
},
roomIDToRoom: map[string]*SortableRoom{
roomA: {
RoomID: roomA,
Name: "Room A",
LastMessageTimestamp: timestampNow - 8000,
},
roomB: {
RoomID: roomB,
Name: "Room B",
LastMessageTimestamp: timestampNow,
},
roomC: {
RoomID: roomC,
Name: "Room C",
LastMessageTimestamp: timestampNow - 4000,
},
},
})
if userID != cs.UserID() {
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
}
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
Sort: []string{SortByRecency},
Rooms: SliceRanges([][2]int64{
{0, 9},
}),
})
if err != nil {
t.Fatalf("HandleIncomingRequest returned error : %s", err)
}
checkResponse(t, false, res, &Response{
Count: 3,
Ops: []ResponseOp{
&ResponseOpRange{
Operation: "SYNC",
Range: []int64{0, 9},
Rooms: []Room{
{
RoomID: roomB,
Name: "Room B",
},
{
RoomID: roomC,
Name: "Room C",
},
{
RoomID: roomA,
Name: "Room A",
},
},
},
},
})
// bump A to the top
cs.PushNewEvent(&EventData{
event: json.RawMessage(`{}`),
roomID: roomA,
eventType: "unimportant",
timestamp: timestampNow + 1000,
})
// request again for the diff
res, err = cs.HandleIncomingRequest(context.Background(), connID, &Request{
Sort: []string{SortByRecency},
Rooms: SliceRanges([][2]int64{
{0, 9},
}),
})
if err != nil {
t.Fatalf("HandleIncomingRequest returned error : %s", err)
}
checkResponse(t, true, res, &Response{
Count: 3,
Ops: []ResponseOp{
&ResponseOpSingle{
Operation: "DELETE",
Index: intPtr(2),
},
&ResponseOpSingle{
Operation: "INSERT",
Index: intPtr(0),
Room: &Room{
RoomID: roomA,
},
},
},
})
// another message should just update
cs.PushNewEvent(&EventData{
event: json.RawMessage(`{}`),
roomID: roomA,
eventType: "still unimportant",
timestamp: timestampNow + 2000,
})
res, err = cs.HandleIncomingRequest(context.Background(), connID, &Request{
Sort: []string{SortByRecency},
Rooms: SliceRanges([][2]int64{
{0, 9},
}),
})
if err != nil {
t.Fatalf("HandleIncomingRequest returned error : %s", err)
}
checkResponse(t, true, res, &Response{
Count: 3,
Ops: []ResponseOp{
&ResponseOpSingle{
Operation: "UPDATE",
Index: intPtr(0),
Room: &Room{
RoomID: roomA,
},
},
},
})
}
func checkResponse(t *testing.T, checkRoomIDsOnly bool, got, want *Response) {
t.Helper()
if want.Count > 0 {
if got.Count != want.Count {
t.Errorf("response Count: got %d want %d", got.Count, want.Count)
}
}
if len(want.Ops) > 0 {
t.Logf("got %v", serialise(t, got))
t.Logf("want %v", serialise(t, want))
defer func() {
t.Helper()
if !t.Failed() {
t.Logf("OK!")
}
}()
if len(got.Ops) != len(want.Ops) {
t.Fatalf("got %d ops, want %d", len(got.Ops), len(want.Ops))
}
for i, wantOpVal := range want.Ops {
gotOp := got.Ops[i]
if gotOp.Op() != wantOpVal.Op() {
t.Errorf("operation i=%d got '%s' want '%s'", i, gotOp.Op(), wantOpVal.Op())
}
switch wantOp := wantOpVal.(type) {
case *ResponseOpRange:
gotOpRange, ok := gotOp.(*ResponseOpRange)
if !ok {
t.Fatalf("operation i=%d (%s) want type ResponseOpRange but it isn't", i, gotOp.Op())
}
if !reflect.DeepEqual(gotOpRange.Range, wantOp.Range) {
t.Errorf("operation i=%d (%s) got range %v want range %v", i, gotOp.Op(), gotOpRange.Range, wantOp.Range)
}
if len(gotOpRange.Rooms) != len(wantOp.Rooms) {
t.Fatalf("operation i=%d (%s) got %d rooms in array, want %d", i, gotOp.Op(), len(gotOpRange.Rooms), len(wantOp.Rooms))
}
for j := range wantOp.Rooms {
checkRoomsEqual(t, checkRoomIDsOnly, &gotOpRange.Rooms[j], &wantOp.Rooms[j])
}
case *ResponseOpSingle:
gotOpSingle, ok := gotOp.(*ResponseOpSingle)
if !ok {
t.Fatalf("operation i=%d (%s) want type ResponseOpSingle but it isn't", i, gotOp.Op())
}
if *gotOpSingle.Index != *wantOp.Index {
t.Errorf("operation i=%d (%s) single op on index %d want index %d", i, gotOp.Op(), *gotOpSingle.Index, *wantOp.Index)
}
checkRoomsEqual(t, checkRoomIDsOnly, gotOpSingle.Room, wantOp.Room)
}
}
}
}
func checkRoomsEqual(t *testing.T, checkRoomIDsOnly bool, got, want *Room) {
t.Helper()
if got == nil && want == nil {
return // e.g DELETE ops
}
if (got == nil && want != nil) || (want == nil && got != nil) {
t.Fatalf("nil room, got %+v want %+v", got, want)
}
if checkRoomIDsOnly {
if got.RoomID != want.RoomID {
t.Fatalf("got room '%s' want room '%s'", got.RoomID, want.RoomID)
}
return
}
gotBytes, err := json.Marshal(got)
if err != nil {
t.Fatalf("cannot marshal got room: %s", err)
}
wantBytes, err := json.Marshal(want)
if err != nil {
t.Fatalf("cannot marshal want room: %s", err)
}
if !bytes.Equal(gotBytes, wantBytes) {
t.Errorf("rooms do not match,\ngot %s want %s", string(gotBytes), string(wantBytes))
}
}
func serialise(t *testing.T, thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {
t.Fatalf("cannot serialise: %s", err)
}
return string(b)
}
func intPtr(val int) *int {
return &val
}

18
synclive/main_test.go Normal file
View File

@ -0,0 +1,18 @@
package synclive
import (
"os"
"testing"
"github.com/rs/zerolog"
)
func TestMain(m *testing.M) {
logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
NoColor: true,
})
exitCode := m.Run()
os.Exit(exitCode)
}