PageRenderTime 14ms CodeModel.GetById 1ms app.highlight 9ms RepoModel.GetById 1ms app.codeStats 0ms

/src/luarocks/type_check.lua

http://github.com/keplerproject/luarocks
Lua | 213 lines | 148 code | 21 blank | 44 comment | 44 complexity | fe37d6ba09249e0bc1a16946de4ad742 MD5 | raw file
  1
  2local type_check = {}
  3
  4local cfg = require("luarocks.core.cfg")
  5local fun = require("luarocks.fun")
  6local util = require("luarocks.util")
  7local vers = require("luarocks.core.vers")
  8--------------------------------------------------------------------------------
  9
 10-- A magic constant that is not used anywhere in a schema definition
 11-- and retains equality when the table is deep-copied.
 12type_check.MAGIC_PLATFORMS = 0xEBABEFAC
 13
 14do
 15   local function fill_in_version(tbl, version)
 16      for _, v in pairs(tbl) do
 17         if type(v) == "table" then
 18            if v._version == nil then
 19               v._version = version
 20            end
 21            fill_in_version(v)
 22         end
 23      end
 24   end
 25   
 26   local function expand_magic_platforms(tbl)
 27      for k,v in pairs(tbl) do
 28         if v == type_check.MAGIC_PLATFORMS then
 29            tbl[k] = {
 30               _any = util.deep_copy(tbl)
 31            }
 32            tbl[k]._any[k] = nil
 33         elseif type(v) == "table" then
 34            expand_magic_platforms(v)
 35         end
 36      end
 37   end
 38   
 39   -- Build a table of schemas.
 40   -- @param versions a table where each key is a version number as a string,
 41   -- and the value is a schema specification. Schema versions are considered
 42   -- incremental: version "2.0" only needs to specify what's new/changed from
 43   -- version "1.0".
 44   function type_check.declare_schemas(inputs)
 45      local schemas = {}
 46      local parent_version
 47   
 48      local versions = fun.reverse_in(fun.sort_in(util.keys(inputs), vers.compare_versions))
 49
 50      for _, version in ipairs(versions) do
 51         local schema = inputs[version]
 52         if parent_version ~= nil then
 53            local copy = util.deep_copy(schemas[parent_version])
 54            util.deep_merge(copy, schema)
 55            schema = copy
 56         end
 57         fill_in_version(schema, version)
 58         expand_magic_platforms(schema)
 59         parent_version = version
 60         schemas[version] = schema
 61      end
 62
 63      return schemas, versions
 64   end
 65end
 66
 67--------------------------------------------------------------------------------
 68
 69local function check_version(version, typetbl, context)
 70   local typetbl_version = typetbl._version or "1.0"
 71   if vers.compare_versions(typetbl_version, version) then
 72      if context == "" then
 73         return nil, "Invalid rockspec_format version number in rockspec? Please fix rockspec accordingly."
 74      else
 75         return nil, context.." is not supported in rockspec format "..version.." (requires version "..typetbl_version.."), please fix the rockspec_format field accordingly."
 76      end
 77   end
 78   return true
 79end
 80
 81--- Type check an object.
 82-- The object is compared against an archetypical value
 83-- matching the expected type -- the actual values don't matter,
 84-- only their types. Tables are type checked recursively.
 85-- @param version string: The version of the item.
 86-- @param item any: The object being checked.
 87-- @param typetbl any: The type-checking table for the object.
 88-- @param context string: A string indicating the "context" where the
 89-- error occurred (the full table path), for error messages.
 90-- @return boolean or (nil, string): true if type checking
 91-- succeeded, or nil and an error message if it failed.
 92-- @see type_check_table
 93local function type_check_item(version, item, typetbl, context)
 94   assert(type(version) == "string")
 95
 96   if typetbl._version and typetbl._version ~= "1.0" then
 97      local ok, err = check_version(version, typetbl, context)
 98      if not ok then
 99         return nil, err
100      end
101   end
102   
103   local item_type = type(item) or "nil"
104   local expected_type = typetbl._type or "table"
105   
106   if expected_type == "number" then
107      if not tonumber(item) then
108         return nil, "Type mismatch on field "..context..": expected a number"
109      end
110   elseif expected_type == "string" then
111      if item_type ~= "string" then
112         return nil, "Type mismatch on field "..context..": expected a string, got "..item_type
113      end
114      local pattern = typetbl._pattern
115      if pattern then
116         if not item:match("^"..pattern.."$") then
117            local what = typetbl._name or ("'"..pattern.."'")
118            return nil, "Type mismatch on field "..context..": invalid value '"..item.."' does not match " .. what
119         end
120      end
121   elseif expected_type == "table" then
122      if item_type ~= expected_type then
123         return nil, "Type mismatch on field "..context..": expected a table"
124      else
125         return type_check.type_check_table(version, item, typetbl, context)
126      end
127   elseif item_type ~= expected_type then
128      return nil, "Type mismatch on field "..context..": expected "..expected_type
129   end
130   return true
131end
132
133local function mkfield(context, field)
134   if context == "" then
135      return tostring(field)
136   elseif type(field) == "string" then
137      return context.."."..field
138   else
139      return context.."["..tostring(field).."]"
140   end
141end
142
143--- Type check the contents of a table.
144-- The table's contents are compared against a reference table,
145-- which contains the recognized fields, with archetypical values
146-- matching the expected types -- the actual values of items in the
147-- reference table don't matter, only their types (ie, for field x
148-- in tbl that is correctly typed, type(tbl.x) == type(types.x)).
149-- If the reference table contains a field called MORE, then
150-- unknown fields in the checked table are accepted.
151-- If it contains a field called ANY, then its type will be 
152-- used to check any unknown fields. If a field is prefixed
153-- with MUST_, it is mandatory; its absence from the table is
154-- a type error.
155-- Tables are type checked recursively.
156-- @param version string: The version of tbl.
157-- @param tbl table: The table to be type checked.
158-- @param typetbl table: The type-checking table, containing
159-- values for recognized fields in the checked table.
160-- @param context string: A string indicating the "context" where the
161-- error occurred (such as the name of the table the item is a part of),
162-- to be used by error messages.
163-- @return boolean or (nil, string): true if type checking
164-- succeeded, or nil and an error message if it failed.
165function type_check.type_check_table(version, tbl, typetbl, context)
166   assert(type(version) == "string")
167   assert(type(tbl) == "table")
168   assert(type(typetbl) == "table")
169
170   local ok, err = check_version(version, typetbl, context)
171   if not ok then
172      return nil, err
173   end
174
175   for k, v in pairs(tbl) do
176      local t = typetbl[k] or typetbl._any
177      if t then 
178         local ok, err = type_check_item(version, v, t, mkfield(context, k))
179         if not ok then return nil, err end
180      elseif typetbl._more then
181         -- Accept unknown field
182      else
183         if not cfg.accept_unknown_fields then
184            return nil, "Unknown field "..k
185         end
186      end
187   end
188   for k, v in pairs(typetbl) do
189      if k:sub(1,1) ~= "_" and v._mandatory then
190         if not tbl[k] then
191            return nil, "Mandatory field "..mkfield(context, k).." is missing."
192         end
193      end
194   end
195   return true
196end
197
198function type_check.check_undeclared_globals(globals, typetbl)
199   local undeclared = {}
200   for glob, _ in pairs(globals) do
201      if not (typetbl[glob] or typetbl["MUST_"..glob]) then
202         table.insert(undeclared, glob)
203      end
204   end
205   if #undeclared == 1 then
206      return nil, "Unknown variable: "..undeclared[1]
207   elseif #undeclared > 1 then
208      return nil, "Unknown variables: "..table.concat(undeclared, ", ")
209   end
210   return true
211end
212
213return type_check