feat: add invites to backend
This commit is contained in:
parent
47ed36d24c
commit
6237ea940f
|
@ -0,0 +1,111 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"emperror.dev/errors"
|
||||||
|
"github.com/georgysavva/scany/pgxscan"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Invite struct {
|
||||||
|
UserID xid.ID
|
||||||
|
Code string
|
||||||
|
Created time.Time
|
||||||
|
Used bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) UserInvites(ctx context.Context, userID xid.ID) (is []Invite, err error) {
|
||||||
|
sql, args, err := sq.Select("*").From("invites").Where("user_id = ?", userID).OrderBy("created").ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "building sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgxscan.Select(ctx, db, &is, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "querying database")
|
||||||
|
}
|
||||||
|
if len(is) == 0 {
|
||||||
|
is = []Invite{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return is, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const ErrTooManyInvites = errors.Sentinel("user invite limit reached")
|
||||||
|
|
||||||
|
func (db *DB) CreateInvite(ctx context.Context, userID xid.ID) (i Invite, err error) {
|
||||||
|
tx, err := db.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "beginning transaction")
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
var maxInvites, inviteCount int
|
||||||
|
err = tx.QueryRow(ctx, "SELECT max_invites FROM users WHERE id = $1", userID).Scan(&maxInvites)
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "querying invite limit")
|
||||||
|
}
|
||||||
|
err = tx.QueryRow(ctx, "SELECT count(*) FROM invites WHERE user_id = $1", userID).Scan(&inviteCount)
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "querying current invite count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if inviteCount >= maxInvites {
|
||||||
|
return i, ErrTooManyInvites
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 32)
|
||||||
|
|
||||||
|
_, err = rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
sql, args, err := sq.Insert("invites").Columns("user_id", "code").Values(userID, code).Suffix("RETURNING *").ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "building insert invite sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgxscan.Get(ctx, db, &i, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "inserting invite")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return i, errors.Wrap(err, "committing transaction")
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) InvalidateInvite(ctx context.Context, tx pgx.Tx, code string) (valid, alreadyUsed bool, err error) {
|
||||||
|
err = tx.QueryRow(ctx, "SELECT used FROM invites WHERE code = $1", code).Scan(&alreadyUsed)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Cause(err) == pgx.ErrNoRows {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, false, errors.Wrap(err, "checking if invite exists and is used")
|
||||||
|
}
|
||||||
|
|
||||||
|
// valid: true, already used: true
|
||||||
|
if alreadyUsed {
|
||||||
|
return true, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// invite is valid, not already used
|
||||||
|
_, err = tx.Exec(ctx, "UPDATE invites SET used = true WHERE code = $1", code)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, errors.Wrap(err, "updating invite usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// valid: true, already used: false
|
||||||
|
return true, false, nil
|
||||||
|
}
|
|
@ -24,6 +24,8 @@ type User struct {
|
||||||
|
|
||||||
Discord *string
|
Discord *string
|
||||||
DiscordUsername *string
|
DiscordUsername *string
|
||||||
|
|
||||||
|
MaxInvites int
|
||||||
}
|
}
|
||||||
|
|
||||||
// usernames must match this regex
|
// usernames must match this regex
|
||||||
|
|
|
@ -182,17 +182,18 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.RequireInvite {
|
if s.RequireInvite {
|
||||||
// TODO: check invites, invalidate invite when done
|
valid, used, err := s.DB.InvalidateInvite(ctx, tx, req.InviteCode)
|
||||||
inviteValid := true
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "checking and invalidating invite")
|
||||||
if !inviteValid {
|
}
|
||||||
err = tx.Rollback(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "rolling back transaction")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if !valid {
|
||||||
return server.APIError{Code: server.ErrInviteRequired}
|
return server.APIError{Code: server.ErrInviteRequired}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if used {
|
||||||
|
return server.APIError{Code: server.ErrInviteAlreadyUsed}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete sign up ticket
|
// delete sign up ticket
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"codeberg.org/u1f320/pronouns.cc/backend/db"
|
||||||
|
"codeberg.org/u1f320/pronouns.cc/backend/server"
|
||||||
|
"emperror.dev/errors"
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inviteResponse struct {
|
||||||
|
Code string `json:"string"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Used bool `json:"used"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbInviteToResponse(i db.Invite) inviteResponse {
|
||||||
|
return inviteResponse{
|
||||||
|
Code: i.Code,
|
||||||
|
Created: i.Created,
|
||||||
|
Used: i.Used,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getInvites(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if !s.RequireInvite {
|
||||||
|
return server.APIError{Code: server.ErrInvitesDisabled}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, _ := server.ClaimsFromContext(ctx)
|
||||||
|
|
||||||
|
is, err := s.DB.UserInvites(ctx, claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "getting user invites")
|
||||||
|
}
|
||||||
|
|
||||||
|
resps := make([]inviteResponse, len(is))
|
||||||
|
for i := range is {
|
||||||
|
resps[i] = dbInviteToResponse(is[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
render.JSON(w, r, resps)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) createInvite(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if !s.RequireInvite {
|
||||||
|
return server.APIError{Code: server.ErrInvitesDisabled}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, _ := server.ClaimsFromContext(ctx)
|
||||||
|
|
||||||
|
inv, err := s.DB.CreateInvite(ctx, claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
if err == db.ErrTooManyInvites {
|
||||||
|
return server.APIError{Code: server.ErrInviteLimitReached}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Wrap(err, "creating invite")
|
||||||
|
}
|
||||||
|
|
||||||
|
render.JSON(w, r, dbInviteToResponse(inv))
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -63,6 +63,10 @@ func Mount(srv *server.Server, r chi.Router) {
|
||||||
// takes discord signup ticket to register account
|
// takes discord signup ticket to register account
|
||||||
r.Post("/signup", server.WrapHandler(s.discordSignup))
|
r.Post("/signup", server.WrapHandler(s.discordSignup))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// invite routes
|
||||||
|
r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites))
|
||||||
|
r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,13 +73,16 @@ const (
|
||||||
ErrInternalServerError = 500 // catch-all code for unknown errors
|
ErrInternalServerError = 500 // catch-all code for unknown errors
|
||||||
|
|
||||||
// Login/authorize error codes
|
// Login/authorize error codes
|
||||||
ErrInvalidState = 1001
|
ErrInvalidState = 1001
|
||||||
ErrInvalidOAuthCode = 1002
|
ErrInvalidOAuthCode = 1002
|
||||||
ErrInvalidToken = 1003 // a token was supplied, but it is invalid
|
ErrInvalidToken = 1003 // a token was supplied, but it is invalid
|
||||||
ErrInviteRequired = 1004
|
ErrInviteRequired = 1004
|
||||||
ErrInvalidTicket = 1005 // invalid signup ticket
|
ErrInvalidTicket = 1005 // invalid signup ticket
|
||||||
ErrInvalidUsername = 1006 // invalid username (when signing up)
|
ErrInvalidUsername = 1006 // invalid username (when signing up)
|
||||||
ErrUsernameTaken = 1007 // username taken (when signing up)
|
ErrUsernameTaken = 1007 // username taken (when signing up)
|
||||||
|
ErrInvitesDisabled = 1008 // invites are disabled (unneeded)
|
||||||
|
ErrInviteLimitReached = 1009 // invite limit reached (when creating invites)
|
||||||
|
ErrInviteAlreadyUsed = 1010 // invite already used (when signing up)
|
||||||
|
|
||||||
// User-related error codes
|
// User-related error codes
|
||||||
ErrUserNotFound = 2001
|
ErrUserNotFound = 2001
|
||||||
|
@ -100,13 +103,16 @@ var errCodeMessages = map[int]string{
|
||||||
ErrTooManyRequests: "Rate limit reached",
|
ErrTooManyRequests: "Rate limit reached",
|
||||||
ErrMethodNotAllowed: "Method not allowed",
|
ErrMethodNotAllowed: "Method not allowed",
|
||||||
|
|
||||||
ErrInvalidState: "Invalid OAuth state",
|
ErrInvalidState: "Invalid OAuth state",
|
||||||
ErrInvalidOAuthCode: "Invalid OAuth code",
|
ErrInvalidOAuthCode: "Invalid OAuth code",
|
||||||
ErrInvalidToken: "Supplied token was invalid",
|
ErrInvalidToken: "Supplied token was invalid",
|
||||||
ErrInviteRequired: "A valid invite code is required",
|
ErrInviteRequired: "A valid invite code is required",
|
||||||
ErrInvalidTicket: "Invalid signup ticket",
|
ErrInvalidTicket: "Invalid signup ticket",
|
||||||
ErrInvalidUsername: "Invalid username",
|
ErrInvalidUsername: "Invalid username",
|
||||||
ErrUsernameTaken: "Username is already taken",
|
ErrUsernameTaken: "Username is already taken",
|
||||||
|
ErrInvitesDisabled: "Invites are disabled",
|
||||||
|
ErrInviteLimitReached: "Your account has reached the invite limit",
|
||||||
|
ErrInviteAlreadyUsed: "That invite code has already been used",
|
||||||
|
|
||||||
ErrUserNotFound: "User not found",
|
ErrUserNotFound: "User not found",
|
||||||
|
|
||||||
|
@ -124,13 +130,16 @@ var errCodeStatuses = map[int]int{
|
||||||
ErrTooManyRequests: http.StatusTooManyRequests,
|
ErrTooManyRequests: http.StatusTooManyRequests,
|
||||||
ErrMethodNotAllowed: http.StatusMethodNotAllowed,
|
ErrMethodNotAllowed: http.StatusMethodNotAllowed,
|
||||||
|
|
||||||
ErrInvalidState: http.StatusBadRequest,
|
ErrInvalidState: http.StatusBadRequest,
|
||||||
ErrInvalidOAuthCode: http.StatusForbidden,
|
ErrInvalidOAuthCode: http.StatusForbidden,
|
||||||
ErrInvalidToken: http.StatusUnauthorized,
|
ErrInvalidToken: http.StatusUnauthorized,
|
||||||
ErrInviteRequired: http.StatusBadRequest,
|
ErrInviteRequired: http.StatusBadRequest,
|
||||||
ErrInvalidTicket: http.StatusBadRequest,
|
ErrInvalidTicket: http.StatusBadRequest,
|
||||||
ErrInvalidUsername: http.StatusBadRequest,
|
ErrInvalidUsername: http.StatusBadRequest,
|
||||||
ErrUsernameTaken: http.StatusBadRequest,
|
ErrUsernameTaken: http.StatusBadRequest,
|
||||||
|
ErrInvitesDisabled: http.StatusForbidden,
|
||||||
|
ErrInviteLimitReached: http.StatusForbidden,
|
||||||
|
ErrInviteAlreadyUsed: http.StatusBadRequest,
|
||||||
|
|
||||||
ErrUserNotFound: http.StatusNotFound,
|
ErrUserNotFound: http.StatusNotFound,
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,9 @@ create table users (
|
||||||
links text[],
|
links text[],
|
||||||
|
|
||||||
discord text unique, -- for Discord oauth
|
discord text unique, -- for Discord oauth
|
||||||
discord_username text
|
discord_username text,
|
||||||
|
|
||||||
|
max_invites int default 10
|
||||||
);
|
);
|
||||||
|
|
||||||
create table user_names (
|
create table user_names (
|
||||||
|
@ -80,3 +82,10 @@ create table member_fields (
|
||||||
friends_only text[],
|
friends_only text[],
|
||||||
avoid text[]
|
avoid text[]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
create table invites (
|
||||||
|
user_id text not null references users (id) on delete cascade,
|
||||||
|
code text primary key,
|
||||||
|
created timestamp not null default (current_timestamp at time zone 'utc'),
|
||||||
|
used boolean not null default false
|
||||||
|
);
|
||||||
|
|
Loading…
Reference in New Issue