pronounsfu/backend/server/auth.go

90 lines
2.2 KiB
Go

package server
import (
"context"
"net/http"
"strings"
"codeberg.org/u1f320/pronouns.cc/backend/log"
"codeberg.org/u1f320/pronouns.cc/backend/server/auth"
"github.com/go-chi/render"
)
// maybeAuth is a globally-used middleware.
func (s *Server) maybeAuth(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if token == "" {
next.ServeHTTP(w, r)
return
}
claims, err := s.Auth.Claims(token)
if err != nil {
render.Status(r, errCodeStatuses[ErrInvalidToken])
render.JSON(w, r, APIError{
Code: ErrInvalidToken,
Message: errCodeMessages[ErrInvalidToken],
})
return
}
// "valid" here refers to existence and expiry date, not whether the token is known
valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID)
if err != nil {
log.Errorf("validating token for user %v: %v", claims.UserID, err)
render.Status(r, errCodeStatuses[ErrInternalServerError])
render.JSON(w, r, APIError{
Code: ErrInternalServerError,
Message: errCodeMessages[ErrInternalServerError],
})
return
}
if !valid {
render.Status(r, errCodeStatuses[ErrInvalidToken])
render.JSON(w, r, APIError{
Code: ErrInvalidToken,
Message: errCodeMessages[ErrInvalidToken],
})
return
}
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// MustAuth makes a valid token required
func MustAuth(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
_, ok := ClaimsFromContext(r.Context())
if !ok {
render.Status(r, errCodeStatuses[ErrForbidden])
render.JSON(w, r, APIError{
Code: ErrForbidden,
Message: errCodeMessages[ErrForbidden],
})
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// ClaimsFromContext returns the auth.Claims in the context, if any.
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
v := ctx.Value(ctxKeyClaims)
if v == nil {
return auth.Claims{}, false
}
claims, ok := v.(auth.Claims)
return claims, ok
}