PageRenderTime 34ms CodeModel.GetById 12ms app.highlight 17ms RepoModel.GetById 1ms app.codeStats 0ms

/src/orbit/model.lua

http://github.com/keplerproject/orbit
Lua | 526 lines | 468 code | 58 blank | 0 comment | 102 complexity | 2868fe005fe5d76c4e31a067b8d038b9 MD5 | raw file
  1local lpeg = require "lpeg"
  2local re   = require "re"
  3
  4module("orbit.model", package.seeall)
  5
  6model_methods = {}
  7
  8dao_methods = {}
  9
 10local type_names = {}
 11
 12local function log_query(sql)
 13  io.stderr:write("[orbit.model] " .. sql .. "\n")
 14end
 15
 16function type_names.sqlite3(t)
 17  return string.lower(string.match(t, "(%a+)"))
 18end
 19
 20function type_names.mysql(t)
 21  if t == "number(1)" then
 22    return "boolean"
 23  else
 24    return string.lower(string.match(t, "(%a+)"))
 25  end
 26end
 27
 28function type_names.postgres(t)
 29  if t == "bool" then
 30    return "boolean"
 31  else
 32    return string.lower(string.match(t, "(%a+)"))
 33  end
 34end
 35
 36local convert = {}
 37
 38function convert.real(v)
 39  return tonumber(v)
 40end
 41
 42function convert.float(v)
 43  return tonumber(v)
 44end
 45
 46function convert.integer(v)
 47  return tonumber(v)
 48end
 49
 50function convert.int(v)
 51  return tonumber(v)
 52end
 53
 54function convert.number(v)
 55  return tonumber(v)
 56end
 57
 58function convert.numeric(v)
 59  return tonumber(v)
 60end
 61
 62function convert.varchar(v)
 63  return tostring(v)
 64end
 65
 66function convert.string(v)
 67  return tostring(v)
 68end
 69
 70function convert.text(v)
 71  return tostring(v)
 72end
 73
 74function convert.boolean(v, driver)
 75  if driver == "sqlite3" then
 76    return v == "t"
 77  elseif driver == "mysql" then
 78    return tonumber(v) == 1
 79  elseif driver == "postgres" then
 80    return v == "t"
 81  else
 82    error("driver not supported")
 83  end
 84end
 85
 86function convert.binary(v)
 87  return convert.text(v)
 88end
 89
 90function convert.datetime(v)
 91  local year, month, day, hour, min, sec =
 92    string.match(v, "(%d+)%-(%d+)%-(%d+) (%d+):(%d+):(%d+)")
 93  return os.time({ year = tonumber(year), month = tonumber(month),
 94		   day = tonumber(day), hour = tonumber(hour),
 95		   min = tonumber(min), sec = tonumber(sec) })
 96end
 97convert.timestamp = convert.datetime
 98
 99local function convert_types(row, meta, driver)
100  for k, v in pairs(row) do
101    if meta[k] then
102      local conv = convert[meta[k].type]
103      if conv then
104        row[k] = conv(v, driver)
105      else
106        error("no conversion for type " .. meta[k].type)
107      end
108    end
109  end
110end
111
112local escape = {}
113
114function escape.real(v)
115  return tostring(v)
116end
117
118function escape.float(v)
119  return tostring(v)
120end
121
122function escape.integer(v)
123  return tostring(v)
124end
125
126function escape.int(v)
127  return tostring(v)
128end
129
130function escape.number(v)
131  return escape.integer(v)
132end
133
134function escape.numeric(v)
135  return escape.integer(v)
136end
137
138function escape.varchar(v, driver, conn)
139  return "'" .. conn:escape(v) .. "'"
140end
141
142function escape.string(v, driver, conn)
143  return escape.varchar(v, driver, conn)
144end
145
146function escape.text(v, driver, conn)
147  return "'" .. conn:escape(v) .. "'"
148end
149
150function escape.datetime(v)
151  return "'" .. os.date("%Y-%m-%d %H:%M:%S", v) .. "'"
152end
153escape.timestamp = escape.datetime
154
155function escape.boolean(v, driver)
156  if v then
157    if driver == "sqlite3" or driver == "postgres" then return "'t'" else return tostring(v) end
158  else
159    if driver == "sqlite3" or driver == "postgres" then return "'f'" else return tostring(v) end
160  end
161end
162
163function escape.binary(v, driver, conn)
164  return escape.text(v, driver, conn)
165end
166
167local function escape_values(row)
168  local row_escaped = {}
169  for i, m in ipairs(row.meta) do
170    if row[m.name] == nil then
171      row_escaped[m.name] = "NULL"
172    else
173      local esc = escape[m.type]
174      if esc then
175	row_escaped[m.name] = esc(row[m.name], row.driver, row.model.conn)
176      else
177	error("no escape function for type " .. m.type)
178      end
179    end
180  end
181  return row_escaped
182end
183
184local function fetch_row(dao, sql)
185  local cursor, err = dao.model.conn:execute(sql)
186  if not cursor then error(err) end
187  local row = cursor:fetch({}, "a")
188  cursor:close()
189  if row then
190    convert_types(row, dao.meta, dao.driver)
191    setmetatable(row, { __index = dao })
192  end
193  return row
194end
195
196local function fetch_rows(dao, sql, count)
197  local rows = {}
198  local cursor, err = dao.model.conn:execute(sql)
199  if not cursor then error(err) end
200  local row, fetched = cursor:fetch({}, "a"), 1
201  while row and (not count or fetched <= count) do
202    convert_types(row, dao.meta, dao.driver)
203    setmetatable(row, { __index = dao })
204    rows[#rows + 1] = row
205    row, fetched = cursor:fetch({}, "a"), fetched + 1
206  end
207  cursor:close()
208  return rows
209end
210
211local by_condition_parser = re.compile([[
212  fields <- ({(!conective .)+} (conective {(!conective .)+})*) -> {}
213  conective <- and / or
214  and <- "_and_" -> "and"
215  or <- "_or_" -> "or"
216]])
217
218local function parse_condition(dao, condition, args)
219  local parts = by_condition_parser:match(condition)
220  local j = 0
221  for i, part in ipairs(parts) do
222    if part ~= "or" and part ~= "and" then
223      j = j + 1
224      local value
225      if args[j] == nil then
226        parts[i] = part .. " is null"
227      elseif type(args[j]) == "table" then
228        local values = {}
229        for _, value in ipairs(args[j]) do
230	  values[#values + 1] = escape[dao.meta[part].type](value, dao.driver, dao.model.conn)
231        end
232        parts[i] = part .. " IN (" .. table.concat(values,", ") .. ")"
233      else
234        value = escape[dao.meta[part].type](args[j], dao.driver, dao.model.conn)
235        parts[i] = part .. " = " .. value
236      end
237    end
238  end
239  return parts
240end
241
242local function build_inject(project, inject, dao)
243  local fields = {}
244  if project then
245     for i, field in ipairs(project) do
246	fields[i] = dao.table_name .. "." .. field .. " as " .. field
247     end
248  else
249     for i, field in ipairs(dao.meta) do
250	fields[i] = dao.table_name .. "." .. field.name .. " as " .. field.name
251     end
252  end
253  local inject_fields = {}
254  local model = inject.model
255  for _, field in ipairs(inject.fields) do
256    inject_fields[model.name .. "_" .. field] =
257      model.meta[field]
258    fields[#fields + 1] = model.table_name .. "." .. field .. " as " ..
259      model.name .. "_" .. field
260  end
261  setmetatable(dao.meta, { __index = inject_fields })
262  return table.concat(fields, ", "), dao.table_name .. ", " ..
263    model.table_name,  model.name .. "_id = " .. model.table_name .. ".id"
264end
265
266local function build_query_by(dao, condition, args)
267  local parts = parse_condition(dao, condition, args)
268  local order = ""
269  local field_list, table_list, select, limit, offset
270  if args.distinct then select = "select distinct " else select = "select " end
271  if tonumber(args.count) then limit = " limit " .. tonumber(args.count) else limit = "" end
272  if tonumber(args.offset) then offset = " offset " .. tonumber(args.offset) else offset = "" end
273  if args.order then order = " order by " .. args.order end
274  if args.inject then
275    if #parts > 0 then parts[#parts + 1] = "and" end
276    field_list, table_list, parts[#parts + 1] = build_inject(args.fields, args.inject,
277      dao)
278  else
279    if args.fields then
280       field_list = table.concat(args.fields, ", ")
281    else
282       field_list = "*"
283    end
284    table_list = dao.table_name
285  end
286  local sql = select .. field_list .. " from " .. table_list ..
287    " where " .. table.concat(parts, " ") .. order .. limit .. offset
288  if dao.model.logging then log_query(sql) end
289  return sql
290end
291
292local function find_by(dao, condition, args)
293  return fetch_row(dao, build_query_by(dao, condition, args))
294end
295
296local function find_all_by(dao, condition, args)
297  return fetch_rows(dao, build_query_by(dao, condition, args), args.count)
298end
299
300local function dao_index(dao, name)
301  local m = dao_methods[name]
302  if m then
303    return m
304  else
305    local match = string.match(name, "^find_by_(.+)$")
306    if match then
307      return function (dao, args) return find_by(dao, match, args) end
308    end
309    local match = string.match(name, "^find_all_by_(.+)$")
310    if match then
311      return function (dao, args) return find_all_by(dao, match, args) end
312    end
313    return nil
314  end
315end
316
317function model_methods:new(name, dao)
318  dao = dao or {}
319  dao.model, dao.name, dao.table_name, dao.meta, dao.driver = self, name,
320    self.table_prefix .. name, {}, self.driver
321  setmetatable(dao, { __index = dao_index })
322  local sql = "select * from " .. dao.table_name .. " limit 0"
323  if self.logging then log_query(sql) end
324  local cursor, err = self.conn:execute(sql)
325  if not cursor then error(err) end
326  local names, types = cursor:getcolnames(), cursor:getcoltypes()
327  cursor:close()
328  for i = 1, #names do
329    local colinfo = { name = names[i],
330    type = type_names[self.driver](types[i]) }
331    dao.meta[i] = colinfo
332    dao.meta[colinfo.name] = colinfo
333  end
334  return dao
335end
336
337function recycle(fresh_conn, timeout)
338  local created_at = os.time()
339  local conn = fresh_conn()
340  timeout = timeout or 20000
341  return setmetatable({}, { __index = function (tab, meth)
342					 tab[meth] = function (tab, ...)
343							if created_at + timeout < os.time() then
344							   created_at = os.time()
345							   pcall(conn.close, conn)
346							   conn = fresh_conn()
347							end
348							return conn[meth](conn, ...)
349						     end
350					 return tab[meth]
351				      end
352			 })
353end
354
355function new(table_prefix, conn, driver, logging)
356  driver = driver or "sqlite3"
357  local app_model = { table_prefix = table_prefix or "", conn = conn, driver = driver or "sqlite3", logging = logging, models = {} }
358  setmetatable(app_model, { __index = model_methods })
359  return app_model
360end
361
362function dao_methods.find(dao, id, inject)
363  if not type(id) == "number" then
364    error("find error: id must be a number")
365  end
366  local sql = "select * from " .. dao.table_name ..
367    " where id=" .. id
368  if dao.logging then log_query(sql) end
369  return fetch_row(dao, sql)
370end
371
372condition_parser = re.compile([[
373				  top <- {~ <condition>* ~}
374				  s <- %s+ -> ' ' / ''
375				  condition <- (<s> '(' <s> <condition> <s> ')' <s> / <simple>) (<conective> <condition>)*
376				  simple <- <s> (%func <field> <op> '?') -> apply <s> / <s> <field> <op> <field> <s> /
377				            <s> <field> <op> <s>
378				  field <- !<conective> {[_%w]+('.'[_%w]+)?}
379				  op <- {~ <s> [!<>=~]+ <s> / ((%s+ -> ' ') !<conective> %w+)+ <s> ~}
380				  conective <- [aA][nN][dD] / [oO][rR]
381			      ]], { func = lpeg.Carg(1) , apply = function (f, field, op) return f(field, op) end })
382
383local function build_query(dao, condition, args)
384  local i = 0
385  args = args or {}
386  condition = condition or ""
387  if type(condition) == "table" then
388    args = condition
389    condition = ""
390  end
391  if condition ~= "" then
392    condition = " where " ..
393      condition_parser:match(condition, 1,
394		  function (field, op)
395		    i = i + 1
396		    if not args[i] then
397		      return "id=id"
398		    elseif type(args[i]) == "table" and args[i].type == "query" then
399			  return field .. " " .. op .. " (" .. args[i][1] .. ")"
400		    elseif type(args[i]) == "table" then
401		      local values = {}
402		      for j, value in ipairs(args[i]) do
403				values[#values + 1] = field .. " " .. op .. " " ..
404		          escape[dao.meta[field].type](value, dao.driver, dao.model.conn)
405              end
406		      return "(" .. table.concat(values, " or ") .. ")"
407            else
408		      return field .. " " .. op .. " " ..
409		        escape[dao.meta[field].type](args[i], dao.driver, dao.model.conn)
410            end
411		  end)
412  end
413  local order = ""
414  if args.order then order = " order by " .. args.order end
415  local field_list, table_list, select, limit, offset
416  if args.distinct then select = "select distinct " else select = "select " end
417  if tonumber(args.count) then limit = " limit " .. tonumber(args.count) else limit = "" end
418  if tonumber(args.offset) then offset = " offset " .. tonumber(args.offset) else offset = "" end
419  if args.inject then
420    local inject_condition
421    field_list, table_list, inject_condition = build_inject(args.fields, args.inject,
422      dao)
423    if condition == "" then
424      condition = " where " .. inject_condition
425    else
426      condition = condition .. " and " .. inject_condition
427    end
428  else
429    if args.fields then
430       field_list = table.concat(args.fields, ", ")
431    else
432       field_list = "*"
433    end
434    table_list = table.concat({ dao.table_name, unpack(args.from or {}) }, ", ")
435  end
436  local sql = select .. field_list .. " from " .. table_list ..
437    condition .. order .. limit .. offset
438  if dao.model.logging then log_query(sql) end
439  return sql
440end
441
442function dao_methods.find_first(dao, condition, args)
443  return fetch_row(dao, build_query(dao, condition, args))
444end
445
446function dao_methods.find_all(dao, condition, args)
447  return fetch_rows(dao, build_query(dao, condition, args),
448		    (args and args.count) or (condition and condition.count))
449end
450
451function dao_methods.new(dao, row)
452  row = row or {}
453  setmetatable(row, { __index = dao })
454  return row
455end
456
457local function update(row)
458  local row_escaped = escape_values(row)
459  local updates = {}
460  if row.meta["updated_at"] then
461    local now = os.time()
462    row.updated_at = now
463    row_escaped.updated_at = escape.datetime(now, row.driver)
464  end
465  for k, v in pairs(row_escaped) do
466    table.insert(updates, k .. "=" .. v)
467  end
468  local sql = "update " .. row.table_name .. " set " ..
469    table.concat(updates, ", ") .. " where id = " .. row.id
470  if row.model.logging then log_query(sql) end
471  local ok, err = row.model.conn:execute(sql)
472  if not ok then error(err) end
473end
474
475local function insert(row)
476  local row_escaped = escape_values(row)
477  local now = os.time()
478  if row.meta["created_at"] then
479    row.created_at = row.created_at or now
480    row_escaped.created_at = escape.datetime(now, row.driver)
481  end
482  if row.meta["updated_at"] then
483    row.updated_at = row.updated_at or now
484    row_escaped.updated_at = escape.datetime(now, row.driver)
485  end
486  local columns, values = {}, {}
487  for k, v in pairs(row_escaped) do
488    if row.driver ~= "postgres" or k ~= "id" and v ~= "NULL" then
489      table.insert(columns, k)
490      table.insert(values, v)
491    end
492  end
493  local sql = "insert into " .. row.table_name ..
494    " (" .. table.concat(columns, ", ") .. ") values (" ..
495    table.concat(values, ", ") .. ")"
496  if row.driver == "postgres" then sql = sql .. " returning id" end
497  if row.model.logging then log_query(sql) end
498  local ok, err = row.model.conn:execute(sql)
499  if ok then
500    if row.driver ~= "postgres" then
501      row.id = row.id or row.model.conn:getlastautoid()
502    else
503      row.id = row.id or tonumber( ok:fetch() )
504      ok:close()
505    end
506  else
507    error(err)
508  end
509end
510
511function dao_methods.save(row, force_insert)
512  if row.id and (not force_insert) then
513    update(row)
514  else
515    insert(row)
516  end
517end
518
519function dao_methods.delete(row)
520  if row.id then
521    local sql = "delete from " .. row.table_name .. " where id = " .. row.id
522    if row.model.logging then log_query(sql) end
523    local ok, err = row.model.conn:execute(sql)
524    if ok then row.id = nil else error(err) end
525  end
526end