diff --git a/backend/db/member.go b/backend/db/member.go index 08e75e7..2cb377a 100644 --- a/backend/db/member.go +++ b/backend/db/member.go @@ -38,6 +38,11 @@ const ( var memberNameRegex = regexp.MustCompile("^[^@\\?!#/\\\\[\\]\"\\{\\}'$%&()+<=>^|~`,]{1,100}$") func MemberNameValid(name string) bool { + // These two names will break routing, but periods should still be allowed in names otherwise. + if name == "." || name == ".." { + return false + } + return memberNameRegex.MatchString(name) } diff --git a/backend/db/user.go b/backend/db/user.go index f7bdfaa..63b8173 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -113,6 +113,24 @@ func (u User) NumProviders() (numProviders int) { // usernames must match this regex var usernameRegex = regexp.MustCompile(`^[\w-.]{2,40}$`) +func UsernameValid(username string) (err error) { + // This name would break routing, but periods should still be allowed in names otherwise. + if username == ".." { + return ErrInvalidUsername + } + + if !usernameRegex.MatchString(username) { + if len(username) < 2 { + return ErrUsernameTooShort + } else if len(username) > 40 { + return ErrUsernameTooLong + } + + return ErrInvalidUsername + } + return nil +} + const ( ErrUserNotFound = errors.Sentinel("user not found") @@ -139,14 +157,8 @@ const ( func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u User, err error) { // check if the username is valid // if not, return an error depending on what failed - if !usernameRegex.MatchString(username) { - if len(username) < 2 { - return u, ErrUsernameTooShort - } else if len(username) > 40 { - return u, ErrUsernameTooLong - } - - return u, ErrInvalidUsername + if err := UsernameValid(username); err != nil { + return u, err } sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING *").ToSql() @@ -458,7 +470,7 @@ func (db *DB) Username(ctx context.Context, name string) (u User, err error) { // UsernameTaken checks if the given username is already taken. func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken bool, err error) { - if !usernameRegex.MatchString(username) { + if err := UsernameValid(username); err != nil { return false, false, nil } @@ -468,8 +480,8 @@ func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken // UpdateUsername validates the given username, then updates the given user's name to it if valid. func (db *DB) UpdateUsername(ctx context.Context, tx pgx.Tx, id xid.ID, newName string) error { - if !usernameRegex.MatchString(newName) { - return ErrInvalidUsername + if err := UsernameValid(newName); err != nil { + return err } sql, args, err := sq.Update("users").Set("username", newName).Where("id = ?", id).ToSql()