feat(backend): make snowflake IDs usable in /users/{id}, /users/{id}/members

This commit is contained in:
sam 2023-08-20 22:45:14 +02:00
parent b1a7ef89ca
commit 1cce0defca
No known key found for this signature in database
GPG Key ID: B4EF20DDE721CAA1
4 changed files with 54 additions and 2 deletions

View File

@ -73,9 +73,22 @@ func (db *DB) Member(ctx context.Context, id xid.ID) (m Member, err error) {
return m, nil return m, nil
} }
func (db *DB) MemberBySnowflake(ctx context.Context, id common.MemberID) (m Member, err error) {
sql, args, err := sq.Select("*").From("members").Where("snowflake_id = ?", id).ToSql()
if err != nil {
return m, errors.Wrap(err, "building sql")
}
err = pgxscan.Get(ctx, db, &m, sql, args...)
if err != nil {
return m, errors.Wrap(err, "executing query")
}
return m, nil
}
// UserMember returns a member scoped by user. // UserMember returns a member scoped by user.
func (db *DB) UserMember(ctx context.Context, userID xid.ID, memberRef string) (m Member, err error) { func (db *DB) UserMember(ctx context.Context, userID xid.ID, memberRef string) (m Member, err error) {
sql, args, err := sq.Select("*").From("members").Where("user_id = ?", userID).Where("(id = ? or name = ?)", memberRef, memberRef).ToSql() sql, args, err := sq.Select("*").From("members").Where("user_id = ?", userID).Where("(id = ? or name = ? or snowflake_id = ?)", memberRef, memberRef, memberRef).ToSql()
if err != nil { if err != nil {
return m, errors.Wrap(err, "building sql") return m, errors.Wrap(err, "building sql")
} }

View File

@ -495,6 +495,26 @@ func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) {
return u, nil return u, nil
} }
// UserBySnowflake gets a user by their snowflake ID.
func (db *DB) UserBySnowflake(ctx context.Context, id common.UserID) (u User, err error) {
sql, args, err := sq.Select("*", "(SELECT instance FROM fediverse_apps WHERE id = users.fediverse_app_id) AS fediverse_instance").
From("users").Where("snowflake_id = ?", id).ToSql()
if err != nil {
return u, errors.Wrap(err, "building sql")
}
err = pgxscan.Get(ctx, db, &u, sql, args...)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return u, ErrUserNotFound
}
return u, errors.Wrap(err, "getting user from db")
}
return u, nil
}
// Username gets a user by username. // Username gets a user by username.
func (db *DB) Username(ctx context.Context, name string) (u User, err error) { func (db *DB) Username(ctx context.Context, name string) (u User, err error) {
sql, args, err := sq.Select("*").From("users").Where("username = ?", name).ToSql() sql, args, err := sq.Select("*").From("users").Where("username = ?", name).ToSql()

View File

@ -191,12 +191,22 @@ func (s *Server) getMeMember(w http.ResponseWriter, r *http.Request) error {
} }
func (s *Server) parseUser(ctx context.Context, userRef string) (u db.User, err error) { func (s *Server) parseUser(ctx context.Context, userRef string) (u db.User, err error) {
if id, err := xid.FromString(userRef); err != nil { // check xid first
if id, err := xid.FromString(userRef); err == nil {
u, err := s.DB.User(ctx, id) u, err := s.DB.User(ctx, id)
if err == nil { if err == nil {
return u, nil return u, nil
} }
} }
// if not an xid, check by snowflake
if id, err := common.ParseSnowflake(userRef); err == nil {
u, err := s.DB.UserBySnowflake(ctx, common.UserID(id))
if err == nil {
return u, nil
}
}
// else, use username
return s.DB.Username(ctx, userRef) return s.DB.Username(ctx, userRef)
} }

View File

@ -129,6 +129,15 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) (err error) {
} }
} }
if u.ID.IsNil() {
if id, err := common.ParseSnowflake(userRef); err == nil {
u, err = s.DB.UserBySnowflake(ctx, common.UserID(id))
if err != nil {
log.Errorf("getting user by snowflake: %v", err)
}
}
}
if u.ID.IsNil() { if u.ID.IsNil() {
u, err = s.DB.Username(ctx, userRef) u, err = s.DB.Username(ctx, userRef)
if err == db.ErrUserNotFound { if err == db.ErrUserNotFound {