PageRenderTime 29ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/exp/cmd/errfix/errgo.go

https://code.google.com/p/rog-go/
Go | 361 lines | 308 code | 25 blank | 28 comment | 74 complexity | 1118a4c063c5cbac7a08c2c03c0673ec MD5 | raw file
  1. package main
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/token"
  6. "log"
  7. "path"
  8. "strconv"
  9. "strings"
  10. )
  11. func init() {
  12. register(causeFix)
  13. register(maskFix)
  14. register(newFix)
  15. }
  16. const errgoPkgPath = "github.com/juju/errgo"
  17. var maskFix = fix{
  18. "errgo-mask",
  19. "2014-02-10",
  20. errgoMask,
  21. `wrap all returned errors; use errgo for all error creation functions
  22. `,
  23. }
  24. var causeFix = fix{
  25. "errgo-cause",
  26. "2014-02-14",
  27. errgoCause,
  28. `use Cause when comparing errors
  29. `,
  30. }
  31. var newFix = fix{
  32. "errgo-new",
  33. "2014-03-03",
  34. errgoNew,
  35. `use errgo.Newf instead of fmt.Errorf, and errgo.New instead of errors.New
  36. `,
  37. }
  38. type errgoFixContext struct {
  39. pathToIdent map[string]string
  40. gocheckIdent string
  41. importsOldErrgo bool
  42. }
  43. const errgoIdent = "errgo"
  44. func newErrgoFixContext(f *ast.File) *errgoFixContext {
  45. ctxt := &errgoFixContext{
  46. pathToIdent: importPathToIdentMap(f),
  47. }
  48. ctxt.gocheckIdent = ctxt.pathToIdent["launchpad.net/gocheck"]
  49. // If we import from any */errors package path,
  50. // import as errgo to save name clashes.
  51. for _, imp := range f.Imports {
  52. if importPath(imp) == "github.com/errgo/errgo" {
  53. ctxt.importsOldErrgo = true
  54. }
  55. }
  56. return ctxt
  57. }
  58. func importPathToIdentMap(f *ast.File) map[string]string {
  59. m := make(map[string]string)
  60. for _, imp := range f.Imports {
  61. ipath := importPath(imp)
  62. if imp.Name != nil {
  63. m[ipath] = imp.Name.Name
  64. } else {
  65. _, name := path.Split(ipath)
  66. m[ipath] = name
  67. }
  68. }
  69. return m
  70. }
  71. func errgoNew(f *ast.File) bool {
  72. ctxt := newErrgoFixContext(f)
  73. fixed := false
  74. walk(f, func(n interface{}) {
  75. warning := func(format string, arg ...interface{}) {
  76. pos := fset.Position(n.(ast.Node).Pos())
  77. log.Printf("warning: %s: %s", pos, fmt.Sprintf(format, arg...))
  78. }
  79. switch n := n.(type) {
  80. case *ast.CallExpr:
  81. switch {
  82. case isPkgDot(n.Fun, "fmt", "Errorf"):
  83. if len(n.Args) == 0 {
  84. warning("Errorf with no args")
  85. break
  86. }
  87. lit, ok := n.Args[0].(*ast.BasicLit)
  88. if !ok {
  89. warning("Errorf with non-constant first arg")
  90. break
  91. }
  92. if lit.Kind != token.STRING {
  93. warning("Errorf with non-string literal first arg")
  94. break
  95. }
  96. format, err := strconv.Unquote(lit.Value)
  97. if err != nil {
  98. warning("Errorf with invalid quoted string literal: %v", err)
  99. break
  100. }
  101. if !strings.HasSuffix(format, ": %v") || len(n.Args) < 2 || !isName(n.Args[len(n.Args)-1], "err") {
  102. // fmt.Errorf("foo %s", x) ->
  103. // errgo.Newf("foo %s", x)
  104. n.Fun = &ast.SelectorExpr{
  105. X: ast.NewIdent(errgoIdent),
  106. Sel: ast.NewIdent("Newf"),
  107. }
  108. fixed = true
  109. break
  110. }
  111. // fmt.Errorf("format: %v", args..., err) ->
  112. // errgo.Maskf(err, "format", args...)
  113. newArgs := []ast.Expr{
  114. n.Args[len(n.Args)-1],
  115. &ast.BasicLit{
  116. Kind: token.STRING,
  117. Value: fmt.Sprintf("%q", strings.TrimSuffix(format, ": %v")),
  118. },
  119. }
  120. newArgs = append(newArgs, n.Args[1:len(n.Args)-1]...)
  121. n.Args = newArgs
  122. n.Fun = &ast.SelectorExpr{
  123. X: ast.NewIdent(errgoIdent),
  124. Sel: ast.NewIdent("Notef"),
  125. }
  126. fixed = true
  127. case ctxt.importsOldErrgo && isPkgDot(n.Fun, "errgo", "Annotate"):
  128. n.Fun = &ast.SelectorExpr{
  129. X: ast.NewIdent(errgoIdent),
  130. Sel: ast.NewIdent("NoteMask"),
  131. }
  132. fixed = true
  133. case ctxt.importsOldErrgo && isPkgDot(n.Fun, "errgo", "Annotatef"):
  134. n.Fun = &ast.SelectorExpr{
  135. X: ast.NewIdent(errgoIdent),
  136. Sel: ast.NewIdent("Notef"),
  137. }
  138. fixed = true
  139. case ctxt.importsOldErrgo && isPkgDot(n.Fun, "errgo", "New"):
  140. n.Fun = &ast.SelectorExpr{
  141. X: ast.NewIdent(errgoIdent),
  142. Sel: ast.NewIdent("Newf"),
  143. }
  144. fixed = true
  145. case isPkgDot(n.Fun, ctxt.pathToIdent["errors"], "New"):
  146. n.Fun = &ast.SelectorExpr{
  147. X: ast.NewIdent(errgoIdent),
  148. Sel: ast.NewIdent("New"),
  149. }
  150. fixed = true
  151. }
  152. }
  153. })
  154. fixed = deleteImport(f, "github.com/errgo/errgo") || fixed
  155. fixed = rewriteImports(ctxt, f, fixed) || fixed
  156. return fixed
  157. }
  158. func errgoMask(f *ast.File) bool {
  159. ctxt := newErrgoFixContext(f)
  160. fixed := false
  161. walk(f, func(n interface{}) {
  162. switch n := n.(type) {
  163. case *ast.IfStmt:
  164. if ok := fixIfErrNotEqualNil(n); ok {
  165. fixed = true
  166. break
  167. }
  168. }
  169. })
  170. fixed = deleteImport(f, "github.com/errgo/errgo") || fixed
  171. fixed = rewriteImports(ctxt, f, fixed) || fixed
  172. return fixed
  173. }
  174. func errgoCause(f *ast.File) bool {
  175. ctxt := newErrgoFixContext(f)
  176. fixed := false
  177. walk(f, func(n interface{}) {
  178. switch n := n.(type) {
  179. case *ast.IfStmt:
  180. if ok := fixIfErrEqualSomething(n, errgoIdent); ok {
  181. fixed = true
  182. break
  183. }
  184. case *ast.TypeAssertExpr:
  185. if isName(n.X, "err") {
  186. n.X = causeExpr(errgoIdent, "err")
  187. fixed = true
  188. }
  189. case *ast.CallExpr:
  190. fixed = fixGocheck(n, errgoIdent, ctxt.gocheckIdent) || fixed
  191. }
  192. })
  193. if fixed {
  194. rewriteImports(ctxt, f, fixed)
  195. }
  196. return fixed
  197. }
  198. func rewriteImports(ctxt *errgoFixContext, f *ast.File, usingErrgo bool) bool {
  199. // If there was already an "errors" import, then we can
  200. // rewrite it to use errgo
  201. fixed := false
  202. if ctxt.pathToIdent["errors"] != "" {
  203. // We've already imported the errors package;
  204. // change it to refer to errgo.
  205. for _, imp := range f.Imports {
  206. if importPath(imp) == "errors" {
  207. fixed = true
  208. imp.Name = nil
  209. imp.EndPos = imp.End()
  210. imp.Path.Value = strconv.Quote(errgoPkgPath)
  211. }
  212. }
  213. } else if usingErrgo {
  214. fixed = addImport(f, errgoPkgPath, errgoIdent, false)
  215. }
  216. return fixed
  217. }
  218. func fixIfErrNotEqualNil(n *ast.IfStmt) bool {
  219. // if stmt; err != nil {
  220. // return [..., ]err
  221. // }
  222. // ->
  223. // if stmt; err != nil {
  224. // return [..., ]errgo.Mask(err)
  225. // }
  226. cond, ok := n.Cond.(*ast.BinaryExpr)
  227. if !ok {
  228. return false
  229. }
  230. if !isName(cond.X, "err") {
  231. return false
  232. }
  233. if !isName(cond.Y, "nil") {
  234. // comparison of errors against anything
  235. // other than nil - use errgo.Cause.
  236. }
  237. if cond.Op != token.NEQ {
  238. return false
  239. }
  240. if len(n.Body.List) != 1 {
  241. return false
  242. }
  243. returnStmt, ok := n.Body.List[0].(*ast.ReturnStmt)
  244. if !ok {
  245. return false
  246. }
  247. if len(returnStmt.Results) == 0 {
  248. return false
  249. }
  250. lastResult := &returnStmt.Results[len(returnStmt.Results)-1]
  251. if !isName(*lastResult, "err") {
  252. return false
  253. }
  254. *lastResult = &ast.CallExpr{
  255. Fun: &ast.SelectorExpr{
  256. X: ast.NewIdent(errgoIdent),
  257. Sel: ast.NewIdent("Mask"),
  258. },
  259. Args: []ast.Expr{ast.NewIdent("err")},
  260. }
  261. return true
  262. }
  263. func fixIfErrEqualSomething(n *ast.IfStmt, errgoIdent string) bool {
  264. // if stmt; err == something-but-not-nil
  265. // ->
  266. // if stmt; errgo.Cause(err) == something-but-not-nil
  267. cond, ok := n.Cond.(*ast.BinaryExpr)
  268. if !ok {
  269. return false
  270. }
  271. if !isName(cond.X, "err") {
  272. return false
  273. }
  274. if cond.Op != token.EQL {
  275. return false
  276. }
  277. if isName(cond.Y, "nil") {
  278. return false
  279. }
  280. cond.X = causeExpr(errgoIdent, "err")
  281. return true
  282. }
  283. func fixGocheck(n *ast.CallExpr, errgoIdent, gocheckIdent string) bool {
  284. // gc.Check(err, gc.Equals, foo-not-nil)
  285. // ->
  286. // gc.Check(errgo.Cause(err), gc.Equals, foo-not-nil)
  287. // gc.Check(err, gc.Not(gc.Equals), foo-not-nil)
  288. // ->
  289. // gc.Check(errgo.Cause(err), gc.Not(gc.Equals), foo-not-nil)
  290. if gocheckIdent == "" {
  291. return false
  292. }
  293. sel, ok := n.Fun.(*ast.SelectorExpr)
  294. if !ok {
  295. return false
  296. }
  297. if !isName(sel.X, "c") {
  298. return false
  299. }
  300. if s := sel.Sel.String(); s != "Check" && s != "Assert" {
  301. return false
  302. }
  303. if len(n.Args) < 3 {
  304. return false
  305. }
  306. if !isName(n.Args[0], "err") {
  307. return false
  308. }
  309. if condCall, ok := n.Args[1].(*ast.CallExpr); ok {
  310. if !isPkgDot(condCall.Fun, gocheckIdent, "Not") {
  311. return false
  312. }
  313. if len(condCall.Args) != 1 {
  314. return false
  315. }
  316. if !isPkgDot(condCall.Args[0], gocheckIdent, "Equals") {
  317. return false
  318. }
  319. } else if !isPkgDot(n.Args[1], gocheckIdent, "Equals") {
  320. return false
  321. }
  322. if isName(n.Args[2], "nil") {
  323. return false
  324. }
  325. n.Args[0] = causeExpr(errgoIdent, "err")
  326. return true
  327. }
  328. func causeExpr(errgoIdent string, ident string) *ast.CallExpr {
  329. return &ast.CallExpr{
  330. Fun: &ast.SelectorExpr{
  331. X: ast.NewIdent(errgoIdent),
  332. Sel: ast.NewIdent("Cause"),
  333. },
  334. Args: []ast.Expr{ast.NewIdent(ident)},
  335. }
  336. }