PageRenderTime 50ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/staging/src/k8s.io/code-generator/cmd/go-to-protobuf/protobuf/parser.go

https://gitlab.com/unofficial-mirrors/kubernetes
Go | 452 lines | 371 code | 33 blank | 48 comment | 145 complexity | ffe1c81f6b50b7e8b30529a328c228dc MD5 | raw file
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package protobuf
  14. import (
  15. "bytes"
  16. "errors"
  17. "fmt"
  18. "go/ast"
  19. "go/format"
  20. "go/parser"
  21. "go/printer"
  22. "go/token"
  23. "io/ioutil"
  24. "os"
  25. "reflect"
  26. "strings"
  27. customreflect "k8s.io/code-generator/third_party/forked/golang/reflect"
  28. )
  29. func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast.File) error) error {
  30. fset := token.NewFileSet()
  31. src, err := ioutil.ReadFile(name)
  32. if err != nil {
  33. return err
  34. }
  35. file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments)
  36. if err != nil {
  37. return err
  38. }
  39. if err := rewriteFn(fset, file); err != nil {
  40. return err
  41. }
  42. b := &bytes.Buffer{}
  43. b.Write(header)
  44. if err := printer.Fprint(b, fset, file); err != nil {
  45. return err
  46. }
  47. body, err := format.Source(b.Bytes())
  48. if err != nil {
  49. return err
  50. }
  51. f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644)
  52. if err != nil {
  53. return err
  54. }
  55. defer f.Close()
  56. if _, err := f.Write(body); err != nil {
  57. return err
  58. }
  59. return f.Close()
  60. }
  61. // ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be
  62. // removed from the destination file.
  63. type ExtractFunc func(*ast.TypeSpec) bool
  64. // OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true
  65. // and should have its marshal functions adjusted to remove the 'Items' accessor.
  66. type OptionalFunc func(name string) bool
  67. func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error {
  68. return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
  69. cmap := ast.NewCommentMap(fset, file, file.Comments)
  70. // transform methods that point to optional maps or slices
  71. for _, d := range file.Decls {
  72. rewriteOptionalMethods(d, optionalFn)
  73. }
  74. // remove types that are already declared
  75. decls := []ast.Decl{}
  76. for _, d := range file.Decls {
  77. if dropExistingTypeDeclarations(d, extractFn) {
  78. continue
  79. }
  80. if dropEmptyImportDeclarations(d) {
  81. continue
  82. }
  83. decls = append(decls, d)
  84. }
  85. file.Decls = decls
  86. // remove unmapped comments
  87. file.Comments = cmap.Filter(file).Comments()
  88. return nil
  89. })
  90. }
  91. // rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
  92. // as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
  93. // properly discriminate between empty and nil (which is not possible in protobuf).
  94. // TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
  95. // has agreement
  96. func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
  97. switch t := decl.(type) {
  98. case *ast.FuncDecl:
  99. ident, ptr, ok := receiver(t)
  100. if !ok {
  101. return
  102. }
  103. // correct initialization of the form `m.Field = &OptionalType{}` to
  104. // `m.Field = OptionalType{}`
  105. if t.Name.Name == "Unmarshal" {
  106. ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
  107. }
  108. if !isOptional(ident.Name) {
  109. return
  110. }
  111. switch t.Name.Name {
  112. case "Unmarshal":
  113. ast.Walk(&optionalItemsVisitor{}, t.Body)
  114. case "MarshalTo", "Size", "String":
  115. ast.Walk(&optionalItemsVisitor{}, t.Body)
  116. fallthrough
  117. case "Marshal":
  118. // if the method has a pointer receiver, set it back to a normal receiver
  119. if ptr {
  120. t.Recv.List[0].Type = ident
  121. }
  122. }
  123. }
  124. }
  125. type optionalAssignmentVisitor struct {
  126. fn OptionalFunc
  127. }
  128. // Visit walks the provided node, transforming field initializations of the form
  129. // m.Field = &OptionalType{} -> m.Field = OptionalType{}
  130. func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor {
  131. switch t := n.(type) {
  132. case *ast.AssignStmt:
  133. if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
  134. if !isFieldSelector(t.Lhs[0], "m", "") {
  135. return nil
  136. }
  137. unary, ok := t.Rhs[0].(*ast.UnaryExpr)
  138. if !ok || unary.Op != token.AND {
  139. return nil
  140. }
  141. composite, ok := unary.X.(*ast.CompositeLit)
  142. if !ok || composite.Type == nil || len(composite.Elts) != 0 {
  143. return nil
  144. }
  145. if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) {
  146. t.Rhs[0] = composite
  147. }
  148. }
  149. return nil
  150. }
  151. return v
  152. }
  153. type optionalItemsVisitor struct{}
  154. // Visit walks the provided node, looking for specific patterns to transform that match
  155. // the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
  156. func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
  157. switch t := n.(type) {
  158. case *ast.RangeStmt:
  159. if isFieldSelector(t.X, "m", "Items") {
  160. t.X = &ast.Ident{Name: "m"}
  161. }
  162. case *ast.AssignStmt:
  163. if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
  164. switch lhs := t.Lhs[0].(type) {
  165. case *ast.IndexExpr:
  166. if isFieldSelector(lhs.X, "m", "Items") {
  167. lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  168. }
  169. default:
  170. if isFieldSelector(t.Lhs[0], "m", "Items") {
  171. t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  172. }
  173. }
  174. switch rhs := t.Rhs[0].(type) {
  175. case *ast.CallExpr:
  176. if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
  177. ast.Walk(v, rhs)
  178. if len(rhs.Args) > 0 {
  179. switch arg := rhs.Args[0].(type) {
  180. case *ast.Ident:
  181. if arg.Name == "m" {
  182. rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  183. }
  184. }
  185. }
  186. return nil
  187. }
  188. }
  189. }
  190. case *ast.IfStmt:
  191. switch cond := t.Cond.(type) {
  192. case *ast.BinaryExpr:
  193. if cond.Op == token.EQL {
  194. if isFieldSelector(cond.X, "m", "Items") && isIdent(cond.Y, "nil") {
  195. cond.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  196. }
  197. }
  198. }
  199. if t.Init != nil {
  200. // Find form:
  201. // if err := m[len(m.Items)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
  202. // return err
  203. // }
  204. switch s := t.Init.(type) {
  205. case *ast.AssignStmt:
  206. if call, ok := s.Rhs[0].(*ast.CallExpr); ok {
  207. if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
  208. if x, ok := sel.X.(*ast.IndexExpr); ok {
  209. // m[] -> (*m)[]
  210. if sel2, ok := x.X.(*ast.SelectorExpr); ok {
  211. if ident, ok := sel2.X.(*ast.Ident); ok && ident.Name == "m" {
  212. x.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  213. }
  214. }
  215. // len(m.Items) -> len(*m)
  216. if bin, ok := x.Index.(*ast.BinaryExpr); ok {
  217. if call2, ok := bin.X.(*ast.CallExpr); ok && len(call2.Args) == 1 {
  218. if isFieldSelector(call2.Args[0], "m", "Items") {
  219. call2.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
  220. }
  221. }
  222. }
  223. }
  224. }
  225. }
  226. }
  227. }
  228. case *ast.IndexExpr:
  229. if isFieldSelector(t.X, "m", "Items") {
  230. t.X = &ast.Ident{Name: "m"}
  231. return nil
  232. }
  233. case *ast.CallExpr:
  234. changed := false
  235. for i := range t.Args {
  236. if isFieldSelector(t.Args[i], "m", "Items") {
  237. t.Args[i] = &ast.Ident{Name: "m"}
  238. changed = true
  239. }
  240. }
  241. if changed {
  242. return nil
  243. }
  244. }
  245. return v
  246. }
  247. func isFieldSelector(n ast.Expr, name, field string) bool {
  248. s, ok := n.(*ast.SelectorExpr)
  249. if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) {
  250. return false
  251. }
  252. return isIdent(s.X, name)
  253. }
  254. func isIdent(n ast.Expr, value string) bool {
  255. ident, ok := n.(*ast.Ident)
  256. return ok && ident.Name == value
  257. }
  258. func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) {
  259. if f.Recv == nil || len(f.Recv.List) != 1 {
  260. return nil, false, false
  261. }
  262. switch t := f.Recv.List[0].Type.(type) {
  263. case *ast.StarExpr:
  264. identity, ok := t.X.(*ast.Ident)
  265. if !ok {
  266. return nil, false, false
  267. }
  268. return identity, true, true
  269. case *ast.Ident:
  270. return t, false, true
  271. }
  272. return nil, false, false
  273. }
  274. // dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function
  275. // returns true if the entire declaration should be dropped.
  276. func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
  277. switch t := decl.(type) {
  278. case *ast.GenDecl:
  279. if t.Tok != token.TYPE {
  280. return false
  281. }
  282. specs := []ast.Spec{}
  283. for _, s := range t.Specs {
  284. switch spec := s.(type) {
  285. case *ast.TypeSpec:
  286. if extractFn(spec) {
  287. continue
  288. }
  289. specs = append(specs, spec)
  290. }
  291. }
  292. if len(specs) == 0 {
  293. return true
  294. }
  295. t.Specs = specs
  296. }
  297. return false
  298. }
  299. // dropEmptyImportDeclarations strips any generated but no-op imports from the generated code
  300. // to prevent generation from being able to define side-effects. The function returns true
  301. // if the entire declaration should be dropped.
  302. func dropEmptyImportDeclarations(decl ast.Decl) bool {
  303. switch t := decl.(type) {
  304. case *ast.GenDecl:
  305. if t.Tok != token.IMPORT {
  306. return false
  307. }
  308. specs := []ast.Spec{}
  309. for _, s := range t.Specs {
  310. switch spec := s.(type) {
  311. case *ast.ImportSpec:
  312. if spec.Name != nil && spec.Name.Name == "_" {
  313. continue
  314. }
  315. specs = append(specs, spec)
  316. }
  317. }
  318. if len(specs) == 0 {
  319. return true
  320. }
  321. t.Specs = specs
  322. }
  323. return false
  324. }
  325. func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error {
  326. return rewriteFile(name, []byte{}, func(fset *token.FileSet, file *ast.File) error {
  327. allErrs := []error{}
  328. // set any new struct tags
  329. for _, d := range file.Decls {
  330. if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 {
  331. allErrs = append(allErrs, errs...)
  332. }
  333. }
  334. if len(allErrs) > 0 {
  335. var s string
  336. for _, err := range allErrs {
  337. s += err.Error() + "\n"
  338. }
  339. return errors.New(s)
  340. }
  341. return nil
  342. })
  343. }
  344. func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {
  345. var errs []error
  346. t, ok := decl.(*ast.GenDecl)
  347. if !ok {
  348. return nil
  349. }
  350. if t.Tok != token.TYPE {
  351. return nil
  352. }
  353. for _, s := range t.Specs {
  354. spec, ok := s.(*ast.TypeSpec)
  355. if !ok {
  356. continue
  357. }
  358. typeName := spec.Name.Name
  359. fieldTags, ok := structTags[typeName]
  360. if !ok {
  361. continue
  362. }
  363. st, ok := spec.Type.(*ast.StructType)
  364. if !ok {
  365. continue
  366. }
  367. for i := range st.Fields.List {
  368. f := st.Fields.List[i]
  369. var name string
  370. if len(f.Names) == 0 {
  371. switch t := f.Type.(type) {
  372. case *ast.Ident:
  373. name = t.Name
  374. case *ast.SelectorExpr:
  375. name = t.Sel.Name
  376. default:
  377. errs = append(errs, fmt.Errorf("unable to get name for tag from struct %q, field %#v", spec.Name.Name, t))
  378. continue
  379. }
  380. } else {
  381. name = f.Names[0].Name
  382. }
  383. value, ok := fieldTags[name]
  384. if !ok {
  385. continue
  386. }
  387. var tags customreflect.StructTags
  388. if f.Tag != nil {
  389. oldTags, err := customreflect.ParseStructTags(strings.Trim(f.Tag.Value, "`"))
  390. if err != nil {
  391. errs = append(errs, fmt.Errorf("unable to read struct tag from struct %q, field %q: %v", spec.Name.Name, name, err))
  392. continue
  393. }
  394. tags = oldTags
  395. }
  396. for _, name := range toCopy {
  397. // don't overwrite existing tags
  398. if tags.Has(name) {
  399. continue
  400. }
  401. // append new tags
  402. if v := reflect.StructTag(value).Get(name); len(v) > 0 {
  403. tags = append(tags, customreflect.StructTag{Name: name, Value: v})
  404. }
  405. }
  406. if len(tags) == 0 {
  407. continue
  408. }
  409. if f.Tag == nil {
  410. f.Tag = &ast.BasicLit{}
  411. }
  412. f.Tag.Value = tags.String()
  413. }
  414. }
  415. return errs
  416. }