fix(backend): fix sql errors in CreateUser and User.UpdateFromDiscord

This commit is contained in:
Sam 2023-02-25 22:16:22 +01:00
parent b92ced7d1a
commit b41ca0b753
No known key found for this signature in database
GPG Key ID: B4EF20DDE721CAA1
1 changed files with 20 additions and 10 deletions

View File

@ -65,15 +65,18 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use
return u, ErrInvalidUsername
}
sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING *").ToSql()
sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING id").ToSql()
if err != nil {
return u, errors.Wrap(err, "building sql")
}
err = pgxscan.Get(ctx, tx, &u, sql, args...)
var id xid.ID
err = tx.QueryRow(ctx, sql, args...).Scan(&id)
if err != nil {
if v, ok := errors.Cause(err).(*pgconn.PgError); ok {
if v.Code == "23505" { // unique constraint violation
pge := &pgconn.PgError{}
if errors.As(err, &pge) {
// unique constraint violation
if pge.Code == "23505" {
return u, ErrUsernameTaken
}
}
@ -81,7 +84,7 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use
return u, errors.Cause(err)
}
return u, nil
return db.getUser(ctx, tx, id)
}
// DiscordUser fetches a user by Discord user ID.
@ -103,18 +106,25 @@ func (db *DB) DiscordUser(ctx context.Context, discordID string) (u User, err er
}
func (u *User) UpdateFromDiscord(ctx context.Context, db querier, du *discordgo.User) error {
builder := sq.Update("users").
sql, args, err := sq.Update("users").
Set("discord", du.ID).
Set("discord_username", du.String()).
Where("id = ?", u.ID).
Suffix("RETURNING *")
sql, args, err := builder.ToSql()
ToSql()
if err != nil {
return errors.Wrap(err, "building sql")
}
return pgxscan.Get(ctx, db, u, sql, args...)
_, err = db.Exec(ctx, sql, args...)
if err != nil {
return errors.Wrap(err, "executing query")
}
u.Discord = &du.ID
username := du.String()
u.DiscordUsername = &username
return nil
}
func (db *DB) getUser(ctx context.Context, q querier, id xid.ID) (u User, err error) {