/templates/go/go.go
Go | 1961 lines | 1675 code | 87 blank | 199 comment | 309 complexity | 7cd217d447e5d84997b89f4e176c0fe3 MD5 | raw file
- //go:build xotpl
- package gotpl
- import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "io"
- "io/fs"
- "io/ioutil"
- "os"
- "path/filepath"
- "regexp"
- "strconv"
- "strings"
- "text/template"
- "github.com/kenshaw/inflector"
- "github.com/kenshaw/snaker"
- "github.com/xo/xo/loader"
- xo "github.com/xo/xo/types"
- "golang.org/x/tools/imports"
- "mvdan.cc/gofumpt/format"
- )
- var (
- ErrNoSingle = errors.New("in query exec mode, the --single or -S must be provided")
- )
- // Init registers the template.
- func Init(ctx context.Context, f func(xo.TemplateType)) error {
- knownTypes := map[string]bool{
- "bool": true,
- "string": true,
- "byte": true,
- "rune": true,
- "int": true,
- "int16": true,
- "int32": true,
- "int64": true,
- "uint": true,
- "uint8": true,
- "uint16": true,
- "uint32": true,
- "uint64": true,
- "float32": true,
- "float64": true,
- "Slice": true,
- "StringSlice": true,
- }
- shorts := map[string]string{
- "bool": "b",
- "string": "s",
- "byte": "b",
- "rune": "r",
- "int": "i",
- "int16": "i",
- "int32": "i",
- "int64": "i",
- "uint": "u",
- "uint8": "u",
- "uint16": "u",
- "uint32": "u",
- "uint64": "u",
- "float32": "f",
- "float64": "f",
- "Slice": "s",
- "StringSlice": "ss",
- }
- f(xo.TemplateType{
- Modes: []string{"query", "schema"},
- Flags: []xo.Flag{
- {
- ContextKey: AppendKey,
- Type: "bool",
- Desc: "enable append mode",
- Short: "a",
- Aliases: []string{"append"},
- },
- {
- ContextKey: NotFirstKey,
- Type: "bool",
- Desc: "disable package file (ie. not first generated file)",
- Short: "2",
- Default: "false",
- },
- {
- ContextKey: Int32Key,
- Type: "string",
- Desc: "int32 type",
- Default: "int",
- },
- {
- ContextKey: Uint32Key,
- Type: "string",
- Desc: "uint32 type",
- Default: "uint",
- },
- {
- ContextKey: PkgKey,
- Type: "string",
- Desc: "package name",
- },
- {
- ContextKey: TagKey,
- Type: "[]string",
- Desc: "build tags",
- },
- {
- ContextKey: ImportKey,
- Type: "[]string",
- Desc: "package imports",
- },
- {
- ContextKey: UUIDKey,
- Type: "string",
- Desc: "uuid type package",
- Default: "github.com/google/uuid",
- },
- {
- ContextKey: CustomKey,
- Type: "string",
- Desc: "package name for custom types",
- },
- {
- ContextKey: ConflictKey,
- Type: "string",
- Desc: "name conflict suffix",
- Default: "Val",
- },
- {
- ContextKey: InitialismKey,
- Type: "[]string",
- Desc: "add initialism (e.g. ID, API, URI, ...)",
- },
- {
- ContextKey: EscKey,
- Type: "[]string",
- Desc: "escape fields",
- Default: "none",
- Enums: []string{"none", "schema", "table", "column", "all"},
- },
- {
- ContextKey: FieldTagKey,
- Type: "string",
- Desc: "field tag",
- Short: "g",
- Default: `json:"{{ .SQLName }}"`,
- },
- {
- ContextKey: ContextKey,
- Type: "string",
- Desc: "context mode",
- Default: "only",
- Enums: []string{"disable", "both", "only"},
- },
- {
- ContextKey: InjectKey,
- Type: "string",
- Desc: "insert code into generated file headers",
- Default: "",
- },
- {
- ContextKey: InjectFileKey,
- Type: "string",
- Desc: "insert code into generated file headers from a file",
- Default: "",
- },
- {
- ContextKey: LegacyKey,
- Type: "bool",
- Desc: "enables legacy v1 template funcs",
- Default: "false",
- },
- },
- Funcs: func(ctx context.Context, _ string) (template.FuncMap, error) {
- funcs, err := NewFuncs(ctx)
- if err != nil {
- return nil, err
- }
- if Legacy(ctx) {
- addLegacyFuncs(ctx, funcs)
- }
- return funcs, nil
- },
- NewContext: func(ctx context.Context, _ string) context.Context {
- ctx = context.WithValue(ctx, KnownTypesKey, knownTypes)
- ctx = context.WithValue(ctx, ShortsKey, shorts)
- return ctx
- },
- Order: func(ctx context.Context, mode string) []string {
- base := []string{"header", "db"}
- switch mode {
- case "query":
- return append(base, "typedef", "query")
- case "schema":
- return append(base, "enum", "proc", "typedef", "query", "index", "foreignkey")
- }
- return nil
- },
- Pre: func(ctx context.Context, mode string, set *xo.Set, out fs.FS, emit func(xo.Template)) error {
- if err := addInitialisms(ctx); err != nil {
- return err
- }
- files, err := fileNames(ctx, mode, set)
- if err != nil {
- return err
- }
- // If -2 is provided, skip package template outputs as requested.
- // If -a is provided, skip to avoid duplicating the template.
- if !NotFirst(ctx) && !Append(ctx) {
- emit(xo.Template{
- Partial: "db",
- Dest: "db.xo.go",
- })
- // If --single is provided, don't generate header for db.xo.go.
- if xo.Single(ctx) == "" {
- files["db.xo.go"] = true
- }
- }
- if Append(ctx) {
- for filename := range files {
- f, err := out.Open(filename)
- switch {
- case errors.Is(err, os.ErrNotExist):
- continue
- case err != nil:
- return err
- }
- defer f.Close()
- data, err := io.ReadAll(f)
- if err != nil {
- return err
- }
- emit(xo.Template{
- Src: "{{.Data}}",
- Partial: "header", // ordered first
- Data: string(data),
- Dest: filename,
- })
- delete(files, filename)
- }
- }
- for filename := range files {
- emit(xo.Template{
- Partial: "header",
- Dest: filename,
- })
- }
- return nil
- },
- Process: func(ctx context.Context, mode string, set *xo.Set, emit func(xo.Template)) error {
- if mode == "query" {
- for _, query := range set.Queries {
- if err := emitQuery(ctx, query, emit); err != nil {
- return err
- }
- }
- } else {
- for _, schema := range set.Schemas {
- if err := emitSchema(ctx, schema, emit); err != nil {
- return err
- }
- }
- }
- return nil
- },
- Post: func(ctx context.Context, mode string, files map[string][]byte, emit func(string, []byte)) error {
- for file, content := range files {
- // Run goimports.
- buf, err := imports.Process("", content, nil)
- if err != nil {
- return fmt.Errorf("%s:%w", file, err)
- }
- // Run gofumpt.
- formatted, err := format.Source(buf, format.Options{
- ExtraRules: true,
- })
- if err != nil {
- return err
- }
- emit(file, formatted)
- }
- return nil
- },
- })
- return nil
- }
- // fileNames returns a list of file names that will be generated by the
- // template based on the parameters and schema.
- func fileNames(ctx context.Context, mode string, set *xo.Set) (map[string]bool, error) {
- // In single mode, only the specified file be generated.
- singleFile := xo.Single(ctx)
- if singleFile != "" {
- return map[string]bool{
- singleFile: true,
- }, nil
- }
- // Otherwise, infer filenames from set.
- files := make(map[string]bool)
- addFile := func(filename string) {
- // Filenames are always lowercase.
- filename = strings.ToLower(filename)
- files[filename+ext] = true
- }
- switch mode {
- case "schema":
- for _, schema := range set.Schemas {
- for _, e := range schema.Enums {
- addFile(camelExport(e.Name))
- }
- for _, p := range schema.Procs {
- goName := camelExport(p.Name)
- if p.Type == "function" {
- addFile("sf_" + goName)
- } else {
- addFile("sp_" + goName)
- }
- }
- for _, t := range schema.Tables {
- addFile(camelExport(singularize(t.Name)))
- }
- for _, v := range schema.Views {
- addFile(camelExport(singularize(v.Name)))
- }
- }
- case "query":
- for _, query := range set.Queries {
- addFile(query.Type)
- if query.Exec {
- // Single mode is handled at the start of the function but it
- // must be used for Exec queries.
- return nil, ErrNoSingle
- }
- }
- default:
- panic("unknown mode: " + mode)
- }
- return files, nil
- }
- // emitQuery emits the query.
- func emitQuery(ctx context.Context, query xo.Query, emit func(xo.Template)) error {
- var table Table
- // build type if needed
- if !query.Exec {
- var err error
- if table, err = buildQueryType(ctx, query); err != nil {
- return err
- }
- }
- // emit type definition
- if !query.Exec && !query.Flat && !Append(ctx) {
- emit(xo.Template{
- Partial: "typedef",
- Dest: strings.ToLower(table.GoName) + ext,
- SortType: query.Type,
- SortName: query.Name,
- Data: table,
- })
- }
- // build query params
- var params []QueryParam
- for _, param := range query.Params {
- params = append(params, QueryParam{
- Name: param.Name,
- Type: param.Type.Type,
- Interpolate: param.Interpolate,
- Join: param.Join,
- })
- }
- // emit query
- emit(xo.Template{
- Partial: "query",
- Dest: strings.ToLower(table.GoName) + ext,
- SortType: query.Type,
- SortName: query.Name,
- Data: Query{
- Name: buildQueryName(query),
- Query: query.Query,
- Comments: query.Comments,
- Params: params,
- One: query.Exec || query.Flat || query.One,
- Flat: query.Flat,
- Exec: query.Exec,
- Interpolate: query.Interpolate,
- Type: table,
- Comment: query.Comment,
- },
- })
- return nil
- }
- func buildQueryType(ctx context.Context, query xo.Query) (Table, error) {
- tf := camelExport
- if query.Flat {
- tf = camel
- }
- var fields []Field
- for _, z := range query.Fields {
- f, err := convertField(ctx, tf, z)
- if err != nil {
- return Table{}, err
- }
- // dont use convertField; the types are already provided by the user
- if query.ManualFields {
- f = Field{
- GoName: z.Name,
- SQLName: snake(z.Name),
- Type: z.Type.Type,
- }
- }
- fields = append(fields, f)
- }
- sqlName := snake(query.Type)
- return Table{
- GoName: query.Type,
- SQLName: sqlName,
- Fields: fields,
- Comment: query.TypeComment,
- }, nil
- }
- // buildQueryName builds a name for the query.
- func buildQueryName(query xo.Query) string {
- if query.Name != "" {
- return query.Name
- }
- // generate name if not specified
- name := query.Type
- if !query.One {
- name = inflector.Pluralize(name)
- }
- // add params
- if len(query.Params) == 0 {
- name = "Get" + name
- } else {
- name += "By"
- for _, p := range query.Params {
- name += camelExport(p.Name)
- }
- }
- return name
- }
- // emitSchema emits the xo schema for the template set.
- func emitSchema(ctx context.Context, schema xo.Schema, emit func(xo.Template)) error {
- // emit enums
- for _, e := range schema.Enums {
- enum := convertEnum(e)
- emit(xo.Template{
- Partial: "enum",
- Dest: strings.ToLower(enum.GoName) + ext,
- SortName: enum.GoName,
- Data: enum,
- })
- }
- // build procs
- overloadMap := make(map[string][]Proc)
- // procOrder ensures procs are always emitted in alphabetic order for
- // consistency in single mode
- var procOrder []string
- for _, p := range schema.Procs {
- var err error
- if procOrder, err = convertProc(ctx, overloadMap, procOrder, p); err != nil {
- return err
- }
- }
- // emit procs
- for _, name := range procOrder {
- procs := overloadMap[name]
- prefix := "sp_"
- if procs[0].Type == "function" {
- prefix = "sf_"
- }
- // Set flag to change name to their overloaded versions if needed.
- for i := range procs {
- procs[i].Overloaded = len(procs) > 1
- }
- emit(xo.Template{
- Dest: prefix + strings.ToLower(name) + ext,
- Partial: "procs",
- SortName: prefix + name,
- Data: procs,
- })
- }
- // emit tables
- for _, t := range append(schema.Tables, schema.Views...) {
- table, err := convertTable(ctx, t)
- if err != nil {
- return err
- }
- emit(xo.Template{
- Dest: strings.ToLower(table.GoName) + ext,
- Partial: "typedef",
- SortType: table.Type,
- SortName: table.GoName,
- Data: table,
- })
- // emit indexes
- for _, i := range t.Indexes {
- index, err := convertIndex(ctx, table, i)
- if err != nil {
- return err
- }
- emit(xo.Template{
- Dest: strings.ToLower(table.GoName) + ext,
- Partial: "index",
- SortType: table.Type,
- SortName: index.SQLName,
- Data: index,
- })
- }
- // emit fkeys
- for _, fk := range t.ForeignKeys {
- fkey, err := convertFKey(ctx, table, fk)
- if err != nil {
- return err
- }
- emit(xo.Template{
- Dest: strings.ToLower(table.GoName) + ext,
- Partial: "foreignkey",
- SortType: table.Type,
- SortName: fkey.SQLName,
- Data: fkey,
- })
- }
- }
- return nil
- }
- // convertEnum converts a xo.Enum.
- func convertEnum(e xo.Enum) Enum {
- var vals []EnumValue
- goName := camelExport(e.Name)
- for _, v := range e.Values {
- name := camelExport(strings.ToLower(v.Name))
- if strings.HasSuffix(name, goName) && goName != name {
- name = strings.TrimSuffix(name, goName)
- }
- vals = append(vals, EnumValue{
- GoName: name,
- SQLName: v.Name,
- ConstValue: *v.ConstValue,
- })
- }
- return Enum{
- GoName: goName,
- SQLName: e.Name,
- Values: vals,
- }
- }
- // convertProc converts a xo.Proc.
- func convertProc(ctx context.Context, overloadMap map[string][]Proc, order []string, p xo.Proc) ([]string, error) {
- _, _, schema := xo.DriverDbSchema(ctx)
- proc := Proc{
- Type: p.Type,
- GoName: camelExport(p.Name),
- SQLName: p.Name,
- Signature: fmt.Sprintf("%s.%s", schema, p.Name),
- Void: p.Void,
- }
- // proc params
- var types []string
- for _, z := range p.Params {
- f, err := convertField(ctx, camel, z)
- if err != nil {
- return nil, err
- }
- proc.Params = append(proc.Params, f)
- types = append(types, z.Type.Type)
- }
- // add to signature, generate name
- proc.Signature += "(" + strings.Join(types, ", ") + ")"
- proc.OverloadedName = overloadedName(types, proc)
- types = nil
- // proc return
- for _, z := range p.Returns {
- f, err := convertField(ctx, camel, z)
- if err != nil {
- return nil, err
- }
- proc.Returns = append(proc.Returns, f)
- types = append(types, z.Type.Type)
- }
- // append signature
- if !p.Void {
- format := " (%s)"
- if len(p.Returns) == 1 {
- format = " %s"
- }
- proc.Signature += fmt.Sprintf(format, strings.Join(types, ", "))
- }
- // add proc
- procs, ok := overloadMap[proc.GoName]
- if !ok {
- order = append(order, proc.GoName)
- }
- overloadMap[proc.GoName] = append(procs, proc)
- return order, nil
- }
- // convertTable converts a xo.Table to a Table.
- func convertTable(ctx context.Context, t xo.Table) (Table, error) {
- var cols, pkCols []Field
- for _, z := range t.Columns {
- f, err := convertField(ctx, camelExport, z)
- if err != nil {
- return Table{}, err
- }
- cols = append(cols, f)
- if z.IsPrimary {
- pkCols = append(pkCols, f)
- }
- }
- return Table{
- GoName: camelExport(singularize(t.Name)),
- SQLName: t.Name,
- Fields: cols,
- PrimaryKeys: pkCols,
- Manual: t.Manual,
- }, nil
- }
- func convertIndex(ctx context.Context, t Table, i xo.Index) (Index, error) {
- var fields []Field
- for _, z := range i.Fields {
- f, err := convertField(ctx, camelExport, z)
- if err != nil {
- return Index{}, err
- }
- fields = append(fields, f)
- }
- return Index{
- SQLName: i.Name,
- Func: camelExport(i.Func),
- Table: t,
- Fields: fields,
- IsUnique: i.IsUnique,
- IsPrimary: i.IsPrimary,
- }, nil
- }
- func convertFKey(ctx context.Context, t Table, fk xo.ForeignKey) (ForeignKey, error) {
- var fields, refFields []Field
- // convert fields
- for _, f := range fk.Fields {
- field, err := convertField(ctx, camelExport, f)
- if err != nil {
- return ForeignKey{}, err
- }
- fields = append(fields, field)
- }
- // convert ref fields
- for _, f := range fk.RefFields {
- refField, err := convertField(ctx, camelExport, f)
- if err != nil {
- return ForeignKey{}, err
- }
- refFields = append(refFields, refField)
- }
- return ForeignKey{
- GoName: camelExport(fk.Func),
- SQLName: fk.Name,
- Table: t,
- Fields: fields,
- RefTable: camelExport(singularize(fk.RefTable)),
- RefFields: refFields,
- RefFunc: camelExport(fk.RefFunc),
- }, nil
- }
- func overloadedName(sqlTypes []string, proc Proc) string {
- if len(proc.Params) == 0 {
- return proc.GoName
- }
- var names []string
- // build parameters for proc.
- // if the proc's parameter has no name, use the types of the proc instead
- for i, f := range proc.Params {
- if f.SQLName == fmt.Sprintf("p%d", i) {
- names = append(names, camelExport(strings.Split(sqlTypes[i], " ")...))
- continue
- }
- names = append(names, camelExport(f.GoName))
- }
- if len(names) == 1 {
- return fmt.Sprintf("%sBy%s", proc.GoName, names[0])
- }
- front, last := strings.Join(names[:len(names)-1], ""), names[len(names)-1]
- return fmt.Sprintf("%sBy%sAnd%s", proc.GoName, front, last)
- }
- func convertField(ctx context.Context, tf transformFunc, f xo.Field) (Field, error) {
- typ, zero, err := goType(ctx, f.Type)
- if err != nil {
- return Field{}, err
- }
- return Field{
- Type: typ,
- GoName: tf(f.Name),
- SQLName: f.Name,
- Zero: zero,
- IsPrimary: f.IsPrimary,
- IsSequence: f.IsSequence,
- }, nil
- }
- func goType(ctx context.Context, typ xo.Type) (string, string, error) {
- driver, _, schema := xo.DriverDbSchema(ctx)
- var f func(xo.Type, string, string, string) (string, string, error)
- switch driver {
- case "mysql":
- f = loader.MysqlGoType
- case "oracle":
- f = loader.OracleGoType
- case "postgres":
- f = loader.PostgresGoType
- case "sqlite3":
- f = loader.Sqlite3GoType
- case "sqlserver":
- f = loader.SqlserverGoType
- default:
- return "", "", fmt.Errorf("unknown driver %q", driver)
- }
- return f(typ, schema, Int32(ctx), Uint32(ctx))
- }
- type transformFunc func(...string) string
- func snake(names ...string) string {
- return snaker.CamelToSnake(strings.Join(names, "_"))
- }
- func camel(names ...string) string {
- return snaker.ForceLowerCamelIdentifier(strings.Join(names, "_"))
- }
- func camelExport(names ...string) string {
- return snaker.ForceCamelIdentifier(strings.Join(names, "_"))
- }
- const ext = ".xo.go"
- // Funcs is a set of template funcs.
- type Funcs struct {
- driver string
- schema string
- nth func(int) string
- first bool
- pkg string
- tags []string
- imports []string
- conflict string
- custom string
- escSchema bool
- escTable bool
- escColumn bool
- fieldtag *template.Template
- context string
- inject string
- // knownTypes is the collection of known Go types.
- knownTypes map[string]bool
- // shorts is the collection of Go style short names for types, mainly
- // used for use with declaring a func receiver on a type.
- shorts map[string]string
- }
- // NewFuncs creates custom template funcs for the context.
- func NewFuncs(ctx context.Context) (template.FuncMap, error) {
- first := !NotFirst(ctx)
- // parse field tag template
- fieldtag, err := template.New("fieldtag").Parse(FieldTag(ctx))
- if err != nil {
- return nil, err
- }
- // load inject
- inject := Inject(ctx)
- if s := InjectFile(ctx); s != "" {
- buf, err := ioutil.ReadFile(s)
- if err != nil {
- return nil, fmt.Errorf("unable to read file: %v", err)
- }
- inject = string(buf)
- }
- driver, _, schema := xo.DriverDbSchema(ctx)
- nth, err := loader.NthParam(ctx)
- if err != nil {
- return nil, err
- }
- funcs := &Funcs{
- first: first,
- driver: driver,
- schema: schema,
- nth: nth,
- pkg: Pkg(ctx),
- tags: Tags(ctx),
- imports: Imports(ctx),
- conflict: Conflict(ctx),
- custom: Custom(ctx),
- escSchema: Esc(ctx, "schema"),
- escTable: Esc(ctx, "table"),
- escColumn: Esc(ctx, "column"),
- fieldtag: fieldtag,
- context: Context(ctx),
- inject: inject,
- knownTypes: KnownTypes(ctx),
- shorts: Shorts(ctx),
- }
- return funcs.FuncMap(), nil
- }
- // FuncMap returns the func map.
- func (f *Funcs) FuncMap() template.FuncMap {
- return template.FuncMap{
- // general
- "first": f.firstfn,
- "driver": f.driverfn,
- "schema": f.schemafn,
- "pkg": f.pkgfn,
- "tags": f.tagsfn,
- "imports": f.importsfn,
- "inject": f.injectfn,
- // context
- "context": f.contextfn,
- "context_both": f.context_both,
- "context_disable": f.context_disable,
- // func and query
- "func_name_context": f.func_name_context,
- "func_name": f.func_name_none,
- "func_context": f.func_context,
- "func": f.func_none,
- "recv_context": f.recv_context,
- "recv": f.recv_none,
- "foreign_key_context": f.foreign_key_context,
- "foreign_key": f.foreign_key_none,
- "db": f.db,
- "db_prefix": f.db_prefix,
- "db_update": f.db_update,
- "db_named": f.db_named,
- "named": f.named,
- "logf": f.logf,
- "logf_pkeys": f.logf_pkeys,
- "logf_update": f.logf_update,
- // type
- "names": f.names,
- "names_all": f.names_all,
- "names_ignore": f.names_ignore,
- "params": f.params,
- "zero": f.zero,
- "type": f.typefn,
- "field": f.field,
- "short": f.short,
- // sqlstr funcs
- "querystr": f.querystr,
- "sqlstr": f.sqlstr,
- // helpers
- "check_name": checkName,
- "eval": eval,
- }
- }
- func (f *Funcs) firstfn() bool {
- if f.first {
- f.first = false
- return true
- }
- return false
- }
- // driverfn returns true if the driver is any of the passed drivers.
- func (f *Funcs) driverfn(drivers ...string) bool {
- for _, driver := range drivers {
- if f.driver == driver {
- return true
- }
- }
- return false
- }
- // schemafn takes a series of names and joins them with the schema name.
- func (f *Funcs) schemafn(names ...string) string {
- s := f.schema
- // escape table names
- if f.escTable {
- for i, name := range names {
- names[i] = escfn(name)
- }
- }
- n := strings.Join(names, ".")
- switch {
- case s == "" && n == "":
- return ""
- case f.driver == "sqlite3" && n == "":
- return f.schema
- case f.driver == "sqlite3":
- return n
- case s != "" && n != "":
- if f.escSchema {
- s = escfn(s)
- }
- s += "."
- }
- return s + n
- }
- // pkgfn returns the package name.
- func (f *Funcs) pkgfn() string {
- return f.pkg
- }
- // tagsfn returns the tags.
- func (f *Funcs) tagsfn() []string {
- return f.tags
- }
- // importsfn returns the imports.
- func (f *Funcs) importsfn() []PackageImport {
- var imports []PackageImport
- for _, s := range f.imports {
- alias, pkg := "", s
- if i := strings.Index(pkg, " "); i != -1 {
- alias, pkg = pkg[:i], strings.TrimSpace(pkg[i:])
- }
- imports = append(imports, PackageImport{
- Alias: alias,
- Pkg: strconv.Quote(pkg),
- })
- }
- return imports
- }
- // contextfn returns true when the context mode is both or only.
- func (f *Funcs) contextfn() bool {
- return f.context == "both" || f.context == "only"
- }
- // context_both returns true with the context mode is both.
- func (f *Funcs) context_both() bool {
- return f.context == "both"
- }
- // context_disable returns true with the context mode is both.
- func (f *Funcs) context_disable() bool {
- return f.context == "disable"
- }
- // injectfn returns the injected content provided from args.
- func (f *Funcs) injectfn() string {
- return f.inject
- }
- // func_name_none builds a func name.
- func (f *Funcs) func_name_none(v interface{}) string {
- switch x := v.(type) {
- case string:
- return x
- case Query:
- return x.Name
- case Table:
- return x.GoName
- case ForeignKey:
- return x.GoName
- case Proc:
- n := x.GoName
- if x.Overloaded {
- n = x.OverloadedName
- }
- return n
- case Index:
- return x.Func
- }
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 1: %T ]]", v)
- }
- // func_name_context generates a name for the func.
- func (f *Funcs) func_name_context(v interface{}) string {
- switch x := v.(type) {
- case string:
- return nameContext(f.context_both(), x)
- case Query:
- return nameContext(f.context_both(), x.Name)
- case Table:
- return nameContext(f.context_both(), x.GoName)
- case ForeignKey:
- return nameContext(f.context_both(), x.GoName)
- case Proc:
- n := x.GoName
- if x.Overloaded {
- n = x.OverloadedName
- }
- return nameContext(f.context_both(), n)
- case Index:
- return nameContext(f.context_both(), x.Func)
- }
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 2: %T ]]", v)
- }
- // funcfn builds a func definition.
- func (f *Funcs) funcfn(name string, context bool, v interface{}) string {
- var p, r []string
- if context {
- p = append(p, "ctx context.Context")
- }
- p = append(p, "db DB")
- switch x := v.(type) {
- case Query:
- // params
- for _, z := range x.Params {
- p = append(p, fmt.Sprintf("%s %s", z.Name, z.Type))
- }
- // returns
- switch {
- case x.Exec:
- r = append(r, "sql.Result")
- case x.Flat:
- for _, z := range x.Type.Fields {
- r = append(r, f.typefn(z.Type))
- }
- case x.One:
- r = append(r, "*"+x.Type.GoName)
- default:
- r = append(r, "[]*"+x.Type.GoName)
- }
- case Proc:
- // params
- p = append(p, f.params(x.Params, true))
- // returns
- if !x.Void {
- for _, ret := range x.Returns {
- r = append(r, f.typefn(ret.Type))
- }
- }
- case Index:
- // params
- p = append(p, f.params(x.Fields, true))
- // returns
- rt := "*" + x.Table.GoName
- if !x.IsUnique {
- rt = "[]" + rt
- }
- r = append(r, rt)
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 3: %T ]]", v)
- }
- r = append(r, "error")
- return fmt.Sprintf("func %s(%s) (%s)", name, strings.Join(p, ", "), strings.Join(r, ", "))
- }
- // func_context generates a func signature for v with context determined by the
- // context mode.
- func (f *Funcs) func_context(v interface{}) string {
- return f.funcfn(f.func_name_context(v), f.contextfn(), v)
- }
- // func_none genarates a func signature for v without context.
- func (f *Funcs) func_none(v interface{}) string {
- return f.funcfn(f.func_name_none(v), false, v)
- }
- // recv builds a receiver func definition.
- func (f *Funcs) recv(name string, context bool, t Table, v interface{}) string {
- short := f.short(t)
- var p, r []string
- // determine params and return type
- if context {
- p = append(p, "ctx context.Context")
- }
- p = append(p, "db DB")
- switch x := v.(type) {
- case ForeignKey:
- r = append(r, "*"+x.RefTable)
- }
- r = append(r, "error")
- return fmt.Sprintf("func (%s *%s) %s(%s) (%s)", short, t.GoName, name, strings.Join(p, ", "), strings.Join(r, ", "))
- }
- // recv_context builds a receiver func definition with context determined by
- // the context mode.
- func (f *Funcs) recv_context(typ interface{}, v interface{}) string {
- switch x := typ.(type) {
- case Table:
- return f.recv(f.func_name_context(v), f.contextfn(), x, v)
- }
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 4: %T ]]", typ)
- }
- // recv_none builds a receiver func definition without context.
- func (f *Funcs) recv_none(typ interface{}, v interface{}) string {
- switch x := typ.(type) {
- case Table:
- return f.recv(f.func_name_none(v), false, x, v)
- }
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 5: %T ]]", typ)
- }
- func (f *Funcs) foreign_key_context(v interface{}) string {
- var name string
- var p []string
- if f.contextfn() {
- p = append(p, "ctx")
- }
- switch x := v.(type) {
- case ForeignKey:
- name = x.RefFunc
- if f.context_both() {
- name += "Context"
- }
- // add params
- p = append(p, "db", f.convertTypes(x))
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 6: %T ]]", v)
- }
- return fmt.Sprintf("%s(%s)", name, strings.Join(p, ", "))
- }
- func (f *Funcs) foreign_key_none(v interface{}) string {
- var name string
- var p []string
- switch x := v.(type) {
- case ForeignKey:
- name = x.RefFunc
- p = append(p, "context.Background()", "db", f.convertTypes(x))
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 7: %T ]]", v)
- }
- return fmt.Sprintf("%s(%s)", name, strings.Join(p, ", "))
- }
- // db generates a db.<name>Context(ctx, sqlstr, ...)
- func (f *Funcs) db(name string, v ...interface{}) string {
- // params
- var p []interface{}
- if f.contextfn() {
- name += "Context"
- p = append(p, "ctx")
- }
- p = append(p, "sqlstr")
- return fmt.Sprintf("db.%s(%s)", name, f.names("", append(p, v...)...))
- }
- // db_prefix generates a db.<name>Context(ctx, sqlstr, <prefix>.param, ...).
- //
- // Will skip the specific parameters based on the type provided.
- func (f *Funcs) db_prefix(name string, skip bool, vs ...interface{}) string {
- var prefix string
- var params []interface{}
- for i, v := range vs {
- var ignore []string
- switch x := v.(type) {
- case string:
- params = append(params, x)
- case Table:
- prefix = f.short(x.GoName) + "."
- // skip primary keys
- if skip {
- for _, field := range x.Fields {
- if field.IsSequence {
- ignore = append(ignore, field.GoName)
- }
- }
- }
- p := f.names_ignore(prefix, v, ignore...)
- // p is "" when no columns are present except for primary key
- // params
- if p != "" {
- params = append(params, p)
- }
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 8 (%d): %T ]]", i, v)
- }
- }
- return f.db(name, params...)
- }
- // db_update generates a db.<name>Context(ctx, sqlstr, regularparams,
- // primaryparams)
- func (f *Funcs) db_update(name string, v interface{}) string {
- var ignore, p []string
- switch x := v.(type) {
- case Table:
- prefix := f.short(x.GoName) + "."
- for _, pk := range x.PrimaryKeys {
- ignore = append(ignore, pk.GoName)
- }
- p = append(p, f.names_ignore(prefix, x, ignore...), f.names(prefix, x.PrimaryKeys))
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 9: %T ]]", v)
- }
- return f.db(name, strings.Join(p, ", "))
- }
- // db_named generates a db.<name>Context(ctx, sql.Named(name, res)...)
- func (f *Funcs) db_named(name string, v interface{}) string {
- var p []string
- switch x := v.(type) {
- case Proc:
- for _, z := range x.Params {
- p = append(p, f.named(z.SQLName, z.GoName, false))
- }
- for _, z := range x.Returns {
- p = append(p, f.named(z.SQLName, "&"+z.GoName, true))
- }
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 10: %T ]]", v)
- }
- return f.db(name, strings.Join(p, ", "))
- }
- func (f *Funcs) named(name, value string, out bool) string {
- if out {
- return fmt.Sprintf("sql.Named(%q, sql.Out{Dest: %s})", name, value)
- }
- return fmt.Sprintf("sql.Named(%q, %s)", name, value)
- }
- func (f *Funcs) logf_pkeys(v interface{}) string {
- p := []string{"sqlstr"}
- switch x := v.(type) {
- case Table:
- p = append(p, f.names(f.short(x.GoName)+".", x.PrimaryKeys))
- }
- return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
- }
- func (f *Funcs) logf(v interface{}, ignore ...interface{}) string {
- var ignoreNames []string
- p := []string{"sqlstr"}
- // build ignore list
- for i, x := range ignore {
- switch z := x.(type) {
- case string:
- ignoreNames = append(ignoreNames, z)
- case Field:
- ignoreNames = append(ignoreNames, z.GoName)
- case []Field:
- for _, f := range z {
- ignoreNames = append(ignoreNames, f.GoName)
- }
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 11 (%d): %T ]]", i, x)
- }
- }
- // add fields
- switch x := v.(type) {
- case Table:
- p = append(p, f.names_ignore(f.short(x.GoName)+".", x, ignoreNames...))
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 12: %T ]]", v)
- }
- return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
- }
- func (f *Funcs) logf_update(v interface{}) string {
- var ignore []string
- p := []string{"sqlstr"}
- switch x := v.(type) {
- case Table:
- prefix := f.short(x.GoName) + "."
- for _, pk := range x.PrimaryKeys {
- ignore = append(ignore, pk.GoName)
- }
- p = append(p, f.names_ignore(prefix, x, ignore...), f.names(prefix, x.PrimaryKeys))
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 13: %T ]]", v)
- }
- return fmt.Sprintf("logf(%s)", strings.Join(p, ", "))
- }
- // names generates a list of names.
- func (f *Funcs) namesfn(all bool, prefix string, z ...interface{}) string {
- var names []string
- for i, v := range z {
- switch x := v.(type) {
- case string:
- names = append(names, x)
- case Query:
- for _, p := range x.Params {
- if !all && p.Interpolate {
- continue
- }
- names = append(names, prefix+p.Name)
- }
- case Table:
- for _, p := range x.Fields {
- names = append(names, prefix+checkName(p.GoName))
- }
- case []Field:
- for _, p := range x {
- names = append(names, prefix+checkName(p.GoName))
- }
- case Proc:
- if params := f.params(x.Params, false); params != "" {
- names = append(names, params)
- }
- case Index:
- names = append(names, f.params(x.Fields, false))
- default:
- names = append(names, fmt.Sprintf("/* UNSUPPORTED TYPE 14 (%d): %T */", i, v))
- }
- }
- return strings.Join(names, ", ")
- }
- // names generates a list of names (excluding certain ones such as interpolated
- // names).
- func (f *Funcs) names(prefix string, z ...interface{}) string {
- return f.namesfn(false, prefix, z...)
- }
- // names_all generates a list of all names.
- func (f *Funcs) names_all(prefix string, z ...interface{}) string {
- return f.namesfn(true, prefix, z...)
- }
- // names_ignore generates a list of all names, ignoring fields that match the value in ignore.
- func (f *Funcs) names_ignore(prefix string, v interface{}, ignore ...string) string {
- m := make(map[string]bool)
- for _, n := range ignore {
- m[n] = true
- }
- var vals []Field
- switch x := v.(type) {
- case Table:
- for _, p := range x.Fields {
- if m[p.GoName] {
- continue
- }
- vals = append(vals, p)
- }
- case []Field:
- for _, p := range x {
- if m[p.GoName] {
- continue
- }
- vals = append(vals, p)
- }
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 15: %T ]]", v)
- }
- return f.namesfn(true, prefix, vals)
- }
- // querystr generates a querystr for the specified query and any accompanying
- // comments.
- func (f *Funcs) querystr(v interface{}) string {
- var interpolate bool
- var query, comments []string
- switch x := v.(type) {
- case Query:
- interpolate, query, comments = x.Interpolate, x.Query, x.Comments
- default:
- return fmt.Sprintf("const sqlstr = [[ UNSUPPORTED TYPE 16: %T ]]", v)
- }
- typ := "const"
- if interpolate {
- typ = "var"
- }
- var lines []string
- for i := 0; i < len(query); i++ {
- line := "`" + query[i] + "`"
- if i != len(query)-1 {
- line += " + "
- }
- if s := strings.TrimSpace(comments[i]); s != "" {
- line += "// " + s
- }
- lines = append(lines, line)
- }
- sqlstr := stripRE.ReplaceAllString(strings.Join(lines, "\n"), " ")
- return fmt.Sprintf("%s sqlstr = %s", typ, sqlstr)
- }
- var stripRE = regexp.MustCompile(`\s+\+\s+` + "``")
- func (f *Funcs) sqlstr(typ string, v interface{}) string {
- var lines []string
- switch typ {
- case "insert_manual":
- lines = f.sqlstr_insert_manual(v)
- case "insert":
- lines = f.sqlstr_insert(v)
- case "update":
- lines = f.sqlstr_update(v)
- case "upsert":
- lines = f.sqlstr_upsert(v)
- case "delete":
- lines = f.sqlstr_delete(v)
- case "proc":
- lines = f.sqlstr_proc(v)
- case "index":
- lines = f.sqlstr_index(v)
- default:
- return fmt.Sprintf("const sqlstr = `UNKNOWN QUERY TYPE: %s`", typ)
- }
- return fmt.Sprintf("const sqlstr = `%s`", strings.Join(lines, "` +\n\t`"))
- }
- // sqlstr_insert_base builds an INSERT query
- // If not all, sequence columns are skipped.
- func (f *Funcs) sqlstr_insert_base(all bool, v interface{}) []string {
- switch x := v.(type) {
- case Table:
- // build names and values
- var n int
- var fields, vals []string
- for _, z := range x.Fields {
- if z.IsSequence && !all {
- continue
- }
- fields, vals = append(fields, f.colname(z)), append(vals, f.nth(n))
- n++
- }
- return []string{
- "INSERT INTO " + f.schemafn(x.SQLName) + " (",
- strings.Join(fields, ", "),
- ") VALUES (",
- strings.Join(vals, ", "),
- ")",
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 17: %T ]]", v)}
- }
- // sqlstr_insert_manual builds an INSERT query that inserts all fields.
- func (f *Funcs) sqlstr_insert_manual(v interface{}) []string {
- return f.sqlstr_insert_base(true, v)
- }
- // sqlstr_insert builds an INSERT query, skipping the sequence field with
- // applicable RETURNING clause for generated primary key fields.
- func (f *Funcs) sqlstr_insert(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- var seq Field
- for _, field := range x.Fields {
- if field.IsSequence {
- seq = field
- }
- }
- lines := f.sqlstr_insert_base(false, v)
- // add return clause
- switch f.driver {
- case "oracle":
- lines[len(lines)-1] += ` RETURNING ` + f.colname(seq) + ` /*LASTINSERTID*/ INTO :pk`
- case "postgres":
- lines[len(lines)-1] += ` RETURNING ` + f.colname(seq)
- case "sqlserver":
- lines[len(lines)-1] += "; SELECT ID = CONVERT(BIGINT, SCOPE_IDENTITY())"
- }
- return lines
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 18: %T ]]", v)}
- }
- // sqlstr_update_base builds an UPDATE query, using primary key fields as the WHERE
- // clause, adding prefix.
- //
- // When prefix is empty, the WHERE clause will be in the form of name = $1.
- // When prefix is non-empty, the WHERE clause will be in the form of name = <PREFIX>name.
- //
- // Similarly, when prefix is empty, the table's name is added after UPDATE,
- // otherwise it is omitted.
- func (f *Funcs) sqlstr_update_base(prefix string, v interface{}) (int, []string) {
- switch x := v.(type) {
- case Table:
- // build names and values
- var n int
- var list []string
- for _, z := range x.Fields {
- if z.IsPrimary {
- continue
- }
- name, param := f.colname(z), f.nth(n)
- if prefix != "" {
- param = prefix + name
- }
- list = append(list, fmt.Sprintf("%s = %s", name, param))
- n++
- }
- name := ""
- if prefix == "" {
- name = f.schemafn(x.SQLName) + " "
- }
- return n, []string{
- "UPDATE " + name + "SET ",
- strings.Join(list, ", ") + " ",
- }
- }
- return 0, []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 19: %T ]]", v)}
- }
- // sqlstr_update builds an UPDATE query, using primary key fields as the WHERE
- // clause.
- func (f *Funcs) sqlstr_update(v interface{}) []string {
- // build pkey vals
- switch x := v.(type) {
- case Table:
- var list []string
- n, lines := f.sqlstr_update_base("", v)
- for i, z := range x.PrimaryKeys {
- list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(n+i)))
- }
- return append(lines, "WHERE "+strings.Join(list, " AND "))
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 20: %T ]]", v)}
- }
- func (f *Funcs) sqlstr_upsert(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- // build insert
- lines := f.sqlstr_insert_base(true, x)
- switch f.driver {
- case "postgres", "sqlite3":
- return append(lines, f.sqlstr_upsert_postgres_sqlite(x)...)
- case "mysql":
- return append(lines, f.sqlstr_upsert_mysql(x)...)
- case "sqlserver", "oracle":
- return f.sqlstr_upsert_sqlserver_oracle(x)
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 21 %s: %T ]]", f.driver, v)}
- }
- // sqlstr_upsert_postgres_sqlite builds an uspert query for postgres and sqlite
- //
- // INSERT (..) VALUES (..) ON CONFLICT DO UPDATE SET ...
- func (f *Funcs) sqlstr_upsert_postgres_sqlite(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- // add conflict and update
- var conflicts []string
- for _, f := range x.PrimaryKeys {
- conflicts = append(conflicts, f.SQLName)
- }
- lines := []string{" ON CONFLICT (" + strings.Join(conflicts, ", ") + ") DO "}
- _, update := f.sqlstr_update_base("EXCLUDED.", v)
- return append(lines, update...)
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 22: %T ]]", v)}
- }
- // sqlstr_upsert_mysql builds an uspert query for mysql
- //
- // INSERT (..) VALUES (..) ON DUPLICATE KEY UPDATE SET ...
- func (f *Funcs) sqlstr_upsert_mysql(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- lines := []string{" ON DUPLICATE KEY UPDATE "}
- var list []string
- i := len(x.Fields)
- for _, z := range x.Fields {
- if z.IsSequence {
- continue
- }
- name := f.colname(z)
- list = append(list, fmt.Sprintf("%s = VALUES(%s)", name, name))
- i++
- }
- return append(lines, strings.Join(list, ", "))
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 23: %T ]]", v)}
- }
- // sqlstr_upsert_sqlserver_oracle builds an upsert query for sqlserver
- //
- // MERGE [table] AS target USING (SELECT [pkeys]) AS source ...
- func (f *Funcs) sqlstr_upsert_sqlserver_oracle(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- var lines []string
- // merge [table]...
- switch f.driver {
- case "sqlserver":
- lines = []string{"MERGE " + f.schemafn(x.SQLName) + " AS t "}
- case "oracle":
- lines = []string{"MERGE " + f.schemafn(x.SQLName) + "t "}
- }
- // using (select ..)
- var fields, predicate []string
- for i, field := range x.Fields {
- fields = append(fields, fmt.Sprintf("%s %s", f.nth(i), field.SQLName))
- }
- for _, field := range x.PrimaryKeys {
- predicate = append(predicate, fmt.Sprintf("s.%s = t.%s", field.SQLName, field.SQLName))
- }
- // closing part for select
- var closing string
- switch f.driver {
- case "sqlserver":
- closing = `) AS s `
- case "oracle":
- closing = `FROM DUAL ) s `
- }
- lines = append(lines, `USING (`,
- `SELECT `+strings.Join(fields, ", ")+" ",
- closing,
- `ON `+strings.Join(predicate, " AND ")+" ")
- // build param lists
- var updateParams, insertParams, insertVals []string
- for _, field := range x.Fields {
- // sequences are always managed by db
- if field.IsSequence {
- continue
- }
- // primary keys
- if !field.IsPrimary {
- updateParams = append(updateParams, fmt.Sprintf("t.%s = s.%s", field.SQLName, field.SQLName))
- }
- insertParams = append(insertParams, field.SQLName)
- insertVals = append(insertVals, "s."+field.SQLName)
- }
- // when matched then update...
- lines = append(lines,
- `WHEN MATCHED THEN `, `UPDATE SET `,
- strings.Join(updateParams, ", ")+" ",
- `WHEN NOT MATCHED THEN `,
- `INSERT (`,
- strings.Join(insertParams, ", "),
- `) VALUES (`,
- strings.Join(insertVals, ", "),
- `);`,
- )
- return lines
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 24: %T ]]", v)}
- }
- // sqlstr_delete builds a DELETE query for the primary keys.
- func (f *Funcs) sqlstr_delete(v interface{}) []string {
- switch x := v.(type) {
- case Table:
- // names and values
- var list []string
- for i, z := range x.PrimaryKeys {
- list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(i)))
- }
- return []string{
- "DELETE FROM " + f.schemafn(x.SQLName) + " ",
- "WHERE " + strings.Join(list, " AND "),
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 25: %T ]]", v)}
- }
- // sqlstr_index builds a index fields.
- func (f *Funcs) sqlstr_index(v interface{}) []string {
- switch x := v.(type) {
- case Index:
- // build table fieldnames
- var fields []string
- for _, z := range x.Table.Fields {
- fields = append(fields, f.colname(z))
- }
- // index fields
- var list []string
- for i, z := range x.Fields {
- list = append(list, fmt.Sprintf("%s = %s", f.colname(z), f.nth(i)))
- }
- return []string{
- "SELECT ",
- strings.Join(fields, ", ") + " ",
- "FROM " + f.schemafn(x.Table.SQLName) + " ",
- "WHERE " + strings.Join(list, " AND "),
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 26: %T ]]", v)}
- }
- // sqlstr_proc builds a stored procedure call.
- func (f *Funcs) sqlstr_proc(v interface{}) []string {
- switch x := v.(type) {
- case Proc:
- if x.Type == "function" {
- return f.sqlstr_func(v)
- }
- // sql string format
- var format string
- switch f.driver {
- case "postgres", "mysql":
- format = "CALL %s(%s)"
- case "sqlserver":
- format = "%[1]s"
- case "oracle":
- format = "BEGIN %s(%s); END;"
- }
- // build params list; add return fields for orcle
- l := x.Params
- if f.driver == "oracle" {
- l = append(l, x.Returns...)
- }
- var list []string
- for i, field := range l {
- s := f.nth(i)
- if f.driver == "oracle" {
- s = ":" + field.SQLName
- }
- list = append(list, s)
- }
- // dont prefix with schema for oracle
- name := f.schemafn(x.SQLName)
- if f.driver == "oracle" {
- name = x.SQLName
- }
- return []string{
- fmt.Sprintf(format, name, strings.Join(list, ", ")),
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 27: %T ]]", v)}
- }
- func (f *Funcs) sqlstr_func(v interface{}) []string {
- switch x := v.(type) {
- case Proc:
- var format string
- switch f.driver {
- case "postgres":
- format = "SELECT * FROM %s(%s)"
- case "mysql":
- format = "SELECT %s(%s)"
- case "sqlserver":
- format = "SELECT %s(%s) AS OUT"
- case "oracle":
- format = "SELECT %s(%s) FROM dual"
- }
- var list []string
- l := x.Params
- for i := range l {
- list = append(list, f.nth(i))
- }
- return []string{
- fmt.Sprintf(format, f.schemafn(x.SQLName), strings.Join(list, ", ")),
- }
- }
- return []string{fmt.Sprintf("[[ UNSUPPORTED TYPE 28: %T ]]", v)}
- }
- // convertTypes generates the conversions to convert the foreign key field
- // types to their respective referenced field types.
- func (f *Funcs) convertTypes(fkey ForeignKey) string {
- var p []string
- for i := range fkey.Fields {
- field := fkey.Fields[i]
- refField := fkey.RefFields[i]
- expr := f.short(fkey.Table) + "." + field.GoName
- // types match, can match
- if field.Type == refField.Type {
- p = append(p, expr)
- continue
- }
- // convert types
- typ, refType := field.Type, refField.Type
- if strings.HasPrefix(typ, "sql.Null") {
- expr = expr + "." + typ[8:]
- typ = strings.ToLower(typ[8:])
- }
- if strings.ToLower(refType) != typ {
- expr = refType + "(" + expr + ")"
- }
- p = append(p, expr)
- }
- return strings.Join(p, ", ")
- }
- // params converts a list of fields into their named Go parameters, skipping
- // any Field with Name contained in ignore. addType will cause the go Type to
- // be added after each variable name. addPrefix will cause the returned string
- // to be prefixed with ", " if the generated string is not empty.
- //
- // Any field name encountered will be checked against goReservedNames, and will
- // have its name substituted by its corresponding looked up value.
- //
- // Used to present a comma separated list of Go variable names for use with as
- // either a Go func parameter list, or in a call to another Go func.
- // (ie, ", a, b, c, ..." or ", a T1, b T2, c T3, ...").
- func (f *Funcs) params(fields []Field, addType bool) string {
- var vals []string
- for _, field := range fields {
- vals = append(vals, f.param(field, addType))
- }
- return strings.Join(vals, ", ")
- }
- func (f *Funcs) param(field Field, addType bool) string {
- n := strings.Split(snaker.CamelToSnake(field.GoName), "_")
- s := strings.ToLower(n[0]) + field.GoName[len(n[0]):]
- // check go reserved names
- if r, ok := goReservedNames[strings.ToLower(s)]; ok {
- s = r
- }
- // add the go type
- if addType {
- s += " " + f.typefn(field.Type)
- }
- // add to vals
- return s
- }
- // zero generates a zero list.
- func (f *Funcs) zero(z ...interface{}) string {
- var zeroes []string
- for i, v := range z {
- switch x := v.(type) {
- case string:
- zeroes = append(zeroes, x)
- case Table:
- for _, p := range x.Fields {
- zeroes = append(zeroes, f.zero(p))
- }
- case []Field:
- for _, p := range x {
- zeroes = append(zeroes, f.zero(p))
- }
- case Field:
- if _, ok := f.knownTypes[x.Type]; ok || x.Zero == "nil" {
- zeroes = append(zeroes, x.Zero)
- break
- }
- zeroes = append(zeroes, f.typefn(x.Type)+"{}")
- default:
- zeroes = append(zeroes, fmt.Sprintf("/* UNSUPPORTED TYPE 29 (%d): %T */", i, v))
- }
- }
- return strings.Join(zeroes, ", ")
- }
- // typefn generates the Go type, prefixing the custom package name if applicable.
- func (f *Funcs) typefn(typ string) string {
- if strings.Contains(typ, ".") {
- return typ
- }
- var prefix string
- for strings.HasPrefix(typ, "[]") {
- typ = typ[2:]
- prefix += "[]"
- }
- if _, ok := f.knownTypes[typ]; ok || f.custom == "" {
- return prefix + typ
- }
- return prefix + f.custom + "." + typ
- }
- // field generates a field definition for a struct.
- func (f *Funcs) field(field Field) (string, error) {
- buf := new(bytes.Buffer)
- if err := f.fieldtag.Funcs(f.FuncMap()).Execute(buf, field); err != nil {
- return "", err
- }
- var tag string
- if s := buf.String(); s != "" {
- tag = " `" + s + "`"
- }
- return fmt.Sprintf("\t%s %s%s // %s", field.GoName, f.typefn(field.Type), tag, field.SQLName), nil
- }
- // short generates a safe Go identifier for typ. typ is first checked
- // against shorts, and if not found, then the value is calculated and
- // stored in the shorts for future use.
- //
- // A short is the concatenation of the lowercase of the first character in
- // the words comprising the name. For example, "MyCustomName" will have have
- // the short of "mcn".
- //
- // If a generated short conflicts with a Go reserved name or a name used in
- // the templates, then the corresponding value in goReservedNames map will be
- // used.
- //
- // Generated shorts that have conflicts with any scopeConflicts member will
- // have nameConflictSuffix appended.
- func (f *Funcs) short(v interface{}) string {
- var n string
- switch x := v.(type) {
- case string:
- n = x
- case Table:
- n = x.GoName
- default:
- return fmt.Sprintf("[[ UNSUPPORTED TYPE 30: %T ]]", v)
- }
- // check short name map
- name, ok := f.shorts[n]
- if !ok {
- // calc the short name
- var u []string
- for _, s := range strings.Split(strings.ToLower(snaker.CamelToSnake(n)), "_") {
- if len(s) > 0 && s != "id" {
- u = append(u, s[:1])
- }
- }
- // ensure no name conflict
- name = checkName(strings.Join(u, ""))
- // store back to short name map
- f.shorts[n] = name
- }
- // append suffix if conflict exists
- if _, ok := templateReservedNames[name]; ok {
- name += f.conflict
- }
- return name
- }
- // colname returns the ColumnName of a field escaped if needed.
- func (f *Funcs) colname(z Field) string {
- if f.escColumn {
- return escfn(z.SQLName)
- }
- return z.SQLName
- }
- func checkName(name string) string {
- if n, ok := goReservedNames[name]; ok {
- return n
- }
- return name
- }
- // escfn escapes s.
- func escfn(s string) string {
- return `"` + s + `"`
- }
- // eval evalutates a template s against v.
- func eval(v interface{}, s string) (string, error) {
- tpl, err := template.New(fmt.Sprintf("[EVAL %q]", s)).Parse(s)
- if err != nil {
- return "", err
- }
- buf := new(bytes.Buffer)
- if err := tpl.Execute(buf, v); err != nil {
- return "", err
- }
- return buf.String(), nil
- }
- // templateReservedNames are the template reserved names.
- var templateReservedNames = map[string]bool{
- // variables
- "ctx": true,
- "db": true,
- "err": true,
- "log": true,
- "logf": true,
- "res": true,
- "rows": true,
- // packages
- "context": true,
- "csv": true,
- "driver": true,
- "errors": true,
- "fmt": true,
- "hstore": true,
- "regexp": true,
- "sql": true,
- "strings": true,
- "time