/test/lunit.lua
Lua | 670 lines | 576 code | 61 blank | 33 comment | 53 complexity | d36b2a4021fdc2da41b1b438a43c122b MD5 | raw file
1 2--[[-------------------------------------------------------------------------- 3 4 This file is part of lunit 0.5. 5 6 For Details about lunit look at: http://www.mroth.net/lunit/ 7 8 Author: Michael Roth <mroth@nessie.de> 9 10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de> 11 12 Permission is hereby granted, free of charge, to any person 13 obtaining a copy of this software and associated documentation 14 files (the "Software"), to deal in the Software without restriction, 15 including without limitation the rights to use, copy, modify, merge, 16 publish, distribute, sublicense, and/or sell copies of the Software, 17 and to permit persons to whom the Software is furnished to do so, 18 subject to the following conditions: 19 20 The above copyright notice and this permission notice shall be 21 included in all copies or substantial portions of the Software. 22 23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 30 31--]]-------------------------------------------------------------------------- 32 33 34 35 36local orig_assert = assert 37 38local pairs = pairs 39local ipairs = ipairs 40local next = next 41local type = type 42local error = error 43local tostring = tostring 44 45local string_sub = string.sub 46local string_format = string.format 47 48 49module("lunit", package.seeall) -- FIXME: Remove package.seeall 50 51local lunit = _M 52 53local __failure__ = {} -- Type tag for failed assertions 54 55local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" } 56 57 58 59local traceback_hide -- Traceback function which hides lunit internals 60local mypcall -- Protected call to a function with own traceback 61do 62 local _tb_hide = setmetatable( {}, {__mode="k"} ) 63 64 function traceback_hide(func) 65 _tb_hide[func] = true 66 end 67 68 local function my_traceback(errobj) 69 if is_table(errobj) and errobj.type == __failure__ then 70 local info = debug.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad... 71 errobj.where = string_format( "%s:%d", info.short_src, info.currentline) 72 else 73 errobj = { msg = tostring(errobj) } 74 errobj.tb = {} 75 local i = 2 76 while true do 77 local info = debug.getinfo(i, "Snlf") 78 if not is_table(info) then 79 break 80 end 81 if not _tb_hide[info.func] then 82 local line = {} -- Ripped from ldblib.c... 83 line[#line+1] = string_format("%s:", info.short_src) 84 if info.currentline > 0 then 85 line[#line+1] = string_format("%d:", info.currentline) 86 end 87 if info.namewhat ~= "" then 88 line[#line+1] = string_format(" in function '%s'", info.name) 89 else 90 if info.what == "main" then 91 line[#line+1] = " in main chunk" 92 elseif info.what == "C" or info.what == "tail" then 93 line[#line+1] = " ?" 94 else 95 line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined) 96 end 97 end 98 errobj.tb[#errobj.tb+1] = table.concat(line) 99 end 100 i = i + 1 101 end 102 end 103 return errobj 104 end 105 106 function mypcall(func) 107 orig_assert( is_function(func) ) 108 local ok, errobj = xpcall(func, my_traceback) 109 if not ok then 110 return errobj 111 end 112 end 113 traceback_hide(mypcall) 114end 115 116 117-- Type check functions 118 119for _, typename in ipairs(typenames) do 120 lunit["is_"..typename] = function(x) 121 return type(x) == typename 122 end 123end 124 125local is_nil = is_nil 126local is_boolean = is_boolean 127local is_number = is_number 128local is_string = is_string 129local is_table = is_table 130local is_function = is_function 131local is_thread = is_thread 132local is_userdata = is_userdata 133 134 135local function failure(name, usermsg, defaultmsg, ...) 136 local errobj = { 137 type = __failure__, 138 name = name, 139 msg = string_format(defaultmsg,...), 140 usermsg = usermsg 141 } 142 error(errobj, 0) 143end 144traceback_hide( failure ) 145 146 147local function format_arg(arg) 148 local argtype = type(arg) 149 if argtype == "string" then 150 return "'"..arg.."'" 151 elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then 152 return tostring(arg) 153 else 154 return "["..tostring(arg).."]" 155 end 156end 157 158 159function fail(msg) 160 stats.assertions = stats.assertions + 1 161 failure( "fail", msg, "failure" ) 162end 163traceback_hide( fail ) 164 165 166function assert(assertion, msg) 167 stats.assertions = stats.assertions + 1 168 if not assertion then 169 failure( "assert", msg, "assertion failed" ) 170 end 171 return assertion 172end 173traceback_hide( assert ) 174 175 176function assert_true(actual, msg) 177 stats.assertions = stats.assertions + 1 178 local actualtype = type(actual) 179 if actualtype ~= "boolean" then 180 failure( "assert_true", msg, "true expected but was a "..actualtype ) 181 end 182 if actual ~= true then 183 failure( "assert_true", msg, "true expected but was false" ) 184 end 185 return actual 186end 187traceback_hide( assert_true ) 188 189 190function assert_false(actual, msg) 191 stats.assertions = stats.assertions + 1 192 local actualtype = type(actual) 193 if actualtype ~= "boolean" then 194 failure( "assert_false", msg, "false expected but was a "..actualtype ) 195 end 196 if actual ~= false then 197 failure( "assert_false", msg, "false expected but was true" ) 198 end 199 return actual 200end 201traceback_hide( assert_false ) 202 203 204function assert_equal(expected, actual, msg) 205 stats.assertions = stats.assertions + 1 206 if expected ~= actual then 207 failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) ) 208 end 209 return actual 210end 211traceback_hide( assert_equal ) 212 213 214function assert_not_equal(unexpected, actual, msg) 215 stats.assertions = stats.assertions + 1 216 if unexpected == actual then 217 failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) ) 218 end 219 return actual 220end 221traceback_hide( assert_not_equal ) 222 223 224function assert_match(pattern, actual, msg) 225 stats.assertions = stats.assertions + 1 226 local patterntype = type(pattern) 227 if patterntype ~= "string" then 228 failure( "assert_match", msg, "expected the pattern as a string but was a "..patterntype ) 229 end 230 local actualtype = type(actual) 231 if actualtype ~= "string" then 232 failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, actualtype ) 233 end 234 if not string.find(actual, pattern) then 235 failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern ) 236 end 237 return actual 238end 239traceback_hide( assert_match ) 240 241 242function assert_not_match(pattern, actual, msg) 243 stats.assertions = stats.assertions + 1 244 local patterntype = type(pattern) 245 if patterntype ~= "string" then 246 failure( "assert_not_match", msg, "expected the pattern as a string but was a "..patterntype ) 247 end 248 local actualtype = type(actual) 249 if actualtype ~= "string" then 250 failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was a %s", pattern, actualtype ) 251 end 252 if string.find(actual, pattern) then 253 failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern ) 254 end 255 return actual 256end 257traceback_hide( assert_not_match ) 258 259 260function assert_error(msg, func) 261 stats.assertions = stats.assertions + 1 262 if func == nil then 263 func, msg = msg, nil 264 end 265 local functype = type(func) 266 if functype ~= "function" then 267 failure( "assert_error", msg, "expected a function as last argument but was a "..functype ) 268 end 269 local ok, errmsg = pcall(func) 270 if ok then 271 failure( "assert_error", msg, "error expected but no error occurred" ) 272 end 273end 274traceback_hide( assert_error ) 275 276 277function assert_error_match(msg, pattern, func) 278 stats.assertions = stats.assertions + 1 279 if func == nil then 280 msg, pattern, func = nil, msg, pattern 281 end 282 local patterntype = type(pattern) 283 if patterntype ~= "string" then 284 failure( "assert_error_match", msg, "expected the pattern as a string but was a "..patterntype ) 285 end 286 local functype = type(func) 287 if functype ~= "function" then 288 failure( "assert_error_match", msg, "expected a function as last argument but was a "..functype ) 289 end 290 local ok, errmsg = pcall(func) 291 if ok then 292 failure( "assert_error_match", msg, "error expected but no error occurred" ) 293 end 294 local errmsgtype = type(errmsg) 295 if errmsgtype ~= "string" then 296 failure( "assert_error_match", msg, "error as string expected but was a "..errmsgtype ) 297 end 298 if not string.find(errmsg, pattern) then 299 failure( "assert_error_match", msg, "expected error '%s' to match pattern '%s' but doesn't", errmsg, pattern ) 300 end 301end 302traceback_hide( assert_error_match ) 303 304 305function assert_pass(msg, func) 306 stats.assertions = stats.assertions + 1 307 if func == nil then 308 func, msg = msg, nil 309 end 310 local functype = type(func) 311 if functype ~= "function" then 312 failure( "assert_pass", msg, "expected a function as last argument but was a %s", functype ) 313 end 314 local ok, errmsg = pcall(func) 315 if not ok then 316 failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg ) 317 end 318end 319traceback_hide( assert_pass ) 320 321 322-- lunit.assert_typename functions 323 324for _, typename in ipairs(typenames) do 325 local assert_typename = "assert_"..typename 326 lunit[assert_typename] = function(actual, msg) 327 stats.assertions = stats.assertions + 1 328 local actualtype = type(actual) 329 if actualtype ~= typename then 330 failure( assert_typename, msg, typename.." expected but was a "..actualtype ) 331 end 332 return actual 333 end 334 traceback_hide( lunit[assert_typename] ) 335end 336 337 338-- lunit.assert_not_typename functions 339 340for _, typename in ipairs(typenames) do 341 local assert_not_typename = "assert_not_"..typename 342 lunit[assert_not_typename] = function(actual, msg) 343 stats.assertions = stats.assertions + 1 344 if type(actual) == typename then 345 failure( assert_not_typename, msg, typename.." not expected but was one" ) 346 end 347 end 348 traceback_hide( lunit[assert_not_typename] ) 349end 350 351 352function lunit.clearstats() 353 stats = { 354 assertions = 0; 355 passed = 0; 356 failed = 0; 357 errors = 0; 358 } 359end 360 361 362local report, reporterrobj 363do 364 local testrunner 365 366 function lunit.setrunner(newrunner) 367 if not ( is_table(newrunner) or is_nil(newrunner) ) then 368 return error("lunit.setrunner: Invalid argument", 0) 369 end 370 local oldrunner = testrunner 371 testrunner = newrunner 372 return oldrunner 373 end 374 375 function lunit.loadrunner(name) 376 if not is_string(name) then 377 return error("lunit.loadrunner: Invalid argument", 0) 378 end 379 local ok, runner = pcall( require, name ) 380 if not ok then 381 return error("lunit.loadrunner: Can't load test runner: "..runner, 0) 382 end 383 return setrunner(runner) 384 end 385 386 function report(event, ...) 387 local f = testrunner and testrunner[event] 388 if is_function(f) then 389 pcall(f, ...) 390 end 391 end 392 393 function reporterrobj(context, tcname, testname, errobj) 394 local fullname = tcname .. "." .. testname 395 if context == "setup" then 396 fullname = fullname .. ":" .. setupname(tcname, testname) 397 elseif context == "teardown" then 398 fullname = fullname .. ":" .. teardownname(tcname, testname) 399 end 400 if errobj.type == __failure__ then 401 stats.failed = stats.failed + 1 402 report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg) 403 else 404 stats.errors = stats.errors + 1 405 report("err", fullname, errobj.msg, errobj.tb) 406 end 407 end 408end 409 410 411 412local function key_iter(t, k) 413 return (next(t,k)) 414end 415 416 417local testcase 418do 419 -- Array with all registered testcases 420 local _testcases = {} 421 422 -- Marks a module as a testcase. 423 -- Applied over a module from module("xyz", lunit.testcase). 424 function lunit.testcase(m) 425 orig_assert( is_table(m) ) 426 --orig_assert( m._M == m ) 427 orig_assert( is_string(m._NAME) ) 428 --orig_assert( is_string(m._PACKAGE) ) 429 430 -- Register the module as a testcase 431 _testcases[m._NAME] = m 432 433 -- Import lunit, fail, assert* and is_* function to the module/testcase 434 m.lunit = lunit 435 m.fail = lunit.fail 436 for funcname, func in pairs(lunit) do 437 if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then 438 m[funcname] = func 439 end 440 end 441 end 442 443 -- Iterator (testcasename) over all Testcases 444 function lunit.testcases() 445 -- Make a copy of testcases to prevent confusing the iterator when 446 -- new testcase are defined 447 local _testcases2 = {} 448 for k,v in pairs(_testcases) do 449 _testcases2[k] = true 450 end 451 return key_iter, _testcases2, nil 452 end 453 454 function testcase(tcname) 455 return _testcases[tcname] 456 end 457end 458 459 460do 461 -- Finds a function in a testcase case insensitive 462 local function findfuncname(tcname, name) 463 for key, value in pairs(testcase(tcname)) do 464 if is_string(key) and is_function(value) and string.lower(key) == name then 465 return key 466 end 467 end 468 end 469 470 function lunit.setupname(tcname) 471 return findfuncname(tcname, "setup") 472 end 473 474 function lunit.teardownname(tcname) 475 return findfuncname(tcname, "teardown") 476 end 477 478 -- Iterator over all test names in a testcase. 479 -- Have to collect the names first in case one of the test 480 -- functions creates a new global and throws off the iteration. 481 function lunit.tests(tcname) 482 local testnames = {} 483 for key, value in pairs(testcase(tcname)) do 484 if is_string(key) and is_function(value) then 485 local lfn = string.lower(key) 486 if string.sub(lfn, 1, 4) == "test" or string.sub(lfn, -4) == "test" then 487 testnames[key] = true 488 end 489 end 490 end 491 return key_iter, testnames, nil 492 end 493end 494 495 496 497 498function lunit.runtest(tcname, testname) 499 orig_assert( is_string(tcname) ) 500 orig_assert( is_string(testname) ) 501 502 local function callit(context, func) 503 if func then 504 local err = mypcall(func) 505 if err then 506 reporterrobj(context, tcname, testname, err) 507 return false 508 end 509 end 510 return true 511 end 512 traceback_hide(callit) 513 514 report("run", tcname, testname) 515 516 local tc = testcase(tcname) 517 local setup = tc[setupname(tcname)] 518 local test = tc[testname] 519 local teardown = tc[teardownname(tcname)] 520 521 local setup_ok = callit( "setup", setup ) 522 local test_ok = setup_ok and callit( "test", test ) 523 local teardown_ok = setup_ok and callit( "teardown", teardown ) 524 525 if setup_ok and test_ok and teardown_ok then 526 stats.passed = stats.passed + 1 527 report("pass", tcname, testname) 528 end 529end 530traceback_hide(runtest) 531 532 533 534function lunit.run() 535 clearstats() 536 report("begin") 537 for testcasename in lunit.testcases() do 538 -- Run tests in the testcases 539 for testname in lunit.tests(testcasename) do 540 runtest(testcasename, testname) 541 end 542 end 543 report("done") 544 return stats 545end 546traceback_hide(run) 547 548 549function lunit.loadonly() 550 clearstats() 551 report("begin") 552 report("done") 553 return stats 554end 555 556 557 558 559 560 561 562 563 564local lunitpat2luapat 565do 566 local conv = { 567 ["^"] = "%^", 568 ["$"] = "%$", 569 ["("] = "%(", 570 [")"] = "%)", 571 ["%"] = "%%", 572 ["."] = "%.", 573 ["["] = "%[", 574 ["]"] = "%]", 575 ["+"] = "%+", 576 ["-"] = "%-", 577 ["?"] = ".", 578 ["*"] = ".*" 579 } 580 function lunitpat2luapat(str) 581 return "^" .. string.gsub(str, "%W", conv) .. "$" 582 end 583end 584 585 586 587local function in_patternmap(map, name) 588 if map[name] == true then 589 return true 590 else 591 for _, pat in ipairs(map) do 592 if string.find(name, pat) then 593 return true 594 end 595 end 596 end 597 return false 598end 599 600 601 602 603 604 605 606 607-- Called from 'lunit' shell script. 608 609function main(argv) 610 argv = argv or {} 611 612 -- FIXME: Error handling and error messages aren't nice. 613 614 local function checkarg(optname, arg) 615 if not is_string(arg) then 616 return error("lunit.main: option "..optname..": argument missing.", 0) 617 end 618 end 619 620 local function loadtestcase(filename) 621 if not is_string(filename) then 622 return error("lunit.main: invalid argument") 623 end 624 local chunk, err = loadfile(filename) 625 if err then 626 return error(err) 627 else 628 chunk() 629 end 630 end 631 632 local testpatterns = nil 633 local doloadonly = false 634 local runner = nil 635 636 local i = 0 637 while i < #argv do 638 i = i + 1 639 local arg = argv[i] 640 if arg == "--loadonly" then 641 doloadonly = true 642 elseif arg == "--runner" or arg == "-r" then 643 local optname = arg; i = i + 1; arg = argv[i] 644 checkarg(optname, arg) 645 runner = arg 646 elseif arg == "--test" or arg == "-t" then 647 local optname = arg; i = i + 1; arg = argv[i] 648 checkarg(optname, arg) 649 testpatterns = testpatterns or {} 650 testpatterns[#testpatterns+1] = arg 651 elseif arg == "--" then 652 while i < #argv do 653 i = i + 1; arg = argv[i] 654 loadtestcase(arg) 655 end 656 else 657 loadtestcase(arg) 658 end 659 end 660 661 loadrunner(runner or "lunit-console") 662 663 if doloadonly then 664 return loadonly() 665 else 666 return run(testpatterns) 667 end 668end 669 670clearstats()