/test/lunit.lua

http://github.com/u1tnk/kah.lua · Lua · 670 lines · 576 code · 61 blank · 33 comment · 75 complexity · d36b2a4021fdc2da41b1b438a43c122b MD5 · raw file

  1. --[[--------------------------------------------------------------------------
  2. This file is part of lunit 0.5.
  3. For Details about lunit look at: http://www.mroth.net/lunit/
  4. Author: Michael Roth <mroth@nessie.de>
  5. Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
  6. Permission is hereby granted, free of charge, to any person
  7. obtaining a copy of this software and associated documentation
  8. files (the "Software"), to deal in the Software without restriction,
  9. including without limitation the rights to use, copy, modify, merge,
  10. publish, distribute, sublicense, and/or sell copies of the Software,
  11. and to permit persons to whom the Software is furnished to do so,
  12. subject to the following conditions:
  13. The above copyright notice and this permission notice shall be
  14. included in all copies or substantial portions of the Software.
  15. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  16. EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  17. MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
  18. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
  19. CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
  20. TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
  21. SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  22. --]]--------------------------------------------------------------------------
  23. local orig_assert = assert
  24. local pairs = pairs
  25. local ipairs = ipairs
  26. local next = next
  27. local type = type
  28. local error = error
  29. local tostring = tostring
  30. local string_sub = string.sub
  31. local string_format = string.format
  32. module("lunit", package.seeall) -- FIXME: Remove package.seeall
  33. local lunit = _M
  34. local __failure__ = {} -- Type tag for failed assertions
  35. local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
  36. local traceback_hide -- Traceback function which hides lunit internals
  37. local mypcall -- Protected call to a function with own traceback
  38. do
  39. local _tb_hide = setmetatable( {}, {__mode="k"} )
  40. function traceback_hide(func)
  41. _tb_hide[func] = true
  42. end
  43. local function my_traceback(errobj)
  44. if is_table(errobj) and errobj.type == __failure__ then
  45. local info = debug.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
  46. errobj.where = string_format( "%s:%d", info.short_src, info.currentline)
  47. else
  48. errobj = { msg = tostring(errobj) }
  49. errobj.tb = {}
  50. local i = 2
  51. while true do
  52. local info = debug.getinfo(i, "Snlf")
  53. if not is_table(info) then
  54. break
  55. end
  56. if not _tb_hide[info.func] then
  57. local line = {} -- Ripped from ldblib.c...
  58. line[#line+1] = string_format("%s:", info.short_src)
  59. if info.currentline > 0 then
  60. line[#line+1] = string_format("%d:", info.currentline)
  61. end
  62. if info.namewhat ~= "" then
  63. line[#line+1] = string_format(" in function '%s'", info.name)
  64. else
  65. if info.what == "main" then
  66. line[#line+1] = " in main chunk"
  67. elseif info.what == "C" or info.what == "tail" then
  68. line[#line+1] = " ?"
  69. else
  70. line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined)
  71. end
  72. end
  73. errobj.tb[#errobj.tb+1] = table.concat(line)
  74. end
  75. i = i + 1
  76. end
  77. end
  78. return errobj
  79. end
  80. function mypcall(func)
  81. orig_assert( is_function(func) )
  82. local ok, errobj = xpcall(func, my_traceback)
  83. if not ok then
  84. return errobj
  85. end
  86. end
  87. traceback_hide(mypcall)
  88. end
  89. -- Type check functions
  90. for _, typename in ipairs(typenames) do
  91. lunit["is_"..typename] = function(x)
  92. return type(x) == typename
  93. end
  94. end
  95. local is_nil = is_nil
  96. local is_boolean = is_boolean
  97. local is_number = is_number
  98. local is_string = is_string
  99. local is_table = is_table
  100. local is_function = is_function
  101. local is_thread = is_thread
  102. local is_userdata = is_userdata
  103. local function failure(name, usermsg, defaultmsg, ...)
  104. local errobj = {
  105. type = __failure__,
  106. name = name,
  107. msg = string_format(defaultmsg,...),
  108. usermsg = usermsg
  109. }
  110. error(errobj, 0)
  111. end
  112. traceback_hide( failure )
  113. local function format_arg(arg)
  114. local argtype = type(arg)
  115. if argtype == "string" then
  116. return "'"..arg.."'"
  117. elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then
  118. return tostring(arg)
  119. else
  120. return "["..tostring(arg).."]"
  121. end
  122. end
  123. function fail(msg)
  124. stats.assertions = stats.assertions + 1
  125. failure( "fail", msg, "failure" )
  126. end
  127. traceback_hide( fail )
  128. function assert(assertion, msg)
  129. stats.assertions = stats.assertions + 1
  130. if not assertion then
  131. failure( "assert", msg, "assertion failed" )
  132. end
  133. return assertion
  134. end
  135. traceback_hide( assert )
  136. function assert_true(actual, msg)
  137. stats.assertions = stats.assertions + 1
  138. local actualtype = type(actual)
  139. if actualtype ~= "boolean" then
  140. failure( "assert_true", msg, "true expected but was a "..actualtype )
  141. end
  142. if actual ~= true then
  143. failure( "assert_true", msg, "true expected but was false" )
  144. end
  145. return actual
  146. end
  147. traceback_hide( assert_true )
  148. function assert_false(actual, msg)
  149. stats.assertions = stats.assertions + 1
  150. local actualtype = type(actual)
  151. if actualtype ~= "boolean" then
  152. failure( "assert_false", msg, "false expected but was a "..actualtype )
  153. end
  154. if actual ~= false then
  155. failure( "assert_false", msg, "false expected but was true" )
  156. end
  157. return actual
  158. end
  159. traceback_hide( assert_false )
  160. function assert_equal(expected, actual, msg)
  161. stats.assertions = stats.assertions + 1
  162. if expected ~= actual then
  163. failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) )
  164. end
  165. return actual
  166. end
  167. traceback_hide( assert_equal )
  168. function assert_not_equal(unexpected, actual, msg)
  169. stats.assertions = stats.assertions + 1
  170. if unexpected == actual then
  171. failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) )
  172. end
  173. return actual
  174. end
  175. traceback_hide( assert_not_equal )
  176. function assert_match(pattern, actual, msg)
  177. stats.assertions = stats.assertions + 1
  178. local patterntype = type(pattern)
  179. if patterntype ~= "string" then
  180. failure( "assert_match", msg, "expected the pattern as a string but was a "..patterntype )
  181. end
  182. local actualtype = type(actual)
  183. if actualtype ~= "string" then
  184. failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, actualtype )
  185. end
  186. if not string.find(actual, pattern) then
  187. failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern )
  188. end
  189. return actual
  190. end
  191. traceback_hide( assert_match )
  192. function assert_not_match(pattern, actual, msg)
  193. stats.assertions = stats.assertions + 1
  194. local patterntype = type(pattern)
  195. if patterntype ~= "string" then
  196. failure( "assert_not_match", msg, "expected the pattern as a string but was a "..patterntype )
  197. end
  198. local actualtype = type(actual)
  199. if actualtype ~= "string" then
  200. failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was a %s", pattern, actualtype )
  201. end
  202. if string.find(actual, pattern) then
  203. failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern )
  204. end
  205. return actual
  206. end
  207. traceback_hide( assert_not_match )
  208. function assert_error(msg, func)
  209. stats.assertions = stats.assertions + 1
  210. if func == nil then
  211. func, msg = msg, nil
  212. end
  213. local functype = type(func)
  214. if functype ~= "function" then
  215. failure( "assert_error", msg, "expected a function as last argument but was a "..functype )
  216. end
  217. local ok, errmsg = pcall(func)
  218. if ok then
  219. failure( "assert_error", msg, "error expected but no error occurred" )
  220. end
  221. end
  222. traceback_hide( assert_error )
  223. function assert_error_match(msg, pattern, func)
  224. stats.assertions = stats.assertions + 1
  225. if func == nil then
  226. msg, pattern, func = nil, msg, pattern
  227. end
  228. local patterntype = type(pattern)
  229. if patterntype ~= "string" then
  230. failure( "assert_error_match", msg, "expected the pattern as a string but was a "..patterntype )
  231. end
  232. local functype = type(func)
  233. if functype ~= "function" then
  234. failure( "assert_error_match", msg, "expected a function as last argument but was a "..functype )
  235. end
  236. local ok, errmsg = pcall(func)
  237. if ok then
  238. failure( "assert_error_match", msg, "error expected but no error occurred" )
  239. end
  240. local errmsgtype = type(errmsg)
  241. if errmsgtype ~= "string" then
  242. failure( "assert_error_match", msg, "error as string expected but was a "..errmsgtype )
  243. end
  244. if not string.find(errmsg, pattern) then
  245. failure( "assert_error_match", msg, "expected error '%s' to match pattern '%s' but doesn't", errmsg, pattern )
  246. end
  247. end
  248. traceback_hide( assert_error_match )
  249. function assert_pass(msg, func)
  250. stats.assertions = stats.assertions + 1
  251. if func == nil then
  252. func, msg = msg, nil
  253. end
  254. local functype = type(func)
  255. if functype ~= "function" then
  256. failure( "assert_pass", msg, "expected a function as last argument but was a %s", functype )
  257. end
  258. local ok, errmsg = pcall(func)
  259. if not ok then
  260. failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg )
  261. end
  262. end
  263. traceback_hide( assert_pass )
  264. -- lunit.assert_typename functions
  265. for _, typename in ipairs(typenames) do
  266. local assert_typename = "assert_"..typename
  267. lunit[assert_typename] = function(actual, msg)
  268. stats.assertions = stats.assertions + 1
  269. local actualtype = type(actual)
  270. if actualtype ~= typename then
  271. failure( assert_typename, msg, typename.." expected but was a "..actualtype )
  272. end
  273. return actual
  274. end
  275. traceback_hide( lunit[assert_typename] )
  276. end
  277. -- lunit.assert_not_typename functions
  278. for _, typename in ipairs(typenames) do
  279. local assert_not_typename = "assert_not_"..typename
  280. lunit[assert_not_typename] = function(actual, msg)
  281. stats.assertions = stats.assertions + 1
  282. if type(actual) == typename then
  283. failure( assert_not_typename, msg, typename.." not expected but was one" )
  284. end
  285. end
  286. traceback_hide( lunit[assert_not_typename] )
  287. end
  288. function lunit.clearstats()
  289. stats = {
  290. assertions = 0;
  291. passed = 0;
  292. failed = 0;
  293. errors = 0;
  294. }
  295. end
  296. local report, reporterrobj
  297. do
  298. local testrunner
  299. function lunit.setrunner(newrunner)
  300. if not ( is_table(newrunner) or is_nil(newrunner) ) then
  301. return error("lunit.setrunner: Invalid argument", 0)
  302. end
  303. local oldrunner = testrunner
  304. testrunner = newrunner
  305. return oldrunner
  306. end
  307. function lunit.loadrunner(name)
  308. if not is_string(name) then
  309. return error("lunit.loadrunner: Invalid argument", 0)
  310. end
  311. local ok, runner = pcall( require, name )
  312. if not ok then
  313. return error("lunit.loadrunner: Can't load test runner: "..runner, 0)
  314. end
  315. return setrunner(runner)
  316. end
  317. function report(event, ...)
  318. local f = testrunner and testrunner[event]
  319. if is_function(f) then
  320. pcall(f, ...)
  321. end
  322. end
  323. function reporterrobj(context, tcname, testname, errobj)
  324. local fullname = tcname .. "." .. testname
  325. if context == "setup" then
  326. fullname = fullname .. ":" .. setupname(tcname, testname)
  327. elseif context == "teardown" then
  328. fullname = fullname .. ":" .. teardownname(tcname, testname)
  329. end
  330. if errobj.type == __failure__ then
  331. stats.failed = stats.failed + 1
  332. report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg)
  333. else
  334. stats.errors = stats.errors + 1
  335. report("err", fullname, errobj.msg, errobj.tb)
  336. end
  337. end
  338. end
  339. local function key_iter(t, k)
  340. return (next(t,k))
  341. end
  342. local testcase
  343. do
  344. -- Array with all registered testcases
  345. local _testcases = {}
  346. -- Marks a module as a testcase.
  347. -- Applied over a module from module("xyz", lunit.testcase).
  348. function lunit.testcase(m)
  349. orig_assert( is_table(m) )
  350. --orig_assert( m._M == m )
  351. orig_assert( is_string(m._NAME) )
  352. --orig_assert( is_string(m._PACKAGE) )
  353. -- Register the module as a testcase
  354. _testcases[m._NAME] = m
  355. -- Import lunit, fail, assert* and is_* function to the module/testcase
  356. m.lunit = lunit
  357. m.fail = lunit.fail
  358. for funcname, func in pairs(lunit) do
  359. if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then
  360. m[funcname] = func
  361. end
  362. end
  363. end
  364. -- Iterator (testcasename) over all Testcases
  365. function lunit.testcases()
  366. -- Make a copy of testcases to prevent confusing the iterator when
  367. -- new testcase are defined
  368. local _testcases2 = {}
  369. for k,v in pairs(_testcases) do
  370. _testcases2[k] = true
  371. end
  372. return key_iter, _testcases2, nil
  373. end
  374. function testcase(tcname)
  375. return _testcases[tcname]
  376. end
  377. end
  378. do
  379. -- Finds a function in a testcase case insensitive
  380. local function findfuncname(tcname, name)
  381. for key, value in pairs(testcase(tcname)) do
  382. if is_string(key) and is_function(value) and string.lower(key) == name then
  383. return key
  384. end
  385. end
  386. end
  387. function lunit.setupname(tcname)
  388. return findfuncname(tcname, "setup")
  389. end
  390. function lunit.teardownname(tcname)
  391. return findfuncname(tcname, "teardown")
  392. end
  393. -- Iterator over all test names in a testcase.
  394. -- Have to collect the names first in case one of the test
  395. -- functions creates a new global and throws off the iteration.
  396. function lunit.tests(tcname)
  397. local testnames = {}
  398. for key, value in pairs(testcase(tcname)) do
  399. if is_string(key) and is_function(value) then
  400. local lfn = string.lower(key)
  401. if string.sub(lfn, 1, 4) == "test" or string.sub(lfn, -4) == "test" then
  402. testnames[key] = true
  403. end
  404. end
  405. end
  406. return key_iter, testnames, nil
  407. end
  408. end
  409. function lunit.runtest(tcname, testname)
  410. orig_assert( is_string(tcname) )
  411. orig_assert( is_string(testname) )
  412. local function callit(context, func)
  413. if func then
  414. local err = mypcall(func)
  415. if err then
  416. reporterrobj(context, tcname, testname, err)
  417. return false
  418. end
  419. end
  420. return true
  421. end
  422. traceback_hide(callit)
  423. report("run", tcname, testname)
  424. local tc = testcase(tcname)
  425. local setup = tc[setupname(tcname)]
  426. local test = tc[testname]
  427. local teardown = tc[teardownname(tcname)]
  428. local setup_ok = callit( "setup", setup )
  429. local test_ok = setup_ok and callit( "test", test )
  430. local teardown_ok = setup_ok and callit( "teardown", teardown )
  431. if setup_ok and test_ok and teardown_ok then
  432. stats.passed = stats.passed + 1
  433. report("pass", tcname, testname)
  434. end
  435. end
  436. traceback_hide(runtest)
  437. function lunit.run()
  438. clearstats()
  439. report("begin")
  440. for testcasename in lunit.testcases() do
  441. -- Run tests in the testcases
  442. for testname in lunit.tests(testcasename) do
  443. runtest(testcasename, testname)
  444. end
  445. end
  446. report("done")
  447. return stats
  448. end
  449. traceback_hide(run)
  450. function lunit.loadonly()
  451. clearstats()
  452. report("begin")
  453. report("done")
  454. return stats
  455. end
  456. local lunitpat2luapat
  457. do
  458. local conv = {
  459. ["^"] = "%^",
  460. ["$"] = "%$",
  461. ["("] = "%(",
  462. [")"] = "%)",
  463. ["%"] = "%%",
  464. ["."] = "%.",
  465. ["["] = "%[",
  466. ["]"] = "%]",
  467. ["+"] = "%+",
  468. ["-"] = "%-",
  469. ["?"] = ".",
  470. ["*"] = ".*"
  471. }
  472. function lunitpat2luapat(str)
  473. return "^" .. string.gsub(str, "%W", conv) .. "$"
  474. end
  475. end
  476. local function in_patternmap(map, name)
  477. if map[name] == true then
  478. return true
  479. else
  480. for _, pat in ipairs(map) do
  481. if string.find(name, pat) then
  482. return true
  483. end
  484. end
  485. end
  486. return false
  487. end
  488. -- Called from 'lunit' shell script.
  489. function main(argv)
  490. argv = argv or {}
  491. -- FIXME: Error handling and error messages aren't nice.
  492. local function checkarg(optname, arg)
  493. if not is_string(arg) then
  494. return error("lunit.main: option "..optname..": argument missing.", 0)
  495. end
  496. end
  497. local function loadtestcase(filename)
  498. if not is_string(filename) then
  499. return error("lunit.main: invalid argument")
  500. end
  501. local chunk, err = loadfile(filename)
  502. if err then
  503. return error(err)
  504. else
  505. chunk()
  506. end
  507. end
  508. local testpatterns = nil
  509. local doloadonly = false
  510. local runner = nil
  511. local i = 0
  512. while i < #argv do
  513. i = i + 1
  514. local arg = argv[i]
  515. if arg == "--loadonly" then
  516. doloadonly = true
  517. elseif arg == "--runner" or arg == "-r" then
  518. local optname = arg; i = i + 1; arg = argv[i]
  519. checkarg(optname, arg)
  520. runner = arg
  521. elseif arg == "--test" or arg == "-t" then
  522. local optname = arg; i = i + 1; arg = argv[i]
  523. checkarg(optname, arg)
  524. testpatterns = testpatterns or {}
  525. testpatterns[#testpatterns+1] = arg
  526. elseif arg == "--" then
  527. while i < #argv do
  528. i = i + 1; arg = argv[i]
  529. loadtestcase(arg)
  530. end
  531. else
  532. loadtestcase(arg)
  533. end
  534. end
  535. loadrunner(runner or "lunit-console")
  536. if doloadonly then
  537. return loadonly()
  538. else
  539. return run(testpatterns)
  540. end
  541. end
  542. clearstats()