diff --git a/internal/home_server_url.go b/internal/home_server_url.go deleted file mode 100644 index 0e5544a..0000000 --- a/internal/home_server_url.go +++ /dev/null @@ -1,25 +0,0 @@ -package internal - -import "strings" - -type HomeServerUrl struct { - HttpOrUnixStr string -} - -func (u HomeServerUrl) IsUnixSocket() bool { - return strings.HasPrefix(u.HttpOrUnixStr, "/") -} - -func (u HomeServerUrl) GetUnixSocket() string { - if u.IsUnixSocket() { - return u.HttpOrUnixStr - } - return "" -} - -func (u HomeServerUrl) GetBaseUrl() string { - if u.IsUnixSocket() { - return "http://unix" - } - return u.HttpOrUnixStr -} diff --git a/internal/home_server_url_test.go b/internal/home_server_url_test.go deleted file mode 100644 index 360e374..0000000 --- a/internal/home_server_url_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package internal - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestHomeServerUrl_IsUnixSocket_True(t *testing.T) { - assert.True(t, HomeServerUrl{"/path/to/socket"}.IsUnixSocket()) -} - -func TestHomeServerUrl_IsUnixSocket_False(t *testing.T) { - assert.False(t, HomeServerUrl{"localhost:8080"}.IsUnixSocket()) -} - -func TestHomeServerUrl_GetUnixSocket(t *testing.T) { - assert.Equal(t, "/path/to/socket", HomeServerUrl{"/path/to/socket"}.GetUnixSocket()) -} - -func TestHomeServerUrl_GetUnixSocket_Http(t *testing.T) { - assert.Equal(t, "", HomeServerUrl{"localhost:8080"}.GetUnixSocket()) -} - -func TestHomeServerUrl_GetBaseUrl_UnixSocket(t *testing.T) { - assert.Equal(t, "http://unix", HomeServerUrl{"/path/to/socket"}.GetBaseUrl()) -} - -func TestHomeServerUrl_GetBaseUrl_Http(t *testing.T) { - assert.Equal(t, "localhost:8080", HomeServerUrl{"localhost:8080"}.GetBaseUrl()) -} diff --git a/internal/util.go b/internal/util.go index 51760a7..c0dffad 100644 --- a/internal/util.go +++ b/internal/util.go @@ -1,5 +1,12 @@ package internal +import ( + "context" + "net" + "net/http" + "strings" +) + // Keys returns a slice containing copies of the keys of the given map, in no particular // order. func Keys[K comparable, V any](m map[K]V) []K { @@ -12,3 +19,22 @@ func Keys[K comparable, V any](m map[K]V) []K { } return output } + +func IsUnixSocket(httpOrUnixStr string) bool { + return strings.HasPrefix(httpOrUnixStr, "/") +} + +func UnixTransport(httpOrUnixStr string) *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", httpOrUnixStr) + }, + } +} + +func GetBaseURL(httpOrUnixStr string) string { + if IsUnixSocket(httpOrUnixStr) { + return "http://unix" + } + return httpOrUnixStr +} diff --git a/internal/util_test.go b/internal/util_test.go index 63c1423..33f88c0 100644 --- a/internal/util_test.go +++ b/internal/util_test.go @@ -28,3 +28,31 @@ func assertSlice(t *testing.T, got, want []string) { t.Errorf("After sorting, got %v but expected %v", got, want) } } + +func TestUnixSocket_True(t *testing.T) { + address := "/path/to/socket" + if !IsUnixSocket(address) { + t.Errorf("%s is socket", address) + } +} + +func TestUnixSocket_False(t *testing.T) { + address := "localhost:8080" + if IsUnixSocket(address) { + t.Errorf("%s is not socket", address) + } +} + +func TestGetBaseUrl_UnixSocket(t *testing.T) { + address := "/path/to/socket" + if GetBaseURL(address) != "http://unix" { + t.Errorf("%s is unix socket", address) + } +} + +func TestGetBaseUrl_Http(t *testing.T) { + address := "localhost:8080" + if GetBaseURL(address) != "localhost:8080" { + t.Errorf("%s is not a unix socket", address) + } +} diff --git a/sync2/client.go b/sync2/client.go index 45aec87..7d47f61 100644 --- a/sync2/client.go +++ b/sync2/client.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/matrix-org/sliding-sync/internal" "io" - "net" "net/http" "net/url" "time" @@ -41,22 +40,17 @@ type HTTPClient struct { } func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient { - hsUrl := internal.HomeServerUrl{HttpOrUnixStr: destHomeServer} return &HTTPClient{ - LongTimeoutClient: newClient(longTimeout, hsUrl), - Client: newClient(shortTimeout, hsUrl), - DestinationServer: hsUrl.GetBaseUrl(), + LongTimeoutClient: newClient(longTimeout, destHomeServer), + Client: newClient(shortTimeout, destHomeServer), + DestinationServer: internal.GetBaseURL(destHomeServer), } } -func newClient(timeout time.Duration, url internal.HomeServerUrl) *http.Client { +func newClient(timeout time.Duration, destHomeServer string) *http.Client { transport := http.DefaultTransport - if url.IsUnixSocket() { - transport = &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", url.GetUnixSocket()) - }, - } + if internal.IsUnixSocket(destHomeServer) { + transport = internal.UnixTransport(destHomeServer) } return &http.Client{ Timeout: timeout,