/src/main/scala/ru/yandex/mysqlDiff/script/ScriptSerializer.scala

https://bitbucket.org/stepancheg/mysql-diff/ · Scala · 401 lines · 320 code · 70 blank · 11 comment · 17 complexity · d39993d5fceaffa565c622612c4d9201 MD5 · raw file

  1. package ru.yandex.mysqlDiff
  2. package script
  3. import scala.collection.mutable.ArrayBuffer
  4. import scala.util.Sorting._
  5. import model._
  6. // XXX: drop MySQL
  7. import vendor.mysql._
  8. import Implicits._
  9. object ScriptSerializer {
  10. case class Options(multiline: Boolean = false) {
  11. def stmtJoin =
  12. if (multiline) "\n"
  13. else " "
  14. }
  15. object Options {
  16. val singleline = new Options()
  17. val multiline = singleline.copy(true)
  18. val default = singleline
  19. }
  20. }
  21. /**
  22. * Serialize Script objects to text (String).
  23. */
  24. class ScriptSerializer(context: Context) {
  25. import context._
  26. import TableDdlStatement._
  27. import ScriptSerializer._
  28. /** Serializer options */
  29. def serialize(stmts: Seq[ScriptElement], options: Options): String = {
  30. def serializeInList(stmt: ScriptElement) = {
  31. val tail = stmt match {
  32. case _: ScriptStatement => ";"
  33. case _ => ""
  34. }
  35. serialize(stmt, options) + tail
  36. }
  37. val stmtTail =
  38. if (options.multiline)
  39. "\n"
  40. else
  41. ""
  42. stmts.map(serializeInList _).mkString(options.stmtJoin) + stmtTail
  43. }
  44. def serialize(stmt: ScriptElement, options: Options): String = stmt match {
  45. case s: ScriptStatement => serializeStatement(s, options)
  46. case Unparsed(q) => q
  47. case CommentElement(c) => c
  48. }
  49. def serialize(stmt: ScriptElement): String = serialize(stmt, Options.singleline)
  50. def isKeyword(name: String) =
  51. ScriptConstants.isSql2003Keyword(name)
  52. def quoteName(name: String) = '"' + name + '"'
  53. def serializeName(name: String) =
  54. if (isKeyword(name)) quoteName(name)
  55. else name
  56. def serializeString(string: String) =
  57. "'" + string.replace("'", "''") + "'"
  58. def serializeTableDdlStatement(stmt: TableDdlStatement, options: Options): String = stmt match {
  59. case st: CreateTableStatement => serializeCreateTable(st, options)
  60. case dt: DropTableStatement => serializeDropTable(dt)
  61. case st: AlterTableStatement => serializeChangeTable(st)
  62. }
  63. def serializeSequenceDdlStatement(stmt: SequenceDdlStatement, options: Options): String = stmt match {
  64. case CreateSequenceStatement(name: String) => "CREATE SEQUENCE " + serializeName(name)
  65. case DropSequenceStatement(name: String) => "DROP SEQUENCE " + serializeName(name)
  66. }
  67. def serializeCreateIndexStatement(stmt: CreateIndexStatement) = {
  68. val CreateIndexStatement(name, tableName, columns) = stmt
  69. "CREATE INDEX " + serializeName(name) +
  70. " ON " + serializeName(tableName) + " (" + columns.map(serializeIndexColumn _).mkString(", ") + ")"
  71. }
  72. def serializeDropIndexStatement(stmt: DropIndexStatement) =
  73. "DROP INDEX " + serializeName(stmt.name)
  74. def serializeIndexDdlStatement(stmt: IndexDdlStatement, options: Options): String = stmt match {
  75. case stmt: CreateIndexStatement => serializeCreateIndexStatement(stmt)
  76. case stmt: DropIndexStatement => serializeDropIndexStatement(stmt)
  77. }
  78. def serializeDdlStatement(stmt: DdlStatement, options: Options): String = stmt match {
  79. case stmt: TableDdlStatement => serializeTableDdlStatement(stmt, options)
  80. case stmt: SequenceDdlStatement => serializeSequenceDdlStatement(stmt, options)
  81. case stmt: IndexDdlStatement => serializeIndexDdlStatement(stmt, options)
  82. }
  83. def serializeStatement(stmt: ScriptStatement, options: Options): String = stmt match {
  84. case ts: DdlStatement => serializeDdlStatement(ts, options)
  85. case is: InsertStatement => serializeInsert(is)
  86. }
  87. def serializeValue(value: SqlValue): String = value match {
  88. case NullValue => "NULL"
  89. case NumberValue(number) => number.toString
  90. case StringValue(string) => serializeString(string)
  91. case BooleanValue(true) => "TRUE"
  92. case BooleanValue(false) => "FALSE"
  93. case NowValue => "NOW()"
  94. case t: TemporalValue =>
  95. val n = t match {
  96. case _: TimestampValue => "TIMESTAMP"
  97. case _: TimestampWithTimeZoneValue => "TIMESTAMP WITHOUT TIME ZONE"
  98. case _: TimeValue => "TIME"
  99. case _: TimeWithTimeZoneValue => "TIME WITHOUT TIME ZONE"
  100. case _: DateValue => "DATE"
  101. }
  102. n + " '" + t.value + "'"
  103. }
  104. def serializeCast(cast: CastExpr) = {
  105. val CastExpr(e, as) = cast
  106. "CAST(" + serializeExpr(e) + " AS " + serializeDataType(as) + ")"
  107. }
  108. def serializeFunctionCall(fc: FunctionCallExpr) =
  109. fc.name + "(" + fc.params.map(serializeExpr _).mkString(", ") + ")"
  110. def serializeExpr(expr: SqlExpr): String = expr match {
  111. case v: SqlValue => serializeValue(v)
  112. case s: CastExpr => serializeCast(s)
  113. case f: FunctionCallExpr => serializeFunctionCall(f)
  114. }
  115. // XXX: drop mysql
  116. def serializeModelColumnProperty(cp: ColumnProperty): Option[String] = cp match {
  117. case vendor.mysql.MysqlAutoIncrement(true) => Some("AUTO_INCREMENT")
  118. case vendor.mysql.MysqlAutoIncrement(false) => None
  119. case Nullability(true) => Some("NULL")
  120. case Nullability(false) => Some("NOT NULL")
  121. case DefaultValue(value) => Some("DEFAULT " + serializeExpr(value))
  122. }
  123. def serializeImportedKeyRule(p: ImportedKeyRule) = p match {
  124. case ImportedKeyNoAction => "NO ACTION"
  125. case ImportedKeyCascade => "CASCADE"
  126. case ImportedKeySetNull => "SET NULL"
  127. case ImportedKeySetDefault => "SET DEFAULT"
  128. }
  129. def serializeColumnProperty(cp: ColumnPropertyDecl): Option[String] = cp match {
  130. case ModelColumnProperty(cp) => serializeModelColumnProperty(cp)
  131. case InlineUnique => Some("UNIQUE")
  132. case InlinePrimaryKey => Some("PRIMARY KEY")
  133. case InlineReferences(References(table, Seq(column), updateRule, deleteRule)) =>
  134. val words = new ArrayBuffer[String]
  135. words += "REFERENCES " + serializeName(table) + "(" + serializeName(column) + ")"
  136. words ++= updateRule.map(p => "ON UPDATE " + serializeImportedKeyRule(p))
  137. words ++= deleteRule.map(p => "ON DELETE " + serializeImportedKeyRule(p))
  138. Some(words.mkString(" "))
  139. }
  140. def serializeColumnProperty(cp: ColumnPropertyDecl, c: Column): Option[String] =
  141. cp match {
  142. // MySQL does not support NOT NULL DEFAULT NULL
  143. case ModelColumnProperty(DefaultValue(NullValue)) if c.isNotNull => None
  144. case _ => serializeColumnProperty(cp)
  145. }
  146. def serializeConstraint(c: Constraint) =
  147. c match {
  148. case PrimaryKey(pk) => serializePrimaryKey(pk)
  149. case ForeignKey(fk) => serializeForeignKey(fk)
  150. case UniqueKey(u) => serializeUniqueKey(u)
  151. }
  152. def serializeTableElement(e: TableElement): String = e match {
  153. case c @ Column(name, dataType, attrs) =>
  154. serializeName(name) + " " + serializeDataType(dataType) +
  155. (if (attrs.isEmpty) ""
  156. else " " + attrs.flatMap(cp => serializeColumnProperty(cp, c)).mkString(" "))
  157. case Index(index) => serializeIndex(index)
  158. case LikeClause(name) => "LIKE " + serializeName(name)
  159. case c: Constraint => serializeConstraint(c)
  160. }
  161. def serializeCreateTable(ct: CreateTableStatement, options: Options): String = {
  162. val CreateTableStatement(name, ifNotExists, TableElementList(elements), tableOptions) = ct
  163. def mapTableElement(e: TableElement) =
  164. serializeTableElement(e)
  165. val l = elements.map(mapTableElement _).reverse
  166. val (afterComma, indent) =
  167. if (options.multiline) ("\n", " ")
  168. else (" ", "")
  169. val lines = (List(l.head) ++ l.drop(1).map(_ + "," + afterComma)).reverse.map(indent + _)
  170. val firstLine = "CREATE TABLE " + serializeName(name) + " ("
  171. val lastLine = ")" +
  172. (tableOptions.flatMap(serializeCreateTableTableOption _) match {
  173. case Seq() => ""
  174. case l => " " + l.mkString(" ")
  175. })
  176. (List(firstLine) ++ lines ++ List(lastLine)).mkString(if (options.multiline) "\n" else "")
  177. }
  178. protected def serializeCreateTableTableOption(opt: TableOption): Option[String] =
  179. Some(serializeTableOption(opt))
  180. protected def serializeTableOption(opt: TableOption): String =
  181. throw new MysqlDiffException("unknown table option: "+ opt)
  182. def serializeInsert(is: InsertStatement) = {
  183. val r = new ArrayBuffer[String]
  184. r += "INSERT"
  185. if (is.ignore) r += "IGNORE"
  186. r += "INTO"
  187. r += serializeName(is.table)
  188. if (is.columns.isDefined)
  189. r += ("(" + is.columns.get.mkString(", ") + ")")
  190. r += "VALUES"
  191. r += is.data.map(row => "(" + row.map(serializeExpr _).mkString(", ") + ")").mkString(", ")
  192. r.mkString(" ")
  193. // XXX: untested
  194. }
  195. def serializeDropTable(dt: DropTableStatement) = {
  196. val DropTableStatement(name, ifExists) = dt
  197. val words = new ArrayBuffer[String]
  198. words += "DROP TABLE"
  199. if (ifExists) words += "IF EXISTS"
  200. words += serializeName(name)
  201. words.mkString(" ")
  202. }
  203. def serializeChangeTable(st: AlterTableStatement) =
  204. "ALTER TABLE " + serializeName(st.name) + " " +
  205. st.ops.map(serializeAlterTableOperation(_)).mkString(", ")
  206. def serializeColumnPosition(position: ColumnPosition) = position match {
  207. case ColumnFirst => "FIRST"
  208. case ColumnAfter(column) => "AFTER " + serializeName(column)
  209. }
  210. def serializeAddColumn(ac: AddColumn) = {
  211. val AddColumn(column, pos) = ac
  212. val words = new ArrayBuffer[String]
  213. words += "ADD COLUMN"
  214. words += serializeTableElement(column)
  215. words ++= pos.map(serializeColumnPosition _)
  216. words.mkString(" ")
  217. }
  218. def serializeChangeColumn(cc: ChangeColumn) = {
  219. val ChangeColumn(oldName, column, pos) = cc
  220. val words = new ArrayBuffer[String]
  221. words += "CHANGE COLUMN"
  222. words += serializeName(oldName)
  223. words += serializeColumn(column)
  224. words ++= pos.map(serializeColumnPosition _)
  225. words.mkString(" ")
  226. }
  227. def serializeModifyColumn(mc: ModifyColumn) = {
  228. val ModifyColumn(column, pos) = mc
  229. val words = new ArrayBuffer[String]
  230. words += "MODIFY COLUMN"
  231. words += serializeColumn(column)
  232. words ++= pos.map(serializeColumnPosition _)
  233. words.mkString(" ")
  234. }
  235. def serializeAlterColumnOperation(op: AlterColumnOperation) = op match {
  236. case SetNotNull(true) => "SET NOT NULL"
  237. case SetNotNull(false) => "DROP NOT NULL"
  238. case SetDefault(Some(value)) => "SET DEFAULT " + serializeExpr(value)
  239. case SetDefault(None) => "DROP DEFAULT"
  240. }
  241. def serializeAlterTableOperation(op: Operation) = op match {
  242. case ac: AddColumn => serializeAddColumn(ac)
  243. case AddExtra(e) => "ADD " + serializeTableElement(e)
  244. case cc: ChangeColumn => serializeChangeColumn(cc)
  245. case mc: ModifyColumn => serializeModifyColumn(mc)
  246. case DropColumn(name) => "DROP COLUMN " + serializeName(name)
  247. case AlterColumn(name, op) => "ALTER COLUMN " + serializeName(name) + " " + serializeAlterColumnOperation(op)
  248. case DropConstraint(name) => "DROP CONSTRAINT " + serializeName(name)
  249. case DropPrimaryKey => "DROP PRIMARY KEY"
  250. case DropIndex(name) => "DROP INDEX " + serializeName(name)
  251. case DropForeignKey(name) => "DROP FOREIGN KEY " + serializeName(name)
  252. case DropUniqueKey(name) => "DROP KEY " + serializeName(name)
  253. case ChangeTableOption(o) => serializeTableOption(o)
  254. }
  255. def serializeDefaultDataType(dataType: DefaultDataType) =
  256. dataType.name + dataType.length.map("(" + _ + ")").getOrElse("")
  257. def serializeDataType(dataType: DataType) = (dataType: @unchecked) match {
  258. case dataType: DefaultDataType => serializeDefaultDataType(dataType)
  259. case NumericDataType(None, None) => "NUMERIC"
  260. case NumericDataType(Some(precision), None) => "NUMERIC(" + precision + ")"
  261. case NumericDataType(Some(precision), Some(scale)) => "NUMERIC(" + precision + ", " + scale + ")"
  262. }
  263. def serializeColumn(model: ColumnModel) =
  264. serializeTableElement(modelSerializer.serializeColumn(model))
  265. def serializeIndexColumn(ik: IndexColumn) = {
  266. val words = new ArrayBuffer[String]
  267. words += serializeName(ik.name) + ik.length.map("(" + _ + ")").getOrElse("")
  268. if (!ik.asc) words += "DESC"
  269. words.mkString(" ")
  270. }
  271. def serializePrimaryKey(pk: PrimaryKeyModel) = {
  272. val words = new ArrayBuffer[String]
  273. if (pk.name.isDefined) words += "CONSTRAINT " + serializeName(pk.name.get)
  274. words += "PRIMARY KEY"
  275. words += ("(" + pk.columns.map(serializeIndexColumn _).mkString(", ") + ")")
  276. words.mkString(" ")
  277. }
  278. // XXX: handle MySQL specific stuff
  279. def serializeForeignKey(fk: ForeignKeyModel) = {
  280. val ForeignKeyModel(name, localColumns, externalTable, externalColumns, updateRule, deleteRule) = fk
  281. val words = new ArrayBuffer[String]
  282. if (name.isDefined) words += "CONSTRAINT " + serializeName(fk.name.get)
  283. words += "FOREIGN KEY"
  284. words += ("(" + localColumns.map(serializeIndexColumn _).mkString(", ") + ")")
  285. words += "REFERENCES"
  286. words += serializeName(externalTable)
  287. words += ("(" + externalColumns.mkString(", ") + ")")
  288. words ++= updateRule.map(p => "ON UPDATE " + serializeImportedKeyRule(p))
  289. words ++= deleteRule.map(p => "ON DELETE " + serializeImportedKeyRule(p))
  290. words.mkString(" ")
  291. }
  292. def serializeUniqueKey(uk: UniqueKeyModel) = {
  293. val words = new ArrayBuffer[String]
  294. if (uk.name.isDefined) words += "CONSTRAINT " + serializeName(uk.name.get)
  295. words += "UNIQUE (" + uk.columns.map(serializeIndexColumn _).mkString(", ") + ")"
  296. words.mkString(" ")
  297. }
  298. def serializeIndex(index: IndexModel) = {
  299. val words = new ArrayBuffer[String]
  300. words += "INDEX"
  301. words ++= index.name.map(serializeName _)
  302. words += ("(" + index.columns.map(serializeIndexColumn _).mkString(", ") + ")")
  303. words.mkString(" ")
  304. }
  305. }
  306. object ScriptSerializerTests extends MySpecification {
  307. val context = Environment.defaultContext
  308. import context._
  309. import ScriptSerializer._
  310. "serialize semi singleline" in {
  311. val dt = DropTableStatement("users", false)
  312. val c = CommentElement("/* h */")
  313. val script = List(dt, c, dt, dt, c, c, dt)
  314. val options = Options.singleline
  315. val serialized = scriptSerializer.serialize(script, options)
  316. //println("'" + serialized + "'")
  317. serialized must_== "DROP TABLE users; /* h */ DROP TABLE users; DROP TABLE users; /* h */ /* h */ DROP TABLE users;"
  318. }
  319. "serialize default value" in {
  320. scriptSerializer.serializeModelColumnProperty(DefaultValue(NumberValue(15))).get must_== "DEFAULT 15"
  321. }
  322. "serialize value" in {
  323. scriptSerializer.serializeValue(NullValue) must_== "NULL"
  324. scriptSerializer.serializeValue(NumberValue(15)) must_== "15"
  325. scriptSerializer.serializeValue(StringValue("hello")) must_== "'hello'"
  326. scriptSerializer.serializeValue(StringValue("'hello world'")) must_== "'''hello world'''"
  327. scriptSerializer.serializeValue(NowValue) must_== "NOW()" // XXX: or CURRENT_TIMESTAMP
  328. scriptSerializer.serializeValue(DateValue("2009-03-09")) must_== "DATE '2009-03-09'"
  329. }
  330. }
  331. // vim: set ts=4 sw=4 et: