/src/orbit/model.lua

http://github.com/keplerproject/orbit · Lua · 526 lines · 487 code · 39 blank · 0 comment · 42 complexity · 2868fe005fe5d76c4e31a067b8d038b9 MD5 · raw file

  1. local lpeg = require "lpeg"
  2. local re = require "re"
  3. module("orbit.model", package.seeall)
  4. model_methods = {}
  5. dao_methods = {}
  6. local type_names = {}
  7. local function log_query(sql)
  8. io.stderr:write("[orbit.model] " .. sql .. "\n")
  9. end
  10. function type_names.sqlite3(t)
  11. return string.lower(string.match(t, "(%a+)"))
  12. end
  13. function type_names.mysql(t)
  14. if t == "number(1)" then
  15. return "boolean"
  16. else
  17. return string.lower(string.match(t, "(%a+)"))
  18. end
  19. end
  20. function type_names.postgres(t)
  21. if t == "bool" then
  22. return "boolean"
  23. else
  24. return string.lower(string.match(t, "(%a+)"))
  25. end
  26. end
  27. local convert = {}
  28. function convert.real(v)
  29. return tonumber(v)
  30. end
  31. function convert.float(v)
  32. return tonumber(v)
  33. end
  34. function convert.integer(v)
  35. return tonumber(v)
  36. end
  37. function convert.int(v)
  38. return tonumber(v)
  39. end
  40. function convert.number(v)
  41. return tonumber(v)
  42. end
  43. function convert.numeric(v)
  44. return tonumber(v)
  45. end
  46. function convert.varchar(v)
  47. return tostring(v)
  48. end
  49. function convert.string(v)
  50. return tostring(v)
  51. end
  52. function convert.text(v)
  53. return tostring(v)
  54. end
  55. function convert.boolean(v, driver)
  56. if driver == "sqlite3" then
  57. return v == "t"
  58. elseif driver == "mysql" then
  59. return tonumber(v) == 1
  60. elseif driver == "postgres" then
  61. return v == "t"
  62. else
  63. error("driver not supported")
  64. end
  65. end
  66. function convert.binary(v)
  67. return convert.text(v)
  68. end
  69. function convert.datetime(v)
  70. local year, month, day, hour, min, sec =
  71. string.match(v, "(%d+)%-(%d+)%-(%d+) (%d+):(%d+):(%d+)")
  72. return os.time({ year = tonumber(year), month = tonumber(month),
  73. day = tonumber(day), hour = tonumber(hour),
  74. min = tonumber(min), sec = tonumber(sec) })
  75. end
  76. convert.timestamp = convert.datetime
  77. local function convert_types(row, meta, driver)
  78. for k, v in pairs(row) do
  79. if meta[k] then
  80. local conv = convert[meta[k].type]
  81. if conv then
  82. row[k] = conv(v, driver)
  83. else
  84. error("no conversion for type " .. meta[k].type)
  85. end
  86. end
  87. end
  88. end
  89. local escape = {}
  90. function escape.real(v)
  91. return tostring(v)
  92. end
  93. function escape.float(v)
  94. return tostring(v)
  95. end
  96. function escape.integer(v)
  97. return tostring(v)
  98. end
  99. function escape.int(v)
  100. return tostring(v)
  101. end
  102. function escape.number(v)
  103. return escape.integer(v)
  104. end
  105. function escape.numeric(v)
  106. return escape.integer(v)
  107. end
  108. function escape.varchar(v, driver, conn)
  109. return "'" .. conn:escape(v) .. "'"
  110. end
  111. function escape.string(v, driver, conn)
  112. return escape.varchar(v, driver, conn)
  113. end
  114. function escape.text(v, driver, conn)
  115. return "'" .. conn:escape(v) .. "'"
  116. end
  117. function escape.datetime(v)
  118. return "'" .. os.date("%Y-%m-%d %H:%M:%S", v) .. "'"
  119. end
  120. escape.timestamp = escape.datetime
  121. function escape.boolean(v, driver)
  122. if v then
  123. if driver == "sqlite3" or driver == "postgres" then return "'t'" else return tostring(v) end
  124. else
  125. if driver == "sqlite3" or driver == "postgres" then return "'f'" else return tostring(v) end
  126. end
  127. end
  128. function escape.binary(v, driver, conn)
  129. return escape.text(v, driver, conn)
  130. end
  131. local function escape_values(row)
  132. local row_escaped = {}
  133. for i, m in ipairs(row.meta) do
  134. if row[m.name] == nil then
  135. row_escaped[m.name] = "NULL"
  136. else
  137. local esc = escape[m.type]
  138. if esc then
  139. row_escaped[m.name] = esc(row[m.name], row.driver, row.model.conn)
  140. else
  141. error("no escape function for type " .. m.type)
  142. end
  143. end
  144. end
  145. return row_escaped
  146. end
  147. local function fetch_row(dao, sql)
  148. local cursor, err = dao.model.conn:execute(sql)
  149. if not cursor then error(err) end
  150. local row = cursor:fetch({}, "a")
  151. cursor:close()
  152. if row then
  153. convert_types(row, dao.meta, dao.driver)
  154. setmetatable(row, { __index = dao })
  155. end
  156. return row
  157. end
  158. local function fetch_rows(dao, sql, count)
  159. local rows = {}
  160. local cursor, err = dao.model.conn:execute(sql)
  161. if not cursor then error(err) end
  162. local row, fetched = cursor:fetch({}, "a"), 1
  163. while row and (not count or fetched <= count) do
  164. convert_types(row, dao.meta, dao.driver)
  165. setmetatable(row, { __index = dao })
  166. rows[#rows + 1] = row
  167. row, fetched = cursor:fetch({}, "a"), fetched + 1
  168. end
  169. cursor:close()
  170. return rows
  171. end
  172. local by_condition_parser = re.compile([[
  173. fields <- ({(!conective .)+} (conective {(!conective .)+})*) -> {}
  174. conective <- and / or
  175. and <- "_and_" -> "and"
  176. or <- "_or_" -> "or"
  177. ]])
  178. local function parse_condition(dao, condition, args)
  179. local parts = by_condition_parser:match(condition)
  180. local j = 0
  181. for i, part in ipairs(parts) do
  182. if part ~= "or" and part ~= "and" then
  183. j = j + 1
  184. local value
  185. if args[j] == nil then
  186. parts[i] = part .. " is null"
  187. elseif type(args[j]) == "table" then
  188. local values = {}
  189. for _, value in ipairs(args[j]) do
  190. values[#values + 1] = escape[dao.meta[part].type](value, dao.driver, dao.model.conn)
  191. end
  192. parts[i] = part .. " IN (" .. table.concat(values,", ") .. ")"
  193. else
  194. value = escape[dao.meta[part].type](args[j], dao.driver, dao.model.conn)
  195. parts[i] = part .. " = " .. value
  196. end
  197. end
  198. end
  199. return parts
  200. end
  201. local function build_inject(project, inject, dao)
  202. local fields = {}
  203. if project then
  204. for i, field in ipairs(project) do
  205. fields[i] = dao.table_name .. "." .. field .. " as " .. field
  206. end
  207. else
  208. for i, field in ipairs(dao.meta) do
  209. fields[i] = dao.table_name .. "." .. field.name .. " as " .. field.name
  210. end
  211. end
  212. local inject_fields = {}
  213. local model = inject.model
  214. for _, field in ipairs(inject.fields) do
  215. inject_fields[model.name .. "_" .. field] =
  216. model.meta[field]
  217. fields[#fields + 1] = model.table_name .. "." .. field .. " as " ..
  218. model.name .. "_" .. field
  219. end
  220. setmetatable(dao.meta, { __index = inject_fields })
  221. return table.concat(fields, ", "), dao.table_name .. ", " ..
  222. model.table_name, model.name .. "_id = " .. model.table_name .. ".id"
  223. end
  224. local function build_query_by(dao, condition, args)
  225. local parts = parse_condition(dao, condition, args)
  226. local order = ""
  227. local field_list, table_list, select, limit, offset
  228. if args.distinct then select = "select distinct " else select = "select " end
  229. if tonumber(args.count) then limit = " limit " .. tonumber(args.count) else limit = "" end
  230. if tonumber(args.offset) then offset = " offset " .. tonumber(args.offset) else offset = "" end
  231. if args.order then order = " order by " .. args.order end
  232. if args.inject then
  233. if #parts > 0 then parts[#parts + 1] = "and" end
  234. field_list, table_list, parts[#parts + 1] = build_inject(args.fields, args.inject,
  235. dao)
  236. else
  237. if args.fields then
  238. field_list = table.concat(args.fields, ", ")
  239. else
  240. field_list = "*"
  241. end
  242. table_list = dao.table_name
  243. end
  244. local sql = select .. field_list .. " from " .. table_list ..
  245. " where " .. table.concat(parts, " ") .. order .. limit .. offset
  246. if dao.model.logging then log_query(sql) end
  247. return sql
  248. end
  249. local function find_by(dao, condition, args)
  250. return fetch_row(dao, build_query_by(dao, condition, args))
  251. end
  252. local function find_all_by(dao, condition, args)
  253. return fetch_rows(dao, build_query_by(dao, condition, args), args.count)
  254. end
  255. local function dao_index(dao, name)
  256. local m = dao_methods[name]
  257. if m then
  258. return m
  259. else
  260. local match = string.match(name, "^find_by_(.+)$")
  261. if match then
  262. return function (dao, args) return find_by(dao, match, args) end
  263. end
  264. local match = string.match(name, "^find_all_by_(.+)$")
  265. if match then
  266. return function (dao, args) return find_all_by(dao, match, args) end
  267. end
  268. return nil
  269. end
  270. end
  271. function model_methods:new(name, dao)
  272. dao = dao or {}
  273. dao.model, dao.name, dao.table_name, dao.meta, dao.driver = self, name,
  274. self.table_prefix .. name, {}, self.driver
  275. setmetatable(dao, { __index = dao_index })
  276. local sql = "select * from " .. dao.table_name .. " limit 0"
  277. if self.logging then log_query(sql) end
  278. local cursor, err = self.conn:execute(sql)
  279. if not cursor then error(err) end
  280. local names, types = cursor:getcolnames(), cursor:getcoltypes()
  281. cursor:close()
  282. for i = 1, #names do
  283. local colinfo = { name = names[i],
  284. type = type_names[self.driver](types[i]) }
  285. dao.meta[i] = colinfo
  286. dao.meta[colinfo.name] = colinfo
  287. end
  288. return dao
  289. end
  290. function recycle(fresh_conn, timeout)
  291. local created_at = os.time()
  292. local conn = fresh_conn()
  293. timeout = timeout or 20000
  294. return setmetatable({}, { __index = function (tab, meth)
  295. tab[meth] = function (tab, ...)
  296. if created_at + timeout < os.time() then
  297. created_at = os.time()
  298. pcall(conn.close, conn)
  299. conn = fresh_conn()
  300. end
  301. return conn[meth](conn, ...)
  302. end
  303. return tab[meth]
  304. end
  305. })
  306. end
  307. function new(table_prefix, conn, driver, logging)
  308. driver = driver or "sqlite3"
  309. local app_model = { table_prefix = table_prefix or "", conn = conn, driver = driver or "sqlite3", logging = logging, models = {} }
  310. setmetatable(app_model, { __index = model_methods })
  311. return app_model
  312. end
  313. function dao_methods.find(dao, id, inject)
  314. if not type(id) == "number" then
  315. error("find error: id must be a number")
  316. end
  317. local sql = "select * from " .. dao.table_name ..
  318. " where id=" .. id
  319. if dao.logging then log_query(sql) end
  320. return fetch_row(dao, sql)
  321. end
  322. condition_parser = re.compile([[
  323. top <- {~ <condition>* ~}
  324. s <- %s+ -> ' ' / ''
  325. condition <- (<s> '(' <s> <condition> <s> ')' <s> / <simple>) (<conective> <condition>)*
  326. simple <- <s> (%func <field> <op> '?') -> apply <s> / <s> <field> <op> <field> <s> /
  327. <s> <field> <op> <s>
  328. field <- !<conective> {[_%w]+('.'[_%w]+)?}
  329. op <- {~ <s> [!<>=~]+ <s> / ((%s+ -> ' ') !<conective> %w+)+ <s> ~}
  330. conective <- [aA][nN][dD] / [oO][rR]
  331. ]], { func = lpeg.Carg(1) , apply = function (f, field, op) return f(field, op) end })
  332. local function build_query(dao, condition, args)
  333. local i = 0
  334. args = args or {}
  335. condition = condition or ""
  336. if type(condition) == "table" then
  337. args = condition
  338. condition = ""
  339. end
  340. if condition ~= "" then
  341. condition = " where " ..
  342. condition_parser:match(condition, 1,
  343. function (field, op)
  344. i = i + 1
  345. if not args[i] then
  346. return "id=id"
  347. elseif type(args[i]) == "table" and args[i].type == "query" then
  348. return field .. " " .. op .. " (" .. args[i][1] .. ")"
  349. elseif type(args[i]) == "table" then
  350. local values = {}
  351. for j, value in ipairs(args[i]) do
  352. values[#values + 1] = field .. " " .. op .. " " ..
  353. escape[dao.meta[field].type](value, dao.driver, dao.model.conn)
  354. end
  355. return "(" .. table.concat(values, " or ") .. ")"
  356. else
  357. return field .. " " .. op .. " " ..
  358. escape[dao.meta[field].type](args[i], dao.driver, dao.model.conn)
  359. end
  360. end)
  361. end
  362. local order = ""
  363. if args.order then order = " order by " .. args.order end
  364. local field_list, table_list, select, limit, offset
  365. if args.distinct then select = "select distinct " else select = "select " end
  366. if tonumber(args.count) then limit = " limit " .. tonumber(args.count) else limit = "" end
  367. if tonumber(args.offset) then offset = " offset " .. tonumber(args.offset) else offset = "" end
  368. if args.inject then
  369. local inject_condition
  370. field_list, table_list, inject_condition = build_inject(args.fields, args.inject,
  371. dao)
  372. if condition == "" then
  373. condition = " where " .. inject_condition
  374. else
  375. condition = condition .. " and " .. inject_condition
  376. end
  377. else
  378. if args.fields then
  379. field_list = table.concat(args.fields, ", ")
  380. else
  381. field_list = "*"
  382. end
  383. table_list = table.concat({ dao.table_name, unpack(args.from or {}) }, ", ")
  384. end
  385. local sql = select .. field_list .. " from " .. table_list ..
  386. condition .. order .. limit .. offset
  387. if dao.model.logging then log_query(sql) end
  388. return sql
  389. end
  390. function dao_methods.find_first(dao, condition, args)
  391. return fetch_row(dao, build_query(dao, condition, args))
  392. end
  393. function dao_methods.find_all(dao, condition, args)
  394. return fetch_rows(dao, build_query(dao, condition, args),
  395. (args and args.count) or (condition and condition.count))
  396. end
  397. function dao_methods.new(dao, row)
  398. row = row or {}
  399. setmetatable(row, { __index = dao })
  400. return row
  401. end
  402. local function update(row)
  403. local row_escaped = escape_values(row)
  404. local updates = {}
  405. if row.meta["updated_at"] then
  406. local now = os.time()
  407. row.updated_at = now
  408. row_escaped.updated_at = escape.datetime(now, row.driver)
  409. end
  410. for k, v in pairs(row_escaped) do
  411. table.insert(updates, k .. "=" .. v)
  412. end
  413. local sql = "update " .. row.table_name .. " set " ..
  414. table.concat(updates, ", ") .. " where id = " .. row.id
  415. if row.model.logging then log_query(sql) end
  416. local ok, err = row.model.conn:execute(sql)
  417. if not ok then error(err) end
  418. end
  419. local function insert(row)
  420. local row_escaped = escape_values(row)
  421. local now = os.time()
  422. if row.meta["created_at"] then
  423. row.created_at = row.created_at or now
  424. row_escaped.created_at = escape.datetime(now, row.driver)
  425. end
  426. if row.meta["updated_at"] then
  427. row.updated_at = row.updated_at or now
  428. row_escaped.updated_at = escape.datetime(now, row.driver)
  429. end
  430. local columns, values = {}, {}
  431. for k, v in pairs(row_escaped) do
  432. if row.driver ~= "postgres" or k ~= "id" and v ~= "NULL" then
  433. table.insert(columns, k)
  434. table.insert(values, v)
  435. end
  436. end
  437. local sql = "insert into " .. row.table_name ..
  438. " (" .. table.concat(columns, ", ") .. ") values (" ..
  439. table.concat(values, ", ") .. ")"
  440. if row.driver == "postgres" then sql = sql .. " returning id" end
  441. if row.model.logging then log_query(sql) end
  442. local ok, err = row.model.conn:execute(sql)
  443. if ok then
  444. if row.driver ~= "postgres" then
  445. row.id = row.id or row.model.conn:getlastautoid()
  446. else
  447. row.id = row.id or tonumber( ok:fetch() )
  448. ok:close()
  449. end
  450. else
  451. error(err)
  452. end
  453. end
  454. function dao_methods.save(row, force_insert)
  455. if row.id and (not force_insert) then
  456. update(row)
  457. else
  458. insert(row)
  459. end
  460. end
  461. function dao_methods.delete(row)
  462. if row.id then
  463. local sql = "delete from " .. row.table_name .. " where id = " .. row.id
  464. if row.model.logging then log_query(sql) end
  465. local ok, err = row.model.conn:execute(sql)
  466. if ok then row.id = nil else error(err) end
  467. end
  468. end