PageRenderTime 72ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/IPython/parallel/controller/sqlitedb.py

https://github.com/cboos/ipython
Python | 408 lines | 360 code | 17 blank | 31 comment | 19 complexity | 28af4883b7fd58767f52a11e1c9428ca MD5 | raw file
  1. """A TaskRecord backend using sqlite3
  2. Authors:
  3. * Min RK
  4. """
  5. #-----------------------------------------------------------------------------
  6. # Copyright (C) 2011 The IPython Development Team
  7. #
  8. # Distributed under the terms of the BSD License. The full license is in
  9. # the file COPYING, distributed as part of this software.
  10. #-----------------------------------------------------------------------------
  11. import json
  12. import os
  13. import cPickle as pickle
  14. from datetime import datetime
  15. try:
  16. import sqlite3
  17. except ImportError:
  18. sqlite3 = None
  19. from zmq.eventloop import ioloop
  20. from IPython.utils.traitlets import Unicode, Instance, List, Dict
  21. from .dictdb import BaseDB
  22. from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
  23. #-----------------------------------------------------------------------------
  24. # SQLite operators, adapters, and converters
  25. #-----------------------------------------------------------------------------
  26. try:
  27. buffer
  28. except NameError:
  29. # py3k
  30. buffer = memoryview
  31. operators = {
  32. '$lt' : "<",
  33. '$gt' : ">",
  34. # null is handled weird with ==,!=
  35. '$eq' : "=",
  36. '$ne' : "!=",
  37. '$lte': "<=",
  38. '$gte': ">=",
  39. '$in' : ('=', ' OR '),
  40. '$nin': ('!=', ' AND '),
  41. # '$all': None,
  42. # '$mod': None,
  43. # '$exists' : None
  44. }
  45. null_operators = {
  46. '=' : "IS NULL",
  47. '!=' : "IS NOT NULL",
  48. }
  49. def _adapt_dict(d):
  50. return json.dumps(d, default=date_default)
  51. def _convert_dict(ds):
  52. if ds is None:
  53. return ds
  54. else:
  55. if isinstance(ds, bytes):
  56. # If I understand the sqlite doc correctly, this will always be utf8
  57. ds = ds.decode('utf8')
  58. return extract_dates(json.loads(ds))
  59. def _adapt_bufs(bufs):
  60. # this is *horrible*
  61. # copy buffers into single list and pickle it:
  62. if bufs and isinstance(bufs[0], (bytes, buffer)):
  63. return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
  64. elif bufs:
  65. return bufs
  66. else:
  67. return None
  68. def _convert_bufs(bs):
  69. if bs is None:
  70. return []
  71. else:
  72. return pickle.loads(bytes(bs))
  73. #-----------------------------------------------------------------------------
  74. # SQLiteDB class
  75. #-----------------------------------------------------------------------------
  76. class SQLiteDB(BaseDB):
  77. """SQLite3 TaskRecord backend."""
  78. filename = Unicode('tasks.db', config=True,
  79. help="""The filename of the sqlite task database. [default: 'tasks.db']""")
  80. location = Unicode('', config=True,
  81. help="""The directory containing the sqlite task database. The default
  82. is to use the cluster_dir location.""")
  83. table = Unicode("", config=True,
  84. help="""The SQLite Table to use for storing tasks for this session. If unspecified,
  85. a new table will be created with the Hub's IDENT. Specifying the table will result
  86. in tasks from previous sessions being available via Clients' db_query and
  87. get_result methods.""")
  88. if sqlite3 is not None:
  89. _db = Instance('sqlite3.Connection')
  90. else:
  91. _db = None
  92. # the ordered list of column names
  93. _keys = List(['msg_id' ,
  94. 'header' ,
  95. 'content',
  96. 'buffers',
  97. 'submitted',
  98. 'client_uuid' ,
  99. 'engine_uuid' ,
  100. 'started',
  101. 'completed',
  102. 'resubmitted',
  103. 'result_header' ,
  104. 'result_content' ,
  105. 'result_buffers' ,
  106. 'queue' ,
  107. 'pyin' ,
  108. 'pyout',
  109. 'pyerr',
  110. 'stdout',
  111. 'stderr',
  112. ])
  113. # sqlite datatypes for checking that db is current format
  114. _types = Dict({'msg_id' : 'text' ,
  115. 'header' : 'dict text',
  116. 'content' : 'dict text',
  117. 'buffers' : 'bufs blob',
  118. 'submitted' : 'timestamp',
  119. 'client_uuid' : 'text',
  120. 'engine_uuid' : 'text',
  121. 'started' : 'timestamp',
  122. 'completed' : 'timestamp',
  123. 'resubmitted' : 'timestamp',
  124. 'result_header' : 'dict text',
  125. 'result_content' : 'dict text',
  126. 'result_buffers' : 'bufs blob',
  127. 'queue' : 'text',
  128. 'pyin' : 'text',
  129. 'pyout' : 'text',
  130. 'pyerr' : 'text',
  131. 'stdout' : 'text',
  132. 'stderr' : 'text',
  133. })
  134. def __init__(self, **kwargs):
  135. super(SQLiteDB, self).__init__(**kwargs)
  136. if sqlite3 is None:
  137. raise ImportError("SQLiteDB requires sqlite3")
  138. if not self.table:
  139. # use session, and prefix _, since starting with # is illegal
  140. self.table = '_'+self.session.replace('-','_')
  141. if not self.location:
  142. # get current profile
  143. from IPython.core.application import BaseIPythonApplication
  144. if BaseIPythonApplication.initialized():
  145. app = BaseIPythonApplication.instance()
  146. if app.profile_dir is not None:
  147. self.location = app.profile_dir.location
  148. else:
  149. self.location = u'.'
  150. else:
  151. self.location = u'.'
  152. self._init_db()
  153. # register db commit as 2s periodic callback
  154. # to prevent clogging pipes
  155. # assumes we are being run in a zmq ioloop app
  156. loop = ioloop.IOLoop.instance()
  157. pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
  158. pc.start()
  159. def _defaults(self, keys=None):
  160. """create an empty record"""
  161. d = {}
  162. keys = self._keys if keys is None else keys
  163. for key in keys:
  164. d[key] = None
  165. return d
  166. def _check_table(self):
  167. """Ensure that an incorrect table doesn't exist
  168. If a bad (old) table does exist, return False
  169. """
  170. cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
  171. lines = cursor.fetchall()
  172. if not lines:
  173. # table does not exist
  174. return True
  175. types = {}
  176. keys = []
  177. for line in lines:
  178. keys.append(line[1])
  179. types[line[1]] = line[2]
  180. if self._keys != keys:
  181. # key mismatch
  182. self.log.warn('keys mismatch')
  183. return False
  184. for key in self._keys:
  185. if types[key] != self._types[key]:
  186. self.log.warn(
  187. 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
  188. )
  189. return False
  190. return True
  191. def _init_db(self):
  192. """Connect to the database and get new session number."""
  193. # register adapters
  194. sqlite3.register_adapter(dict, _adapt_dict)
  195. sqlite3.register_converter('dict', _convert_dict)
  196. sqlite3.register_adapter(list, _adapt_bufs)
  197. sqlite3.register_converter('bufs', _convert_bufs)
  198. # connect to the db
  199. dbfile = os.path.join(self.location, self.filename)
  200. self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
  201. # isolation_level = None)#,
  202. cached_statements=64)
  203. # print dir(self._db)
  204. first_table = self.table
  205. i=0
  206. while not self._check_table():
  207. i+=1
  208. self.table = first_table+'_%i'%i
  209. self.log.warn(
  210. "Table %s exists and doesn't match db format, trying %s"%
  211. (first_table,self.table)
  212. )
  213. self._db.execute("""CREATE TABLE IF NOT EXISTS %s
  214. (msg_id text PRIMARY KEY,
  215. header dict text,
  216. content dict text,
  217. buffers bufs blob,
  218. submitted timestamp,
  219. client_uuid text,
  220. engine_uuid text,
  221. started timestamp,
  222. completed timestamp,
  223. resubmitted timestamp,
  224. result_header dict text,
  225. result_content dict text,
  226. result_buffers bufs blob,
  227. queue text,
  228. pyin text,
  229. pyout text,
  230. pyerr text,
  231. stdout text,
  232. stderr text)
  233. """%self.table)
  234. self._db.commit()
  235. def _dict_to_list(self, d):
  236. """turn a mongodb-style record dict into a list."""
  237. return [ d[key] for key in self._keys ]
  238. def _list_to_dict(self, line, keys=None):
  239. """Inverse of dict_to_list"""
  240. keys = self._keys if keys is None else keys
  241. d = self._defaults(keys)
  242. for key,value in zip(keys, line):
  243. d[key] = value
  244. return d
  245. def _render_expression(self, check):
  246. """Turn a mongodb-style search dict into an SQL query."""
  247. expressions = []
  248. args = []
  249. skeys = set(check.keys())
  250. skeys.difference_update(set(self._keys))
  251. skeys.difference_update(set(['buffers', 'result_buffers']))
  252. if skeys:
  253. raise KeyError("Illegal testing key(s): %s"%skeys)
  254. for name,sub_check in check.iteritems():
  255. if isinstance(sub_check, dict):
  256. for test,value in sub_check.iteritems():
  257. try:
  258. op = operators[test]
  259. except KeyError:
  260. raise KeyError("Unsupported operator: %r"%test)
  261. if isinstance(op, tuple):
  262. op, join = op
  263. if value is None and op in null_operators:
  264. expr = "%s %s"%null_operators[op]
  265. else:
  266. expr = "%s %s ?"%(name, op)
  267. if isinstance(value, (tuple,list)):
  268. if op in null_operators and any([v is None for v in value]):
  269. # equality tests don't work with NULL
  270. raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
  271. expr = '( %s )'%( join.join([expr]*len(value)) )
  272. args.extend(value)
  273. else:
  274. args.append(value)
  275. expressions.append(expr)
  276. else:
  277. # it's an equality check
  278. if sub_check is None:
  279. expressions.append("%s IS NULL")
  280. else:
  281. expressions.append("%s = ?"%name)
  282. args.append(sub_check)
  283. expr = " AND ".join(expressions)
  284. return expr, args
  285. def add_record(self, msg_id, rec):
  286. """Add a new Task Record, by msg_id."""
  287. d = self._defaults()
  288. d.update(rec)
  289. d['msg_id'] = msg_id
  290. line = self._dict_to_list(d)
  291. tups = '(%s)'%(','.join(['?']*len(line)))
  292. self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
  293. # self._db.commit()
  294. def get_record(self, msg_id):
  295. """Get a specific Task Record, by msg_id."""
  296. cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
  297. line = cursor.fetchone()
  298. if line is None:
  299. raise KeyError("No such msg: %r"%msg_id)
  300. return self._list_to_dict(line)
  301. def update_record(self, msg_id, rec):
  302. """Update the data in an existing record."""
  303. query = "UPDATE %s SET "%self.table
  304. sets = []
  305. keys = sorted(rec.keys())
  306. values = []
  307. for key in keys:
  308. sets.append('%s = ?'%key)
  309. values.append(rec[key])
  310. query += ', '.join(sets)
  311. query += ' WHERE msg_id == ?'
  312. values.append(msg_id)
  313. self._db.execute(query, values)
  314. # self._db.commit()
  315. def drop_record(self, msg_id):
  316. """Remove a record from the DB."""
  317. self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
  318. # self._db.commit()
  319. def drop_matching_records(self, check):
  320. """Remove a record from the DB."""
  321. expr,args = self._render_expression(check)
  322. query = "DELETE FROM %s WHERE %s"%(self.table, expr)
  323. self._db.execute(query,args)
  324. # self._db.commit()
  325. def find_records(self, check, keys=None):
  326. """Find records matching a query dict, optionally extracting subset of keys.
  327. Returns list of matching records.
  328. Parameters
  329. ----------
  330. check: dict
  331. mongodb-style query argument
  332. keys: list of strs [optional]
  333. if specified, the subset of keys to extract. msg_id will *always* be
  334. included.
  335. """
  336. if keys:
  337. bad_keys = [ key for key in keys if key not in self._keys ]
  338. if bad_keys:
  339. raise KeyError("Bad record key(s): %s"%bad_keys)
  340. if keys:
  341. # ensure msg_id is present and first:
  342. if 'msg_id' in keys:
  343. keys.remove('msg_id')
  344. keys.insert(0, 'msg_id')
  345. req = ', '.join(keys)
  346. else:
  347. req = '*'
  348. expr,args = self._render_expression(check)
  349. query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
  350. cursor = self._db.execute(query, args)
  351. matches = cursor.fetchall()
  352. records = []
  353. for line in matches:
  354. rec = self._list_to_dict(line, keys)
  355. records.append(rec)
  356. return records
  357. def get_history(self):
  358. """get all msg_ids, ordered by time submitted."""
  359. query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
  360. cursor = self._db.execute(query)
  361. # will be a list of length 1 tuples
  362. return [ tup[0] for tup in cursor.fetchall()]
  363. __all__ = ['SQLiteDB']