From ac293fb7333b66d657ef7b3ae021b7a62c943275 Mon Sep 17 00:00:00 2001 From: syumai Date: Wed, 3 Jan 2024 21:20:07 +0900 Subject: [PATCH] change newSocket func interface --- cloudflare/sockets/dialer.go | 17 +++++++--- cloudflare/sockets/socket.go | 60 ++++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/cloudflare/sockets/dialer.go b/cloudflare/sockets/dialer.go index ed6dd5a..2498188 100644 --- a/cloudflare/sockets/dialer.go +++ b/cloudflare/sockets/dialer.go @@ -3,17 +3,21 @@ package sockets import ( "context" "net" - - "github.com/syumai/workers/internal/jsutil" + "time" "github.com/syumai/workers/cloudflare/internal/cfruntimecontext" + "github.com/syumai/workers/internal/jsutil" ) type SecureTransport string const ( - SecureTransportOn SecureTransport = "on" - SecureTransportOff SecureTransport = "off" + // SecureTransportOn indicates "Use TLS". + SecureTransportOn SecureTransport = "on" + // SecureTransportOff indicates "Do not use TLS". + SecureTransportOff SecureTransport = "off" + // SecureTransportStartTLS indicates "Do not use TLS initially, but allow the socket to be upgraded + // to use TLS by calling *Socket.StartTLS()". SecureTransportStartTLS SecureTransport = "starttls" ) @@ -22,6 +26,8 @@ type SocketOptions struct { AllowHalfOpen bool `json:"allowHalfOpen"` } +const defaultDeadline = 999999 * time.Hour + func Connect(ctx context.Context, addr string, opts *SocketOptions) (net.Conn, error) { connect, err := cfruntimecontext.GetRuntimeContextValue(ctx, "connect") if err != nil { @@ -37,5 +43,6 @@ func Connect(ctx context.Context, addr string, opts *SocketOptions) (net.Conn, e } } sockVal := connect.Invoke(addr, optionsObj) - return newSocket(ctx, sockVal), nil + deadline := time.Now().Add(defaultDeadline) + return newSocket(ctx, sockVal, deadline, deadline), nil } diff --git a/cloudflare/sockets/socket.go b/cloudflare/sockets/socket.go index b13ceed..4c80fd7 100644 --- a/cloudflare/sockets/socket.go +++ b/cloudflare/sockets/socket.go @@ -11,34 +11,41 @@ import ( "github.com/syumai/workers/internal/jsutil" ) -func newSocket(ctx context.Context, sockVal js.Value) *Socket { +func newSocket(ctx context.Context, sockVal js.Value, readDeadline, writeDeadline time.Time) *Socket { ctx, cancel := context.WithCancel(ctx) - reader := sockVal.Get("readable").Call("getReader") - sock := &Socket{ - socket: sockVal, - writer: sockVal.Get("writable").Call("getWriter"), - reader: reader, - rd: jsutil.ConvertStreamReaderToReader(reader), + writerVal := sockVal.Get("writable").Call("getWriter") + readerVal := sockVal.Get("readable").Call("getReader") + return &Socket{ ctx: ctx, cancel: cancel, + + reader: jsutil.ConvertStreamReaderToReader(readerVal), + writerVal: writerVal, + + readDeadline: readDeadline, + writeDeadline: writeDeadline, + + startTLS: func() js.Value { return sockVal.Call("startTls") }, + close: func() { sockVal.Call("close") }, + closeRead: func() { readerVal.Call("close") }, + closeWrite: func() { writerVal.Call("close") }, } - // SetDeadline returns no error - _ = sock.SetDeadline(time.Now().Add(999999 * time.Hour)) - return sock } type Socket struct { - socket js.Value - writer js.Value - reader js.Value + ctx context.Context + cancel context.CancelFunc - rd io.Reader + reader io.Reader + writerVal js.Value readDeadline time.Time writeDeadline time.Time - ctx context.Context - cancel context.CancelFunc + startTLS func() js.Value + close func() + closeRead func() + closeWrite func() } var _ net.Conn = (*Socket)(nil) @@ -51,7 +58,7 @@ func (t *Socket) Read(b []byte) (n int, err error) { defer cancel() done := make(chan struct{}) go func() { - n, err = t.rd.Read(b) + n, err = t.reader.Read(b) close(done) }() select { @@ -72,7 +79,8 @@ func (t *Socket) Write(b []byte) (n int, err error) { go func() { arr := jsutil.NewUint8Array(len(b)) js.CopyBytesToJS(arr, b) - _, err = jsutil.AwaitPromise(t.writer.Call("write", arr)) + _, err = jsutil.AwaitPromise(t.writerVal.Call("write", arr)) + // TODO: handle error if err == nil { n = len(b) } @@ -86,29 +94,29 @@ func (t *Socket) Write(b []byte) (n int, err error) { } } -// StartTls will call startTls on the socket -func (t *Socket) StartTls() *Socket { - sockVal := t.socket.Call("startTls") - return newSocket(t.ctx, sockVal) +// StartTLS upgrades an insecure socket to a secure one that uses TLS, returning a new *Socket. + +func (t *Socket) StartTLS() *Socket { + return newSocket(t.ctx, t.startTLS(), t.readDeadline, t.writeDeadline) } // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. func (t *Socket) Close() error { - t.cancel() - t.socket.Call("close") + defer t.cancel() + t.close() return nil } // CloseRead closes the read side of the connection. func (t *Socket) CloseRead() error { - t.reader.Call("close") + t.closeRead() return nil } // CloseWrite closes the write side of the connection. func (t *Socket) CloseWrite() error { - t.writer.Call("close") + t.closeWrite() return nil }