back to util functions

This commit is contained in:
Boris Rybalkin 2023-11-16 19:31:43 +00:00
parent 4c858fe3ef
commit 8d38785ac0
5 changed files with 60 additions and 67 deletions

View File

@ -1,25 +0,0 @@
package internal
import "strings"
type HomeServerUrl struct {
HttpOrUnixStr string
}
func (u HomeServerUrl) IsUnixSocket() bool {
return strings.HasPrefix(u.HttpOrUnixStr, "/")
}
func (u HomeServerUrl) GetUnixSocket() string {
if u.IsUnixSocket() {
return u.HttpOrUnixStr
}
return ""
}
func (u HomeServerUrl) GetBaseUrl() string {
if u.IsUnixSocket() {
return "http://unix"
}
return u.HttpOrUnixStr
}

View File

@ -1,30 +0,0 @@
package internal
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestHomeServerUrl_IsUnixSocket_True(t *testing.T) {
assert.True(t, HomeServerUrl{"/path/to/socket"}.IsUnixSocket())
}
func TestHomeServerUrl_IsUnixSocket_False(t *testing.T) {
assert.False(t, HomeServerUrl{"localhost:8080"}.IsUnixSocket())
}
func TestHomeServerUrl_GetUnixSocket(t *testing.T) {
assert.Equal(t, "/path/to/socket", HomeServerUrl{"/path/to/socket"}.GetUnixSocket())
}
func TestHomeServerUrl_GetUnixSocket_Http(t *testing.T) {
assert.Equal(t, "", HomeServerUrl{"localhost:8080"}.GetUnixSocket())
}
func TestHomeServerUrl_GetBaseUrl_UnixSocket(t *testing.T) {
assert.Equal(t, "http://unix", HomeServerUrl{"/path/to/socket"}.GetBaseUrl())
}
func TestHomeServerUrl_GetBaseUrl_Http(t *testing.T) {
assert.Equal(t, "localhost:8080", HomeServerUrl{"localhost:8080"}.GetBaseUrl())
}

View File

@ -1,5 +1,12 @@
package internal
import (
"context"
"net"
"net/http"
"strings"
)
// Keys returns a slice containing copies of the keys of the given map, in no particular
// order.
func Keys[K comparable, V any](m map[K]V) []K {
@ -12,3 +19,22 @@ func Keys[K comparable, V any](m map[K]V) []K {
}
return output
}
func IsUnixSocket(httpOrUnixStr string) bool {
return strings.HasPrefix(httpOrUnixStr, "/")
}
func UnixTransport(httpOrUnixStr string) *http.Transport {
return &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", httpOrUnixStr)
},
}
}
func GetBaseURL(httpOrUnixStr string) string {
if IsUnixSocket(httpOrUnixStr) {
return "http://unix"
}
return httpOrUnixStr
}

View File

@ -28,3 +28,31 @@ func assertSlice(t *testing.T, got, want []string) {
t.Errorf("After sorting, got %v but expected %v", got, want)
}
}
func TestUnixSocket_True(t *testing.T) {
address := "/path/to/socket"
if !IsUnixSocket(address) {
t.Errorf("%s is socket", address)
}
}
func TestUnixSocket_False(t *testing.T) {
address := "localhost:8080"
if IsUnixSocket(address) {
t.Errorf("%s is not socket", address)
}
}
func TestGetBaseUrl_UnixSocket(t *testing.T) {
address := "/path/to/socket"
if GetBaseURL(address) != "http://unix" {
t.Errorf("%s is unix socket", address)
}
}
func TestGetBaseUrl_Http(t *testing.T) {
address := "localhost:8080"
if GetBaseURL(address) != "localhost:8080" {
t.Errorf("%s is not a unix socket", address)
}
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/matrix-org/sliding-sync/internal"
"io"
"net"
"net/http"
"net/url"
"time"
@ -41,22 +40,17 @@ type HTTPClient struct {
}
func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient {
hsUrl := internal.HomeServerUrl{HttpOrUnixStr: destHomeServer}
return &HTTPClient{
LongTimeoutClient: newClient(longTimeout, hsUrl),
Client: newClient(shortTimeout, hsUrl),
DestinationServer: hsUrl.GetBaseUrl(),
LongTimeoutClient: newClient(longTimeout, destHomeServer),
Client: newClient(shortTimeout, destHomeServer),
DestinationServer: internal.GetBaseURL(destHomeServer),
}
}
func newClient(timeout time.Duration, url internal.HomeServerUrl) *http.Client {
func newClient(timeout time.Duration, destHomeServer string) *http.Client {
transport := http.DefaultTransport
if url.IsUnixSocket() {
transport = &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", url.GetUnixSocket())
},
}
if internal.IsUnixSocket(destHomeServer) {
transport = internal.UnixTransport(destHomeServer)
}
return &http.Client{
Timeout: timeout,