PageRenderTime 149ms CodeModel.GetById 30ms RepoModel.GetById 0ms app.codeStats 1ms

/templates/go/go.go

https://github.com/xo/xo
Go | 1961 lines | 1675 code | 87 blank | 199 comment | 309 complexity | 7cd217d447e5d84997b89f4e176c0fe3 MD5 | raw file
  1. //go:build xotpl
  2. package gotpl
  3. import (
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "io/ioutil"
  11. "os"
  12. "path/filepath"
  13. "regexp"
  14. "strconv"
  15. "strings"
  16. "text/template"
  17. "github.com/kenshaw/inflector"
  18. "github.com/kenshaw/snaker"
  19. "github.com/xo/xo/loader"
  20. xo "github.com/xo/xo/types"
  21. "golang.org/x/tools/imports"
  22. "mvdan.cc/gofumpt/format"
  23. )
  24. var (
  25. ErrNoSingle = errors.New("in query exec mode, the --single or -S must be provided")
  26. )
  27. // Init registers the template.
  28. func Init(ctx context.Context, f func(xo.TemplateType)) error {
  29. knownTypes := map[string]bool{
  30. "bool": true,
  31. "string": true,
  32. "byte": true,
  33. "rune": true,
  34. "int": true,
  35. "int16": true,
  36. "int32": true,
  37. "int64": true,
  38. "uint": true,
  39. "uint8": true,
  40. "uint16": true,
  41. "uint32": true,
  42. "uint64": true,
  43. "float32": true,
  44. "float64": true,
  45. "Slice": true,
  46. "StringSlice": true,
  47. }
  48. shorts := map[string]string{
  49. "bool": "b",
  50. "string": "s",
  51. "byte": "b",
  52. "rune": "r",
  53. "int": "i",
  54. "int16": "i",
  55. "int32": "i",
  56. "int64": "i",
  57. "uint": "u",
  58. "uint8": "u",
  59. "uint16": "u",
  60. "uint32": "u",
  61. "uint64": "u",
  62. "float32": "f",
  63. "float64": "f",
  64. "Slice": "s",
  65. "StringSlice": "ss",
  66. }
  67. f(xo.TemplateType{
  68. Modes: []string{"query", "schema"},
  69. Flags: []xo.Flag{
  70. {
  71. ContextKey: AppendKey,
  72. Type: "bool",
  73. Desc: "enable append mode",
  74. Short: "a",
  75. Aliases: []string{"append"},
  76. },
  77. {
  78. ContextKey: NotFirstKey,
  79. Type: "bool",
  80. Desc: "disable package file (ie. not first generated file)",
  81. Short: "2",
  82. Default: "false",
  83. },
  84. {
  85. ContextKey: Int32Key,
  86. Type: "string",
  87. Desc: "int32 type",
  88. Default: "int",
  89. },
  90. {
  91. ContextKey: Uint32Key,
  92. Type: "string",
  93. Desc: "uint32 type",
  94. Default: "uint",
  95. },
  96. {
  97. ContextKey: PkgKey,
  98. Type: "string",
  99. Desc: "package name",
  100. },
  101. {
  102. ContextKey: TagKey,
  103. Type: "[]string",
  104. Desc: "build tags",
  105. },
  106. {
  107. ContextKey: ImportKey,
  108. Type: "[]string",
  109. Desc: "package imports",
  110. },
  111. {
  112. ContextKey: UUIDKey,
  113. Type: "string",
  114. Desc: "uuid type package",
  115. Default: "github.com/google/uuid",
  116. },
  117. {
  118. ContextKey: CustomKey,
  119. Type: "string",
  120. Desc: "package name for custom types",
  121. },
  122. {
  123. ContextKey: ConflictKey,
  124. Type: "string",
  125. Desc: "name conflict suffix",
  126. Default: "Val",
  127. },
  128. {
  129. ContextKey: InitialismKey,
  130. Type: "[]string",
  131. Desc: "add initialism (e.g. ID, API, URI, ...)",
  132. },
  133. {
  134. ContextKey: EscKey,
  135. Type: "[]string",
  136. Desc: "escape fields",
  137. Default: "none",
  138. Enums: []string{"none", "schema", "table", "column", "all"},
  139. },
  140. {
  141. ContextKey: FieldTagKey,
  142. Type: "string",
  143. Desc: "field tag",
  144. Short: "g",
  145. Default: `json:"{{ .SQLName }}"`,
  146. },
  147. {
  148. ContextKey: ContextKey,
  149. Type: "string",
  150. Desc: "context mode",
  151. Default: "only",
  152. Enums: []string{"disable", "both", "only"},
  153. },
  154. {
  155. ContextKey: InjectKey,
  156. Type: "string",
  157. Desc: "insert code into generated file headers",
  158. Default: "",
  159. },
  160. {
  161. ContextKey: InjectFileKey,
  162. Type: "string",
  163. Desc: "insert code into generated file headers from a file",
  164. Default: "",
  165. },
  166. {
  167. ContextKey: LegacyKey,
  168. Type: "bool",
  169. Desc: "enables legacy v1 template funcs",
  170. Default: "false",
  171. },
  172. },
  173. Funcs: func(ctx context.Context, _ string) (template.FuncMap, error) {
  174. funcs, err := NewFuncs(ctx)
  175. if err != nil {
  176. return nil, err
  177. }
  178. if Legacy(ctx) {
  179. addLegacyFuncs(ctx, funcs)
  180. }
  181. return funcs, nil
  182. },
  183. NewContext: func(ctx context.Context, _ string) context.Context {
  184. ctx = context.WithValue(ctx, KnownTypesKey, knownTypes)
  185. ctx = context.WithValue(ctx, ShortsKey, shorts)
  186. return ctx
  187. },
  188. Order: func(ctx context.Context, mode string) []string {
  189. base := []string{"header", "db"}
  190. switch mode {
  191. case "query":
  192. return append(base, "typedef", "query")
  193. case "schema":
  194. return append(base, "enum", "proc", "typedef", "query", "index", "foreignkey")
  195. }
  196. return nil
  197. },
  198. Pre: func(ctx context.Context, mode string, set *xo.Set, out fs.FS, emit func(xo.Template)) error {
  199. if err := addInitialisms(ctx); err != nil {
  200. return err
  201. }
  202. files, err := fileNames(ctx, mode, set)
  203. if err != nil {
  204. return err
  205. }
  206. // If -2 is provided, skip package template outputs as requested.
  207. // If -a is provided, skip to avoid duplicating the template.
  208. if !NotFirst(ctx) && !Append(ctx) {
  209. emit(xo.Template{
  210. Partial: "db",
  211. Dest: "db.xo.go",
  212. })
  213. // If --single is provided, don't generate header for db.xo.go.
  214. if xo.Single(ctx) == "" {
  215. files["db.xo.go"] = true
  216. }
  217. }
  218. if Append(ctx) {
  219. for filename := range files {
  220. f, err := out.Open(filename)
  221. switch {
  222. case errors.Is(err, os.ErrNotExist):
  223. continue
  224. case err != nil:
  225. return err
  226. }
  227. defer f.Close()
  228. data, err := io.ReadAll(f)
  229. if err != nil {
  230. return err
  231. }
  232. emit(xo.Template{
  233. Src: "{{.Data}}",
  234. Partial: "header", // ordered first
  235. Data: string(data),
  236. Dest: filename,
  237. })
  238. delete(files, filename)
  239. }
  240. }
  241. for filename := range files {
  242. emit(xo.Template{
  243. Partial: "header",
  244. Dest: filename,
  245. })
  246. }
  247. return nil
  248. },
  249. Process: func(ctx context.Context, mode string, set *xo.Set, emit func(xo.Template)) error {
  250. if mode == "query" {
  251. for _, query := range set.Queries {
  252. if err := emitQuery(ctx, query, emit); err != nil {
  253. return err
  254. }
  255. }
  256. } else {
  257. for _, schema := range set.Schemas {
  258. if err := emitSchema(ctx, schema, emit); err != nil {
  259. return err
  260. }
  261. }
  262. }
  263. return nil
  264. },
  265. Post: func(ctx context.Context, mode string, files map[string][]byte, emit func(string, []byte)) error {
  266. for file, content := range files {
  267. // Run goimports.
  268. buf, err := imports.Process("", content, nil)
  269. if err != nil {
  270. return fmt.Errorf("%s:%w", file, err)
  271. }
  272. // Run gofumpt.
  273. formatted, err := format.Source(buf, format.Options{
  274. ExtraRules: true,
  275. })
  276. if err != nil {
  277. return err
  278. }
  279. emit(file, formatted)
  280. }
  281. return nil
  282. },
  283. })
  284. return nil
  285. }
  286. // fileNames returns a list of file names that will be generated by the
  287. // template based on the parameters and schema.
  288. func fileNames(ctx context.Context, mode string, set *xo.Set) (map[string]bool, error) {
  289. // In single mode, only the specified file be generated.
  290. singleFile := xo.Single(ctx)
  291. if singleFile != "" {
  292. return map[string]bool{
  293. singleFile: true,
  294. }, nil
  295. }
  296. // Otherwise, infer filenames from set.
  297. files := make(map[string]bool)
  298. addFile := func(filename string) {
  299. // Filenames are always lowercase.
  300. filename = strings.ToLower(filename)
  301. files[filename+ext] = true
  302. }
  303. switch mode {
  304. case "schema":
  305. for _, schema := range set.Schemas {
  306. for _, e := range schema.Enums {
  307. addFile(camelExport(e.Name))
  308. }
  309. for _, p := range schema.Procs {
  310. goName := camelExport(p.Name)
  311. if p.Type == "function" {
  312. addFile("sf_" + goName)
  313. } else {
  314. addFile("sp_" + goName)
  315. }
  316. }
  317. for _, t := range schema.Tables {
  318. addFile(camelExport(singularize(t.Name)))
  319. }
  320. for _, v := range schema.Views {
  321. addFile(camelExport(singularize(v.Name)))
  322. }
  323. }
  324. case "query":
  325. for _, query := range set.Queries {
  326. addFile(query.Type)
  327. if query.Exec {
  328. // Single mode is handled at the start of the function but it
  329. // must be used for Exec queries.
  330. return nil, ErrNoSingle
  331. }
  332. }
  333. default:
  334. panic("unknown mode: " + mode)
  335. }
  336. return files, nil
  337. }
  338. // emitQuery emits the query.
  339. func emitQuery(ctx context.Context, query xo.Query, emit func(xo.Template)) error {
  340. var table Table
  341. // build type if needed
  342. if !query.Exec {
  343. var err error
  344. if table, err = buildQueryType(ctx, query); err != nil {
  345. return err
  346. }
  347. }
  348. // emit type definition
  349. if !query.Exec && !query.Flat && !Append(ctx) {
  350. emit(xo.Template{
  351. Partial: "typedef",
  352. Dest: strings.ToLower(table.GoName) + ext,
  353. SortType: query.Type,
  354. SortName: query.Name,
  355. Data: table,
  356. })
  357. }
  358. // build query params
  359. var params []QueryParam
  360. for _, param := range query.Params {
  361. params = append(params, QueryParam{
  362. Name: param.Name,
  363. Type: param.Type.Type,
  364. Interpolate: param.Interpolate,
  365. Join: param.Join,
  366. })
  367. }
  368. // emit query
  369. emit(xo.Template{
  370. Partial: "query",
  371. Dest: strings.ToLower(table.GoName) + ext,
  372. SortType: query.Type,
  373. SortName: query.Name,
  374. Data: Query{
  375. Name: buildQueryName(query),
  376. Query: query.Query,
  377. Comments: query.Comments,
  378. Params: params,
  379. One: query.Exec || query.Flat || query.One,
  380. Flat: query.Flat,
  381. Exec: query.Exec,
  382. Interpolate: query.Interpolate,
  383. Type: table,
  384. Comment: query.Comment,
  385. },
  386. })
  387. return nil
  388. }
  389. func buildQueryType(ctx context.Context, query xo.Query) (Table, error) {
  390. tf := camelExport
  391. if query.Flat {
  392. tf = camel
  393. }
  394. var fields []Field
  395. for _, z := range query.Fields {
  396. f, err := convertField(ctx, tf, z)
  397. if err != nil {
  398. return Table{}, err
  399. }
  400. // dont use convertField; the types are already provided by the user
  401. if query.ManualFields {
  402. f = Field{
  403. GoName: z.Name,
  404. SQLName: snake(z.Name),
  405. Type: z.Type.Type,
  406. }
  407. }
  408. fields = append(fields, f)
  409. }
  410. sqlName := snake(query.Type)
  411. return Table{
  412. GoName: query.Type,
  413. SQLName: sqlName,
  414. Fields: fields,
  415. Comment: query.TypeComment,
  416. }, nil
  417. }
  418. // buildQueryName builds a name for the query.
  419. func buildQueryName(query xo.Query) string {
  420. if query.Name != "" {
  421. return query.Name
  422. }
  423. // generate name if not specified
  424. name := query.Type
  425. if !query.One {
  426. name = inflector.Pluralize(name)
  427. }
  428. // add params
  429. if len(query.Params) == 0 {
  430. name = "Get" + name
  431. } else {
  432. name += "By"
  433. for _, p := range query.Params {
  434. name += camelExport(p.Name)
  435. }
  436. }
  437. return name
  438. }
  439. // emitSchema emits the xo schema for the template set.
  440. func emitSchema(ctx context.Context, schema xo.Schema, emit func(xo.Template)) error {
  441. // emit enums
  442. for _, e := range schema.Enums {
  443. enum := convertEnum(e)
  444. emit(xo.Template{
  445. Partial: "enum",
  446. Dest: strings.ToLower(enum.GoName) + ext,
  447. SortName: enum.GoName,
  448. Data: enum,
  449. })
  450. }
  451. // build procs
  452. overloadMap := make(map[string][]Proc)
  453. // procOrder ensures procs are always emitted in alphabetic order for
  454. // consistency in single mode
  455. var procOrder []string
  456. for _, p := range schema.Procs {
  457. var err error
  458. if procOrder, err = convertProc(ctx, overloadMap, procOrder, p); err != nil {
  459. return err
  460. }
  461. }
  462. // emit procs
  463. for _, name := range procOrder {
  464. procs := overloadMap[name]
  465. prefix := "sp_"
  466. if procs[0].Type == "function" {
  467. prefix = "sf_"
  468. }
  469. // Set flag to change name to their overloaded versions if needed.
  470. for i := range procs {
  471. procs[i].Overloaded = len(procs) > 1
  472. }
  473. emit(xo.Template{
  474. Dest: prefix + strings.ToLower(name) + ext,
  475. Partial: "procs",
  476. SortName: prefix + name,
  477. Data: procs,
  478. })
  479. }
  480. // emit tables
  481. for _, t := range append(schema.Tables, schema.Views...) {
  482. table, err := convertTable(ctx, t)
  483. if err != nil {
  484. return err
  485. }
  486. emit(xo.Template{
  487. Dest: strings.ToLower(table.GoName) + ext,
  488. Partial: "typedef",
  489. SortType: table.Type,
  490. SortName: table.GoName,
  491. Data: table,
  492. })
  493. // emit indexes
  494. for _, i := range t.Indexes {
  495. index, err := convertIndex(ctx, table, i)
  496. if err != nil {
  497. return err
  498. }
  499. emit(xo.Template{
  500. Dest: strings.ToLower(table.GoName) + ext,
  501. Partial: "index",
  502. SortType: table.Type,
  503. SortName: index.SQLName,
  504. Data: index,
  505. })
  506. }
  507. // emit fkeys
  508. for _, fk := range t.ForeignKeys {
  509. fkey, err := convertFKey(ctx, table, fk)
  510. if err != nil {
  511. return err
  512. }
  513. emit(xo.Template{
  514. Dest: strings.ToLower(table.GoName) + ext,
  515. Partial: "foreignkey",
  516. SortType: table.Type,
  517. SortName: fkey.SQLName,
  518. Data: fkey,
  519. })
  520. }
  521. }
  522. return nil
  523. }
  524. // convertEnum converts a xo.Enum.
  525. func convertEnum(e xo.Enum) Enum {
  526. var vals []EnumValue
  527. goName := camelExport(e.Name)
  528. for _, v := range e.Values {
  529. name := camelExport(strings.ToLower(v.Name))
  530. if strings.HasSuffix(name, goName) && goName != name {
  531. name = strings.TrimSuffix(name, goName)
  532. }
  533. vals = append(vals, EnumValue{
  534. GoName: name,
  535. SQLName: v.Name,
  536. ConstValue: *v.ConstValue,
  537. })
  538. }
  539. return Enum{
  540. GoName: goName,
  541. SQLName: e.Name,
  542. Values: vals,
  543. }
  544. }
  545. // convertProc converts a xo.Proc.
  546. func convertProc(ctx context.Context, overloadMap map[string][]Proc, order []string, p xo.Proc) ([]string, error) {
  547. _, _, schema := xo.DriverDbSchema(ctx)
  548. proc := Proc{
  549. Type: p.Type,
  550. GoName: camelExport(p.Name),
  551. SQLName: p.Name,
  552. Signature: fmt.Sprintf("%s.%s", schema, p.Name),
  553. Void: p.Void,
  554. }
  555. // proc params
  556. var types []string
  557. for _, z := range p.Params {
  558. f, err := convertField(ctx, camel, z)
  559. if err != nil {
  560. return nil, err
  561. }
  562. proc.Params = append(proc.Params, f)
  563. types = append(types, z.Type.Type)
  564. }
  565. // add to signature, generate name
  566. proc.Signature += "(" + strings.Join(types, ", ") + ")"
  567. proc.OverloadedName = overloadedName(types, proc)
  568. types = nil
  569. // proc return
  570. for _, z := range p.Returns {
  571. f, err := convertField(ctx, camel, z)
  572. if err != nil {
  573. return nil, err
  574. }
  575. proc.Returns = append(proc.Returns, f)
  576. types = append(types, z.Type.Type)
  577. }
  578. // append signature
  579. if !p.Void {
  580. format := " (%s)"
  581. if len(p.Returns) == 1 {
  582. format = " %s"
  583. }
  584. proc.Signature += fmt.Sprintf(format, strings.Join(types, ", "))
  585. }
  586. // add proc
  587. procs, ok := overloadMap[proc.GoName]
  588. if !ok {
  589. order = append(order, proc.GoName)
  590. }
  591. overloadMap[proc.GoName] = append(procs, proc)
  592. return order, nil
  593. }
  594. // convertTable converts a xo.Table to a Table.
  595. func convertTable(ctx context.Context, t xo.Table) (Table, error) {
  596. var cols, pkCols []Field
  597. for _, z := range t.Columns {
  598. f, err := convertField(ctx, camelExport, z)
  599. if err != nil {
  600. return Table{}, err
  601. }
  602. cols = append(cols, f)
  603. if z.IsPrimary {
  604. pkCols = append(pkCols, f)
  605. }
  606. }
  607. return Table{
  608. GoName: camelExport(singularize(t.Name)),
  609. SQLName: t.Name,
  610. Fields: cols,
  611. PrimaryKeys: pkCols,
  612. Manual: t.Manual,
  613. }, nil
  614. }
  615. func convertIndex(ctx context.Context, t Table, i xo.Index) (Index, error) {
  616. var fields []Field
  617. for _, z := range i.Fields {
  618. f, err := convertField(ctx, camelExport, z)
  619. if err != nil {
  620. return Index{}, err
  621. }
  622. fields = append(fields, f)
  623. }
  624. return Index{
  625. SQLName: i.Name,
  626. Func: camelExport(i.Func),
  627. Table: t,
  628. Fields: fields,
  629. IsUnique: i.IsUnique,
  630. IsPrimary: i.IsPrimary,
  631. }, nil
  632. }
  633. func convertFKey(ctx context.Context, t Table, fk xo.ForeignKey) (ForeignKey, error) {
  634. var fields, refFields []Field
  635. // convert fields
  636. for _, f := range fk.Fields {
  637. field, err := convertField(ctx, camelExport, f)
  638. if err != nil {
  639. return ForeignKey{}, err
  640. }
  641. fields = append(fields, field)
  642. }
  643. // convert ref fields
  644. for _, f := range fk.RefFields {
  645. refField, err := convertField(ctx, camelExport, f)
  646. if err != nil {
  647. return ForeignKey{}, err
  648. }
  649. refFields = append(refFields, refField)
  650. }
  651. return ForeignKey{
  652. GoName: camelExport(fk.Func),
  653. SQLName: fk.Name,
  654. Table: t,
  655. Fields: fields,
  656. RefTable: camelExport(singularize(fk.RefTable)),
  657. RefFields: refFields,
  658. RefFunc: camelExport(fk.RefFunc),
  659. }, nil
  660. }
  661. func overloadedName(sqlTypes []string, proc Proc) string {
  662. if len(proc.Params) == 0 {
  663. return proc.GoName
  664. }
  665. var names []string
  666. // build parameters for proc.
  667. // if the proc's parameter has no name, use the types of the proc instead
  668. for i, f := range proc.Params {
  669. if f.SQLName == fmt.Sprintf("p%d", i) {
  670. names = append(names, camelExport(strings.Split(sqlTypes[i], " ")...))
  671. continue
  672. }
  673. names = append(names, camelExport(f.GoName))
  674. }
  675. if len(names) == 1 {
  676. return fmt.Sprintf("%sBy%s", proc.GoName, names[0])
  677. }
  678. front, last := strings.Join(names[:len(names)-1], ""), names[len(names)-1]
  679. return fmt.Sprintf("%sBy%sAnd%s", proc.GoName, front, last)
  680. }
  681. func convertField(ctx context.Context, tf transformFunc, f xo.Field) (Field, error) {
  682. typ, zero, err := goType(ctx, f.Type)
  683. if err != nil {
  684. return Field{}, err
  685. }
  686. return Field{
  687. Type: typ,
  688. GoName: tf(f.Name),
  689. SQLName: f.Name,
  690. Zero: zero,
  691. IsPrimary: f.IsPrimary,
  692. IsSequence: f.IsSequence,
  693. }, nil
  694. }
  695. func goType(ctx context.Context, typ xo.Type) (string, string, error) {
  696. driver, _, schema := xo.DriverDbSchema(ctx)
  697. var f func(xo.Type, string, string, string) (string, string, error)
  698. switch driver {
  699. case "mysql":
  700. f = loader.MysqlGoType
  701. case "oracle":
  702. f = loader.OracleGoType
  703. case "postgres":
  704. f = loader.PostgresGoType
  705. case "sqlite3":
  706. f = loader.Sqlite3GoType
  707. case "sqlserver":
  708. f = loader.SqlserverGoType
  709. default:
  710. return "", "", fmt.Errorf("unknown driver %q", driver)
  711. }
  712. return f(typ, schema, Int32(ctx), Uint32(ctx))
  713. }
  714. type transformFunc func(...string) string
  715. func snake(names ...string) string {
  716. return snaker.CamelToSnake(strings.Join(names, "_"))
  717. }
  718. func camel(names ...string) string {
  719. return snaker.ForceLowerCamelIdentifier(strings.Join(names, "_"))
  720. }
  721. func camelExport(names ...string) string {
  722. return snaker.ForceCamelIdentifier(strings.Join(names, "_"))
  723. }
  724. const ext = ".xo.go"
  725. // Funcs is a set of template funcs.
  726. type Funcs struct {
  727. driver string
  728. schema string
  729. nth func(int) string
  730. first bool
  731. pkg string
  732. tags []string
  733. imports []string
  734. conflict string
  735. custom string
  736. escSchema bool
  737. escTable bool
  738. escColumn bool
  739. fieldtag *template.Template
  740. context string
  741. inject string
  742. // knownTypes is the collection of known Go types.
  743. knownTypes map[string]bool
  744. // shorts is the collection of Go style short names for types, mainly
  745. // used for use with declaring a func receiver on a type.
  746. shorts map[string]string
  747. }
  748. // NewFuncs creates custom template funcs for the context.
  749. func NewFuncs(ctx context.Context) (template.FuncMap, error) {
  750. first := !NotFirst(ctx)
  751. // parse field tag template
  752. fieldtag, err := template.New("fieldtag").Parse(FieldTag(ctx))
  753. if err != nil {
  754. return nil, err
  755. }
  756. // load inject
  757. inject := Inject(ctx)
  758. if s := InjectFile(ctx); s != "" {
  759. buf, err := ioutil.ReadFile(s)
  760. if err != nil {
  761. return nil, fmt.Errorf("unable to read file: %v", err)
  762. }
  763. inject = string(buf)
  764. }
  765. driver, _, schema := xo.DriverDbSchema(ctx)
  766. nth, err := loader.NthParam(ctx)
  767. if err != nil {
  768. return nil, err
  769. }
  770. funcs := &Funcs{
  771. first: first,
  772. driver: driver,
  773. schema: schema,
  774. nth: nth,
  775. pkg: Pkg(ctx),
  776. tags: Tags(ctx),
  777. imports: Imports(ctx),
  778. conflict: Conflict(ctx),
  779. custom: Custom(ctx),
  780. escSchema: Esc(ctx, "schema"),
  781. escTable: Esc(ctx, "table"),
  782. escColumn: Esc(ctx, "column"),
  783. fieldtag: fieldtag,
  784. context: Context(ctx),
  785. inject: inject,
  786. knownTypes: KnownTypes(ctx),
  787. shorts: Shorts(ctx),
  788. }
  789. return funcs.FuncMap(), nil
  790. }
  791. // FuncMap returns the func map.
  792. func (f *Funcs) FuncMap() template.FuncMap {
  793. return template.FuncMap{
  794. // general
  795. "first": f.firstfn,
  796. "driver": f.driverfn,
  797. "schema": f.schemafn,
  798. "pkg": f.pkgfn,
  799. "tags": f.tagsfn,
  800. "imports": f.importsfn,
  801. "inject": f.injectfn,
  802. // context
  803. "context": f.contextfn,
  804. "context_both": f.context_both,
  805. "context_disable": f.context_disable,
  806. // func and query
  807. "func_name_context": f.func_name_context,
  808. "func_name": f.func_name_none,
  809. "func_context": f.func_context,
  810. "func": f.func_none,
  811. "recv_context": f.recv_context,
  812. "recv": f.recv_none,
  813. "foreign_key_context": f.foreign_key_context,
  814. "foreign_key": f.foreign_key_none,
  815. "db": f.db,
  816. "db_prefix": f.db_prefix,
  817. "db_update": f.db_update,
  818. "db_named": f.db_named,
  819. "named": f.named,
  820. "logf": f.logf,
  821. "logf_pkeys": f.logf_pkeys,
  822. "logf_update": f.logf_update,
  823. // type
  824. "names": f.names,
  825. "names_all": f.names_all,
  826. "names_ignore": f.names_ignore,
  827. "params": f.params,
  828. "zero": f.zero,
  829. "type": f.typefn,
  830. "field": f.field,
  831. "short": f.short,
  832. // sqlstr funcs
  833. "querystr": f.querystr,
  834. "sqlstr": f.sqlstr,
  835. // helpers
  836. "check_name": checkName,
  837. "eval": eval,
  838. }
  839. }
  840. func (f *Funcs) firstfn() bool {
  841. if f.first {
  842. f.first = false
  843. return true
  844. }
  845. return false
  846. }
  847. // driverfn returns true if the driver is any of the passed drivers.
  848. func (f *Funcs) driverfn(drivers ...string) bool {
  849. for _, driver := range drivers {
  850. if f.driver == driver {
  851. return true
  852. }
  853. }
  854. return false
  855. }
  856. // schemafn takes a series of names and joins them with the schema name.
  857. func (f *Funcs) schemafn(names ...string) string {
  858. s := f.schema
  859. // escape table names
  860. if f.escTable {
  861. for i, name := range names {
  862. names[i] = escfn(name)
  863. }
  864. }
  865. n := strings.Join(names, ".")
  866. switch {
  867. case s == "" && n == "":
  868. return ""
  869. case f.driver == "sqlite3" && n == "":
  870. return f.schema
  871. case f.driver == "sqlite3":
  872. return n
  873. case s != "" && n != "":
  874. if f.escSchema {
  875. s = escfn(s)
  876. }
  877. s += "."
  878. }
  879. return s + n
  880. }
  881. // pkgfn returns the package name.
  882. func (f *Funcs) pkgfn() string {
  883. return f.pkg
  884. }
  885. // tagsfn returns the tags.
  886. func (f *Funcs) tagsfn() []string {
  887. return f.tags
  888. }
  889. // importsfn returns the imports.
  890. func (f *Funcs) importsfn() []PackageImport {
  891. var imports []PackageImport
  892. for _, s := range f.imports {
  893. alias, pkg := "", s
  894. if i := strings.Index(pkg, " "); i != -1 {
  895. alias, pkg = pkg[:i], strings.TrimSpace(pkg[i:])
  896. }
  897. imports = append(imports, PackageImport{
  898. Alias: alias,
  899. Pkg: strconv.Quote(pkg),
  900. })
  901. }
  902. return imports
  903. }
  904. // contextfn returns true when the context mode is both or only.
  905. func (f *Funcs) contextfn() bool {
  906. return f.context == "both" || f.context == "only"
  907. }
  908. // context_both returns true with the context mode is both.
  909. func (f *Funcs) context_both() bool {
  910. return f.context == "both"
  911. }
  912. // context_disable returns true with the context mode is both.
  913. func (f *Funcs) context_disable() bool {
  914. return f.context == "disable"
  915. }
  916. // injectfn returns the injected content provided from args.
  917. func (f *Funcs) injectfn() string {
  918. return f.inject
  919. }
  920. // func_name_none builds a func name.
  921. func (f *Funcs) func_name_none(v interface{}) string {
  922. switch x := v.(type) {
  923. case string:
  924. return x
  925. case Query:
  926. return x.Name
  927. case Table:
  928. return x.GoName
  929. case ForeignKey:
  930. return x.GoName
  931. case Proc:
  932. n := x.GoName
  933. if x.Overloaded {
  934. n = x.OverloadedName
  935. }
  936. return n
  937. case Index:
  938. return x.Func
  939. }
  940. return fmt.Sprintf("[[ UNSUPPORTED TYPE 1: %T ]]", v)
  941. }
  942. // func_name_context generates a name for the func.
  943. func (f *Funcs) func_name_context(v interface{}) string {
  944. switch x := v.(type) {
  945. case string:
  946. return nameContext(f.context_both(), x)
  947. case Query:
  948. return nameContext(f.context_both(), x.Name)
  949. case Table:
  950. return nameContext(f.context_both(), x.GoName)
  951. case ForeignKey:
  952. return nameContext(f.context_both(), x.GoName)
  953. case Proc:
  954. n := x.GoName
  955. if x.Overloaded {
  956. n = x.OverloadedName
  957. }
  958. return nameContext(f.context_both(), n)
  959. case Index:
  960. return nameContext(f.context_both(), x.Func)
  961. }
  962. return fmt.Sprintf("[[ UNSUPPORTED TYPE 2: %T ]]", v)
  963. }
  964. // funcfn builds a func definition.
  965. func (f *Funcs) funcfn(name string, context bool, v interface{}) string {
  966. var p, r []string
  967. if context {
  968. p = append(p, "ctx context.Context")
  969. }
  970. p = append(p, "db DB")
  971. switch x := v.(type) {
  972. case Query:
  973. // params
  974. for _, z := range x.Params {
  975. p = append(p, fmt.Sprintf("%s %s", z.Name, z.Type))
  976. }
  977. // returns
  978. switch {
  979. case x.Exec:
  980. r = append(r, "sql.Result")
  981. case x.Flat:
  982. for _, z := range x.Type.Fields {
  983. r = append(r, f.typefn(z.Type))
  984. }
  985. case x.One:
  986. r = append(r, "*"+x.Type.GoName)
  987. default:
  988. r = append(r, "[]*"+x.Type.GoName)
  989. }
  990. case Proc:
  991. // params
  992. p = append(p, f.params(x.Params, true))
  993. // returns
  994. if !x.Void {
  995. for _, ret := range x.Returns {
  996. r = append(r, f.typefn(ret.Type))
  997. }
  998. }
  999. case Index:
  1000. // params
  1001. p = append(p, f.params(x.Fields, true))
  1002. // returns
  1003. rt := "*" + x.Table.GoName
  1004. if !x.IsUnique {
  1005. rt = "[]" + rt
  1006. }
  1007. r = append(r, rt)
  1008. default:
  1009. return fmt.Sprintf("[[ UNSUPPORTED TYPE 3: %T ]]", v)
  1010. }
  1011. r = append(r, "error")
  1012. return fmt.Sprintf("func %s(%s) (%s)", name, strings.Join(p, ", "), strings.Join(r, ", "))
  1013. }
  1014. // func_context generates a func signature for v with context determined by the
  1015. // context mode.
  1016. func (f *Funcs) func_context(v interface{}) string {
  1017. return f.funcfn(f.func_name_context(v), f.contextfn(), v)
  1018. }
  1019. // func_none genarates a func signature for v without context.
  1020. func (f *Funcs) func_none(v interface{}) string {
  1021. return f.funcfn(f.func_name_none(v), false, v)
  1022. }
  1023. // recv builds a receiver func definition.
  1024. func (f *Funcs) recv(name string, context bool, t Table, v interface{}) string {
  1025. short := f.short(t)
  1026. var p, r []string
  1027. // determine params and return type
  1028. if context {
  1029. p = append(p, "ctx context.Context")
  1030. }
  1031. p = append(p, "db DB")
  1032. switch x := v.(type) {
  1033. case ForeignKey:
  1034. r = append(r, "*"+x.RefTable)
  1035. }
  1036. r = append(r, "error")
  1037. return fmt.Sprintf("func (%s *%s) %s(%s) (%s)", short, t.GoName, name, strings.Join(p, ", "), strings.Join(r, ", "))
  1038. }
  1039. // recv_context builds a receiver func definition with context determined by
  1040. // the context mode.
  1041. func (f *Funcs) recv_context(typ interface{}, v interface{}) string {
  1042. switch x := typ.(type) {
  1043. case Table:
  1044. return f.recv(f.func_name_context(v), f.contextfn(), x, v)
  1045. }
  1046. return fmt.Sprintf("[[ UNSUPPORTED TYPE 4: %T ]]", typ)
  1047. }
  1048. // recv_none builds a receiver func definition without context.
  1049. func (f *Funcs) recv_none(typ interface{}, v interface{}) string {
  1050. switch x := typ.(type) {
  1051. case Table:
  1052. return f.recv(f.func_name_none(v), false, x, v)
  1053. }
  1054. return fmt.Sprintf("[[ UNSUPPORTED TYPE 5: %T ]]", typ)
  1055. }
  1056. func (f *Funcs) foreign_key_context(v interface{}) string {
  1057. var name string
  1058. var p []string
  1059. if f.contextfn() {
  1060. p = append(p, "ctx")
  1061. }
  1062. switch x := v.(type) {
  1063. case ForeignKey:
  1064. name = x.RefFunc
  1065. if f.context_both() {
  1066. name += "Context"
  1067. }
  1068. // add params
  1069. p = append(p, "db", f.convertTypes(x))
  1070. default:
  1071. return fmt.Sprintf("[[ UNSUPPORTED TYPE 6: %T ]]", v)
  1072. }
  1073. return fmt.Sprintf("%s(%s)", name, strings.Join(p, ", "))
  1074. }
  1075. func (f *Funcs) foreign_key_none(v interface{}) string {
  1076. var name string
  1077. var p []string
  1078. switch x := v.(type) {
  1079. case ForeignKey:
  1080. name = x.RefFunc
  1081. p = append(p, "context.Background()", "db", f.convertTypes(x))
  1082. default:
  1083. return fmt.Sprintf("[[ UNSUPPORTED TYPE 7: %T ]]", v)
  1084. }
  1085. return fmt.Sprintf("%s(%s)", name, strings.Join(p, ", "))
  1086. }
  1087. // db generates a db.<name>Context(ctx, sqlstr, ...)
  1088. func (f *Funcs) db(name string, v ...interface{}) string {
  1089. // params
  1090. var p []interface{}
  1091. if f.contextfn() {
  1092. name += "Context"
  1093. p = append(p, "ctx")
  1094. }
  1095. p = append(p, "sqlstr")
  1096. return fmt.Sprintf("db.%s(%s)", name, f.names("", append(p, v...)...))
  1097. }
  1098. // db_prefix generates a db.<name>Context(ctx, sqlstr, <prefix>.param, ...).
  1099. //
  1100. // Will skip the specific parameters based on the type provided.
  1101. func (f *Funcs) db_prefix(name string, skip bool, vs ...interface{}) string {
  1102. var prefix string
  1103. var params []interface{}
  1104. for i, v := range vs {
  1105. var ignore []string
  1106. switch x := v.(type) {
  1107. case string:
  1108. params = append(params, x)
  1109. case Table:
  1110. prefix = f.short(x.GoName) + "."
  1111. // skip primary keys
  1112. if skip {
  1113. for _, field := range x.Fields {
  1114. if field.IsSequence {
  1115. ignore = append(ignore, field.GoName)
  1116. }
  1117. }
  1118. }
  1119. p := f.names_ignore(prefix, v, ignore...)
  1120. // p is "" when no columns are present except for primary key
  1121. // params
  1122. if p != "" {
  1123. params = append(params, p)
  1124. }
  1125. default:
  1126. return fmt.Sprintf("[[ UNSUPPORTED TYPE 8 (%d): %T ]]", i, v)
  1127. }
  1128. }
  1129. return f.db(name, params...)
  1130. }
  1131. // db_update generates a db.<name>Context(ctx, sqlstr, regularparams,
  1132. // primaryparams)
  1133. func (f *Funcs) db_update(name string, v interface{}) string {
  1134. var ignore, p []string
  1135. switch x := v.(type) {
  1136. case Table:
  1137. prefix := f.short(x.GoName) + "."
  1138. for _, pk := range x.PrimaryKeys {
  1139. ignore = append(ignore, pk.GoName)
  1140. }
  1141. p = append(p, f.names_ignore(prefix, x, ignore...), f.names(prefix, x.PrimaryKeys))
  1142. default:
  1143. return fmt.Sprintf("[[ UNSUPPORTED TYPE 9: %T ]]", v)
  1144. }
  1145. return f.db(name, strings.Join(p, ", "))
  1146. }
  1147. // db_named generates a db.<name>Context(ctx, sql.Named(name, res)...)
  1148. func (f *Funcs) db_named(name string, v interface{}) string {
  1149. var p []string
  1150. switch x := v.(type) {
  1151. case Proc:
  1152. for _, z := range x.Params {
  1153. p = append(p, f.named(z.SQLName, z.GoName, false))
  1154. }
  1155. for _, z := range x.Returns {
  1156. p = append(p, f.named(z.SQLName, "&"+z.GoName, true))
  1157. }
  1158. default:
  1159. return fmt.Sprintf("[[ UNSUPPORTED TYPE 10: %T ]]", v)
  1160. }
  1161. return f.db(name, strings.Join(p, ", "))
  1162. }
  1163. func (f *Funcs) named(name, value string, out bool) string {
  1164. if out {
  1165. return fmt.Sprintf("sql.Named(%q, sql.Out{Dest: %s})", name, value)
  1166. }
  1167. return fmt.Sprintf("sql.Named(%q, %s)", name, value)
  1168. }
  1169. func (f *Funcs) logf_pkeys(v interface{}) string {
  1170. p := []string{"sqlstr"}
  1171. switch x := v.(type) {
  1172. case Table:
  1173. p = append(p, f.names(f.short(x.GoName)+".", x.PrimaryKeys))
  1174. }
  1175. return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
  1176. }
  1177. func (f *Funcs) logf(v interface{}, ignore ...interface{}) string {
  1178. var ignoreNames []string
  1179. p := []string{"sqlstr"}
  1180. // build ignore list
  1181. for i, x := range ignore {
  1182. switch z := x.(type) {
  1183. case string:
  1184. ignoreNames = append(ignoreNames, z)
  1185. case Field:
  1186. ignoreNames = append(ignoreNames, z.GoName)
  1187. case []Field:
  1188. for _, f := range z {
  1189. ignoreNames = append(ignoreNames, f.GoName)
  1190. }
  1191. default:
  1192. return fmt.Sprintf("[[ UNSUPPORTED TYPE 11 (%d): %T ]]", i, x)
  1193. }
  1194. }
  1195. // add fields
  1196. switch x := v.(type) {
  1197. case Table:
  1198. p = append(p, f.names_ignore(f.short(x.GoName)+".", x, ignoreNames...))
  1199. default:
  1200. return fmt.Sprintf("[[ UNSUPPORTED TYPE 12: %T ]]", v)
  1201. }
  1202. return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
  1203. }
  1204. func (f *Funcs) logf_update(v interface{}) string {
  1205. var ignore []string
  1206. p := []string{"sqlstr"}
  1207. switch x := v.(type) {
  1208. case Table:
  1209. prefix := f.short(x.GoName) + "."
  1210. for _, pk := range x.PrimaryKeys {
  1211. ignore = append(ignore, pk.GoName)
  1212. }
  1213. p = append(p, f.names_ignore(prefix, x, ignore...), f.names(prefix, x.PrimaryKeys))
  1214. default:
  1215. return fmt.Sprintf("[[ UNSUPPORTED TYPE 13: %T ]]", v)
  1216. }
  1217. return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
  1218. }
  1219. // names generates a list of names.
  1220. func (f *Funcs) namesfn(all bool, prefix string, z ...interface{}) string {
  1221. var names []string
  1222. for i, v := range z {
  1223. switch x := v.(type) {
  1224. case string:
  1225. names = append(names, x)
  1226. case Query:
  1227. for _, p := range x.Params {
  1228. if !all && p.Interpolate {
  1229. continue
  1230. }
  1231. names = append(names, prefix+p.Name)
  1232. }
  1233. case Table:
  1234. for _, p := range x.Fields {
  1235. names = append(names, prefix+checkName(p.GoName))
  1236. }
  1237. case []Field:
  1238. for _, p := range x {
  1239. names = append(names, prefix+checkName(p.GoName))
  1240. }
  1241. case Proc:
  1242. if params := f.params(x.Params, false); params != "" {
  1243. names = append(names, params)
  1244. }
  1245. case Index:
  1246. names = append(names, f.params(x.Fields, false))
  1247. default:
  1248. names = append(names, fmt.Sprintf("/* UNSUPPORTED TYPE 14 (%d): %T */", i, v))
  1249. }
  1250. }
  1251. return strings.Join(names, ", ")
  1252. }
  1253. // names generates a list of names (excluding certain ones such as interpolated
  1254. // names).
  1255. func (f *Funcs) names(prefix string, z ...interface{}) string {
  1256. return f.namesfn(false, prefix, z...)
  1257. }
  1258. // names_all generates a list of all names.
  1259. func (f *Funcs) names_all(prefix string, z ...interface{}) string {
  1260. return f.namesfn(true, prefix, z...)
  1261. }
  1262. // names_ignore generates a list of all names, ignoring fields that match the value in ignore.
  1263. func (f *Funcs) names_ignore(prefix string, v interface{}, ignore ...string) string {
  1264. m := make(map[string]bool)
  1265. for _, n := range ignore {
  1266. m[n] = true
  1267. }
  1268. var vals []Field
  1269. switch x := v.(type) {
  1270. case Table:
  1271. for _, p := range x.Fields {
  1272. if m[p.GoName] {
  1273. continue
  1274. }
  1275. vals = append(vals, p)
  1276. }
  1277. case []Field:
  1278. for _, p := range x {
  1279. if m[p.GoName] {
  1280. continue
  1281. }
  1282. vals = append(vals, p)
  1283. }
  1284. default:
  1285. return fmt.Sprintf("[[ UNSUPPORTED TYPE 15: %T ]]", v)
  1286. }
  1287. return f.namesfn(true, prefix, vals)
  1288. }
  1289. // querystr generates a querystr for the specified query and any accompanying
  1290. // comments.
  1291. func (f *Funcs) querystr(v interface{}) string {
  1292. var interpolate bool
  1293. var query, comments []string
  1294. switch x := v.(type) {
  1295. case Query:
  1296. interpolate, query, comments = x.Interpolate, x.Query, x.Comments
  1297. default:
  1298. return fmt.Sprintf("const sqlstr = [[ UNSUPPORTED TYPE 16: %T ]]", v)
  1299. }
  1300. typ := "const"
  1301. if interpolate {
  1302. typ = "var"
  1303. }
  1304. var lines []string
  1305. for i := 0; i < len(query); i++ {
  1306. line := "`" + query[i] + "`"
  1307. if i != len(query)-1 {
  1308. line += " + "
  1309. }
  1310. if s := strings.TrimSpace(comments[i]); s != "" {
  1311. line += "// " + s
  1312. }
  1313. lines = append(lines, line)
  1314. }
  1315. sqlstr := stripRE.ReplaceAllString(strings.Join(lines, "\n"), " ")
  1316. return fmt.Sprintf("%s sqlstr = %s", typ, sqlstr)
  1317. }
  1318. var stripRE = regexp.MustCompile(`\s+\+\s+` + "``")
  1319. func (f *Funcs) sqlstr(typ string, v interface{}) string {
  1320. var lines []string
  1321. switch typ {
  1322. case "insert_manual":
  1323. lines = f.sqlstr_insert_manual(v)
  1324. case "insert":
  1325. lines = f.sqlstr_insert(v)
  1326. case "update":
  1327. lines = f.sqlstr_update(v)
  1328. case "upsert":
  1329. lines = f.sqlstr_upsert(v)
  1330. case "delete":
  1331. lines = f.sqlstr_delete(v)
  1332. case "proc":
  1333. lines = f.sqlstr_proc(v)
  1334. case "index":
  1335. lines = f.sqlstr_index(v)
  1336. default:
  1337. return fmt.Sprintf("const sqlstr = `UNKNOWN QUERY TYPE: %s`", typ)
  1338. }
  1339. return fmt.Sprintf("const sqlstr = `%s`", strings.Join(lines, "` +\n\t`"))
  1340. }
  1341. // sqlstr_insert_base builds an INSERT query
  1342. // If not all, sequence columns are skipped.
  1343. func (f *Funcs) sqlstr_insert_base(all bool, v interface{}) []string {
  1344. switch x := v.(type) {
  1345. case Table:
  1346. // build names and values
  1347. var n int
  1348. var fields, vals []string
  1349. for _, z := range x.Fields {
  1350. if z.IsSequence && !all {
  1351. continue
  1352. }
  1353. fields, vals = append(fields, f.colname(z)), append(vals, f.nth(n))
  1354. n++
  1355. }
  1356. return []string{
  1357. "INSERT INTO " + f.schemafn(x.SQLName) + " (",
  1358. strings.Join(fields, ", "),
  1359. ") VALUES (",
  1360. strings.Join(vals, ", "),
  1361. ")",
  1362. }
  1363. }
  1364. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 17: %T ]]", v)}
  1365. }
  1366. // sqlstr_insert_manual builds an INSERT query that inserts all fields.
  1367. func (f *Funcs) sqlstr_insert_manual(v interface{}) []string {
  1368. return f.sqlstr_insert_base(true, v)
  1369. }
  1370. // sqlstr_insert builds an INSERT query, skipping the sequence field with
  1371. // applicable RETURNING clause for generated primary key fields.
  1372. func (f *Funcs) sqlstr_insert(v interface{}) []string {
  1373. switch x := v.(type) {
  1374. case Table:
  1375. var seq Field
  1376. for _, field := range x.Fields {
  1377. if field.IsSequence {
  1378. seq = field
  1379. }
  1380. }
  1381. lines := f.sqlstr_insert_base(false, v)
  1382. // add return clause
  1383. switch f.driver {
  1384. case "oracle":
  1385. lines[len(lines)-1] += ` RETURNING ` + f.colname(seq) + ` /*LASTINSERTID*/ INTO :pk`
  1386. case "postgres":
  1387. lines[len(lines)-1] += ` RETURNING ` + f.colname(seq)
  1388. case "sqlserver":
  1389. lines[len(lines)-1] += "; SELECT ID = CONVERT(BIGINT, SCOPE_IDENTITY())"
  1390. }
  1391. return lines
  1392. }
  1393. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 18: %T ]]", v)}
  1394. }
  1395. // sqlstr_update_base builds an UPDATE query, using primary key fields as the WHERE
  1396. // clause, adding prefix.
  1397. //
  1398. // When prefix is empty, the WHERE clause will be in the form of name = $1.
  1399. // When prefix is non-empty, the WHERE clause will be in the form of name = <PREFIX>name.
  1400. //
  1401. // Similarly, when prefix is empty, the table's name is added after UPDATE,
  1402. // otherwise it is omitted.
  1403. func (f *Funcs) sqlstr_update_base(prefix string, v interface{}) (int, []string) {
  1404. switch x := v.(type) {
  1405. case Table:
  1406. // build names and values
  1407. var n int
  1408. var list []string
  1409. for _, z := range x.Fields {
  1410. if z.IsPrimary {
  1411. continue
  1412. }
  1413. name, param := f.colname(z), f.nth(n)
  1414. if prefix != "" {
  1415. param = prefix + name
  1416. }
  1417. list = append(list, fmt.Sprintf("%s = %s", name, param))
  1418. n++
  1419. }
  1420. name := ""
  1421. if prefix == "" {
  1422. name = f.schemafn(x.SQLName) + " "
  1423. }
  1424. return n, []string{
  1425. "UPDATE " + name + "SET ",
  1426. strings.Join(list, ", ") + " ",
  1427. }
  1428. }
  1429. return 0, []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 19: %T ]]", v)}
  1430. }
  1431. // sqlstr_update builds an UPDATE query, using primary key fields as the WHERE
  1432. // clause.
  1433. func (f *Funcs) sqlstr_update(v interface{}) []string {
  1434. // build pkey vals
  1435. switch x := v.(type) {
  1436. case Table:
  1437. var list []string
  1438. n, lines := f.sqlstr_update_base("", v)
  1439. for i, z := range x.PrimaryKeys {
  1440. list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(n+i)))
  1441. }
  1442. return append(lines, "WHERE "+strings.Join(list, " AND "))
  1443. }
  1444. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 20: %T ]]", v)}
  1445. }
  1446. func (f *Funcs) sqlstr_upsert(v interface{}) []string {
  1447. switch x := v.(type) {
  1448. case Table:
  1449. // build insert
  1450. lines := f.sqlstr_insert_base(true, x)
  1451. switch f.driver {
  1452. case "postgres", "sqlite3":
  1453. return append(lines, f.sqlstr_upsert_postgres_sqlite(x)...)
  1454. case "mysql":
  1455. return append(lines, f.sqlstr_upsert_mysql(x)...)
  1456. case "sqlserver", "oracle":
  1457. return f.sqlstr_upsert_sqlserver_oracle(x)
  1458. }
  1459. }
  1460. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 21 %s: %T ]]", f.driver, v)}
  1461. }
  1462. // sqlstr_upsert_postgres_sqlite builds an uspert query for postgres and sqlite
  1463. //
  1464. // INSERT (..) VALUES (..) ON CONFLICT DO UPDATE SET ...
  1465. func (f *Funcs) sqlstr_upsert_postgres_sqlite(v interface{}) []string {
  1466. switch x := v.(type) {
  1467. case Table:
  1468. // add conflict and update
  1469. var conflicts []string
  1470. for _, f := range x.PrimaryKeys {
  1471. conflicts = append(conflicts, f.SQLName)
  1472. }
  1473. lines := []string{" ON CONFLICT (" + strings.Join(conflicts, ", ") + ") DO "}
  1474. _, update := f.sqlstr_update_base("EXCLUDED.", v)
  1475. return append(lines, update...)
  1476. }
  1477. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 22: %T ]]", v)}
  1478. }
  1479. // sqlstr_upsert_mysql builds an uspert query for mysql
  1480. //
  1481. // INSERT (..) VALUES (..) ON DUPLICATE KEY UPDATE SET ...
  1482. func (f *Funcs) sqlstr_upsert_mysql(v interface{}) []string {
  1483. switch x := v.(type) {
  1484. case Table:
  1485. lines := []string{" ON DUPLICATE KEY UPDATE "}
  1486. var list []string
  1487. i := len(x.Fields)
  1488. for _, z := range x.Fields {
  1489. if z.IsSequence {
  1490. continue
  1491. }
  1492. name := f.colname(z)
  1493. list = append(list, fmt.Sprintf("%s = VALUES(%s)", name, name))
  1494. i++
  1495. }
  1496. return append(lines, strings.Join(list, ", "))
  1497. }
  1498. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 23: %T ]]", v)}
  1499. }
  1500. // sqlstr_upsert_sqlserver_oracle builds an upsert query for sqlserver
  1501. //
  1502. // MERGE [table] AS target USING (SELECT [pkeys]) AS source ...
  1503. func (f *Funcs) sqlstr_upsert_sqlserver_oracle(v interface{}) []string {
  1504. switch x := v.(type) {
  1505. case Table:
  1506. var lines []string
  1507. // merge [table]...
  1508. switch f.driver {
  1509. case "sqlserver":
  1510. lines = []string{"MERGE " + f.schemafn(x.SQLName) + " AS t "}
  1511. case "oracle":
  1512. lines = []string{"MERGE " + f.schemafn(x.SQLName) + "t "}
  1513. }
  1514. // using (select ..)
  1515. var fields, predicate []string
  1516. for i, field := range x.Fields {
  1517. fields = append(fields, fmt.Sprintf("%s %s", f.nth(i), field.SQLName))
  1518. }
  1519. for _, field := range x.PrimaryKeys {
  1520. predicate = append(predicate, fmt.Sprintf("s.%s = t.%s", field.SQLName, field.SQLName))
  1521. }
  1522. // closing part for select
  1523. var closing string
  1524. switch f.driver {
  1525. case "sqlserver":
  1526. closing = `) AS s `
  1527. case "oracle":
  1528. closing = `FROM DUAL ) s `
  1529. }
  1530. lines = append(lines, `USING (`,
  1531. `SELECT `+strings.Join(fields, ", ")+" ",
  1532. closing,
  1533. `ON `+strings.Join(predicate, " AND ")+" ")
  1534. // build param lists
  1535. var updateParams, insertParams, insertVals []string
  1536. for _, field := range x.Fields {
  1537. // sequences are always managed by db
  1538. if field.IsSequence {
  1539. continue
  1540. }
  1541. // primary keys
  1542. if !field.IsPrimary {
  1543. updateParams = append(updateParams, fmt.Sprintf("t.%s = s.%s", field.SQLName, field.SQLName))
  1544. }
  1545. insertParams = append(insertParams, field.SQLName)
  1546. insertVals = append(insertVals, "s."+field.SQLName)
  1547. }
  1548. // when matched then update...
  1549. lines = append(lines,
  1550. `WHEN MATCHED THEN `, `UPDATE SET `,
  1551. strings.Join(updateParams, ", ")+" ",
  1552. `WHEN NOT MATCHED THEN `,
  1553. `INSERT (`,
  1554. strings.Join(insertParams, ", "),
  1555. `) VALUES (`,
  1556. strings.Join(insertVals, ", "),
  1557. `);`,
  1558. )
  1559. return lines
  1560. }
  1561. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 24: %T ]]", v)}
  1562. }
  1563. // sqlstr_delete builds a DELETE query for the primary keys.
  1564. func (f *Funcs) sqlstr_delete(v interface{}) []string {
  1565. switch x := v.(type) {
  1566. case Table:
  1567. // names and values
  1568. var list []string
  1569. for i, z := range x.PrimaryKeys {
  1570. list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(i)))
  1571. }
  1572. return []string{
  1573. "DELETE FROM " + f.schemafn(x.SQLName) + " ",
  1574. "WHERE " + strings.Join(list, " AND "),
  1575. }
  1576. }
  1577. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 25: %T ]]", v)}
  1578. }
  1579. // sqlstr_index builds a index fields.
  1580. func (f *Funcs) sqlstr_index(v interface{}) []string {
  1581. switch x := v.(type) {
  1582. case Index:
  1583. // build table fieldnames
  1584. var fields []string
  1585. for _, z := range x.Table.Fields {
  1586. fields = append(fields, f.colname(z))
  1587. }
  1588. // index fields
  1589. var list []string
  1590. for i, z := range x.Fields {
  1591. list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(i)))
  1592. }
  1593. return []string{
  1594. "SELECT ",
  1595. strings.Join(fields, ", ") + " ",
  1596. "FROM " + f.schemafn(x.Table.SQLName) + " ",
  1597. "WHERE " + strings.Join(list, " AND "),
  1598. }
  1599. }
  1600. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 26: %T ]]", v)}
  1601. }
  1602. // sqlstr_proc builds a stored procedure call.
  1603. func (f *Funcs) sqlstr_proc(v interface{}) []string {
  1604. switch x := v.(type) {
  1605. case Proc:
  1606. if x.Type == "function" {
  1607. return f.sqlstr_func(v)
  1608. }
  1609. // sql string format
  1610. var format string
  1611. switch f.driver {
  1612. case "postgres", "mysql":
  1613. format = "CALL %s(%s)"
  1614. case "sqlserver":
  1615. format = "%[1]s"
  1616. case "oracle":
  1617. format = "BEGIN %s(%s); END;"
  1618. }
  1619. // build params list; add return fields for orcle
  1620. l := x.Params
  1621. if f.driver == "oracle" {
  1622. l = append(l, x.Returns...)
  1623. }
  1624. var list []string
  1625. for i, field := range l {
  1626. s := f.nth(i)
  1627. if f.driver == "oracle" {
  1628. s = ":" + field.SQLName
  1629. }
  1630. list = append(list, s)
  1631. }
  1632. // dont prefix with schema for oracle
  1633. name := f.schemafn(x.SQLName)
  1634. if f.driver == "oracle" {
  1635. name = x.SQLName
  1636. }
  1637. return []string{
  1638. fmt.Sprintf(format, name, strings.Join(list, ", ")),
  1639. }
  1640. }
  1641. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 27: %T ]]", v)}
  1642. }
  1643. func (f *Funcs) sqlstr_func(v interface{}) []string {
  1644. switch x := v.(type) {
  1645. case Proc:
  1646. var format string
  1647. switch f.driver {
  1648. case "postgres":
  1649. format = "SELECT * FROM %s(%s)"
  1650. case "mysql":
  1651. format = "SELECT %s(%s)"
  1652. case "sqlserver":
  1653. format = "SELECT %s(%s) AS OUT"
  1654. case "oracle":
  1655. format = "SELECT %s(%s) FROM dual"
  1656. }
  1657. var list []string
  1658. l := x.Params
  1659. for i := range l {
  1660. list = append(list, f.nth(i))
  1661. }
  1662. return []string{
  1663. fmt.Sprintf(format, f.schemafn(x.SQLName), strings.Join(list, ", ")),
  1664. }
  1665. }
  1666. return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 28: %T ]]", v)}
  1667. }
  1668. // convertTypes generates the conversions to convert the foreign key field
  1669. // types to their respective referenced field types.
  1670. func (f *Funcs) convertTypes(fkey ForeignKey) string {
  1671. var p []string
  1672. for i := range fkey.Fields {
  1673. field := fkey.Fields[i]
  1674. refField := fkey.RefFields[i]
  1675. expr := f.short(fkey.Table) + "." + field.GoName
  1676. // types match, can match
  1677. if field.Type == refField.Type {
  1678. p = append(p, expr)
  1679. continue
  1680. }
  1681. // convert types
  1682. typ, refType := field.Type, refField.Type
  1683. if strings.HasPrefix(typ, "sql.Null") {
  1684. expr = expr + "." + typ[8:]
  1685. typ = strings.ToLower(typ[8:])
  1686. }
  1687. if strings.ToLower(refType) != typ {
  1688. expr = refType + "(" + expr + ")"
  1689. }
  1690. p = append(p, expr)
  1691. }
  1692. return strings.Join(p, ", ")
  1693. }
  1694. // params converts a list of fields into their named Go parameters, skipping
  1695. // any Field with Name contained in ignore. addType will cause the go Type to
  1696. // be added after each variable name. addPrefix will cause the returned string
  1697. // to be prefixed with ", " if the generated string is not empty.
  1698. //
  1699. // Any field name encountered will be checked against goReservedNames, and will
  1700. // have its name substituted by its corresponding looked up value.
  1701. //
  1702. // Used to present a comma separated list of Go variable names for use with as
  1703. // either a Go func parameter list, or in a call to another Go func.
  1704. // (ie, ", a, b, c, ..." or ", a T1, b T2, c T3, ...").
  1705. func (f *Funcs) params(fields []Field, addType bool) string {
  1706. var vals []string
  1707. for _, field := range fields {
  1708. vals = append(vals, f.param(field, addType))
  1709. }
  1710. return strings.Join(vals, ", ")
  1711. }
  1712. func (f *Funcs) param(field Field, addType bool) string {
  1713. n := strings.Split(snaker.CamelToSnake(field.GoName), "_")
  1714. s := strings.ToLower(n[0]) + field.GoName[len(n[0]):]
  1715. // check go reserved names
  1716. if r, ok := goReservedNames[strings.ToLower(s)]; ok {
  1717. s = r
  1718. }
  1719. // add the go type
  1720. if addType {
  1721. s += " " + f.typefn(field.Type)
  1722. }
  1723. // add to vals
  1724. return s
  1725. }
  1726. // zero generates a zero list.
  1727. func (f *Funcs) zero(z ...interface{}) string {
  1728. var zeroes []string
  1729. for i, v := range z {
  1730. switch x := v.(type) {
  1731. case string:
  1732. zeroes = append(zeroes, x)
  1733. case Table:
  1734. for _, p := range x.Fields {
  1735. zeroes = append(zeroes, f.zero(p))
  1736. }
  1737. case []Field:
  1738. for _, p := range x {
  1739. zeroes = append(zeroes, f.zero(p))
  1740. }
  1741. case Field:
  1742. if _, ok := f.knownTypes[x.Type]; ok || x.Zero == "nil" {
  1743. zeroes = append(zeroes, x.Zero)
  1744. break
  1745. }
  1746. zeroes = append(zeroes, f.typefn(x.Type)+"{}")
  1747. default:
  1748. zeroes = append(zeroes, fmt.Sprintf("/* UNSUPPORTED TYPE 29 (%d): %T */", i, v))
  1749. }
  1750. }
  1751. return strings.Join(zeroes, ", ")
  1752. }
  1753. // typefn generates the Go type, prefixing the custom package name if applicable.
  1754. func (f *Funcs) typefn(typ string) string {
  1755. if strings.Contains(typ, ".") {
  1756. return typ
  1757. }
  1758. var prefix string
  1759. for strings.HasPrefix(typ, "[]") {
  1760. typ = typ[2:]
  1761. prefix += "[]"
  1762. }
  1763. if _, ok := f.knownTypes[typ]; ok || f.custom == "" {
  1764. return prefix + typ
  1765. }
  1766. return prefix + f.custom + "." + typ
  1767. }
  1768. // field generates a field definition for a struct.
  1769. func (f *Funcs) field(field Field) (string, error) {
  1770. buf := new(bytes.Buffer)
  1771. if err := f.fieldtag.Funcs(f.FuncMap()).Execute(buf, field); err != nil {
  1772. return "", err
  1773. }
  1774. var tag string
  1775. if s := buf.String(); s != "" {
  1776. tag = " `" + s + "`"
  1777. }
  1778. return fmt.Sprintf("\t%s %s%s // %s", field.GoName, f.typefn(field.Type), tag, field.SQLName), nil
  1779. }
  1780. // short generates a safe Go identifier for typ. typ is first checked
  1781. // against shorts, and if not found, then the value is calculated and
  1782. // stored in the shorts for future use.
  1783. //
  1784. // A short is the concatenation of the lowercase of the first character in
  1785. // the words comprising the name. For example, "MyCustomName" will have have
  1786. // the short of "mcn".
  1787. //
  1788. // If a generated short conflicts with a Go reserved name or a name used in
  1789. // the templates, then the corresponding value in goReservedNames map will be
  1790. // used.
  1791. //
  1792. // Generated shorts that have conflicts with any scopeConflicts member will
  1793. // have nameConflictSuffix appended.
  1794. func (f *Funcs) short(v interface{}) string {
  1795. var n string
  1796. switch x := v.(type) {
  1797. case string:
  1798. n = x
  1799. case Table:
  1800. n = x.GoName
  1801. default:
  1802. return fmt.Sprintf("[[ UNSUPPORTED TYPE 30: %T ]]", v)
  1803. }
  1804. // check short name map
  1805. name, ok := f.shorts[n]
  1806. if !ok {
  1807. // calc the short name
  1808. var u []string
  1809. for _, s := range strings.Split(strings.ToLower(snaker.CamelToSnake(n)), "_") {
  1810. if len(s) > 0 && s != "id" {
  1811. u = append(u, s[:1])
  1812. }
  1813. }
  1814. // ensure no name conflict
  1815. name = checkName(strings.Join(u, ""))
  1816. // store back to short name map
  1817. f.shorts[n] = name
  1818. }
  1819. // append suffix if conflict exists
  1820. if _, ok := templateReservedNames[name]; ok {
  1821. name += f.conflict
  1822. }
  1823. return name
  1824. }
  1825. // colname returns the ColumnName of a field escaped if needed.
  1826. func (f *Funcs) colname(z Field) string {
  1827. if f.escColumn {
  1828. return escfn(z.SQLName)
  1829. }
  1830. return z.SQLName
  1831. }
  1832. func checkName(name string) string {
  1833. if n, ok := goReservedNames[name]; ok {
  1834. return n
  1835. }
  1836. return name
  1837. }
  1838. // escfn escapes s.
  1839. func escfn(s string) string {
  1840. return `"` + s + `"`
  1841. }
  1842. // eval evalutates a template s against v.
  1843. func eval(v interface{}, s string) (string, error) {
  1844. tpl, err := template.New(fmt.Sprintf("[EVAL %q]", s)).Parse(s)
  1845. if err != nil {
  1846. return "", err
  1847. }
  1848. buf := new(bytes.Buffer)
  1849. if err := tpl.Execute(buf, v); err != nil {
  1850. return "", err
  1851. }
  1852. return buf.String(), nil
  1853. }
  1854. // templateReservedNames are the template reserved names.
  1855. var templateReservedNames = map[string]bool{
  1856. // variables
  1857. "ctx": true,
  1858. "db": true,
  1859. "err": true,
  1860. "log": true,
  1861. "logf": true,
  1862. "res": true,
  1863. "rows": true,
  1864. // packages
  1865. "context": true,
  1866. "csv": true,
  1867. "driver": true,
  1868. "errors": true,
  1869. "fmt": true,
  1870. "hstore": true,
  1871. "regexp": true,
  1872. "sql": true,
  1873. "strings": true,
  1874. "time