diff --git a/README.md b/README.md index fd0e870..3c6a899 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ import ( "os" "ariga.io/atlas-provider-gorm/gormschema" - _ "ariga.io/atlas-provider-gorm/recordriver" + _ "ariga.io/atlas-go-sdk/recordriver" "github.com///path/to/models" ) diff --git a/go.mod b/go.mod index 8cb9acc..1fb3ea6 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module ariga.io/atlas-provider-gorm go 1.20 require ( + ariga.io/atlas-go-sdk v0.0.0-20230709063453-1058d6508503 github.com/alecthomas/kong v0.7.1 github.com/stretchr/testify v1.8.4 golang.org/x/tools v0.10.0 diff --git a/go.sum b/go.sum index 0e55d8e..68cb8e7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +ariga.io/atlas-go-sdk v0.0.0-20230709063453-1058d6508503 h1:D12EzAAjhL7xEJk5jFnkKYUXVrDQ4mh0IjB8vqTmo8I= +ariga.io/atlas-go-sdk v0.0.0-20230709063453-1058d6508503/go.mod h1:fwi5nIOFLedo6CqZ0a172dhykLWBnoD25bqmZhvW948= github.com/alecthomas/assert/v2 v2.1.0 h1:tbredtNcQnoSd3QBhQWI7QZ3XHOVkw1Moklp2ojoH/0= github.com/alecthomas/kong v0.7.1 h1:azoTh0IOfwlAX3qN9sHWTxACE2oV8Bg2gAwBsMwDQY4= github.com/alecthomas/kong v0.7.1/go.mod h1:n1iCIO2xS46oE8ZfYCNDqdR0b0wZNrXAIAqro/2132U= diff --git a/gormschema/gorm.go b/gormschema/gorm.go index 48a875f..3448bee 100644 --- a/gormschema/gorm.go +++ b/gormschema/gorm.go @@ -5,7 +5,7 @@ import ( "database/sql/driver" "fmt" - "ariga.io/atlas-provider-gorm/recordriver" + "ariga.io/atlas-go-sdk/recordriver" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" diff --git a/loader.tmpl b/loader.tmpl index 51c01e1..ee235d5 100644 --- a/loader.tmpl +++ b/loader.tmpl @@ -9,7 +9,7 @@ import ( "{{ . }}" {{- end}} "ariga.io/atlas-provider-gorm/gormschema" - _ "ariga.io/atlas-provider-gorm/recordriver" + _ "ariga.io/atlas-go-sdk/recordriver" ) func main() { diff --git a/recordriver/driver.go b/recordriver/driver.go deleted file mode 100644 index 35d11aa..0000000 --- a/recordriver/driver.go +++ /dev/null @@ -1,179 +0,0 @@ -// package recordriver provides a driver for database/sql which records queries and statements -// and allows you to set responses for queries. It is used for testing or providing a runtime replacement -// for a real database in cases where you want to learn the queries and statements that are executed. - -package recordriver - -import ( - "database/sql" - "database/sql/driver" - "io" - "strings" - "sync" -) - -func init() { - sql.Register("recordriver", &drv{}) -} - -var ( - sessions = map[string]*session{} - mu sync.Mutex -) - -type ( - // session is a session of recordriver which records queries and statements. - session struct { - Queries []string - Statements []string - responses map[string]*Response - } - // Response is a response to a query. - Response struct { - Cols []string - Data [][]driver.Value - } - drv struct{} - conn struct { - session string - } - stmt struct { - query string - session string - } - tx struct{} - emptyResult struct{} -) - -// Stmts returns the statements as a string, separated by semicolons and newlines. -func (s *session) Stmts() string { - var sb strings.Builder - for _, stmt := range s.Statements { - sb.WriteString(stmt) - sb.WriteString(";\n") - } - return sb.String() -} - -// Session returns the session with the given name and reports whether it exists. -func Session(name string) (*session, bool) { - mu.Lock() - defer mu.Unlock() - h, ok := sessions[name] - return h, ok -} - -// SetResponse sets the response for the given session and query. -func SetResponse(s string, query string, resp *Response) { - mu.Lock() - defer mu.Unlock() - if _, ok := sessions[s]; !ok { - sessions[s] = &session{ - responses: make(map[string]*Response), - } - } - sessions[s].responses[query] = resp -} - -// Open returns a new connection to the database. -func (d *drv) Open(name string) (driver.Conn, error) { - mu.Lock() - defer mu.Unlock() - if _, ok := sessions[name]; !ok { - sessions[name] = &session{ - responses: make(map[string]*Response), - } - } - return &conn{session: name}, nil -} - -// Prepare returns a prepared statement, bound to this connection. -func (c *conn) Prepare(query string) (driver.Stmt, error) { - return &stmt{query: query, session: c.session}, nil -} - -// Close closes the connection. -func (c *conn) Close() error { - mu.Lock() - defer mu.Unlock() - delete(sessions, c.session) - return nil -} - -// Begin starts and returns a new transaction. -func (c *conn) Begin() (driver.Tx, error) { - return &tx{}, nil -} - -// Commit commits the transaction. It is a noop. -func (*tx) Commit() error { - return nil -} - -// Rollback rolls back the transaction. It is a noop. -func (*tx) Rollback() error { - return nil -} - -// Close closes the statement. -func (*stmt) Close() error { - return nil -} - -// NumInput returns the number of placeholder parameters. Reporting -1 does not know the -// number of parameters. -func (*stmt) NumInput() int { - return -1 -} - -// Exec executes a query that doesn't return rows, such as an CREATE or ALTER TABLE. -func (s *stmt) Exec(_ []driver.Value) (driver.Result, error) { - mu.Lock() - defer mu.Unlock() - sessions[s.session].Statements = append(sessions[s.session].Statements, s.query) - return emptyResult{}, nil -} - -// Query executes a query that may return rows, such as an SELECT. -func (s *stmt) Query(_ []driver.Value) (driver.Rows, error) { - mu.Lock() - defer mu.Unlock() - sess := s.session - sessions[sess].Queries = append(sessions[sess].Queries, s.query) - if resp, ok := sessions[sess].responses[s.query]; ok { - return resp, nil - } - return &Response{}, nil -} - -// Columns returns the names of the columns in the result set. -func (r *Response) Columns() []string { - return r.Cols -} - -// Close closes the rows iterator. It is a noop. -func (*Response) Close() error { - return nil -} - -// Next is called to populate the next row of data into the provided slice. -func (r *Response) Next(dest []driver.Value) error { - if len(r.Data) == 0 { - return io.EOF - } - copy(dest, r.Data[0]) - r.Data = r.Data[1:] - return nil -} - -// LastInsertId returns the integer generated by the database in response to a command. LastInsertId -// always returns a value of 0. -func (emptyResult) LastInsertId() (int64, error) { - return 0, nil -} - -// RowsAffected returns the number of rows affected by the query. RowsAffected always returns a -// value of 0. -func (emptyResult) RowsAffected() (int64, error) { - return 0, nil -} diff --git a/recordriver/driver_test.go b/recordriver/driver_test.go deleted file mode 100644 index 9a138df..0000000 --- a/recordriver/driver_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package recordriver - -import ( - "database/sql" - "database/sql/driver" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestDriver(t *testing.T) { - db, err := sql.Open("recordriver", "t1") - require.NoError(t, err) - defer db.Close() - SetResponse("t1", "select sqlite_version()", &Response{ - Cols: []string{"sqlite_version()"}, - Data: [][]driver.Value{{"3.30.1"}}, - }) - query, err := db.Query("select sqlite_version()") - require.NoError(t, err) - defer query.Close() - for query.Next() { - var version string - err = query.Scan(&version) - require.NoError(t, err) - require.Equal(t, "3.30.1", version) - } - hi, ok := Session("t1") - require.True(t, ok) - require.Len(t, hi.Queries, 1) -} - -func TestInputs(t *testing.T) { - db, err := sql.Open("recordriver", "t1") - require.NoError(t, err) - defer db.Close() - _, err = db.Query("select * from t where id = ?", 1) - require.NoError(t, err) -}