/translators/mysql.go

https://github.com/gobuffalo/fizz · Go · 300 lines · 257 code · 40 blank · 3 comment · 64 complexity · 2d7b2c79f65af710e57a6eac9c47dd0b MD5 · raw file

  1. package translators
  2. import (
  3. "fmt"
  4. "regexp"
  5. "strings"
  6. "github.com/gobuffalo/fizz"
  7. )
  8. // MySQL is a MySQL-specific translator.
  9. type MySQL struct {
  10. Schema SchemaQuery
  11. strDefaultSize int
  12. }
  13. // NewMySQL constructs a new MySQL translator.
  14. func NewMySQL(url, name string) *MySQL {
  15. schema := &mysqlSchema{Schema{URL: url, Name: name, schema: map[string]*fizz.Table{}}}
  16. schema.Builder = schema
  17. return &MySQL{
  18. Schema: schema,
  19. strDefaultSize: 255,
  20. }
  21. }
  22. func (MySQL) Name() string {
  23. return "mysql"
  24. }
  25. // CreateTable translates a fizz Table to its MySQL SQL definition.
  26. func (p *MySQL) CreateTable(t fizz.Table) (string, error) {
  27. sql := []string{}
  28. cols := []string{}
  29. for _, c := range t.Columns {
  30. cols = append(cols, p.buildColumn(c))
  31. if c.Primary {
  32. cols = append(cols, fmt.Sprintf("PRIMARY KEY(`%s`)", c.Name))
  33. }
  34. }
  35. for _, fk := range t.ForeignKeys {
  36. cols = append(cols, p.buildForeignKey(t, fk, true))
  37. }
  38. primaryKeys := t.PrimaryKeys()
  39. if len(primaryKeys) > 1 {
  40. pks := make([]string, len(primaryKeys))
  41. for i, pk := range primaryKeys {
  42. pks[i] = fmt.Sprintf("`%s`", pk)
  43. }
  44. cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
  45. }
  46. s := fmt.Sprintf("CREATE TABLE %s (\n%s\n) ENGINE=InnoDB;", p.escapeIdentifier(t.Name), strings.Join(cols, ",\n"))
  47. sql = append(sql, s)
  48. for _, i := range t.Indexes {
  49. s, err := p.AddIndex(fizz.Table{
  50. Name: t.Name,
  51. Indexes: []fizz.Index{i},
  52. })
  53. if err != nil {
  54. return "", err
  55. }
  56. sql = append(sql, s)
  57. }
  58. return strings.Join(sql, "\n"), nil
  59. }
  60. func (p *MySQL) DropTable(t fizz.Table) (string, error) {
  61. return fmt.Sprintf("DROP TABLE %s;", p.escapeIdentifier(t.Name)), nil
  62. }
  63. func (p *MySQL) RenameTable(t []fizz.Table) (string, error) {
  64. if len(t) < 2 {
  65. return "", fmt.Errorf("not enough table names supplied")
  66. }
  67. return fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", p.escapeIdentifier(t[0].Name), p.escapeIdentifier(t[1].Name)), nil
  68. }
  69. func (p *MySQL) ChangeColumn(t fizz.Table) (string, error) {
  70. if len(t.Columns) == 0 {
  71. return "", fmt.Errorf("not enough columns supplied")
  72. }
  73. c := t.Columns[0]
  74. s := fmt.Sprintf("ALTER TABLE %s MODIFY %s;", p.escapeIdentifier(t.Name), p.buildColumn(c))
  75. return s, nil
  76. }
  77. func (p *MySQL) AddColumn(t fizz.Table) (string, error) {
  78. if len(t.Columns) == 0 {
  79. return "", fmt.Errorf("not enough columns supplied")
  80. }
  81. if _, ok := t.Columns[0].Options["first"]; ok {
  82. c := t.Columns[0]
  83. s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s FIRST;", p.escapeIdentifier(t.Name), p.buildColumn(c))
  84. return s, nil
  85. }
  86. if val, ok := t.Columns[0].Options["after"]; ok {
  87. c := t.Columns[0]
  88. s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s AFTER `%s`;", p.escapeIdentifier(t.Name), p.buildColumn(c), val)
  89. return s, nil
  90. }
  91. c := t.Columns[0]
  92. s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", p.escapeIdentifier(t.Name), p.buildColumn(c))
  93. return s, nil
  94. }
  95. func (p *MySQL) DropColumn(t fizz.Table) (string, error) {
  96. if len(t.Columns) == 0 {
  97. return "", fmt.Errorf("not enough columns supplied")
  98. }
  99. c := t.Columns[0]
  100. return fmt.Sprintf("ALTER TABLE %s DROP COLUMN `%s`;", p.escapeIdentifier(t.Name), c.Name), nil
  101. }
  102. func (p *MySQL) RenameColumn(t fizz.Table) (string, error) {
  103. if len(t.Columns) < 2 {
  104. return "", fmt.Errorf("not enough columns supplied")
  105. }
  106. oc := t.Columns[0]
  107. nc := t.Columns[1]
  108. ti, err := p.Schema.TableInfo(t.Name)
  109. if err != nil {
  110. return "", err
  111. }
  112. var c fizz.Column
  113. for _, c = range ti.Columns {
  114. if c.Name == oc.Name {
  115. break
  116. }
  117. }
  118. col := p.buildColumn(c)
  119. col = strings.Replace(col, oc.Name, fmt.Sprintf("%s` `%s", oc.Name, nc.Name), -1)
  120. s := fmt.Sprintf("ALTER TABLE %s CHANGE %s;", p.escapeIdentifier(t.Name), col)
  121. return s, nil
  122. }
  123. func (p *MySQL) AddIndex(t fizz.Table) (string, error) {
  124. if len(t.Indexes) == 0 {
  125. return "", fmt.Errorf("not enough indexes supplied")
  126. }
  127. i := t.Indexes[0]
  128. cols := []string{}
  129. for _, c := range i.Columns {
  130. cols = append(cols, fmt.Sprintf("`%s`", c))
  131. }
  132. s := fmt.Sprintf("CREATE INDEX `%s` ON %s (%s);", i.Name, p.escapeIdentifier(t.Name), strings.Join(cols, ", "))
  133. if i.Unique {
  134. s = strings.Replace(s, "CREATE", "CREATE UNIQUE", 1)
  135. }
  136. return s, nil
  137. }
  138. func (p *MySQL) DropIndex(t fizz.Table) (string, error) {
  139. if len(t.Indexes) == 0 {
  140. return "", fmt.Errorf("not enough indexes supplied")
  141. }
  142. i := t.Indexes[0]
  143. return fmt.Sprintf("DROP INDEX `%s` ON %s;", i.Name, p.escapeIdentifier(t.Name)), nil
  144. }
  145. func (p *MySQL) RenameIndex(t fizz.Table) (string, error) {
  146. schema := p.Schema.(*mysqlSchema)
  147. version, err := schema.Version()
  148. if err != nil {
  149. return "", err
  150. }
  151. if version.LessThan(mysql57Version) {
  152. return "", fmt.Errorf("renaming indexes on MySQL versions less than 5.7 is not supported by fizz; use raw SQL instead")
  153. }
  154. ix := t.Indexes
  155. if len(ix) < 2 {
  156. return "", fmt.Errorf("not enough indexes supplied")
  157. }
  158. oi := ix[0]
  159. ni := ix[1]
  160. return fmt.Sprintf("ALTER TABLE %s RENAME INDEX `%s` TO `%s`;", p.escapeIdentifier(t.Name), oi.Name, ni.Name), nil
  161. }
  162. func (p *MySQL) AddForeignKey(t fizz.Table) (string, error) {
  163. if len(t.ForeignKeys) == 0 {
  164. return "", fmt.Errorf("not enough foreign keys supplied")
  165. }
  166. return p.buildForeignKey(t, t.ForeignKeys[0], false), nil
  167. }
  168. func (p *MySQL) DropForeignKey(t fizz.Table) (string, error) {
  169. if len(t.ForeignKeys) == 0 {
  170. return "", fmt.Errorf("not enough foreign keys supplied")
  171. }
  172. fk := t.ForeignKeys[0]
  173. var ifExists string
  174. if v, ok := fk.Options["if_exists"]; ok && v.(bool) {
  175. ifExists = "IF EXISTS "
  176. }
  177. s := fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s`%s`;", p.escapeIdentifier(t.Name), ifExists, fk.Name)
  178. return s, nil
  179. }
  180. func (p *MySQL) buildColumn(c fizz.Column) string {
  181. s := fmt.Sprintf("`%s` %s", c.Name, p.colType(c))
  182. if c.Options["null"] == nil || c.Primary {
  183. s = fmt.Sprintf("%s NOT NULL", s)
  184. }
  185. if c.Options["default"] != nil {
  186. d := fmt.Sprintf("%#v", c.Options["default"])
  187. re := regexp.MustCompile("^(\")(.+)(\")$")
  188. d = re.ReplaceAllString(d, "'$2'")
  189. s = fmt.Sprintf("%s DEFAULT %s", s, d)
  190. }
  191. if c.Options["default_raw"] != nil {
  192. d := fmt.Sprintf("%s", c.Options["default_raw"])
  193. s = fmt.Sprintf("%s DEFAULT %s", s, d)
  194. }
  195. if c.Primary && (c.ColType == "integer" || strings.ToLower(c.ColType) == "int") {
  196. s = fmt.Sprintf("%s AUTO_INCREMENT", s)
  197. }
  198. return s
  199. }
  200. func (p *MySQL) colType(c fizz.Column) string {
  201. switch strings.ToLower(c.ColType) {
  202. case "string":
  203. s := fmt.Sprintf("%d", p.strDefaultSize)
  204. if c.Options["size"] != nil {
  205. s = fmt.Sprintf("%d", c.Options["size"])
  206. }
  207. return fmt.Sprintf("VARCHAR (%s)", s)
  208. case "uuid":
  209. return "char(36)"
  210. case "timestamp", "time", "datetime":
  211. return "DATETIME"
  212. case "blob", "[]byte":
  213. return "BLOB"
  214. case "int", "integer":
  215. return "INTEGER"
  216. case "float", "double", "decimal", "numeric":
  217. colType := strings.ToUpper(c.ColType)
  218. if c.Options["precision"] != nil {
  219. precision := c.Options["precision"]
  220. if c.Options["scale"] != nil {
  221. scale := c.Options["scale"]
  222. return fmt.Sprintf("%s(%d,%d)", colType, precision, scale)
  223. }
  224. return fmt.Sprintf("%s(%d)", colType, precision)
  225. }
  226. return colType
  227. case "json":
  228. return "JSON"
  229. default:
  230. return c.ColType
  231. }
  232. }
  233. func (p *MySQL) buildForeignKey(t fizz.Table, fk fizz.ForeignKey, onCreate bool) string {
  234. rcols := []string{}
  235. for _, c := range fk.References.Columns {
  236. rcols = append(rcols, fmt.Sprintf("`%s`", c))
  237. }
  238. refs := fmt.Sprintf("%s (%s)", p.escapeIdentifier(fk.References.Table), strings.Join(rcols, ", "))
  239. s := fmt.Sprintf("FOREIGN KEY (`%s`) REFERENCES %s", fk.Column, refs)
  240. if onUpdate, ok := fk.Options["on_update"]; ok {
  241. s += fmt.Sprintf(" ON UPDATE %s", onUpdate)
  242. }
  243. if onDelete, ok := fk.Options["on_delete"]; ok {
  244. s += fmt.Sprintf(" ON DELETE %s", onDelete)
  245. }
  246. if !onCreate {
  247. s = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT `%s` %s;", p.escapeIdentifier(t.Name), fk.Name, s)
  248. }
  249. return s
  250. }
  251. func (MySQL) escapeIdentifier(s string) string {
  252. if !strings.ContainsRune(s, '.') {
  253. return fmt.Sprintf("`%s`", s)
  254. }
  255. parts := strings.Split(s, ".")
  256. for _, p := range parts {
  257. p = fmt.Sprintf("`%s`", p)
  258. }
  259. return strings.Join(parts, ".")
  260. }