/internal/test_db.go

https://github.com/go-reform/reform · Go · 155 lines · 123 code · 26 blank · 6 comment · 25 complexity · 19ae93929043b80855d4947ebd037289 MD5 · raw file

  1. package internal
  2. import (
  3. "database/sql"
  4. "log"
  5. "os"
  6. "strings"
  7. "sync"
  8. "time"
  9. sqlite3Driver "github.com/mattn/go-sqlite3"
  10. "gopkg.in/reform.v1"
  11. "gopkg.in/reform.v1/dialects"
  12. "gopkg.in/reform.v1/dialects/mssql" //nolint:staticcheck
  13. "gopkg.in/reform.v1/dialects/mysql"
  14. "gopkg.in/reform.v1/dialects/postgresql"
  15. "gopkg.in/reform.v1/dialects/sqlite3"
  16. "gopkg.in/reform.v1/dialects/sqlserver"
  17. )
  18. //nolint:gochecknoglobals
  19. var (
  20. sqlite3RegisterOnce sync.Once
  21. inspectOnce sync.Once
  22. )
  23. // ConnectToTestDB returns open and prepared connection to test DB.
  24. func ConnectToTestDB() *reform.DB {
  25. driver := strings.TrimSpace(os.Getenv("REFORM_TEST_DRIVER"))
  26. source := strings.TrimSpace(os.Getenv("REFORM_TEST_SOURCE"))
  27. if driver == "" || source == "" {
  28. log.Fatal("no driver or source, set REFORM_TEST_DRIVER and REFORM_TEST_SOURCE")
  29. }
  30. // register custom function "sleep" for context tests
  31. if driver == "sqlite3" {
  32. driver = "sqlite3_with_sleep"
  33. sqlite3RegisterOnce.Do(func() {
  34. sleep := func(nsec int64) (int64, error) {
  35. time.Sleep(time.Duration(nsec))
  36. return nsec, nil
  37. }
  38. sql.Register(driver, &sqlite3Driver.SQLiteDriver{
  39. ConnectHook: func(conn *sqlite3Driver.SQLiteConn) error {
  40. return conn.RegisterFunc("sleep", sleep, false)
  41. },
  42. })
  43. })
  44. }
  45. db, err := sql.Open(driver, source)
  46. if err != nil {
  47. log.Fatal(err)
  48. }
  49. // Use single connection so various session-related variables work.
  50. // For example: "PRAGMA foreign_keys" for SQLite3, "SET IDENTITY_INSERT" for MS SQL, etc.
  51. db.SetMaxIdleConns(1)
  52. db.SetMaxOpenConns(1)
  53. db.SetConnMaxLifetime(0)
  54. if err = db.Ping(); err != nil {
  55. log.Fatal(err)
  56. }
  57. now := time.Now()
  58. // select dialect for driver
  59. dialect := dialects.ForDriver(driver)
  60. switch dialect {
  61. case postgresql.Dialect:
  62. inspectOnce.Do(func() {
  63. log.Printf("driver = %q, source = %q", driver, source)
  64. log.Printf("time.Now() = %s", now)
  65. log.Printf("time.Now().UTC() = %s", now.UTC())
  66. var version, tz string
  67. if err = db.QueryRow("SHOW server_version").Scan(&version); err != nil {
  68. log.Fatal(err)
  69. }
  70. if err = db.QueryRow("SHOW TimeZone").Scan(&tz); err != nil {
  71. log.Fatal(err)
  72. }
  73. log.Printf("PostgreSQL version = %q", version)
  74. log.Printf("PostgreSQL TimeZone = %q", tz)
  75. })
  76. case mysql.Dialect:
  77. inspectOnce.Do(func() {
  78. log.Printf("driver = %q, source = %q", driver, source)
  79. log.Printf("time.Now() = %s", now)
  80. log.Printf("time.Now().UTC() = %s", now.UTC())
  81. q := "SELECT @@version, @@sql_mode, @@autocommit, @@time_zone"
  82. var version, mode, autocommit, tz string
  83. if err = db.QueryRow(q).Scan(&version, &mode, &autocommit, &tz); err != nil {
  84. log.Fatal(err)
  85. }
  86. log.Printf("MySQL version = %q", version)
  87. log.Printf("MySQL sql_mode = %q", mode)
  88. log.Printf("MySQL autocommit = %q", autocommit)
  89. log.Printf("MySQL time_zone = %q", tz)
  90. })
  91. case sqlite3.Dialect:
  92. if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
  93. log.Fatal(err)
  94. }
  95. inspectOnce.Do(func() {
  96. log.Printf("driver = %q, source = %q", driver, source)
  97. log.Printf("time.Now() = %s", now)
  98. log.Printf("time.Now().UTC() = %s", now.UTC())
  99. var version, sourceID string
  100. if err = db.QueryRow("SELECT sqlite_version(), sqlite_source_id()").Scan(&version, &sourceID); err != nil {
  101. log.Fatal(err)
  102. }
  103. log.Printf("SQLite3 version = %q", version)
  104. log.Printf("SQLite3 source = %q", sourceID)
  105. })
  106. case mssql.Dialect: //nolint:staticcheck
  107. fallthrough
  108. case sqlserver.Dialect:
  109. inspectOnce.Do(func() {
  110. log.Printf("driver = %q, source = %q", driver, source)
  111. log.Printf("time.Now() = %s", now)
  112. log.Printf("time.Now().UTC() = %s", now.UTC())
  113. var version string
  114. var options uint16
  115. if err = db.QueryRow("SELECT @@VERSION, @@OPTIONS").Scan(&version, &options); err != nil {
  116. log.Fatal(err)
  117. }
  118. xact := "ON"
  119. if options&0x4000 == 0 {
  120. xact = "OFF"
  121. }
  122. log.Printf("MS SQL VERSION = %s", version)
  123. log.Printf("MS SQL OPTIONS = %#4x (XACT_ABORT %s)", options, xact)
  124. })
  125. default:
  126. log.Fatalf("reform: no dialect for driver %s", driver)
  127. }
  128. return reform.NewDB(db, dialect, nil)
  129. }