89 lines
2.6 KiB
Go
89 lines
2.6 KiB
Go
package tokens
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
type contextKey struct {
|
|
name string
|
|
}
|
|
|
|
var userIDCtxKey = &contextKey{"UserID"}
|
|
var usernameCtxKey = &contextKey{"Username"}
|
|
|
|
func unauthorized(w http.ResponseWriter, r *http.Request) {
|
|
code := http.StatusUnauthorized
|
|
http.Error(w, http.StatusText(code), code)
|
|
}
|
|
|
|
// TokenFromHeader tries to retreive the token string from the
|
|
// "Authorization" reqeust header: "Authorization: BEARER T".
|
|
func TokenFromHeader(r *http.Request) string {
|
|
// Get token from authorization header.
|
|
bearer := r.Header.Get("Authorization")
|
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
|
return bearer[7:]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (tok *jwtToker) Authenticator(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenString := TokenFromHeader(r)
|
|
if tokenString == "" {
|
|
log.Print("No valid token found")
|
|
unauthorized(w, r)
|
|
return
|
|
}
|
|
|
|
userToken, err := tok.DecodeTokenString(tokenString)
|
|
if err != nil {
|
|
log.Printf("Error while verifying token: %s", err)
|
|
unauthorized(w, r)
|
|
return
|
|
}
|
|
|
|
log.Printf("Got user with ID: [%d]", userToken.ID)
|
|
ctx := context.WithValue(r.Context(), userIDCtxKey, userToken.ID)
|
|
ctx = context.WithValue(ctx, usernameCtxKey, userToken.Username)
|
|
// Authenticated
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// GetUserID is a convenience method that gets the user ID from the context.
|
|
// I hate the fact that we're passing user ID on the context, but it is more
|
|
// idiomatic Go than any type shenanigans.
|
|
func GetUserID(ctx context.Context) (int, error) {
|
|
userID, ok := ctx.Value(userIDCtxKey).(int64)
|
|
if !ok {
|
|
return -1, fmt.Errorf("Could not parse user ID [%s] from context", ctx.Value(userIDCtxKey))
|
|
|
|
}
|
|
return int(userID), nil
|
|
}
|
|
|
|
// SetUserID sets the username field on a context, necessary because the key is an unexported custom type.
|
|
func SetUserID(ctx context.Context, id int) context.Context {
|
|
return context.WithValue(ctx, userIDCtxKey, int64(id))
|
|
}
|
|
|
|
// GetUsername does something similar to GetUserID.
|
|
func GetUsername(ctx context.Context) (string, error) {
|
|
username, ok := ctx.Value(usernameCtxKey).(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("Could not parse username [%s] from context", ctx.Value(usernameCtxKey))
|
|
}
|
|
return username, nil
|
|
}
|
|
|
|
// GetContextForUserValues is a test helper method that creates a context with user ID set.
|
|
func GetContextForUserValues(userID int, username string) context.Context {
|
|
ctx := context.WithValue(context.Background(), userIDCtxKey, int64(userID))
|
|
return context.WithValue(ctx, usernameCtxKey, username)
|
|
}
|