/internal/api/external/auth/auth.go

https://github.com/brocaar/chirpstack-application-server · Go · 192 lines · 127 code · 38 blank · 27 comment · 34 complexity · 5a69ff76c2d736c41f950613d054c831 MD5 · raw file

  1. package auth
  2. import (
  3. "fmt"
  4. "regexp"
  5. "github.com/gofrs/uuid"
  6. jwt "github.com/golang-jwt/jwt/v4"
  7. "github.com/jmoiron/sqlx"
  8. "github.com/pkg/errors"
  9. log "github.com/sirupsen/logrus"
  10. "golang.org/x/net/context"
  11. "google.golang.org/grpc/metadata"
  12. "github.com/brocaar/chirpstack-application-server/internal/storage"
  13. )
  14. var validAuthorizationRegexp = regexp.MustCompile(`(?i)^bearer (.*)$`)
  15. // Claims defines the struct containing the token claims.
  16. type Claims struct {
  17. jwt.StandardClaims
  18. // Username defines the identity of the user.
  19. Username string `json:"username"`
  20. // UserID defines the ID of th user.
  21. UserID int64 `json:"user_id"`
  22. // APIKeyID defines the API key ID.
  23. APIKeyID uuid.UUID `json:"api_key_id"`
  24. }
  25. // Validator defines the interface a validator needs to implement.
  26. type Validator interface {
  27. // Validate validates the given set of validators against the given context.
  28. // Must return after the first validator function either returns true or
  29. // and error. The way how the validation must be seens is:
  30. // if validatorFunc1 || validatorFunc2 || validatorFunc3 ...
  31. // In case multiple validators must validate to true, then a validator
  32. // func needs to be implemented which validates a given set of funcs as:
  33. // if validatorFunc1 && validatorFunc2 && ValidatorFunc3 ...
  34. Validate(context.Context, ...ValidatorFunc) error
  35. // GetSubject returns the claim subject.
  36. GetSubject(context.Context) (string, error)
  37. // GetUser returns the user object.
  38. GetUser(context.Context) (storage.User, error)
  39. // GetAPIKey returns the API key ID.
  40. GetAPIKeyID(context.Context) (uuid.UUID, error)
  41. }
  42. // ValidatorFunc defines the signature of a claim validator function.
  43. // It returns a bool indicating if the validation passed or failed and an
  44. // error in case an error occurred (e.g. db connectivity).
  45. type ValidatorFunc func(sqlx.Queryer, *Claims) (bool, error)
  46. // JWTValidator validates JWT tokens.
  47. type JWTValidator struct {
  48. db sqlx.Ext
  49. secret string
  50. algorithm string
  51. }
  52. // NewJWTValidator creates a new JWTValidator.
  53. func NewJWTValidator(db sqlx.Ext, algorithm, secret string) *JWTValidator {
  54. return &JWTValidator{
  55. db: db,
  56. secret: secret,
  57. algorithm: algorithm,
  58. }
  59. }
  60. // Validate validates the token from the given context against the given
  61. // validator funcs.
  62. func (v JWTValidator) Validate(ctx context.Context, funcs ...ValidatorFunc) error {
  63. claims, err := v.getClaims(ctx)
  64. if err != nil {
  65. return err
  66. }
  67. if claims.Audience != "as" {
  68. return ErrNotAuthorized
  69. }
  70. for _, f := range funcs {
  71. ok, err := f(v.db, claims)
  72. if err != nil {
  73. return errors.Wrap(err, "validator func error")
  74. }
  75. if ok {
  76. return nil
  77. }
  78. }
  79. return ErrNotAuthorized
  80. }
  81. // GetSubject returns the subject of the claim.
  82. func (v JWTValidator) GetSubject(ctx context.Context) (string, error) {
  83. claims, err := v.getClaims(ctx)
  84. if err != nil {
  85. return "", err
  86. }
  87. return claims.Subject, nil
  88. }
  89. // GetAPIKeyID returns the API key of the token.
  90. func (v JWTValidator) GetAPIKeyID(ctx context.Context) (uuid.UUID, error) {
  91. claims, err := v.getClaims(ctx)
  92. if err != nil {
  93. return uuid.Nil, err
  94. }
  95. return claims.APIKeyID, nil
  96. }
  97. // GetUser returns the user object.
  98. func (v JWTValidator) GetUser(ctx context.Context) (storage.User, error) {
  99. claims, err := v.getClaims(ctx)
  100. if err != nil {
  101. return storage.User{}, err
  102. }
  103. if claims.Subject != SubjectUser {
  104. return storage.User{}, errors.New("subject must be user")
  105. }
  106. if claims.UserID != 0 {
  107. return storage.GetUser(ctx, v.db, claims.UserID)
  108. }
  109. if claims.Username != "" {
  110. return storage.GetUserByEmail(ctx, v.db, claims.Username)
  111. }
  112. return storage.User{}, errors.New("no username or user_id in claims")
  113. }
  114. func (v JWTValidator) getClaims(ctx context.Context) (*Claims, error) {
  115. tokenStr, err := getTokenFromContext(ctx)
  116. if err != nil {
  117. return nil, errors.Wrap(err, "get token from context error")
  118. }
  119. token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  120. if token.Header["alg"] != v.algorithm {
  121. return nil, ErrInvalidAlgorithm
  122. }
  123. return []byte(v.secret), nil
  124. })
  125. if err != nil {
  126. return nil, errors.Wrap(err, "jwt parse error")
  127. }
  128. if !token.Valid {
  129. return nil, ErrInvalidToken
  130. }
  131. claims, ok := token.Claims.(*Claims)
  132. if !ok {
  133. // no need to use a static error, this should never happen
  134. return nil, fmt.Errorf("api/auth: expected *Claims, got %T", token.Claims)
  135. }
  136. return claims, nil
  137. }
  138. func getTokenFromContext(ctx context.Context) (string, error) {
  139. md, ok := metadata.FromIncomingContext(ctx)
  140. if !ok {
  141. return "", ErrNoMetadataInContext
  142. }
  143. token, ok := md["authorization"]
  144. if !ok || len(token) == 0 {
  145. return "", ErrNoAuthorizationInMetadata
  146. }
  147. match := validAuthorizationRegexp.FindStringSubmatch(token[0])
  148. // authorization header should respect RFC1945
  149. if len(match) == 0 {
  150. log.Warning("Deprecated Authorization header : RFC1945 format expected : Authorization: <type> <credentials>")
  151. return token[0], nil
  152. }
  153. return match[1], nil
  154. }