diff --git a/backend/db/entries.go b/backend/db/entries.go index 9a2c28f..1b4b8cd 100644 --- a/backend/db/entries.go +++ b/backend/db/entries.go @@ -40,13 +40,13 @@ func (w *WordStatus) UnmarshalJSON(src []byte) error { return nil } -func (w WordStatus) Valid(extra ...WordStatus) bool { +func (w WordStatus) Valid(extra CustomPreferences) bool { if w == StatusFavourite || w == StatusOkay || w == StatusJokingly || w == StatusFriendsOnly || w == StatusAvoid { return true } - for i := range extra { - if w == extra[i] { + for k := range extra { + if string(w) == k { return true } } @@ -58,7 +58,7 @@ type FieldEntry struct { Status WordStatus `json:"status"` } -func (fe FieldEntry) Validate() string { +func (fe FieldEntry) Validate(custom CustomPreferences) string { if fe.Value == "" { return "value cannot be empty" } @@ -67,7 +67,7 @@ func (fe FieldEntry) Validate() string { return fmt.Sprintf("name must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(fe.Value))) } - if !fe.Status.Valid() { + if !fe.Status.Valid(custom) { return "status is invalid" } @@ -80,7 +80,7 @@ type PronounEntry struct { Status WordStatus `json:"status"` } -func (p PronounEntry) Validate() string { +func (p PronounEntry) Validate(custom CustomPreferences) string { if p.Pronouns == "" { return "pronouns cannot be empty" } @@ -95,7 +95,7 @@ func (p PronounEntry) Validate() string { return fmt.Sprintf("pronouns must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(p.Pronouns))) } - if !p.Status.Valid() { + if !p.Status.Valid(custom) { return "status is invalid" } diff --git a/backend/db/field.go b/backend/db/field.go index 69a2b2e..285d5d4 100644 --- a/backend/db/field.go +++ b/backend/db/field.go @@ -24,7 +24,7 @@ type Field struct { } // Validate validates this field. If it is invalid, a non-empty string is returned as error message. -func (f Field) Validate() string { +func (f Field) Validate(custom CustomPreferences) string { if f.Name == "" { return "name cannot be empty" } @@ -42,7 +42,7 @@ func (f Field) Validate() string { return fmt.Sprintf("entries.%d: max length is %d characters, length is %d", i, FieldEntryMaxLength, length) } - if !entry.Status.Valid() { + if !entry.Status.Valid(custom) { return fmt.Sprintf("entries.%d: status is invalid", i) } } diff --git a/backend/db/user.go b/backend/db/user.go index dc27b12..6d8e24f 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -498,8 +498,9 @@ func (db *DB) UpdateUser( memberTitle *string, listPrivate *bool, links *[]string, avatar *string, + customPreferences *CustomPreferences, ) (u User, err error) { - if displayName == nil && bio == nil && links == nil && avatar == nil && memberTitle == nil && listPrivate == nil { + if displayName == nil && bio == nil && links == nil && avatar == nil && memberTitle == nil && listPrivate == nil && customPreferences == nil { sql, args, err := sq.Select("*").From("users").Where("id = ?", id).ToSql() if err != nil { return u, errors.Wrap(err, "building sql") @@ -541,6 +542,9 @@ func (db *DB) UpdateUser( if listPrivate != nil { builder = builder.Set("list_private", *listPrivate) } + if customPreferences != nil { + builder = builder.Set("custom_preferences", *customPreferences) + } if avatar != nil { if *avatar == "" { diff --git a/backend/routes/member/create_member.go b/backend/routes/member/create_member.go index b18f8c8..3724cd6 100644 --- a/backend/routes/member/create_member.go +++ b/backend/routes/member/create_member.go @@ -103,15 +103,15 @@ func (s *Server) createMember(w http.ResponseWriter, r *http.Request) (err error } } - if err := validateSlicePtr("name", &cmr.Names); err != nil { + if err := validateSlicePtr("name", &cmr.Names, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("pronoun", &cmr.Pronouns); err != nil { + if err := validateSlicePtr("pronoun", &cmr.Pronouns, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("field", &cmr.Fields); err != nil { + if err := validateSlicePtr("field", &cmr.Fields, u.CustomPreferences); err != nil { return *err } @@ -186,12 +186,12 @@ func (s *Server) createMember(w http.ResponseWriter, r *http.Request) (err error } type validator interface { - Validate() string + Validate(custom db.CustomPreferences) string } // validateSlicePtr validates a slice of validators. // If the slice is nil, a nil error is returned (assuming that the field is not required) -func validateSlicePtr[T validator](typ string, slice *[]T) *server.APIError { +func validateSlicePtr[T validator](typ string, slice *[]T, custom db.CustomPreferences) *server.APIError { if slice == nil { return nil } @@ -211,7 +211,7 @@ func validateSlicePtr[T validator](typ string, slice *[]T) *server.APIError { // validate all fields for i, pronouns := range *slice { - if s := pronouns.Validate(); s != "" { + if s := pronouns.Validate(custom); s != "" { return &server.APIError{ Code: server.ErrBadRequest, Details: fmt.Sprintf("%s %d: %s", typ, i+1, s), diff --git a/backend/routes/member/patch_member.go b/backend/routes/member/patch_member.go index a61d620..4621401 100644 --- a/backend/routes/member/patch_member.go +++ b/backend/routes/member/patch_member.go @@ -41,6 +41,11 @@ func (s *Server) patchMember(w http.ResponseWriter, r *http.Request) error { return server.APIError{Code: server.ErrMemberNotFound} } + u, err := s.DB.User(ctx, claims.UserID) + if err != nil { + return errors.Wrap(err, "getting user") + } + m, err := s.DB.Member(ctx, id) if err != nil { if err == db.ErrMemberNotFound { @@ -148,15 +153,15 @@ func (s *Server) patchMember(w http.ResponseWriter, r *http.Request) error { } } - if err := validateSlicePtr("name", req.Names); err != nil { + if err := validateSlicePtr("name", req.Names, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("pronoun", req.Pronouns); err != nil { + if err := validateSlicePtr("pronoun", req.Pronouns, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("field", req.Fields); err != nil { + if err := validateSlicePtr("field", req.Fields, u.CustomPreferences); err != nil { return *err } @@ -271,11 +276,6 @@ func (s *Server) patchMember(w http.ResponseWriter, r *http.Request) error { return err } - u, err := s.DB.User(ctx, claims.UserID) - if err != nil { - return errors.Wrap(err, "getting user") - } - // echo the updated member back on success render.JSON(w, r, dbMemberToMember(u, m, fields, true)) return nil diff --git a/backend/routes/user/patch_user.go b/backend/routes/user/patch_user.go index 0d3f9f8..5507b45 100644 --- a/backend/routes/user/patch_user.go +++ b/backend/routes/user/patch_user.go @@ -59,7 +59,8 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { req.Fields == nil && req.Names == nil && req.Pronouns == nil && - req.Avatar == nil { + req.Avatar == nil && + req.CustomPreferences == nil { return server.APIError{ Code: server.ErrBadRequest, Details: "Data must not be empty", @@ -105,15 +106,15 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } } - if err := validateSlicePtr("name", req.Names); err != nil { + if err := validateSlicePtr("name", req.Names, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("pronoun", req.Pronouns); err != nil { + if err := validateSlicePtr("pronoun", req.Pronouns, u.CustomPreferences); err != nil { return *err } - if err := validateSlicePtr("field", req.Fields); err != nil { + if err := validateSlicePtr("field", req.Fields, u.CustomPreferences); err != nil { return *err } @@ -201,7 +202,7 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } } - u, err = s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.MemberTitle, req.ListPrivate, req.Links, avatarHash) + u, err = s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.MemberTitle, req.ListPrivate, req.Links, avatarHash, req.CustomPreferences) if err != nil && errors.Cause(err) != db.ErrNothingToUpdate { log.Errorf("updating user: %v", err) return err @@ -278,12 +279,12 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } type validator interface { - Validate() string + Validate(custom db.CustomPreferences) string } // validateSlicePtr validates a slice of validators. // If the slice is nil, a nil error is returned (assuming that the field is not required) -func validateSlicePtr[T validator](typ string, slice *[]T) *server.APIError { +func validateSlicePtr[T validator](typ string, slice *[]T, custom db.CustomPreferences) *server.APIError { if slice == nil { return nil } @@ -303,7 +304,7 @@ func validateSlicePtr[T validator](typ string, slice *[]T) *server.APIError { // validate all fields for i, pronouns := range *slice { - if s := pronouns.Validate(); s != "" { + if s := pronouns.Validate(custom); s != "" { return &server.APIError{ Code: server.ErrBadRequest, Details: fmt.Sprintf("%s %d: %s", typ, i+1, s), diff --git a/go.mod b/go.mod index 9b0a4f9..8cd6ce4 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/go-chi/render v1.0.2 github.com/gobwas/glob v0.2.3 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/google/uuid v1.3.0 github.com/jackc/pgx/v5 v5.3.1 github.com/joho/godotenv v1.5.1 github.com/mediocregopher/radix/v4 v4.1.2 @@ -40,7 +41,6 @@ require ( github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/s2a-go v0.1.0 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/scripts/seeddb/main.go b/scripts/seeddb/main.go index acb91a2..a430e72 100644 --- a/scripts/seeddb/main.go +++ b/scripts/seeddb/main.go @@ -48,7 +48,7 @@ func run(c *cli.Context) error { return err } - _, err = pg.UpdateUser(ctx, tx, u.ID, ptr("testing"), ptr("This is a bio!"), nil, ptr(false), &[]string{"https://pronouns.cc"}, nil) + _, err = pg.UpdateUser(ctx, tx, u.ID, ptr("testing"), ptr("This is a bio!"), nil, ptr(false), &[]string{"https://pronouns.cc"}, nil, nil) if err != nil { fmt.Println("error setting user info:", err) return err