Compare commits
3 Commits
c8e33884c0
...
262321a1e2
| Author | SHA1 | Date | |
|---|---|---|---|
|
262321a1e2
|
|||
|
c8b8f87f6c
|
|||
|
f59593e9e8
|
1
go.mod
1
go.mod
@@ -10,4 +10,5 @@ require (
|
||||
github.com/jmoiron/sqlx v1.2.0
|
||||
github.com/spf13/viper v1.7.1
|
||||
github.com/stretchr/testify v1.5.1
|
||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899
|
||||
)
|
||||
|
||||
@@ -36,10 +36,14 @@ func (e *errorStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action,
|
||||
return nil, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) SelectUserByID(id int) (*models.User, error) {
|
||||
func (e *errorStore) SelectUserByUsername(username string) (*models.User, error) {
|
||||
return nil, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) InsertUser(user *models.User) (int, error) {
|
||||
return 0, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) ConnectionLive() error {
|
||||
return e.error
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type notFoundError struct {
|
||||
@@ -27,3 +28,27 @@ func wrapNotFound(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type invalidLoginError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (e *invalidLoginError) InvalidLogin() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsInvalidLoginError returns true if the model deems it an invalid login error.
|
||||
func IsInvalidLoginError(err error) bool {
|
||||
type invalidLogin interface {
|
||||
InvalidLogin() bool
|
||||
}
|
||||
te, ok := err.(invalidLogin)
|
||||
return ok && te.InvalidLogin()
|
||||
}
|
||||
|
||||
func wrapInvalidLogin(err error) error {
|
||||
if err == sql.ErrNoRows || err == bcrypt.ErrMismatchedHashAndPassword {
|
||||
return &invalidLoginError{error: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -40,3 +40,10 @@ func TestErrorModelWrapping(t *testing.T) {
|
||||
_, err = m.Action(0)
|
||||
assert.True(models.IsNotFoundError(err))
|
||||
}
|
||||
func TestErrorModelInvalidLogin(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
m := getErrorModel(sql.ErrNoRows)
|
||||
|
||||
_, err := m.VerifyUserByUsernamePassword("duck", "duck")
|
||||
assert.True(models.IsInvalidLoginError(err))
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ type Store interface {
|
||||
SelectPlanByID(id int) (*Plan, error)
|
||||
InsertPlan(plan *Plan) (int, error)
|
||||
SelectActionsByPlanID(plan *Plan) ([]*Action, error)
|
||||
SelectUserByID(id int) (*User, error)
|
||||
SelectUserByUsername(username string) (*User, error)
|
||||
InsertUser(user *User) (int, error)
|
||||
}
|
||||
|
||||
// Model represents a current model item.
|
||||
|
||||
@@ -43,8 +43,13 @@ func (ms *multiStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action
|
||||
return ms.actions, nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) SelectUserByID(id int) (*models.User, error) {
|
||||
return &models.User{UserID: int64(id), Username: "test", DisplayName: "Ted Est", Password: []byte("oh no")}, nil
|
||||
func (ms *multiStore) SelectUserByUsername(username string) (*models.User, error) {
|
||||
// password is "password"
|
||||
return &models.User{UserID: int64(1), Username: username, DisplayName: "Ted Est", Password: []byte("$2y$05$6SVV35GX4cB4PDPhRaDD/exsL.HV8QtMMr60YL6dLyqtX4l58q.cy")}, nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) InsertUser(user *models.User) (int, error) {
|
||||
return int(user.UserID), nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) ConnectionLive() error {
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// User represents the full DB user field, for inserts and compares.
|
||||
// No reason to return the hashed pw on the route though.
|
||||
type User struct {
|
||||
@@ -17,13 +22,21 @@ type UserNoPassword struct {
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// User returns a single plan from the store by plan_id.
|
||||
func (m *Model) User(id int) (*UserNoPassword, error) {
|
||||
user, err := m.SelectUserByID(id)
|
||||
if user == nil {
|
||||
return nil, wrapNotFound(err)
|
||||
// VerifyUserByUsernamePassword returns a single user by the unique username, if the provided password is correct.
|
||||
func (m *Model) VerifyUserByUsernamePassword(username string, password string) (*UserNoPassword, error) {
|
||||
user, err := m.SelectUserByUsername(username)
|
||||
if err != nil {
|
||||
// throwaway to pad time
|
||||
hashPassword(username)
|
||||
return nil, wrapInvalidLogin(err)
|
||||
}
|
||||
return user.NoPassword(), wrapNotFound(err)
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(user.Password, []byte(password))
|
||||
if err != nil {
|
||||
return nil, wrapInvalidLogin(err)
|
||||
}
|
||||
|
||||
return user.NoPassword(), nil
|
||||
}
|
||||
|
||||
// NoPassword strips the user of password.
|
||||
@@ -34,3 +47,38 @@ func (u *User) NoPassword() *UserNoPassword {
|
||||
DisplayName: u.DisplayName,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUserRequest represents a desired user creation.
|
||||
type CreateUserRequest struct {
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// CreateUser takes in a create user request and returns the ID of the newly created user.
|
||||
func (m *Model) CreateUser(req *CreateUserRequest) (int, error) {
|
||||
if req.Username == "" {
|
||||
return -1, fmt.Errorf("No username provided")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return -1, fmt.Errorf("No password provided")
|
||||
}
|
||||
hash, err := hashPassword(req.Password)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
desiredUser := &User{
|
||||
Username: req.Username,
|
||||
DisplayName: req.DisplayName,
|
||||
Password: hash,
|
||||
}
|
||||
|
||||
return m.InsertUser(desiredUser)
|
||||
}
|
||||
|
||||
// hashPassword hashes a password
|
||||
func hashPassword(password string) ([]byte, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 11)
|
||||
return bytes, err
|
||||
}
|
||||
|
||||
@@ -17,30 +17,67 @@ func TestModelUsers(t *testing.T) {
|
||||
[]*models.Plan{p}}
|
||||
m := models.New(ss)
|
||||
|
||||
user, err := m.User(3)
|
||||
user, err := m.VerifyUserByUsernamePassword("test", "password")
|
||||
assert.Nil(err)
|
||||
assert.NotNil(user)
|
||||
|
||||
user, err = m.VerifyUserByUsernamePassword("test", "wrong_password")
|
||||
assert.NotNil(err)
|
||||
assert.Nil(user)
|
||||
}
|
||||
|
||||
func TestErrorUsers(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
m := getErrorModel(fmt.Errorf("err"))
|
||||
|
||||
user, err := m.User(3)
|
||||
user, err := m.VerifyUserByUsernamePassword("snth", "aoeu")
|
||||
assert.Nil(user)
|
||||
assert.NotNil(err)
|
||||
}
|
||||
|
||||
func TestUserNoPassword(t *testing.T) {
|
||||
func TestCreateUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
id := int64(3)
|
||||
username := "test"
|
||||
displayName := "Ted Est"
|
||||
pass := []byte("abc")
|
||||
u := &models.User{UserID: id, Username: username, DisplayName: displayName, Password: pass}
|
||||
pass := "abc"
|
||||
u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass}
|
||||
|
||||
unp := u.NoPassword()
|
||||
assert.EqualValues(id, unp.UserID)
|
||||
assert.Equal(username, unp.Username)
|
||||
assert.Equal(displayName, unp.DisplayName)
|
||||
ss := &multiStore{
|
||||
[]*models.Action{},
|
||||
[]*models.Plan{}}
|
||||
m := models.New(ss)
|
||||
|
||||
_, err := m.CreateUser(u)
|
||||
assert.Nil(err)
|
||||
}
|
||||
func TestCreateUserFailValidation(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
username := ""
|
||||
displayName := "Ted Est"
|
||||
pass := "abc"
|
||||
u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass}
|
||||
|
||||
ss := &multiStore{
|
||||
[]*models.Action{},
|
||||
[]*models.Plan{}}
|
||||
m := models.New(ss)
|
||||
|
||||
_, err := m.CreateUser(u)
|
||||
assert.NotNil(err)
|
||||
}
|
||||
|
||||
func TestCreateUserFailValidationPassword(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
username := "aoeu"
|
||||
displayName := "Ted Est"
|
||||
pass := ""
|
||||
u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass}
|
||||
|
||||
ss := &multiStore{
|
||||
[]*models.Action{},
|
||||
[]*models.Plan{}}
|
||||
m := models.New(ss)
|
||||
|
||||
_, err := m.CreateUser(u)
|
||||
assert.NotNil(err)
|
||||
}
|
||||
|
||||
96
routes/auth.go
Normal file
96
routes/auth.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gitea.deepak.science/deepak/gogmagog/models"
|
||||
"github.com/go-chi/chi"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func newAuthRouter(m *models.Model) http.Handler {
|
||||
router := chi.NewRouter()
|
||||
router.Post("/register", postUserFunc(m))
|
||||
router.Post("/tokens", createTokenFunc(m))
|
||||
return router
|
||||
}
|
||||
|
||||
type createUserResponse struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func postUserFunc(m *models.Model) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1024)
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
var req models.CreateUserRequest
|
||||
err := dec.Decode(&req)
|
||||
if err != nil {
|
||||
badRequestError(w, err)
|
||||
return
|
||||
}
|
||||
err = dec.Decode(&struct{}{})
|
||||
if err != io.EOF {
|
||||
badRequestError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = m.CreateUser(&req)
|
||||
if err != nil {
|
||||
serverError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
response := &createUserResponse{
|
||||
Username: req.Username,
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
serverError(w, err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type loginCreds struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func createTokenFunc(m *models.Model) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1024)
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
var creds loginCreds
|
||||
err := dec.Decode(&creds)
|
||||
if err != nil {
|
||||
badRequestError(w, err)
|
||||
return
|
||||
}
|
||||
err = dec.Decode(&struct{}{})
|
||||
if err != io.EOF {
|
||||
badRequestError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := m.VerifyUserByUsernamePassword(creds.Username, creds.Password)
|
||||
if err != nil {
|
||||
if models.IsInvalidLoginError(err) {
|
||||
unauthorizedHandler(w, r)
|
||||
return
|
||||
}
|
||||
serverError(w, err)
|
||||
return
|
||||
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(user); err != nil {
|
||||
serverError(w, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,3 +25,8 @@ func notFoundHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := http.StatusNotFound
|
||||
http.Error(w, http.StatusText(code), code)
|
||||
}
|
||||
|
||||
func unauthorizedHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := http.StatusUnauthorized
|
||||
http.Error(w, http.StatusText(code), code)
|
||||
}
|
||||
|
||||
@@ -58,10 +58,14 @@ func (ms *multiStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action
|
||||
return ms.actions, nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) SelectUserByID(id int) (*models.User, error) {
|
||||
func (ms *multiStore) SelectUserByUsername(name string) (*models.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) InsertUser(user *models.User) (int, error) {
|
||||
return int(user.UserID), nil
|
||||
}
|
||||
|
||||
func (ms *multiStore) ConnectionLive() error {
|
||||
return nil
|
||||
}
|
||||
@@ -115,10 +119,14 @@ func (e *errorStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action,
|
||||
return nil, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) SelectUserByID(id int) (*models.User, error) {
|
||||
func (e *errorStore) SelectUserByUsername(name string) (*models.User, error) {
|
||||
return nil, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) InsertUser(user *models.User) (int, error) {
|
||||
return 0, e.error
|
||||
}
|
||||
|
||||
func (e *errorStore) ConnectionLive() error {
|
||||
return e.error
|
||||
}
|
||||
@@ -164,10 +172,14 @@ func (e *onlyCreateStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Ac
|
||||
return nil, e.error
|
||||
}
|
||||
|
||||
func (e *onlyCreateStore) SelectUserByID(id int) (*models.User, error) {
|
||||
func (e *onlyCreateStore) SelectUserByUsername(name string) (*models.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *onlyCreateStore) InsertUser(user *models.User) (int, error) {
|
||||
return 0, e.error
|
||||
}
|
||||
|
||||
func (e *onlyCreateStore) ConnectionLive() error {
|
||||
return e.error
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func NewRouter(m *models.Model) http.Handler {
|
||||
router.NotFound(notFoundHandler)
|
||||
router.Mount("/plans", newPlanRouter(m))
|
||||
router.Mount("/actions", newActionRouter(m))
|
||||
router.Mount("/auth", newAuthRouter(m))
|
||||
router.Mount("/health", newHealthRouter(m))
|
||||
router.Get("/ping", ping)
|
||||
return router
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gitea.deepak.science/deepak/gogmagog/models"
|
||||
"github.com/go-chi/chi"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func newUserRouter(m *models.Model) http.Handler {
|
||||
router := chi.NewRouter()
|
||||
// router.Post("/", postUserFunc(m))
|
||||
router.Get("/{userid}", getUserByIDFunc(m))
|
||||
return router
|
||||
}
|
||||
|
||||
func getUserByIDFunc(m *models.Model) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(chi.URLParam(r, "userid"))
|
||||
if err != nil {
|
||||
notFoundHandler(w, r)
|
||||
return
|
||||
}
|
||||
user, err := m.User(id)
|
||||
if err != nil {
|
||||
if models.IsNotFoundError(err) {
|
||||
notFoundHandler(w, r)
|
||||
return
|
||||
}
|
||||
serverError(w, err)
|
||||
return
|
||||
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(user); err != nil {
|
||||
serverError(w, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ CREATE TABLE IF NOT EXISTS actions(
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
user_id serial PRIMARY KEY,
|
||||
username VARCHAR(50) NOT NULL,
|
||||
username VARCHAR(50) NOT NULL UNIQUE,
|
||||
display_name VARCHAR (100) NOT NULL,
|
||||
password bytea,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
|
||||
@@ -135,11 +135,27 @@ func (store *postgresStore) ConnectionLive() error {
|
||||
return store.db.Ping()
|
||||
}
|
||||
|
||||
func (store *postgresStore) SelectUserByID(id int) (*models.User, error) {
|
||||
func (store *postgresStore) SelectUserByUsername(username string) (*models.User, error) {
|
||||
user := models.User{}
|
||||
err := store.db.Get(&user, store.db.Rebind("SELECT user_id, username, display_name, password FROM users WHERE user_id = ?"), id)
|
||||
err := store.db.Get(&user, store.db.Rebind("SELECT user_id, username, display_name, password FROM users WHERE username = ?"), username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (store *postgresStore) InsertUser(user *models.User) (int, error) {
|
||||
queryString := store.db.Rebind("INSERT INTO users (username, display_name, password) VALUES (?, ?, ?) RETURNING user_id")
|
||||
tx := store.db.MustBegin()
|
||||
var id int
|
||||
err := tx.Get(&id, queryString, user.Username, user.DisplayName, user.Password)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
@@ -2,12 +2,13 @@ package store_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitea.deepak.science/deepak/gogmagog/models"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSelectUserById(t *testing.T) {
|
||||
func TestSelectUserByUsername(t *testing.T) {
|
||||
// set up test
|
||||
assert := assert.New(t)
|
||||
|
||||
@@ -26,12 +27,12 @@ func TestSelectUserById(t *testing.T) {
|
||||
}).
|
||||
AddRow(id, username, displayName, password)
|
||||
|
||||
mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE user_id = \$1`).
|
||||
WithArgs(id).
|
||||
mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE username = \$1`).
|
||||
WithArgs(username).
|
||||
WillReturnRows(rows)
|
||||
|
||||
// function under test
|
||||
user, err := str.SelectUserByID(1)
|
||||
user, err := str.SelectUserByUsername(username)
|
||||
|
||||
// test results
|
||||
assert.Nil(err)
|
||||
@@ -48,13 +49,11 @@ func TestSelectUserById(t *testing.T) {
|
||||
func TestErrUserByID(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
idToUse := 1
|
||||
|
||||
str, mock := getDbMock(t)
|
||||
username := "snth"
|
||||
mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE username = \$1`).WithArgs(username).WillReturnError(fmt.Errorf("example error"))
|
||||
|
||||
mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE user_id = \$1`).WithArgs(idToUse).WillReturnError(fmt.Errorf("example error"))
|
||||
|
||||
user, err := str.SelectUserByID(idToUse)
|
||||
user, err := str.SelectUserByUsername(username)
|
||||
assert.NotNil(err)
|
||||
assert.Nil(user)
|
||||
|
||||
@@ -62,3 +61,90 @@ func TestErrUserByID(t *testing.T) {
|
||||
t.Errorf("unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertUser(t *testing.T) {
|
||||
// setup
|
||||
assert := assert.New(t)
|
||||
|
||||
str, mock := getDbMock(t)
|
||||
username := "test"
|
||||
displayName := "Tom Est"
|
||||
password := []byte("ABC€")
|
||||
usr := &models.User{Username: username, DisplayName: displayName, Password: password}
|
||||
|
||||
idToUse := 8
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).AddRow(8)
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`).
|
||||
WithArgs(username, displayName, password).
|
||||
WillReturnRows(rows)
|
||||
mock.ExpectCommit()
|
||||
|
||||
// function under test
|
||||
insertedId, err := str.InsertUser(usr)
|
||||
// check results
|
||||
assert.Nil(err)
|
||||
assert.EqualValues(idToUse, insertedId)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInsertUserErr(t *testing.T) {
|
||||
// setup
|
||||
assert := assert.New(t)
|
||||
|
||||
str, mock := getDbMock(t)
|
||||
username := "test"
|
||||
displayName := "Tom Est"
|
||||
password := []byte("ABC€")
|
||||
usr := &models.User{Username: username, DisplayName: displayName, Password: password}
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`).
|
||||
WithArgs(username, displayName, password).
|
||||
WillReturnError(fmt.Errorf("example error"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
// function under test
|
||||
_, err := str.InsertUser(usr)
|
||||
// check results
|
||||
assert.NotNil(err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInsertUserCommitErr(t *testing.T) {
|
||||
// setup
|
||||
assert := assert.New(t)
|
||||
|
||||
str, mock := getDbMock(t)
|
||||
username := "test"
|
||||
displayName := "Tom Est"
|
||||
password := []byte("ABC€")
|
||||
usr := &models.User{Username: username, DisplayName: displayName, Password: password}
|
||||
|
||||
idToUse := 8
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).AddRow(idToUse)
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`).
|
||||
WithArgs(username, displayName, password).
|
||||
WillReturnRows(rows)
|
||||
mock.ExpectCommit().WillReturnError(fmt.Errorf("another error example"))
|
||||
|
||||
// function under test
|
||||
_, err := str.InsertUser(usr)
|
||||
// check results
|
||||
assert.NotNil(err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user