dialect/sql/sqlgraph: allow arbitrary last insert id type (#1104)

This commit is contained in:
Ciaran Liedeman
2020-12-31 12:54:44 +02:00
committed by GitHub
parent 402a1a1a0e
commit 4a1ac1eef1
3 changed files with 20 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ package sql
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strings"
@@ -71,6 +72,15 @@ func ScanString(rows ColumnScanner) (string, error) {
return s, nil
}
// ScanValue scans and returns a driver.Value from the rows columns.
func ScanValue(rows ColumnScanner) (driver.Value, error) {
var v driver.Value
if err := ScanOne(rows, &v); err != nil {
return "", err
}
return v, nil
}
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()

View File

@@ -171,6 +171,14 @@ func TestScanInt64(t *testing.T) {
require.EqualValues(t, 10, n)
}
func TestScanValue(t *testing.T) {
mock := sqlmock.NewRows([]string{"count"}).
AddRow(10)
n, err := ScanValue(toRows(mock))
require.NoError(t, err)
require.EqualValues(t, 10, n)
}
func TestScanOne(t *testing.T) {
mock := sqlmock.NewRows([]string{"name"}).
AddRow("10").

View File

@@ -1179,7 +1179,7 @@ func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(st
}
// insertLastID invokes the insert query on the transaction and returns the LastInsertID.
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (int64, error) {
func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (driver.Value, error) {
query, args := insert.Query()
// PostgreSQL does not support the LastInsertId() method of sql.Result
// on Exec, and should be extracted manually using the `RETURNING` clause.
@@ -1189,7 +1189,7 @@ func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.Inser
return 0, err
}
defer rows.Close()
return sql.ScanInt64(rows)
return sql.ScanValue(rows)
}
// MySQL, SQLite, etc.
var res sql.Result