/pkg/database/database.go

https://github.com/stashapp/stash · Go · 232 lines · 174 code · 42 blank · 16 comment · 37 complexity · eb9835a5bfc7f681825cca0722bfeb84 MD5 · raw file

  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "time"
  8. "github.com/gobuffalo/packr/v2"
  9. "github.com/golang-migrate/migrate/v4"
  10. sqlite3mig "github.com/golang-migrate/migrate/v4/database/sqlite3"
  11. "github.com/golang-migrate/migrate/v4/source"
  12. "github.com/jmoiron/sqlx"
  13. sqlite3 "github.com/mattn/go-sqlite3"
  14. "github.com/stashapp/stash/pkg/logger"
  15. "github.com/stashapp/stash/pkg/utils"
  16. )
  17. var DB *sqlx.DB
  18. var dbPath string
  19. var appSchemaVersion uint = 12
  20. var databaseSchemaVersion uint
  21. const sqlite3Driver = "sqlite3ex"
  22. func init() {
  23. // register custom driver with regexp function
  24. registerCustomDriver()
  25. }
  26. // Initialize initializes the database. If the database is new, then it
  27. // performs a full migration to the latest schema version. Otherwise, any
  28. // necessary migrations must be run separately using RunMigrations.
  29. // Returns true if the database is new.
  30. func Initialize(databasePath string) bool {
  31. dbPath = databasePath
  32. if err := getDatabaseSchemaVersion(); err != nil {
  33. panic(err)
  34. }
  35. if databaseSchemaVersion == 0 {
  36. // new database, just run the migrations
  37. if err := RunMigrations(); err != nil {
  38. panic(err)
  39. }
  40. // RunMigrations calls Initialise. Just return
  41. return true
  42. } else {
  43. if databaseSchemaVersion > appSchemaVersion {
  44. panic(fmt.Sprintf("Database schema version %d is incompatible with required schema version %d", databaseSchemaVersion, appSchemaVersion))
  45. }
  46. // if migration is needed, then don't open the connection
  47. if NeedsMigration() {
  48. logger.Warnf("Database schema version %d does not match required schema version %d.", databaseSchemaVersion, appSchemaVersion)
  49. return false
  50. }
  51. }
  52. const disableForeignKeys = false
  53. DB = open(databasePath, disableForeignKeys)
  54. return false
  55. }
  56. func open(databasePath string, disableForeignKeys bool) *sqlx.DB {
  57. // https://github.com/mattn/go-sqlite3
  58. url := "file:" + databasePath
  59. if !disableForeignKeys {
  60. url += "?_fk=true"
  61. }
  62. conn, err := sqlx.Open(sqlite3Driver, url)
  63. conn.SetMaxOpenConns(25)
  64. conn.SetMaxIdleConns(4)
  65. if err != nil {
  66. logger.Fatalf("db.Open(): %q\n", err)
  67. }
  68. return conn
  69. }
  70. func Reset(databasePath string) error {
  71. err := DB.Close()
  72. if err != nil {
  73. return errors.New("Error closing database: " + err.Error())
  74. }
  75. err = os.Remove(databasePath)
  76. if err != nil {
  77. return errors.New("Error removing database: " + err.Error())
  78. }
  79. Initialize(databasePath)
  80. return nil
  81. }
  82. // Backup the database
  83. func Backup(backupPath string) error {
  84. db, err := sqlx.Connect(sqlite3Driver, "file:"+dbPath+"?_fk=true")
  85. if err != nil {
  86. return fmt.Errorf("Open database %s failed:%s", dbPath, err)
  87. }
  88. defer db.Close()
  89. logger.Infof("Backing up database into: %s", backupPath)
  90. _, err = db.Exec(`VACUUM INTO "` + backupPath + `"`)
  91. if err != nil {
  92. return fmt.Errorf("Vacuum failed: %s", err)
  93. }
  94. return nil
  95. }
  96. func RestoreFromBackup(backupPath string) error {
  97. logger.Infof("Restoring backup database %s into %s", backupPath, dbPath)
  98. return os.Rename(backupPath, dbPath)
  99. }
  100. // Migrate the database
  101. func NeedsMigration() bool {
  102. return databaseSchemaVersion != appSchemaVersion
  103. }
  104. func AppSchemaVersion() uint {
  105. return appSchemaVersion
  106. }
  107. func DatabaseBackupPath() string {
  108. return fmt.Sprintf("%s.%d.%s", dbPath, databaseSchemaVersion, time.Now().Format("20060102_150405"))
  109. }
  110. func Version() uint {
  111. return databaseSchemaVersion
  112. }
  113. func getMigrate() (*migrate.Migrate, error) {
  114. migrationsBox := packr.New("Migrations Box", "./migrations")
  115. packrSource := &Packr2Source{
  116. Box: migrationsBox,
  117. Migrations: source.NewMigrations(),
  118. }
  119. databasePath := utils.FixWindowsPath(dbPath)
  120. s, _ := WithInstance(packrSource)
  121. const disableForeignKeys = true
  122. conn := open(databasePath, disableForeignKeys)
  123. driver, err := sqlite3mig.WithInstance(conn.DB, &sqlite3mig.Config{})
  124. if err != nil {
  125. return nil, err
  126. }
  127. // use sqlite3Driver so that migration has access to durationToTinyInt
  128. return migrate.NewWithInstance(
  129. "packr2",
  130. s,
  131. databasePath,
  132. driver,
  133. )
  134. }
  135. func getDatabaseSchemaVersion() error {
  136. m, err := getMigrate()
  137. if err != nil {
  138. return err
  139. }
  140. databaseSchemaVersion, _, _ = m.Version()
  141. m.Close()
  142. return nil
  143. }
  144. // Migrate the database
  145. func RunMigrations() error {
  146. m, err := getMigrate()
  147. if err != nil {
  148. panic(err.Error())
  149. }
  150. databaseSchemaVersion, _, _ = m.Version()
  151. stepNumber := appSchemaVersion - databaseSchemaVersion
  152. if stepNumber != 0 {
  153. logger.Infof("Migrating database from version %d to %d", databaseSchemaVersion, appSchemaVersion)
  154. err = m.Steps(int(stepNumber))
  155. if err != nil {
  156. // migration failed
  157. logger.Errorf("Error migrating database: %s", err.Error())
  158. m.Close()
  159. return err
  160. }
  161. }
  162. m.Close()
  163. // re-initialise the database
  164. Initialize(dbPath)
  165. // run a vacuum on the database
  166. logger.Info("Performing vacuum on database")
  167. _, err = DB.Exec("VACUUM")
  168. if err != nil {
  169. logger.Warnf("error while performing post-migration vacuum: %s", err.Error())
  170. }
  171. return nil
  172. }
  173. func registerCustomDriver() {
  174. sql.Register(sqlite3Driver,
  175. &sqlite3.SQLiteDriver{
  176. ConnectHook: func(conn *sqlite3.SQLiteConn) error {
  177. funcs := map[string]interface{}{
  178. "regexp": regexFn,
  179. "durationToTinyInt": durationToTinyIntFn,
  180. }
  181. for name, fn := range funcs {
  182. if err := conn.RegisterFunc(name, fn, true); err != nil {
  183. return fmt.Errorf("Error registering function %s: %s", name, err.Error())
  184. }
  185. }
  186. return nil
  187. },
  188. },
  189. )
  190. }