mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
back to util functions
This commit is contained in:
parent
4c858fe3ef
commit
8d38785ac0
@ -1,25 +0,0 @@
|
||||
package internal
|
||||
|
||||
import "strings"
|
||||
|
||||
type HomeServerUrl struct {
|
||||
HttpOrUnixStr string
|
||||
}
|
||||
|
||||
func (u HomeServerUrl) IsUnixSocket() bool {
|
||||
return strings.HasPrefix(u.HttpOrUnixStr, "/")
|
||||
}
|
||||
|
||||
func (u HomeServerUrl) GetUnixSocket() string {
|
||||
if u.IsUnixSocket() {
|
||||
return u.HttpOrUnixStr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (u HomeServerUrl) GetBaseUrl() string {
|
||||
if u.IsUnixSocket() {
|
||||
return "http://unix"
|
||||
}
|
||||
return u.HttpOrUnixStr
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHomeServerUrl_IsUnixSocket_True(t *testing.T) {
|
||||
assert.True(t, HomeServerUrl{"/path/to/socket"}.IsUnixSocket())
|
||||
}
|
||||
|
||||
func TestHomeServerUrl_IsUnixSocket_False(t *testing.T) {
|
||||
assert.False(t, HomeServerUrl{"localhost:8080"}.IsUnixSocket())
|
||||
}
|
||||
|
||||
func TestHomeServerUrl_GetUnixSocket(t *testing.T) {
|
||||
assert.Equal(t, "/path/to/socket", HomeServerUrl{"/path/to/socket"}.GetUnixSocket())
|
||||
}
|
||||
|
||||
func TestHomeServerUrl_GetUnixSocket_Http(t *testing.T) {
|
||||
assert.Equal(t, "", HomeServerUrl{"localhost:8080"}.GetUnixSocket())
|
||||
}
|
||||
|
||||
func TestHomeServerUrl_GetBaseUrl_UnixSocket(t *testing.T) {
|
||||
assert.Equal(t, "http://unix", HomeServerUrl{"/path/to/socket"}.GetBaseUrl())
|
||||
}
|
||||
|
||||
func TestHomeServerUrl_GetBaseUrl_Http(t *testing.T) {
|
||||
assert.Equal(t, "localhost:8080", HomeServerUrl{"localhost:8080"}.GetBaseUrl())
|
||||
}
|
@ -1,5 +1,12 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Keys returns a slice containing copies of the keys of the given map, in no particular
|
||||
// order.
|
||||
func Keys[K comparable, V any](m map[K]V) []K {
|
||||
@ -12,3 +19,22 @@ func Keys[K comparable, V any](m map[K]V) []K {
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func IsUnixSocket(httpOrUnixStr string) bool {
|
||||
return strings.HasPrefix(httpOrUnixStr, "/")
|
||||
}
|
||||
|
||||
func UnixTransport(httpOrUnixStr string) *http.Transport {
|
||||
return &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", httpOrUnixStr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetBaseURL(httpOrUnixStr string) string {
|
||||
if IsUnixSocket(httpOrUnixStr) {
|
||||
return "http://unix"
|
||||
}
|
||||
return httpOrUnixStr
|
||||
}
|
||||
|
@ -28,3 +28,31 @@ func assertSlice(t *testing.T, got, want []string) {
|
||||
t.Errorf("After sorting, got %v but expected %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocket_True(t *testing.T) {
|
||||
address := "/path/to/socket"
|
||||
if !IsUnixSocket(address) {
|
||||
t.Errorf("%s is socket", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocket_False(t *testing.T) {
|
||||
address := "localhost:8080"
|
||||
if IsUnixSocket(address) {
|
||||
t.Errorf("%s is not socket", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBaseUrl_UnixSocket(t *testing.T) {
|
||||
address := "/path/to/socket"
|
||||
if GetBaseURL(address) != "http://unix" {
|
||||
t.Errorf("%s is unix socket", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBaseUrl_Http(t *testing.T) {
|
||||
address := "localhost:8080"
|
||||
if GetBaseURL(address) != "localhost:8080" {
|
||||
t.Errorf("%s is not a unix socket", address)
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@ -41,22 +40,17 @@ type HTTPClient struct {
|
||||
}
|
||||
|
||||
func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient {
|
||||
hsUrl := internal.HomeServerUrl{HttpOrUnixStr: destHomeServer}
|
||||
return &HTTPClient{
|
||||
LongTimeoutClient: newClient(longTimeout, hsUrl),
|
||||
Client: newClient(shortTimeout, hsUrl),
|
||||
DestinationServer: hsUrl.GetBaseUrl(),
|
||||
LongTimeoutClient: newClient(longTimeout, destHomeServer),
|
||||
Client: newClient(shortTimeout, destHomeServer),
|
||||
DestinationServer: internal.GetBaseURL(destHomeServer),
|
||||
}
|
||||
}
|
||||
|
||||
func newClient(timeout time.Duration, url internal.HomeServerUrl) *http.Client {
|
||||
func newClient(timeout time.Duration, destHomeServer string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if url.IsUnixSocket() {
|
||||
transport = &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", url.GetUnixSocket())
|
||||
},
|
||||
}
|
||||
if internal.IsUnixSocket(destHomeServer) {
|
||||
transport = internal.UnixTransport(destHomeServer)
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
|
Loading…
x
Reference in New Issue
Block a user