change newSocket func interface

This commit is contained in:
syumai 2024-01-03 21:20:07 +09:00
parent fe098426b2
commit ac293fb733
2 changed files with 46 additions and 31 deletions

View File

@ -3,17 +3,21 @@ package sockets
import ( import (
"context" "context"
"net" "net"
"time"
"github.com/syumai/workers/internal/jsutil"
"github.com/syumai/workers/cloudflare/internal/cfruntimecontext" "github.com/syumai/workers/cloudflare/internal/cfruntimecontext"
"github.com/syumai/workers/internal/jsutil"
) )
type SecureTransport string type SecureTransport string
const ( const (
SecureTransportOn SecureTransport = "on" // SecureTransportOn indicates "Use TLS".
SecureTransportOff SecureTransport = "off" SecureTransportOn SecureTransport = "on"
// SecureTransportOff indicates "Do not use TLS".
SecureTransportOff SecureTransport = "off"
// SecureTransportStartTLS indicates "Do not use TLS initially, but allow the socket to be upgraded
// to use TLS by calling *Socket.StartTLS()".
SecureTransportStartTLS SecureTransport = "starttls" SecureTransportStartTLS SecureTransport = "starttls"
) )
@ -22,6 +26,8 @@ type SocketOptions struct {
AllowHalfOpen bool `json:"allowHalfOpen"` AllowHalfOpen bool `json:"allowHalfOpen"`
} }
const defaultDeadline = 999999 * time.Hour
func Connect(ctx context.Context, addr string, opts *SocketOptions) (net.Conn, error) { func Connect(ctx context.Context, addr string, opts *SocketOptions) (net.Conn, error) {
connect, err := cfruntimecontext.GetRuntimeContextValue(ctx, "connect") connect, err := cfruntimecontext.GetRuntimeContextValue(ctx, "connect")
if err != nil { if err != nil {
@ -37,5 +43,6 @@ func Connect(ctx context.Context, addr string, opts *SocketOptions) (net.Conn, e
} }
} }
sockVal := connect.Invoke(addr, optionsObj) sockVal := connect.Invoke(addr, optionsObj)
return newSocket(ctx, sockVal), nil deadline := time.Now().Add(defaultDeadline)
return newSocket(ctx, sockVal, deadline, deadline), nil
} }

View File

@ -11,34 +11,41 @@ import (
"github.com/syumai/workers/internal/jsutil" "github.com/syumai/workers/internal/jsutil"
) )
func newSocket(ctx context.Context, sockVal js.Value) *Socket { func newSocket(ctx context.Context, sockVal js.Value, readDeadline, writeDeadline time.Time) *Socket {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
reader := sockVal.Get("readable").Call("getReader") writerVal := sockVal.Get("writable").Call("getWriter")
sock := &Socket{ readerVal := sockVal.Get("readable").Call("getReader")
socket: sockVal, return &Socket{
writer: sockVal.Get("writable").Call("getWriter"),
reader: reader,
rd: jsutil.ConvertStreamReaderToReader(reader),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
reader: jsutil.ConvertStreamReaderToReader(readerVal),
writerVal: writerVal,
readDeadline: readDeadline,
writeDeadline: writeDeadline,
startTLS: func() js.Value { return sockVal.Call("startTls") },
close: func() { sockVal.Call("close") },
closeRead: func() { readerVal.Call("close") },
closeWrite: func() { writerVal.Call("close") },
} }
// SetDeadline returns no error
_ = sock.SetDeadline(time.Now().Add(999999 * time.Hour))
return sock
} }
type Socket struct { type Socket struct {
socket js.Value ctx context.Context
writer js.Value cancel context.CancelFunc
reader js.Value
rd io.Reader reader io.Reader
writerVal js.Value
readDeadline time.Time readDeadline time.Time
writeDeadline time.Time writeDeadline time.Time
ctx context.Context startTLS func() js.Value
cancel context.CancelFunc close func()
closeRead func()
closeWrite func()
} }
var _ net.Conn = (*Socket)(nil) var _ net.Conn = (*Socket)(nil)
@ -51,7 +58,7 @@ func (t *Socket) Read(b []byte) (n int, err error) {
defer cancel() defer cancel()
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
n, err = t.rd.Read(b) n, err = t.reader.Read(b)
close(done) close(done)
}() }()
select { select {
@ -72,7 +79,8 @@ func (t *Socket) Write(b []byte) (n int, err error) {
go func() { go func() {
arr := jsutil.NewUint8Array(len(b)) arr := jsutil.NewUint8Array(len(b))
js.CopyBytesToJS(arr, b) js.CopyBytesToJS(arr, b)
_, err = jsutil.AwaitPromise(t.writer.Call("write", arr)) _, err = jsutil.AwaitPromise(t.writerVal.Call("write", arr))
// TODO: handle error
if err == nil { if err == nil {
n = len(b) n = len(b)
} }
@ -86,29 +94,29 @@ func (t *Socket) Write(b []byte) (n int, err error) {
} }
} }
// StartTls will call startTls on the socket // StartTLS upgrades an insecure socket to a secure one that uses TLS, returning a new *Socket.
func (t *Socket) StartTls() *Socket {
sockVal := t.socket.Call("startTls") func (t *Socket) StartTLS() *Socket {
return newSocket(t.ctx, sockVal) return newSocket(t.ctx, t.startTLS(), t.readDeadline, t.writeDeadline)
} }
// Close closes the connection. // Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors. // Any blocked Read or Write operations will be unblocked and return errors.
func (t *Socket) Close() error { func (t *Socket) Close() error {
t.cancel() defer t.cancel()
t.socket.Call("close") t.close()
return nil return nil
} }
// CloseRead closes the read side of the connection. // CloseRead closes the read side of the connection.
func (t *Socket) CloseRead() error { func (t *Socket) CloseRead() error {
t.reader.Call("close") t.closeRead()
return nil return nil
} }
// CloseWrite closes the write side of the connection. // CloseWrite closes the write side of the connection.
func (t *Socket) CloseWrite() error { func (t *Socket) CloseWrite() error {
t.writer.Call("close") t.closeWrite()
return nil return nil
} }