E2EE extension: Add support for device_unused_fallback_key_types

With tests
This commit is contained in:
Kegan Dougal 2022-08-09 10:05:18 +01:00
parent ca2b19310e
commit 47ddc04652
5 changed files with 39 additions and 13 deletions

View File

@ -99,7 +99,8 @@ type SyncResponse struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
} `json:"device_lists"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceUnusedFallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
}
type SyncRoomsResponse struct {

View File

@ -31,7 +31,7 @@ type V2DataReceiver interface {
// Fetcher which PollerMap satisfies used by the E2EE extension
type E2EEFetcher interface {
LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string)
LatestE2EEData(deviceID string) (otkCounts map[string]int, fallbackKeyTypes, changed, left []string)
}
type TransactionIDFetcher interface {
@ -88,7 +88,7 @@ func (h *PollerMap) TransactionIDForEvent(userID, eventID string) string {
// LatestE2EEData pulls the latest device_lists and device_one_time_keys_count values from the poller.
// These bits of data are ephemeral and do not need to be persisted.
func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, changed, left []string) {
func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, fallbackKeyTypes, changed, left []string) {
h.pollerMu.Lock()
poller := h.Pollers[deviceID]
h.pollerMu.Unlock()
@ -98,6 +98,7 @@ func (h *PollerMap) LatestE2EEData(deviceID string) (otkCounts map[string]int, c
return
}
otkCounts = poller.OTKCounts()
fallbackKeyTypes = poller.FallbackKeyTypes()
changed, left = poller.DeviceListChanges()
return
}
@ -245,6 +246,7 @@ type Poller struct {
// E2EE fields
e2eeMu *sync.Mutex
fallbackKeyTypes []string
otkCounts map[string]int
deviceListChanges map[string]string // latest user_id -> state e.g "@alice" -> "left"
@ -330,6 +332,9 @@ func (p *Poller) OTKCounts() map[string]int {
defer p.e2eeMu.Unlock()
return p.otkCounts
}
func (p *Poller) FallbackKeyTypes() []string {
return p.fallbackKeyTypes
}
func (p *Poller) DeviceListChanges() (changed, left []string) {
p.e2eeMu.Lock()
@ -364,6 +369,9 @@ func (p *Poller) parseE2EEData(res *SyncResponse) {
if res.DeviceListsOTKCount != nil {
p.otkCounts = res.DeviceListsOTKCount
}
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
p.fallbackKeyTypes = res.DeviceUnusedFallbackKeyTypes
}
for _, userID := range res.DeviceLists.Changed {
p.deviceListChanges[userID] = "changed"
}

View File

@ -16,8 +16,9 @@ func (r E2EERequest) ApplyDelta(next *E2EERequest) *E2EERequest {
// Server response
type E2EEResponse struct {
OTKCounts map[string]int `json:"device_one_time_keys_count"`
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
OTKCounts map[string]int `json:"device_one_time_keys_count"`
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
FallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
}
type E2EEDeviceList struct {
@ -35,9 +36,10 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
func ProcessE2EE(fetcher sync2.E2EEFetcher, userID, deviceID string, req *E2EERequest) (res *E2EEResponse) {
// pull OTK counts and changed/left from v2 poller
otkCounts, changed, left := fetcher.LatestE2EEData(deviceID)
otkCounts, fallbackKeyTypes, changed, left := fetcher.LatestE2EEData(deviceID)
res = &E2EEResponse{
OTKCounts: otkCounts,
OTKCounts: otkCounts,
FallbackKeyTypes: fallbackKeyTypes,
}
if len(changed) > 0 || len(left) > 0 {
res.DeviceLists = &E2EEDeviceList{

View File

@ -24,14 +24,16 @@ func TestExtensionE2EE(t *testing.T) {
defer v2.close()
defer v3.close()
// check that OTK counts go through
// check that OTK counts / fallback key types go through
otkCounts := map[string]int{
"curve25519": 10,
"signed_curve25519": 100,
}
fallbackKeyTypes := []string{"signed_curve25519"}
v2.addAccount(alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
DeviceListsOTKCount: otkCounts,
DeviceListsOTKCount: otkCounts,
DeviceUnusedFallbackKeyTypes: fallbackKeyTypes,
})
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: []sync3.RequestList{{
@ -46,9 +48,9 @@ func TestExtensionE2EE(t *testing.T) {
},
},
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
// check that OTK counts remain constant when they aren't included in the v2 response.
// check that OTK counts / fallback key types remain constant when they aren't included in the v2 response.
// Do this by feeding in a new joined room
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
@ -67,9 +69,10 @@ func TestExtensionE2EE(t *testing.T) {
}},
// skip enabled: true as it should be sticky
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
// check that OTK counts update when they are included in the v2 response
// check fallback key types persist when not included
otkCounts = map[string]int{
"curve25519": 99,
"signed_curve25519": 999,
@ -91,7 +94,7 @@ func TestExtensionE2EE(t *testing.T) {
},
},
})
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts))
m.MatchResponse(t, res, m.MatchOTKCounts(otkCounts), m.MatchFallbackKeyTypes(fallbackKeyTypes))
// check that changed|left get passed to v3
wantChanged := []string{"bob"}

View File

@ -210,6 +210,18 @@ func MatchOTKCounts(otkCounts map[string]int) RespMatcher {
}
}
func MatchFallbackKeyTypes(fallbackKeyTypes []string) RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.E2EE == nil {
return fmt.Errorf("MatchFallbackKeyTypes: no E2EE extension present")
}
if !reflect.DeepEqual(res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
return fmt.Errorf("MatchFallbackKeyTypes: got %v want %v", res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes)
}
return nil
}
}
func MatchDeviceLists(changed, left []string) RespMatcher {
return func(res *sync3.Response) error {
if res.Extensions.E2EE == nil {