diff --git a/config/config-sample.yaml b/config/config-sample.yaml index 07a78de..e481575 100644 --- a/config/config-sample.yaml +++ b/config/config-sample.yaml @@ -9,3 +9,4 @@ db: user: USER password: PASSWORD database: g2 + droponstart: true # don't use this in production! diff --git a/config/config.go b/config/config.go index 599550e..25a4ff8 100644 --- a/config/config.go +++ b/config/config.go @@ -15,12 +15,13 @@ type AppConfig struct { // DBConfig is the config for the DB connection. type DBConfig struct { - Type string - Host string - Port string - User string - Password string - Database string + Type string + Host string + Port string + User string + Password string + Database string + DropOnStart bool } // Conf represents the overall configuration of the application. @@ -37,12 +38,13 @@ func createDefaultConf() *Conf { Timezone: "America/New_York", }, Db: DBConfig{ - Type: "postgres", - Host: "localhost", - Port: "5432", - User: "", - Password: "", - Database: "gogmagog", + Type: "postgres", + Host: "localhost", + Port: "5432", + User: "", + Password: "", + Database: "gogmagog", + DropOnStart: false, }, } diff --git a/config/config_test.go b/config/config_test.go index 13c1e83..8697ea0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -25,6 +25,7 @@ func TestSample(t *testing.T) { assert.Equal("USER", dbConf.User) assert.Equal("PASSWORD", dbConf.Password) assert.Equal("g2", dbConf.Database) + assert.True(dbConf.DropOnStart) } func TestDefault(t *testing.T) { @@ -45,6 +46,7 @@ func TestDefault(t *testing.T) { assert.Equal("", dbConf.User) assert.Equal("", dbConf.Password) assert.Equal("gogmagog", dbConf.Database) + assert.False(dbConf.DropOnStart) } func TestMissingFile(t *testing.T) { diff --git a/go.mod b/go.mod index 52b1f71..cd079e2 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,12 @@ go 1.15 require ( github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/go-chi/chi v4.1.2+incompatible + github.com/go-chi/jwtauth v1.1.1 github.com/golang-migrate/migrate/v4 v4.14.1 github.com/jackc/pgx/v4 v4.10.1 github.com/jmoiron/sqlx v1.2.0 + github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.5.1 + golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 ) diff --git a/go.sum b/go.sum index 0fb2c70..2c2db65 100644 --- a/go.sum +++ b/go.sum @@ -108,8 +108,11 @@ github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsouza/fake-gcs-server v1.17.0/go.mod h1:D1rTE4YCyHFNa99oyJJ5HyclvN/0uQR+pM/VdlL83bw= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-chi/chi v1.5.1/go.mod h1:REp24E+25iKvxgeTfHmdUoL5x15kBiDBlnIl5bCwe2k= github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/jwtauth v1.1.1 h1:CtUHwzvXUfZeZSbASLgzaTZQ8mL7p+vitX59NBTL1vY= +github.com/go-chi/jwtauth v1.1.1/go.mod h1:znOWz9e5/GfBOKiZlOUoEfjSjUF+cLZO3GcpkoGXvFI= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -307,6 +310,11 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4= +github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY= +github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= +github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 h1:+HTAqhgKkKqizghOYb4uEpZ7wK8tl3Y48ZbUTHF521c= +github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0= +github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -651,6 +659,7 @@ golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= +golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/main.go b/main.go index acc2a24..9719c9f 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "gitea.deepak.science/deepak/gogmagog/models" "gitea.deepak.science/deepak/gogmagog/routes" "gitea.deepak.science/deepak/gogmagog/store" + "gitea.deepak.science/deepak/gogmagog/tokens" "log" "net/http" "os" @@ -38,7 +39,7 @@ func main() { log.Println("created model") } - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("my secret")) log.Println("Running server on " + port) diff --git a/models/err_model_test.go b/models/err_model_test.go index 27c357a..3928c6b 100644 --- a/models/err_model_test.go +++ b/models/err_model_test.go @@ -2,49 +2,10 @@ package models_test import ( "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" ) -func (e *errorStore) SelectActions() ([]*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) SelectActionByID(id int) (*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) InsertAction(action *models.Action) (int, error) { - return 0, e.error -} - -func (e *errorStore) UpdateAction(action *models.Action) error { - return e.error -} - -func (e *errorStore) SelectPlans() ([]*models.Plan, error) { - return nil, e.error -} - -func (e *errorStore) SelectPlanByID(id int) (*models.Plan, error) { - return nil, e.error -} - -func (e *errorStore) InsertPlan(plan *models.Plan) (int, error) { - return 0, e.error -} - -func (e *errorStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) ConnectionLive() error { - return e.error -} - -type errorStore struct { - error error -} - func getErrorModel(err error) *models.Model { - e := &errorStore{error: err} - return models.New(e) + str := store.GetErrorStoreForError(err, true) + return models.New(str) } diff --git a/models/errors.go b/models/errors.go index 8274dc4..6d071fb 100644 --- a/models/errors.go +++ b/models/errors.go @@ -2,6 +2,7 @@ package models import ( "database/sql" + "golang.org/x/crypto/bcrypt" ) type notFoundError struct { @@ -27,3 +28,27 @@ func wrapNotFound(err error) error { } return err } + +type invalidLoginError struct { + error +} + +func (e *invalidLoginError) InvalidLogin() bool { + return true +} + +// IsInvalidLoginError returns true if the model deems it an invalid login error. +func IsInvalidLoginError(err error) bool { + type invalidLogin interface { + InvalidLogin() bool + } + te, ok := err.(invalidLogin) + return ok && te.InvalidLogin() +} + +func wrapInvalidLogin(err error) error { + if err == sql.ErrNoRows || err == bcrypt.ErrMismatchedHashAndPassword { + return &invalidLoginError{error: err} + } + return err +} diff --git a/models/errors_test.go b/models/errors_test.go index 88baab0..99ed40f 100644 --- a/models/errors_test.go +++ b/models/errors_test.go @@ -35,8 +35,15 @@ func TestErrorModelWrapping(t *testing.T) { assert := assert.New(t) m := getErrorModel(sql.ErrNoRows) - _, err := m.Plan(0) + _, err := m.Plan(0, 0) assert.True(models.IsNotFoundError(err)) _, err = m.Action(0) assert.True(models.IsNotFoundError(err)) } +func TestErrorModelInvalidLogin(t *testing.T) { + assert := assert.New(t) + m := getErrorModel(sql.ErrNoRows) + + _, err := m.VerifyUserByUsernamePassword("duck", "duck") + assert.True(models.IsInvalidLoginError(err)) +} diff --git a/models/models.go b/models/models.go index d25a1ac..02c4e4f 100644 --- a/models/models.go +++ b/models/models.go @@ -11,10 +11,12 @@ type Store interface { SelectActionByID(id int) (*Action, error) InsertAction(action *Action) (int, error) UpdateAction(action *Action) error - SelectPlans() ([]*Plan, error) - SelectPlanByID(id int) (*Plan, error) - InsertPlan(plan *Plan) (int, error) + SelectPlans(userID int) ([]*Plan, error) + SelectPlanByID(id int, userID int) (*Plan, error) + InsertPlan(plan *Plan, userID int) (int, error) SelectActionsByPlanID(plan *Plan) ([]*Action, error) + SelectUserByUsername(username string) (*User, error) + InsertUser(user *User) (int, error) } // Model represents a current model item. diff --git a/models/models_test.go b/models/models_test.go index 7cb62e8..a82acc8 100644 --- a/models/models_test.go +++ b/models/models_test.go @@ -2,72 +2,34 @@ package models_test import ( "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" "github.com/stretchr/testify/assert" "testing" ) -type multiStore struct { - actions []*models.Action - plans []*models.Plan -} - -func (ms *multiStore) SelectActions() ([]*models.Action, error) { - return ms.actions, nil -} - -func (ms *multiStore) SelectActionByID(id int) (*models.Action, error) { - return ms.actions[0], nil -} - -func (ms *multiStore) InsertAction(action *models.Action) (int, error) { - return int(action.ActionID), nil -} - -func (ms *multiStore) UpdateAction(action *models.Action) error { - return nil -} - -func (ms *multiStore) SelectPlans() ([]*models.Plan, error) { - return ms.plans, nil -} - -func (ms *multiStore) SelectPlanByID(id int) (*models.Plan, error) { - return ms.plans[0], nil -} - -func (ms *multiStore) InsertPlan(plan *models.Plan) (int, error) { - return int(plan.PlanID), nil -} - -func (ms *multiStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { - return ms.actions, nil -} - -func (ms *multiStore) ConnectionLive() error { - return nil -} - func TestModelActions(t *testing.T) { assert := assert.New(t) a1 := &models.Action{ActionID: 3} a2 := &models.Action{ActionID: 4} + userID := 3 p := &models.Plan{PlanID: 6} - ss := &multiStore{ - []*models.Action{a1, a2}, - []*models.Plan{p}} - m := models.New(ss) + + str, _ := store.GetInMemoryStore() + str.InsertAction(a1) + str.InsertPlan(p, userID) + m := models.New(str) actions, err := m.Actions() assert.Nil(err) - assert.Equal(2, len(actions)) + assert.Equal(1, len(actions)) - firstAction, err := m.Action(3) + firstAction, err := m.Action(1) assert.Nil(err) - assert.EqualValues(3, firstAction.ActionID) + assert.EqualValues(1, firstAction.ActionID) - actionID, err := m.AddAction(a1) + actionID, err := m.AddAction(a2) assert.Nil(err) - assert.EqualValues(3, actionID) + assert.EqualValues(2, actionID) err = m.SaveAction(a1) assert.Nil(err) @@ -75,41 +37,39 @@ func TestModelActions(t *testing.T) { func TestModelPlanMethods(t *testing.T) { assert := assert.New(t) - a1 := &models.Action{ActionID: 3} + userID := 3 + a1 := &models.Action{ActionID: 3, PlanID: 1} a2 := &models.Action{ActionID: 4} - p := &models.Plan{PlanID: 6} + p := &models.Plan{} - ss := &multiStore{ - []*models.Action{a1, a2}, - []*models.Plan{p}, - } - m := models.New(ss) + str, _ := store.GetInMemoryStore() + str.InsertPlan(p, userID) + str.InsertAction(a1) + str.InsertAction(a2) + m := models.New(str) - plans, err := m.Plans() + plans, err := m.Plans(userID) assert.Nil(err) assert.Equal(1, len(plans)) - firstPlan, err := m.Plan(6) + firstPlan, err := m.Plan(1, userID) assert.Nil(err) - assert.EqualValues(6, firstPlan.PlanID) + assert.EqualValues(1, firstPlan.PlanID) actions, err := m.GetActions(firstPlan) assert.Nil(err) - assert.Equal(2, len(actions)) + assert.Equal(1, len(actions)) - planId, err := m.AddPlan(p) + planId, err := m.AddPlan(&models.Plan{}, userID) assert.Nil(err) - assert.EqualValues(6, planId) + assert.EqualValues(2, planId) } func TestModelHealthy(t *testing.T) { assert := assert.New(t) - ss := &multiStore{ - []*models.Action{}, - []*models.Plan{}, - } - m := models.New(ss) + str, _ := store.GetInMemoryStore() + m := models.New(str) err := m.Healthy() assert.Nil(err) diff --git a/models/plan.go b/models/plan.go index 02e0212..f661b2d 100644 --- a/models/plan.go +++ b/models/plan.go @@ -8,22 +8,23 @@ import ( type Plan struct { PlanID int64 `json:"plan_id"` PlanDate *time.Time `json:"plan_date"` + UserID int64 `json:"user_id"` } // Plans returns all plans in the model. -func (m *Model) Plans() ([]*Plan, error) { - return m.SelectPlans() +func (m *Model) Plans(userID int) ([]*Plan, error) { + return m.SelectPlans(userID) } // Plan returns a single plan from the store by plan_id. -func (m *Model) Plan(id int) (*Plan, error) { - plan, err := m.SelectPlanByID(id) +func (m *Model) Plan(id int, userID int) (*Plan, error) { + plan, err := m.SelectPlanByID(id, userID) return plan, wrapNotFound(err) } // AddPlan inserts a given plan into the store, returning the generated PlanID. The provided PlanID is ignored. -func (m *Model) AddPlan(plan *Plan) (int, error) { - return m.InsertPlan(plan) +func (m *Model) AddPlan(plan *Plan, userID int) (int, error) { + return m.InsertPlan(plan, userID) } // GetActions returns the actions associated with a particular plan. diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..1db7a0e --- /dev/null +++ b/models/user.go @@ -0,0 +1,96 @@ +package models + +import ( + "fmt" + "golang.org/x/crypto/bcrypt" +) + +// User represents the full DB user field, for inserts and compares. +// No reason to return the hashed pw on the route though. +type User struct { + UserID int64 + Username string + DisplayName string + Password []byte +} + +// UserNoPassword contains the non password user fields. +// This is preferred outside of the model / store. +type UserNoPassword struct { + UserID int64 `json:"user_id"` + Username string `json:"username"` + DisplayName string `json:"display_name"` +} + +// VerifyUserByUsernamePassword returns a single user by the unique username, if the provided password is correct. +func (m *Model) VerifyUserByUsernamePassword(username string, password string) (*UserNoPassword, error) { + user, err := m.SelectUserByUsername(username) + if err != nil { + // throwaway to pad time + hashPassword(username) + return nil, wrapInvalidLogin(err) + } + + err = bcrypt.CompareHashAndPassword(user.Password, []byte(password)) + if err != nil { + return nil, wrapInvalidLogin(err) + } + + return user.NoPassword(), nil +} + +// NoPassword strips the user of password. +func (u *User) NoPassword() *UserNoPassword { + return &UserNoPassword{ + UserID: u.UserID, + Username: u.Username, + DisplayName: u.DisplayName, + } +} + +// CreateUserRequest represents a desired user creation. +type CreateUserRequest struct { + Username string `json:"username"` + DisplayName string `json:"display_name"` + Password string `json:"password"` +} + +// CreateUser takes in a create user request and returns the ID of the newly created user. +func (m *Model) CreateUser(req *CreateUserRequest) (int, error) { + if req.Username == "" { + return -1, fmt.Errorf("No username provided") + } + if req.Password == "" { + return -1, fmt.Errorf("No password provided") + } + hash, err := hashPassword(req.Password) + if err != nil { + return -1, err + } + + desiredUser := &User{ + Username: req.Username, + DisplayName: req.DisplayName, + Password: hash, + } + + return m.InsertUser(desiredUser) +} + +// UserByUsername retrieves a single username from the store, verifying the passed in userID. +func (m *Model) UserByUsername(username string, userID int) (*UserNoPassword, error) { + user, err := m.SelectUserByUsername(username) + if user == nil { + return nil, wrapNotFound(err) + } + if int(user.UserID) != userID { + return nil, ¬FoundError{error: fmt.Errorf("provided userID does not match the retrieved user")} + } + return user.NoPassword(), wrapNotFound(err) +} + +// hashPassword hashes a password +func hashPassword(password string) ([]byte, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), 11) + return bytes, err +} diff --git a/models/user_test.go b/models/user_test.go new file mode 100644 index 0000000..fa5a9bd --- /dev/null +++ b/models/user_test.go @@ -0,0 +1,104 @@ +package models_test + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestModelUsers(t *testing.T) { + assert := assert.New(t) + + a1 := &models.Action{ActionID: 3} + a2 := &models.Action{ActionID: 4} + p := &models.Plan{PlanID: 6} + + username := "test1" + // password := password + user1 := &models.User{Username: username, DisplayName: "Ted Est", Password: []byte("$2y$05$6SVV35GX4cB4PDPhRaDD/exsL.HV8QtMMr60YL6dLyqtX4l58q.cy")} + str, _ := store.GetInMemoryStore() + str.InsertPlan(p, 3) + str.InsertAction(a1) + str.InsertAction(a2) + str.InsertUser(user1) + m := models.New(str) + + userNoPass, err := m.UserByUsername("test1", 1) + assert.Nil(err) + assert.NotNil(userNoPass) + + userNoPass, err = m.UserByUsername("test1", 2) + assert.NotNil(err) + assert.True(models.IsNotFoundError(err)) + assert.Nil(userNoPass) + + userNoPass, err = m.UserByUsername("test2", 2) + assert.NotNil(err) + assert.True(models.IsNotFoundError(err)) + assert.Nil(userNoPass) + + user, err := m.VerifyUserByUsernamePassword("test1", "password") + assert.Nil(err) + assert.NotNil(user) + + user, err = m.VerifyUserByUsernamePassword("test1", "wrong_password") + assert.NotNil(err) + assert.Nil(user) + + user, err = m.VerifyUserByUsernamePassword("test2", "password") + assert.NotNil(err) + assert.Nil(user) +} + +func TestErrorUsers(t *testing.T) { + assert := assert.New(t) + m := getErrorModel(fmt.Errorf("err")) + + user, err := m.VerifyUserByUsernamePassword("snth", "aoeu") + assert.Nil(user) + assert.NotNil(err) +} + +func TestCreateUser(t *testing.T) { + assert := assert.New(t) + username := "test" + displayName := "Ted Est" + pass := "abc" + u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass} + + str, _ := store.GetInMemoryStore() + m := models.New(str) + + id, err := m.CreateUser(u) + assert.Nil(err) + assert.EqualValues(1, id) +} +func TestCreateUserFailValidation(t *testing.T) { + assert := assert.New(t) + username := "" + displayName := "Ted Est" + pass := "abc" + u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass} + + str, _ := store.GetInMemoryStore() + m := models.New(str) + + _, err := m.CreateUser(u) + assert.NotNil(err) +} + +func TestCreateUserFailValidationPassword(t *testing.T) { + assert := assert.New(t) + username := "aoeu" + displayName := "Ted Est" + pass := "" + u := &models.CreateUserRequest{Username: username, DisplayName: displayName, Password: pass} + + str, _ := store.GetInMemoryStore() + m := models.New(str) + + _, err := m.CreateUser(u) + assert.NotNil(err) +} diff --git a/routes/actions.go b/routes/actions.go index d63d294..3526803 100644 --- a/routes/actions.go +++ b/routes/actions.go @@ -9,7 +9,8 @@ import ( "strconv" ) -func newActionRouter(m *models.Model) http.Handler { +// NewActionRouter returns a new action router +func NewActionRouter(m *models.Model) http.Handler { router := chi.NewRouter() router.Get("/", getActionsFunc(m)) router.Post("/", postActionFunc(m)) diff --git a/routes/actions_test.go b/routes/actions_test.go index 2008eef..66a1fc5 100644 --- a/routes/actions_test.go +++ b/routes/actions_test.go @@ -15,8 +15,8 @@ func TestEmptyActions(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() @@ -40,9 +40,11 @@ func TestOneAction(t *testing.T) { updatedDate, _ := time.Parse("2006-01-02", "2021-01-02") completedDate, _ := time.Parse("2006-01-02", "2021-01-03") a1 := &models.Action{ActionID: 3, ActionDescription: "testing", CompletedChunks: 1, CompletedOn: &completedDate, CreatedAt: &createdDate, UpdatedAt: &updatedDate, EstimatedChunks: 3, PlanID: 0} - m := getModel([]*models.Plan{}, []*models.Action{a1}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions", nil) + m := getEmptyModel() + m.AddAction(a1) + + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() @@ -55,7 +57,7 @@ func TestOneAction(t *testing.T) { // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `[ { - "action_id": 3, + "action_id": 1, "action_description": "testing", "estimated_chunks": 3, "completed_chunks": 1, @@ -76,8 +78,8 @@ func TestErrorAction(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() @@ -98,8 +100,8 @@ func TestEmptyActionErrorWriter(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/", nil) rr := NewBadWriter() @@ -118,9 +120,10 @@ func TestOneActionByID(t *testing.T) { createdDate, _ := time.Parse("2006-01-02", "2021-01-01") updatedDate, _ := time.Parse("2006-01-02", "2021-01-02") a := &models.Action{ActionID: 6, ActionDescription: "howdy", CompletedOn: nil, CreatedAt: &createdDate, UpdatedAt: &updatedDate, CompletedChunks: 0, EstimatedChunks: 54, PlanID: 3} - m := getModel([]*models.Plan{}, []*models.Action{a}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions/6", nil) + m := getEmptyModel() + m.InsertAction(a) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/1", nil) rr := httptest.NewRecorder() @@ -132,7 +135,7 @@ func TestOneActionByID(t *testing.T) { assert.Equal(http.StatusOK, status) // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `{ - "action_id": 6, + "action_id": 1, "action_description": "howdy", "estimated_chunks": 54, "completed_chunks": 0, @@ -151,8 +154,8 @@ func TestErrorActionByID(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions/5", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/5", nil) rr := httptest.NewRecorder() @@ -172,10 +175,11 @@ func TestEmptyActionErrorWriterByID(t *testing.T) { assert := assert.New(t) a := &models.Action{ActionID: 6} - m := getModel([]*models.Plan{}, []*models.Action{a}) + m := getEmptyModel() + m.AddAction(a) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions/6", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/1", nil) rr := NewBadWriter() @@ -194,8 +198,8 @@ func TestNotFoundActionByIDText(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions/wo", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/wo", nil) rr := httptest.NewRecorder() @@ -213,8 +217,8 @@ func TestNotFoundActionByIDEmpty(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions/1", nil) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/1", nil) rr := httptest.NewRecorder() @@ -232,10 +236,11 @@ func TestActionsByPlanID(t *testing.T) { assert := assert.New(t) createdDate, _ := time.Parse("2006-01-02", "2021-01-01") updatedDate, _ := time.Parse("2006-01-02", "2021-01-02") - a := &models.Action{ActionID: 6, ActionDescription: "howdy", CompletedOn: nil, CreatedAt: &createdDate, UpdatedAt: &updatedDate, CompletedChunks: 0, EstimatedChunks: 54, PlanID: 3} - m := getModel([]*models.Plan{}, []*models.Action{a}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions?plan_id=6", nil) + a := &models.Action{ActionID: 1, ActionDescription: "howdy", CompletedOn: nil, CreatedAt: &createdDate, UpdatedAt: &updatedDate, CompletedChunks: 0, EstimatedChunks: 54, PlanID: 6} + m := getEmptyModel() + m.AddAction(a) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/?plan_id=6", nil) rr := httptest.NewRecorder() @@ -248,13 +253,13 @@ func TestActionsByPlanID(t *testing.T) { // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `[ { - "action_id": 6, + "action_id": 1, "action_description": "howdy", "estimated_chunks": 54, "completed_chunks": 0, "updated_at": "2021-01-02T00:00:00Z", "created_at": "2021-01-01T00:00:00Z", - "plan_id": 3 + "plan_id": 6 } ]` assert.JSONEq(expected, rr.Body.String()) @@ -268,9 +273,10 @@ func TestActionsByPlanIDInvalidID(t *testing.T) { createdDate, _ := time.Parse("2006-01-02", "2021-01-01") updatedDate, _ := time.Parse("2006-01-02", "2021-01-02") a := &models.Action{ActionID: 6, ActionDescription: "howdy", CompletedOn: nil, CreatedAt: &createdDate, UpdatedAt: &updatedDate, CompletedChunks: 0, EstimatedChunks: 54, PlanID: 3} - m := getModel([]*models.Plan{}, []*models.Action{a}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/actions?plan_id=aoeu", nil) + m := getEmptyModel() + m.AddAction(a) + router := routes.NewActionRouter(m) + req, _ := http.NewRequest("GET", "/?plan_id=aoeu", nil) rr := httptest.NewRecorder() diff --git a/routes/auth.go b/routes/auth.go new file mode 100644 index 0000000..f487b62 --- /dev/null +++ b/routes/auth.go @@ -0,0 +1,103 @@ +package routes + +import ( + "encoding/json" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/go-chi/chi" + "io" + "net/http" +) + +// NewAuthRouter returns a new auth router. +func NewAuthRouter(m *models.Model, tok tokens.Toker) http.Handler { + router := chi.NewRouter() + router.Post("/register", postUserFunc(m)) + router.Post("/tokens", createTokenFunc(m, tok)) + return router +} + +type createUserResponse struct { + Username string `json:"username"` +} + +func postUserFunc(m *models.Model) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + r.Body = http.MaxBytesReader(w, r.Body, 1024) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + var req models.CreateUserRequest + err := dec.Decode(&req) + if err != nil { + badRequestError(w, err) + return + } + err = dec.Decode(&struct{}{}) + if err != io.EOF { + badRequestError(w, err) + return + } + + _, err = m.CreateUser(&req) + if err != nil { + serverError(w, err) + return + } + + response := &createUserResponse{ + Username: req.Username, + } + w.WriteHeader(http.StatusCreated) + w.Header().Add("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + serverError(w, err) + } + + } +} + +type loginCreds struct { + Username string `json:"username"` + Password string `json:"password"` +} +type createdToken struct { + Token string `json:"token"` +} + +func createTokenFunc(m *models.Model, tok tokens.Toker) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + r.Body = http.MaxBytesReader(w, r.Body, 1024) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + var creds loginCreds + err := dec.Decode(&creds) + if err != nil { + badRequestError(w, err) + return + } + err = dec.Decode(&struct{}{}) + if err != io.EOF { + badRequestError(w, err) + return + } + + user, err := m.VerifyUserByUsernamePassword(creds.Username, creds.Password) + if err != nil { + if models.IsInvalidLoginError(err) { + unauthorizedHandler(w, r) + return + } + serverError(w, err) + return + + } + w.Header().Add("Content-Type", "application/json") + response := &createdToken{Token: tok.EncodeUser(user)} + if err := json.NewEncoder(w).Encode(response); err != nil { + serverError(w, err) + return + } + } +} diff --git a/routes/auth_login_test.go b/routes/auth_login_test.go new file mode 100644 index 0000000..db4925a --- /dev/null +++ b/routes/auth_login_test.go @@ -0,0 +1,202 @@ +package routes_test + +import ( + "bytes" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestLoginAuth(t *testing.T) { + // set up + assert := assert.New(t) + + m := getEmptyModel() + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_username", + "password": "pass" + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusOK, status) + expected := `{ + "token": "{\"ID\":1,\"Username\":\"testing_username\"}" + }` + assert.JSONEq(expected, rr.Body.String()) + contentType := rr.Header().Get("Content-Type") + assert.Equal("application/json", contentType) + +} + +func TestLoginBadCreds(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_use + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusBadRequest, status) + +} + +func TestLoginBadRequestTwoBodies(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_username", + "password": "pass" + }{ + "username": "testing_username", + "password": "pass" + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusBadRequest, status) +} + +func TestLoginAuthWrongPass(t *testing.T) { + // set up + assert := assert.New(t) + + m := getEmptyModel() + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_username", + "password": "badpass" + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestLoginErrorModel(t *testing.T) { + // set up + assert := assert.New(t) + + m := getErrorModel("error") + + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_username", + "password": "badpass" + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} + +func TestLoginBadWriter(t *testing.T) { + // set up + assert := assert.New(t) + + m := getEmptyModel() + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + toker := tokens.GetDeterministicToker() + data := []byte(`{ + "username": "testing_username", + "password": "pass" + }`) + req, _ := http.NewRequest("POST", "/tokens", bytes.NewBuffer(data)) + + rr := NewBadWriter() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} + +// +// func TestRegisterBadWriter(t *testing.T) { +// // set up +// assert := assert.New(t) +// m := getEmptyModel() +// toker := tokens.New("secret") +// data := []byte(`{ +// "username": "test", +// "password": "pass", +// "display_name": "My Display Name" +// }`) +// req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) +// +// rr := NewBadWriter() +// +// // function under test +// router := routes.NewAuthRouter(m, toker) +// router.ServeHTTP(rr, req) +// +// // check results +// status := rr.Code +// assert.Equal(http.StatusInternalServerError, status) +// +// } diff --git a/routes/auth_register_test.go b/routes/auth_register_test.go new file mode 100644 index 0000000..1208a22 --- /dev/null +++ b/routes/auth_register_test.go @@ -0,0 +1,139 @@ +package routes_test + +import ( + "bytes" + "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRegisterAuth(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.New("secret") + data := []byte(`{ + "username": "test", + "password": "pass", + "display_name": "My Display Name" + }`) + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusCreated, status) + expected := `{ + "username": "test" + }` + assert.JSONEq(expected, rr.Body.String()) + contentType := rr.Header().Get("Content-Type") + assert.Equal("application/json", contentType) + +} + +func TestRegisterBadRequestAuth(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.New("secret") + data := []byte(`{ + "username": y Display Name" + }`) + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusBadRequest, status) + +} + +func TestRegisterBadRequestTwoBodies(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.New("secret") + data := []byte(`{ + "username": "test", + "password": "pass", + "display_name": "My Display Name" + }, { + "username": "test", + "password": "pass", + "display_name": "My Display Name" + }`) + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusBadRequest, status) + +} + +func TestRegisterErrorModel(t *testing.T) { + // set up + assert := assert.New(t) + m := getErrorModel("here's an error") + toker := tokens.New("secret") + data := []byte(`{ + "username": "test", + "password": "pass", + "display_name": "My Display Name" + }`) + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) + + rr := httptest.NewRecorder() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} + +func TestRegisterBadWriter(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + toker := tokens.New("secret") + data := []byte(`{ + "username": "test", + "password": "pass", + "display_name": "My Display Name" + }`) + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(data)) + + rr := NewBadWriter() + + // function under test + router := routes.NewAuthRouter(m, toker) + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} diff --git a/routes/currentUser.go b/routes/currentUser.go new file mode 100644 index 0000000..e65ce13 --- /dev/null +++ b/routes/currentUser.go @@ -0,0 +1,48 @@ +package routes + +import ( + "encoding/json" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/go-chi/chi" + "log" + "net/http" +) + +// NewCurrentUserRouter returns a new router for getting the current user. +func NewCurrentUserRouter(m *models.Model) http.Handler { + router := chi.NewRouter() + router.Get("/", getMeFunc(m)) + return router +} + +func getMeFunc(m *models.Model) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userID, err := tokens.GetUserID(r.Context()) + if err != nil { + log.Print(err) + unauthorizedHandler(w, r) + return + } + username, err := tokens.GetUsername(r.Context()) + if err != nil { + log.Print(err) + unauthorizedHandler(w, r) + return + } + + user, err := m.UserByUsername(username, userID) + if err != nil { + if models.IsNotFoundError(err) { + notFoundHandler(w, r) + return + } + serverError(w, err) + return + } + w.Header().Add("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(user); err != nil { + serverError(w, err) + } + } +} diff --git a/routes/currentUser_test.go b/routes/currentUser_test.go new file mode 100644 index 0000000..4ac993e --- /dev/null +++ b/routes/currentUser_test.go @@ -0,0 +1,158 @@ +package routes_test + +import ( + "context" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestEmptyCurrentUser(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(tokens.GetContextForUserValues(3, "testing"), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusNotFound, status) + +} + +func TestSingleUser(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + + idToUse := 1 + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(tokens.GetContextForUserValues(idToUse, username), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusOK, status) + expected := `{ + "user_id": 1, + "username": "testing_username", + "display_name": "testing_name" + }` + assert.JSONEq(expected, rr.Body.String()) + contentType := rr.Header().Get("Content-Type") + assert.Equal("application/json", contentType) + +} + +func TestSingleUserEmptyContext(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestSingleUserContextNoUserID(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + + idToUse := 1 + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(tokens.SetUserID(context.Background(), idToUse), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestErrorUserContextNoUserID(t *testing.T) { + // set up + assert := assert.New(t) + m := getErrorModel("Here's an error.") + + idToUse := 1 + + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(tokens.GetContextForUserValues(idToUse, "username"), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} + +func TestSingleUserErrorWriter(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + + idToUse := 1 + username := "testing_username" + displayName := "testing_name" + password := "pass" + m.CreateUser(&models.CreateUserRequest{Username: username, DisplayName: displayName, Password: password}) + + router := routes.NewCurrentUserRouter(m) + req, _ := http.NewRequestWithContext(tokens.GetContextForUserValues(idToUse, username), "GET", "/", nil) + + rr := NewBadWriter() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusInternalServerError, status) + +} diff --git a/routes/errors.go b/routes/errors.go index a43793a..43a457a 100644 --- a/routes/errors.go +++ b/routes/errors.go @@ -25,3 +25,8 @@ func notFoundHandler(w http.ResponseWriter, r *http.Request) { code := http.StatusNotFound http.Error(w, http.StatusText(code), code) } + +func unauthorizedHandler(w http.ResponseWriter, r *http.Request) { + code := http.StatusUnauthorized + http.Error(w, http.StatusText(code), code) +} diff --git a/routes/health_test.go b/routes/health_test.go index 92902b3..30b7100 100644 --- a/routes/health_test.go +++ b/routes/health_test.go @@ -3,6 +3,7 @@ package routes_test import ( "fmt" "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -15,7 +16,7 @@ func TestEmptyHeatlhErrorWriter(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("GET", "/health", nil) rr := NewBadWriter() @@ -33,7 +34,7 @@ func TestEmptyHealth(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("GET", "/health", nil) rr := httptest.NewRecorder() @@ -62,7 +63,7 @@ func TestUnhealthyDB(t *testing.T) { assert := assert.New(t) errorMsg := "error" m := getErrorModel(errorMsg) - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("GET", "/health", nil) rr := httptest.NewRecorder() diff --git a/routes/plans.go b/routes/plans.go index 6d2921c..c1a26ae 100644 --- a/routes/plans.go +++ b/routes/plans.go @@ -3,13 +3,15 @@ package routes import ( "encoding/json" "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" "github.com/go-chi/chi" "io" "net/http" "strconv" ) -func newPlanRouter(m *models.Model) http.Handler { +// NewPlanRouter returns the http.Handler for the passed in model to route plan methods. +func NewPlanRouter(m *models.Model) http.Handler { router := chi.NewRouter() router.Get("/", getAllPlansFunc(m)) router.Post("/", postPlanFunc(m)) @@ -19,7 +21,13 @@ func newPlanRouter(m *models.Model) http.Handler { func getAllPlansFunc(m *models.Model) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - plans, err := m.Plans() + userID, err := tokens.GetUserID(r.Context()) + if err != nil { + unauthorizedHandler(w, r) + return + } + + plans, err := m.Plans(userID) if err != nil { serverError(w, err) return @@ -33,12 +41,19 @@ func getAllPlansFunc(m *models.Model) http.HandlerFunc { func getPlanByIDFunc(m *models.Model) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + userID, err := tokens.GetUserID(r.Context()) + if err != nil { + unauthorizedHandler(w, r) + return + } + id, err := strconv.Atoi(chi.URLParam(r, "planid")) if err != nil { notFoundHandler(w, r) return } - plan, err := m.Plan(id) + // todo get real user id + plan, err := m.Plan(id, userID) if err != nil { if models.IsNotFoundError(err) { notFoundHandler(w, r) @@ -62,12 +77,17 @@ type createPlanResponse struct { func postPlanFunc(m *models.Model) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + userID, err := tokens.GetUserID(r.Context()) + if err != nil { + unauthorizedHandler(w, r) + return + } r.Body = http.MaxBytesReader(w, r.Body, 1024) dec := json.NewDecoder(r.Body) dec.DisallowUnknownFields() var p models.Plan - err := dec.Decode(&p) + err = dec.Decode(&p) if err != nil { badRequestError(w, err) return @@ -78,13 +98,14 @@ func postPlanFunc(m *models.Model) http.HandlerFunc { return } - plan := &models.Plan{PlanDate: p.PlanDate} - id, err := m.AddPlan(plan) + // Map the fields we allow to be set to the plan to be created. + plan := &models.Plan{PlanDate: p.PlanDate, UserID: p.UserID} + id, err := m.AddPlan(plan, userID) if err != nil { serverError(w, err) return } - plan, err = m.Plan(id) + plan, err = m.Plan(id, userID) if err != nil { serverError(w, err) return diff --git a/routes/plans_test.go b/routes/plans_test.go index fd58b1c..04d0254 100644 --- a/routes/plans_test.go +++ b/routes/plans_test.go @@ -3,6 +3,7 @@ package routes_test import ( "gitea.deepak.science/deepak/gogmagog/models" "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -11,12 +12,14 @@ import ( "time" ) +var sampleContext = tokens.GetContextForUserValues(3, "testing") + func TestEmptyPlans(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/", nil) rr := httptest.NewRecorder() @@ -37,10 +40,11 @@ func TestOnePlan(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans", nil) + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: 3} + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/", nil) rr := httptest.NewRecorder() @@ -53,8 +57,9 @@ func TestOnePlan(t *testing.T) { // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `[ { - "plan_id": 6, - "plan_date": "2021-01-01T00:00:00Z" + "plan_id": 1, + "plan_date": "2021-01-01T00:00:00Z", + "user_id": 3 } ]` assert.JSONEq(expected, rr.Body.String()) @@ -68,8 +73,8 @@ func TestErrorPlan(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/", nil) rr := httptest.NewRecorder() @@ -90,8 +95,8 @@ func TestEmptyPlanErrorWriter(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/", nil) rr := NewBadWriter() @@ -108,10 +113,11 @@ func TestOnePlanByID(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans/6", nil) + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: 3} + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/1", nil) rr := httptest.NewRecorder() @@ -123,8 +129,9 @@ func TestOnePlanByID(t *testing.T) { assert.Equal(http.StatusOK, status) // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `{ - "plan_id": 6, - "plan_date": "2021-01-01T00:00:00Z" + "plan_id": 1, + "plan_date": "2021-01-01T00:00:00Z", + "user_id": 3 }` assert.JSONEq(expected, rr.Body.String()) contentType := rr.Header().Get("Content-Type") @@ -137,8 +144,8 @@ func TestErrorPlanByID(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans/5", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/5", nil) rr := httptest.NewRecorder() @@ -158,11 +165,12 @@ func TestEmptyPlanErrorWriterByID(t *testing.T) { assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) + p := &models.Plan{PlanID: 1, PlanDate: &planDate} + m := getEmptyModel() + m.AddPlan(p, 3) - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans/6", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/1", nil) rr := NewBadWriter() @@ -181,8 +189,8 @@ func TestNotFoundPlanByIDText(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans/wo", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/wo", nil) rr := httptest.NewRecorder() @@ -200,8 +208,8 @@ func TestNotFoundPlanByIDEmpty(t *testing.T) { m := getEmptyModel() - router := routes.NewRouter(m) - req, _ := http.NewRequest("GET", "/plans/1", nil) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(sampleContext, "GET", "/1", nil) rr := httptest.NewRecorder() diff --git a/routes/plans_unauthorized_test.go b/routes/plans_unauthorized_test.go new file mode 100644 index 0000000..763daa1 --- /dev/null +++ b/routes/plans_unauthorized_test.go @@ -0,0 +1,95 @@ +package routes_test + +import ( + "bytes" + "context" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/routes" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestEmptyPlanEmptyContext(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestOnePlanEmptyContext(t *testing.T) { + // set up + assert := assert.New(t) + planDate, _ := time.Parse("2006-01-02", "2021-01-01") + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: 3} + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestOnePlanByIDEmptyContext(t *testing.T) { + // set up + assert := assert.New(t) + planDate, _ := time.Parse("2006-01-02", "2021-01-01") + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: 3} + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/1", nil) + + rr := httptest.NewRecorder() + + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) + +} + +func TestPureJSONEmptyContext(t *testing.T) { + // set up + assert := assert.New(t) + m := getEmptyModel() + router := routes.NewPlanRouter(m) + data := []byte(`{ + "plan_date": "2021-01-01T00:00:00Z", + "plan_id": 1, + "user_id": 3 + }`) + req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewBuffer(data)) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + // function under test + router.ServeHTTP(rr, req) + + // check results + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) +} diff --git a/routes/post_action_test.go b/routes/post_action_test.go index d89aba3..0e87b07 100644 --- a/routes/post_action_test.go +++ b/routes/post_action_test.go @@ -24,8 +24,9 @@ func TestPureJSONPostAction(t *testing.T) { CompletedChunks: 2, ActionDescription: "here's an action", } - m := getModel([]*models.Plan{}, []*models.Action{a}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddAction(a) + router := routes.NewActionRouter(m) data := []byte(`{ "action_description": "here's an action", "estimated_chunks": 3, @@ -33,7 +34,7 @@ func TestPureJSONPostAction(t *testing.T) { "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5 }`) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -51,9 +52,9 @@ func TestPureJSONPostAction(t *testing.T) { "completed_chunks": 2, "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5, - "action_id": 0 + "action_id": 2 }, - "id": 0 + "id": 2 }` assert.JSONEq(expected, rr.Body.String()) contentType := rr.Header().Get("Content-Type") @@ -65,14 +66,15 @@ func TestExtraFieldActionPostJSON(t *testing.T) { assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewActionRouter(m) data := []byte(`{ "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5, "sabotage": "omg" }`) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -91,10 +93,11 @@ func TestEmptyBodyActionPost(t *testing.T) { assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewActionRouter(m) data := []byte(``) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -114,14 +117,15 @@ func TestTwoBodyActionPost(t *testing.T) { assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewActionRouter(m) data := []byte(`{ "plan_id": 5 }, { "plan_id": 6 }`) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -142,10 +146,10 @@ func TestErrorCreateAction(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) a := &models.Action{PlanID: 6} data, _ := json.Marshal(a) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -167,10 +171,10 @@ func TestErrorOnRetrieveCreateAction(t *testing.T) { m := getErrorOnGetModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) a := &models.Action{PlanID: 6} data, _ := json.Marshal(a) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -190,11 +194,12 @@ func TestErrorWriterCreateAction(t *testing.T) { assert := assert.New(t) a := &models.Action{PlanID: 6} - m := getModel([]*models.Plan{}, []*models.Action{a}) + m := getEmptyModel() + m.AddAction(a) - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) data, _ := json.Marshal(a) - req, _ := http.NewRequest("POST", "/actions", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := NewBadWriter() diff --git a/routes/post_plan_test.go b/routes/post_plan_test.go index 8a194e2..699c948 100644 --- a/routes/post_plan_test.go +++ b/routes/post_plan_test.go @@ -17,11 +17,13 @@ func TestCreatePlanRoute(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + userID := 3 + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: int64(userID)} + m := getEmptyModel() + m.AddPlan(p, userID) + router := routes.NewPlanRouter(m) data, _ := json.Marshal(p) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -34,10 +36,11 @@ func TestCreatePlanRoute(t *testing.T) { // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `{ "created_plan": { - "plan_id": 6, - "plan_date": "2021-01-01T00:00:00Z" + "plan_id": 2, + "plan_date": "2021-01-01T00:00:00Z", + "user_id": 3 }, - "id": 0 + "id": 2 }` assert.JSONEq(expected, rr.Body.String()) contentType := rr.Header().Get("Content-Type") @@ -48,14 +51,17 @@ func TestPureJSON(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + userID := 3 + p := &models.Plan{PlanID: 1, PlanDate: &planDate, UserID: int64(userID)} + m := getEmptyModel() + m.AddPlan(p, userID) + router := routes.NewPlanRouter(m) data := []byte(`{ "plan_date": "2021-01-01T00:00:00Z", - "plan_id": 5 + "plan_id": 1, + "user_id": 3 }`) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -68,10 +74,11 @@ func TestPureJSON(t *testing.T) { // We pass in the date as a time.time so it makes sense that it comes back with a midnight timestamp. expected := `{ "created_plan": { - "plan_id": 6, + "plan_id": 2, + "user_id": 3, "plan_date": "2021-01-01T00:00:00Z" }, - "id": 0 + "id": 2 }` assert.JSONEq(expected, rr.Body.String()) contentType := rr.Header().Get("Content-Type") @@ -82,15 +89,17 @@ func TestExtraFieldJSON(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + userID := 3 + p := &models.Plan{PlanID: 6, PlanDate: &planDate, UserID: int64(userID)} + m := getEmptyModel() + m.AddPlan(p, userID) + router := routes.NewPlanRouter(m) data := []byte(`{ "plan_date": "2021-01-01T00:00:00Z", "plan_id": 5, "plan_sabotage": "omg" }`) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -108,11 +117,13 @@ func TestEmptyBody(t *testing.T) { // set up assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") + userID := 3 p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, userID) + router := routes.NewPlanRouter(m) data := []byte(``) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -132,8 +143,9 @@ func TestTwoBody(t *testing.T) { assert := assert.New(t) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) data := []byte(`{ "plan_date": "2021-01-01T00:00:00Z", "plan_id": 5 @@ -141,7 +153,7 @@ func TestTwoBody(t *testing.T) { "plan_date": "2021-01-01T00:00:00Z", "plan_id": 6 }`) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -162,11 +174,11 @@ func TestErrorCreatePlan(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewPlanRouter(m) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} data, _ := json.Marshal(p) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -188,11 +200,11 @@ func TestErrorOnRetrieveCreatePlan(t *testing.T) { m := getErrorOnGetModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewPlanRouter(m) planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} data, _ := json.Marshal(p) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -213,11 +225,11 @@ func TestErrorWriterCreatePlan(t *testing.T) { planDate, _ := time.Parse("2006-01-02", "2021-01-01") p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddPlan(p, 3) + router := routes.NewPlanRouter(m) data, _ := json.Marshal(p) - req, _ := http.NewRequest("POST", "/plans", bytes.NewBuffer(data)) + req, _ := http.NewRequestWithContext(sampleContext, "POST", "/", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := NewBadWriter() diff --git a/routes/put_action_test.go b/routes/put_action_test.go index 5dc03b4..3599f1d 100644 --- a/routes/put_action_test.go +++ b/routes/put_action_test.go @@ -20,12 +20,13 @@ func TestPureJSONPutAction(t *testing.T) { a := &models.Action{ PlanID: 5, CompletedOn: &compOn, - EstimatedChunks: 3, - CompletedChunks: 2, - ActionDescription: "here's an action", + EstimatedChunks: 1, + CompletedChunks: 1, + ActionDescription: "hn", } - m := getModel([]*models.Plan{}, []*models.Action{a}) - router := routes.NewRouter(m) + m := getEmptyModel() + m.AddAction(a) + router := routes.NewActionRouter(m) data := []byte(`{ "action_description": "here's an action", "estimated_chunks": 3, @@ -33,7 +34,7 @@ func TestPureJSONPutAction(t *testing.T) { "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5 }`) - req, _ := http.NewRequest("PUT", "/actions/0", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -51,9 +52,9 @@ func TestPureJSONPutAction(t *testing.T) { "completed_chunks": 2, "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5, - "action_id": 0 + "action_id": 1 }, - "id": 0 + "id": 1 }` assert.JSONEq(expected, rr.Body.String()) contentType := rr.Header().Get("Content-Type") @@ -63,16 +64,14 @@ func TestPureJSONPutAction(t *testing.T) { func TestExtraFieldActionPutJSON(t *testing.T) { // set up assert := assert.New(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + router := routes.NewActionRouter(m) data := []byte(`{ "completed_on": "2021-01-01T00:00:00Z", "plan_id": 5, "sabotage": "omg" }`) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -89,12 +88,10 @@ func TestExtraFieldActionPutJSON(t *testing.T) { func TestEmptyBodyActionPut(t *testing.T) { // set up assert := assert.New(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + router := routes.NewActionRouter(m) data := []byte(``) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -112,16 +109,14 @@ func TestEmptyBodyActionPut(t *testing.T) { func TestTwoBodyActionPut(t *testing.T) { // set up assert := assert.New(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + router := routes.NewActionRouter(m) data := []byte(`{ "plan_id": 5 }, { "plan_id": 6 }`) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -139,16 +134,14 @@ func TestTwoBodyActionPut(t *testing.T) { func TestBadActionIDPut(t *testing.T) { // set up assert := assert.New(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - p := &models.Plan{PlanID: 6, PlanDate: &planDate} - m := getModel([]*models.Plan{p}, []*models.Action{}) - router := routes.NewRouter(m) + m := getEmptyModel() + router := routes.NewActionRouter(m) data := []byte(`{ "plan_id": 5 }, { "plan_id": 6 }`) - req, _ := http.NewRequest("PUT", "/actions/text", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/text", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -169,10 +162,10 @@ func TestErrorUpdateAction(t *testing.T) { m := getErrorModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) a := &models.Action{PlanID: 6} data, _ := json.Marshal(a) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -194,10 +187,10 @@ func TestErrorOnRetrieveUpdateAction(t *testing.T) { m := getErrorOnGetModel("Model always errors") - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) a := &models.Action{PlanID: 6} data, _ := json.Marshal(a) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() @@ -217,11 +210,12 @@ func TestErrorWriterUpdateAction(t *testing.T) { assert := assert.New(t) a := &models.Action{PlanID: 6} - m := getModel([]*models.Plan{}, []*models.Action{a}) + m := getEmptyModel() + m.AddAction(a) - router := routes.NewRouter(m) + router := routes.NewActionRouter(m) data, _ := json.Marshal(a) - req, _ := http.NewRequest("PUT", "/actions/1", bytes.NewBuffer(data)) + req, _ := http.NewRequest("PUT", "/1", bytes.NewBuffer(data)) req.Header.Set("Content-Type", "application/json") rr := NewBadWriter() diff --git a/routes/route_model_test.go b/routes/route_model_test.go index d0d98f8..1f50614 100644 --- a/routes/route_model_test.go +++ b/routes/route_model_test.go @@ -1,170 +1,22 @@ package routes_test import ( - "fmt" "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" ) -type multiStore struct { - actions []*models.Action - plans []*models.Plan -} - -type testNotFoundError struct { - error -} - -func (t *testNotFoundError) NotFound() bool { - return true -} - -func (ms *multiStore) SelectActions() ([]*models.Action, error) { - return ms.actions, nil -} - -func (ms *multiStore) SelectActionByID(id int) (*models.Action, error) { - if len(ms.actions) < 1 { - err := &testNotFoundError{fmt.Errorf("too small")} - return nil, err - } - return ms.actions[0], nil -} - -func (ms *multiStore) InsertAction(action *models.Action) (int, error) { - return int(action.ActionID), nil -} - -func (ms *multiStore) UpdateAction(action *models.Action) error { - return nil -} - -func (ms *multiStore) SelectPlans() ([]*models.Plan, error) { - return ms.plans, nil -} - -func (ms *multiStore) SelectPlanByID(id int) (*models.Plan, error) { - if len(ms.plans) < 1 { - err := &testNotFoundError{fmt.Errorf("too small")} - return nil, err - } - return ms.plans[0], nil -} - -func (ms *multiStore) InsertPlan(plan *models.Plan) (int, error) { - return int(plan.PlanID), nil -} - -func (ms *multiStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { - return ms.actions, nil -} - -func (ms *multiStore) ConnectionLive() error { - return nil -} - func getEmptyModel() *models.Model { - ss := &multiStore{ - []*models.Action{}, - []*models.Plan{}, - } - m := models.New(ss) - return m -} -func getModel(plns []*models.Plan, acts []*models.Action) *models.Model { - ss := &multiStore{ - actions: acts, - plans: plns, - } - m := models.New(ss) + str, _ := store.GetInMemoryStore() + m := models.New(str) return m } -func (e *errorStore) SelectActions() ([]*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) SelectActionByID(id int) (*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) InsertAction(action *models.Action) (int, error) { - return 0, e.error -} - -func (e *errorStore) UpdateAction(action *models.Action) error { - return e.error -} - -func (e *errorStore) SelectPlans() ([]*models.Plan, error) { - return nil, e.error -} - -func (e *errorStore) SelectPlanByID(id int) (*models.Plan, error) { - return nil, e.error -} - -func (e *errorStore) InsertPlan(plan *models.Plan) (int, error) { - return 0, e.error -} - -func (e *errorStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { - return nil, e.error -} - -func (e *errorStore) ConnectionLive() error { - return e.error -} - -type errorStore struct { - error error -} - -func getErrorModel(errorMsg string) *models.Model { - e := &errorStore{error: fmt.Errorf(errorMsg)} - return models.New(e) -} - -func (e *onlyCreateStore) SelectActions() ([]*models.Action, error) { - return nil, e.error -} - -func (e *onlyCreateStore) SelectActionByID(id int) (*models.Action, error) { - return nil, e.error -} - -func (e *onlyCreateStore) InsertAction(action *models.Action) (int, error) { - return int(action.ActionID), nil -} - -func (e *onlyCreateStore) UpdateAction(action *models.Action) error { - return nil -} - -func (e *onlyCreateStore) SelectPlans() ([]*models.Plan, error) { - return nil, e.error -} - -func (e *onlyCreateStore) SelectPlanByID(id int) (*models.Plan, error) { - return nil, e.error -} - -func (e *onlyCreateStore) InsertPlan(plan *models.Plan) (int, error) { - return int(plan.PlanID), nil -} - -func (e *onlyCreateStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { - return nil, e.error -} - -func (e *onlyCreateStore) ConnectionLive() error { - return e.error -} - -type onlyCreateStore struct { - error error +func getErrorModel(message string) *models.Model { + str := store.GetErrorStore(message, true) + return models.New(str) } func getErrorOnGetModel(errorMsg string) *models.Model { - e := &onlyCreateStore{error: fmt.Errorf(errorMsg)} - return models.New(e) + str := store.GetErrorStore(errorMsg, false) + return models.New(str) } diff --git a/routes/routes.go b/routes/routes.go index 57e189e..dbcd846 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -3,17 +3,23 @@ package routes import ( "encoding/json" "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" "github.com/go-chi/chi" "net/http" ) // NewRouter returns a router powered by the provided model. -func NewRouter(m *models.Model) http.Handler { +func NewRouter(m *models.Model, tok tokens.Toker) http.Handler { router := chi.NewRouter() router.MethodNotAllowed(methodNotAllowedHandler) router.NotFound(notFoundHandler) - router.Mount("/plans", newPlanRouter(m)) - router.Mount("/actions", newActionRouter(m)) + router.Group(func(r chi.Router) { + r.Use(tok.Authenticator) + r.Mount("/actions", NewActionRouter(m)) + r.Mount("/plans", NewPlanRouter(m)) + r.Mount("/me", NewCurrentUserRouter(m)) + }) + router.Mount("/auth", NewAuthRouter(m, tok)) router.Mount("/health", newHealthRouter(m)) router.Get("/ping", ping) return router diff --git a/routes/routes_test.go b/routes/routes_test.go index 0086bd1..b0574b3 100644 --- a/routes/routes_test.go +++ b/routes/routes_test.go @@ -2,6 +2,7 @@ package routes_test import ( "gitea.deepak.science/deepak/gogmagog/routes" + "gitea.deepak.science/deepak/gogmagog/tokens" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -13,7 +14,7 @@ func TestPingHandler(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("GET", "/ping", nil) rr := httptest.NewRecorder() @@ -34,7 +35,7 @@ func TestPingPostHandler(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("POST", "/ping", nil) rr := httptest.NewRecorder() @@ -53,7 +54,7 @@ func TestNotFoundHandler(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("POST", "/null", nil) rr := httptest.NewRecorder() @@ -72,7 +73,7 @@ func TestPingError(t *testing.T) { // set up assert := assert.New(t) m := getEmptyModel() - router := routes.NewRouter(m) + router := routes.NewRouter(m, tokens.New("whatever")) req, _ := http.NewRequest("GET", "/ping", nil) rr := NewBadWriter() diff --git a/store/errorStore.go b/store/errorStore.go new file mode 100644 index 0000000..8cadfcf --- /dev/null +++ b/store/errorStore.go @@ -0,0 +1,79 @@ +package store + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" +) + +func (e *errorStore) SelectActions() ([]*models.Action, error) { + return nil, e.error +} + +func (e *errorStore) SelectActionByID(id int) (*models.Action, error) { + return nil, e.error +} + +func (e *errorStore) InsertAction(action *models.Action) (int, error) { + if e.errorOnInsert { + return 0, e.error + } + return 0, nil +} + +func (e *errorStore) UpdateAction(action *models.Action) error { + if e.errorOnInsert { + return e.error + } + return nil +} + +func (e *errorStore) SelectPlans(userID int) ([]*models.Plan, error) { + return nil, e.error +} + +func (e *errorStore) SelectPlanByID(id int, userID int) (*models.Plan, error) { + return nil, e.error +} + +func (e *errorStore) InsertPlan(plan *models.Plan, userID int) (int, error) { + if e.errorOnInsert { + return 0, e.error + } + return 0, nil +} + +func (e *errorStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { + return nil, e.error +} + +func (e *errorStore) SelectUserByUsername(name string) (*models.User, error) { + return nil, e.error +} + +func (e *errorStore) InsertUser(user *models.User) (int, error) { + if e.errorOnInsert { + return 0, e.error + } + return 0, nil +} + +func (e *errorStore) ConnectionLive() error { + return e.error +} + +type errorStore struct { + error error + errorOnInsert bool +} + +// GetErrorStore returns a models.Store that always errors. This is useful for testing purposes. +func GetErrorStore(errorMsg string, errorOnInsert bool) models.Store { + e := &errorStore{error: fmt.Errorf(errorMsg), errorOnInsert: errorOnInsert} + return e +} + +// GetErrorStoreForError returns a models.Store that always errors with the provided error. +func GetErrorStoreForError(err error, errorOnInsert bool) models.Store { + e := &errorStore{error: err, errorOnInsert: errorOnInsert} + return e +} diff --git a/store/errorStore_test.go b/store/errorStore_test.go new file mode 100644 index 0000000..31c4322 --- /dev/null +++ b/store/errorStore_test.go @@ -0,0 +1,83 @@ +package store_test + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestErrorActionMethods(t *testing.T) { + assert := assert.New(t) + str := store.GetErrorStore("error message sample", true) + str2 := store.GetErrorStore("error message sample", false) + str3 := store.GetErrorStoreForError(fmt.Errorf("test error"), false) + + _, err := str.InsertAction(&models.Action{}) + assert.NotNil(err) + _, err = str2.InsertAction(&models.Action{}) + assert.Nil(err) + _, err = str3.InsertAction(&models.Action{}) + assert.Nil(err) + + _, err = str.SelectActionByID(8) + assert.NotNil(err) + _, err = str2.SelectActionByID(8) + assert.NotNil(err) + + _, err = str.SelectActions() + assert.NotNil(err) + + _, err = str.SelectActionsByPlanID(&models.Plan{}) + assert.NotNil(err) + + replacementAction := &models.Action{} + err = str.UpdateAction(replacementAction) + assert.NotNil(err) + err = str2.UpdateAction(replacementAction) + assert.Nil(err) + +} + +func TestErrorPlanMethods(t *testing.T) { + assert := assert.New(t) + str := store.GetErrorStore("sntahoeu", true) + str2 := store.GetErrorStore("sntahoeu", false) + + _, err := str.SelectPlans(3) + assert.NotNil(err) + + _, err = str.InsertPlan(&models.Plan{}, 3) + assert.NotNil(err) + _, err = str2.InsertPlan(&models.Plan{}, 3) + assert.Nil(err) + + _, err = str.SelectPlanByID(5, 3) + assert.NotNil(err) + +} + +func TestErrorLive(t *testing.T) { + assert := assert.New(t) + str := store.GetErrorStore("error", true) + + err := str.ConnectionLive() + assert.NotNil(err) +} + +func TestErrorUserMethods(t *testing.T) { + assert := assert.New(t) + str := store.GetErrorStore("error", true) + str2 := store.GetErrorStore("error", false) + + u := &models.User{} + + _, err := str.InsertUser(u) + assert.NotNil(err) + _, err = str2.InsertUser(u) + assert.Nil(err) + + _, err = str.SelectUserByUsername("snth") + assert.NotNil(err) +} diff --git a/store/inmemory.go b/store/inmemory.go new file mode 100644 index 0000000..26c49bf --- /dev/null +++ b/store/inmemory.go @@ -0,0 +1,113 @@ +package store + +import ( + "database/sql" + "gitea.deepak.science/deepak/gogmagog/models" +) + +type inMemoryStore struct { + actions []*models.Action + plans []*models.Plan + users []*models.User +} + +// GetInMemoryStore provides a purely in memory store, for testing purposes only, with no persistence. +func GetInMemoryStore() (models.Store, error) { + return &inMemoryStore{ + actions: make([]*models.Action, 0), + plans: make([]*models.Plan, 0), + users: make([]*models.User, 0), + }, nil +} + +func (store *inMemoryStore) SelectActions() ([]*models.Action, error) { + return store.actions, nil +} + +func (store *inMemoryStore) SelectActionsByPlanID(plan *models.Plan) ([]*models.Action, error) { + ret := make([]*models.Action, 0) + for _, action := range store.actions { + if int(plan.PlanID) == int(action.PlanID) { + ret = append(ret, action) + } + } + return ret, nil +} + +func (store *inMemoryStore) SelectActionByID(id int) (*models.Action, error) { + for _, action := range store.actions { + if id == int(action.ActionID) { + return action, nil + } + } + return nil, sql.ErrNoRows +} + +func (store *inMemoryStore) InsertAction(action *models.Action) (int, error) { + id := len(store.actions) + 1 + action.ActionID = int64(id) + store.actions = append(store.actions, action) + return id, nil +} + +func (store *inMemoryStore) UpdateAction(action *models.Action) error { + currentAction, err := store.SelectActionByID(int(action.ActionID)) + if err != nil { + return err + } + currentAction.ActionDescription = action.ActionDescription + currentAction.EstimatedChunks = action.EstimatedChunks + currentAction.CompletedChunks = action.CompletedChunks + currentAction.PlanID = action.PlanID + + return nil + +} + +func (store *inMemoryStore) SelectPlans(userID int) ([]*models.Plan, error) { + ret := make([]*models.Plan, 0) + for _, plan := range store.plans { + if int(plan.UserID) == userID { + ret = append(ret, plan) + } + } + return ret, nil +} + +func (store *inMemoryStore) SelectPlanByID(id int, userID int) (*models.Plan, error) { + for _, plan := range store.plans { + if id == int(plan.PlanID) && (userID == int(plan.UserID)) { + return plan, nil + } + } + return nil, sql.ErrNoRows +} + +func (store *inMemoryStore) InsertPlan(plan *models.Plan, userID int) (int, error) { + id := len(store.plans) + 1 + plan.PlanID = int64(id) + plan.UserID = int64(userID) + store.plans = append(store.plans, plan) + return id, nil +} + +func (store *inMemoryStore) ConnectionLive() error { + return nil +} + +func (store *inMemoryStore) SelectUserByUsername(username string) (*models.User, error) { + for _, user := range store.users { + if username == user.Username { + return user, nil + } + } + return nil, sql.ErrNoRows +} + +// inMemoryStore.InsertUser will not enforce unique usernames, which is ok. +func (store *inMemoryStore) InsertUser(user *models.User) (int, error) { + id := len(store.users) + 1 + user.UserID = int64(id) + store.users = append(store.users, user) + return id, nil +} diff --git a/store/inmemory_test.go b/store/inmemory_test.go new file mode 100644 index 0000000..7b3bf51 --- /dev/null +++ b/store/inmemory_test.go @@ -0,0 +1,104 @@ +package store_test + +import ( + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/store" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestInMemoryActionMethods(t *testing.T) { + assert := assert.New(t) + str, _ := store.GetInMemoryStore() + + sampleplanid := 8 + + act := &models.Action{} + a2 := &models.Action{PlanID: sampleplanid} + + id, _ := str.InsertAction(act) + assert.EqualValues(1, id) + + receivedAction, err := str.SelectActionByID(id) + assert.Nil(err) + assert.EqualValues(act, receivedAction) + + allactions, err := str.SelectActions() + assert.Nil(err) + assert.EqualValues(1, len(allactions)) + + str.InsertAction(a2) + allactions, err = str.SelectActions() + assert.Nil(err) + assert.EqualValues(2, len(allactions)) + + planactions, err := str.SelectActionsByPlanID(&models.Plan{PlanID: int64(sampleplanid)}) + assert.Nil(err) + assert.EqualValues(1, len(planactions)) + assert.Equal(a2, planactions[0]) + + _, err = str.SelectActionByID(151) + assert.NotNil(err) + + sampleDescription := "snth" + replacementAction := &models.Action{ActionID: 1, ActionDescription: sampleDescription} + err = str.UpdateAction(replacementAction) + assert.Nil(err) + assert.Equal(sampleDescription, act.ActionDescription) + + replacementAction = &models.Action{ActionID: 1235122, ActionDescription: sampleDescription} + err = str.UpdateAction(replacementAction) + assert.NotNil(err) + +} + +func TestInMemoryPlanMethods(t *testing.T) { + assert := assert.New(t) + str, _ := store.GetInMemoryStore() + userID := 1 + p := &models.Plan{} + plans, err := str.SelectPlans(userID) + + assert.Nil(err) + assert.EqualValues(0, len(plans)) + + id, err := str.InsertPlan(p, userID) + plans, err = str.SelectPlans(userID) + + assert.Nil(err) + assert.EqualValues(1, len(plans)) + + retrievedPlan, err := str.SelectPlanByID(id, userID) + assert.Nil(err) + assert.Equal(retrievedPlan, p) + + _, err = str.SelectPlanByID(135135, userID) + assert.NotNil(err) +} + +func TestLive(t *testing.T) { + assert := assert.New(t) + str, _ := store.GetInMemoryStore() + + err := str.ConnectionLive() + assert.Nil(err) +} + +func TestInMemoryUserMethods(t *testing.T) { + assert := assert.New(t) + str, _ := store.GetInMemoryStore() + + uname := "hiimauser" + + u := &models.User{Username: uname} + + id, err := str.InsertUser(u) + assert.Nil(err) + + retrievedUser, err := str.SelectUserByUsername(uname) + assert.Nil(err) + assert.EqualValues(id, retrievedUser.UserID) + + _, err = str.SelectUserByUsername("bad username") + assert.NotNil(err) +} diff --git a/store/migrations/000001_create_action_table.down.sql b/store/migrations/000001_create_action_table.down.sql index 0b978f2..6a2c85b 100644 --- a/store/migrations/000001_create_action_table.down.sql +++ b/store/migrations/000001_create_action_table.down.sql @@ -1,4 +1,5 @@ DROP TABLE IF EXISTS actions; DROP TABLE IF EXISTS plans; +DROP TABLE IF EXISTS users; DROP FUNCTION IF EXISTS trigger_set_timestamp; diff --git a/store/migrations/000001_create_action_table.up.sql b/store/migrations/000001_create_action_table.up.sql index b00d482..7c13be1 100644 --- a/store/migrations/000001_create_action_table.up.sql +++ b/store/migrations/000001_create_action_table.up.sql @@ -1,6 +1,16 @@ +CREATE TABLE IF NOT EXISTS users( + user_id serial PRIMARY KEY, + username VARCHAR(50) NOT NULL UNIQUE, + display_name VARCHAR (100) NOT NULL, + password bytea, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL +); + CREATE TABLE IF NOT EXISTS plans( plan_id serial PRIMARY KEY, plan_date DATE NOT NULL, + user_id int REFERENCES users(user_id), created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL ); @@ -16,7 +26,6 @@ CREATE TABLE IF NOT EXISTS actions( plan_id int REFERENCES plans(plan_id) ); - CREATE OR REPLACE FUNCTION trigger_set_timestamp() RETURNS TRIGGER AS $set_updated$ BEGIN @@ -34,3 +43,8 @@ CREATE TRIGGER set_updated BEFORE UPDATE ON plans FOR EACH ROW EXECUTE PROCEDURE trigger_set_timestamp(); + +CREATE TRIGGER set_updated +BEFORE UPDATE ON users +FOR EACH ROW +EXECUTE PROCEDURE trigger_set_timestamp(); diff --git a/store/postgres.go b/store/postgres.go index a443f40..667b4d7 100644 --- a/store/postgres.go +++ b/store/postgres.go @@ -97,29 +97,30 @@ func (store *postgresStore) UpdateAction(action *models.Action) error { } -func (store *postgresStore) SelectPlans() ([]*models.Plan, error) { +func (store *postgresStore) SelectPlans(userID int) ([]*models.Plan, error) { + queryString := store.db.Rebind("SELECT plan_id, plan_date, user_id FROM plans WHERE user_id = ?") plans := make([]*models.Plan, 0) - err := store.db.Select(&plans, "SELECT plan_id, plan_date FROM plans") + err := store.db.Select(&plans, queryString, userID) if err != nil { return nil, err } return plans, nil } -func (store *postgresStore) SelectPlanByID(id int) (*models.Plan, error) { +func (store *postgresStore) SelectPlanByID(id int, userID int) (*models.Plan, error) { plan := models.Plan{} - err := store.db.Get(&plan, store.db.Rebind("SELECT plan_id, plan_date FROM plans WHERE plan_id = ?"), id) + err := store.db.Get(&plan, store.db.Rebind("SELECT plan_id, plan_date, user_id FROM plans WHERE plan_id = ? AND user_id = ?"), id, userID) if err != nil { return nil, err } return &plan, nil } -func (store *postgresStore) InsertPlan(plan *models.Plan) (int, error) { - queryString := store.db.Rebind("INSERT INTO plans (plan_date) VALUES (?) RETURNING plan_id") +func (store *postgresStore) InsertPlan(plan *models.Plan, userID int) (int, error) { + queryString := store.db.Rebind("INSERT INTO plans (plan_date, user_id) VALUES (?, ?) RETURNING plan_id") tx := store.db.MustBegin() var id int - err := tx.Get(&id, queryString, plan.PlanDate) + err := tx.Get(&id, queryString, plan.PlanDate, userID) if err != nil { tx.Rollback() return -1, err @@ -134,3 +135,28 @@ func (store *postgresStore) InsertPlan(plan *models.Plan) (int, error) { func (store *postgresStore) ConnectionLive() error { return store.db.Ping() } + +func (store *postgresStore) SelectUserByUsername(username string) (*models.User, error) { + user := models.User{} + err := store.db.Get(&user, store.db.Rebind("SELECT user_id, username, display_name, password FROM users WHERE username = ?"), username) + if err != nil { + return nil, err + } + return &user, nil +} + +func (store *postgresStore) InsertUser(user *models.User) (int, error) { + queryString := store.db.Rebind("INSERT INTO users (username, display_name, password) VALUES (?, ?, ?) RETURNING user_id") + tx := store.db.MustBegin() + var id int + err := tx.Get(&id, queryString, user.Username, user.DisplayName, user.Password) + if err != nil { + tx.Rollback() + return -1, err + } + err = tx.Commit() + if err != nil { + return -1, err + } + return id, nil +} diff --git a/store/postgres_plan_test.go b/store/postgres_plan_test.go new file mode 100644 index 0000000..fdf1ed8 --- /dev/null +++ b/store/postgres_plan_test.go @@ -0,0 +1,186 @@ +package store_test + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestSelectPlans(t *testing.T) { + assert := assert.New(t) + + currentTime := time.Now() + idToUse := 1 + userIDToUse := 2 + + str, mock := getDbMock(t) + + rows := sqlmock.NewRows([]string{"plan_id", "plan_date", "user_id"}).AddRow(idToUse, currentTime, userIDToUse) + mock.ExpectQuery(`^SELECT plan_id, plan_date, user_id FROM plans WHERE user_id = \$1`). + WithArgs(userIDToUse). + WillReturnRows(rows) + + plans, err := str.SelectPlans(userIDToUse) + assert.Nil(err) + assert.Equal(1, len(plans)) + plan := plans[0] + assert.EqualValues(idToUse, plan.PlanID) + assert.Equal(currentTime, *plan.PlanDate) + assert.EqualValues(userIDToUse, plan.UserID) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} + +func TestSelectPlanByID(t *testing.T) { + assert := assert.New(t) + + currentTime := time.Now() + idToUse := 1 + userIDToUse := 2 + + str, mock := getDbMock(t) + + rows := sqlmock.NewRows([]string{"plan_id", "plan_date", "user_id"}).AddRow(idToUse, currentTime, userIDToUse) + mock.ExpectQuery(`^SELECT plan_id, plan_date, user_id FROM plans WHERE plan_id = \$1 AND user_id = \$2$`). + WithArgs(idToUse, userIDToUse). + WillReturnRows(rows) + + plan, err := str.SelectPlanByID(idToUse, userIDToUse) + assert.Nil(err) + assert.EqualValues(idToUse, plan.PlanID) + assert.Equal(currentTime, *plan.PlanDate) + assert.EqualValues(userIDToUse, plan.UserID) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} + +func TestInsertPlan(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + planDate, _ := time.Parse("2006-01-02", "2021-01-01") + userID := 2 + badUserID := 7 + + plan := &models.Plan{PlanDate: &planDate, UserID: int64(badUserID)} + + idToUse := 8 + + rows := sqlmock.NewRows([]string{"plan_id"}).AddRow(8) + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO plans \(plan_date, user_id\) VALUES \(\$1, \$2\) RETURNING plan_id$`). + WithArgs(planDate, userID). + WillReturnRows(rows) + mock.ExpectCommit() + + // function under test + insertedId, err := str.InsertPlan(plan, userID) + // check results + assert.Nil(err) + assert.EqualValues(idToUse, insertedId) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} + +func TestInsertPlanErr(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + planDate, _ := time.Parse("2006-01-02", "2021-01-01") + userID := 2 + badUserID := 7 + plan := &models.Plan{PlanDate: &planDate, UserID: int64(badUserID)} + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO plans \(plan_date, user_id\) VALUES \(\$1, \$2\) RETURNING plan_id$`). + WithArgs(planDate, userID). + WillReturnError(fmt.Errorf("example error")) + mock.ExpectRollback() + + // function under test + _, err := str.InsertPlan(plan, userID) + // check results + assert.NotNil(err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} + +func TestInsertPlanCommitErr(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + planDate, _ := time.Parse("2006-01-02", "2021-01-01") + userID := 2 + plan := &models.Plan{PlanDate: &planDate, UserID: int64(userID)} + idToUse := 8 + + rows := sqlmock.NewRows([]string{"plan_id"}).AddRow(idToUse) + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO plans \(plan_date, user_id\) VALUES \(\$1, \$2\) RETURNING plan_id$`). + WithArgs(planDate, userID). + WillReturnRows(rows) + mock.ExpectCommit().WillReturnError(fmt.Errorf("another error example")) + + // function under test + _, err := str.InsertPlan(plan, userID) + // check results + assert.NotNil(err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} + +func TestErrPlanByID(t *testing.T) { + assert := assert.New(t) + + idToUse := 1 + + str, mock := getDbMock(t) + + mock.ExpectQuery(`^SELECT plan_id, plan_date, user_id FROM plans WHERE plan_id = \$1 AND user_id = \$2$`). + WithArgs(idToUse, 8). + WillReturnError(fmt.Errorf("example error")) + + plan, err := str.SelectPlanByID(idToUse, 8) + assert.NotNil(err) + assert.Nil(plan) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} + +func TestErrPlans(t *testing.T) { + // set up tests + assert := assert.New(t) + str, mock := getDbMock(t) + + mock.ExpectQuery(`^SELECT plan_id, plan_date, user_id FROM plans WHERE user_id = \$1$`). + WithArgs(8). + WillReturnError(fmt.Errorf("example error")) + // function under test + plans, err := str.SelectPlans(8) + // test results + assert.Nil(plans) + assert.NotNil(err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} diff --git a/store/postgres_test.go b/store/postgres_test.go index 4fe3371..2a309aa 100644 --- a/store/postgres_test.go +++ b/store/postgres_test.go @@ -24,149 +24,6 @@ func getDbMock(t *testing.T) (models.Store, sqlmock.Sqlmock) { return str, mock } -func TestSelectPlans(t *testing.T) { - assert := assert.New(t) - - currentTime := time.Now() - idToUse := 1 - - str, mock := getDbMock(t) - - rows := sqlmock.NewRows([]string{"plan_id", "plan_date"}).AddRow(idToUse, currentTime) - mock.ExpectQuery("^SELECT plan_id, plan_date FROM plans$").WillReturnRows(rows) - - plans, err := str.SelectPlans() - assert.Nil(err) - assert.Equal(1, len(plans)) - plan := plans[0] - assert.EqualValues(idToUse, plan.PlanID) - assert.Equal(currentTime, *plan.PlanDate) - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } -} - -func TestSelectPlanByID(t *testing.T) { - assert := assert.New(t) - - currentTime := time.Now() - idToUse := 1 - - str, mock := getDbMock(t) - - rows := sqlmock.NewRows([]string{"plan_id", "plan_date"}).AddRow(idToUse, currentTime) - mock.ExpectQuery("^SELECT plan_id, plan_date FROM plans WHERE plan_id = \\$1$").WithArgs(idToUse).WillReturnRows(rows) - - plan, err := str.SelectPlanByID(idToUse) - assert.Nil(err) - assert.EqualValues(idToUse, plan.PlanID) - assert.Equal(currentTime, *plan.PlanDate) - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } -} - -func TestInsertPlan(t *testing.T) { - // setup - assert := assert.New(t) - - str, mock := getDbMock(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - plan := &models.Plan{PlanDate: &planDate} - - idToUse := 8 - - rows := sqlmock.NewRows([]string{"plan_id"}).AddRow(8) - - mock.ExpectBegin() - mock.ExpectQuery("^INSERT INTO plans \\(plan_date\\) VALUES \\(\\$1\\) RETURNING plan_id$"). - WithArgs(planDate). - WillReturnRows(rows) - mock.ExpectCommit() - - // function under test - insertedId, err := str.InsertPlan(plan) - // check results - assert.Nil(err) - assert.EqualValues(idToUse, insertedId) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } - -} - -func TestInsertPlanErr(t *testing.T) { - // setup - assert := assert.New(t) - - str, mock := getDbMock(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - plan := &models.Plan{PlanDate: &planDate} - - mock.ExpectBegin() - mock.ExpectQuery("^INSERT INTO plans \\(plan_date\\) VALUES \\(\\$1\\) RETURNING plan_id$"). - WithArgs(planDate). - WillReturnError(fmt.Errorf("example error")) - mock.ExpectRollback() - - // function under test - _, err := str.InsertPlan(plan) - // check results - assert.NotNil(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } - -} - -func TestInsertPlanCommitErr(t *testing.T) { - // setup - assert := assert.New(t) - - str, mock := getDbMock(t) - planDate, _ := time.Parse("2006-01-02", "2021-01-01") - plan := &models.Plan{PlanDate: &planDate} - - idToUse := 8 - - rows := sqlmock.NewRows([]string{"plan_id"}).AddRow(idToUse) - - mock.ExpectBegin() - mock.ExpectQuery("^INSERT INTO plans \\(plan_date\\) VALUES \\(\\$1\\) RETURNING plan_id$"). - WithArgs(planDate). - WillReturnRows(rows) - mock.ExpectCommit().WillReturnError(fmt.Errorf("another error example")) - - // function under test - _, err := str.InsertPlan(plan) - // check results - assert.NotNil(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } - -} - -func TestErrPlanByID(t *testing.T) { - assert := assert.New(t) - - idToUse := 1 - - str, mock := getDbMock(t) - - mock.ExpectQuery("^SELECT plan_id, plan_date FROM plans WHERE plan_id = \\$1$").WithArgs(idToUse).WillReturnError(fmt.Errorf("example error")) - - plan, err := str.SelectPlanByID(idToUse) - assert.NotNil(err) - assert.Nil(plan) - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } -} - func TestSelectActions(t *testing.T) { // set up test assert := assert.New(t) @@ -338,22 +195,6 @@ func TestSelectActionById(t *testing.T) { } } -func TestErrPlans(t *testing.T) { - // set up tests - assert := assert.New(t) - str, mock := getDbMock(t) - - mock.ExpectQuery("^SELECT plan_id, plan_date FROM plans$").WillReturnError(fmt.Errorf("example error")) - // function under test - plans, err := str.SelectPlans() - // test results - assert.Nil(plans) - assert.NotNil(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %s", err) - } -} - func TestErrActions(t *testing.T) { // set up tests assert := assert.New(t) diff --git a/store/postgres_user_test.go b/store/postgres_user_test.go new file mode 100644 index 0000000..83a5d55 --- /dev/null +++ b/store/postgres_user_test.go @@ -0,0 +1,150 @@ +package store_test + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSelectUserByUsername(t *testing.T) { + // set up test + assert := assert.New(t) + + id := 1 + username := "test" + displayName := "Tom Est" + password := []byte("ABC€") + + str, mock := getDbMock(t) + + rows := sqlmock.NewRows([]string{ + "user_id", + "username", + "display_name", + "password", + }). + AddRow(id, username, displayName, password) + + mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE username = \$1`). + WithArgs(username). + WillReturnRows(rows) + + // function under test + user, err := str.SelectUserByUsername(username) + + // test results + assert.Nil(err) + assert.EqualValues(id, user.UserID) + assert.Equal(username, user.Username) + assert.Equal(displayName, user.DisplayName) + assert.Equal(password, user.Password) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} + +func TestErrUserByID(t *testing.T) { + assert := assert.New(t) + + str, mock := getDbMock(t) + username := "snth" + mock.ExpectQuery(`^SELECT user_id, username, display_name, password FROM users WHERE username = \$1`).WithArgs(username).WillReturnError(fmt.Errorf("example error")) + + user, err := str.SelectUserByUsername(username) + assert.NotNil(err) + assert.Nil(user) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } +} + +func TestInsertUser(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + username := "test" + displayName := "Tom Est" + password := []byte("ABC€") + usr := &models.User{Username: username, DisplayName: displayName, Password: password} + + idToUse := 8 + + rows := sqlmock.NewRows([]string{"user_id"}).AddRow(8) + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`). + WithArgs(username, displayName, password). + WillReturnRows(rows) + mock.ExpectCommit() + + // function under test + insertedId, err := str.InsertUser(usr) + // check results + assert.Nil(err) + assert.EqualValues(idToUse, insertedId) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} + +func TestInsertUserErr(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + username := "test" + displayName := "Tom Est" + password := []byte("ABC€") + usr := &models.User{Username: username, DisplayName: displayName, Password: password} + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`). + WithArgs(username, displayName, password). + WillReturnError(fmt.Errorf("example error")) + mock.ExpectRollback() + + // function under test + _, err := str.InsertUser(usr) + // check results + assert.NotNil(err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} + +func TestInsertUserCommitErr(t *testing.T) { + // setup + assert := assert.New(t) + + str, mock := getDbMock(t) + username := "test" + displayName := "Tom Est" + password := []byte("ABC€") + usr := &models.User{Username: username, DisplayName: displayName, Password: password} + + idToUse := 8 + + rows := sqlmock.NewRows([]string{"user_id"}).AddRow(idToUse) + + mock.ExpectBegin() + mock.ExpectQuery(`^INSERT INTO users \(username, display_name, password\) VALUES \(\$1, \$2, \$3\) RETURNING user_id$`). + WithArgs(username, displayName, password). + WillReturnRows(rows) + mock.ExpectCommit().WillReturnError(fmt.Errorf("another error example")) + + // function under test + _, err := str.InsertUser(usr) + // check results + assert.NotNil(err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %s", err) + } + +} diff --git a/store/store.go b/store/store.go index 97f87fb..3e1036a 100644 --- a/store/store.go +++ b/store/store.go @@ -59,6 +59,10 @@ func createPostgresDB(dbConf *config.DBConfig) (*sqlx.DB, error) { log.Print("Could not perform migration", err) return nil, err } + if dbConf.DropOnStart { + log.Print("Going down") + m.Down() + } if err := m.Up(); err != nil { if err == migrate.ErrNoChange { log.Print("No migration needed.") @@ -66,6 +70,8 @@ func createPostgresDB(dbConf *config.DBConfig) (*sqlx.DB, error) { log.Printf("An error occurred while syncing the database.. %v", err) return nil, err } + } else { + log.Print("Performed database migration") } return db, nil } diff --git a/tokens/deterministicToker.go b/tokens/deterministicToker.go new file mode 100644 index 0000000..96cd9ed --- /dev/null +++ b/tokens/deterministicToker.go @@ -0,0 +1,53 @@ +package tokens + +import ( + "context" + "encoding/json" + "gitea.deepak.science/deepak/gogmagog/models" + "log" + "net/http" +) + +type deterministicToker struct{} + +// GetDeterministicToker returns a zero security toker for testing purposes. +// Do not use in production. +func GetDeterministicToker() Toker { + return &deterministicToker{} +} + +func (d *deterministicToker) EncodeUser(user *models.UserNoPassword) string { + tok := &UserToken{ID: user.UserID, Username: user.Username} + ret, _ := json.Marshal(tok) + return string(ret) +} + +func (d *deterministicToker) DecodeTokenString(tokenString string) (*UserToken, error) { + var tok UserToken + err := json.Unmarshal([]byte(tokenString), &tok) + return &tok, err +} + +func (d *deterministicToker) Authenticator(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := TokenFromHeader(r) + if tokenString == "" { + log.Print("No valid token found") + unauthorized(w, r) + return + } + + userToken, err := d.DecodeTokenString(tokenString) + if err != nil { + log.Printf("Error while verifying token: %s", err) + unauthorized(w, r) + return + } + + log.Printf("Got user with ID: [%d]", userToken.ID) + ctx := context.WithValue(r.Context(), userIDCtxKey, userToken.ID) + ctx = context.WithValue(ctx, usernameCtxKey, userToken.Username) + // Authenticated + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/tokens/deterministic_toker_middleware_test.go b/tokens/deterministic_toker_middleware_test.go new file mode 100644 index 0000000..e3c6f03 --- /dev/null +++ b/tokens/deterministic_toker_middleware_test.go @@ -0,0 +1,80 @@ +package tokens_test + +import ( + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "log" + "net/http" + "net/http/httptest" + "testing" +) + +var dtMiddlewareURL string = "/" + +func dtRequestAuth(header string) *http.Request { + req, _ := http.NewRequest("GET", dtMiddlewareURL, nil) + req.Header.Add(authKey, header) + + return req +} + +func verifyingHandlerdt(t *testing.T, username string, userID int) http.Handler { + assert := assert.New(t) + toker := tokens.GetDeterministicToker() + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + receivedID, _ := tokens.GetUserID(ctx) + receivedUsername, _ := tokens.GetUsername(ctx) + assert.EqualValues(userID, receivedID) + assert.Equal(username, receivedUsername) + }) + return toker.Authenticator(dummyHandler) +} + +func TestMiddlewareNoTokendt(t *testing.T) { + assert := assert.New(t) + + req := httptest.NewRequest(http.MethodGet, dtMiddlewareURL, nil) + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandlerdt(t, "", 0) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) +} + +func TestMiddlewareBadTokendt(t *testing.T) { + assert := assert.New(t) + + req := mwRequestAuth("Bearer bad") + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandlerdt(t, "", 0) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) +} + +func TestMiddlewareGoodTokendt(t *testing.T) { + assert := assert.New(t) + + idToUse := 3 + username := "username" + displayName := "display name" + user := &models.UserNoPassword{UserID: int64(idToUse), Username: username, DisplayName: displayName} + + toker := tokens.GetDeterministicToker() + validToken := toker.EncodeUser(user) + log.Print(validToken) + req := mwRequestAuth("Bearer " + validToken) + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandlerdt(t, username, idToUse) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusOK, status) +} diff --git a/tokens/middleware.go b/tokens/middleware.go new file mode 100644 index 0000000..46b971e --- /dev/null +++ b/tokens/middleware.go @@ -0,0 +1,88 @@ +package tokens + +import ( + "context" + "fmt" + "log" + "net/http" + "strings" +) + +type contextKey struct { + name string +} + +var userIDCtxKey = &contextKey{"UserID"} +var usernameCtxKey = &contextKey{"Username"} + +func unauthorized(w http.ResponseWriter, r *http.Request) { + code := http.StatusUnauthorized + http.Error(w, http.StatusText(code), code) +} + +// TokenFromHeader tries to retreive the token string from the +// "Authorization" reqeust header: "Authorization: BEARER T". +func TokenFromHeader(r *http.Request) string { + // Get token from authorization header. + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { + return bearer[7:] + } + return "" +} + +func (tok *jwtToker) Authenticator(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := TokenFromHeader(r) + if tokenString == "" { + log.Print("No valid token found") + unauthorized(w, r) + return + } + + userToken, err := tok.DecodeTokenString(tokenString) + if err != nil { + log.Printf("Error while verifying token: %s", err) + unauthorized(w, r) + return + } + + log.Printf("Got user with ID: [%d]", userToken.ID) + ctx := context.WithValue(r.Context(), userIDCtxKey, userToken.ID) + ctx = context.WithValue(ctx, usernameCtxKey, userToken.Username) + // Authenticated + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetUserID is a convenience method that gets the user ID from the context. +// I hate the fact that we're passing user ID on the context, but it is more +// idiomatic Go than any type shenanigans. +func GetUserID(ctx context.Context) (int, error) { + userID, ok := ctx.Value(userIDCtxKey).(int64) + if !ok { + return -1, fmt.Errorf("Could not parse user ID [%s] from context", ctx.Value(userIDCtxKey)) + + } + return int(userID), nil +} + +// SetUserID sets the username field on a context, necessary because the key is an unexported custom type. +func SetUserID(ctx context.Context, id int) context.Context { + return context.WithValue(ctx, userIDCtxKey, int64(id)) +} + +// GetUsername does something similar to GetUserID. +func GetUsername(ctx context.Context) (string, error) { + username, ok := ctx.Value(usernameCtxKey).(string) + if !ok { + return "", fmt.Errorf("Could not parse username [%s] from context", ctx.Value(usernameCtxKey)) + } + return username, nil +} + +// GetContextForUserValues is a test helper method that creates a context with user ID set. +func GetContextForUserValues(userID int, username string) context.Context { + ctx := context.WithValue(context.Background(), userIDCtxKey, int64(userID)) + return context.WithValue(ctx, usernameCtxKey, username) +} diff --git a/tokens/middleware_context_test.go b/tokens/middleware_context_test.go new file mode 100644 index 0000000..06e2847 --- /dev/null +++ b/tokens/middleware_context_test.go @@ -0,0 +1,49 @@ +package tokens_test + +import ( + "context" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGoodContext(t *testing.T) { + assert := assert.New(t) + + idToUse := 3 + username := "username" + + ctx := tokens.GetContextForUserValues(idToUse, username) + + receivedID, err := tokens.GetUserID(ctx) + assert.Nil(err) + assert.EqualValues(idToUse, receivedID) + + receivedUsername, err := tokens.GetUsername(ctx) + assert.Nil(err) + assert.Equal(username, receivedUsername) + +} + +func TestBadContext(t *testing.T) { + assert := assert.New(t) + + ctx := context.Background() + + _, err := tokens.GetUserID(ctx) + assert.NotNil(err) + + _, err = tokens.GetUsername(ctx) + assert.NotNil(err) + +} + +func TestSetContext(t *testing.T) { + assert := assert.New(t) + + idToUse := 3 + ctx := tokens.SetUserID(context.Background(), 3) + receivedID, err := tokens.GetUserID(ctx) + assert.Nil(err) + assert.EqualValues(idToUse, receivedID) +} diff --git a/tokens/middleware_http_test.go b/tokens/middleware_http_test.go new file mode 100644 index 0000000..4ec0f17 --- /dev/null +++ b/tokens/middleware_http_test.go @@ -0,0 +1,78 @@ +package tokens_test + +import ( + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +var middlewareURL string = "/" + +func mwRequestAuth(header string) *http.Request { + req, _ := http.NewRequest("GET", middlewareURL, nil) + req.Header.Add(authKey, header) + + return req +} + +func verifyingHandler(t *testing.T, username string, userID int) http.Handler { + assert := assert.New(t) + toker := tokens.New("secret") + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + receivedID, _ := tokens.GetUserID(ctx) + receivedUsername, _ := tokens.GetUsername(ctx) + assert.EqualValues(userID, receivedID) + assert.Equal(username, receivedUsername) + }) + return toker.Authenticator(dummyHandler) +} + +func TestMiddlewareNoToken(t *testing.T) { + assert := assert.New(t) + + req := httptest.NewRequest(http.MethodGet, middlewareURL, nil) + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandler(t, "", 0) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) +} + +func TestMiddlewareBadToken(t *testing.T) { + assert := assert.New(t) + + req := mwRequestAuth("Bearer bad") + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandler(t, "", 0) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusUnauthorized, status) +} + +func TestMiddlewareGoodToken(t *testing.T) { + assert := assert.New(t) + + idToUse := 3 + username := "username" + displayName := "display name" + user := &models.UserNoPassword{UserID: int64(idToUse), Username: username, DisplayName: displayName} + + toker := tokens.New("secret") + validToken := toker.EncodeUser(user) + req := mwRequestAuth("Bearer " + validToken) + rr := httptest.NewRecorder() + + middlewareHandler := verifyingHandler(t, username, idToUse) + middlewareHandler.ServeHTTP(rr, req) + + status := rr.Code + assert.Equal(http.StatusOK, status) +} diff --git a/tokens/middleware_test.go b/tokens/middleware_test.go new file mode 100644 index 0000000..f2ba184 --- /dev/null +++ b/tokens/middleware_test.go @@ -0,0 +1,56 @@ +package tokens_test + +import ( + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +var ( + url = "" + authKey = "Authorization" +) + +func requestWithAuth(header string) *http.Request { + req, _ := http.NewRequest("GET", url, nil) + req.Header.Add(authKey, header) + + return req +} + +func TestHeaderParseBasic(t *testing.T) { + assert := assert.New(t) + + header := "Bearer testing" + req := requestWithAuth(header) + + assert.Equal("testing", tokens.TokenFromHeader(req)) +} + +func TestHeaderParseNoSpace(t *testing.T) { + assert := assert.New(t) + + header := "Bearerxtesting" + req := requestWithAuth(header) + + assert.Equal("testing", tokens.TokenFromHeader(req)) +} + +func TestHeaderParseUnicode(t *testing.T) { + assert := assert.New(t) + + header := "Bearer 🌸" + req := requestWithAuth(header) + + assert.Equal("🌸", tokens.TokenFromHeader(req)) +} + +func TestHeaderParseMalformed(t *testing.T) { + assert := assert.New(t) + + header := "testing" + req := requestWithAuth(header) + + assert.Equal("", tokens.TokenFromHeader(req)) +} diff --git a/tokens/tokens.go b/tokens/tokens.go new file mode 100644 index 0000000..72ed0af --- /dev/null +++ b/tokens/tokens.go @@ -0,0 +1,86 @@ +package tokens + +import ( + "fmt" + "gitea.deepak.science/deepak/gogmagog/models" + "github.com/go-chi/jwtauth" + "github.com/lestrrat-go/jwx/jwt" + "net/http" + "time" +) + +// Toker represents a tokenizer, capable of encoding and verifying tokens. +type Toker interface { + EncodeUser(user *models.UserNoPassword) string + DecodeTokenString(tokenString string) (*UserToken, error) + Authenticator(http.Handler) http.Handler +} + +type jwtToker struct { + tokenAuth *jwtauth.JWTAuth +} + +// New returns a default Toker for a given secret key. +func New(key string) Toker { + return &jwtToker{tokenAuth: jwtauth.New("HS256", []byte(key), nil)} +} + +func (tok *jwtToker) EncodeUser(user *models.UserNoPassword) string { + claims := map[string]interface{}{ + "user_id": user.UserID, + "username": user.Username, + "display_name": user.DisplayName, + "iss": "gogmagog.deepak.science", + "aud": "gogmagog.deepak.science", + } + jwtauth.SetIssuedNow(claims) + jwtauth.SetExpiryIn(claims, 2*time.Hour) + _, tokenString, _ := tok.tokenAuth.Encode(claims) + return tokenString +} + +// UserToken represents a decoded jwt token. +type UserToken struct { + ID int64 + Username string +} + +func (tok *jwtToker) DecodeTokenString(tokenString string) (*UserToken, error) { + token, err := tok.tokenAuth.Decode(tokenString) + if err != nil { + return nil, fmt.Errorf("Error decoding token") + } + + // Should never happen, remove soon. + // if token == nil { + // return nil, fmt.Errorf("Token was nil") + // } + + err = jwt.Validate( + token, + jwt.WithIssuer("gogmagog.deepak.science"), + jwt.WithAudience("gogmagog.deepak.science"), + ) + if err != nil { + return nil, err + } + + userIDRaw, ok := token.Get("user_id") + if !ok { + return nil, fmt.Errorf("error finding user_id claim") + } + userID, ok := userIDRaw.(float64) + if !ok { + return nil, fmt.Errorf("Could not parse [%s] as userID", userIDRaw) + } + usernameRaw, ok := token.Get("username") + if !ok { + return nil, fmt.Errorf("error finding username claim") + } + username, ok := usernameRaw.(string) + if !ok { + return nil, fmt.Errorf("Could not parse [%s] as username", usernameRaw) + } + + return &UserToken{ID: int64(userID), Username: username}, nil +} diff --git a/tokens/tokens_test.go b/tokens/tokens_test.go new file mode 100644 index 0000000..76265c1 --- /dev/null +++ b/tokens/tokens_test.go @@ -0,0 +1,165 @@ +package tokens_test + +import ( + "gitea.deepak.science/deepak/gogmagog/models" + "gitea.deepak.science/deepak/gogmagog/tokens" + "github.com/go-chi/jwtauth" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestBasic(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + idToUse := int64(3) + usernameToUse := "test" + usr := &models.UserNoPassword{ + UserID: idToUse, + Username: usernameToUse, + DisplayName: "Ted Est III", + } + token := toker.EncodeUser(usr) + + userToken, err := toker.DecodeTokenString(token) + assert.Nil(err) + assert.Equal(usernameToUse, userToken.Username) + assert.Equal(idToUse, userToken.ID) + _, err = tokens.New("bad secret").DecodeTokenString(token) + assert.NotNil(err) +} + +func getTokenString(claims map[string]interface{}) string { + auth := jwtauth.New("HS256", []byte("secret"), nil) + + jwtauth.SetIssuedNow(claims) + jwtauth.SetExpiryIn(claims, 2*time.Hour) + _, tokenString, _ := auth.Encode(claims) + + return tokenString +} + +func TestDecodeBadIssuer(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + idToUse := 3 + username := "test" + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "user_id": int64(idToUse), + "username": username, + "display_name": "display_name", + "iss": gog, + "aud": "bad", + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +} + +func TestDecodeBadAudience(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + idToUse := 3 + username := "test" + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "user_id": int64(idToUse), + "username": username, + "display_name": "display_name", + "iss": "bad", + "aud": gog, + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +} + +func TestDecodeMissingUserID(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + username := "test" + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "username": username, + "display_name": "display_name", + "iss": gog, + "aud": gog, + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +} + +func TestDecodeBadUserID(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + username := "test" + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "username": username, + "user_id": "id", + "display_name": "display_name", + "iss": gog, + "aud": gog, + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +} + +func TestDecodeMissingUsername(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + idToUse := 3 + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "user_id": int64(idToUse), + "display_name": "display_name", + "iss": gog, + "aud": gog, + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +} + +func TestDecodeBadUsername(t *testing.T) { + assert := assert.New(t) + toker := tokens.New("secret") + + gog := "gogmagog.deepak.science" + + claims := map[string]interface{}{ + "username": 5, + "user_id": 3, + "display_name": "display_name", + "iss": gog, + "aud": gog, + } + + token := getTokenString(claims) + _, err := toker.DecodeTokenString(token) + assert.NotNil(err) + +}