/storage/db/db.go

https://github.com/markphelps/flipt · Go · 174 lines · 133 code · 34 blank · 7 comment · 27 complexity · 46d1799f8c39ccddb8c0ced13b0ab3b2 MD5 · raw file

  1. package db
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "fmt"
  6. "net/url"
  7. "github.com/go-sql-driver/mysql"
  8. "github.com/lib/pq"
  9. "github.com/luna-duclos/instrumentedsql"
  10. "github.com/luna-duclos/instrumentedsql/opentracing"
  11. "github.com/markphelps/flipt/config"
  12. "github.com/mattn/go-sqlite3"
  13. "github.com/xo/dburl"
  14. )
  15. // Open opens a connection to the db
  16. func Open(cfg config.Config) (*sql.DB, Driver, error) {
  17. sql, driver, err := open(cfg, false)
  18. if err != nil {
  19. return nil, 0, err
  20. }
  21. sql.SetMaxIdleConns(cfg.Database.MaxIdleConn)
  22. if cfg.Database.MaxOpenConn > 0 {
  23. sql.SetMaxOpenConns(cfg.Database.MaxOpenConn)
  24. }
  25. if cfg.Database.ConnMaxLifetime > 0 {
  26. sql.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
  27. }
  28. registerMetrics(driver, sql)
  29. return sql, driver, nil
  30. }
  31. func open(cfg config.Config, migrate bool) (*sql.DB, Driver, error) {
  32. d, url, err := parse(cfg, migrate)
  33. if err != nil {
  34. return nil, 0, err
  35. }
  36. driverName := fmt.Sprintf("instrumented-%s", d)
  37. var dr driver.Driver
  38. switch d {
  39. case SQLite:
  40. dr = &sqlite3.SQLiteDriver{}
  41. case Postgres:
  42. dr = &pq.Driver{}
  43. case MySQL:
  44. dr = &mysql.MySQLDriver{}
  45. }
  46. registered := false
  47. for _, dd := range sql.Drivers() {
  48. if dd == driverName {
  49. registered = true
  50. break
  51. }
  52. }
  53. if !registered {
  54. sql.Register(driverName, instrumentedsql.WrapDriver(dr, instrumentedsql.WithTracer(opentracing.NewTracer(false))))
  55. }
  56. db, err := sql.Open(driverName, url.DSN)
  57. if err != nil {
  58. return nil, 0, fmt.Errorf("opening db for driver: %s %w", d, err)
  59. }
  60. return db, d, nil
  61. }
  62. var (
  63. driverToString = map[Driver]string{
  64. SQLite: "sqlite3",
  65. Postgres: "postgres",
  66. MySQL: "mysql",
  67. }
  68. stringToDriver = map[string]Driver{
  69. "sqlite3": SQLite,
  70. "postgres": Postgres,
  71. "mysql": MySQL,
  72. }
  73. )
  74. // Driver represents a database driver
  75. type Driver uint8
  76. func (d Driver) String() string {
  77. return driverToString[d]
  78. }
  79. const (
  80. _ Driver = iota
  81. // SQLite ...
  82. SQLite
  83. // Postgres ...
  84. Postgres
  85. // MySQL ...
  86. MySQL
  87. )
  88. func parse(cfg config.Config, migrate bool) (Driver, *dburl.URL, error) {
  89. u := cfg.Database.URL
  90. if u == "" {
  91. host := cfg.Database.Host
  92. if cfg.Database.Port > 0 {
  93. host = fmt.Sprintf("%s:%d", host, cfg.Database.Port)
  94. }
  95. uu := url.URL{
  96. Scheme: cfg.Database.Protocol.String(),
  97. Host: host,
  98. Path: cfg.Database.Name,
  99. }
  100. if cfg.Database.User != "" {
  101. if cfg.Database.Password != "" {
  102. uu.User = url.UserPassword(cfg.Database.User, cfg.Database.Password)
  103. } else {
  104. uu.User = url.User(cfg.Database.User)
  105. }
  106. }
  107. u = uu.String()
  108. }
  109. errURL := func(rawurl string, err error) error {
  110. return fmt.Errorf("error parsing url: %q, %v", rawurl, err)
  111. }
  112. url, err := dburl.Parse(u)
  113. if err != nil {
  114. return 0, nil, errURL(u, err)
  115. }
  116. driver := stringToDriver[url.Driver]
  117. if driver == 0 {
  118. return 0, nil, fmt.Errorf("unknown database driver for: %q", url.Driver)
  119. }
  120. switch driver {
  121. case MySQL:
  122. v := url.Query()
  123. v.Set("multiStatements", "true")
  124. v.Set("parseTime", "true")
  125. if !migrate {
  126. v.Set("sql_mode", "ANSI")
  127. }
  128. url.RawQuery = v.Encode()
  129. // we need to re-parse since we modified the query params
  130. url, err = dburl.Parse(url.URL.String())
  131. case SQLite:
  132. v := url.Query()
  133. v.Set("cache", "shared")
  134. v.Set("_fk", "true")
  135. url.RawQuery = v.Encode()
  136. // we need to re-parse since we modified the query params
  137. url, err = dburl.Parse(url.URL.String())
  138. }
  139. return driver, url, err
  140. }