package tokens import ( "context" "fmt" "log" "net/http" "strings" ) type contextKey struct { name string } var userIDCtxKey = &contextKey{"UserID"} var usernameCtxKey = &contextKey{"Username"} func unauthorized(w http.ResponseWriter, r *http.Request) { code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) } // TokenFromHeader tries to retreive the token string from the // "Authorization" reqeust header: "Authorization: BEARER T". func TokenFromHeader(r *http.Request) string { // Get token from authorization header. bearer := r.Header.Get("Authorization") if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { return bearer[7:] } return "" } func (tok *jwtToker) Authenticator(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := TokenFromHeader(r) if tokenString == "" { log.Print("No valid token found") unauthorized(w, r) return } userToken, err := tok.DecodeTokenString(tokenString) if err != nil { log.Printf("Error while verifying token: %s", err) unauthorized(w, r) return } log.Printf("Got user with ID: [%d]", userToken.ID) ctx := context.WithValue(r.Context(), userIDCtxKey, userToken.ID) ctx = context.WithValue(ctx, usernameCtxKey, userToken.Username) // Authenticated next.ServeHTTP(w, r.WithContext(ctx)) }) } // GetUserID is a convenience method that gets the user ID from the context. // I hate the fact that we're passing user ID on the context, but it is more // idiomatic Go than any type shenanigans. func GetUserID(ctx context.Context) (int, error) { userID, ok := ctx.Value(userIDCtxKey).(int64) if !ok { return -1, fmt.Errorf("Could not parse user ID [%s] from context", ctx.Value(userIDCtxKey)) } return int(userID), nil } // SetUserID sets the username field on a context, necessary because the key is an unexported custom type. func SetUserID(ctx context.Context, id int) context.Context { return context.WithValue(ctx, userIDCtxKey, int64(id)) } // GetUsername does something similar to GetUserID. func GetUsername(ctx context.Context) (string, error) { username, ok := ctx.Value(usernameCtxKey).(string) if !ok { return "", fmt.Errorf("Could not parse username [%s] from context", ctx.Value(usernameCtxKey)) } return username, nil } // GetContextForUserValues is a test helper method that creates a context with user ID set. func GetContextForUserValues(userID int, username string) context.Context { ctx := context.WithValue(context.Background(), userIDCtxKey, int64(userID)) return context.WithValue(ctx, usernameCtxKey, username) }