extensions refactor: handle processing extensions in the same way

This allows us to automatically trace and automatically process only
enabled extensions. Live update code will be modified to use the same
code paths.
This commit is contained in:
Kegan Dougal 2023-02-08 12:58:52 +00:00
parent b7defa2723
commit 66a010f249
11 changed files with 150 additions and 127 deletions

View File

@ -1,6 +1,7 @@
package extensions
import (
"context"
"encoding/json"
"github.com/matrix-org/sliding-sync/state"
@ -68,35 +69,35 @@ func ProcessLiveAccountData(up caches.Update, store *state.Storage, updateWillRe
return nil
}
func ProcessAccountData(store *state.Storage, roomIDToTimeline map[string][]string, userID string, isInitial bool, req *AccountDataRequest) (res *AccountDataResponse) {
roomIDs := make([]string, len(roomIDToTimeline))
func (r *AccountDataRequest) Process(ctx context.Context, res *Response, extCtx Context) {
roomIDs := make([]string, len(extCtx.RoomIDToTimeline))
i := 0
for roomID := range roomIDToTimeline {
for roomID := range extCtx.RoomIDToTimeline {
roomIDs[i] = roomID
i++
}
res = &AccountDataResponse{}
extRes := &AccountDataResponse{}
// room account data needs to be sent every time the user scrolls the list to get new room IDs
// TODO: remember which rooms the client has been told about
if len(roomIDs) > 0 {
roomsAccountData, err := store.AccountDatas(userID, roomIDs...)
roomsAccountData, err := extCtx.Store.AccountDatas(extCtx.UserID, roomIDs...)
if err != nil {
logger.Err(err).Str("user", userID).Strs("rooms", roomIDs).Msg("failed to fetch room account data")
logger.Err(err).Str("user", extCtx.UserID).Strs("rooms", roomIDs).Msg("failed to fetch room account data")
} else {
res.Rooms = make(map[string][]json.RawMessage)
extRes.Rooms = make(map[string][]json.RawMessage)
for _, ad := range roomsAccountData {
res.Rooms[ad.RoomID] = append(res.Rooms[ad.RoomID], ad.Data)
extRes.Rooms[ad.RoomID] = append(extRes.Rooms[ad.RoomID], ad.Data)
}
}
}
// global account data is only sent on the first connection, then we live stream
if isInitial {
globalAccountData, err := store.AccountDatas(userID)
if extCtx.IsInitial {
globalAccountData, err := extCtx.Store.AccountDatas(extCtx.UserID)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to fetch global account data")
logger.Err(err).Str("user", extCtx.UserID).Msg("failed to fetch global account data")
} else {
res.Global = accountEventsAsJSON(globalAccountData)
extRes.Global = accountEventsAsJSON(globalAccountData)
}
}
return
res.AccountData = extRes
}

View File

@ -1,8 +1,9 @@
package extensions
import (
"context"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
// Fetcher used by the E2EE extension
@ -38,33 +39,32 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0
}
func ProcessLiveE2EE(up caches.Update, fetcher E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) {
_, ok := up.(caches.DeviceDataUpdate)
if !ok {
return nil
}
return ProcessE2EE(fetcher, userID, deviceID, req, false)
}
func ProcessE2EE(fetcher E2EEFetcher, userID, deviceID string, req *E2EERequest, isInitial bool) (res *E2EEResponse) {
func (r *E2EERequest) Process(ctx context.Context, res *Response, extCtx Context) {
// pull OTK counts and changed/left from device data
dd := fetcher.DeviceData(userID, deviceID, isInitial)
res = &E2EEResponse{}
dd := extCtx.E2EEFetcher.DeviceData(extCtx.UserID, extCtx.DeviceID, extCtx.IsInitial)
if dd == nil {
return res // unknown device?
return // unknown device?
}
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || isInitial) {
res.FallbackKeyTypes = dd.FallbackKeyTypes
extRes := &E2EEResponse{}
hasUpdates := false
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || extCtx.IsInitial) {
extRes.FallbackKeyTypes = dd.FallbackKeyTypes
hasUpdates = true
}
if dd.OTKCounts != nil && (dd.OTKCountChanged() || isInitial) {
res.OTKCounts = dd.OTKCounts
if dd.OTKCounts != nil && (dd.OTKCountChanged() || extCtx.IsInitial) {
extRes.OTKCounts = dd.OTKCounts
hasUpdates = true
}
changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent)
if len(changed) > 0 || len(left) > 0 {
res.DeviceLists = &E2EEDeviceList{
extRes.DeviceLists = &E2EEDeviceList{
Changed: changed,
Left: left,
}
hasUpdates = true
}
return
if !hasUpdates {
return
}
res.E2EE = extRes // TODO: aggregate
}

View File

@ -98,8 +98,16 @@ func (e Response) HasData(isInitial bool) bool {
(e.Receipts != nil && e.Receipts.HasData(isInitial))
}
type Context struct {
*Handler
RoomIDToTimeline map[string][]string
IsInitial bool
UserID string
DeviceID string
}
type HandlerInterface interface {
Handle(ctx context.Context, req Request, roomIDToTimeline map[string][]string, isInitial bool) (res Response)
Handle(ctx context.Context, req Request, extCtx Context) (res Response)
HandleLiveUpdate(update caches.Update, req Request, res *Response, updateWillReturnResponse, isInitial bool)
}
@ -131,7 +139,15 @@ func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Respo
}
}
if req.ToDevice != nil && req.ToDevice.Enabled != nil && *req.ToDevice.Enabled {
res.ToDevice = ProcessLiveToDeviceEvents(update, h.Store, req.UserID, req.DeviceID, req.ToDevice)
_, ok := update.(caches.DeviceEventsUpdate)
if ok {
req.ToDevice.Process(context.Background(), res, Context{
Handler: h,
IsInitial: false,
UserID: req.UserID,
DeviceID: req.DeviceID,
})
}
}
// only process 'live' e2ee when we aren't going to return data as we need to ensure that we don't calculate this twice
// e.g once on incoming request then again due to wakeup
@ -139,34 +155,24 @@ func (h *Handler) HandleLiveUpdate(update caches.Update, req Request, res *Respo
if res.E2EE != nil && res.E2EE.HasData(false) {
return
}
res.E2EE = ProcessLiveE2EE(update, h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE)
_, ok := update.(caches.DeviceDataUpdate)
if ok {
req.E2EE.Process(context.Background(), res, Context{
Handler: h,
IsInitial: false,
UserID: req.UserID,
DeviceID: req.DeviceID,
})
}
}
}
func (h *Handler) Handle(ctx context.Context, req Request, roomIDToTimeline map[string][]string, isInitial bool) (res Response) {
if req.ToDevice != nil && req.ToDevice.Enabled != nil && *req.ToDevice.Enabled {
region := trace.StartRegion(ctx, "extension_to_device")
res.ToDevice = ProcessToDevice(h.Store, req.UserID, req.DeviceID, req.ToDevice, isInitial)
region.End()
}
if req.E2EE != nil && req.E2EE.Enabled != nil && *req.E2EE.Enabled {
region := trace.StartRegion(ctx, "extension_e2ee")
res.E2EE = ProcessE2EE(h.E2EEFetcher, req.UserID, req.DeviceID, req.E2EE, isInitial)
region.End()
}
if req.AccountData != nil && req.AccountData.Enabled != nil && *req.AccountData.Enabled {
region := trace.StartRegion(ctx, "extension_account_data")
res.AccountData = ProcessAccountData(h.Store, roomIDToTimeline, req.UserID, isInitial, req.AccountData)
region.End()
}
if req.Typing != nil && req.Typing.Enabled != nil && *req.Typing.Enabled {
region := trace.StartRegion(ctx, "extension_typing")
res.Typing = ProcessTyping(h.GlobalCache, roomIDToTimeline, req.UserID, isInitial, req.Typing)
region.End()
}
if req.Receipts != nil && req.Receipts.Enabled != nil && *req.Receipts.Enabled {
region := trace.StartRegion(ctx, "extension_receipts")
res.Receipts = ProcessReceipts(h.Store, roomIDToTimeline, req.UserID, isInitial, req.Receipts)
func (h *Handler) Handle(ctx context.Context, req Request, extCtx Context) (res Response) {
extCtx.Handler = h
exts := req.EnabledExtensions()
for _, ext := range exts {
region := trace.StartRegion(ctx, "extension_"+ext.Name())
ext.Process(ctx, &res, extCtx)
region.End()
}
return

View File

@ -1,11 +1,15 @@
package extensions
import "context"
type GenericRequest interface {
Name() string
// Returns the value of the `enabled` JSON key. nil for "not specified".
IsEnabled() *bool
// Overwrite fields in the request by side-effecting on this struct.
ApplyDelta(next GenericRequest)
// Process this request and put the response into *Response.
Process(ctx context.Context, res *Response, extCtx Context)
}
// mixin for managing the enabled flag

View File

@ -1,6 +1,7 @@
package extensions
import (
"context"
"encoding/json"
"github.com/matrix-org/sliding-sync/state"
@ -42,27 +43,29 @@ func ProcessLiveReceipts(up caches.Update, updateWillReturnResponse bool, userID
return nil
}
func ProcessReceipts(store *state.Storage, roomIDToTimeline map[string][]string, userID string, isInitial bool, req *ReceiptsRequest) (res *ReceiptsResponse) {
func (r *ReceiptsRequest) Process(ctx context.Context, res *Response, extCtx Context) {
// grab receipts for all timelines for all the rooms we're going to return
res = &ReceiptsResponse{
Rooms: make(map[string]json.RawMessage),
}
for roomID, timeline := range roomIDToTimeline {
receipts, err := store.ReceiptTable.SelectReceiptsForEvents(roomID, timeline)
rooms := make(map[string]json.RawMessage)
for roomID, timeline := range extCtx.RoomIDToTimeline {
receipts, err := extCtx.Store.ReceiptTable.SelectReceiptsForEvents(roomID, timeline)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to SelectReceiptsForEvents")
logger.Err(err).Str("user", extCtx.UserID).Str("room", roomID).Msg("failed to SelectReceiptsForEvents")
continue
}
// always include your own receipts
ownReceipts, err := store.ReceiptTable.SelectReceiptsForUser(roomID, userID)
ownReceipts, err := extCtx.Store.ReceiptTable.SelectReceiptsForUser(roomID, extCtx.UserID)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to SelectReceiptsForUser")
logger.Err(err).Str("user", extCtx.UserID).Str("room", roomID).Msg("failed to SelectReceiptsForUser")
continue
}
if len(receipts) == 0 && len(ownReceipts) == 0 {
continue
}
res.Rooms[roomID], _ = state.PackReceiptsIntoEDU(append(receipts, ownReceipts...))
rooms[roomID], _ = state.PackReceiptsIntoEDU(append(receipts, ownReceipts...))
}
if len(rooms) > 0 {
res.Receipts = &ReceiptsResponse{
Rooms: rooms, // TODO aggregate
}
}
return
}

View File

@ -1,13 +1,11 @@
package extensions
import (
"context"
"encoding/json"
"fmt"
"strconv"
"sync"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sync3/caches"
)
// used to remember since positions to warn when they are not incremented. This can happen
@ -49,62 +47,53 @@ func (r *ToDeviceResponse) HasData(isInitial bool) bool {
return len(r.Events) > 0
}
func ProcessLiveToDeviceEvents(up caches.Update, store *state.Storage, userID, deviceID string, req *ToDeviceRequest) (res *ToDeviceResponse) {
_, ok := up.(caches.DeviceEventsUpdate)
if !ok {
return nil
func (r *ToDeviceRequest) Process(ctx context.Context, res *Response, extCtx Context) {
if r.Limit == 0 {
r.Limit = 100 // default to 100
}
return ProcessToDevice(store, userID, deviceID, req, false)
}
func ProcessToDevice(store *state.Storage, userID, deviceID string, req *ToDeviceRequest, isInitial bool) (res *ToDeviceResponse) {
if req.Limit == 0 {
req.Limit = 100 // default to 100
}
l := logger.With().Str("user", userID).Str("device", deviceID).Logger()
l := logger.With().Str("user", extCtx.UserID).Str("device", extCtx.DeviceID).Logger()
var from int64
var err error
if req.Since != "" {
from, err = strconv.ParseInt(req.Since, 10, 64)
if r.Since != "" {
from, err = strconv.ParseInt(r.Since, 10, 64)
if err != nil {
l.Err(err).Str("since", req.Since).Msg("invalid since value")
return nil
l.Err(err).Str("since", r.Since).Msg("invalid since value")
return
}
// the client is confirming messages up to `from` so delete everything up to and including it.
if err = store.ToDeviceTable.DeleteMessagesUpToAndIncluding(deviceID, from); err != nil {
l.Err(err).Str("since", req.Since).Msg("failed to delete to-device messages up to this value")
if err = extCtx.Store.ToDeviceTable.DeleteMessagesUpToAndIncluding(extCtx.DeviceID, from); err != nil {
l.Err(err).Str("since", r.Since).Msg("failed to delete to-device messages up to this value")
// non-fatal TODO sentry
}
}
mapMu.Lock()
lastSentPos := deviceIDToSinceDebugOnly[deviceID]
lastSentPos := deviceIDToSinceDebugOnly[extCtx.DeviceID]
mapMu.Unlock()
if from < lastSentPos {
// we told the client about a newer position, but yet they are using an older position, yell loudly
// TODO sentry
l.Warn().Int64("last_sent", lastSentPos).Int64("recv", from).Bool("initial", isInitial).Msg(
l.Warn().Int64("last_sent", lastSentPos).Int64("recv", from).Bool("initial", extCtx.IsInitial).Msg(
"Client did not increment since token: possibly sending back duplicate to-device events!",
)
}
msgs, upTo, err := store.ToDeviceTable.Messages(deviceID, from, int64(req.Limit))
msgs, upTo, err := extCtx.Store.ToDeviceTable.Messages(extCtx.DeviceID, from, int64(r.Limit))
if err != nil {
l.Err(err).Int64("from", from).Msg("cannot query to-device messages")
// TODO sentry
return nil
return
}
err = store.ToDeviceTable.SetUnackedPosition(deviceID, upTo)
err = extCtx.Store.ToDeviceTable.SetUnackedPosition(extCtx.DeviceID, upTo)
if err != nil {
l.Err(err).Msg("cannot set unacked position")
// TODO sentry
return nil
return
}
mapMu.Lock()
deviceIDToSinceDebugOnly[deviceID] = upTo
deviceIDToSinceDebugOnly[extCtx.DeviceID] = upTo
mapMu.Unlock()
res = &ToDeviceResponse{
res.ToDevice = &ToDeviceResponse{ // TODO: aggregate
NextBatch: fmt.Sprintf("%d", upTo),
Events: msgs,
}
return
}

View File

@ -1,6 +1,7 @@
package extensions
import (
"context"
"encoding/json"
"github.com/matrix-org/sliding-sync/sync3/caches"
@ -55,22 +56,25 @@ func ProcessLiveTyping(up caches.Update, updateWillReturnResponse bool, userID s
return nil
}
func ProcessTyping(globalCache *caches.GlobalCache, roomIDToTimeline map[string][]string, userID string, isInitial bool, req *TypingRequest) (res *TypingResponse) {
func (r *TypingRequest) Process(ctx context.Context, res *Response, extCtx Context) {
// grab typing users for all the rooms we're going to return
res = &TypingResponse{
Rooms: make(map[string]json.RawMessage),
}
roomIDs := make([]string, 0, len(roomIDToTimeline))
for roomID := range roomIDToTimeline {
rooms := make(map[string]json.RawMessage)
roomIDs := make([]string, 0, len(extCtx.RoomIDToTimeline))
for roomID := range extCtx.RoomIDToTimeline {
roomIDs = append(roomIDs, roomID)
}
roomToGlobalMetadata := globalCache.LoadRooms(roomIDs...)
for roomID := range roomIDToTimeline {
roomToGlobalMetadata := extCtx.GlobalCache.LoadRooms(roomIDs...)
for roomID := range extCtx.RoomIDToTimeline {
meta := roomToGlobalMetadata[roomID]
if meta == nil || meta.TypingEvent == nil {
continue
}
res.Rooms[roomID] = meta.TypingEvent
rooms[roomID] = meta.TypingEvent
}
if len(rooms) == 0 {
return // don't add a typing extension, no data!
}
res.Typing = &TypingResponse{
Rooms: rooms, // TODO aggregate
}
return
}

View File

@ -174,7 +174,12 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
// is being notified about (e.g. for room account data)
region := trace.StartRegion(ctx, "extensions")
response.Extensions = s.extensionsHandler.Handle(ctx, ex, includedRoomIDs, isInitial)
response.Extensions = s.extensionsHandler.Handle(ctx, ex, extensions.Context{
UserID: ex.UserID,
DeviceID: ex.DeviceID,
RoomIDToTimeline: includedRoomIDs,
IsInitial: isInitial,
})
region.End()
if response.ListOps() > 0 || len(response.Rooms) > 0 || response.Extensions.HasData(isInitial) {

View File

@ -18,7 +18,7 @@ import (
type NopExtensionHandler struct{}
func (h *NopExtensionHandler) Handle(ctx context.Context, req extensions.Request, listRoomIDs map[string][]string, isInitial bool) (res extensions.Response) {
func (h *NopExtensionHandler) Handle(ctx context.Context, req extensions.Request, extCtx extensions.Context) (res extensions.Response) {
return
}

View File

@ -185,22 +185,8 @@ func TestReceiptsPrivate(t *testing.T) {
// bob secretly reads this
bob.SendReceipt(t, roomID, eventID, "m.read.private")
time.Sleep(300 * time.Millisecond) // TODO: find a better way to wait until the proxy has processed this.
// alice does sliding sync -> does not see private RR
res := alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
Extensions: extensions.Request{
Receipts: &extensions.ReceiptsRequest{
Enableable: extensions.Enableable{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchReceipts(roomID, nil))
// bob does sliding sync -> sees private RR
res = bob.SlidingSync(t, sync3.Request{
res := bob.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
@ -219,4 +205,20 @@ func TestReceiptsPrivate(t *testing.T) {
Type: "m.read.private",
},
}))
// alice does sliding sync -> does not see private RR
// We do this _after_ bob's sync request so we know we got the private RR and it is actively
// suppressed, rather than the private RR not making it to the proxy yet.
res = alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
Extensions: extensions.Request{
Receipts: &extensions.ReceiptsRequest{
Enableable: extensions.Enableable{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchNoReceiptsExtension())
}

View File

@ -236,6 +236,15 @@ func MatchNoE2EEExtension() RespMatcher {
}
}
func MatchNoReceiptsExtension() RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.Receipts != nil {
return fmt.Errorf("MatchNoReceiptsExtension: got Receipts extension: %+v", res.Extensions.Receipts)
}
return nil
}
}
func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.E2EE == nil {