2022-05-04 07:27:16 -07:00
|
|
|
package auth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/rand"
|
|
|
|
"encoding/base64"
|
|
|
|
|
|
|
|
"github.com/mediocregopher/radix/v4"
|
|
|
|
)
|
|
|
|
|
|
|
|
// numStates is the number of CSRF states stored in Redis at any one time.
|
|
|
|
// This must be an integer.
|
|
|
|
const numStates = "1000"
|
|
|
|
|
|
|
|
// setCSRFState generates a random string to use as state, then stores that in Redis.
|
|
|
|
func (s *Server) setCSRFState(ctx context.Context) (string, error) {
|
2022-05-17 13:35:26 -07:00
|
|
|
state := RandBase64(32)
|
2022-05-04 07:27:16 -07:00
|
|
|
|
2022-05-17 13:35:26 -07:00
|
|
|
err := s.DB.MultiCmd(ctx,
|
2022-05-04 07:27:16 -07:00
|
|
|
radix.Cmd(nil, "LPUSH", "csrf", state),
|
|
|
|
radix.Cmd(nil, "LTRIM", "csrf", "0", numStates),
|
|
|
|
)
|
|
|
|
return state, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// validateCSRFState checks if the given state exists in Redis.
|
|
|
|
func (s *Server) validateCSRFState(ctx context.Context, state string) (matched bool, err error) {
|
|
|
|
var num int
|
|
|
|
err = s.DB.Redis.Do(ctx, radix.Cmd(&num, "LREM", "csrf", "1", state))
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
return num > 0, nil
|
|
|
|
}
|
2022-05-17 13:35:26 -07:00
|
|
|
|
|
|
|
// RandBase64 returns a string of random bytes encoded in raw base 64.
|
|
|
|
func RandBase64(size int) string {
|
|
|
|
b := make([]byte, size)
|
|
|
|
_, err := rand.Read(b)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
|
|
}
|