PageRenderTime 156ms CodeModel.GetById 36ms RepoModel.GetById 0ms app.codeStats 1ms

/loader/postgres.go

https://github.com/xo/xo
Go | 218 lines | 184 code | 11 blank | 23 comment | 36 complexity | eadeb90f1544f9830a0d36191daec398 MD5 | raw file
  1. package loader
  2. import (
  3. "context"
  4. "fmt"
  5. "regexp"
  6. "strconv"
  7. "strings"
  8. "github.com/xo/xo/models"
  9. xo "github.com/xo/xo/types"
  10. )
  11. func init() {
  12. Register("postgres", Loader{
  13. Mask: "$%d",
  14. Flags: PostgresFlags,
  15. Schema: models.PostgresSchema,
  16. Enums: models.PostgresEnums,
  17. EnumValues: models.PostgresEnumValues,
  18. Procs: models.PostgresProcs,
  19. ProcParams: models.PostgresProcParams,
  20. Tables: models.PostgresTables,
  21. TableColumns: PostgresTableColumns,
  22. TableSequences: models.PostgresTableSequences,
  23. TableForeignKeys: models.PostgresTableForeignKeys,
  24. TableIndexes: models.PostgresTableIndexes,
  25. IndexColumns: PostgresIndexColumns,
  26. ViewCreate: models.PostgresViewCreate,
  27. ViewSchema: models.PostgresViewSchema,
  28. ViewDrop: models.PostgresViewDrop,
  29. ViewStrip: PostgresViewStrip,
  30. })
  31. }
  32. // PostgresFlags returnss the postgres flags.
  33. func PostgresFlags() []xo.Flag {
  34. return []xo.Flag{
  35. {
  36. ContextKey: OidsKey,
  37. Type: "bool",
  38. Desc: "enable postgres OIDs",
  39. Default: "false",
  40. },
  41. }
  42. }
  43. // PostgresGoType parse a type into a Go type based on the database type
  44. // definition.
  45. func PostgresGoType(d xo.Type, schema, itype, utype string) (string, string, error) {
  46. // SETOF -> []T
  47. if strings.HasPrefix(d.Type, "SETOF ") {
  48. d.Type = d.Type[len("SETOF "):]
  49. goType, _, err := PostgresGoType(d, schema, itype, utype)
  50. if err != nil {
  51. return "", "", err
  52. }
  53. return "[]" + goType, "nil", nil
  54. }
  55. // special type handling
  56. typ := d.Type
  57. switch {
  58. case typ == `"char"`:
  59. typ = "char"
  60. case strings.HasPrefix(typ, "information_schema."):
  61. switch strings.TrimPrefix(typ, "information_schema.") {
  62. case "cardinal_number":
  63. typ = "integer"
  64. case "character_data", "sql_identifier", "yes_or_no":
  65. typ = "character varying"
  66. case "time_stamp":
  67. typ = "timestamp with time zone"
  68. }
  69. }
  70. var goType, zero string
  71. switch typ {
  72. case "boolean":
  73. goType, zero = "bool", "false"
  74. if d.Nullable {
  75. goType, zero = "sql.NullBool", "sql.NullBool{}"
  76. }
  77. case "bpchar", "character varying", "character", "inet", "money", "text", "name":
  78. goType, zero = "string", `""`
  79. if d.Nullable {
  80. goType, zero = "sql.NullString", "sql.NullString{}"
  81. }
  82. case "smallint":
  83. goType, zero = "int16", "0"
  84. if d.Nullable {
  85. goType, zero = "sql.NullInt64", "sql.NullInt64{}"
  86. }
  87. case "integer":
  88. goType, zero = itype, "0"
  89. if d.Nullable {
  90. goType, zero = "sql.NullInt64", "sql.NullInt64{}"
  91. }
  92. case "bigint":
  93. goType, zero = "int64", "0"
  94. if d.Nullable {
  95. goType, zero = "sql.NullInt64", "sql.NullInt64{}"
  96. }
  97. case "real":
  98. goType, zero = "float32", "0.0"
  99. if d.Nullable {
  100. goType, zero = "sql.NullFloat64", "sql.NullFloat64{}"
  101. }
  102. case "double precision", "numeric":
  103. goType, zero = "float64", "0.0"
  104. if d.Nullable {
  105. goType, zero = "sql.NullFloat64", "sql.NullFloat64{}"
  106. }
  107. case "date", "timestamp with time zone", "time with time zone", "time without time zone", "timestamp without time zone":
  108. goType, zero = "time.Time", "time.Time{}"
  109. if d.Nullable {
  110. goType, zero = "sql.NullTime", "sql.NullTime{}"
  111. }
  112. case "bit":
  113. goType, zero = "uint8", "0"
  114. if d.Nullable {
  115. goType, zero = "*uint8", "nil"
  116. }
  117. case "any", "bit varying", "bytea", "interval", "json", "jsonb", "xml":
  118. // TODO: write custom type for interval marshaling
  119. // TODO: marshalling for json types
  120. goType, zero = "[]byte", "nil"
  121. case "hstore":
  122. goType, zero = "hstore.Hstore", "nil"
  123. case "uuid":
  124. goType, zero = "uuid.UUID", "uuid.UUID{}"
  125. if d.Nullable {
  126. goType, zero = "uuid.NullUUID", "uuid.NullUUID{}"
  127. }
  128. default:
  129. goType, zero = schemaType(d.Type, d.Nullable, schema)
  130. }
  131. // handle slices
  132. switch {
  133. case d.IsArray && goType == "string":
  134. return "StringSlice", "StringSlice{}", nil
  135. case d.IsArray:
  136. return "[]" + goType, "nil", nil
  137. }
  138. return goType, zero, nil
  139. }
  140. // PostgresTableColumns returns the columns for a table.
  141. func PostgresTableColumns(ctx context.Context, db models.DB, schema string, table string) ([]*models.Column, error) {
  142. return models.PostgresTableColumns(ctx, db, schema, table, EnableOids(ctx))
  143. }
  144. // PostgresIndexColumns returns the column list for an index.
  145. //
  146. // FIXME: rewrite this using SQL exclusively using OVER
  147. func PostgresIndexColumns(ctx context.Context, db models.DB, schema string, table string, index string) ([]*models.IndexColumn, error) {
  148. // load columns
  149. cols, err := models.PostgresIndexColumns(ctx, db, schema, index)
  150. if err != nil {
  151. return nil, err
  152. }
  153. // load col order
  154. colOrd, err := models.PostgresGetColOrder(ctx, db, schema, index)
  155. if err != nil {
  156. return nil, err
  157. }
  158. // build schema name used in errors
  159. s := schema
  160. if s != "" {
  161. s += "."
  162. }
  163. // put cols in order using colOrder
  164. var ret []*models.IndexColumn
  165. for _, v := range strings.Split(colOrd.Ord, " ") {
  166. cid, err := strconv.Atoi(v)
  167. if err != nil {
  168. return nil, fmt.Errorf("could not convert %s%s index %s column %s to int", s, table, index, v)
  169. }
  170. // find column
  171. found := false
  172. var c *models.IndexColumn
  173. for _, ic := range cols {
  174. if cid == ic.Cid {
  175. found, c = true, ic
  176. break
  177. }
  178. }
  179. // sanity check
  180. if !found {
  181. return nil, fmt.Errorf("could not find %s%s index %s column id %d", s, table, index, cid)
  182. }
  183. ret = append(ret, c)
  184. }
  185. return ret, nil
  186. }
  187. // PostgresViewStrip strips '::type AS name' in queries.
  188. func PostgresViewStrip(query, inspect []string) ([]string, []string, []string, error) {
  189. comments := make([]string, len(query))
  190. for i, line := range query {
  191. if pos := stripRE.FindStringIndex(line); pos != nil {
  192. query[i] = line[:pos[0]] + line[pos[1]:]
  193. comments[i] = line[pos[0]:pos[1]]
  194. }
  195. }
  196. return query, inspect, comments, nil
  197. }
  198. // stripRE is the regexp to match the '::type AS name' portion in a query,
  199. // which is a quirk/requirement of generating queries for postgres.
  200. var stripRE = regexp.MustCompile(`(?i)::[a-z][a-z0-9_\.]+\s+AS\s+[a-z][a-z0-9_\.]+`)
  201. // OidsKey is the oids context key.
  202. const OidsKey xo.ContextKey = "oids"
  203. // EnableOids returns the EnableOids value from the context.
  204. func EnableOids(ctx context.Context) bool {
  205. b, _ := ctx.Value(OidsKey).(bool)
  206. return b
  207. }