From 1cce0defca6d1a72dbfd81852c80c76ab6ac3591 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 20 Aug 2023 22:45:14 +0200 Subject: [PATCH] feat(backend): make snowflake IDs usable in /users/{id}, /users/{id}/members --- backend/db/member.go | 15 ++++++++++++++- backend/db/user.go | 20 ++++++++++++++++++++ backend/routes/v1/member/get_member.go | 12 +++++++++++- backend/routes/v1/user/get_user.go | 9 +++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/backend/db/member.go b/backend/db/member.go index 5de82aa..53b6a9f 100644 --- a/backend/db/member.go +++ b/backend/db/member.go @@ -73,9 +73,22 @@ func (db *DB) Member(ctx context.Context, id xid.ID) (m Member, err error) { return m, nil } +func (db *DB) MemberBySnowflake(ctx context.Context, id common.MemberID) (m Member, err error) { + sql, args, err := sq.Select("*").From("members").Where("snowflake_id = ?", id).ToSql() + if err != nil { + return m, errors.Wrap(err, "building sql") + } + + err = pgxscan.Get(ctx, db, &m, sql, args...) + if err != nil { + return m, errors.Wrap(err, "executing query") + } + return m, nil +} + // UserMember returns a member scoped by user. func (db *DB) UserMember(ctx context.Context, userID xid.ID, memberRef string) (m Member, err error) { - sql, args, err := sq.Select("*").From("members").Where("user_id = ?", userID).Where("(id = ? or name = ?)", memberRef, memberRef).ToSql() + sql, args, err := sq.Select("*").From("members").Where("user_id = ?", userID).Where("(id = ? or name = ? or snowflake_id = ?)", memberRef, memberRef, memberRef).ToSql() if err != nil { return m, errors.Wrap(err, "building sql") } diff --git a/backend/db/user.go b/backend/db/user.go index 8b9a0ca..80d380d 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -495,6 +495,26 @@ func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) { return u, nil } +// UserBySnowflake gets a user by their snowflake ID. +func (db *DB) UserBySnowflake(ctx context.Context, id common.UserID) (u User, err error) { + sql, args, err := sq.Select("*", "(SELECT instance FROM fediverse_apps WHERE id = users.fediverse_app_id) AS fediverse_instance"). + From("users").Where("snowflake_id = ?", id).ToSql() + if err != nil { + return u, errors.Wrap(err, "building sql") + } + + err = pgxscan.Get(ctx, db, &u, sql, args...) + if err != nil { + if errors.Cause(err) == pgx.ErrNoRows { + return u, ErrUserNotFound + } + + return u, errors.Wrap(err, "getting user from db") + } + + return u, nil +} + // Username gets a user by username. func (db *DB) Username(ctx context.Context, name string) (u User, err error) { sql, args, err := sq.Select("*").From("users").Where("username = ?", name).ToSql() diff --git a/backend/routes/v1/member/get_member.go b/backend/routes/v1/member/get_member.go index f87bbe9..60fa573 100644 --- a/backend/routes/v1/member/get_member.go +++ b/backend/routes/v1/member/get_member.go @@ -191,12 +191,22 @@ func (s *Server) getMeMember(w http.ResponseWriter, r *http.Request) error { } func (s *Server) parseUser(ctx context.Context, userRef string) (u db.User, err error) { - if id, err := xid.FromString(userRef); err != nil { + // check xid first + if id, err := xid.FromString(userRef); err == nil { u, err := s.DB.User(ctx, id) if err == nil { return u, nil } } + // if not an xid, check by snowflake + if id, err := common.ParseSnowflake(userRef); err == nil { + u, err := s.DB.UserBySnowflake(ctx, common.UserID(id)) + if err == nil { + return u, nil + } + } + + // else, use username return s.DB.Username(ctx, userRef) } diff --git a/backend/routes/v1/user/get_user.go b/backend/routes/v1/user/get_user.go index 4826ed0..af43a4b 100644 --- a/backend/routes/v1/user/get_user.go +++ b/backend/routes/v1/user/get_user.go @@ -129,6 +129,15 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) (err error) { } } + if u.ID.IsNil() { + if id, err := common.ParseSnowflake(userRef); err == nil { + u, err = s.DB.UserBySnowflake(ctx, common.UserID(id)) + if err != nil { + log.Errorf("getting user by snowflake: %v", err) + } + } + } + if u.ID.IsNil() { u, err = s.DB.Username(ctx, userRef) if err == db.ErrUserNotFound {