/IPython/parallel/controller/sqlitedb.py
Python | 408 lines | 360 code | 17 blank | 31 comment | 19 complexity | 28af4883b7fd58767f52a11e1c9428ca MD5 | raw file
- """A TaskRecord backend using sqlite3
- Authors:
- * Min RK
- """
- #-----------------------------------------------------------------------------
- # Copyright (C) 2011 The IPython Development Team
- #
- # Distributed under the terms of the BSD License. The full license is in
- # the file COPYING, distributed as part of this software.
- #-----------------------------------------------------------------------------
- import json
- import os
- import cPickle as pickle
- from datetime import datetime
- try:
- import sqlite3
- except ImportError:
- sqlite3 = None
- from zmq.eventloop import ioloop
- from IPython.utils.traitlets import Unicode, Instance, List, Dict
- from .dictdb import BaseDB
- from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
- #-----------------------------------------------------------------------------
- # SQLite operators, adapters, and converters
- #-----------------------------------------------------------------------------
- try:
- buffer
- except NameError:
- # py3k
- buffer = memoryview
- operators = {
- '$lt' : "<",
- '$gt' : ">",
- # null is handled weird with ==,!=
- '$eq' : "=",
- '$ne' : "!=",
- '$lte': "<=",
- '$gte': ">=",
- '$in' : ('=', ' OR '),
- '$nin': ('!=', ' AND '),
- # '$all': None,
- # '$mod': None,
- # '$exists' : None
- }
- null_operators = {
- '=' : "IS NULL",
- '!=' : "IS NOT NULL",
- }
- def _adapt_dict(d):
- return json.dumps(d, default=date_default)
- def _convert_dict(ds):
- if ds is None:
- return ds
- else:
- if isinstance(ds, bytes):
- # If I understand the sqlite doc correctly, this will always be utf8
- ds = ds.decode('utf8')
- return extract_dates(json.loads(ds))
- def _adapt_bufs(bufs):
- # this is *horrible*
- # copy buffers into single list and pickle it:
- if bufs and isinstance(bufs[0], (bytes, buffer)):
- return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
- elif bufs:
- return bufs
- else:
- return None
- def _convert_bufs(bs):
- if bs is None:
- return []
- else:
- return pickle.loads(bytes(bs))
- #-----------------------------------------------------------------------------
- # SQLiteDB class
- #-----------------------------------------------------------------------------
- class SQLiteDB(BaseDB):
- """SQLite3 TaskRecord backend."""
- filename = Unicode('tasks.db', config=True,
- help="""The filename of the sqlite task database. [default: 'tasks.db']""")
- location = Unicode('', config=True,
- help="""The directory containing the sqlite task database. The default
- is to use the cluster_dir location.""")
- table = Unicode("", config=True,
- help="""The SQLite Table to use for storing tasks for this session. If unspecified,
- a new table will be created with the Hub's IDENT. Specifying the table will result
- in tasks from previous sessions being available via Clients' db_query and
- get_result methods.""")
- if sqlite3 is not None:
- _db = Instance('sqlite3.Connection')
- else:
- _db = None
- # the ordered list of column names
- _keys = List(['msg_id' ,
- 'header' ,
- 'content',
- 'buffers',
- 'submitted',
- 'client_uuid' ,
- 'engine_uuid' ,
- 'started',
- 'completed',
- 'resubmitted',
- 'result_header' ,
- 'result_content' ,
- 'result_buffers' ,
- 'queue' ,
- 'pyin' ,
- 'pyout',
- 'pyerr',
- 'stdout',
- 'stderr',
- ])
- # sqlite datatypes for checking that db is current format
- _types = Dict({'msg_id' : 'text' ,
- 'header' : 'dict text',
- 'content' : 'dict text',
- 'buffers' : 'bufs blob',
- 'submitted' : 'timestamp',
- 'client_uuid' : 'text',
- 'engine_uuid' : 'text',
- 'started' : 'timestamp',
- 'completed' : 'timestamp',
- 'resubmitted' : 'timestamp',
- 'result_header' : 'dict text',
- 'result_content' : 'dict text',
- 'result_buffers' : 'bufs blob',
- 'queue' : 'text',
- 'pyin' : 'text',
- 'pyout' : 'text',
- 'pyerr' : 'text',
- 'stdout' : 'text',
- 'stderr' : 'text',
- })
- def __init__(self, **kwargs):
- super(SQLiteDB, self).__init__(**kwargs)
- if sqlite3 is None:
- raise ImportError("SQLiteDB requires sqlite3")
- if not self.table:
- # use session, and prefix _, since starting with # is illegal
- self.table = '_'+self.session.replace('-','_')
- if not self.location:
- # get current profile
- from IPython.core.application import BaseIPythonApplication
- if BaseIPythonApplication.initialized():
- app = BaseIPythonApplication.instance()
- if app.profile_dir is not None:
- self.location = app.profile_dir.location
- else:
- self.location = u'.'
- else:
- self.location = u'.'
- self._init_db()
- # register db commit as 2s periodic callback
- # to prevent clogging pipes
- # assumes we are being run in a zmq ioloop app
- loop = ioloop.IOLoop.instance()
- pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
- pc.start()
- def _defaults(self, keys=None):
- """create an empty record"""
- d = {}
- keys = self._keys if keys is None else keys
- for key in keys:
- d[key] = None
- return d
- def _check_table(self):
- """Ensure that an incorrect table doesn't exist
- If a bad (old) table does exist, return False
- """
- cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
- lines = cursor.fetchall()
- if not lines:
- # table does not exist
- return True
- types = {}
- keys = []
- for line in lines:
- keys.append(line[1])
- types[line[1]] = line[2]
- if self._keys != keys:
- # key mismatch
- self.log.warn('keys mismatch')
- return False
- for key in self._keys:
- if types[key] != self._types[key]:
- self.log.warn(
- 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
- )
- return False
- return True
- def _init_db(self):
- """Connect to the database and get new session number."""
- # register adapters
- sqlite3.register_adapter(dict, _adapt_dict)
- sqlite3.register_converter('dict', _convert_dict)
- sqlite3.register_adapter(list, _adapt_bufs)
- sqlite3.register_converter('bufs', _convert_bufs)
- # connect to the db
- dbfile = os.path.join(self.location, self.filename)
- self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
- # isolation_level = None)#,
- cached_statements=64)
- # print dir(self._db)
- first_table = self.table
- i=0
- while not self._check_table():
- i+=1
- self.table = first_table+'_%i'%i
- self.log.warn(
- "Table %s exists and doesn't match db format, trying %s"%
- (first_table,self.table)
- )
- self._db.execute("""CREATE TABLE IF NOT EXISTS %s
- (msg_id text PRIMARY KEY,
- header dict text,
- content dict text,
- buffers bufs blob,
- submitted timestamp,
- client_uuid text,
- engine_uuid text,
- started timestamp,
- completed timestamp,
- resubmitted timestamp,
- result_header dict text,
- result_content dict text,
- result_buffers bufs blob,
- queue text,
- pyin text,
- pyout text,
- pyerr text,
- stdout text,
- stderr text)
- """%self.table)
- self._db.commit()
- def _dict_to_list(self, d):
- """turn a mongodb-style record dict into a list."""
- return [ d[key] for key in self._keys ]
- def _list_to_dict(self, line, keys=None):
- """Inverse of dict_to_list"""
- keys = self._keys if keys is None else keys
- d = self._defaults(keys)
- for key,value in zip(keys, line):
- d[key] = value
- return d
- def _render_expression(self, check):
- """Turn a mongodb-style search dict into an SQL query."""
- expressions = []
- args = []
- skeys = set(check.keys())
- skeys.difference_update(set(self._keys))
- skeys.difference_update(set(['buffers', 'result_buffers']))
- if skeys:
- raise KeyError("Illegal testing key(s): %s"%skeys)
- for name,sub_check in check.iteritems():
- if isinstance(sub_check, dict):
- for test,value in sub_check.iteritems():
- try:
- op = operators[test]
- except KeyError:
- raise KeyError("Unsupported operator: %r"%test)
- if isinstance(op, tuple):
- op, join = op
- if value is None and op in null_operators:
- expr = "%s %s"%null_operators[op]
- else:
- expr = "%s %s ?"%(name, op)
- if isinstance(value, (tuple,list)):
- if op in null_operators and any([v is None for v in value]):
- # equality tests don't work with NULL
- raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
- expr = '( %s )'%( join.join([expr]*len(value)) )
- args.extend(value)
- else:
- args.append(value)
- expressions.append(expr)
- else:
- # it's an equality check
- if sub_check is None:
- expressions.append("%s IS NULL")
- else:
- expressions.append("%s = ?"%name)
- args.append(sub_check)
- expr = " AND ".join(expressions)
- return expr, args
- def add_record(self, msg_id, rec):
- """Add a new Task Record, by msg_id."""
- d = self._defaults()
- d.update(rec)
- d['msg_id'] = msg_id
- line = self._dict_to_list(d)
- tups = '(%s)'%(','.join(['?']*len(line)))
- self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
- # self._db.commit()
- def get_record(self, msg_id):
- """Get a specific Task Record, by msg_id."""
- cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
- line = cursor.fetchone()
- if line is None:
- raise KeyError("No such msg: %r"%msg_id)
- return self._list_to_dict(line)
- def update_record(self, msg_id, rec):
- """Update the data in an existing record."""
- query = "UPDATE %s SET "%self.table
- sets = []
- keys = sorted(rec.keys())
- values = []
- for key in keys:
- sets.append('%s = ?'%key)
- values.append(rec[key])
- query += ', '.join(sets)
- query += ' WHERE msg_id == ?'
- values.append(msg_id)
- self._db.execute(query, values)
- # self._db.commit()
- def drop_record(self, msg_id):
- """Remove a record from the DB."""
- self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
- # self._db.commit()
- def drop_matching_records(self, check):
- """Remove a record from the DB."""
- expr,args = self._render_expression(check)
- query = "DELETE FROM %s WHERE %s"%(self.table, expr)
- self._db.execute(query,args)
- # self._db.commit()
- def find_records(self, check, keys=None):
- """Find records matching a query dict, optionally extracting subset of keys.
- Returns list of matching records.
- Parameters
- ----------
- check: dict
- mongodb-style query argument
- keys: list of strs [optional]
- if specified, the subset of keys to extract. msg_id will *always* be
- included.
- """
- if keys:
- bad_keys = [ key for key in keys if key not in self._keys ]
- if bad_keys:
- raise KeyError("Bad record key(s): %s"%bad_keys)
- if keys:
- # ensure msg_id is present and first:
- if 'msg_id' in keys:
- keys.remove('msg_id')
- keys.insert(0, 'msg_id')
- req = ', '.join(keys)
- else:
- req = '*'
- expr,args = self._render_expression(check)
- query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
- cursor = self._db.execute(query, args)
- matches = cursor.fetchall()
- records = []
- for line in matches:
- rec = self._list_to_dict(line, keys)
- records.append(rec)
- return records
- def get_history(self):
- """get all msg_ids, ordered by time submitted."""
- query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
- cursor = self._db.execute(query)
- # will be a list of length 1 tuples
- return [ tup[0] for tup in cursor.fetchall()]
- __all__ = ['SQLiteDB']