diff --git a/cloudflare/d1/conn.go b/cloudflare/d1/conn.go index ce32feb..64875f0 100644 --- a/cloudflare/d1/conn.go +++ b/cloudflare/d1/conn.go @@ -15,17 +15,17 @@ var ( _ driver.Conn = (*Conn)(nil) _ driver.ConnBeginTx = (*Conn)(nil) _ driver.ConnPrepareContext = (*Conn)(nil) - _ driver.QueryerContext = (*Conn)(nil) ) func (c *Conn) Prepare(query string) (driver.Stmt, error) { - //TODO implement me - panic("implement me") + stmtObj := c.dbObj.Call("prepare", query) + return &stmt{ + stmtObj: stmtObj, + }, nil } func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - //TODO implement me - panic("implement me") + return c.Prepare(query) } func (c *Conn) Close() error { @@ -34,14 +34,9 @@ func (c *Conn) Close() error { } func (c *Conn) Begin() (driver.Tx, error) { - return nil, errors.New("d1: transaction is not currently supported") + return nil, errors.New("d1: Begin is deprecated and not implemented") } func (c *Conn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { return nil, errors.New("d1: transaction is not currently supported") } - -func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - //TODO implement me - panic("implement me") -} diff --git a/cloudflare/d1/result.go b/cloudflare/d1/result.go index edbf0ff..e3d37c0 100644 --- a/cloudflare/d1/result.go +++ b/cloudflare/d1/result.go @@ -1,15 +1,48 @@ package d1 -import "database/sql" +import ( + "database/sql" + "errors" + "syscall/js" +) -type result struct{} +type result struct { + resultObj js.Value +} -var _ sql.Result = (*result)(nil) +// Result is the interface which represents Cloudflare's D1Result type. +// For `changes` field, RowsAffected method can be used. +// see: https://github.com/cloudflare/workers-types/blob/v3.18.0/src/workers.json#L1608 +type Result interface { + // LastRowId returns id of result's last row. + // If LastRowId can't be retrieved, this method returns nil. + LastRowId() *int + // Duration returns duration of executed query. + Duration() int +} + +var ( + _ sql.Result = (*result)(nil) + _ Result = (*result)(nil) +) func (r *result) LastInsertId() (int64, error) { - panic("implement me") + return 0, errors.New("d1: LastInsertId is not implemented. instead of it, please use d1.Result.LastRowId") } func (r *result) RowsAffected() (int64, error) { - panic("implement me") + return int64(r.resultObj.Get("changes").Int()), nil +} + +func (r *result) LastRowId() *int { + v := r.resultObj.Get("lastRowId") + if v.IsNull() { + return nil + } + id := v.Int() + return &id +} + +func (r *result) Duration() int { + return r.resultObj.Get("duration").Int() } diff --git a/cloudflare/d1/rows.go b/cloudflare/d1/rows.go index f2df05b..eead042 100644 --- a/cloudflare/d1/rows.go +++ b/cloudflare/d1/rows.go @@ -2,20 +2,95 @@ package d1 import ( "database/sql/driver" + "errors" + "github.com/syumai/workers/internal/jsutil" + "io" + "sync" + "syscall/js" ) -type rows struct{} +type rows struct { + rowsObj js.Value + currentRow int + // columns is cached value of Columns method. + // do not use this directly. + _columns []string + onceColumns sync.Once + // _rowsLen is cached value of rowsLen method. + // do not use this directly. + _rowsLen int + onceRowsLen sync.Once + mu sync.Mutex +} var _ driver.Rows = (*rows)(nil) +// Columns returns column names retrieved from query result object's keys. +// If rows are empty, this returns nil. func (r *rows) Columns() []string { - panic("implement me") + r.onceColumns.Do(func() { + if r.rowsObj.Length() == 0 { + // return nothing when row count is zero. + return + } + colsArray := jsutil.ObjectClass.Call("keys", r.rowsObj.Index(0)) + colsLen := colsArray.Length() + cols := make([]string, colsLen) + for i := 0; i < colsLen; i++ { + cols[i] = colsArray.Index(i).String() + } + r._columns = cols + }) + return r._columns } func (r *rows) Close() error { - panic("implement me") + // do nothing + return nil +} + +// convertRowColumnValueToDriverValue converts row column's value in JS to Go's driver.Value. +// row column value is `null | Number | String | ArrayBuffer`. +// see: https://developers.cloudflare.com/d1/platform/client-api/#type-conversion +func convertRowColumnValueToAny(v js.Value) (driver.Value, error) { + switch v.Type() { + case js.TypeNull: + return nil, nil + case js.TypeNumber: + // TODO: handle INTEGER type. + return v.Float(), nil + case js.TypeString: + return v.String(), nil + case js.TypeObject: + // TODO: handle BLOB type (ArrayBuffer). + // see: https://developers.cloudflare.com/d1/platform/client-api/#type-conversion + return nil, errors.New("d1: row column value type object is not currently supported") + } + return nil, errors.New("d1: unexpected row column value type") } func (r *rows) Next(dest []driver.Value) error { - panic("implement me") + r.mu.Lock() + defer r.mu.Unlock() + if r.currentRow == r.rowsLen() { + return io.EOF + } + rowObj := r.rowsObj.Index(r.currentRow) + cols := r.Columns() + for i, col := range cols { + v, err := convertRowColumnValueToAny(rowObj.Get(col)) + if err != nil { + return err + } + dest[i] = v + } + r.currentRow++ + return nil +} + +func (r *rows) rowsLen() int { + r.onceRowsLen.Do(func() { + r._rowsLen = r.rowsObj.Length() + }) + return r._rowsLen } diff --git a/cloudflare/d1/stmt.go b/cloudflare/d1/stmt.go index cf6905f..f7b7061 100644 --- a/cloudflare/d1/stmt.go +++ b/cloudflare/d1/stmt.go @@ -4,9 +4,13 @@ import ( "context" "database/sql/driver" "errors" + "github.com/syumai/workers/internal/jsutil" + "syscall/js" ) -type stmt struct{} +type stmt struct { + stmtObj js.Value +} var ( _ driver.Stmt = (*stmt)(nil) @@ -15,25 +19,51 @@ var ( ) func (s *stmt) Close() error { - panic("implement me") + // do nothing + return nil } +// NumInput is not supported and always returns -1. func (s *stmt) NumInput() int { - panic("implement me") + return -1 } -func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { +func (s *stmt) Exec([]driver.Value) (driver.Result, error) { return nil, errors.New("d1: Exec is deprecated and not implemented") } -func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - panic("implement me") +// ExecContext executes prepared statement. +// Given []drier.NamedValue's `Name` field will be ignored because Cloudflare D1 client doesn't support it. +func (s *stmt) ExecContext(_ context.Context, args []driver.NamedValue) (driver.Result, error) { + argValues := make([]any, len(args)) + for i, arg := range args { + argValues[i] = arg.Value + } + resultPromise := s.stmtObj.Call("bind", argValues...).Call("run") + resultObj, err := jsutil.AwaitPromise(resultPromise) + if err != nil { + return nil, err + } + return &result{ + resultObj: resultObj, + }, nil } -func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { +func (s *stmt) Query([]driver.Value) (driver.Rows, error) { return nil, errors.New("d1: Query is deprecated and not implemented") } -func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - panic("implement me") +func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver.Rows, error) { + argValues := make([]any, len(args)) + for i, arg := range args { + argValues[i] = arg + } + resultPromise := s.stmtObj.Call("bind", argValues...).Call("all") + rowsObj, err := jsutil.AwaitPromise(resultPromise) + if err != nil { + return nil, err + } + return &rows{ + rowsObj: rowsObj, + }, nil }