pronounsfu/backend/server/auth.go

90 lines
2.2 KiB
Go
Raw Normal View History

2022-05-02 08:19:37 -07:00
package server
import (
"context"
"net/http"
"strings"
2022-05-02 08:19:37 -07:00
"codeberg.org/u1f320/pronouns.cc/backend/log"
"codeberg.org/u1f320/pronouns.cc/backend/server/auth"
"github.com/go-chi/render"
2022-05-02 08:19:37 -07:00
)
// maybeAuth is a globally-used middleware.
func (s *Server) maybeAuth(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
2022-05-02 08:19:37 -07:00
if token == "" {
next.ServeHTTP(w, r)
return
}
claims, err := s.Auth.Claims(token)
if err != nil {
2022-05-14 12:57:44 -07:00
render.Status(r, errCodeStatuses[ErrInvalidToken])
render.JSON(w, r, APIError{
2022-05-14 12:57:44 -07:00
Code: ErrInvalidToken,
Message: errCodeMessages[ErrInvalidToken],
})
return
2022-05-02 08:19:37 -07:00
}
// "valid" here refers to existence and expiry date, not whether the token is known
valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID)
if err != nil {
log.Errorf("validating token for user %v: %v", claims.UserID, err)
render.Status(r, errCodeStatuses[ErrInternalServerError])
render.JSON(w, r, APIError{
Code: ErrInternalServerError,
Message: errCodeMessages[ErrInternalServerError],
})
return
}
if !valid {
render.Status(r, errCodeStatuses[ErrInvalidToken])
render.JSON(w, r, APIError{
Code: ErrInvalidToken,
Message: errCodeMessages[ErrInvalidToken],
})
return
}
2022-05-02 08:19:37 -07:00
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// MustAuth makes a valid token required
func MustAuth(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
_, ok := ClaimsFromContext(r.Context())
if !ok {
render.Status(r, errCodeStatuses[ErrForbidden])
render.JSON(w, r, APIError{
Code: ErrForbidden,
Message: errCodeMessages[ErrForbidden],
})
return
2022-05-02 08:19:37 -07:00
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// ClaimsFromContext returns the auth.Claims in the context, if any.
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
v := ctx.Value(ctxKeyClaims)
if v == nil {
return auth.Claims{}, false
}
claims, ok := v.(auth.Claims)
return claims, ok
}