PageRenderTime 26ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/migrate.go

https://bitbucket.org/snormore/goose
Go | 278 lines | 193 code | 59 blank | 26 comment | 59 complexity | 398c08adb50df481e1c85d1d18024289 MD5 | raw file
  1. package main
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. _ "github.com/lib/pq"
  7. _ "github.com/ziutek/mymysql/godrv"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "sort"
  12. "strconv"
  13. "strings"
  14. "time"
  15. )
  16. var ErrTableDoesNotExist = errors.New("table does not exist")
  17. type MigrationRecord struct {
  18. VersionId int64
  19. TStamp time.Time
  20. IsApplied bool // was this a result of up() or down()
  21. }
  22. type Migration struct {
  23. Version int64
  24. Next int64 // next version, or -1 if none
  25. Previous int64 // previous version, -1 if none
  26. Source string // .go or .sql script
  27. }
  28. type MigrationSlice []Migration
  29. // helpers so we can use pkg sort
  30. func (s MigrationSlice) Len() int { return len(s) }
  31. func (s MigrationSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  32. func (s MigrationSlice) Less(i, j int) bool { return s[i].Version < s[j].Version }
  33. type MigrationMap struct {
  34. Migrations MigrationSlice // migrations, sorted according to Direction
  35. Direction bool // sort direction: true -> Up, false -> Down
  36. }
  37. func runMigrations(conf *DBConf, migrationsDir string, target int64) {
  38. db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
  39. if err != nil {
  40. log.Fatal("couldn't open DB:", err)
  41. }
  42. defer db.Close()
  43. current, e := ensureDBVersion(conf, db)
  44. if e != nil {
  45. log.Fatalf("couldn't get DB version: %v", e)
  46. }
  47. mm, err := collectMigrations(migrationsDir, current, target)
  48. if err != nil {
  49. log.Fatal(err)
  50. }
  51. if len(mm.Migrations) == 0 {
  52. fmt.Printf("goose: no migrations to run. current version: %d\n", current)
  53. return
  54. }
  55. mm.Sort(current < target)
  56. fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n",
  57. conf.Env, current, target)
  58. for _, m := range mm.Migrations {
  59. var e error
  60. switch filepath.Ext(m.Source) {
  61. case ".go":
  62. e = runGoMigration(conf, m.Source, m.Version, mm.Direction)
  63. case ".sql":
  64. e = runSQLMigration(db, m.Source, m.Version, mm.Direction)
  65. }
  66. if e != nil {
  67. log.Fatalf("FAIL %v, quitting migration", e)
  68. }
  69. fmt.Println("OK ", filepath.Base(m.Source))
  70. }
  71. }
  72. // collect all the valid looking migration scripts in the
  73. // migrations folder, and key them by version
  74. func collectMigrations(dirpath string, current, target int64) (mm *MigrationMap, err error) {
  75. mm = &MigrationMap{}
  76. // extract the numeric component of each migration,
  77. // filter out any uninteresting files,
  78. // and ensure we only have one file per migration version.
  79. filepath.Walk(dirpath, func(name string, info os.FileInfo, err error) error {
  80. if v, e := numericComponent(name); e == nil {
  81. for _, m := range mm.Migrations {
  82. if v == m.Version {
  83. log.Fatalf("more than one file specifies the migration for version %d (%s and %s)",
  84. v, m.Source, filepath.Join(dirpath, name))
  85. }
  86. }
  87. if versionFilter(v, current, target) {
  88. mm.Append(v, name)
  89. }
  90. }
  91. return nil
  92. })
  93. return mm, nil
  94. }
  95. func versionFilter(v, current, target int64) bool {
  96. if target > current {
  97. return v > current && v <= target
  98. }
  99. if target < current {
  100. return v <= current && v > target
  101. }
  102. return false
  103. }
  104. func (mm *MigrationMap) Append(v int64, source string) {
  105. mm.Migrations = append(mm.Migrations, Migration{
  106. Version: v,
  107. Next: -1,
  108. Previous: -1,
  109. Source: source,
  110. })
  111. }
  112. func (mm *MigrationMap) Sort(direction bool) {
  113. sort.Sort(mm.Migrations)
  114. // set direction, and reverse order if need be
  115. mm.Direction = direction
  116. if mm.Direction == false {
  117. for i, j := 0, len(mm.Migrations)-1; i < j; i, j = i+1, j-1 {
  118. mm.Migrations[i], mm.Migrations[j] = mm.Migrations[j], mm.Migrations[i]
  119. }
  120. }
  121. // now that we're sorted in the appropriate direction,
  122. // populate next and previous for each migration
  123. for i, m := range mm.Migrations {
  124. prev := int64(-1)
  125. if i > 0 {
  126. prev = mm.Migrations[i-1].Version
  127. mm.Migrations[i-1].Next = m.Version
  128. }
  129. mm.Migrations[i].Previous = prev
  130. }
  131. }
  132. // look for migration scripts with names in the form:
  133. // XXX_descriptivename.ext
  134. // where XXX specifies the version number
  135. // and ext specifies the type of migration
  136. func numericComponent(name string) (int64, error) {
  137. base := filepath.Base(name)
  138. if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
  139. return 0, errors.New("not a recognized migration file type")
  140. }
  141. idx := strings.Index(base, "_")
  142. if idx < 0 {
  143. return 0, errors.New("no separator found")
  144. }
  145. n, e := strconv.ParseInt(base[:idx], 10, 64)
  146. if e == nil && n <= 0 {
  147. return 0, errors.New("migration IDs must be greater than zero")
  148. }
  149. return n, e
  150. }
  151. // retrieve the current version for this DB.
  152. // Create and initialize the DB version table if it doesn't exist.
  153. func ensureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
  154. rows, err := conf.Driver.Dialect.dbVersionQuery(db)
  155. if err != nil {
  156. if err == ErrTableDoesNotExist {
  157. return 0, createVersionTable(conf, db)
  158. }
  159. return 0, err
  160. }
  161. // The most recent record for each migration specifies
  162. // whether it has been applied or rolled back.
  163. // The first version we find that has been applied is the current version.
  164. toSkip := make([]int64, 0)
  165. for rows.Next() {
  166. var row MigrationRecord
  167. if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
  168. log.Fatal("error scanning rows:", err)
  169. }
  170. // have we already marked this version to be skipped?
  171. skip := false
  172. for _, v := range toSkip {
  173. if v == row.VersionId {
  174. skip = true
  175. break
  176. }
  177. }
  178. // if version has been applied and not marked to be skipped, we're done
  179. if row.IsApplied && !skip {
  180. return row.VersionId, nil
  181. }
  182. // version is either not applied, or we've already seen a more
  183. // recent version of it that was not applied.
  184. if !skip {
  185. toSkip = append(toSkip, row.VersionId)
  186. }
  187. }
  188. panic("failure in ensureDBVersion()")
  189. }
  190. // Create the goose_db_version table
  191. // and insert the initial 0 value into it
  192. func createVersionTable(conf *DBConf, db *sql.DB) error {
  193. txn, err := db.Begin()
  194. if err != nil {
  195. return err
  196. }
  197. d := conf.Driver.Dialect
  198. for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} {
  199. if _, err := txn.Exec(str); err != nil {
  200. txn.Rollback()
  201. return err
  202. }
  203. }
  204. return txn.Commit()
  205. }
  206. // wrapper for ensureDBVersion for callers that don't already have
  207. // their own DB instance
  208. func getDBVersion(conf *DBConf) int64 {
  209. db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
  210. if err != nil {
  211. log.Fatal("couldn't open DB:", err)
  212. }
  213. defer db.Close()
  214. version, err := ensureDBVersion(conf, db)
  215. if err != nil {
  216. log.Fatalf("couldn't get DB version: %v", err)
  217. }
  218. return version
  219. }