mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
E2EE extension: Add support for device_unused_fallback_key_types
With tests
This commit is contained in:
parent
ca2b19310e
commit
47ddc04652
@ -99,7 +99,8 @@ type SyncResponse struct {
|
||||
Changed []string `json:"changed,omitempty"`
|
||||
Left []string `json:"left,omitempty"`
|
||||
} `json:"device_lists"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
DeviceUnusedFallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
|
||||
}
|
||||
|
||||
type SyncRoomsResponse struct {
|
||||
|
@ -31,7 +31,7 @@ type V2DataReceiver interface {
|
||||
|
||||
// Fetcher which PollerMap satisfies used by the E2EE extension
|
||||
type E2EEFetcher interface {
|
||||
LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string)
|
||||
LatestE2EEData(deviceID string) (otkCounts map[string]int, fallbackKeyTypes, changed, left []string)
|
||||
}
|
||||
|
||||
type TransactionIDFetcher interface {
|
||||
@ -88,7 +88,7 @@ func (h *PollerMap) TransactionIDForEvent(userID, eventID string) string {
|
||||
|
||||
// LatestE2EEData pulls the latest device_lists and device_one_time_keys_count values from the poller.
|
||||
// These bits of data are ephemeral and do not need to be persisted.
|
||||
func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string) {
|
||||
func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, fallbackKeyTypes, changed, left []string) {
|
||||
h.pollerMu.Lock()
|
||||
poller := h.Pollers[deviceID]
|
||||
h.pollerMu.Unlock()
|
||||
@ -98,6 +98,7 @@ func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, c
|
||||
return
|
||||
}
|
||||
otkCounts = poller.OTKCounts()
|
||||
fallbackKeyTypes = poller.FallbackKeyTypes()
|
||||
changed, left = poller.DeviceListChanges()
|
||||
return
|
||||
}
|
||||
@ -245,6 +246,7 @@ type Poller struct {
|
||||
|
||||
// E2EE fields
|
||||
e2eeMu *sync.Mutex
|
||||
fallbackKeyTypes []string
|
||||
otkCounts map[string]int
|
||||
deviceListChanges map[string]string // latest user_id -> state e.g "@alice" -> "left"
|
||||
|
||||
@ -330,6 +332,9 @@ func (p *Poller) OTKCounts() map[string]int {
|
||||
defer p.e2eeMu.Unlock()
|
||||
return p.otkCounts
|
||||
}
|
||||
func (p *Poller) FallbackKeyTypes() []string {
|
||||
return p.fallbackKeyTypes
|
||||
}
|
||||
|
||||
func (p *Poller) DeviceListChanges() (changed, left []string) {
|
||||
p.e2eeMu.Lock()
|
||||
@ -364,6 +369,9 @@ func (p *Poller) parseE2EEData(res *SyncResponse) {
|
||||
if res.DeviceListsOTKCount != nil {
|
||||
p.otkCounts = res.DeviceListsOTKCount
|
||||
}
|
||||
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
|
||||
p.fallbackKeyTypes = res.DeviceUnusedFallbackKeyTypes
|
||||
}
|
||||
for _, userID := range res.DeviceLists.Changed {
|
||||
p.deviceListChanges[userID] = "changed"
|
||||
}
|
||||
|
@ -16,8 +16,9 @@ func (r E2EERequest) ApplyDelta(next *E2EERequest) *E2EERequest {
|
||||
|
||||
// Server response
|
||||
type E2EEResponse struct {
|
||||
OTKCounts map[string]int `json:"device_one_time_keys_count"`
|
||||
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
|
||||
OTKCounts map[string]int `json:"device_one_time_keys_count"`
|
||||
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
|
||||
FallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
|
||||
}
|
||||
|
||||
type E2EEDeviceList struct {
|
||||
@ -35,9 +36,10 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
|
||||
|
||||
func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) {
|
||||
// pull OTK counts and changed/left from v2 poller
|
||||
otkCounts, changed, left := fetcher.LatestE2EEData(deviceID)
|
||||
otkCounts, fallbackKeyTypes, changed, left := fetcher.LatestE2EEData(deviceID)
|
||||
res = &E2EEResponse{
|
||||
OTKCounts: otkCounts,
|
||||
OTKCounts: otkCounts,
|
||||
FallbackKeyTypes: fallbackKeyTypes,
|
||||
}
|
||||
if len(changed) > 0 || len(left) > 0 {
|
||||
res.DeviceLists = &E2EEDeviceList{
|
||||
|
@ -24,14 +24,16 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
// check that OTK counts go through
|
||||
// check that OTK counts / fallback key types go through
|
||||
otkCounts := map[string]int{
|
||||
"curve25519": 10,
|
||||
"signed_curve25519": 100,
|
||||
}
|
||||
fallbackKeyTypes := []string{"signed_curve25519"}
|
||||
v2.addAccount(alice, aliceToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
DeviceListsOTKCount: otkCounts,
|
||||
DeviceListsOTKCount: otkCounts,
|
||||
DeviceUnusedFallbackKeyTypes: fallbackKeyTypes,
|
||||
})
|
||||
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: []sync3.RequestList{{
|
||||
@ -46,9 +48,9 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
|
||||
// check that OTK counts remain constant when they aren't included in the v2 response.
|
||||
// check that OTK counts / fallback key types remain constant when they aren't included in the v2 response.
|
||||
// Do this by feeding in a new joined room
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
@ -67,9 +69,10 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
}},
|
||||
// skip enabled: true as it should be sticky
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
|
||||
// check that OTK counts update when they are included in the v2 response
|
||||
// check fallback key types persist when not included
|
||||
otkCounts = map[string]int{
|
||||
"curve25519": 99,
|
||||
"signed_curve25519": 999,
|
||||
@ -91,7 +94,7 @@ func TestExtensionE2EE(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
|
||||
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
|
||||
|
||||
// check that changed|left get passed to v3
|
||||
wantChanged := []string{"bob"}
|
||||
|
@ -210,6 +210,18 @@ func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func MatchFallbackKeyTypes(fallbackKeyTypes []string) RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.E2EE == nil {
|
||||
return fmt.Errorf("MatchFallbackKeyTypes: no E2EE extension present")
|
||||
}
|
||||
if !reflect.DeepEqual(res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
|
||||
return fmt.Errorf("MatchFallbackKeyTypes: got %v want %v", res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MatchDeviceLists(changed, left []string) RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.E2EE == nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user