/db/sessions.go

https://github.com/starkandwayne/shield · Go · 236 lines · 201 code · 35 blank · 0 comment · 48 complexity · 1c77aa9ff2b019092ebc6e784b691047 MD5 · raw file

  1. package db
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "regexp"
  6. "strings"
  7. "time"
  8. )
  9. type Session struct {
  10. UUID string `json:"uuid"`
  11. UserUUID string `json:"user_uuid"`
  12. CreatedAt int64 `json:"created_at"`
  13. LastSeen int64 `json:"last_seen_at"`
  14. Token string `json:"token_uuid"`
  15. Name string `json:"name"`
  16. IP string `json:"ip_addr"`
  17. UserAgent string `json:"user_agent"`
  18. UserAccount string `json:"user_account"`
  19. CurrentSession bool `json:"current_session"`
  20. }
  21. type SessionFilter struct {
  22. Name string
  23. ExactMatch bool
  24. UUID string
  25. UserUUID string
  26. Limit int
  27. IP string
  28. IsToken bool
  29. }
  30. func (f *SessionFilter) Query() (string, []interface{}) {
  31. wheres := []string{"s.uuid = s.uuid"}
  32. var args []interface{}
  33. if f.UUID != "" {
  34. wheres = append(wheres, "s.uuid = ?")
  35. args = append(args, f.UUID)
  36. }
  37. if f.UserUUID != "" {
  38. wheres = append(wheres, "s.user_uuid = ?")
  39. args = append(args, f.UserUUID)
  40. }
  41. if f.Name != "" {
  42. if f.ExactMatch {
  43. wheres = append(wheres, "s.name = ?")
  44. args = append(args, Pattern(f.Name))
  45. } else {
  46. wheres = append(wheres, "s.name LIKE ?")
  47. args = append(args, f.Name)
  48. }
  49. }
  50. if f.IP != "" {
  51. wheres = append(wheres, "s.ip_addr = ?")
  52. args = append(args, f.IP)
  53. }
  54. if !f.IsToken {
  55. wheres = append(wheres, "s.token IS NULL")
  56. }
  57. limit := ""
  58. if f.Limit > 0 {
  59. limit = " LIMIT ?"
  60. args = append(args, f.Limit)
  61. }
  62. return `
  63. SELECT s.uuid, s.user_uuid, s.created_at, s.last_seen, s.token, s.name, s.ip_addr, s.user_agent, u.account, u.backend
  64. FROM sessions s
  65. INNER JOIN users u ON u.uuid = s.user_uuid
  66. WHERE ` + strings.Join(wheres, " AND ") + `
  67. ` + limit, args
  68. }
  69. func (db *DB) GetAllSessions(filter *SessionFilter) ([]*Session, error) {
  70. if filter == nil {
  71. filter = &SessionFilter{}
  72. }
  73. l := []*Session{}
  74. query, args := filter.Query()
  75. db.exclusive.Lock()
  76. defer db.exclusive.Unlock()
  77. r, err := db.query(query, args...)
  78. if err != nil {
  79. return l, err
  80. }
  81. defer r.Close()
  82. for r.Next() {
  83. s := &Session{}
  84. var (
  85. backend string
  86. last *int64
  87. token sql.NullString
  88. )
  89. if err := r.Scan(&s.UUID, &s.UserUUID, &s.CreatedAt, &last, &token, &s.Name, &s.IP, &s.UserAgent, &s.UserAccount, &backend); err != nil {
  90. return nil, err
  91. }
  92. s.UserAccount = s.UserAccount + "@" + backend
  93. if last != nil {
  94. s.LastSeen = *last
  95. }
  96. if token.Valid {
  97. s.Token = token.String
  98. }
  99. l = append(l, s)
  100. }
  101. return l, nil
  102. }
  103. func (db *DB) GetSession(id string) (*Session, error) {
  104. db.exclusive.Lock()
  105. defer db.exclusive.Unlock()
  106. r, err := db.query(`
  107. SELECT s.uuid, s.user_uuid, s.created_at, s.last_seen, s.token,
  108. s.name, s.ip_addr, s.user_agent, u.account, u.backend
  109. FROM sessions s
  110. INNER JOIN users u ON u.uuid = s.user_uuid
  111. WHERE s.uuid = ?`, id)
  112. if err != nil {
  113. return nil, fmt.Errorf("failed to retrieve session: %s", err)
  114. }
  115. defer r.Close()
  116. if !r.Next() {
  117. return nil, nil
  118. }
  119. s := &Session{}
  120. var (
  121. backend string
  122. last *int64
  123. token sql.NullString
  124. )
  125. if err := r.Scan(&s.UUID, &s.UserUUID, &s.CreatedAt, &last, &token,
  126. &s.Name, &s.IP, &s.UserAgent, &s.UserAccount, &backend); err != nil {
  127. return nil, err
  128. }
  129. s.UserAccount = s.UserAccount + "@" + backend
  130. if token.Valid {
  131. s.Token = token.String
  132. }
  133. if last != nil {
  134. s.LastSeen = *last
  135. }
  136. return s, nil
  137. }
  138. func (db *DB) GetUserForSession(id string) (*User, error) {
  139. db.exclusive.Lock()
  140. defer db.exclusive.Unlock()
  141. r, err := db.query(`
  142. SELECT u.uuid, u.name, u.account, u.backend, u.sysrole,
  143. u.pwhash, u.default_tenant
  144. FROM sessions s
  145. INNER JOIN users u ON u.uuid = s.user_uuid
  146. WHERE s.uuid = ?`, id)
  147. if err != nil {
  148. return nil, err
  149. }
  150. defer r.Close()
  151. if !r.Next() {
  152. return nil, nil
  153. }
  154. u := &User{}
  155. var pwhash sql.NullString
  156. if err := r.Scan(&u.UUID, &u.Name, &u.Account, &u.Backend, &u.SysRole,
  157. &pwhash, &u.DefaultTenant); err != nil {
  158. return nil, err
  159. }
  160. if pwhash.Valid {
  161. u.pwhash = pwhash.String
  162. }
  163. return u, nil
  164. }
  165. func (db *DB) CreateSession(session *Session) (*Session, error) {
  166. if session == nil {
  167. return nil, fmt.Errorf("cannot create an empty (user-less) session")
  168. }
  169. id := RandomID()
  170. err := db.Exec(`
  171. INSERT INTO sessions (uuid, user_uuid, created_at, last_seen, ip_addr, user_agent)
  172. VALUES ( ?, ?, ?, ?, ?, ?)`,
  173. id, session.UserUUID, time.Now().Unix(), time.Now().Unix(), stripIP(session.IP), session.UserAgent)
  174. if err != nil {
  175. return nil, err
  176. }
  177. return db.GetSession(id)
  178. }
  179. func (db *DB) ClearAllSessions() error {
  180. return db.Exec(`DELETE FROM sessions`)
  181. }
  182. func (db *DB) ClearExpiredSessions(expiration_threshold time.Time) error {
  183. return db.Exec(`DELETE FROM sessions WHERE token IS NULL AND last_seen < ?`, expiration_threshold.Unix())
  184. }
  185. func (db *DB) ClearSession(id string) error {
  186. return db.Exec(`DELETE FROM sessions WHERE token IS NULL AND uuid = ?`, id)
  187. }
  188. func (db *DB) PokeSession(session *Session) error {
  189. return db.Exec(`
  190. UPDATE sessions SET last_seen = ?, user_uuid = ?, ip_addr = ?, user_agent = ?
  191. WHERE uuid = ?`, time.Now().Unix(), session.UserUUID, session.IP, session.UserAgent, session.UUID)
  192. }
  193. func stripIP(raw_ip string) string {
  194. return regexp.MustCompile(":[^:]+$").ReplaceAllString(raw_ip, "")
  195. }
  196. func (s *Session) Expired(lifetime int) bool {
  197. return s.Token == "" && time.Now().Unix() > s.LastSeen+(int64)(lifetime*3600)
  198. }