feat: add invites to backend

This commit is contained in:
Sam 2022-11-18 15:27:52 +01:00
parent 47ed36d24c
commit 6237ea940f
7 changed files with 234 additions and 30 deletions

111
backend/db/invites.go Normal file
View File

@ -0,0 +1,111 @@
package db
import (
"context"
"crypto/rand"
"encoding/base64"
"time"
"emperror.dev/errors"
"github.com/georgysavva/scany/pgxscan"
"github.com/jackc/pgx/v4"
"github.com/rs/xid"
)
type Invite struct {
UserID xid.ID
Code string
Created time.Time
Used bool
}
func (db *DB) UserInvites(ctx context.Context, userID xid.ID) (is []Invite, err error) {
sql, args, err := sq.Select("*").From("invites").Where("user_id = ?", userID).OrderBy("created").ToSql()
if err != nil {
return nil, errors.Wrap(err, "building sql")
}
err = pgxscan.Select(ctx, db, &is, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "querying database")
}
if len(is) == 0 {
is = []Invite{}
}
return is, nil
}
const ErrTooManyInvites = errors.Sentinel("user invite limit reached")
func (db *DB) CreateInvite(ctx context.Context, userID xid.ID) (i Invite, err error) {
tx, err := db.Begin(ctx)
if err != nil {
return i, errors.Wrap(err, "beginning transaction")
}
defer tx.Rollback(ctx)
var maxInvites, inviteCount int
err = tx.QueryRow(ctx, "SELECT max_invites FROM users WHERE id = $1", userID).Scan(&maxInvites)
if err != nil {
return i, errors.Wrap(err, "querying invite limit")
}
err = tx.QueryRow(ctx, "SELECT count(*) FROM invites WHERE user_id = $1", userID).Scan(&inviteCount)
if err != nil {
return i, errors.Wrap(err, "querying current invite count")
}
if inviteCount >= maxInvites {
return i, ErrTooManyInvites
}
b := make([]byte, 32)
_, err = rand.Read(b)
if err != nil {
panic(err)
}
code := base64.RawURLEncoding.EncodeToString(b)
sql, args, err := sq.Insert("invites").Columns("user_id", "code").Values(userID, code).Suffix("RETURNING *").ToSql()
if err != nil {
return i, errors.Wrap(err, "building insert invite sql")
}
err = pgxscan.Get(ctx, db, &i, sql, args...)
if err != nil {
return i, errors.Wrap(err, "inserting invite")
}
err = tx.Commit(ctx)
if err != nil {
return i, errors.Wrap(err, "committing transaction")
}
return i, nil
}
func (db *DB) InvalidateInvite(ctx context.Context, tx pgx.Tx, code string) (valid, alreadyUsed bool, err error) {
err = tx.QueryRow(ctx, "SELECT used FROM invites WHERE code = $1", code).Scan(&alreadyUsed)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return false, false, nil
}
return false, false, errors.Wrap(err, "checking if invite exists and is used")
}
// valid: true, already used: true
if alreadyUsed {
return true, true, nil
}
// invite is valid, not already used
_, err = tx.Exec(ctx, "UPDATE invites SET used = true WHERE code = $1", code)
if err != nil {
return false, false, errors.Wrap(err, "updating invite usage")
}
// valid: true, already used: false
return true, false, nil
}

View File

@ -24,6 +24,8 @@ type User struct {
Discord *string
DiscordUsername *string
MaxInvites int
}
// usernames must match this regex

View File

@ -182,17 +182,18 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error {
}
if s.RequireInvite {
// TODO: check invites, invalidate invite when done
inviteValid := true
if !inviteValid {
err = tx.Rollback(ctx)
valid, used, err := s.DB.InvalidateInvite(ctx, tx, req.InviteCode)
if err != nil {
return errors.Wrap(err, "rolling back transaction")
return errors.Wrap(err, "checking and invalidating invite")
}
if !valid {
return server.APIError{Code: server.ErrInviteRequired}
}
if used {
return server.APIError{Code: server.ErrInviteAlreadyUsed}
}
}
// delete sign up ticket

View File

@ -0,0 +1,68 @@
package auth
import (
"net/http"
"time"
"codeberg.org/u1f320/pronouns.cc/backend/db"
"codeberg.org/u1f320/pronouns.cc/backend/server"
"emperror.dev/errors"
"github.com/go-chi/render"
)
type inviteResponse struct {
Code string `json:"string"`
Created time.Time `json:"created"`
Used bool `json:"used"`
}
func dbInviteToResponse(i db.Invite) inviteResponse {
return inviteResponse{
Code: i.Code,
Created: i.Created,
Used: i.Used,
}
}
func (s *Server) getInvites(w http.ResponseWriter, r *http.Request) error {
if !s.RequireInvite {
return server.APIError{Code: server.ErrInvitesDisabled}
}
ctx := r.Context()
claims, _ := server.ClaimsFromContext(ctx)
is, err := s.DB.UserInvites(ctx, claims.UserID)
if err != nil {
return errors.Wrap(err, "getting user invites")
}
resps := make([]inviteResponse, len(is))
for i := range is {
resps[i] = dbInviteToResponse(is[i])
}
render.JSON(w, r, resps)
return nil
}
func (s *Server) createInvite(w http.ResponseWriter, r *http.Request) error {
if !s.RequireInvite {
return server.APIError{Code: server.ErrInvitesDisabled}
}
ctx := r.Context()
claims, _ := server.ClaimsFromContext(ctx)
inv, err := s.DB.CreateInvite(ctx, claims.UserID)
if err != nil {
if err == db.ErrTooManyInvites {
return server.APIError{Code: server.ErrInviteLimitReached}
}
return errors.Wrap(err, "creating invite")
}
render.JSON(w, r, dbInviteToResponse(inv))
return nil
}

View File

@ -63,6 +63,10 @@ func Mount(srv *server.Server, r chi.Router) {
// takes discord signup ticket to register account
r.Post("/signup", server.WrapHandler(s.discordSignup))
})
// invite routes
r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites))
r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite))
})
}

View File

@ -80,6 +80,9 @@ const (
ErrInvalidTicket = 1005 // invalid signup ticket
ErrInvalidUsername = 1006 // invalid username (when signing up)
ErrUsernameTaken = 1007 // username taken (when signing up)
ErrInvitesDisabled = 1008 // invites are disabled (unneeded)
ErrInviteLimitReached = 1009 // invite limit reached (when creating invites)
ErrInviteAlreadyUsed = 1010 // invite already used (when signing up)
// User-related error codes
ErrUserNotFound = 2001
@ -107,6 +110,9 @@ var errCodeMessages = map[int]string{
ErrInvalidTicket: "Invalid signup ticket",
ErrInvalidUsername: "Invalid username",
ErrUsernameTaken: "Username is already taken",
ErrInvitesDisabled: "Invites are disabled",
ErrInviteLimitReached: "Your account has reached the invite limit",
ErrInviteAlreadyUsed: "That invite code has already been used",
ErrUserNotFound: "User not found",
@ -131,6 +137,9 @@ var errCodeStatuses = map[int]int{
ErrInvalidTicket: http.StatusBadRequest,
ErrInvalidUsername: http.StatusBadRequest,
ErrUsernameTaken: http.StatusBadRequest,
ErrInvitesDisabled: http.StatusForbidden,
ErrInviteLimitReached: http.StatusForbidden,
ErrInviteAlreadyUsed: http.StatusBadRequest,
ErrUserNotFound: http.StatusNotFound,

View File

@ -12,7 +12,9 @@ create table users (
links text[],
discord text unique, -- for Discord oauth
discord_username text
discord_username text,
max_invites int default 10
);
create table user_names (
@ -80,3 +82,10 @@ create table member_fields (
friends_only text[],
avoid text[]
);
create table invites (
user_id text not null references users (id) on delete cascade,
code text primary key,
created timestamp not null default (current_timestamp at time zone 'utc'),
used boolean not null default false
);