/src/luarocks/type_check.lua

http://github.com/keplerproject/luarocks · Lua · 213 lines · 148 code · 21 blank · 44 comment · 65 complexity · fe37d6ba09249e0bc1a16946de4ad742 MD5 · raw file

  1. local type_check = {}
  2. local cfg = require("luarocks.core.cfg")
  3. local fun = require("luarocks.fun")
  4. local util = require("luarocks.util")
  5. local vers = require("luarocks.core.vers")
  6. --------------------------------------------------------------------------------
  7. -- A magic constant that is not used anywhere in a schema definition
  8. -- and retains equality when the table is deep-copied.
  9. type_check.MAGIC_PLATFORMS = 0xEBABEFAC
  10. do
  11. local function fill_in_version(tbl, version)
  12. for _, v in pairs(tbl) do
  13. if type(v) == "table" then
  14. if v._version == nil then
  15. v._version = version
  16. end
  17. fill_in_version(v)
  18. end
  19. end
  20. end
  21. local function expand_magic_platforms(tbl)
  22. for k,v in pairs(tbl) do
  23. if v == type_check.MAGIC_PLATFORMS then
  24. tbl[k] = {
  25. _any = util.deep_copy(tbl)
  26. }
  27. tbl[k]._any[k] = nil
  28. elseif type(v) == "table" then
  29. expand_magic_platforms(v)
  30. end
  31. end
  32. end
  33. -- Build a table of schemas.
  34. -- @param versions a table where each key is a version number as a string,
  35. -- and the value is a schema specification. Schema versions are considered
  36. -- incremental: version "2.0" only needs to specify what's new/changed from
  37. -- version "1.0".
  38. function type_check.declare_schemas(inputs)
  39. local schemas = {}
  40. local parent_version
  41. local versions = fun.reverse_in(fun.sort_in(util.keys(inputs), vers.compare_versions))
  42. for _, version in ipairs(versions) do
  43. local schema = inputs[version]
  44. if parent_version ~= nil then
  45. local copy = util.deep_copy(schemas[parent_version])
  46. util.deep_merge(copy, schema)
  47. schema = copy
  48. end
  49. fill_in_version(schema, version)
  50. expand_magic_platforms(schema)
  51. parent_version = version
  52. schemas[version] = schema
  53. end
  54. return schemas, versions
  55. end
  56. end
  57. --------------------------------------------------------------------------------
  58. local function check_version(version, typetbl, context)
  59. local typetbl_version = typetbl._version or "1.0"
  60. if vers.compare_versions(typetbl_version, version) then
  61. if context == "" then
  62. return nil, "Invalid rockspec_format version number in rockspec? Please fix rockspec accordingly."
  63. else
  64. return nil, context.." is not supported in rockspec format "..version.." (requires version "..typetbl_version.."), please fix the rockspec_format field accordingly."
  65. end
  66. end
  67. return true
  68. end
  69. --- Type check an object.
  70. -- The object is compared against an archetypical value
  71. -- matching the expected type -- the actual values don't matter,
  72. -- only their types. Tables are type checked recursively.
  73. -- @param version string: The version of the item.
  74. -- @param item any: The object being checked.
  75. -- @param typetbl any: The type-checking table for the object.
  76. -- @param context string: A string indicating the "context" where the
  77. -- error occurred (the full table path), for error messages.
  78. -- @return boolean or (nil, string): true if type checking
  79. -- succeeded, or nil and an error message if it failed.
  80. -- @see type_check_table
  81. local function type_check_item(version, item, typetbl, context)
  82. assert(type(version) == "string")
  83. if typetbl._version and typetbl._version ~= "1.0" then
  84. local ok, err = check_version(version, typetbl, context)
  85. if not ok then
  86. return nil, err
  87. end
  88. end
  89. local item_type = type(item) or "nil"
  90. local expected_type = typetbl._type or "table"
  91. if expected_type == "number" then
  92. if not tonumber(item) then
  93. return nil, "Type mismatch on field "..context..": expected a number"
  94. end
  95. elseif expected_type == "string" then
  96. if item_type ~= "string" then
  97. return nil, "Type mismatch on field "..context..": expected a string, got "..item_type
  98. end
  99. local pattern = typetbl._pattern
  100. if pattern then
  101. if not item:match("^"..pattern.."$") then
  102. local what = typetbl._name or ("'"..pattern.."'")
  103. return nil, "Type mismatch on field "..context..": invalid value '"..item.."' does not match " .. what
  104. end
  105. end
  106. elseif expected_type == "table" then
  107. if item_type ~= expected_type then
  108. return nil, "Type mismatch on field "..context..": expected a table"
  109. else
  110. return type_check.type_check_table(version, item, typetbl, context)
  111. end
  112. elseif item_type ~= expected_type then
  113. return nil, "Type mismatch on field "..context..": expected "..expected_type
  114. end
  115. return true
  116. end
  117. local function mkfield(context, field)
  118. if context == "" then
  119. return tostring(field)
  120. elseif type(field) == "string" then
  121. return context.."."..field
  122. else
  123. return context.."["..tostring(field).."]"
  124. end
  125. end
  126. --- Type check the contents of a table.
  127. -- The table's contents are compared against a reference table,
  128. -- which contains the recognized fields, with archetypical values
  129. -- matching the expected types -- the actual values of items in the
  130. -- reference table don't matter, only their types (ie, for field x
  131. -- in tbl that is correctly typed, type(tbl.x) == type(types.x)).
  132. -- If the reference table contains a field called MORE, then
  133. -- unknown fields in the checked table are accepted.
  134. -- If it contains a field called ANY, then its type will be
  135. -- used to check any unknown fields. If a field is prefixed
  136. -- with MUST_, it is mandatory; its absence from the table is
  137. -- a type error.
  138. -- Tables are type checked recursively.
  139. -- @param version string: The version of tbl.
  140. -- @param tbl table: The table to be type checked.
  141. -- @param typetbl table: The type-checking table, containing
  142. -- values for recognized fields in the checked table.
  143. -- @param context string: A string indicating the "context" where the
  144. -- error occurred (such as the name of the table the item is a part of),
  145. -- to be used by error messages.
  146. -- @return boolean or (nil, string): true if type checking
  147. -- succeeded, or nil and an error message if it failed.
  148. function type_check.type_check_table(version, tbl, typetbl, context)
  149. assert(type(version) == "string")
  150. assert(type(tbl) == "table")
  151. assert(type(typetbl) == "table")
  152. local ok, err = check_version(version, typetbl, context)
  153. if not ok then
  154. return nil, err
  155. end
  156. for k, v in pairs(tbl) do
  157. local t = typetbl[k] or typetbl._any
  158. if t then
  159. local ok, err = type_check_item(version, v, t, mkfield(context, k))
  160. if not ok then return nil, err end
  161. elseif typetbl._more then
  162. -- Accept unknown field
  163. else
  164. if not cfg.accept_unknown_fields then
  165. return nil, "Unknown field "..k
  166. end
  167. end
  168. end
  169. for k, v in pairs(typetbl) do
  170. if k:sub(1,1) ~= "_" and v._mandatory then
  171. if not tbl[k] then
  172. return nil, "Mandatory field "..mkfield(context, k).." is missing."
  173. end
  174. end
  175. end
  176. return true
  177. end
  178. function type_check.check_undeclared_globals(globals, typetbl)
  179. local undeclared = {}
  180. for glob, _ in pairs(globals) do
  181. if not (typetbl[glob] or typetbl["MUST_"..glob]) then
  182. table.insert(undeclared, glob)
  183. end
  184. end
  185. if #undeclared == 1 then
  186. return nil, "Unknown variable: "..undeclared[1]
  187. elseif #undeclared > 1 then
  188. return nil, "Unknown variables: "..table.concat(undeclared, ", ")
  189. end
  190. return true
  191. end
  192. return type_check