pronounsfu/backend/routes/auth/routes.go

126 lines
2.8 KiB
Go
Raw Normal View History

2022-05-02 08:19:37 -07:00
package auth
import (
"net/http"
"os"
2022-05-02 08:19:37 -07:00
"codeberg.org/u1f320/pronouns.cc/backend/db"
"codeberg.org/u1f320/pronouns.cc/backend/log"
"codeberg.org/u1f320/pronouns.cc/backend/server"
"emperror.dev/errors"
2022-05-02 08:19:37 -07:00
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/rs/xid"
2022-05-02 08:19:37 -07:00
)
type Server struct {
*server.Server
RequireInvite bool
2022-05-02 08:19:37 -07:00
}
type userResponse struct {
ID xid.ID `json:"id"`
Username string `json:"username"`
DisplayName *string `json:"display_name"`
Bio *string `json:"bio"`
AvatarURL *string `json:"avatar_url"`
Links []string `json:"links"`
Discord *string `json:"discord"`
DiscordUsername *string `json:"discord_username"`
}
func dbUserToUserResponse(u db.User) *userResponse {
return &userResponse{
ID: u.ID,
Username: u.Username,
DisplayName: u.DisplayName,
Bio: u.Bio,
AvatarURL: u.AvatarURL,
Links: u.Links,
Discord: u.Discord,
DiscordUsername: u.DiscordUsername,
}
}
2022-05-02 08:19:37 -07:00
func Mount(srv *server.Server, r chi.Router) {
s := &Server{
Server: srv,
RequireInvite: os.Getenv("REQUIRE_INVITE") == "true",
}
2022-05-02 08:19:37 -07:00
r.Route("/auth", func(r chi.Router) {
// check if username is taken
r.Get("/username", server.WrapHandler(s.usernameTaken))
// generate csrf token, returns all supported OAuth provider URLs
2022-05-12 07:41:32 -07:00
r.Post("/urls", server.WrapHandler(s.oauthURLs))
2022-05-02 08:19:37 -07:00
r.Route("/discord", func(r chi.Router) {
// takes code + state, validates it, returns token OR discord signup ticket
2022-05-12 07:41:32 -07:00
r.Post("/callback", server.WrapHandler(s.discordCallback))
// takes discord signup ticket to register account
r.Post("/signup", nil)
2022-05-02 08:19:37 -07:00
})
})
}
type oauthURLsRequest struct {
2022-05-12 07:41:32 -07:00
CallbackDomain string `json:"callback_domain"`
}
type oauthURLsResponse struct {
Discord string `json:"discord"`
}
func (s *Server) oauthURLs(w http.ResponseWriter, r *http.Request) error {
req, err := Decode[oauthURLsRequest](r)
if err != nil {
2022-05-12 07:41:32 -07:00
log.Error(err)
return server.APIError{Code: server.ErrBadRequest}
}
// generate CSRF state
state, err := s.setCSRFState(r.Context())
if err != nil {
return errors.Wrap(err, "setting CSRF state")
}
// copy Discord config and set redirect url
discordCfg := discordOAuthConfig
2022-05-12 07:41:32 -07:00
discordCfg.RedirectURL = req.CallbackDomain + "/login/discord"
render.JSON(w, r, oauthURLsResponse{
Discord: discordCfg.AuthCodeURL(state),
})
return nil
}
func (s *Server) usernameTaken(w http.ResponseWriter, r *http.Request) error {
type Response struct {
Valid bool `json:"valid"`
Taken bool `json:"taken"`
}
name := r.FormValue("username")
if name == "" {
render.JSON(w, r, Response{
Valid: false,
})
return nil
}
valid, taken, err := s.DB.UsernameTaken(r.Context(), name)
if err != nil {
return err
}
render.JSON(w, r, Response{
Valid: valid,
Taken: taken,
})
return nil
}