/_unsorted/experimental/orm.py
Python | 336 lines | 320 code | 1 blank | 15 comment | 2 complexity | 6602744726047c512f33b65c9201386f MD5 | raw file
Possible License(s): BSD-3-Clause
- """
- We'll start out with a very specific approach and generalize from there.
- """
- from collections.abc import MutableMapping, MutableSequence
- import sqlite3
- BASE_DB = sqlite3.Connection
- class RowView(MutableMapping):
- def __init__(self, db, table, id, key="id"):
- self._db = db
- query = "select id from ? where ?=?"
- row = db.execute(query, table, id).fetchone()
- self.table = table
- self.id = row[0]
- self.COLUMNS = tuple(row.keys())
- def _raw(self):
- query = "select * from ? where id=?"
- return self._db.execute(query, self.table, self.id).fetchone()
- def _check(self, key):
- if key not in self.COLUMNS:
- raise KeyError(key)
- if key == "id":
- raise TypeError("'id' is read-only")
- def __iter__(self):
- return iter(self._raw())
- def __container__(self, obj):
- return self.COLUMNS
- def __len__(self):
- return len(self.COLUMNS)
- def __getitem__(self, key):
- if key == "id":
- return self.id
- self.check(key)
- return self._raw()[key]
- def __setitem__(self, key, value):
- self._check(key)
- query = "update ? set ?=? where id=?"
- self._db.execute(query, self.table, key, value, self.id)
- def __delitem__(self, key):
- self._check(key)
- raise TypeError("Currently unsupported")
- class CSVItem(MutableSequence):
- """A csv-backed pseudo-minirow.
- The first value in the CSV sequence is considered the primary one.
-
- """
- def __init__(self, row, column):
- self._row = row
- self.id = row["id"]
- self.column = column
- def _raw(self):
- self._row[self.column].split(",")
- class PrimaryType(object):
- def __get__(self, obj, cls):
- if obj is None: return self
- try: return obj[0]
- except IndexError: return cls.DEFAULT
- def __set__(self, obj, value):
- obj[0] = value # XXX insert at 0 instead?
- primary = PrimaryType()
- del PrimaryType
- def __iter__(self):
- return iter(self.raw())
- def __contains__(self, value):
- return value in self.raw()
- def __len__(self):
- return len(self.raw())
- def __getitem__(self, index):
- return self.raw()[index]
- def __setitem__(self, index, value):
- types = self._raw()
- types[index] = value
- self._row[self.column] = ",".join(types)
- def __delitem__(self, index):
- types = self._raw()
- del types[index]
- self._row[self.column] = ",".join(types)
- class Row(sqlite3.Row):
- def __new__(cls, cursor, row):
- ...
- def __setitem__(self, key, value):
- UPDATE_SINGLE_VALUE = (
- "UPDATE :table"
- " SET :key = :value"
- " WHERE :idkey = :idvalue")
- table = self.TABLE
- idkey = table.primary
- idvalue = self[idkey]
- self.conn.execute(UPDATE_SINGLE_VALUE, locals())
- def __getattr__(self, name):
- return self[name]
- def __setattr__(self, name, value):
- self[name] = value
- class ColumnDefinition(namedtuple("BaseColumn", "name type constraints")):
- def __conform__(self, protocol):
- if protocol is sqlite3.PrepareProtocol:
- return self.name
- class TableDefinition(namedtuple("BaseTable", "name primary columns")):
- def __new__(cls, name, columns, primary="id"):
- if primary not in columns:
- raise TypeError("%s not in %s" % (primary, columns))
- return super(cls, cls).__new__(cls, name, primary, columns)
- def __conform__(self, protocol):
- if protocol is sqlite3.PrepareProtocol:
- return self.name
- class TableMeta(type):
- def from_sql(cls, conn, sql):
- ...
- class Table(Mapping):
- """
- """
- __metaclass__ = TableMeta
- ROW_CLASS = Row
- def __init__(self, conn, definition):
- self.conn = conn
- self.definition = definition
- @property
- def name(self):
- self.definition.name
- @property
- def primary(self):
- self.definition.primary
- @property
- def columns(self):
- self.definition.columns
- def _select(self, query, values):
- with self.conn.using_rowclass(self.ROW_CLASS):
- return self.conn.execute(query, values).fetchall()
- def keys(self):
- #return (r[0] for r in self._select(GET_KEYS, (self.primary,)))
- return (r[self.primary] for r in self.all())
- def one(self, *primary, **pair):
- total = len(primary) + len(pair)
- if total == 0:
- return None
- if total > 1:
- msg = "only one column may be queried, received %s"
- raise TypeError(msg % total)
- if pair:
- key, value = pair.items()[0]
- else:
- key, value = self.primary, primary
- GET_ONE = "SELECT * FROM ? WHERE ? = ?"
- values = (self.name, key, value)
- with self.conn.using_rowclass(self.ROW_CLASS):
- return self.conn.execute(GET_ONE, values).fetchall()
- def all(self):
- GET_ALL = "SELECT * FROM ?"
- with self.conn.using_rowclass(self.ROW_CLASS):
- return self.conn.execute(GET_ONE, values).fetchall()
- def __iter__(self):
- return iter(self.all())
- def __len__(self):
- rows = self.conn.execute("SELECT count(*) FROM ?", (self.name,))
- return rows.fetchone()[0]
- def __contains__(self, name):
- return name in self.keys()
- def __getitem__(self, key):
- row = self.one(key)
- if not row:
- return KeyError(key)
- return row
- def add(self, row):
- ...
- class RowClassContext(object):
- def __init__(self, db, rowclass):
- self.db = db
- self.cls = rowclass
- def __enter__(self):
- self.oldclass = self.db.row_factory
- self.db.row_factory = self.cls
- def __exit__(self, *args, **kwargs):
- self.db.row_factory = self.oldclass
- class TableClassContext(object):
- def __init__(self, db, tableclass):
- self.db = db
- self.cls = tableclass
- def __enter__(self):
- self.oldclass = self.db.table_factory
- self.db.table_factory = self.cls
- def __exit__(self, *args, **kwargs):
- self.db.table_factory = self.oldclass
- class DBMeta(type(BASE_DB)):
- def __call__(cls, location=None):
- if location is None:
- location = cls.LOCATION
- obj = super(cls, cls).__call__(location, isolation_level=None)
- return obj
- def from_class(cls, definition):
- """Class decorator to turn a bare class into a DB subclass.
- Ignores bases and all "private" attributes (except for __name__,
- __module__, and __doc__).
- LOCATION and ROW_CLASS are passed through. All other attributes
- of the class are treated as TableDefinition objects.
-
- """
- namespace = dict(definition.__dict__)
- name = namespace.pop("__name__")
- module = namespace.pop("__module__")
- doc = namespace.pop("__doc__")
-
- tables = {}
- for name in namespace:
- if name.startswith("_"):
- del namespace[name]
- elif name not in ("LOCATION", "ROW_CLASS"):
- tables[name] = namespace.pop(name)
- namespace["TABLES"] = tables
- obj = type(cls)(name, (cls,), namespace)
- obj.__doc__ = doc
- obj.__module__ = module
- return obj
- class DB(sqlite3.Connection, Mapping):
- """A wrapper providing defaults to sqlite3.Connection."""
- __metaclass__ = DBMeta
- LOCATION = ":memory:"
- ROW_CLASS = Row
- TABLE_CLASS = Table
- TABLES = None
- def __init__(self, database, *args, **kwargs)
- super(type(self), self).__init__(database, *args, **kwargs)
- self.row_factory = self.ROW_CLASS
- self.table_factory = self.TABLE_CLASS
- self.location = database
-
- self.sync()
- def __repr__(self):
- return "%s(%s)" % (type(self), self.LOCATION)
- def __iter__(self):
- return iter(tables)
- def __len__(self):
- return len(tables)
- def __contains__(self, key):
- return key in self.keys()
- def __getitem__(self, key):
- return tables[key]
- def keys(self):
- return tuple(t.name for t in tables)
- def using_tableclass(self, cls):
- return TableClassContext(self, cls)
- def using_rowclass(self, cls):
- return RowClassContext(self, cls)
- def _fresh_tables(self, tableclass=None):
- if tableclass is None:
- tableclass = self.table_factory
- with self.using_rowclass(sqlite3.Row):
- rows = self.execute("SELECT * FROM sqlite_master").fetchall()
- for row in rows:
- if row["name"].startswith("sqlite"):
- continue
- yield tableclass.from_sql(self, row["sql"])
- def sync(self):
- cls = self.table_factory
- tables = self._fresh_tables(cls)
- if self.TABLES is None:
- self._tables = dict((t.name, t) for t in tables)
- return
- self._tables = {}
- names = list(self.TABLES)
- for table in tables:
- if table.name not in names:
- msg = "Table %s exists, but is not defined."
- raise TypeError(msg % table.name)
- if table.definition != self.TABLES[table.name]:
- msg = "Table %s out of sync"
- raise TypeError(msg % table.name)
- self._tables[names.pop(table.name)] = table
- for name in names:
- #print("adding table %s" % name)
- table = cls.from_definition(self, self.TABLES[name])
- self._tables[name] = table