mirror of
https://github.com/ent/ent.git
synced 2026-05-22 09:31:45 +03:00
dialect/sql/sqlgraph: allow arbitrary last insert id type (#1104)
This commit is contained in:
@@ -6,6 +6,7 @@ package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -71,6 +72,15 @@ func ScanString(rows ColumnScanner) (string, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ScanValue scans and returns a driver.Value from the rows columns.
|
||||
func ScanValue(rows ColumnScanner) (driver.Value, error) {
|
||||
var v driver.Value
|
||||
if err := ScanOne(rows, &v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
|
||||
func ScanSlice(rows ColumnScanner, v interface{}) error {
|
||||
columns, err := rows.Columns()
|
||||
|
||||
@@ -171,6 +171,14 @@ func TestScanInt64(t *testing.T) {
|
||||
require.EqualValues(t, 10, n)
|
||||
}
|
||||
|
||||
func TestScanValue(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"count"}).
|
||||
AddRow(10)
|
||||
n, err := ScanValue(toRows(mock))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 10, n)
|
||||
}
|
||||
|
||||
func TestScanOne(t *testing.T) {
|
||||
mock := sqlmock.NewRows([]string{"name"}).
|
||||
AddRow("10").
|
||||
|
||||
@@ -1179,7 +1179,7 @@ func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(st
|
||||
}
|
||||
|
||||
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (int64, error) {
|
||||
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (driver.Value, error) {
|
||||
query, args := insert.Query()
|
||||
// PostgreSQL does not support the LastInsertId() method of sql.Result
|
||||
// on Exec, and should be extracted manually using the `RETURNING` clause.
|
||||
@@ -1189,7 +1189,7 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanInt64(rows)
|
||||
return sql.ScanValue(rows)
|
||||
}
|
||||
// MySQL, SQLite, etc.
|
||||
var res sql.Result
|
||||
|
||||
Reference in New Issue
Block a user