From 57c7a0f4defa5506900710947bcf0786a0fc63db Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 16 Jun 2022 14:54:15 +0200 Subject: [PATCH] feat(api): add PATCH /users/@me, remove PATCH /users/@me/fields --- backend/db/db.go | 2 + backend/db/field.go | 13 +-- backend/db/user.go | 54 +++++++++++++ backend/routes/user/fields.go | 54 ------------- backend/routes/user/patch_user.go | 130 ++++++++++++++++++++++++++++++ backend/routes/user/routes.go | 2 +- 6 files changed, 188 insertions(+), 67 deletions(-) delete mode 100644 backend/routes/user/fields.go create mode 100644 backend/routes/user/patch_user.go diff --git a/backend/db/db.go b/backend/db/db.go index 8a13124..f39e69b 100644 --- a/backend/db/db.go +++ b/backend/db/db.go @@ -14,6 +14,8 @@ import ( var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) +const ErrNothingToUpdate = errors.Sentinel("nothing to update") + type DB struct { *pgxpool.Pool diff --git a/backend/db/field.go b/backend/db/field.go index 45d1137..2723ed1 100644 --- a/backend/db/field.go +++ b/backend/db/field.go @@ -91,13 +91,7 @@ func (db *DB) UserFields(ctx context.Context, id xid.ID) (fs []Field, err error) } // SetUserFields updates the fields for the given user. -func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) (err error) { - tx, err := db.Begin(ctx) - if err != nil { - return errors.Wrap(err, "building sql") - } - defer tx.Rollback(ctx) - +func (db *DB) SetUserFields(ctx context.Context, tx pgx.Tx, userID xid.ID, fields []Field) (err error) { sql, args, err := sq.Delete("user_fields").Where("user_id = ?", userID).ToSql() if err != nil { return errors.Wrap(err, "building sql") @@ -124,10 +118,5 @@ func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) if err != nil { return errors.Wrap(err, "inserting new fields") } - - err = tx.Commit(ctx) - if err != nil { - return errors.Wrap(err, "committing transaction") - } return nil } diff --git a/backend/db/user.go b/backend/db/user.go index 368eec1..b2cd671 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -38,6 +38,14 @@ const ( ErrUsernameTooLong = errors.Sentinel("username is too long") ) +const ( + MaxUsernameLength = 40 + MaxDisplayNameLength = 100 + MaxUserBioLength = 1000 + MaxUserLinksLength = 25 + MaxLinkLength = 256 +) + // CreateUser creates a user with the given username. func (db *DB) CreateUser(ctx context.Context, username string) (u User, err error) { // check if the username is valid @@ -146,3 +154,49 @@ func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken err = db.QueryRow(ctx, "select exists (select id from users where username = $1)", username).Scan(&taken) return true, taken, err } + +func (db *DB) UpdateUser( + ctx context.Context, + tx pgx.Tx, id xid.ID, + displayName, bio *string, + links *[]string, +) (u User, err error) { + if displayName == nil && bio == nil && links == nil { + return u, ErrNothingToUpdate + } + + builder := sq.Update("users").Where("id = ?", id) + if displayName != nil { + if *displayName == "" { + builder = builder.Set("display_name", nil) + } else { + builder = builder.Set("display_name", *displayName) + } + } + if bio != nil { + if *bio == "" { + builder = builder.Set("bio", nil) + } else { + builder = builder.Set("bio", *bio) + } + } + if links != nil { + if len(*links) == 0 { + builder = builder.Set("links", nil) + } else { + builder = builder.Set("links", *links) + } + } + + sql, args, err := builder.Suffix("RETURNING *").ToSql() + if err != nil { + return u, errors.Wrap(err, "building sql") + } + + err = pgxscan.Get(ctx, tx, &u, sql, args...) + if err != nil { + return u, errors.Wrap(err, "executing sql") + } + + return u, nil +} diff --git a/backend/routes/user/fields.go b/backend/routes/user/fields.go deleted file mode 100644 index bd58ba2..0000000 --- a/backend/routes/user/fields.go +++ /dev/null @@ -1,54 +0,0 @@ -package user - -import ( - "fmt" - "net/http" - - "codeberg.org/u1f320/pronouns.cc/backend/db" - "codeberg.org/u1f320/pronouns.cc/backend/log" - "codeberg.org/u1f320/pronouns.cc/backend/server" - "github.com/go-chi/render" -) - -type PatchFieldsRequest struct { - Fields []db.Field `json:"fields"` -} - -func (s *Server) patchUserFields(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - claims, _ := server.ClaimsFromContext(ctx) - - var req PatchFieldsRequest - err := render.Decode(r, &req) - if err != nil { - return server.APIError{Code: server.ErrBadRequest} - } - - // max 25 fields - if len(req.Fields) > db.MaxFields { - return server.APIError{ - Code: server.ErrBadRequest, - Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(req.Fields)), - } - } - - // validate all fields - for i, field := range req.Fields { - if s := field.Validate(); s != "" { - return server.APIError{ - Code: server.ErrBadRequest, - Details: fmt.Sprintf("field %d: %s", i, s), - } - } - } - - err = s.DB.SetUserFields(ctx, claims.UserID, req.Fields) - if err != nil { - log.Errorf("setting fields for user %v: %v", claims.UserID, err) - return err - } - - // echo the fields back on success - render.JSON(w, r, req) - return nil -} diff --git a/backend/routes/user/patch_user.go b/backend/routes/user/patch_user.go new file mode 100644 index 0000000..168b437 --- /dev/null +++ b/backend/routes/user/patch_user.go @@ -0,0 +1,130 @@ +package user + +import ( + "fmt" + "net/http" + + "codeberg.org/u1f320/pronouns.cc/backend/db" + "codeberg.org/u1f320/pronouns.cc/backend/log" + "codeberg.org/u1f320/pronouns.cc/backend/server" + "emperror.dev/errors" + "github.com/go-chi/render" +) + +type PatchUserRequest struct { + DisplayName *string `json:"display_name"` + Bio *string `json:"bio"` + Links *[]string `json:"links"` + Fields *[]db.Field `json:"fields"` +} + +// patchUser parses a PatchUserRequest and updates the user with the given ID. +func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + claims, _ := server.ClaimsFromContext(ctx) + + var req PatchUserRequest + err := render.Decode(r, &req) + if err != nil { + return server.APIError{Code: server.ErrBadRequest} + } + + // validate that *something* is set + if req.DisplayName == nil && req.Bio == nil && req.Links == nil && req.Fields == nil { + return server.APIError{ + Code: server.ErrBadRequest, + Details: "Data must not be empty", + } + } + + // validate display name/bio + if req.DisplayName != nil && len(*req.DisplayName) > db.MaxDisplayNameLength { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Display name too long (max %d, current %d)", db.MaxDisplayNameLength, len(*req.DisplayName)), + } + } + if req.Bio != nil && len(*req.Bio) > db.MaxUserBioLength { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Bio too long (max %d, current %d)", db.MaxUserBioLength, len(*req.Bio)), + } + } + + // validate links + if req.Links != nil { + if len(*req.Links) > db.MaxUserLinksLength { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Too many links (max %d, current %d)", db.MaxUserLinksLength, len(*req.Links)), + } + } + + for i, link := range *req.Links { + if len(link) > db.MaxLinkLength { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Link %d too long (max %d, current %d)", i, db.MaxLinkLength, len(link)), + } + } + } + } + + // max 25 fields + if len(*req.Fields) > db.MaxFields { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(*req.Fields)), + } + } + + // validate all fields + for i, field := range *req.Fields { + if s := field.Validate(); s != "" { + return server.APIError{ + Code: server.ErrBadRequest, + Details: fmt.Sprintf("field %d: %s", i, s), + } + } + } + + // start transaction + tx, err := s.DB.Begin(ctx) + if err != nil { + log.Errorf("creating transaction: %v", err) + return err + } + defer tx.Rollback(ctx) + + u, err := s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.Links) + if err != nil && errors.Cause(err) != db.ErrNothingToUpdate { + log.Errorf("updating user: %v", err) + return err + } + + var fields []db.Field + if req.Fields != nil { + err = s.DB.SetUserFields(ctx, tx, claims.UserID, *req.Fields) + if err != nil { + log.Errorf("setting fields for user %v: %v", claims.UserID, err) + return err + } + } else { + fields, err = s.DB.UserFields(ctx, claims.UserID) + if err != nil { + log.Errorf("getting fields for user %v: %v", claims.UserID, err) + return err + } + } + + err = tx.Commit(ctx) + if err != nil { + log.Errorf("committing transaction: %v", err) + return err + } + + // echo the updated user back on success + render.JSON(w, r, dbUserToResponse(u, fields)) + return nil +} diff --git a/backend/routes/user/routes.go b/backend/routes/user/routes.go index 4d17031..76b428c 100644 --- a/backend/routes/user/routes.go +++ b/backend/routes/user/routes.go @@ -17,7 +17,7 @@ func Mount(srv *server.Server, r chi.Router) { r.With(server.MustAuth).Group(func(r chi.Router) { r.Get("/@me", server.WrapHandler(s.getMeUser)) - r.Patch("/@me/fields", server.WrapHandler(s.patchUserFields)) + r.Patch("/@me", server.WrapHandler(s.patchUser)) }) }) }