Add room name filtering; Remove session IDs entirely

Should fix #19
This commit is contained in:
Kegan Dougal 2022-02-18 16:49:26 +00:00
parent 129f6fa61c
commit b208a2e2b3
12 changed files with 111 additions and 117 deletions

View File

@ -22,7 +22,6 @@ func TestExtensionE2EE(t *testing.T) {
defer v3.close()
alice := "@TestExtensionE2EE_alice:localhost"
aliceToken := "ALICE_BEARER_TOKEN_TestExtensionE2EE"
sessionID := "sid"
// check that OTK counts go through
otkCounts := map[string]int{
@ -45,7 +44,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchOTKCounts(otkCounts))
@ -72,7 +70,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchOTKCounts(otkCounts))
@ -97,7 +94,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchOTKCounts(otkCounts))
@ -127,7 +123,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchDeviceLists(wantChanged, wantLeft))
@ -144,7 +139,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchDeviceLists(wantChanged, wantLeft))
@ -171,7 +165,6 @@ func TestExtensionE2EE(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, func(res *sync3.Response) error {
if res.Extensions.E2EE.DeviceLists != nil {
@ -191,7 +184,6 @@ func TestExtensionToDevice(t *testing.T) {
defer v3.close()
alice := "@TestExtensionToDevice_alice:localhost"
aliceToken := "ALICE_BEARER_TOKEN_TestExtensionToDevice"
sessionID := "sid"
v2.addAccount(alice, aliceToken)
toDeviceMsgs := []json.RawMessage{
json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"1"}}`),
@ -217,7 +209,6 @@ func TestExtensionToDevice(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(0), MatchToDeviceMessages(toDeviceMsgs))
@ -233,7 +224,6 @@ func TestExtensionToDevice(t *testing.T) {
Enabled: true,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(0), MatchToDeviceMessages(toDeviceMsgs))
@ -250,7 +240,6 @@ func TestExtensionToDevice(t *testing.T) {
Since: res.Extensions.ToDevice.NextBatch,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(0), MatchToDeviceMessages([]json.RawMessage{}))
@ -276,7 +265,6 @@ func TestExtensionToDevice(t *testing.T) {
Since: res.Extensions.ToDevice.NextBatch,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(0), MatchToDeviceMessages(newToDeviceMsgs))
@ -293,7 +281,6 @@ func TestExtensionToDevice(t *testing.T) {
Since: res.Extensions.ToDevice.NextBatch,
},
},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(0), MatchToDeviceMessages([]json.RawMessage{}))

View File

@ -53,45 +53,44 @@ func TestFilters(t *testing.T) {
})
// connect and make sure either the encrypted room or not depending on what the filter says
encryptedSessionID := "encrypted_session"
encryptedRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{
{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
Filters: &sync3.RequestFilters{
IsEncrypted: &boolTrue,
},
},
Filters: &sync3.RequestFilters{
IsEncrypted: &boolTrue,
{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
Filters: &sync3.RequestFilters{
IsEncrypted: &boolFalse,
},
},
}},
SessionID: encryptedSessionID,
},
})
MatchResponse(t, encryptedRes, MatchV3Count(1), MatchV3Ops(
MatchResponse(t, res, MatchV3Counts([]int{1, 1}), MatchV3Ops(
MatchV3SyncOp(func(op *sync3.ResponseOpRange) error {
if len(op.Rooms) != 1 {
return fmt.Errorf("want %d rooms, got %d", 1, len(op.Rooms))
}
if op.List != 0 {
return fmt.Errorf("unknown list: %d", op.List)
}
return allRooms[0].MatchRoom(op.Rooms[0]) // encrypted room
}),
))
unencryptedSessionID := "unencrypted_session"
unencryptedRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
Filters: &sync3.RequestFilters{
IsEncrypted: &boolFalse,
},
}},
SessionID: unencryptedSessionID,
})
MatchResponse(t, unencryptedRes, MatchV3Count(1), MatchV3Ops(
MatchV3SyncOp(func(op *sync3.ResponseOpRange) error {
if len(op.Rooms) != 1 {
return fmt.Errorf("want %d rooms, got %d", 1, len(op.Rooms))
}
return allRooms[1].MatchRoom(op.Rooms[0]) // unencrypted room
if op.List != 1 {
return fmt.Errorf("unknown list: %d", op.List)
}
return allRooms[1].MatchRoom(op.Rooms[0]) // encrypted room
}),
))
@ -118,22 +117,30 @@ func TestFilters(t *testing.T) {
v2.waitUntilEmpty(t, alice)
// now requesting the encrypted list should include it (added)
encryptedRes = v3.mustDoV3RequestWithPos(t, aliceToken, encryptedRes.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: []sync3.RequestList{
{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
// sticky; should remember filters
},
// sticky; should remember filters
}},
SessionID: encryptedSessionID,
{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
// sticky; should remember filters
},
},
})
MatchResponse(t, encryptedRes, MatchV3Count(len(allRooms)), MatchV3Ops(
MatchResponse(t, res, MatchV3Counts([]int{len(allRooms), 0}), MatchV3Ops(
MatchV3DeleteOp(1, 0),
MatchV3DeleteOp(0, 1),
MatchV3InsertOp(0, 0, unencryptedRoomID),
))
// requesting the encrypted list from scratch returns 2 rooms now
encryptedRes = v3.mustDoV3Request(t, aliceToken, sync3.Request{
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
@ -142,9 +149,8 @@ func TestFilters(t *testing.T) {
IsEncrypted: &boolTrue,
},
}},
SessionID: "new_encrypted_session",
})
MatchResponse(t, encryptedRes, MatchV3Count(2), MatchV3Ops(
MatchResponse(t, res, MatchV3Count(2), MatchV3Ops(
MatchV3SyncOp(func(op *sync3.ResponseOpRange) error {
if len(op.Rooms) != len(allRooms) {
return fmt.Errorf("want %d rooms, got %d", len(allRooms), len(op.Rooms))
@ -162,22 +168,8 @@ func TestFilters(t *testing.T) {
}),
))
// requesting the unencrypted stream DELETEs the room without a corresponding INSERT
unencryptedRes = v3.mustDoV3RequestWithPos(t, aliceToken, unencryptedRes.Pos, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
// sticky; should remember filters
}},
SessionID: unencryptedSessionID,
})
MatchResponse(t, unencryptedRes, MatchV3Count(0), MatchV3Ops(
MatchV3DeleteOp(0, 0),
))
// requesting the unencrypted stream from scratch returns 0 rooms
unencryptedRes = v3.mustDoV3Request(t, aliceToken, sync3.Request{
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
@ -186,7 +178,6 @@ func TestFilters(t *testing.T) {
IsEncrypted: &boolFalse,
},
}},
SessionID: "new_unencrypted_session",
})
MatchResponse(t, unencryptedRes, MatchV3Count(0))
MatchResponse(t, res, MatchV3Count(0))
}

View File

@ -63,7 +63,6 @@ func TestNotificationsOnTop(t *testing.T) {
// prefer highlight count first, THEN eventually recency
Sort: []string{sync3.SortByHighlightCount, sync3.SortByNotificationCount, sync3.SortByRecency},
}},
SessionID: t.Name(),
}
res := v3.mustDoV3Request(t, aliceToken, syncRequestBody)
MatchResponse(t, res, MatchV3Count(len(allRooms)), MatchV3Ops(
@ -140,7 +139,6 @@ func TestNotificationsOnTop(t *testing.T) {
// prefer highlight count first, THEN eventually recency
Sort: []string{sync3.SortByHighlightCount, sync3.SortByNotificationCount, sync3.SortByRecency},
}},
SessionID: t.Name(),
})
MatchResponse(t, res, MatchV3Count(len(allRooms)), MatchV3Ops(
MatchV3SyncOp(func(op *sync3.ResponseOpRange) error {

View File

@ -78,7 +78,6 @@ func TestRoomNames(t *testing.T) {
},
TimelineLimit: int64(100),
}},
SessionID: sessionID,
})
MatchResponse(t, res, MatchV3Count(len(allRooms)), MatchV3Ops(
MatchV3SyncOp(func(op *sync3.ResponseOpRange) error {
@ -103,4 +102,36 @@ func TestRoomNames(t *testing.T) {
// restart the server and repeat the tests, should still be the same when reading from the database
v3.restart(t, v2, pqString)
checkRoomNames("b")
// now check that we can filter the rooms by name
checkRoomNameFilter := func(searchTerm string, wantRooms []roomEvents) {
t.Helper()
// do a sync, make sure room names are sensible
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
Ranges: sync3.SliceRanges{
[2]int64{0, int64(len(allRooms) - 1)}, // all rooms
},
Filters: &sync3.RequestFilters{
RoomNameFilter: searchTerm,
},
}},
})
matchers := make([][]roomMatcher, len(wantRooms))
for i := range wantRooms {
matchers[i] = []roomMatcher{
MatchRoomName(wantRooms[i].name),
MatchRoomID(wantRooms[i].roomID),
}
}
MatchResponse(t, res, MatchV3Count(len(wantRooms)), MatchV3Ops(
MatchV3SyncOpWithMatchers(MatchRoomRange(matchers...)),
))
}
// case-insensitive matching
checkRoomNameFilter("my room name", []roomEvents{allRooms[1]})
// partial matching
checkRoomNameFilter("room na", []roomEvents{allRooms[1]})
// multiple matches
checkRoomNameFilter("bob", []roomEvents{allRooms[0], allRooms[3]})
}

View File

@ -10,12 +10,11 @@ import (
)
type ConnID struct {
SessionID string
DeviceID string
DeviceID string
}
func (c *ConnID) String() string {
return c.SessionID + "-" + c.DeviceID
return c.DeviceID
}
type ConnHandler interface {

View File

@ -29,8 +29,7 @@ func (c *connHandlerMock) Alive() bool { return true }
func TestConn(t *testing.T) {
ctx := context.Background()
connID := ConnID{
DeviceID: "d",
SessionID: "s",
DeviceID: "d",
}
count := 100
c := NewConn(connID, &connHandlerMock{func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
@ -74,8 +73,7 @@ func TestConn(t *testing.T) {
func TestConnBlocking(t *testing.T) {
ctx := context.Background()
connID := ConnID{
DeviceID: "d",
SessionID: "s",
DeviceID: "d",
}
ch := make(chan string)
c := NewConn(connID, &connHandlerMock{func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
@ -129,8 +127,7 @@ func TestConnBlocking(t *testing.T) {
func TestConnRetries(t *testing.T) {
ctx := context.Background()
connID := ConnID{
DeviceID: "d",
SessionID: "s",
DeviceID: "d",
}
callCount := int64(0)
c := NewConn(connID, &connHandlerMock{func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
@ -169,8 +166,7 @@ func TestConnRetries(t *testing.T) {
func TestConnErrors(t *testing.T) {
ctx := context.Background()
connID := ConnID{
DeviceID: "d",
SessionID: "s",
DeviceID: "d",
}
errCh := make(chan error, 1)
c := NewConn(connID, &connHandlerMock{func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {
@ -197,8 +193,7 @@ func TestConnErrors(t *testing.T) {
func TestConnErrorsNoCache(t *testing.T) {
ctx := context.Background()
connID := ConnID{
DeviceID: "d",
SessionID: "s",
DeviceID: "d",
}
errCh := make(chan error, 1)
c := NewConn(connID, &connHandlerMock{func(ctx context.Context, cid ConnID, req *Request) (*Response, error) {

View File

@ -1,6 +1,7 @@
package sync3
import (
"fmt"
"sync"
"time"
@ -41,19 +42,21 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
}
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
func (m *ConnMap) GetOrCreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Conn, bool) {
// atomically check if a conn exists already and return that if so
func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Conn, bool) {
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
conn := m.Conn(cid)
if conn != nil {
return conn, false
// tear down this connection and fallthrough
m.closeConn(conn)
}
h := newConnHandler()
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
fmt.Println("created connection", cid.String())
return conn, true
}
@ -78,6 +81,7 @@ func (m *ConnMap) closeConn(conn *Conn) {
}
connID := conn.ConnID.String()
fmt.Println("tearing down connection", connID)
// remove conn from all the maps
delete(m.connIDToConn, connID)
h := conn.handler

View File

@ -52,8 +52,7 @@ func mockLazyRoomOverride(loadPos int64, roomIDs []string, maxTimelineEvents int
// that basic UPDATE and DELETE/INSERT works when tracking all rooms.
func TestConnStateInitial(t *testing.T) {
ConnID := sync3.ConnID{
SessionID: "s",
DeviceID: "d",
DeviceID: "d",
}
userID := "@TestConnStateInitial_alice:localhost"
deviceID := "yep"
@ -207,8 +206,7 @@ func TestConnStateInitial(t *testing.T) {
func TestConnStateMultipleRanges(t *testing.T) {
t.Skip("flakey")
ConnID := sync3.ConnID{
SessionID: "s",
DeviceID: "d",
DeviceID: "d",
}
userID := "@TestConnStateMultipleRanges_alice:localhost"
deviceID := "yep"
@ -393,8 +391,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
// Regression test for https://github.com/matrix-org/sync-v3/commit/732ea46f1ccde2b6a382e0f849bbd166b80900ed
func TestBumpToOutsideRange(t *testing.T) {
ConnID := sync3.ConnID{
SessionID: "s",
DeviceID: "d",
DeviceID: "d",
}
userID := "@TestBumpToOutsideRange_alice:localhost"
deviceID := "yep"
@ -485,8 +482,7 @@ func TestBumpToOutsideRange(t *testing.T) {
// Test that room subscriptions can be made and that events are pushed for them.
func TestConnStateRoomSubscriptions(t *testing.T) {
ConnID := sync3.ConnID{
SessionID: "s",
DeviceID: "d",
DeviceID: "d",
}
userID := "@TestConnStateRoomSubscriptions_alice:localhost"
deviceID := "yep"

View File

@ -115,10 +115,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
}
}
if requestBody.SessionID == "" {
requestBody.SessionID = DefaultSessionID
}
fmt.Println("incoming sync pos=", req.URL.Query().Get("pos"))
conn, err := h.setupConnection(req, &requestBody, req.URL.Query().Get("pos") != "")
if err != nil {
hlog.FromRequest(req).Err(err).Msg("failed to get or create Conn")
@ -181,10 +178,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// client thinks they have a connection
if containsPos {
// Lookup the connection
// we need to map based on both as the session ID isn't crypto secure but the device ID is (Auth header)
conn = h.ConnMap.Conn(sync3.ConnID{
SessionID: syncReq.SessionID,
DeviceID: deviceID,
DeviceID: deviceID,
})
if err != nil {
log.Warn().Err(err).Msg("failed to lookup conn for request")
@ -194,6 +189,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
}
}
if conn != nil {
fmt.Println("returning existing conn", conn.ConnID.String())
return conn, nil
}
// conn doesn't exist, we probably nuked it.
@ -245,10 +241,9 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection and get the same session ID.
conn, created := h.ConnMap.GetOrCreateConn(sync3.ConnID{
SessionID: syncReq.SessionID,
DeviceID: deviceID,
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
DeviceID: deviceID,
}, func() sync3.ConnHandler {
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher)
})

View File

@ -3,7 +3,9 @@ package sync3
import (
"bytes"
"encoding/json"
"strings"
"github.com/matrix-org/sync-v3/internal"
"github.com/matrix-org/sync-v3/sync3/extensions"
)
@ -27,7 +29,6 @@ type Request struct {
// set via query params or inferred
pos int64
timeoutSecs int
SessionID string `json:"session_id"`
}
type RequestList struct {
@ -65,12 +66,7 @@ func (r *Request) Same(other *Request) bool {
func (r *Request) ApplyDelta(nextReq *Request) (result *Request, subs, unsubs []string) {
// Use the newer values unless they aren't specified, then use the older ones.
// Go is ew in that this can't be represented in a nicer way
sessionID := nextReq.SessionID
if sessionID == "" {
sessionID = r.SessionID
}
result = &Request{
SessionID: sessionID,
Extensions: nextReq.Extensions, // TODO: make them sticky
}
lists := make([]RequestList, len(nextReq.Lists))
@ -174,10 +170,11 @@ func (r *Request) GetRequiredState(listIndex int, roomID string) [][2]string {
}
type RequestFilters struct {
Spaces []string `json:"spaces"`
IsDM *bool `json:"is_dm"`
IsEncrypted *bool `json:"is_encrypted"`
IsInvite *bool `json:"is_invite"`
Spaces []string `json:"spaces"`
IsDM *bool `json:"is_dm"`
IsEncrypted *bool `json:"is_encrypted"`
IsInvite *bool `json:"is_invite"`
RoomNameFilter string `json:"room_name_like"`
// TODO options to control which events should be live-streamed e.g not_types, types from sync v2
}
@ -191,6 +188,9 @@ func (rf *RequestFilters) Include(r *RoomConnMetadata) bool {
if rf.IsInvite != nil && *rf.IsInvite != r.IsInvite {
return false
}
if rf.RoomNameFilter != "" && !strings.Contains(strings.ToLower(internal.CalculateRoomName(&r.RoomMetadata, 5)), strings.ToLower(rf.RoomNameFilter)) {
return false
}
return true
}

View File

@ -15,7 +15,6 @@ func TestRequestApplyDeltas(t *testing.T) {
}{
{
input: Request{
SessionID: "a",
Lists: []RequestList{
{
Sort: []string{SortByName},

View File

@ -394,7 +394,6 @@ func testTimelineLoadInitialEvents(v3 *testV3Server, token string, count int, wa
},
TimelineLimit: int64(numTimelineEventsPerRoom),
}},
SessionID: t.Name(),
})
MatchResponse(t, res, MatchV3Count(count), MatchV3Ops(