/tests/store/base.py
Python | 6163 lines | 6028 code | 81 blank | 54 comment | 12 complexity | 7aaf102afe85e3b7417bd8e4385980cb MD5 | raw file
Possible License(s): LGPL-2.1
- # -*- coding: utf-8 -*-
- #
- # Copyright (c) 2006, 2007 Canonical
- #
- # Written by Gustavo Niemeyer <gustavo@niemeyer.net>
- #
- # This file is part of Storm Object Relational Mapper.
- #
- # Storm is free software; you can redistribute it and/or modify
- # it under the terms of the GNU Lesser General Public License as
- # published by the Free Software Foundation; either version 2.1 of
- # the License, or (at your option) any later version.
- #
- # Storm is distributed in the hope that it will be useful,
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- # GNU Lesser General Public License for more details.
- #
- # You should have received a copy of the GNU Lesser General Public License
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
- #
- from cStringIO import StringIO
- import decimal
- import gc
- import operator
- import weakref
- from storm.references import Reference, ReferenceSet, Proxy
- from storm.database import Result
- from storm.properties import Int, Float, RawStr, Unicode, Property, Pickle
- from storm.properties import PropertyPublisherMeta, Decimal
- from storm.variables import PickleVariable
- from storm.expr import (
- Asc, Desc, Select, LeftJoin, SQL, Count, Sum, Avg, And, Or, Eq, Lower, Alias)
- from storm.variables import Variable, UnicodeVariable, IntVariable
- from storm.info import get_obj_info, ClassAlias
- from storm.exceptions import (
- ClosedError, ConnectionBlockedError, FeatureError, LostObjectError,
- NoStoreError, NotFlushedError, NotOneError, OrderLoopError, UnorderedError,
- WrongStoreError)
- from storm.cache import Cache
- from storm.store import AutoReload, EmptyResultSet, Store, ResultSet
- from storm.tracer import debug
- from tests.info import Wrapper
- from tests.helper import TestHelper
- class Foo(object):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode()
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- foo_id = Int()
- foo = Reference(foo_id, Foo.id)
- class Blob(object):
- __storm_table__ = "bin"
- id = Int(primary=True)
- bin = RawStr()
- class Link(object):
- __storm_table__ = "link"
- __storm_primary__ = "foo_id", "bar_id"
- foo_id = Int()
- bar_id = Int()
- class SelfRef(object):
- __storm_table__ = "selfref"
- id = Int(primary=True)
- title = Unicode()
- selfref_id = Int()
- selfref = Reference(selfref_id, id)
- selfref_on_remote = Reference(id, selfref_id, on_remote=True)
- class FooRef(Foo):
- bar = Reference(Foo.id, Bar.foo_id)
- class FooRefSet(Foo):
- bars = ReferenceSet(Foo.id, Bar.foo_id)
- class FooRefSetOrderID(Foo):
- bars = ReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id)
- class FooRefSetOrderTitle(Foo):
- bars = ReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.title)
- class FooIndRefSet(Foo):
- bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id)
- class FooIndRefSetOrderID(Foo):
- bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id,
- order_by=Bar.id)
- class FooIndRefSetOrderTitle(Foo):
- bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id,
- order_by=Bar.title)
- class FooValue(object):
- __storm_table__ = "foovalue"
- id = Int(primary=True)
- foo_id = Int()
- value1 = Int()
- value2 = Int()
- class BarProxy(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- foo_id = Int()
- foo = Reference(foo_id, Foo.id)
- foo_title = Proxy(foo, Foo.title)
- class Money(object):
- __storm_table__ = "money"
- id = Int(primary=True)
- value = Decimal()
- class DecorateVariable(Variable):
- def parse_get(self, value, to_db):
- return u"to_%s(%s)" % (to_db and "db" or "py", value)
- def parse_set(self, value, from_db):
- return u"from_%s(%s)" % (from_db and "db" or "py", value)
- class FooVariable(Foo):
- title = Property(variable_class=DecorateVariable)
- class DummyDatabase(object):
- def connect(self, event=None):
- return None
- class StoreCacheTest(TestHelper):
- def test_wb_custom_cache(self):
- cache = Cache(25)
- store = Store(DummyDatabase(), cache=cache)
- self.assertEquals(store._cache, cache)
- def test_wb_default_cache_size(self):
- store = Store(DummyDatabase())
- self.assertEquals(store._cache._size, 1000)
- class StoreDatabaseTest(TestHelper):
- def test_store_has_reference_to_its_database(self):
- database = DummyDatabase()
- store = Store(database)
- self.assertIdentical(store.get_database(), database)
- class StoreTest(object):
- def setUp(self):
- self.store = None
- self.stores = []
- self.create_database()
- self.connection = self.database.connect()
- self.drop_tables()
- self.create_tables()
- self.create_sample_data()
- self.create_store()
- def tearDown(self):
- self.drop_store()
- self.drop_sample_data()
- self.drop_tables()
- self.drop_database()
- self.connection.close()
- def create_database(self):
- raise NotImplementedError
- def create_tables(self):
- raise NotImplementedError
- def create_sample_data(self):
- connection = self.connection
- connection.execute("INSERT INTO foo (id, title)"
- " VALUES (10, 'Title 30')")
- connection.execute("INSERT INTO foo (id, title)"
- " VALUES (20, 'Title 20')")
- connection.execute("INSERT INTO foo (id, title)"
- " VALUES (30, 'Title 10')")
- connection.execute("INSERT INTO bar (id, foo_id, title)"
- " VALUES (100, 10, 'Title 300')")
- connection.execute("INSERT INTO bar (id, foo_id, title)"
- " VALUES (200, 20, 'Title 200')")
- connection.execute("INSERT INTO bar (id, foo_id, title)"
- " VALUES (300, 30, 'Title 100')")
- connection.execute("INSERT INTO bin (id, bin) VALUES (10, 'Blob 30')")
- connection.execute("INSERT INTO bin (id, bin) VALUES (20, 'Blob 20')")
- connection.execute("INSERT INTO bin (id, bin) VALUES (30, 'Blob 10')")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 100)")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 200)")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 300)")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (20, 100)")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (20, 200)")
- connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (30, 300)")
- connection.execute("INSERT INTO money (id, value)"
- " VALUES (10, '12.3455')")
- connection.execute("INSERT INTO selfref (id, title, selfref_id)"
- " VALUES (15, 'SelfRef 15', NULL)")
- connection.execute("INSERT INTO selfref (id, title, selfref_id)"
- " VALUES (25, 'SelfRef 25', NULL)")
- connection.execute("INSERT INTO selfref (id, title, selfref_id)"
- " VALUES (35, 'SelfRef 35', 15)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (1, 10, 2, 1)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (2, 10, 2, 1)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (3, 10, 2, 1)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (4, 10, 2, 2)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (5, 20, 1, 3)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (6, 20, 1, 3)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (7, 20, 1, 4)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (8, 20, 1, 4)")
- connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)"
- " VALUES (9, 20, 1, 2)")
- connection.commit()
- def create_store(self):
- store = Store(self.database)
- self.stores.append(store)
- if self.store is None:
- self.store = store
- return store
- def drop_store(self):
- for store in self.stores:
- store.rollback()
- # Closing the store is needed because testcase objects are all
- # instantiated at once, and thus connections are kept open.
- store.close()
- def drop_sample_data(self):
- pass
- def drop_tables(self):
- for table in ["foo", "bar", "bin", "link", "money", "selfref",
- "foovalue"]:
- try:
- self.connection.execute("DROP TABLE %s" % table)
- self.connection.commit()
- except:
- self.connection.rollback()
- def drop_database(self):
- pass
- def get_items(self):
- # Bypass the store to avoid flushing.
- connection = self.store._connection
- result = connection.execute("SELECT * FROM foo ORDER BY id")
- return list(result)
- def get_committed_items(self):
- connection = self.database.connect()
- result = connection.execute("SELECT * FROM foo ORDER BY id")
- return list(result)
- def get_cache(self, store):
- # We don't offer a public API for this just yet.
- return store._cache
- def test_execute(self):
- result = self.store.execute("SELECT 1")
- self.assertTrue(isinstance(result, Result))
- self.assertEquals(result.get_one(), (1,))
- result = self.store.execute("SELECT 1", noresult=True)
- self.assertEquals(result, None)
- def test_execute_params(self):
- result = self.store.execute("SELECT ?", [1])
- self.assertTrue(isinstance(result, Result))
- self.assertEquals(result.get_one(), (1,))
- def test_execute_flushes(self):
- foo = self.store.get(Foo, 10)
- foo.title = u"New Title"
- result = self.store.execute("SELECT title FROM foo WHERE id=10")
- self.assertEquals(result.get_one(), ("New Title",))
- def test_close(self):
- store = Store(self.database)
- store.close()
- self.assertRaises(ClosedError, store.execute, "SELECT 1")
- def test_get(self):
- foo = self.store.get(Foo, 10)
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- foo = self.store.get(Foo, 40)
- self.assertEquals(foo, None)
- def test_get_cached(self):
- foo = self.store.get(Foo, 10)
- self.assertTrue(self.store.get(Foo, 10) is foo)
- def test_wb_get_cached_doesnt_need_connection(self):
- foo = self.store.get(Foo, 10)
- connection = self.store._connection
- self.store._connection = None
- self.store.get(Foo, 10)
- self.store._connection = connection
- def test_cache_cleanup(self):
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- foo = self.store.get(Foo, 10)
- foo.taint = True
- del foo
- gc.collect()
- foo = self.store.get(Foo, 10)
- self.assertFalse(getattr(foo, "taint", False))
- def test_add_returns_object(self):
- """
- Store.add() returns the object passed to it. This allows this
- kind of code:
- thing = Thing()
- store.add(thing)
- return thing
- to be simplified as:
- return store.add(Thing())
- """
- foo = Foo()
- self.assertEquals(self.store.add(foo), foo)
- def test_add_and_stop_referencing(self):
- # After adding an object, no references should be needed in
- # python for it still to be added to the database.
- foo = Foo()
- foo.title = u"live"
- self.store.add(foo)
- del foo
- gc.collect()
- self.assertTrue(self.store.find(Foo, title=u"live").one())
- def test_obj_info_with_deleted_object(self):
- # Let's try to put Storm in trouble by killing the object
- # while still holding a reference to the obj_info.
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- class MyFoo(Foo):
- loaded = False
- def __storm_loaded__(self):
- self.loaded = True
- foo = self.store.get(MyFoo, 20)
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- self.assertEquals(obj_info.get_obj(), None)
- foo = self.store.find(MyFoo, id=20).one()
- self.assertTrue(foo)
- self.assertFalse(getattr(foo, "tainted", False))
- # The object was rebuilt, so the loaded hook must have run.
- self.assertTrue(foo.loaded)
- def test_obj_info_with_deleted_object_and_changed_event(self):
- """
- When an object is collected, the variables disable change notification
- to not create a leak. If we're holding a reference to the obj_info and
- rebuild the object, it should re-enable change notication.
- """
- class PickleBlob(Blob):
- bin = Pickle()
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- del blob
- gc.collect()
- pickle_blob = self.store.get(PickleBlob, 20)
- obj_info = get_obj_info(pickle_blob)
- del pickle_blob
- gc.collect()
- self.assertEquals(obj_info.get_obj(), None)
- pickle_blob = self.store.get(PickleBlob, 20)
- pickle_blob.bin = "foobin"
- events = []
- obj_info.event.hook("changed", lambda *args: events.append(args))
- self.store.flush()
- self.assertEquals(len(events), 1)
- def test_wb_flush_event_with_deleted_object_before_flush(self):
- """
- When an object is deleted before flush and it contains mutable
- variables, those variables unhook from the global event system to
- prevent a leak.
- """
- class PickleBlob(Blob):
- bin = Pickle()
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- del blob
- gc.collect()
- pickle_blob = self.store.get(PickleBlob, 20)
- pickle_blob.bin = "foobin"
- del pickle_blob
- self.store.flush()
- self.assertEquals(self.store._event._hooks["flush"], set())
- def test_mutable_variable_detect_change_from_alive(self):
- """
- Changes in a mutable variable like a L{PickleVariable} are correctly
- detected, even if the object comes from the alive cache.
- """
- class PickleBlob(Blob):
- bin = Pickle()
- blob = PickleBlob()
- blob.bin = {"k": "v"}
- blob.id = 4000
- self.store.add(blob)
- self.store.commit()
- blob = self.store.find(PickleBlob, PickleBlob.id == 4000).one()
- blob.bin["k1"] = "v1"
- self.store.commit()
- blob = self.store.find(PickleBlob, PickleBlob.id == 4000).one()
- self.assertEquals(blob.bin, {"k1": "v1", "k": "v"})
- def test_wb_checkpoint_doesnt_override_changed(self):
- """
- This test ensures that we don't uselessly checkpoint when getting back
- objects from the alive cache, which would hide changed values from the
- store.
- """
- foo = self.store.get(Foo, 20)
- foo.title = u"changed"
- self.store.block_implicit_flushes()
- foo2 = self.store.find(Foo, Foo.id == 20).one()
- self.store.unblock_implicit_flushes()
- self.store.commit()
- foo3 = self.store.find(Foo, Foo.id == 20).one()
- self.assertEquals(foo3.title, u"changed")
- def test_obj_info_with_deleted_object_with_get(self):
- # Same thing, but using get rather than find.
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- foo = self.store.get(Foo, 20)
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- self.assertEquals(obj_info.get_obj(), None)
- foo = self.store.get(Foo, 20)
- self.assertTrue(foo)
- self.assertFalse(getattr(foo, "tainted", False))
- def test_delete_object_when_obj_info_is_dirty(self):
- """Object should stay in memory if dirty."""
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- foo = self.store.get(Foo, 20)
- foo.title = u"Changed"
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- self.assertTrue(obj_info.get_obj())
- def test_get_tuple(self):
- class MyFoo(Foo):
- __storm_primary__ = "title", "id"
- foo = self.store.get(MyFoo, (u"Title 30", 10))
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.get(MyFoo, (u"Title 20", 10))
- self.assertEquals(foo, None)
- def test_of(self):
- foo = self.store.get(Foo, 10)
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(Store.of(Foo()), None)
- self.assertEquals(Store.of(object()), None)
- def test_is_empty(self):
- result = self.store.find(Foo, id=300)
- self.assertEquals(result.is_empty(), True)
- result = self.store.find(Foo, id=30)
- self.assertEquals(result.is_empty(), False)
- def test_is_empty_strips_order_by(self):
- """
- L{ResultSet.is_empty} strips the C{ORDER BY} clause, if one is
- present, since it isn't required to actually determine if a result set
- has any matching rows. This should provide a performance improvement
- when the ordered result set would be large.
- """
- stream = StringIO()
- self.addCleanup(debug, False)
- debug(True, stream)
- result = self.store.find(Foo, Foo.id == 300)
- result.order_by(Foo.id)
- self.assertEqual(True, result.is_empty())
- self.assertNotIn("ORDER BY", stream.getvalue())
- def test_is_empty_with_composed_key(self):
- result = self.store.find(Link, foo_id=300, bar_id=3000)
- self.assertEquals(result.is_empty(), True)
- result = self.store.find(Link, foo_id=30, bar_id=300)
- self.assertEquals(result.is_empty(), False)
- def test_is_empty_with_expression_find(self):
- result = self.store.find(Foo.title, Foo.id == 300)
- self.assertEquals(result.is_empty(), True)
- result = self.store.find(Foo.title, Foo.id == 30)
- self.assertEquals(result.is_empty(), False)
- def test_find_iter(self):
- result = self.store.find(Foo)
- lst = [(foo.id, foo.title) for foo in result]
- lst.sort()
- self.assertEquals(lst, [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_find_from_cache(self):
- foo = self.store.get(Foo, 10)
- self.assertTrue(self.store.find(Foo, id=10).one() is foo)
- def test_find_expr(self):
- result = self.store.find(Foo, Foo.id == 20,
- Foo.title == u"Title 20")
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- (20, "Title 20"),
- ])
- result = self.store.find(Foo, Foo.id == 10,
- Foo.title == u"Title 20")
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- ])
- def test_find_sql(self):
- foo = self.store.find(Foo, SQL("foo.id = 20")).one()
- self.assertEquals(foo.title, "Title 20")
- def test_find_str(self):
- foo = self.store.find(Foo, "foo.id = 20").one()
- self.assertEquals(foo.title, "Title 20")
- def test_find_keywords(self):
- result = self.store.find(Foo, id=20, title=u"Title 20")
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- (20, u"Title 20")
- ])
- result = self.store.find(Foo, id=10, title=u"Title 20")
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- ])
- def test_find_order_by(self, *args):
- result = self.store.find(Foo).order_by(Foo.title)
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [
- (30, "Title 10"),
- (20, "Title 20"),
- (10, "Title 30"),
- ])
- def test_find_order_asc(self, *args):
- result = self.store.find(Foo).order_by(Asc(Foo.title))
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [
- (30, "Title 10"),
- (20, "Title 20"),
- (10, "Title 30"),
- ])
- def test_find_order_desc(self, *args):
- result = self.store.find(Foo).order_by(Desc(Foo.title))
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_find_default_order_asc(self):
- class MyFoo(Foo):
- __storm_order__ = "title"
- result = self.store.find(MyFoo)
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [
- (30, "Title 10"),
- (20, "Title 20"),
- (10, "Title 30"),
- ])
- def test_find_default_order_desc(self):
- class MyFoo(Foo):
- __storm_order__ = "-title"
- result = self.store.find(MyFoo)
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_find_default_order_with_tuple(self):
- class MyLink(Link):
- __storm_order__ = ("foo_id", "-bar_id")
- result = self.store.find(MyLink)
- lst = [(link.foo_id, link.bar_id) for link in result]
- self.assertEquals(lst, [
- (10, 300),
- (10, 200),
- (10, 100),
- (20, 200),
- (20, 100),
- (30, 300),
- ])
- def test_find_default_order_with_tuple_and_expr(self):
- class MyLink(Link):
- __storm_order__ = ("foo_id", Desc(Link.bar_id))
- result = self.store.find(MyLink)
- lst = [(link.foo_id, link.bar_id) for link in result]
- self.assertEquals(lst, [
- (10, 300),
- (10, 200),
- (10, 100),
- (20, 200),
- (20, 100),
- (30, 300),
- ])
- def test_find_index(self):
- """
- L{ResultSet.__getitem__} returns the object at the specified index.
- if a slice is used, a new L{ResultSet} is returned configured with the
- appropriate offset and limit.
- """
- foo = self.store.find(Foo).order_by(Foo.title)[0]
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, "Title 10")
- foo = self.store.find(Foo).order_by(Foo.title)[1]
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- foo = self.store.find(Foo).order_by(Foo.title)[2]
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.find(Foo).order_by(Foo.title)[1:][1]
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- result = self.store.find(Foo).order_by(Foo.title)
- self.assertRaises(IndexError, result.__getitem__, 3)
- def test_find_slice(self):
- result = self.store.find(Foo).order_by(Foo.title)[1:2]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20")])
- def test_find_slice_offset(self):
- result = self.store.find(Foo).order_by(Foo.title)[1:]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20"),
- (10, "Title 30")])
- def test_find_slice_offset_any(self):
- foo = self.store.find(Foo).order_by(Foo.title)[1:].any()
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- def test_find_slice_offset_one(self):
- foo = self.store.find(Foo).order_by(Foo.title)[1:2].one()
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- def test_find_slice_offset_first(self):
- foo = self.store.find(Foo).order_by(Foo.title)[1:].first()
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- def test_find_slice_offset_last(self):
- foo = self.store.find(Foo).order_by(Foo.title)[1:].last()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- def test_find_slice_limit(self):
- result = self.store.find(Foo).order_by(Foo.title)[:2]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(30, "Title 10"),
- (20, "Title 20")])
- def test_find_slice_limit_last(self):
- result = self.store.find(Foo).order_by(Foo.title)[:2]
- self.assertRaises(FeatureError, result.last)
- def test_find_slice_slice(self):
- result = self.store.find(Foo).order_by(Foo.title)[0:2][1:3]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20")])
- result = self.store.find(Foo).order_by(Foo.title)[:2][1:3]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20")])
- result = self.store.find(Foo).order_by(Foo.title)[1:3][0:1]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20")])
- result = self.store.find(Foo).order_by(Foo.title)[1:3][:1]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst,
- [(20, "Title 20")])
- result = self.store.find(Foo).order_by(Foo.title)[5:5][1:1]
- lst = [(foo.id, foo.title) for foo in result]
- self.assertEquals(lst, [])
- def test_find_slice_order_by(self):
- result = self.store.find(Foo)[2:]
- self.assertRaises(FeatureError, result.order_by, None)
- result = self.store.find(Foo)[:2]
- self.assertRaises(FeatureError, result.order_by, None)
- def test_find_slice_remove(self):
- result = self.store.find(Foo)[2:]
- self.assertRaises(FeatureError, result.remove)
- result = self.store.find(Foo)[:2]
- self.assertRaises(FeatureError, result.remove)
- def test_find_contains(self):
- foo = self.store.get(Foo, 10)
- result = self.store.find(Foo)
- self.assertEquals(foo in result, True)
- result = self.store.find(Foo, Foo.id == 20)
- self.assertEquals(foo in result, False)
- result = self.store.find(Foo, "foo.id = 20")
- self.assertEquals(foo in result, False)
- def test_find_contains_wrong_type(self):
- foo = self.store.get(Foo, 10)
- bar = self.store.get(Bar, 200)
- self.assertRaises(TypeError, operator.contains,
- self.store.find(Foo), bar)
- self.assertRaises(TypeError, operator.contains,
- self.store.find((Foo,)), foo)
- self.assertRaises(TypeError, operator.contains,
- self.store.find(Foo), (foo,))
- self.assertRaises(TypeError, operator.contains,
- self.store.find((Foo, Bar)), (bar, foo))
- def test_find_contains_does_not_use_iter(self):
- def no_iter(self):
- raise RuntimeError()
- orig_iter = ResultSet.__iter__
- ResultSet.__iter__ = no_iter
- try:
- foo = self.store.get(Foo, 10)
- result = self.store.find(Foo)
- self.assertEquals(foo in result, True)
- finally:
- ResultSet.__iter__ = orig_iter
- def test_find_contains_with_composed_key(self):
- link = self.store.get(Link, (10, 100))
- result = self.store.find(Link, Link.foo_id == 10)
- self.assertEquals(link in result, True)
- result = self.store.find(Link, Link.bar_id == 200)
- self.assertEquals(link in result, False)
- def test_find_contains_with_set_expression(self):
- foo = self.store.get(Foo, 10)
- result1 = self.store.find(Foo, Foo.id == 10)
- result2 = self.store.find(Foo, Foo.id != 10)
- self.assertEquals(foo in result1.union(result2), True)
- if self.__class__.__name__.startswith("MySQL"):
- return
- self.assertEquals(foo in result1.intersection(result2), False)
- self.assertEquals(foo in result1.intersection(result1), True)
- self.assertEquals(foo in result1.difference(result2), True)
- self.assertEquals(foo in result1.difference(result1), False)
- def test_find_any(self, *args):
- """
- L{ResultSet.any} returns an arbitrary objects from the result set.
- """
- self.assertNotEqual(None, self.store.find(Foo).any())
- self.assertEqual(None, self.store.find(Foo, id=40).any())
- def test_find_any_strips_order_by(self):
- """
- L{ResultSet.any} strips the C{ORDER BY} clause, if one is present,
- since it isn't required. This should provide a performance
- improvement when the ordered result set would be large.
- """
- stream = StringIO()
- self.addCleanup(debug, False)
- debug(True, stream)
- result = self.store.find(Foo, Foo.id == 300)
- result.order_by(Foo.id)
- result.any()
- self.assertNotIn("ORDER BY", stream.getvalue())
- def test_find_first(self, *args):
- self.assertRaises(UnorderedError, self.store.find(Foo).first)
- foo = self.store.find(Foo).order_by(Foo.title).first()
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, "Title 10")
- foo = self.store.find(Foo).order_by(Foo.id).first()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.find(Foo, id=40).order_by(Foo.id).first()
- self.assertEquals(foo, None)
- def test_find_last(self, *args):
- self.assertRaises(UnorderedError, self.store.find(Foo).last)
- foo = self.store.find(Foo).order_by(Foo.title).last()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.find(Foo).order_by(Foo.id).last()
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, "Title 10")
- foo = self.store.find(Foo, id=40).order_by(Foo.id).last()
- self.assertEquals(foo, None)
- def test_find_last_desc(self, *args):
- foo = self.store.find(Foo).order_by(Desc(Foo.title)).last()
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, "Title 10")
- foo = self.store.find(Foo).order_by(Asc(Foo.id)).last()
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, "Title 10")
- def test_find_one(self, *args):
- self.assertRaises(NotOneError, self.store.find(Foo).one)
- foo = self.store.find(Foo, id=10).one()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.find(Foo, id=40).one()
- self.assertEquals(foo, None)
- def test_find_count(self):
- self.assertEquals(self.store.find(Foo).count(), 3)
- def test_find_count_after_slice(self):
- """
- When we slice a ResultSet obtained after a set operation (like union),
- we get a fresh select that doesn't modify the limit and offset
- attribute of the original ResultSet.
- """
- result1 = self.store.find(Foo, Foo.id == 10)
- result2 = self.store.find(Foo, Foo.id == 20)
- result3 = result1.union(result2)
- result3.order_by(Foo.id)
- self.assertEquals(result3.count(), 2)
- result_slice = list(result3[:2])
- self.assertEquals(result3.count(), 2)
- def test_find_count_column(self):
- self.assertEquals(self.store.find(Link).count(Link.foo_id), 6)
- def test_find_count_column_distinct(self):
- count = self.store.find(Link).count(Link.foo_id, distinct=True)
- self.assertEquals(count, 3)
- def test_find_limit_count(self):
- result = self.store.find(Link.foo_id)
- result.config(limit=2)
- count = result.count()
- self.assertEquals(count, 2)
- def test_find_offset_count(self):
- result = self.store.find(Link.foo_id)
- result.config(offset=3)
- count = result.count()
- self.assertEquals(count, 3)
- def test_find_sliced_count(self):
- result = self.store.find(Link.foo_id)
- count = result[2:4].count()
- self.assertEquals(count, 2)
- def test_find_distinct_count(self):
- result = self.store.find(Link.foo_id)
- result.config(distinct=True)
- count = result.count()
- self.assertEquals(count, 3)
- def test_find_distinct_order_by_limit_count(self):
- result = self.store.find(Foo)
- result.order_by(Foo.title)
- result.config(distinct=True, limit=3)
- count = result.count()
- self.assertEquals(count, 3)
- def test_find_distinct_count_multiple_columns(self):
- result = self.store.find((Link.foo_id, Link.bar_id))
- result.config(distinct=True)
- count = result.count()
- self.assertEquals(count, 6)
- def test_find_count_column_with_implicit_distinct(self):
- result = self.store.find(Link)
- result.config(distinct=True)
- count = result.count(Link.foo_id)
- self.assertEquals(count, 6)
- def test_find_max(self):
- self.assertEquals(self.store.find(Foo).max(Foo.id), 30)
- def test_find_max_expr(self):
- self.assertEquals(self.store.find(Foo).max(Foo.id + 1), 31)
- def test_find_max_unicode(self):
- title = self.store.find(Foo).max(Foo.title)
- self.assertEquals(title, "Title 30")
- self.assertTrue(isinstance(title, unicode))
- def test_find_max_with_empty_result_and_disallow_none(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int(allow_none=False)
- result = self.store.find(Bar, Bar.id > 1000)
- self.assertTrue(result.is_empty())
- self.assertEquals(result.max(Bar.foo_id), None)
- def test_find_min(self):
- self.assertEquals(self.store.find(Foo).min(Foo.id), 10)
- def test_find_min_expr(self):
- self.assertEquals(self.store.find(Foo).min(Foo.id - 1), 9)
- def test_find_min_unicode(self):
- title = self.store.find(Foo).min(Foo.title)
- self.assertEquals(title, "Title 10")
- self.assertTrue(isinstance(title, unicode))
- def test_find_min_with_empty_result_and_disallow_none(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int(allow_none=False)
- result = self.store.find(Bar, Bar.id > 1000)
- self.assertTrue(result.is_empty())
- self.assertEquals(result.min(Bar.foo_id), None)
- def test_find_avg(self):
- self.assertEquals(self.store.find(Foo).avg(Foo.id), 20)
- def test_find_avg_expr(self):
- self.assertEquals(self.store.find(Foo).avg(Foo.id + 10), 30)
- def test_find_avg_float(self):
- foo = Foo()
- foo.id = 15
- foo.title = u"Title 15"
- self.store.add(foo)
- self.assertEquals(self.store.find(Foo).avg(Foo.id), 18.75)
- def test_find_sum(self):
- self.assertEquals(self.store.find(Foo).sum(Foo.id), 60)
- def test_find_sum_expr(self):
- self.assertEquals(self.store.find(Foo).sum(Foo.id * 2), 120)
- def test_find_sum_with_empty_result_and_disallow_none(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int(allow_none=False)
- result = self.store.find(Bar, Bar.id > 1000)
- self.assertTrue(result.is_empty())
- self.assertEquals(result.sum(Bar.foo_id), None)
- def test_find_max_order_by(self):
- """Interaction between order by and aggregation shouldn't break."""
- result = self.store.find(Foo)
- self.assertEquals(result.order_by(Foo.id).max(Foo.id), 30)
- def test_find_get_select_expr_without_columns(self):
- """
- A L{FeatureError} is raised if L{ResultSet.get_select_expr} is called
- without a list of L{Column}s.
- """
- result = self.store.find(Foo)
- self.assertRaises(FeatureError, result.get_select_expr)
- def test_find_get_select_expr(self):
- """
- Only the specified L{Column}s are included in the L{Select} expression
- provided by L{ResultSet.get_select_expr}.
- """
- foo = self.store.get(Foo, 10)
- result1 = self.store.find(Foo, Foo.id <= 10)
- subselect = result1.get_select_expr(Foo.id)
- self.assertEqual((Foo.id,), subselect.columns)
- result2 = self.store.find(Foo, Foo.id.is_in(subselect))
- self.assertEqual([foo], list(result2))
- def test_find_get_select_expr_with_set_expression(self):
- """
- A L{FeatureError} is raised if L{ResultSet.get_select_expr} is used
- with a L{ResultSet} that represents a set expression, such as a union.
- """
- result1 = self.store.find(Foo, Foo.id == 10)
- result2 = self.store.find(Foo, Foo.id == 20)
- result3 = result1.union(result2)
- self.assertRaises(FeatureError, result3.get_select_expr, Foo.id)
- def test_find_values(self):
- values = self.store.find(Foo).order_by(Foo.id).values(Foo.id)
- self.assertEquals(list(values), [10, 20, 30])
- values = self.store.find(Foo).order_by(Foo.id).values(Foo.title)
- values = list(values)
- self.assertEquals(values, ["Title 30", "Title 20", "Title 10"])
- self.assertEquals([type(value) for value in values],
- [unicode, unicode, unicode])
- def test_find_multiple_values(self):
- result = self.store.find(Foo).order_by(Foo.id)
- values = result.values(Foo.id, Foo.title)
- self.assertEquals(list(values),
- [(10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10")])
- def test_find_values_with_no_arguments(self):
- result = self.store.find(Foo).order_by(Foo.id)
- self.assertRaises(FeatureError, result.values().next)
- def test_find_slice_values(self):
- values = self.store.find(Foo).order_by(Foo.id)[1:2].values(Foo.id)
- self.assertEquals(list(values), [20])
- def test_find_values_with_set_expression(self):
- """
- A L{FeatureError} is raised if L{ResultSet.values} is used with a
- L{ResultSet} that represents a set expression, such as a union.
- """
- result1 = self.store.find(Foo, Foo.id == 10)
- result2 = self.store.find(Foo, Foo.id == 20)
- result3 = result1.union(result2)
- self.assertRaises(FeatureError, list, result3.values(Foo.id))
- def test_find_remove(self):
- self.store.find(Foo, Foo.id == 20).remove()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- def test_find_cached(self):
- foo = self.store.get(Foo, 20)
- bar = self.store.get(Bar, 200)
- self.assertTrue(foo)
- self.assertTrue(bar)
- self.assertEquals(self.store.find(Foo).cached(), [foo])
- def test_find_cached_where(self):
- foo1 = self.store.get(Foo, 10)
- foo2 = self.store.get(Foo, 20)
- bar = self.store.get(Bar, 200)
- self.assertTrue(foo1)
- self.assertTrue(foo2)
- self.assertTrue(bar)
- self.assertEquals(self.store.find(Foo, title=u"Title 20").cached(),
- [foo2])
- def test_find_cached_invalidated(self):
- foo = self.store.get(Foo, 20)
- self.store.invalidate(foo)
- self.assertEquals(self.store.find(Foo).cached(), [foo])
- def test_find_cached_invalidated_and_deleted(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.store.invalidate(foo)
- # Do not look for the primary key (id), since it's able to get
- # it without touching the database. Use the title instead.
- self.assertEquals(self.store.find(Foo, title=u"Title 20").cached(), [])
- def test_find_cached_with_info_alive_and_object_dead(self):
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- foo = self.store.get(Foo, 20)
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- cached = self.store.find(Foo).cached()
- self.assertEquals(len(cached), 1)
- foo = self.store.get(Foo, 20)
- self.assertFalse(hasattr(foo, "tainted"))
- def test_using_find_join(self):
- bar = self.store.get(Bar, 100)
- bar.foo_id = None
- tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id))
- result = tables.find(Bar).order_by(Foo.id, Bar.id)
- lst = [bar and (bar.id, bar.title) for bar in result]
- self.assertEquals(lst, [
- None,
- (200, u"Title 200"),
- (300, u"Title 100"),
- ])
- def test_using_find_with_strings(self):
- foo = self.store.using("foo").find(Foo, id=10).one()
- self.assertEquals(foo.title, "Title 30")
- foo = self.store.using("foo", "bar").find(Foo, id=10).any()
- self.assertEquals(foo.title, "Title 30")
- def test_using_find_join_with_strings(self):
- bar = self.store.get(Bar, 100)
- bar.foo_id = None
- tables = self.store.using(LeftJoin("foo", "bar",
- "bar.foo_id = foo.id"))
- result = tables.find(Bar).order_by(Foo.id, Bar.id)
- lst = [bar and (bar.id, bar.title) for bar in result]
- self.assertEquals(lst, [
- None,
- (200, u"Title 200"),
- (300, u"Title 100"),
- ])
- def test_find_tuple(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- result = result.order_by(Foo.id)
- lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title))
- for (foo, bar) in result]
- self.assertEquals(lst, [
- ((10, u"Title 30"), (100, u"Title 300")),
- ((30, u"Title 10"), (300, u"Title 100")),
- ])
- def test_find_tuple_using(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id))
- result = tables.find((Foo, Bar)).order_by(Foo.id)
- lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title))
- for (foo, bar) in result]
- self.assertEquals(lst, [
- ((10, u"Title 30"), (100, u"Title 300")),
- ((20, u"Title 20"), None),
- ((30, u"Title 10"), (300, u"Title 100")),
- ])
- def test_find_tuple_using_with_disallow_none(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True, allow_none=False)
- title = Unicode()
- foo_id = Int()
- foo = Reference(foo_id, Foo.id)
- bar = self.store.get(Bar, 200)
- self.store.remove(bar)
- tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id))
- result = tables.find((Foo, Bar)).order_by(Foo.id)
- lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title))
- for (foo, bar) in result]
- self.assertEquals(lst, [
- ((10, u"Title 30"), (100, u"Title 300")),
- ((20, u"Title 20"), None),
- ((30, u"Title 10"), (300, u"Title 100")),
- ])
- def test_find_tuple_using_skip_when_none(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- tables = self.store.using(Foo,
- LeftJoin(Bar, Bar.foo_id == Foo.id),
- LeftJoin(Link, Link.bar_id == Bar.id))
- result = tables.find((Bar, Link)).order_by(Foo.id, Bar.id, Link.foo_id)
- lst = [(bar and (bar.id, bar.title),
- link and (link.bar_id, link.foo_id))
- for (bar, link) in result]
- self.assertEquals(lst, [
- ((100, u"Title 300"), (100, 10)),
- ((100, u"Title 300"), (100, 20)),
- (None, None),
- ((300, u"Title 100"), (300, 10)),
- ((300, u"Title 100"), (300, 30)),
- ])
- def test_find_tuple_contains(self):
- foo = self.store.get(Foo, 10)
- bar = self.store.get(Bar, 100)
- bar200 = self.store.get(Bar, 200)
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- self.assertEquals((foo, bar) in result, True)
- self.assertEquals((foo, bar200) in result, False)
- def test_find_tuple_contains_with_set_expression(self):
- foo = self.store.get(Foo, 10)
- bar = self.store.get(Bar, 100)
- bar200 = self.store.get(Bar, 200)
- result1 = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- result2 = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- self.assertEquals((foo, bar) in result1.union(result2), True)
- if self.__class__.__name__.startswith("MySQL"):
- return
- self.assertEquals((foo, bar) in result1.intersection(result2), True)
- self.assertEquals((foo, bar) in result1.difference(result2), False)
- def test_find_tuple_any(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- foo, bar = result.order_by(Foo.id).any()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, u"Title 30")
- self.assertEquals(bar.id, 100)
- self.assertEquals(bar.title, u"Title 300")
- def test_find_tuple_first(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- foo, bar = result.order_by(Foo.id).first()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, u"Title 30")
- self.assertEquals(bar.id, 100)
- self.assertEquals(bar.title, u"Title 300")
- def test_find_tuple_last(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- foo, bar = result.order_by(Foo.id).last()
- self.assertEquals(foo.id, 30)
- self.assertEquals(foo.title, u"Title 10")
- self.assertEquals(bar.id, 300)
- self.assertEquals(bar.title, u"Title 100")
- def test_find_tuple_one(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar),
- Bar.foo_id == Foo.id, Foo.id == 10)
- foo, bar = result.order_by(Foo.id).one()
- self.assertEquals(foo.id, 10)
- self.assertEquals(foo.title, u"Title 30")
- self.assertEquals(bar.id, 100)
- self.assertEquals(bar.title, u"Title 300")
- def test_find_tuple_count(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id)
- self.assertEquals(result.count(), 2)
- def test_find_tuple_remove(self):
- result = self.store.find((Foo, Bar))
- self.assertRaises(FeatureError, result.remove)
- def test_find_tuple_set(self):
- result = self.store.find((Foo, Bar))
- self.assertRaises(FeatureError, result.set, title=u"Title 40")
- def test_find_tuple_kwargs(self):
- self.assertRaises(FeatureError,
- self.store.find, (Foo, Bar), title=u"Title 10")
- def test_find_tuple_cached(self):
- result = self.store.find((Foo, Bar))
- self.assertRaises(FeatureError, result.cached)
- def test_find_using_cached(self):
- result = self.store.using(Foo, Bar).find(Foo)
- self.assertRaises(FeatureError, result.cached)
- def test_find_with_expr(self):
- result = self.store.find(Foo.title)
- self.assertEquals(sorted(result),
- [u"Title 10", u"Title 20", u"Title 30"])
- def test_find_with_expr_uses_variable_set(self):
- result = self.store.find(FooVariable.title,
- FooVariable.id == 10)
- self.assertEquals(list(result), [u"to_py(from_db(Title 30))"])
- def test_find_tuple_with_expr(self):
- result = self.store.find((Foo, Bar.id, Bar.title),
- Bar.foo_id == Foo.id)
- result.order_by(Foo.id)
- self.assertEquals([(foo.id, foo.title, bar_id, bar_title)
- for foo, bar_id, bar_title in result],
- [(10, u"Title 30", 100, u"Title 300"),
- (20, u"Title 20", 200, u"Title 200"),
- (30, u"Title 10", 300, u"Title 100")])
- def test_find_using_with_expr(self):
- result = self.store.using(Foo).find(Foo.title)
- self.assertEquals(sorted(result),
- [u"Title 10", u"Title 20", u"Title 30"])
- def test_find_with_expr_contains(self):
- result = self.store.find(Foo.title)
- self.assertEquals(u"Title 10" in result, True)
- self.assertEquals(u"Title 42" in result, False)
- def test_find_tuple_with_expr_contains(self):
- foo = self.store.get(Foo, 10)
- result = self.store.find((Foo, Bar.title),
- Bar.foo_id == Foo.id)
- self.assertEquals((foo, u"Title 300") in result, True)
- self.assertEquals((foo, u"Title 100") in result, False)
- def test_find_with_expr_contains_with_set_expression(self):
- result1 = self.store.find(Foo.title)
- result2 = self.store.find(Foo.title)
- self.assertEquals(u"Title 10" in result1.union(result2), True)
- if self.__class__.__name__.startswith("MySQL"):
- return
- self.assertEquals(u"Title 10" in result1.intersection(result2), True)
- self.assertEquals(u"Title 10" in result1.difference(result2), False)
- def test_find_with_expr_remove_unsupported(self):
- result = self.store.find(Foo.title)
- self.assertRaises(FeatureError, result.remove)
- def test_find_tuple_with_expr_remove_unsupported(self):
- result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id)
- self.assertRaises(FeatureError, result.remove)
- def test_find_with_expr_count(self):
- result = self.store.find(Foo.title)
- self.assertEquals(result.count(), 3)
- def test_find_tuple_with_expr_count(self):
- result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id)
- self.assertEquals(result.count(), 3)
- def test_find_with_expr_values(self):
- result = self.store.find(Foo.title)
- self.assertEquals(sorted(result.values(Foo.title)),
- [u"Title 10", u"Title 20", u"Title 30"])
- def test_find_tuple_with_expr_values(self):
- result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id)
- self.assertEquals(sorted(result.values(Foo.title)),
- [u"Title 10", u"Title 20", u"Title 30"])
- def test_find_with_expr_set_unsupported(self):
- result = self.store.find(Foo.title)
- self.assertRaises(FeatureError, result.set)
- def test_find_tuple_with_expr_set_unsupported(self):
- result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id)
- self.assertRaises(FeatureError, result.set)
- def test_find_with_expr_cached_unsupported(self):
- result = self.store.find(Foo.title)
- self.assertRaises(FeatureError, result.cached)
- def test_find_tuple_with_expr_cached_unsupported(self):
- result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id)
- self.assertRaises(FeatureError, result.cached)
- def test_find_with_expr_union(self):
- result1 = self.store.find(Foo.title, Foo.id == 10)
- result2 = self.store.find(Foo.title, Foo.id != 10)
- result = result1.union(result2)
- self.assertEquals(sorted(result),
- [u"Title 10", u"Title 20", u"Title 30",])
- def test_find_alias_with_expr_union(self):
- result1 = self.store.find(Alias(Foo.title, 'title'),
- Foo.title == u"Title 30")
- result2 = self.store.find(Alias(Bar.title, 'title'),
- Bar.title == u"Title 100")
- result = result1.union(result2)
- self.assertEquals(sorted(result), [u'Title 100', u'Title 30'])
- def test_find_tuple_with_expr_union(self):
- result1 = self.store.find(
- (Foo, Bar.title), Bar.foo_id == Foo.id, Bar.title == u"Title 100")
- result2 = self.store.find(
- (Foo, Bar.title), Bar.foo_id == Foo.id, Bar.title == u"Title 200")
- result = result1.union(result2)
- self.assertEquals(sorted((foo.id, title) for (foo, title) in result),
- [(20, u"Title 200"), (30, u"Title 100")])
- def test_get_does_not_validate(self):
- def validator(object, attr, value):
- self.fail("validator called with arguments (%r, %r, %r)" %
- (object, attr, value))
- class Foo(object):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode(validator=validator)
- foo = self.store.get(Foo, 10)
- self.assertEqual(foo.title, "Title 30")
- def test_get_does_not_validate_default_value(self):
- def validator(object, attr, value):
- self.fail("validator called with arguments (%r, %r, %r)" %
- (object, attr, value))
- class Foo(object):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode(validator=validator, default=u"default value")
- foo = self.store.get(Foo, 10)
- self.assertEqual(foo.title, "Title 30")
- def test_find_does_not_validate(self):
- def validator(object, attr, value):
- self.fail("validator called with arguments (%r, %r, %r)" %
- (object, attr, value))
- class Foo(object):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode(validator=validator)
- foo = self.store.find(Foo, Foo.id == 10).one()
- self.assertEqual(foo.title, "Title 30")
- def test_find_group_by(self):
- result = self.store.find((Count(FooValue.id), Sum(FooValue.value1)))
- result.group_by(FooValue.value2)
- result.order_by(Count(FooValue.id), Sum(FooValue.value1))
- result = list(result)
- self.assertEquals(result, [(2L, 2L), (2L, 2L), (2L, 3L), (3L, 6L)])
- def test_find_group_by_table(self):
- result = self.store.find(
- (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id)
- result.group_by(Foo)
- foo1 = self.store.get(Foo, 10)
- foo2 = self.store.get(Foo, 20)
- self.assertEquals(list(result), [(5, foo1), (16, foo2)])
- def test_find_group_by_table_contains(self):
- result = self.store.find(
- (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id)
- result.group_by(Foo)
- foo1 = self.store.get(Foo, 10)
- self.assertEquals((5, foo1) in result, True)
- def test_find_group_by_multiple_tables(self):
- result = self.store.find(
- Sum(FooValue.value2), Foo.id == FooValue.foo_id)
- result.group_by(Foo.id)
- result.order_by(Sum(FooValue.value2))
- result = list(result)
- self.assertEquals(result, [5, 16])
- result = self.store.find(
- (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id)
- result.group_by(Foo)
- result.order_by(Sum(FooValue.value2))
- result = list(result)
- foo1 = self.store.get(Foo, 10)
- foo2 = self.store.get(Foo, 20)
- self.assertEquals(result, [(5, foo1), (16, foo2)])
- result = self.store.find(
- (Foo.id, Sum(FooValue.value2), Avg(FooValue.value1)),
- Foo.id == FooValue.foo_id)
- result.group_by(Foo.id)
- result.order_by(Foo.id)
- result = list(result)
- self.assertEquals(result, [(10, 5, 2),
- (20, 16, 1)])
- def test_find_group_by_having(self):
- result = self.store.find(
- Sum(FooValue.value2), Foo.id == FooValue.foo_id)
- result.group_by(Foo.id)
- result.having(Sum(FooValue.value2) == 5)
- self.assertEquals(list(result), [5])
- result = self.store.find(
- Sum(FooValue.value2), Foo.id == FooValue.foo_id)
- result.group_by(Foo.id)
- result.having(Count() == 5)
- self.assertEquals(list(result), [16])
- def test_find_having_without_group_by(self):
- result = self.store.find(FooValue)
- self.assertRaises(FeatureError, result.having, FooValue.value1 == 1)
- def test_find_group_by_multiple_having(self):
- result = self.store.find((Count(), FooValue.value2))
- result.group_by(FooValue.value2)
- result.having(Count() == 2, FooValue.value2 >= 3)
- result.order_by(Count(), FooValue.value2)
- list_result = list(result)
- self.assertEquals(list_result, [(2, 3), (2, 4)])
- def test_find_successive_group_by(self):
- result = self.store.find(Count())
- result.group_by(FooValue.value2)
- result.order_by(Count())
- list_result = list(result)
- self.assertEquals(list_result, [2, 2, 2, 3])
- result.group_by(FooValue.value1)
- list_result = list(result)
- self.assertEquals(list_result, [4, 5])
- def test_find_multiple_group_by(self):
- result = self.store.find(Count())
- result.group_by(FooValue.value2, FooValue.value1)
- result.order_by(Count())
- list_result = list(result)
- self.assertEquals(list_result, [1, 1, 2, 2, 3])
- def test_find_multiple_group_by_with_having(self):
- result = self.store.find((Count(), FooValue.value2))
- result.group_by(FooValue.value2, FooValue.value1).having(Count() == 2)
- result.order_by(Count(), FooValue.value2)
- list_result = list(result)
- self.assertEquals(list_result, [(2, 3), (2, 4)])
- def test_find_group_by_avg(self):
- result = self.store.find((Count(FooValue.id), Sum(FooValue.value1)))
- result.group_by(FooValue.value2)
- self.assertRaises(FeatureError, result.avg, FooValue.value2)
- def test_find_group_by_values(self):
- result = self.store.find(
- (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id)
- result.group_by(Foo)
- result.order_by(Foo.title)
- result = list(result.values(Foo.title))
- self.assertEquals(result, [u'Title 20', u'Title 30'])
- def test_find_group_by_union(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=10)
- result3 = result1.union(result2)
- self.assertRaises(FeatureError, result3.group_by, Foo.title)
- def test_find_group_by_remove(self):
- result = self.store.find((Count(FooValue.id), Sum(FooValue.value1)))
- result.group_by(FooValue.value2)
- self.assertRaises(FeatureError, result.remove)
- def test_find_group_by_set(self):
- result = self.store.find((Count(FooValue.id), Sum(FooValue.value1)))
- result.group_by(FooValue.value2)
- self.assertRaises(FeatureError, result.set, FooValue.value1 == 1)
- def test_add_commit(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- (40, "Title 40"),
- ])
- def test_add_rollback_commit(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.rollback()
- self.assertEquals(self.store.get(Foo, 3), None)
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_add_get(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- old_foo = foo
- foo = self.store.get(Foo, 40)
- self.assertEquals(foo.id, 40)
- self.assertEquals(foo.title, "Title 40")
- self.assertTrue(foo is old_foo)
- def test_add_find(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- old_foo = foo
- foo = self.store.find(Foo, Foo.id == 40).one()
- self.assertEquals(foo.id, 40)
- self.assertEquals(foo.title, "Title 40")
- self.assertTrue(foo is old_foo)
- def test_add_twice(self):
- foo = Foo()
- self.store.add(foo)
- self.store.add(foo)
- self.assertEquals(Store.of(foo), self.store)
- def test_add_loaded(self):
- foo = self.store.get(Foo, 10)
- self.store.add(foo)
- self.assertEquals(Store.of(foo), self.store)
- def test_add_twice_to_wrong_store(self):
- foo = Foo()
- self.store.add(foo)
- self.assertRaises(WrongStoreError, Store(self.database).add, foo)
- def test_add_checkpoints(self):
- bar = Bar()
- self.store.add(bar)
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo_id = 40
- self.store.flush()
- self.store.execute("UPDATE bar SET title='Title 500' "
- "WHERE id=400")
- bar.foo_id = 400
- # When not checkpointing, this flush will set title again.
- self.store.flush()
- self.store.reload(bar)
- self.assertEquals(bar.title, "Title 500")
- def test_add_completely_undefined(self):
- foo = Foo()
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(type(foo.id), int)
- self.assertEquals(foo.title, u"Default Title")
- def test_remove_commit(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.assertEquals(Store.of(foo), self.store)
- self.store.flush()
- self.assertEquals(Store.of(foo), None)
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.commit()
- self.assertEquals(Store.of(foo), None)
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- def test_remove_rollback_update(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.rollback()
- foo.title = u"Title 200"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_remove_flush_rollback_update(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.flush()
- self.store.rollback()
- foo.title = u"Title 200"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_remove_add_update(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.add(foo)
- foo.title = u"Title 200"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_remove_flush_add_update(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.flush()
- self.store.add(foo)
- foo.title = u"Title 200"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_remove_twice(self):
- foo = self.store.get(Foo, 10)
- self.store.remove(foo)
- self.store.remove(foo)
- def test_remove_unknown(self):
- foo = Foo()
- self.assertRaises(WrongStoreError, self.store.remove, foo)
- def test_remove_from_wrong_store(self):
- foo = self.store.get(Foo, 20)
- self.assertRaises(WrongStoreError, Store(self.database).remove, foo)
- def test_wb_remove_flush_update_isnt_dirty(self):
- foo = self.store.get(Foo, 20)
- obj_info = get_obj_info(foo)
- self.store.remove(foo)
- self.store.flush()
- foo.title = u"Title 200"
- self.assertTrue(obj_info not in self.store._dirty)
- def test_wb_remove_rollback_isnt_dirty(self):
- foo = self.store.get(Foo, 20)
- obj_info = get_obj_info(foo)
- self.store.remove(foo)
- self.store.rollback()
- self.assertTrue(obj_info not in self.store._dirty)
- def test_wb_remove_flush_rollback_isnt_dirty(self):
- foo = self.store.get(Foo, 20)
- obj_info = get_obj_info(foo)
- self.store.remove(foo)
- self.store.flush()
- self.store.rollback()
- self.assertTrue(obj_info not in self.store._dirty)
- def test_add_rollback_not_in_store(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.rollback()
- self.assertEquals(Store.of(foo), None)
- def test_update_flush_commit(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_update_flush_reload_rollback(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- self.store.flush()
- self.store.reload(foo)
- self.store.rollback()
- self.assertEquals(foo.title, "Title 20")
- def test_update_commit(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_update_commit_twice(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- self.store.commit()
- foo.title = u"Title 2000"
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (20, "Title 2000"),
- (30, "Title 10"),
- ])
- def test_update_checkpoints(self):
- bar = self.store.get(Bar, 200)
- bar.title = u"Title 400"
- self.store.flush()
- self.store.execute("UPDATE bar SET title='Title 500' "
- "WHERE id=200")
- bar.foo_id = 40
- # When not checkpointing, this flush will set title again.
- self.store.flush()
- self.store.reload(bar)
- self.assertEquals(bar.title, "Title 500")
- def test_update_primary_key(self):
- foo = self.store.get(Foo, 20)
- foo.id = 25
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (25, "Title 20"),
- (30, "Title 10"),
- ])
- # Update twice to see if the notion of primary key for the
- # existent object was updated as well.
- foo.id = 27
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 30"),
- (27, "Title 20"),
- (30, "Title 10"),
- ])
- # Ensure only the right ones are there.
- self.assertTrue(self.store.get(Foo, 27) is foo)
- self.assertTrue(self.store.get(Foo, 25) is None)
- self.assertTrue(self.store.get(Foo, 20) is None)
- def test_update_primary_key_exchange(self):
- foo1 = self.store.get(Foo, 10)
- foo2 = self.store.get(Foo, 30)
- foo1.id = 40
- self.store.flush()
- foo2.id = 10
- self.store.flush()
- foo1.id = 30
- self.assertTrue(self.store.get(Foo, 30) is foo1)
- self.assertTrue(self.store.get(Foo, 10) is foo2)
- self.store.commit()
- self.assertEquals(self.get_committed_items(), [
- (10, "Title 10"),
- (20, "Title 20"),
- (30, "Title 30"),
- ])
- def test_wb_update_not_dirty_after_flush(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- self.store.flush()
- # If changes get committed even with the notification disabled,
- # it means the dirty flag isn't being cleared.
- self.store._disable_change_notification(get_obj_info(foo))
- foo.title = u"Title 2000"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- def test_update_find(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- result = self.store.find(Foo, Foo.title == u"Title 200")
- self.assertTrue(result.one() is foo)
- def test_update_get(self):
- foo = self.store.get(Foo, 20)
- foo.id = 200
- self.assertTrue(self.store.get(Foo, 200) is foo)
- def test_add_update(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- foo.title = u"Title 400"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- (40, "Title 400"),
- ])
- def test_add_remove_add(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.remove(foo)
- self.assertEquals(Store.of(foo), None)
- foo.title = u"Title 400"
- self.store.add(foo)
- foo.id = 400
- self.store.commit()
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- (400, "Title 400"),
- ])
- self.assertTrue(self.store.get(Foo, 400) is foo)
- def test_wb_add_remove_add(self):
- foo = Foo()
- obj_info = get_obj_info(foo)
- self.store.add(foo)
- self.assertTrue(obj_info in self.store._dirty)
- self.store.remove(foo)
- self.assertTrue(obj_info not in self.store._dirty)
- self.store.add(foo)
- self.assertTrue(obj_info in self.store._dirty)
- self.assertTrue(Store.of(foo) is self.store)
- def test_wb_update_remove_add(self):
- foo = self.store.get(Foo, 20)
- foo.title = u"Title 200"
- obj_info = get_obj_info(foo)
- self.store.remove(foo)
- self.store.add(foo)
- self.assertTrue(obj_info in self.store._dirty)
- def test_commit_autoreloads(self):
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "Title 20")
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- self.store.commit()
- self.assertEquals(foo.title, "New Title")
- def test_commit_invalidates(self):
- foo = self.store.get(Foo, 20)
- self.assertTrue(foo)
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.assertEquals(self.store.get(Foo, 20), foo)
- self.store.commit()
- self.assertEquals(self.store.get(Foo, 20), None)
- def test_rollback_autoreloads(self):
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "Title 20")
- self.store.rollback()
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "New Title")
- def test_rollback_invalidates(self):
- foo = self.store.get(Foo, 20)
- self.assertTrue(foo)
- self.assertEquals(self.store.get(Foo, 20), foo)
- self.store.rollback()
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.assertEquals(self.store.get(Foo, 20), None)
- def test_sub_class(self):
- class SubFoo(Foo):
- id = Float(primary=True)
- foo1 = self.store.get(Foo, 20)
- foo2 = self.store.get(SubFoo, 20)
- self.assertEquals(foo1.id, 20)
- self.assertEquals(foo2.id, 20)
- self.assertEquals(type(foo1.id), int)
- self.assertEquals(type(foo2.id), float)
- def test_join(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- bar = Bar()
- bar.id = 40
- bar.title = u"Title 20"
- self.store.add(bar)
- # Add anbar object with the same title to ensure DISTINCT
- # is in place.
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 20"
- self.store.add(bar)
- result = self.store.find(Foo, Foo.title == Bar.title)
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- (20, "Title 20"),
- (20, "Title 20"),
- ])
- def test_join_distinct(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- bar = Bar()
- bar.id = 40
- bar.title = u"Title 20"
- self.store.add(bar)
- # Add a bar object with the same title to ensure DISTINCT
- # is in place.
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 20"
- self.store.add(bar)
- result = self.store.find(Foo, Foo.title == Bar.title)
- result.config(distinct=True)
- # Make sure that it won't unset it, and that it's returning itself.
- config = result.config()
- self.assertEquals([(foo.id, foo.title) for foo in result], [
- (20, "Title 20"),
- ])
- def test_sub_select(self):
- foo = self.store.find(Foo, Foo.id == Select(SQL("20"))).one()
- self.assertTrue(foo)
- self.assertEquals(foo.id, 20)
- self.assertEquals(foo.title, "Title 20")
- def test_cache_has_improper_object(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.commit()
- self.store.execute("INSERT INTO foo VALUES (20, 'Title 20')")
- self.assertTrue(self.store.get(Foo, 20) is not foo)
- def test_cache_has_improper_object_readded(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.flush()
- old_foo = foo # Keep a reference.
- foo = Foo()
- foo.id = 20
- foo.title = u"Readded"
- self.store.add(foo)
- self.store.commit()
- self.assertTrue(self.store.get(Foo, 20) is foo)
- def test_loaded_hook(self):
- loaded = []
- class MyFoo(Foo):
- def __init__(self):
- loaded.append("NO!")
- def __storm_loaded__(self):
- loaded.append((self.id, self.title))
- self.title = u"Title 200"
- self.some_attribute = 1
- foo = self.store.get(MyFoo, 20)
- self.assertEquals(loaded, [(20, "Title 20")])
- self.assertEquals(foo.title, "Title 200")
- self.assertEquals(foo.some_attribute, 1)
- foo.some_attribute = 2
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 200"),
- (30, "Title 10"),
- ])
- self.store.rollback()
- self.assertEquals(foo.title, "Title 20")
- self.assertEquals(foo.some_attribute, 2)
- def test_flush_hook(self):
- class MyFoo(Foo):
- counter = 0
- def __storm_pre_flush__(self):
- if self.counter == 0:
- self.title = u"Flushing: %s" % self.title
- self.counter += 1
- foo = self.store.get(MyFoo, 20)
- self.assertEquals(foo.title, "Title 20")
- self.store.flush()
- self.assertEquals(foo.title, "Title 20") # It wasn't dirty.
- foo.title = u"Something"
- self.store.flush()
- self.assertEquals(foo.title, "Flushing: Something")
- # It got in the database, because it was flushed *twice* (the
- # title was changed after flushed, and thus the object got dirty
- # again).
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Flushing: Something"),
- (30, "Title 10"),
- ])
- # This shouldn't do anything, because the object is clean again.
- foo.counter = 0
- self.store.flush()
- self.assertEquals(foo.title, "Flushing: Something")
- def test_flush_hook_all(self):
- class MyFoo(Foo):
- def __storm_pre_flush__(self):
- other = [foo1, foo2][foo1 is self]
- other.title = u"Changed in hook: " + other.title
- foo1 = self.store.get(MyFoo, 10)
- foo2 = self.store.get(MyFoo, 20)
- foo1.title = u"Changed"
- self.store.flush()
- self.assertEquals(foo1.title, "Changed in hook: Changed")
- self.assertEquals(foo2.title, "Changed in hook: Title 20")
- def test_flushed_hook(self):
- class MyFoo(Foo):
- done = False
- def __storm_flushed__(self):
- if not self.done:
- self.done = True
- self.title = u"Flushed: %s" % self.title
- foo = self.store.get(MyFoo, 20)
- self.assertEquals(foo.title, "Title 20")
- self.store.flush()
- self.assertEquals(foo.title, "Title 20") # It wasn't dirty.
- foo.title = u"Something"
- self.store.flush()
- self.assertEquals(foo.title, "Flushed: Something")
- # It got in the database, because it was flushed *twice* (the
- # title was changed after flushed, and thus the object got dirty
- # again).
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Flushed: Something"),
- (30, "Title 10"),
- ])
- # This shouldn't do anything, because the object is clean again.
- foo.done = False
- self.store.flush()
- self.assertEquals(foo.title, "Flushed: Something")
- def test_retrieve_default_primary_key(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.flush()
- self.assertNotEquals(foo.id, None)
- self.assertTrue(self.store.get(Foo, foo.id) is foo)
- def test_retrieve_default_value(self):
- foo = Foo()
- foo.id = 40
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(foo.title, "Default Title")
- def test_retrieve_null_when_no_default(self):
- bar = Bar()
- bar.id = 400
- self.store.add(bar)
- self.store.flush()
- self.assertEquals(bar.title, None)
- def test_wb_remove_prop_not_dirty(self):
- foo = self.store.get(Foo, 20)
- obj_info = get_obj_info(foo)
- del foo.title
- self.assertTrue(obj_info not in self.store._dirty)
- def test_flush_with_removed_prop(self):
- foo = self.store.get(Foo, 20)
- del foo.title
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_flush_with_removed_prop_forced_dirty(self):
- foo = self.store.get(Foo, 20)
- del foo.title
- foo.id = 40
- foo.id = 20
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_flush_with_removed_prop_really_dirty(self):
- foo = self.store.get(Foo, 20)
- del foo.title
- foo.id = 25
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (25, "Title 20"),
- (30, "Title 10"),
- ])
- def test_wb_block_implicit_flushes(self):
- # Make sure calling store.flush() will fail.
- def flush():
- raise RuntimeError("Flush called")
- self.store.flush = flush
- # The following operations do not call flush.
- self.store.block_implicit_flushes()
- foo = self.store.get(Foo, 20)
- foo = self.store.find(Foo, Foo.id == 20).one()
- self.store.execute("SELECT title FROM foo WHERE id = 20")
- self.store.unblock_implicit_flushes()
- self.assertRaises(RuntimeError, self.store.get, Foo, 20)
- def test_wb_block_implicit_flushes_is_recursive(self):
- # Make sure calling store.flush() will fail.
- def flush():
- raise RuntimeError("Flush called")
- self.store.flush = flush
- self.store.block_implicit_flushes()
- self.store.block_implicit_flushes()
- self.store.unblock_implicit_flushes()
- # implicit flushes are still blocked, until unblock() is called again.
- foo = self.store.get(Foo, 20)
- self.store.unblock_implicit_flushes()
- self.assertRaises(RuntimeError, self.store.get, Foo, 20)
- def test_block_access(self):
- """Access to the store is blocked by block_access()."""
- # The set_blocked() method blocks access to the connection.
- self.store.block_access()
- self.assertRaises(ConnectionBlockedError,
- self.store.execute, "SELECT 1")
- self.assertRaises(ConnectionBlockedError, self.store.commit)
- # The rollback method is not blocked.
- self.store.rollback()
- self.store.unblock_access()
- self.store.execute("SELECT 1")
- def test_reload(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='Title 40' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- self.store.reload(foo)
- self.assertEquals(foo.title, "Title 40")
- def test_reload_not_changed(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='Title 40' WHERE id=20")
- self.store.reload(foo)
- for variable in get_obj_info(foo).variables.values():
- self.assertFalse(variable.has_changed())
- def test_reload_new(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.assertRaises(WrongStoreError, self.store.reload, foo)
- def test_reload_new_unflushed(self):
- foo = Foo()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.assertRaises(NotFlushedError, self.store.reload, foo)
- def test_reload_removed(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.flush()
- self.assertRaises(WrongStoreError, self.store.reload, foo)
- def test_reload_unknown(self):
- foo = self.store.get(Foo, 20)
- store = self.create_store()
- self.assertRaises(WrongStoreError, store.reload, foo)
- def test_wb_reload_not_dirty(self):
- foo = self.store.get(Foo, 20)
- obj_info = get_obj_info(foo)
- foo.title = u"Title 40"
- self.store.reload(foo)
- self.assertTrue(obj_info not in self.store._dirty)
- def test_find_set_empty(self):
- self.store.find(Foo, title=u"Title 20").set()
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "Title 20")
- def test_find_set(self):
- self.store.find(Foo, title=u"Title 20").set(title=u"Title 40")
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "Title 40")
- def test_find_set_with_func_expr(self):
- self.store.find(Foo, title=u"Title 20").set(title=Lower(u"Title 40"))
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "title 40")
- def test_find_set_equality_with_func_expr(self):
- self.store.find(Foo, title=u"Title 20").set(
- Foo.title == Lower(u"Title 40"))
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "title 40")
- def test_find_set_column(self):
- self.store.find(Bar, title=u"Title 200").set(foo_id=Bar.id)
- bar = self.store.get(Bar, 200)
- self.assertEquals(bar.foo_id, 200)
- def test_find_set_expr(self):
- self.store.find(Foo, title=u"Title 20").set(Foo.title == u"Title 40")
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "Title 40")
- def test_find_set_none(self):
- self.store.find(Foo, title=u"Title 20").set(title=None)
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, None)
- def test_find_set_expr_column(self):
- self.store.find(Bar, id=200).set(Bar.foo_id == Bar.id)
- bar = self.store.get(Bar, 200)
- self.assertEquals(bar.id, 200)
- self.assertEquals(bar.foo_id, 200)
- def test_find_set_on_cached(self):
- foo1 = self.store.get(Foo, 20)
- foo2 = self.store.get(Foo, 30)
- self.store.find(Foo, id=20).set(id=40)
- self.assertEquals(foo1.id, 40)
- self.assertEquals(foo2.id, 30)
- def test_find_set_expr_on_cached(self):
- bar = self.store.get(Bar, 200)
- self.store.find(Bar, id=200).set(Bar.foo_id == Bar.id)
- self.assertEquals(bar.id, 200)
- self.assertEquals(bar.foo_id, 200)
- def test_find_set_none_on_cached(self):
- foo = self.store.get(Foo, 20)
- self.store.find(Foo, title=u"Title 20").set(title=None)
- self.assertEquals(foo.title, None)
- def test_find_set_on_cached_but_removed(self):
- foo1 = self.store.get(Foo, 20)
- foo2 = self.store.get(Foo, 30)
- self.store.remove(foo1)
- self.store.find(Foo, id=20).set(id=40)
- self.assertEquals(foo1.id, 20)
- self.assertEquals(foo2.id, 30)
- def test_find_set_on_cached_unsupported_python_expr(self):
- foo1 = self.store.get(Foo, 20)
- foo2 = self.store.get(Foo, 30)
- self.store.find(
- Foo, Foo.id == Select(SQL("20"))).set(title=u"Title 40")
- self.assertEquals(foo1.title, "Title 40")
- self.assertEquals(foo2.title, "Title 10")
- def test_find_set_expr_unsupported(self):
- result = self.store.find(Foo, title=u"Title 20")
- self.assertRaises(FeatureError, result.set, Foo.title > u"Title 40")
- def test_find_set_expr_unsupported_without_column(self):
- result = self.store.find(Foo, title=u"Title 20")
- self.assertRaises(FeatureError, result.set,
- Eq(object(), IntVariable(1)))
- def test_find_set_expr_unsupported_without_expr_or_variable(self):
- result = self.store.find(Foo, title=u"Title 20")
- self.assertRaises(FeatureError, result.set, Eq(Foo.id, object()))
- def test_find_set_expr_unsupported_autoreloads(self):
- bar1 = self.store.get(Bar, 200)
- bar2 = self.store.get(Bar, 300)
- self.store.find(Bar, id=Select(SQL("200"))).set(title=u"Title 400")
- bar1_vars = get_obj_info(bar1).variables
- bar2_vars = get_obj_info(bar2).variables
- self.assertEquals(bar1_vars[Bar.title].get_lazy(), AutoReload)
- self.assertEquals(bar2_vars[Bar.title].get_lazy(), AutoReload)
- self.assertEquals(bar1_vars[Bar.foo_id].get_lazy(), None)
- self.assertEquals(bar2_vars[Bar.foo_id].get_lazy(), None)
- self.assertEquals(bar1.title, "Title 400")
- self.assertEquals(bar2.title, "Title 100")
- def test_find_set_expr_unsupported_mixed_autoreloads(self):
- # For an expression that does not compile (eg:
- # ResultSet.cached() raises a CompileError), while setting
- # cached entries' columns to AutoReload, if objects of
- # different types could be found in the cache then a KeyError
- # would happen if some object did not have a matching
- # column. See Bug #328603 for more info.
- foo1 = self.store.get(Foo, 20)
- bar1 = self.store.get(Bar, 200)
- self.store.find(Bar, id=Select(SQL("200"))).set(title=u"Title 400")
- foo1_vars = get_obj_info(foo1).variables
- bar1_vars = get_obj_info(bar1).variables
- self.assertNotEquals(foo1_vars[Foo.title].get_lazy(), AutoReload)
- self.assertEquals(bar1_vars[Bar.title].get_lazy(), AutoReload)
- self.assertEquals(bar1_vars[Bar.foo_id].get_lazy(), None)
- self.assertEquals(foo1.title, "Title 20")
- self.assertEquals(bar1.title, "Title 400")
- def test_find_set_autoreloads_with_func_expr(self):
- # In the process of fixing this bug, we've temporarily
- # introduced another bug: the expression would be called
- # twice. We've used an expression that increments the value by
- # one here to see if that case is triggered. In the buggy
- # bugfix, the value would end up being incremented by two due
- # to misfiring two updates.
- foo1 = self.store.get(FooValue, 1)
- self.assertEquals(foo1.value1, 2)
- self.store.find(FooValue, id=1).set(value1=SQL("value1 + 1"))
- foo1_vars = get_obj_info(foo1).variables
- self.assertEquals(foo1_vars[FooValue.value1].get_lazy(), AutoReload)
- self.assertEquals(foo1.value1, 3)
- def test_find_set_equality_autoreloads_with_func_expr(self):
- foo1 = self.store.get(FooValue, 1)
- self.assertEquals(foo1.value1, 2)
- self.store.find(FooValue, id=1).set(
- FooValue.value1 == SQL("value1 + 1"))
- foo1_vars = get_obj_info(foo1).variables
- self.assertEquals(foo1_vars[FooValue.value1].get_lazy(), AutoReload)
- self.assertEquals(foo1.value1, 3)
- def test_wb_find_set_checkpoints(self):
- bar = self.store.get(Bar, 200)
- self.store.find(Bar, id=200).set(title=u"Title 400")
- self.store._connection.execute("UPDATE bar SET "
- "title='Title 500' "
- "WHERE id=200")
- # When not checkpointing, this flush will set title again.
- self.store.flush()
- self.store.reload(bar)
- self.assertEquals(bar.title, "Title 500")
- def test_find_set_with_info_alive_and_object_dead(self):
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- foo = self.store.get(Foo, 20)
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- self.store.find(Foo, title=u"Title 20").set(title=u"Title 40")
- foo = self.store.get(Foo, 20)
- self.assertFalse(hasattr(foo, "tainted"))
- self.assertEquals(foo.title, "Title 40")
- def test_reference(self):
- bar = self.store.get(Bar, 100)
- self.assertTrue(bar.foo)
- self.assertEquals(bar.foo.title, "Title 30")
- def test_reference_explicitly_with_wrapper(self):
- bar = self.store.get(Bar, 100)
- foo = Bar.foo.__get__(Wrapper(bar))
- self.assertTrue(foo)
- self.assertEquals(foo.title, "Title 30")
- def test_reference_break_on_local_diverged(self):
- bar = self.store.get(Bar, 100)
- self.assertTrue(bar.foo)
- bar.foo_id = 40
- self.assertEquals(bar.foo, None)
- def test_reference_break_on_remote_diverged(self):
- bar = self.store.get(Bar, 100)
- bar.foo.id = 40
- self.assertEquals(bar.foo, None)
- def test_reference_break_on_local_diverged_by_lazy(self):
- bar = self.store.get(Bar, 100)
- self.assertEquals(bar.foo.id, 10)
- bar.foo_id = SQL("20")
- self.assertEquals(bar.foo.id, 20)
- def test_reference_remote_leak_on_flush_with_changed(self):
- """
- "changed" events only hold weak references to remote infos object, thus
- not creating a leak when unhooked.
- """
- self.get_cache(self.store).set_size(0)
- bar = self.store.get(Bar, 100)
- bar.foo.title = u"Changed title"
- bar_ref = weakref.ref(get_obj_info(bar))
- foo = bar.foo
- del bar
- self.store.flush()
- gc.collect()
- self.assertEquals(bar_ref(), None)
- def test_reference_remote_leak_on_flush_with_removed(self):
- """
- "removed" events only hold weak references to remote infos objects,
- thus not creating a leak when unhooked.
- """
- self.get_cache(self.store).set_size(0)
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- foo = self.store.get(MyFoo, 10)
- foo.bar.title = u"Changed title"
- foo_ref = weakref.ref(get_obj_info(foo))
- bar = foo.bar
- del foo
- self.store.flush()
- gc.collect()
- self.assertEquals(foo_ref(), None)
- def test_reference_break_on_remote_diverged_by_lazy(self):
- class MyBar(Bar):
- pass
- MyBar.foo = Reference(MyBar.title, Foo.title)
- bar = self.store.get(MyBar, 100)
- bar.title = u"Title 30"
- self.store.flush()
- self.assertEquals(bar.foo.id, 10)
- bar.foo.title = SQL("'Title 40'")
- self.assertEquals(bar.foo, None)
- self.assertEquals(self.store.find(Foo, title=u"Title 30").one(), None)
- self.assertEquals(self.store.get(Foo, 10).title, u"Title 40")
- def test_reference_on_non_primary_key(self):
- self.store.execute("INSERT INTO bar VALUES (400, 40, 'Title 30')")
- class MyBar(Bar):
- foo = Reference(Bar.title, Foo.title)
- bar = self.store.get(Bar, 400)
- self.assertEquals(bar.title, "Title 30")
- self.assertEquals(bar.foo, None)
- mybar = self.store.get(MyBar, 400)
- self.assertEquals(mybar.title, "Title 30")
- self.assertNotEquals(mybar.foo, None)
- self.assertEquals(mybar.foo.id, 10)
- self.assertEquals(mybar.foo.title, "Title 30")
- def test_new_reference(self):
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo_id = 10
- self.assertEquals(bar.foo, None)
- self.store.add(bar)
- self.assertTrue(bar.foo)
- self.assertEquals(bar.foo.title, "Title 30")
- def test_set_reference(self):
- bar = self.store.get(Bar, 100)
- self.assertEquals(bar.foo.id, 10)
- foo = self.store.get(Foo, 30)
- bar.foo = foo
- self.assertEquals(bar.foo.id, 30)
- result = self.store.execute("SELECT foo_id FROM bar WHERE id=100")
- self.assertEquals(result.get_one(), (30,))
- def test_set_reference_explicitly_with_wrapper(self):
- bar = self.store.get(Bar, 100)
- self.assertEquals(bar.foo.id, 10)
- foo = self.store.get(Foo, 30)
- Bar.foo.__set__(Wrapper(bar), Wrapper(foo))
- self.assertEquals(bar.foo.id, 30)
- result = self.store.execute("SELECT foo_id FROM bar WHERE id=100")
- self.assertEquals(result.get_one(), (30,))
- def test_reference_assign_remote_key(self):
- bar = self.store.get(Bar, 100)
- self.assertEquals(bar.foo.id, 10)
- bar.foo = 30
- self.assertEquals(bar.foo_id, 30)
- self.assertEquals(bar.foo.id, 30)
- result = self.store.execute("SELECT foo_id FROM bar WHERE id=100")
- self.assertEquals(result.get_one(), (30,))
- def test_reference_on_added(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.store.add(bar)
- self.assertEquals(bar.foo.id, None)
- self.assertEquals(bar.foo.title, "Title 40")
- self.store.flush()
- self.assertTrue(bar.foo.id)
- self.assertEquals(bar.foo.title, "Title 40")
- result = self.store.execute("SELECT foo.title FROM foo, bar "
- "WHERE bar.id=400 AND "
- "foo.id = bar.foo_id")
- self.assertEquals(result.get_one(), ("Title 40",))
- def test_reference_on_added_with_autoreload_key(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.store.add(bar)
- self.assertEquals(bar.foo.id, None)
- self.assertEquals(bar.foo.title, "Title 40")
- foo.id = AutoReload
- # Variable shouldn't be autoreloaded yet.
- obj_info = get_obj_info(foo)
- self.assertEquals(obj_info.variables[Foo.id].get_lazy(), AutoReload)
- self.assertEquals(type(foo.id), int)
- self.store.flush()
- self.assertTrue(bar.foo.id)
- self.assertEquals(bar.foo.title, "Title 40")
- result = self.store.execute("SELECT foo.title FROM foo, bar "
- "WHERE bar.id=400 AND "
- "foo.id = bar.foo_id")
- self.assertEquals(result.get_one(), ("Title 40",))
- def test_reference_assign_none(self):
- foo = Foo()
- foo.title = u"Title 40"
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- bar.foo = None
- bar.foo = None # Twice to make sure it doesn't blow up.
- self.store.add(bar)
- self.store.flush()
- self.assertEquals(type(bar.id), int)
- self.assertEquals(foo.id, None)
- def test_reference_assign_none_with_unseen(self):
- bar = self.store.get(Bar, 200)
- bar.foo = None
- self.assertEquals(bar.foo, None)
- def test_reference_on_added_composed_key(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int()
- title = Unicode()
- foo = Reference((foo_id, title), (Foo.id, Foo.title))
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.foo = foo
- self.store.add(bar)
- self.assertEquals(bar.foo.id, None)
- self.assertEquals(bar.foo.title, "Title 40")
- self.assertEquals(bar.title, "Title 40")
- self.store.flush()
- self.assertTrue(bar.foo.id)
- self.assertEquals(bar.foo.title, "Title 40")
- result = self.store.execute("SELECT foo.title FROM foo, bar "
- "WHERE bar.id=400 AND "
- "foo.id = bar.foo_id")
- self.assertEquals(result.get_one(), ("Title 40",))
- def test_reference_assign_composed_remote_key(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int()
- title = Unicode()
- foo = Reference((foo_id, title), (Foo.id, Foo.title))
- bar = Bar()
- bar.id = 400
- bar.foo = (20, u"Title 20")
- self.store.add(bar)
- self.assertEquals(bar.foo_id, 20)
- self.assertEquals(bar.foo.id, 20)
- self.assertEquals(bar.title, "Title 20")
- self.assertEquals(bar.foo.title, "Title 20")
- def test_reference_on_added_unlink_on_flush(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.foo = foo
- bar.title = u"Title 400"
- self.store.add(bar)
- foo.id = 40
- self.assertEquals(bar.foo_id, 40)
- foo.id = 50
- self.assertEquals(bar.foo_id, 50)
- foo.id = 60
- self.assertEquals(bar.foo_id, 60)
- self.store.flush()
- foo.id = 70
- self.assertEquals(bar.foo_id, 60)
- def test_reference_on_added_unsets_original_key(self):
- foo = Foo()
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.foo_id = 40
- bar.foo = foo
- self.assertEquals(bar.foo_id, None)
- def test_reference_on_two_added(self):
- foo1 = Foo()
- foo1.title = u"Title 40"
- foo2 = Foo()
- foo2.title = u"Title 40"
- self.store.add(foo1)
- self.store.add(foo2)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo1
- bar.foo = foo2
- self.store.add(bar)
- foo1.id = 40
- self.assertEquals(bar.foo_id, None)
- foo2.id = 50
- self.assertEquals(bar.foo_id, 50)
- def test_reference_on_added_and_changed_manually(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.store.add(bar)
- bar.foo_id = 40
- foo.id = 50
- self.assertEquals(bar.foo_id, 40)
- def test_reference_on_added_composed_key_changed_manually(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int()
- title = Unicode()
- foo = Reference((foo_id, title), (Foo.id, Foo.title))
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.foo = foo
- self.store.add(bar)
- bar.title = u"Title 50"
- self.assertEquals(bar.foo, None)
- foo.id = 40
- self.assertEquals(bar.foo_id, None)
- def test_reference_on_added_no_local_store(self):
- foo = Foo()
- foo.title = u"Title 40"
- self.store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- def test_reference_on_added_no_remote_store(self):
- foo = Foo()
- foo.title = u"Title 40"
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- self.store.add(bar)
- bar.foo = foo
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- def test_reference_on_added_no_store(self):
- foo = Foo()
- foo.title = u"Title 40"
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.store.add(bar)
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_reference_on_added_no_store_2(self):
- foo = Foo()
- foo.title = u"Title 40"
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo
- self.store.add(foo)
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_reference_on_added_wrong_store(self):
- store = self.create_store()
- foo = Foo()
- foo.title = u"Title 40"
- store.add(foo)
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- self.store.add(bar)
- self.assertRaises(WrongStoreError, setattr, bar, "foo", foo)
- def test_reference_on_added_no_store_unlink_before_adding(self):
- foo1 = Foo()
- foo1.title = u"Title 40"
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- bar.foo = foo1
- bar.foo = None
- self.store.add(bar)
- store = self.create_store()
- store.add(foo1)
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo1), store)
- def test_reference_on_removed_wont_add_back(self):
- bar = self.store.get(Bar, 200)
- foo = self.store.get(Foo, bar.foo_id)
- self.store.remove(bar)
- self.assertEquals(bar.foo, foo)
- self.store.flush()
- self.assertEquals(Store.of(bar), None)
- self.assertEquals(Store.of(foo), self.store)
- def test_reference_equals(self):
- foo = self.store.get(Foo, 10)
- bar = self.store.find(Bar, foo=foo).one()
- self.assertTrue(bar)
- self.assertEquals(bar.foo, foo)
- bar = self.store.find(Bar, foo=foo.id).one()
- self.assertTrue(bar)
- self.assertEquals(bar.foo, foo)
- def test_reference_equals_none(self):
- result = list(self.store.find(SelfRef, selfref=None))
- self.assertEquals(len(result), 2)
- self.assertEquals(result[0].selfref, None)
- self.assertEquals(result[1].selfref, None)
- def test_reference_equals_with_composed_key(self):
- # Interesting case of self-reference.
- class LinkWithRef(Link):
- myself = Reference((Link.foo_id, Link.bar_id),
- (Link.foo_id, Link.bar_id))
- link = self.store.find(LinkWithRef, foo_id=10, bar_id=100).one()
- myself = self.store.find(LinkWithRef, myself=link).one()
- self.assertEquals(link, myself)
- myself = self.store.find(LinkWithRef,
- myself=(link.foo_id, link.bar_id)).one()
- self.assertEquals(link, myself)
- def test_reference_equals_with_wrapped(self):
- foo = self.store.get(Foo, 10)
- bar = self.store.find(Bar, foo=Wrapper(foo)).one()
- self.assertTrue(bar)
- self.assertEquals(bar.foo, foo)
- def test_reference_not_equals(self):
- foo = self.store.get(Foo, 10)
- result = self.store.find(Bar, Bar.foo != foo)
- self.assertEquals([200, 300], sorted(bar.id for bar in result))
- def test_reference_not_equals_none(self):
- obj = self.store.find(SelfRef, SelfRef.selfref != None).one()
- self.assertTrue(obj)
- self.assertNotEquals(obj.selfref, None)
- def test_reference_not_equals_with_composed_key(self):
- class LinkWithRef(Link):
- myself = Reference((Link.foo_id, Link.bar_id),
- (Link.foo_id, Link.bar_id))
- link = self.store.find(LinkWithRef, foo_id=10, bar_id=100).one()
- result = list(self.store.find(LinkWithRef, LinkWithRef.myself != link))
- self.assertTrue(link not in result, "%r not in %r" % (link, result))
- result = list(self.store.find(
- LinkWithRef, LinkWithRef.myself != (link.foo_id, link.bar_id)))
- self.assertTrue(link not in result, "%r not in %r" % (link, result))
- def test_reference_self(self):
- selfref = self.store.add(SelfRef())
- selfref.id = 400
- selfref.title = u"Title 400"
- selfref.selfref_id = 25
- self.assertEquals(selfref.selfref.id, 25)
- self.assertEquals(selfref.selfref.title, "SelfRef 25")
- def get_bar_200_title(self):
- connection = self.store._connection
- result = connection.execute("SELECT title FROM bar WHERE id=200")
- return result.get_one()[0]
- def test_reference_wont_touch_store_when_key_is_none(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- bar.title = u"Don't flush this!"
- self.assertEquals(bar.foo, None)
- # Bypass the store to prevent flushing.
- self.assertEquals(self.get_bar_200_title(), "Title 200")
- def test_reference_wont_touch_store_when_key_is_unset(self):
- bar = self.store.get(Bar, 200)
- del bar.foo_id
- bar.title = u"Don't flush this!"
- self.assertEquals(bar.foo, None)
- # Bypass the store to prevent flushing.
- connection = self.store._connection
- result = connection.execute("SELECT title FROM bar WHERE id=200")
- self.assertEquals(result.get_one()[0], "Title 200")
- def test_reference_wont_touch_store_with_composed_key_none(self):
- class Bar(object):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int()
- title = Unicode()
- foo = Reference((foo_id, title), (Foo.id, Foo.title))
- bar = self.store.get(Bar, 200)
- bar.foo_id = None
- bar.title = None
- self.assertEquals(bar.foo, None)
- # Bypass the store to prevent flushing.
- self.assertEquals(self.get_bar_200_title(), "Title 200")
- def test_reference_will_resolve_auto_reload(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = AutoReload
- self.assertTrue(bar.foo)
- def test_back_reference(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- foo = self.store.get(MyFoo, 10)
- self.assertTrue(foo.bar)
- self.assertEquals(foo.bar.title, "Title 300")
- def test_back_reference_setting(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- self.store.add(bar)
- foo = MyFoo()
- foo.bar = bar
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.flush()
- self.assertTrue(foo.id)
- self.assertEquals(bar.foo_id, foo.id)
- result = self.store.execute("SELECT bar.title "
- "FROM foo, bar "
- "WHERE foo.id = bar.foo_id AND "
- "foo.title = 'Title 40'")
- self.assertEquals(result.get_one(), ("Title 400",))
- def test_back_reference_setting_changed_manually(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- self.store.add(bar)
- foo = MyFoo()
- foo.bar = bar
- foo.title = u"Title 40"
- self.store.add(foo)
- foo.id = 40
- self.assertEquals(foo.bar, bar)
- self.store.flush()
- self.assertEquals(foo.id, 40)
- self.assertEquals(bar.foo_id, 40)
- result = self.store.execute("SELECT bar.title "
- "FROM foo, bar "
- "WHERE foo.id = bar.foo_id AND "
- "foo.title = 'Title 40'")
- self.assertEquals(result.get_one(), ("Title 400",))
- def test_back_reference_assign_none_with_unseen(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- foo = self.store.get(MyFoo, 20)
- foo.bar = None
- self.assertEquals(foo.bar, None)
- def test_back_reference_assign_none_from_none(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- self.store.execute("INSERT INTO foo (id, title)"
- " VALUES (40, 'Title 40')")
- self.store.commit()
- foo = self.store.get(MyFoo, 40)
- foo.bar = None
- self.assertEquals(foo.bar, None)
- def test_back_reference_on_added_unsets_original_key(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- foo = MyFoo()
- bar = Bar()
- bar.id = 400
- bar.foo_id = 40
- foo.bar = bar
- self.assertEquals(bar.foo_id, None)
- def test_back_reference_on_added_no_store(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- foo = MyFoo()
- foo.bar = bar
- foo.title = u"Title 40"
- self.store.add(bar)
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_back_reference_on_added_no_store_2(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- foo = MyFoo()
- foo.bar = bar
- foo.title = u"Title 40"
- self.store.add(foo)
- self.assertEquals(Store.of(bar), self.store)
- self.assertEquals(Store.of(foo), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_back_reference_remove_remote(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- foo = MyFoo()
- foo.title = u"Title 40"
- foo.bar = bar
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(foo.bar, bar)
- self.store.remove(bar)
- self.assertEquals(foo.bar, None)
- def test_back_reference_remove_remote_pending_add(self):
- class MyFoo(Foo):
- bar = Reference(Foo.id, Bar.foo_id, on_remote=True)
- bar = Bar()
- bar.title = u"Title 400"
- foo = MyFoo()
- foo.title = u"Title 40"
- foo.bar = bar
- self.store.add(foo)
- self.assertEquals(foo.bar, bar)
- self.store.remove(bar)
- self.assertEquals(foo.bar, None)
- def test_reference_loop_with_undefined_keys_fails(self):
- """A loop of references with undefined keys raises OrderLoopError."""
- ref1 = SelfRef()
- self.store.add(ref1)
- ref2 = SelfRef()
- ref2.selfref = ref1
- ref1.selfref = ref2
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_dirty_keys_fails(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- ref1.id = 42
- ref2 = SelfRef()
- ref2.id = 43
- ref2.selfref = ref1
- ref1.selfref = ref2
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_dirty_keys_changed_later_fails(self):
- ref1 = SelfRef()
- ref2 = SelfRef()
- self.store.add(ref1)
- self.store.add(ref2)
- self.store.flush()
- ref2.selfref = ref1
- ref1.selfref = ref2
- ref1.id = 42
- ref2.id = 43
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_dirty_keys_on_remote_fails(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- ref1.id = 42
- ref2 = SelfRef()
- ref2.id = 43
- ref2.selfref_on_remote = ref1
- ref1.selfref_on_remote = ref2
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_dirty_keys_on_remote_changed_later_fails(self):
- ref1 = SelfRef()
- ref2 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref2.selfref_on_remote = ref1
- ref1.selfref_on_remote = ref2
- ref1.id = 42
- ref2.id = 43
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_unchanged_keys_succeeds(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- ref1.id = 42
- ref2 = SelfRef()
- self.store.add(ref2)
- ref1.id = 43
- self.store.flush()
- # As ref1 and ref2 have been flushed to the database, so these
- # changes can be flushed.
- ref2.selfref = ref1
- ref1.selfref = ref2
- self.store.flush()
- def test_reference_loop_with_one_unchanged_key_succeeds(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref2 = SelfRef()
- ref2.selfref = ref1
- ref1.selfref = ref2
- # As ref1 and ref2 have been flushed to the database, so these
- # changes can be flushed.
- self.store.flush()
- def test_reference_loop_with_key_changed_later_succeeds(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref2 = SelfRef()
- ref1.selfref = ref2
- ref2.id = 42
- self.store.flush()
- def test_reference_loop_with_key_changed_later_on_remote_succeeds(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref2 = SelfRef()
- ref2.selfref_on_remote = ref1
- ref2.id = 42
- self.store.flush()
- def test_reference_loop_with_undefined_and_changed_keys_fails(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref1.id = 400
- ref2 = SelfRef()
- ref2.selfref = ref1
- ref1.selfref = ref2
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_with_undefined_and_changed_keys_fails2(self):
- ref1 = SelfRef()
- self.store.add(ref1)
- self.store.flush()
- ref2 = SelfRef()
- ref2.selfref = ref1
- ref1.selfref = ref2
- ref1.id = 400
- self.assertRaises(OrderLoopError, self.store.flush)
- def test_reference_loop_broken_by_set(self):
- ref1 = SelfRef()
- ref2 = SelfRef()
- ref1.selfref = ref2
- ref2.selfref = ref1
- self.store.add(ref1)
- ref1.selfref = None
- self.store.flush()
- def test_reference_loop_set_only_removes_own_flush_order(self):
- ref1 = SelfRef()
- ref2 = SelfRef()
- self.store.add(ref2)
- self.store.flush()
- # The following does not create a loop since the keys are
- # dirty (as shown in another test).
- ref1.selfref = ref2
- ref2.selfref = ref1
- # Now add a flush order loop.
- self.store.add_flush_order(ref1, ref2)
- self.store.add_flush_order(ref2, ref1)
- # Now break the reference. This should leave the flush
- # ordering loop we previously created in place..
- ref1.selfref = None
- self.assertRaises(OrderLoopError, self.store.flush)
- def add_reference_set_bar_400(self):
- bar = Bar()
- bar.id = 400
- bar.foo_id = 20
- bar.title = u"Title 100"
- self.store.add(bar)
- def test_reference_set(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.foo_id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- def test_reference_set_assign_fails(self):
- foo = self.store.get(FooRefSet, 20)
- try:
- foo.bars = []
- except FeatureError:
- pass
- else:
- self.fail("FeatureError not raised")
- def test_reference_set_explicitly_with_wrapper(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- items = []
- for bar in FooRefSet.bars.__get__(Wrapper(foo)):
- items.append((bar.id, bar.foo_id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- def test_reference_set_with_added(self):
- bar1 = Bar()
- bar1.id = 400
- bar1.title = u"Title 400"
- bar2 = Bar()
- bar2.id = 500
- bar2.title = u"Title 500"
- foo = FooRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar1)
- foo.bars.add(bar2)
- self.store.add(foo)
- self.assertEquals(foo.id, None)
- self.assertEquals(bar1.foo_id, None)
- self.assertEquals(bar2.foo_id, None)
- self.assertEquals(list(foo.bars.order_by(Bar.id)),
- [bar1, bar2])
- self.assertEquals(type(foo.id), int)
- self.assertEquals(foo.id, bar1.foo_id)
- self.assertEquals(foo.id, bar2.foo_id)
- def test_reference_set_composed(self):
- self.add_reference_set_bar_400()
- bar = self.store.get(Bar, 400)
- bar.title = u"Title 20"
- class FooRefSetComposed(Foo):
- bars = ReferenceSet((Foo.id, Foo.title),
- (Bar.foo_id, Bar.title))
- foo = self.store.get(FooRefSetComposed, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (400, 20, "Title 20"),
- ])
- bar = self.store.get(Bar, 200)
- bar.title = u"Title 20"
- del items[:]
- for bar in foo.bars:
- items.append((bar.id, bar.foo_id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (200, 20, "Title 20"),
- (400, 20, "Title 20"),
- ])
- def test_reference_set_contains(self):
- def no_iter(self):
- raise RuntimeError()
- from storm.references import BoundReferenceSetBase
- orig_iter = BoundReferenceSetBase.__iter__
- BoundReferenceSetBase.__iter__ = no_iter
- try:
- foo = self.store.get(FooRefSet, 20)
- bar = self.store.get(Bar, 200)
- self.assertEquals(bar in foo.bars, True)
- finally:
- BoundReferenceSetBase.__iter__ = orig_iter
- def test_reference_set_find(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- items = []
- for bar in foo.bars.find():
- items.append((bar.id, bar.foo_id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- # Notice that there's another item with this title in the base,
- # which isn't part of the reference.
- objects = list(foo.bars.find(Bar.title == u"Title 100"))
- self.assertEquals(len(objects), 1)
- self.assertTrue(objects[0] is bar)
- objects = list(foo.bars.find(title=u"Title 100"))
- self.assertEquals(len(objects), 1)
- self.assertTrue(objects[0] is bar)
- def test_reference_set_clear(self):
- foo = self.store.get(FooRefSet, 20)
- foo.bars.clear()
- self.assertEquals(list(foo.bars), [])
- # Object wasn't removed.
- self.assertTrue(self.store.get(Bar, 200))
- def test_reference_set_clear_cached(self):
- foo = self.store.get(FooRefSet, 20)
- bar = self.store.get(Bar, 200)
- self.assertEquals(bar.foo_id, 20)
- foo.bars.clear()
- self.assertEquals(bar.foo_id, None)
- def test_reference_set_clear_where(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- foo.bars.clear(Bar.id > 200)
- items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- ])
- bar = self.store.get(Bar, 400)
- bar.foo_id = 20
- foo.bars.clear(id=200)
- items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [
- (400, 20, "Title 100"),
- ])
- def test_reference_set_count(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- self.assertEquals(foo.bars.count(), 2)
- def test_reference_set_order_by(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- items = []
- for bar in foo.bars.order_by(Bar.id):
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- del items[:]
- for bar in foo.bars.order_by(Bar.title):
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (400, 20, "Title 100"),
- (200, 20, "Title 200"),
- ])
- def test_reference_set_default_order_by(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSetOrderID, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- items = []
- for bar in foo.bars.find():
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- foo = self.store.get(FooRefSetOrderTitle, 20)
- del items[:]
- for bar in foo.bars:
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (400, 20, "Title 100"),
- (200, 20, "Title 200"),
- ])
- del items[:]
- for bar in foo.bars.find():
- items.append((bar.id, bar.foo_id, bar.title))
- self.assertEquals(items, [
- (400, 20, "Title 100"),
- (200, 20, "Title 200"),
- ])
- def test_reference_set_first_last(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSetOrderID, 20)
- self.assertEquals(foo.bars.first().id, 200)
- self.assertEquals(foo.bars.last().id, 400)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first().id, 400)
- self.assertEquals(foo.bars.last().id, 200)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(Bar.id > 400), None)
- self.assertEquals(foo.bars.last(Bar.id > 400), None)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(Bar.id < 400).id, 200)
- self.assertEquals(foo.bars.last(Bar.id < 400).id, 200)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(id=200).id, 200)
- self.assertEquals(foo.bars.last(id=200).id, 200)
- foo = self.store.get(FooRefSet, 20)
- self.assertRaises(UnorderedError, foo.bars.first)
- self.assertRaises(UnorderedError, foo.bars.last)
- def test_indirect_reference_set_any(self):
- """
- L{BoundReferenceSet.any} returns an arbitrary object from the set of
- referenced objects.
- """
- foo = self.store.get(FooRefSet, 20)
- self.assertNotEqual(None, foo.bars.any())
- def test_indirect_reference_set_any_filtered(self):
- """
- L{BoundReferenceSet.any} optionally takes a list of filtering criteria
- to narrow the set of objects to search. When provided, the criteria
- are used to filter the set before returning an arbitrary object.
- """
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(Bar.id > 400), None)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(Bar.id < 400).id, 200)
- foo = self.store.get(FooRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(id=200).id, 200)
- def test_reference_set_one(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSetOrderID, 20)
- self.assertRaises(NotOneError, foo.bars.one)
- foo = self.store.get(FooRefSetOrderID, 30)
- self.assertEquals(foo.bars.one().id, 300)
- foo = self.store.get(FooRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(Bar.id > 400), None)
- foo = self.store.get(FooRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(Bar.id < 400).id, 200)
- foo = self.store.get(FooRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(id=200).id, 200)
- def test_reference_set_remove(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSet, 20)
- for bar in foo.bars:
- foo.bars.remove(bar)
- self.assertEquals(bar.foo_id, None)
- self.assertEquals(list(foo.bars), [])
- def test_reference_set_after_object_removed(self):
- class MyBar(Bar):
- # Make sure that this works even with allow_none=False.
- foo_id = Int(allow_none=False)
- class MyFoo(Foo):
- bars = ReferenceSet(Foo.id, MyBar.foo_id)
- foo = self.store.get(MyFoo, 20)
- bar = foo.bars.any()
- self.store.remove(bar)
- self.assertTrue(bar not in list(foo.bars))
- def test_reference_set_add(self):
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 100"
- foo = self.store.get(FooRefSet, 20)
- foo.bars.add(bar)
- self.assertEquals(bar.foo_id, 20)
- self.assertEquals(Store.of(bar), self.store)
- def test_reference_set_add_no_store(self):
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- foo = FooRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar)
- self.store.add(foo)
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(Store.of(bar), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_reference_set_add_no_store_2(self):
- bar = Bar()
- bar.id = 400
- bar.title = u"Title 400"
- foo = FooRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar)
- self.store.add(bar)
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(Store.of(bar), self.store)
- self.store.flush()
- self.assertEquals(type(bar.foo_id), int)
- def test_reference_set_add_no_store_unlink_after_adding(self):
- bar1 = Bar()
- bar1.id = 400
- bar1.title = u"Title 400"
- bar2 = Bar()
- bar2.id = 500
- bar2.title = u"Title 500"
- foo = FooRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar1)
- foo.bars.add(bar2)
- foo.bars.remove(bar1)
- self.store.add(foo)
- store = self.create_store()
- store.add(bar1)
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(Store.of(bar1), store)
- self.assertEquals(Store.of(bar2), self.store)
- def test_reference_set_values(self):
- self.add_reference_set_bar_400()
- foo = self.store.get(FooRefSetOrderID, 20)
- values = list(foo.bars.values(Bar.id, Bar.foo_id, Bar.title))
- self.assertEquals(values, [
- (200, 20, "Title 200"),
- (400, 20, "Title 100"),
- ])
- def test_indirect_reference_set(self):
- foo = self.store.get(FooIndRefSet, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- def test_indirect_reference_set_with_added(self):
- bar1 = Bar()
- bar1.id = 400
- bar1.title = u"Title 400"
- bar2 = Bar()
- bar2.id = 500
- bar2.title = u"Title 500"
- self.store.add(bar1)
- self.store.add(bar2)
- foo = FooIndRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar1)
- foo.bars.add(bar2)
- self.assertEquals(foo.id, None)
- self.store.add(foo)
- self.assertEquals(foo.id, None)
- self.assertEquals(bar1.foo_id, None)
- self.assertEquals(bar2.foo_id, None)
- self.assertEquals(list(foo.bars.order_by(Bar.id)),
- [bar1, bar2])
- self.assertEquals(type(foo.id), int)
- self.assertEquals(type(bar1.id), int)
- self.assertEquals(type(bar2.id), int)
- def test_indirect_reference_set_find(self):
- foo = self.store.get(FooIndRefSet, 20)
- items = []
- for bar in foo.bars.find(Bar.title == u"Title 300"):
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- ])
- def test_indirect_reference_set_clear(self):
- foo = self.store.get(FooIndRefSet, 20)
- foo.bars.clear()
- self.assertEquals(list(foo.bars), [])
- def test_indirect_reference_set_clear_where(self):
- foo = self.store.get(FooIndRefSet, 20)
- items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [
- (100, 10, "Title 300"),
- (200, 20, "Title 200"),
- ])
- foo = self.store.get(FooIndRefSet, 30)
- foo.bars.clear(Bar.id < 300)
- foo.bars.clear(id=200)
- foo = self.store.get(FooIndRefSet, 20)
- foo.bars.clear(Bar.id < 200)
- items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [
- (200, 20, "Title 200"),
- ])
- foo.bars.clear(id=200)
- items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [])
- def test_indirect_reference_set_count(self):
- foo = self.store.get(FooIndRefSet, 20)
- self.assertEquals(foo.bars.count(), 2)
- def test_indirect_reference_set_order_by(self):
- foo = self.store.get(FooIndRefSet, 20)
- items = []
- for bar in foo.bars.order_by(Bar.title):
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (200, "Title 200"),
- (100, "Title 300"),
- ])
- del items[:]
- for bar in foo.bars.order_by(Bar.id):
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- def test_indirect_reference_set_default_order_by(self):
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (200, "Title 200"),
- (100, "Title 300"),
- ])
- del items[:]
- for bar in foo.bars.find():
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (200, "Title 200"),
- (100, "Title 300"),
- ])
- foo = self.store.get(FooIndRefSetOrderID, 20)
- del items[:]
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- del items[:]
- for bar in foo.bars.find():
- items.append((bar.id, bar.title))
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- def test_indirect_reference_set_first_last(self):
- foo = self.store.get(FooIndRefSetOrderID, 20)
- self.assertEquals(foo.bars.first().id, 100)
- self.assertEquals(foo.bars.last().id, 200)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first().id, 200)
- self.assertEquals(foo.bars.last().id, 100)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(Bar.id > 200), None)
- self.assertEquals(foo.bars.last(Bar.id > 200), None)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(Bar.id < 200).id, 100)
- self.assertEquals(foo.bars.last(Bar.id < 200).id, 100)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.first(id=200).id, 200)
- self.assertEquals(foo.bars.last(id=200).id, 200)
- foo = self.store.get(FooIndRefSet, 20)
- self.assertRaises(UnorderedError, foo.bars.first)
- self.assertRaises(UnorderedError, foo.bars.last)
- def test_indirect_reference_set_any(self):
- """
- L{BoundIndirectReferenceSet.any} returns an arbitrary object from the
- set of referenced objects.
- """
- foo = self.store.get(FooIndRefSet, 20)
- self.assertNotEqual(None, foo.bars.any())
- def test_indirect_reference_set_any_filtered(self):
- """
- L{BoundIndirectReferenceSet.any} optionally takes a list of filtering
- criteria to narrow the set of objects to search. When provided, the
- criteria are used to filter the set before returning an arbitrary
- object.
- """
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(Bar.id > 200), None)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(Bar.id < 200).id, 100)
- foo = self.store.get(FooIndRefSetOrderTitle, 20)
- self.assertEquals(foo.bars.any(id=200).id, 200)
- def test_indirect_reference_set_one(self):
- foo = self.store.get(FooIndRefSetOrderID, 20)
- self.assertRaises(NotOneError, foo.bars.one)
- foo = self.store.get(FooIndRefSetOrderID, 30)
- self.assertEquals(foo.bars.one().id, 300)
- foo = self.store.get(FooIndRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(Bar.id > 200), None)
- foo = self.store.get(FooIndRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(Bar.id < 200).id, 100)
- foo = self.store.get(FooIndRefSetOrderID, 20)
- self.assertEquals(foo.bars.one(id=200).id, 200)
- def test_indirect_reference_set_add(self):
- foo = self.store.get(FooIndRefSet, 20)
- bar = self.store.get(Bar, 300)
- foo.bars.add(bar)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- (300, "Title 100"),
- ])
- def test_indirect_reference_set_remove(self):
- foo = self.store.get(FooIndRefSet, 20)
- bar = self.store.get(Bar, 200)
- foo.bars.remove(bar)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- ])
- def test_indirect_reference_set_add_remove(self):
- foo = self.store.get(FooIndRefSet, 20)
- bar = self.store.get(Bar, 300)
- foo.bars.add(bar)
- foo.bars.remove(bar)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- def test_indirect_reference_set_add_remove_with_wrapper(self):
- foo = self.store.get(FooIndRefSet, 20)
- bar300 = self.store.get(Bar, 300)
- bar200 = self.store.get(Bar, 200)
- foo.bars.add(Wrapper(bar300))
- foo.bars.remove(Wrapper(bar200))
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- (300, "Title 100"),
- ])
- def test_indirect_reference_set_add_remove_with_added(self):
- foo = FooIndRefSet()
- foo.id = 40
- bar1 = Bar()
- bar1.id = 400
- bar1.title = u"Title 400"
- bar2 = Bar()
- bar2.id = 500
- bar2.title = u"Title 500"
- self.store.add(foo)
- self.store.add(bar1)
- self.store.add(bar2)
- foo.bars.add(bar1)
- foo.bars.add(bar2)
- foo.bars.remove(bar1)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (500, "Title 500"),
- ])
- def test_indirect_reference_set_with_added_no_store(self):
- bar1 = Bar()
- bar1.id = 400
- bar1.title = u"Title 400"
- bar2 = Bar()
- bar2.id = 500
- bar2.title = u"Title 500"
- foo = FooIndRefSet()
- foo.title = u"Title 40"
- foo.bars.add(bar1)
- foo.bars.add(bar2)
- self.store.add(bar1)
- self.assertEquals(Store.of(foo), self.store)
- self.assertEquals(Store.of(bar1), self.store)
- self.assertEquals(Store.of(bar2), self.store)
- self.assertEquals(foo.id, None)
- self.assertEquals(bar1.foo_id, None)
- self.assertEquals(bar2.foo_id, None)
- self.assertEquals(list(foo.bars.order_by(Bar.id)),
- [bar1, bar2])
- def test_indirect_reference_set_values(self):
- foo = self.store.get(FooIndRefSetOrderID, 20)
- values = list(foo.bars.values(Bar.id, Bar.foo_id, Bar.title))
- self.assertEquals(values, [
- (100, 10, "Title 300"),
- (200, 20, "Title 200"),
- ])
- def test_references_raise_nostore(self):
- foo1 = FooRefSet()
- foo2 = FooIndRefSet()
- self.assertRaises(NoStoreError, foo1.bars.__iter__)
- self.assertRaises(NoStoreError, foo2.bars.__iter__)
- self.assertRaises(NoStoreError, foo1.bars.find)
- self.assertRaises(NoStoreError, foo2.bars.find)
- self.assertRaises(NoStoreError, foo1.bars.order_by)
- self.assertRaises(NoStoreError, foo2.bars.order_by)
- self.assertRaises(NoStoreError, foo1.bars.count)
- self.assertRaises(NoStoreError, foo2.bars.count)
- self.assertRaises(NoStoreError, foo1.bars.clear)
- self.assertRaises(NoStoreError, foo2.bars.clear)
- self.assertRaises(NoStoreError, foo2.bars.remove, object())
- def test_string_reference(self):
- class Base(object):
- __metaclass__ = PropertyPublisherMeta
- class MyBar(Base):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- foo_id = Int()
- foo = Reference("foo_id", "MyFoo.id")
- class MyFoo(Base):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode()
- bar = self.store.get(MyBar, 100)
- self.assertTrue(bar.foo)
- self.assertEquals(bar.foo.title, "Title 30")
- self.assertEquals(type(bar.foo), MyFoo)
- def test_string_indirect_reference_set(self):
- """
- A L{ReferenceSet} can have its reference keys specified as strings
- when the class its a member of uses the L{PropertyPublisherMeta}
- metaclass. This makes it possible to work around problems with
- circular dependencies by delaying property resolution.
- """
- class Base(object):
- __metaclass__ = PropertyPublisherMeta
- class MyFoo(Base):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode()
- bars = ReferenceSet("id", "MyLink.foo_id",
- "MyLink.bar_id", "MyBar.id")
- class MyBar(Base):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- class MyLink(Base):
- __storm_table__ = "link"
- __storm_primary__ = "foo_id", "bar_id"
- foo_id = Int()
- bar_id = Int()
- foo = self.store.get(MyFoo, 20)
- items = []
- for bar in foo.bars:
- items.append((bar.id, bar.title))
- items.sort()
- self.assertEquals(items, [
- (100, "Title 300"),
- (200, "Title 200"),
- ])
- def test_string_reference_set_order_by(self):
- """
- A L{ReferenceSet} can have its default order by specified as a string
- when the class its a member of uses the L{PropertyPublisherMeta}
- metaclass. This makes it possible to work around problems with
- circular dependencies by delaying resolution of the order by column.
- """
- class Base(object):
- __metaclass__ = PropertyPublisherMeta
- class MyFoo(Base):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode()
- bars = ReferenceSet("id", "MyLink.foo_id",
- "MyLink.bar_id", "MyBar.id",
- order_by="MyBar.title")
- class MyBar(Base):
- __storm_table__ = "bar"
- id = Int(primary=True)
- title = Unicode()
- class MyLink(Base):
- __storm_table__ = "link"
- __storm_primary__ = "foo_id", "bar_id"
- foo_id = Int()
- bar_id = Int()
- foo = self.store.get(MyFoo, 20)
- items = [(bar.id, bar.title) for bar in foo.bars]
- self.assertEquals(items, [(200, "Title 200"), (100, "Title 300")])
- def test_flush_order(self):
- foo1 = Foo()
- foo2 = Foo()
- foo3 = Foo()
- foo4 = Foo()
- foo5 = Foo()
- for i, foo in enumerate([foo1, foo2, foo3, foo4, foo5]):
- foo.title = u"Object %d" % (i+1)
- self.store.add(foo)
- self.store.add_flush_order(foo2, foo4)
- self.store.add_flush_order(foo4, foo1)
- self.store.add_flush_order(foo1, foo3)
- self.store.add_flush_order(foo3, foo5)
- self.store.add_flush_order(foo5, foo2)
- self.store.add_flush_order(foo5, foo2)
- self.assertRaises(OrderLoopError, self.store.flush)
- self.store.remove_flush_order(foo5, foo2)
- self.assertRaises(OrderLoopError, self.store.flush)
- self.store.remove_flush_order(foo5, foo2)
- self.store.flush()
- self.assertTrue(foo2.id < foo4.id)
- self.assertTrue(foo4.id < foo1.id)
- self.assertTrue(foo1.id < foo3.id)
- self.assertTrue(foo3.id < foo5.id)
- def test_variable_filter_on_load(self):
- foo = self.store.get(FooVariable, 20)
- self.assertEquals(foo.title, "to_py(from_db(Title 20))")
- def test_variable_filter_on_update(self):
- foo = self.store.get(FooVariable, 20)
- foo.title = u"Title 20"
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "to_db(from_py(Title 20))"),
- (30, "Title 10"),
- ])
- def test_variable_filter_on_update_unchanged(self):
- foo = self.store.get(FooVariable, 20)
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_variable_filter_on_insert(self):
- foo = FooVariable()
- foo.id = 40
- foo.title = u"Title 40"
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- (40, "to_db(from_py(Title 40))"),
- ])
- def test_variable_filter_on_missing_values(self):
- foo = FooVariable()
- foo.id = 40
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(foo.title, "to_py(from_db(Default Title))")
- def test_variable_filter_on_set(self):
- foo = FooVariable()
- self.store.find(FooVariable, id=20).set(title=u"Title 20")
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "to_db(from_py(Title 20))"),
- (30, "Title 10"),
- ])
- def test_variable_filter_on_set_expr(self):
- foo = FooVariable()
- result = self.store.find(FooVariable, id=20)
- result.set(FooVariable.title == u"Title 20")
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "to_db(from_py(Title 20))"),
- (30, "Title 10"),
- ])
- def test_wb_result_set_variable(self):
- Result = self.store._connection.result_factory
- class MyResult(Result):
- def set_variable(self, variable, value):
- if variable.__class__ is UnicodeVariable:
- variable.set(u"set_variable(%s)" % value)
- elif variable.__class__ is IntVariable:
- variable.set(value+1)
- else:
- variable.set(value)
- self.store._connection.result_factory = MyResult
- try:
- foo = self.store.get(Foo, 20)
- finally:
- self.store._connection.result_factory = Result
- self.assertEquals(foo.id, 21)
- self.assertEquals(foo.title, "set_variable(Title 20)")
- def test_default(self):
- class MyFoo(Foo):
- title = Unicode(default=u"Some default value")
- foo = MyFoo()
- self.store.add(foo)
- self.store.flush()
- result = self.store.execute("SELECT title FROM foo WHERE id=?",
- (foo.id,))
- self.assertEquals(result.get_one(), ("Some default value",))
- self.assertEquals(foo.title, "Some default value")
- def test_default_factory(self):
- class MyFoo(Foo):
- title = Unicode(default_factory=lambda:u"Some default value")
- foo = MyFoo()
- self.store.add(foo)
- self.store.flush()
- result = self.store.execute("SELECT title FROM foo WHERE id=?",
- (foo.id,))
- self.assertEquals(result.get_one(), ("Some default value",))
- self.assertEquals(foo.title, "Some default value")
- def test_pickle_variable(self):
- class PickleBlob(Blob):
- bin = Pickle()
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- self.assertEquals(pickle_blob.bin["a"], 1)
- pickle_blob.bin["b"] = 2
- self.store.flush()
- self.store.reload(blob)
- self.assertEquals(blob.bin, "\x80\x02}q\x01(U\x01aK\x01U\x01bK\x02u.")
- def test_pickle_variable_remove(self):
- """
- When an object is removed from a store, it should unhook from the
- "flush" event emitted by the store, and thus not emit a "changed" event
- if its content change and that the store is flushed.
- """
- class PickleBlob(Blob):
- bin = Pickle()
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- self.store.remove(pickle_blob)
- self.store.flush()
- # Let's change the object
- pickle_blob.bin = "foobin"
- # And subscribe to its changed event
- obj_info = get_obj_info(pickle_blob)
- events = []
- obj_info.event.hook("changed", lambda *args: events.append(args))
- self.store.flush()
- self.assertEquals(events, [])
- def test_pickle_variable_unhook(self):
- """
- A variable instance must unhook itself from the store event system when
- the store invalidates its objects.
- """
- # I create a custom PickleVariable, with no __slots__ definition, to be
- # able to get a weakref of it, thing that I can't do with
- # PickleVariable that defines __slots__ *AND* those parent is the C
- # implementation of Variable
- class CustomPickleVariable(PickleVariable):
- pass
- class CustomPickle(Pickle):
- variable_class = CustomPickleVariable
- class PickleBlob(Blob):
- bin = CustomPickle()
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- self.store.flush()
- self.store.invalidate()
- obj_info = get_obj_info(pickle_blob)
- variable = obj_info.variables[PickleBlob.bin]
- var_ref = weakref.ref(variable)
- del variable, blob, pickle_blob, obj_info
- gc.collect()
- self.assertTrue(var_ref() is None)
- def test_pickle_variable_referenceset(self):
- """
- A variable instance must unhook itself from the store event system
- explcitely when the store invalidates its objects: it's particulary
- important when a ReferenceSet is used, because it keeps strong
- references to objects involved.
- """
- class CustomPickleVariable(PickleVariable):
- pass
- class CustomPickle(Pickle):
- variable_class = CustomPickleVariable
- class PickleBlob(Blob):
- bin = CustomPickle()
- foo_id = Int()
- class FooBlobRefSet(Foo):
- blobs = ReferenceSet(Foo.id, PickleBlob.foo_id)
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- foo = self.store.get(FooBlobRefSet, 10)
- foo.blobs.add(pickle_blob)
- self.store.flush()
- self.store.invalidate()
- obj_info = get_obj_info(pickle_blob)
- variable = obj_info.variables[PickleBlob.bin]
- var_ref = weakref.ref(variable)
- del variable, blob, pickle_blob, obj_info, foo
- gc.collect()
- self.assertTrue(var_ref() is None)
- def test_pickle_variable_referenceset_several_transactions(self):
- """
- Check that a pickle variable fires the changed event when used among
- several transactions.
- """
- class PickleBlob(Blob):
- bin = Pickle()
- foo_id = Int()
- class FooBlobRefSet(Foo):
- blobs = ReferenceSet(Foo.id, PickleBlob.foo_id)
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- foo = self.store.get(FooBlobRefSet, 10)
- foo.blobs.add(pickle_blob)
- self.store.flush()
- self.store.invalidate()
- self.store.reload(pickle_blob)
- pickle_blob.bin = "foo"
- obj_info = get_obj_info(pickle_blob)
- events = []
- obj_info.event.hook("changed", lambda *args: events.append(args))
- self.store.flush()
- self.assertEquals(len(events), 1)
- def test_undefined_variables_filled_on_find(self):
- """
- Check that when data is fetched from the database on a find,
- it is used to fill up any undefined variables.
- """
- # We do a first find to get the object_infos into the cache.
- foos = list(self.store.find(Foo, title=u"Title 20"))
- # Commit so that all foos are invalidated and variables are
- # set back to AutoReload.
- self.store.commit()
- # Another find which should reuse in-memory foos.
- for foo in self.store.find(Foo, title=u"Title 20"):
- # Make sure we have all variables defined, because
- # values were already retrieved by the find's select.
- obj_info = get_obj_info(foo)
- for column in obj_info.variables:
- self.assertTrue(obj_info.variables[column].is_defined())
- def test_storm_loaded_after_define(self):
- """
- C{__storm_loaded__} is only called once all the variables are correctly
- defined in the object. If the object is in the alive cache but
- disappeared, it used to be called without its variables defined.
- """
- # Disable the cache, which holds strong references.
- self.get_cache(self.store).set_size(0)
- loaded = []
- class MyFoo(Foo):
- def __storm_loaded__(oself):
- loaded.append(None)
- obj_info = get_obj_info(oself)
- for column in obj_info.variables:
- self.assertTrue(obj_info.variables[column].is_defined())
- foo = self.store.get(MyFoo, 20)
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- self.assertEquals(obj_info.get_obj(), None)
- # Commit so that all foos are invalidated and variables are
- # set back to AutoReload.
- self.store.commit()
- foo = self.store.find(MyFoo, title=u"Title 20").one()
- self.assertEquals(foo.id, 20)
- self.assertEquals(len(loaded), 2)
- def test_defined_variables_not_overridden_on_find(self):
- """
- Check that the keep_defined=True setting in _load_object()
- is in place. In practice, it ensures that already defined
- values aren't replaced during a find, when new data comes
- from the database and is used whenever possible.
- """
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- class PickleBlob(object):
- __storm_table__ = "bin"
- id = Int(primary=True)
- pickle = Pickle("bin")
- blob = self.store.get(PickleBlob, 20)
- value = blob.pickle
- # Now the find should not destroy our value pointer.
- blob = self.store.find(PickleBlob, id=20).one()
- self.assertTrue(value is blob.pickle)
- def test_pickle_variable_with_deleted_object(self):
- class PickleBlob(Blob):
- bin = Pickle()
- blob = self.store.get(Blob, 20)
- blob.bin = "\x80\x02}q\x01U\x01aK\x01s."
- self.store.flush()
- pickle_blob = self.store.get(PickleBlob, 20)
- self.assertEquals(pickle_blob.bin["a"], 1)
- pickle_blob.bin["b"] = 2
- del pickle_blob
- gc.collect()
- self.store.flush()
- self.store.reload(blob)
- self.assertEquals(blob.bin, "\x80\x02}q\x01(U\x01aK\x01U\x01bK\x02u.")
- def test_unhashable_object(self):
- class DictFoo(Foo, dict):
- pass
- foo = self.store.get(DictFoo, 20)
- foo["a"] = 1
- self.assertEquals(foo.items(), [("a", 1)])
- new_obj = DictFoo()
- new_obj.id = 40
- new_obj.title = u"My Title"
- self.store.add(new_obj)
- self.store.commit()
- self.assertTrue(self.store.get(DictFoo, 40) is new_obj)
- def test_wrapper(self):
- foo = self.store.get(Foo, 20)
- wrapper = Wrapper(foo)
- self.store.remove(wrapper)
- self.store.flush()
- self.assertEquals(self.store.get(Foo, 20), None)
- def test_rollback_loaded_and_still_in_cached(self):
- # Explore problem found on interaction between caching, commits,
- # and rollbacks, when they still existed.
- foo1 = self.store.get(Foo, 20)
- self.store.commit()
- self.store.rollback()
- foo2 = self.store.get(Foo, 20)
- self.assertTrue(foo1 is foo2)
- def test_class_alias(self):
- FooAlias = ClassAlias(Foo)
- result = self.store.find(FooAlias, FooAlias.id < Foo.id)
- self.assertEquals([(foo.id, foo.title) for foo in result
- if type(foo) is Foo], [
- (10, "Title 30"),
- (10, "Title 30"),
- (20, "Title 20"),
- ])
- def test_expr_values(self):
- foo = self.store.get(Foo, 20)
- foo.title = SQL("'New title'")
- # No commits yet.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.store.flush()
- # Now it should be there.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "New title"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, "New title")
- def test_expr_values_flush_on_demand(self):
- foo = self.store.get(Foo, 20)
- foo.title = SQL("'New title'")
- # No commits yet.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, "New title")
- # Now it should be there.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "New title"),
- (30, "Title 10"),
- ])
- def test_expr_values_flush_and_load_in_separate_steps(self):
- foo = self.store.get(Foo, 20)
- foo.title = SQL("'New title'")
- self.store.flush()
- # It's already in the database.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "New title"),
- (30, "Title 10"),
- ])
- # But our value is now an AutoReload.
- lazy_value = get_obj_info(foo).variables[Foo.title].get_lazy()
- self.assertTrue(lazy_value is AutoReload)
- # Which gets resolved once touched.
- self.assertEquals(foo.title, u"New title")
- def test_expr_values_flush_on_demand_with_added(self):
- foo = Foo()
- foo.id = 40
- foo.title = SQL("'New title'")
- self.store.add(foo)
- # No commits yet.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, "New title")
- # Now it should be there.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- (40, "New title"),
- ])
- def test_expr_values_flush_on_demand_with_removed_and_added(self):
- foo = self.store.get(Foo, 20)
- foo.title = SQL("'New title'")
- self.store.remove(foo)
- self.store.add(foo)
- # No commits yet.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, "New title")
- # Now it should be there.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "New title"),
- (30, "Title 10"),
- ])
- def test_expr_values_flush_on_demand_with_removed_and_rollbacked(self):
- foo = self.store.get(Foo, 20)
- self.store.remove(foo)
- self.store.rollback()
- foo.title = SQL("'New title'")
- # No commits yet.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, "New title")
- # Now it should be there.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "New title"),
- (30, "Title 10"),
- ])
- def test_expr_values_flush_on_demand_with_added_and_removed(self):
- # This test tries to trigger a problem in a few different ways.
- # It uses the same id of an existing object, and add and remove
- # the object. This object should never get in the database, nor
- # update the object that is already there, nor flush any other
- # pending changes when the lazy value is accessed.
- foo = Foo()
- foo.id = 20
- foo_dep = Foo()
- foo_dep.id = 50
- self.store.add(foo)
- self.store.add(foo_dep)
- foo.title = SQL("'New title'")
- # Add ordering to see if it helps triggering a bug of
- # incorrect flushing.
- self.store.add_flush_order(foo_dep, foo)
- self.store.remove(foo)
- # No changes.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, None)
- # Still no changes. There's no reason why foo_dep would be flushed.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_expr_values_flush_on_demand_with_removed(self):
- # Similar case, but removing an existing object instead.
- foo = self.store.get(Foo, 20)
- foo_dep = Foo()
- foo_dep.id = 50
- self.store.add(foo_dep)
- foo.title = SQL("'New title'")
- # Add ordering to see if it helps triggering a bug of
- # incorrect flushing.
- self.store.add_flush_order(foo_dep, foo)
- self.store.remove(foo)
- # No changes.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- self.assertEquals(foo.title, None)
- # Still no changes. There's no reason why foo_dep would be flushed.
- self.assertEquals(self.get_items(), [
- (10, "Title 30"),
- (20, "Title 20"),
- (30, "Title 10"),
- ])
- def test_lazy_value_preserved_with_subsequent_object_initialization(self):
- """
- If a lazy value has been modified on an object that is subsequently
- initialized from the database the lazy value is correctly preserved
- and the object is initialized properly. This tests the fix for the
- problem reported in bug #620615.
- """
- # Retrieve an object, fully loaded.
- foo = self.store.get(Foo, 20)
- # Build and retrieve a result set ahead of time, so that
- # flushes won't happen when actually loading the object.
- result = self.store.find(Foo, Foo.id == 20)
- # Now, set an unflushed lazy value on an attribute.
- foo.title = SQL("'New title'")
- # Finally, get the existing object.
- foo = result.one()
- # We don't really have to test anything here, since the
- # explosion happened above, but here it is anyway.
- self.assertEquals(foo.title, "New title")
- def test_lazy_value_discarded_on_reload(self):
- """
- A counter-test to the above logic, also related to bug #620615. On
- an explicit reload, the lazy value must be discarded.
- """
- # Retrieve an object, fully loaded.
- foo = self.store.get(Foo, 20)
- # Build and retrieve a result set ahead of time, so that
- # flushes won't happen when actually loading the object.
- result = self.store.find(Foo, Foo.id == 20)
- # Now, set an unflushed lazy value on an attribute.
- foo.title = SQL("'New title'")
- # Give up on this and reload the original object.
- self.store.reload(foo)
- # We don't really have to test anything here, since the
- # explosion happened above, but here it is anyway.
- self.assertEquals(foo.title, "Title 20")
- def test_expr_values_with_columns(self):
- bar = self.store.get(Bar, 200)
- bar.foo_id = Bar.id+1
- self.assertEquals(bar.foo_id, 201)
- def test_autoreload_attribute(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- foo.title = AutoReload
- self.assertEquals(foo.title, "New Title")
- self.assertFalse(get_obj_info(foo).variables[Foo.title].has_changed())
- def test_autoreload_attribute_with_changed_primary_key(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- foo.id = 40
- foo.title = AutoReload
- self.assertEquals(foo.title, "New Title")
- self.assertEquals(foo.id, 40)
- def test_autoreload_object(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- self.store.autoreload(foo)
- self.assertEquals(foo.title, "New Title")
- def test_autoreload_primary_key_of_unflushed_object(self):
- foo = Foo()
- self.store.add(foo)
- foo.id = AutoReload
- foo.title = u"New Title"
- self.assertTrue(isinstance(foo.id, (int, long)))
- self.assertEquals(foo.title, "New Title")
- def test_autoreload_primary_key_doesnt_reload_everything_else(self):
- foo = self.store.get(Foo, 20)
- self.store.autoreload(foo)
- obj_info = get_obj_info(foo)
- self.assertEquals(obj_info.variables[Foo.id].get_lazy(), None)
- self.assertEquals(obj_info.variables[Foo.title].get_lazy(), AutoReload)
- self.assertEquals(foo.id, 20)
- self.assertEquals(obj_info.variables[Foo.id].get_lazy(), None)
- self.assertEquals(obj_info.variables[Foo.title].get_lazy(), AutoReload)
- def test_autoreload_all_objects(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.assertEquals(foo.title, "Title 20")
- self.store.autoreload()
- self.assertEquals(foo.title, "New Title")
- def test_autoreload_and_get_will_not_reload(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("UPDATE foo SET title='New Title' WHERE id=20")
- self.store.autoreload(foo)
- obj_info = get_obj_info(foo)
- self.assertEquals(obj_info.variables[Foo.title].get_lazy(), AutoReload)
- self.store.get(Foo, 20)
- self.assertEquals(obj_info.variables[Foo.title].get_lazy(), AutoReload)
- self.assertEquals(foo.title, "New Title")
- def test_autoreload_object_doesnt_tag_as_dirty(self):
- foo = self.store.get(Foo, 20)
- self.store.autoreload(foo)
- self.assertTrue(get_obj_info(foo) not in self.store._dirty)
- def test_autoreload_missing_columns_on_insertion(self):
- foo = Foo()
- self.store.add(foo)
- self.store.flush()
- lazy_value = get_obj_info(foo).variables[Foo.title].get_lazy()
- self.assertEquals(lazy_value, AutoReload)
- self.assertEquals(foo.title, u"Default Title")
- def test_reference_break_on_local_diverged_doesnt_autoreload(self):
- foo = self.store.get(Foo, 10)
- self.store.autoreload(foo)
- bar = self.store.get(Bar, 100)
- self.assertTrue(bar.foo)
- bar.foo_id = 40
- self.assertEquals(bar.foo, None)
- obj_info = get_obj_info(foo)
- self.assertEquals(obj_info.variables[Foo.title].get_lazy(), AutoReload)
- def test_primary_key_reference(self):
- """
- When an object references another one using its primary key, it
- correctly checks for the invalidated state after the store has been
- committed, detecting if the referenced object has been removed behind
- its back.
- """
- class BarOnRemote(object):
- __storm_table__ = "bar"
- foo_id = Int(primary=True)
- foo = Reference(foo_id, Foo.id, on_remote=True)
- foo = self.store.get(Foo, 10)
- bar = self.store.get(BarOnRemote, 10)
- self.assertEqual(bar.foo, foo)
- self.store.execute("DELETE FROM foo WHERE id = 10")
- self.store.commit()
- self.assertEqual(bar.foo, None)
- def test_invalidate_and_get_object(self):
- foo = self.store.get(Foo, 20)
- self.store.invalidate(foo)
- self.assertEquals(self.store.get(Foo, 20), foo)
- self.assertEquals(self.store.find(Foo, id=20).one(), foo)
- def test_invalidate_and_get_removed_object(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.store.invalidate(foo)
- self.assertEquals(self.store.get(Foo, 20), None)
- self.assertEquals(self.store.find(Foo, id=20).one(), None)
- def test_invalidate_and_validate_with_find(self):
- foo = self.store.get(Foo, 20)
- self.store.invalidate(foo)
- self.assertEquals(self.store.find(Foo, id=20).one(), foo)
- # Cache should be considered valid again at this point.
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.assertEquals(self.store.get(Foo, 20), foo)
- def test_invalidate_object_gets_validated(self):
- foo = self.store.get(Foo, 20)
- self.store.invalidate(foo)
- self.assertEquals(self.store.get(Foo, 20), foo)
- # At this point the object is valid again, so deleting it
- # from the database directly shouldn't affect caching.
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.assertEquals(self.store.get(Foo, 20), foo)
- def test_invalidate_object_with_only_primary_key(self):
- link = self.store.get(Link, (20, 200))
- self.store.execute("DELETE FROM link WHERE foo_id=20 AND bar_id=200")
- self.store.invalidate(link)
- self.assertEquals(self.store.get(Link, (20, 200)), None)
- def test_invalidate_added_object(self):
- foo = Foo()
- self.store.add(foo)
- self.store.invalidate(foo)
- foo.id = 40
- foo.title = u"Title 40"
- self.store.flush()
- # Object must have a valid cache at this point, since it was
- # just added.
- self.store.execute("DELETE FROM foo WHERE id=40")
- self.assertEquals(self.store.get(Foo, 40), foo)
- def test_invalidate_and_update(self):
- foo = self.store.get(Foo, 20)
- self.store.execute("DELETE FROM foo WHERE id=20")
- self.store.invalidate(foo)
- self.assertRaises(LostObjectError, setattr, foo, "title", u"Title 40")
- def test_invalidate_and_get_returns_autoreloaded(self):
- foo = self.store.get(Foo, 20)
- self.store.invalidate(foo)
- foo = self.store.get(Foo, 20)
- self.assertEquals(get_obj_info(foo).variables[Foo.title].get_lazy(),
- AutoReload)
- self.assertEquals(foo.title, "Title 20")
- def test_invalidated_hook(self):
- called = []
- class MyFoo(Foo):
- def __storm_invalidated__(self):
- called.append(True)
- foo = self.store.get(MyFoo, 20)
- self.assertEquals(called, [])
- self.store.autoreload(foo)
- self.assertEquals(called, [])
- self.store.invalidate(foo)
- self.assertEquals(called, [True])
- def test_invalidated_hook_called_after_all_invalidated(self):
- """
- Ensure that invalidated hooks are called only when all objects have
- already been marked as invalidated. See comment in
- store.py:_mark_autoreload.
- """
- called = []
- class MyFoo(Foo):
- def __storm_invalidated__(self):
- if not called:
- called.append(get_obj_info(foo1).get("invalidated"))
- called.append(get_obj_info(foo2).get("invalidated"))
- foo1 = self.store.get(MyFoo, 10)
- foo2 = self.store.get(MyFoo, 20)
- self.store.invalidate()
- self.assertEquals(called, [True, True])
- def test_reset_recreates_objects(self):
- """
- After resetting the store, all queries return fresh objects, even if
- there are other objects representing the same database rows still in
- memory.
- """
- foo1 = self.store.get(Foo, 10)
- foo1.dirty = True
- self.store.reset()
- new_foo1 = self.store.get(Foo, 10)
- self.assertFalse(hasattr(new_foo1, "dirty"))
- self.assertNotIdentical(new_foo1, foo1)
- def test_reset_unmarks_dirty(self):
- """
- If an object was dirty when store.reset() is called, its changes will
- not be affected.
- """
- foo1 = self.store.get(Foo, 10)
- foo1_title = foo1.title
- foo1.title = u"radix wuz here"
- self.store.reset()
- self.store.flush()
- new_foo1 = self.store.get(Foo, 10)
- self.assertEquals(new_foo1.title, foo1_title)
- def test_reset_clears_cache(self):
- cache = self.get_cache(self.store)
- foo1 = self.store.get(Foo, 10)
- self.assertTrue(get_obj_info(foo1) in cache.get_cached())
- self.store.reset()
- self.assertEquals(cache.get_cached(), [])
- def test_reset_breaks_store_reference(self):
- """
- After resetting the store, all objects that were associated with that
- store will no longer be.
- """
- foo1 = self.store.get(Foo, 10)
- self.store.reset()
- self.assertIdentical(Store.of(foo1), None)
- def test_result_find(self):
- result1 = self.store.find(Foo, Foo.id <= 20)
- result2 = result1.find(Foo.id > 10)
- foo = result2.one()
- self.assertTrue(foo)
- self.assertEqual(foo.id, 20)
- def test_result_find_kwargs(self):
- result1 = self.store.find(Foo, Foo.id <= 20)
- result2 = result1.find(id=20)
- foo = result2.one()
- self.assertTrue(foo)
- self.assertEqual(foo.id, 20)
- def test_result_find_introduce_join(self):
- result1 = self.store.find(Foo, Foo.id <= 20)
- result2 = result1.find(Foo.id == Bar.foo_id,
- Bar.title == u"Title 300")
- foo = result2.one()
- self.assertTrue(foo)
- self.assertEqual(foo.id, 10)
- def test_result_find_tuple(self):
- result1 = self.store.find((Foo, Bar), Foo.id == Bar.foo_id)
- result2 = result1.find(Bar.title == u"Title 100")
- foo_bar = result2.one()
- self.assertTrue(foo_bar)
- foo, bar = foo_bar
- self.assertEqual(foo.id, 30)
- self.assertEqual(bar.id, 300)
- def test_result_find_undef_where(self):
- result = self.store.find(Foo, Foo.id == 20).find()
- foo = result.one()
- self.assertTrue(foo)
- self.assertEqual(foo.id, 20)
- result = self.store.find(Foo).find(Foo.id == 20)
- foo = result.one()
- self.assertTrue(foo)
- self.assertEqual(foo.id, 20)
- def test_result_find_fails_on_set_expr(self):
- result1 = self.store.find(Foo)
- result2 = self.store.find(Foo)
- result = result1.union(result2)
- self.assertRaises(FeatureError, result.find, Foo.id == 20)
- def test_result_find_fails_on_slice(self):
- result = self.store.find(Foo)[1:2]
- self.assertRaises(FeatureError, result.find, Foo.id == 20)
- def test_result_find_fails_on_group_by(self):
- result = self.store.find(Foo)
- result.group_by(Foo)
- self.assertRaises(FeatureError, result.find, Foo.id == 20)
- def test_result_union(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=10)
- result3 = result1.union(result2)
- result3.order_by(Foo.title)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- (10, "Title 30"),
- ])
- result3.order_by(Desc(Foo.title))
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- def test_result_union_duplicated(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=30)
- result3 = result1.union(result2)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- ])
- def test_result_union_duplicated_with_all(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=30)
- result3 = result1.union(result2, all=True)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- (30, "Title 10"),
- ])
- def test_result_union_with_empty(self):
- result1 = self.store.find(Foo, id=30)
- result2 = EmptyResultSet()
- result3 = result1.union(result2)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- ])
- def test_result_union_unsupported_methods(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=10)
- result3 = result1.union(result2)
- self.assertRaises(FeatureError, result3.set, title=u"Title 40")
- self.assertRaises(FeatureError, result3.remove)
- def test_result_union_count(self):
- result1 = self.store.find(Foo, id=30)
- result2 = self.store.find(Foo, id=30)
- result3 = result1.union(result2, all=True)
- self.assertEquals(result3.count(), 2)
- def test_result_difference(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo)
- result2 = self.store.find(Foo, id=20)
- result3 = result1.difference(result2)
- result3.order_by(Foo.title)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- (10, "Title 30"),
- ])
- result3.order_by(Desc(Foo.title))
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- def test_result_difference_with_empty(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo, id=30)
- result2 = EmptyResultSet()
- result3 = result1.difference(result2)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- ])
- def test_result_difference_count(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo)
- result2 = self.store.find(Foo, id=20)
- result3 = result1.difference(result2)
- self.assertEquals(result3.count(), 2)
- def test_is_in_empty_result_set(self):
- result1 = self.store.find(Foo, Foo.id < 10)
- result2 = self.store.find(Foo, Or(Foo.id > 20, Foo.id.is_in(result1)))
- self.assertEquals(result2.count(), 1)
- def test_is_in_empty_list(self):
- result2 = self.store.find(Foo, Eq(False, And(True, Foo.id.is_in([]))))
- self.assertEquals(result2.count(), 3)
- def test_result_intersection(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo)
- result2 = self.store.find(Foo, Foo.id.is_in((10, 30)))
- result3 = result1.intersection(result2)
- result3.order_by(Foo.title)
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (30, "Title 10"),
- (10, "Title 30"),
- ])
- result3.order_by(Desc(Foo.title))
- self.assertEquals([(foo.id, foo.title) for foo in result3], [
- (10, "Title 30"),
- (30, "Title 10"),
- ])
- def test_result_intersection_with_empty(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo, id=30)
- result2 = EmptyResultSet()
- result3 = result1.intersection(result2)
- self.assertEquals(len(list(result3)), 0)
- def test_result_intersection_count(self):
- if self.__class__.__name__.startswith("MySQL"):
- return
- result1 = self.store.find(Foo, Foo.id.is_in((10, 20)))
- result2 = self.store.find(Foo, Foo.id.is_in((10, 30)))
- result3 = result1.intersection(result2)
- self.assertEquals(result3.count(), 1)
- def test_proxy(self):
- bar = self.store.get(BarProxy, 200)
- self.assertEquals(bar.foo_title, "Title 20")
- def test_proxy_equals(self):
- bar = self.store.find(BarProxy, BarProxy.foo_title == u"Title 20").one()
- self.assertTrue(bar)
- self.assertEquals(bar.id, 200)
- def test_proxy_as_column(self):
- result = self.store.find(BarProxy, BarProxy.id == 200)
- self.assertEquals(list(result.values(BarProxy.foo_title)),
- ["Title 20"])
- def test_proxy_set(self):
- bar = self.store.get(BarProxy, 200)
- bar.foo_title = u"New Title"
- foo = self.store.get(Foo, 20)
- self.assertEquals(foo.title, "New Title")
- def get_bar_proxy_with_string(self):
- class Base(object):
- __metaclass__ = PropertyPublisherMeta
- class MyBarProxy(Base):
- __storm_table__ = "bar"
- id = Int(primary=True)
- foo_id = Int()
- foo = Reference("foo_id", "MyFoo.id")
- foo_title = Proxy(foo, "MyFoo.title")
- class MyFoo(Base):
- __storm_table__ = "foo"
- id = Int(primary=True)
- title = Unicode()
- return MyBarProxy, MyFoo
- def test_proxy_with_string(self):
- MyBarProxy, MyFoo = self.get_bar_proxy_with_string()
- bar = self.store.get(MyBarProxy, 200)
- self.assertEquals(bar.foo_title, "Title 20")
- def test_proxy_with_string_variable_factory_attribute(self):
- MyBarProxy, MyFoo = self.get_bar_proxy_with_string()
- variable = MyBarProxy.foo_title.variable_factory(value=u"Hello")
- self.assertTrue(isinstance(variable, UnicodeVariable))
- def test_proxy_with_extra_table(self):
- """
- Proxies use a join on auto_tables. It should work even if we have
- more tables in the query.
- """
- result = self.store.find((BarProxy, Link),
- BarProxy.foo_title == u"Title 20",
- BarProxy.foo_id == Link.foo_id)
- results = list(result)
- self.assertEquals(len(results), 2)
- for bar, link in results:
- self.assertEquals(bar.id, 200)
- self.assertEquals(bar.foo_title, u"Title 20")
- self.assertEquals(bar.foo_id, 20)
- self.assertEquals(link.foo_id, 20)
- def test_get_decimal_property(self):
- money = self.store.get(Money, 10)
- self.assertEquals(money.value, decimal.Decimal("12.3455"))
- def test_set_decimal_property(self):
- money = self.store.get(Money, 10)
- money.value = decimal.Decimal("12.3456")
- self.store.flush()
- result = self.store.find(Money, value=decimal.Decimal("12.3456"))
- self.assertEquals(result.one(), money)
- def test_fill_missing_primary_key_with_lazy_value(self):
- foo = self.store.get(Foo, 10)
- foo.id = SQL("40")
- self.store.flush()
- self.assertEquals(foo.id, 40)
- self.assertEquals(self.store.get(Foo, 10), None)
- self.assertEquals(self.store.get(Foo, 40), foo)
- def test_fill_missing_primary_key_with_lazy_value_on_creation(self):
- foo = Foo()
- foo.id = SQL("40")
- self.store.add(foo)
- self.store.flush()
- self.assertEquals(foo.id, 40)
- self.assertEquals(self.store.get(Foo, 40), foo)
- def test_preset_primary_key(self):
- check = []
- def preset_primary_key(primary_columns, primary_variables):
- check.append([(variable.is_defined(), variable.get_lazy())
- for variable in primary_variables])
- check.append([column.name for column in primary_columns])
- primary_variables[0].set(SQL("40"))
- class DatabaseWrapper(object):
- """Wrapper to inject our custom preset_primary_key hook."""
- def __init__(self, database):
- self.database = database
- def connect(self, event=None):
- connection = self.database.connect(event)
- connection.preset_primary_key = preset_primary_key
- return connection
- store = Store(DatabaseWrapper(self.database))
- foo = store.add(Foo())
- store.flush()
- try:
- self.assertEquals(check, [[(False, None)], ["id"]])
- self.assertEquals(foo.id, 40)
- finally:
- store.close()
- def test_strong_cache_used(self):
- """
- Objects should be referenced in the cache if not referenced
- in application code.
- """
- foo = self.store.get(Foo, 20)
- foo.tainted = True
- obj_info = get_obj_info(foo)
- del foo
- gc.collect()
- cached = self.store.find(Foo).cached()
- self.assertEquals(len(cached), 1)
- foo = self.store.get(Foo, 20)
- self.assertEquals(cached, [foo])
- self.assertTrue(hasattr(foo, "tainted"))
- def test_strong_cache_cleared_on_invalidate_all(self):
- cache = self.get_cache(self.store)
- foo = self.store.get(Foo, 20)
- self.assertEquals(cache.get_cached(), [get_obj_info(foo)])
- self.store.invalidate()
- self.assertEquals(cache.get_cached(), [])
- def test_strong_cache_loses_object_on_invalidate(self):
- cache = self.get_cache(self.store)
- foo = self.store.get(Foo, 20)
- self.assertEquals(cache.get_cached(), [get_obj_info(foo)])
- self.store.invalidate(foo)
- self.assertEquals(cache.get_cached(), [])
- def test_strong_cache_loses_object_on_remove(self):
- """
- Make sure an object gets removed from the strong reference
- cache when removed from the store.
- """
- cache = self.get_cache(self.store)
- foo = self.store.get(Foo, 20)
- self.assertEquals(cache.get_cached(), [get_obj_info(foo)])
- self.store.remove(foo)
- self.store.flush()
- self.assertEquals(cache.get_cached(), [])
- def test_strong_cache_renews_object_on_get(self):
- cache = self.get_cache(self.store)
- foo1 = self.store.get(Foo, 10)
- foo2 = self.store.get(Foo, 20)
- foo1 = self.store.get(Foo, 10)
- self.assertEquals(cache.get_cached(),
- [get_obj_info(foo1), get_obj_info(foo2)])
- def test_strong_cache_renews_object_on_find(self):
- cache = self.get_cache(self.store)
- foo1 = self.store.find(Foo, id=10).one()
- foo2 = self.store.find(Foo, id=20).one()
- foo1 = self.store.find(Foo, id=10).one()
- self.assertEquals(cache.get_cached(),
- [get_obj_info(foo1), get_obj_info(foo2)])
- def test_unicode(self):
- class MyFoo(Foo):
- pass
- foo = self.store.get(Foo, 20)
- myfoo = self.store.get(MyFoo, 20)
- for title in [u'Cừơng', u'Đức', u'Hạnh']:
- foo.title = title
- self.store.commit()
- try:
- self.assertEquals(myfoo.title, title)
- except AssertionError, e:
- raise AssertionError(str(e) +
- " (ensure your database was created with CREATE DATABASE"
- " ... CHARACTER SET utf8)")
- def test_creation_order_is_preserved_when_possible(self):
- foos = [self.store.add(Foo()) for i in range(10)]
- self.store.flush()
- for i in range(len(foos)-1):
- self.assertTrue(foos[i].id < foos[i+1].id)
- def test_update_order_is_preserved_when_possible(self):
- class MyFoo(Foo):
- sequence = 0
- def __storm_flushed__(self):
- self.flush_order = MyFoo.sequence
- MyFoo.sequence += 1
- foos = [self.store.add(MyFoo()) for i in range(10)]
- self.store.flush()
- MyFoo.sequence = 0
- for foo in foos:
- foo.title = u"Changed Title"
- self.store.flush()
- for i, foo in enumerate(foos):
- self.assertEquals(foo.flush_order, i)
- def test_removal_order_is_preserved_when_possible(self):
- class MyFoo(Foo):
- sequence = 0
- def __storm_flushed__(self):
- self.flush_order = MyFoo.sequence
- MyFoo.sequence += 1
- foos = [self.store.add(MyFoo()) for i in range(10)]
- self.store.flush()
- MyFoo.sequence = 0
- for foo in foos:
- self.store.remove(foo)
- self.store.flush()
- for i, foo in enumerate(foos):
- self.assertEquals(foo.flush_order, i)
- def test_cache_poisoning(self):
- """
- When a object update a field value to the previous value, which is in
- the cache, it correctly updates the value in the database.
- Because of change detection, this has been broken in the past, see bug
- #277095 in launchpad.
- """
- store = self.create_store()
- foo2 = store.get(Foo, 10)
- self.assertEquals(foo2.title, u"Title 30")
- store.commit()
- foo1 = self.store.get(Foo, 10)
- foo1.title = u"Title 40"
- self.store.commit()
- foo2.title = u"Title 30"
- store.commit()
- self.assertEquals(foo2.title, u"Title 30")
- def test_execute_sends_event(self):
- """Statement execution emits the register-transaction event."""
- calls = []
- def register_transaction(owner):
- calls.append(owner)
- self.store._event.hook("register-transaction", register_transaction)
- self.store.execute("SELECT 1")
- self.assertEqual(len(calls), 1)
- self.assertEqual(calls[0], self.store)
- def test_add_sends_event(self):
- """Adding an object emits the register-transaction event."""
- calls = []
- def register_transaction(owner):
- calls.append(owner)
- self.store._event.hook("register-transaction", register_transaction)
- foo = Foo()
- foo.title = u"Foo"
- self.store.add(foo)
- self.assertEqual(len(calls), 1)
- self.assertEqual(calls[0], self.store)
- def test_remove_sends_event(self):
- """Adding an object emits the register-transaction event."""
- calls = []
- def register_transaction(owner):
- calls.append(owner)
- self.store._event.hook("register-transaction", register_transaction)
- foo = self.store.get(Foo, 10)
- del calls[:]
- self.store.remove(foo)
- self.assertEqual(len(calls), 1)
- self.assertEqual(calls[0], self.store)
- def test_change_invalidated_object_sends_event(self):
- """Modifying an object retrieved in a previous transaction emits the
- register-transaction event."""
- calls = []
- def register_transaction(owner):
- calls.append(owner)
- self.store._event.hook("register-transaction", register_transaction)
- foo = self.store.get(Foo, 10)
- self.store.rollback()
- del calls[:]
- foo.title = u"New title"
- self.assertEqual(len(calls), 1)
- self.assertEqual(calls[0], self.store)
- def test_rowcount_remove(self):
- # All supported backends support rowcount, so far.
- result_to_remove = self.store.find(Foo, Foo.id <= 30)
- self.assertEquals(result_to_remove.remove(), 3)
- class EmptyResultSetTest(object):
- def setUp(self):
- self.create_database()
- self.connection = self.database.connect()
- self.drop_tables()
- self.create_tables()
- self.create_store()
- # Most of the tests here exercise the same functionality using
- # self.empty and self.result to ensure that EmptyResultSet and
- # ResultSet behave the same way, in the same situations.
- self.empty = EmptyResultSet()
- self.result = self.store.find(Foo)
- def tearDown(self):
- self.drop_store()
- self.drop_tables()
- self.drop_database()
- self.connection.close()
- def create_database(self):
- raise NotImplementedError
- def create_tables(self):
- raise NotImplementedError
- def create_store(self):
- self.store = Store(self.database)
- def drop_database(self):
- pass
- def drop_tables(self):
- for table in ["foo", "bar", "bin", "link"]:
- try:
- self.connection.execute("DROP TABLE %s" % table)
- self.connection.commit()
- except:
- self.connection.rollback()
- def drop_store(self):
- self.store.rollback()
- # Closing the store is needed because testcase objects are all
- # instantiated at once, and thus connections are kept open.
- self.store.close()
- def test_iter(self):
- self.assertEquals(list(self.result), list(self.empty))
- def test_copy(self):
- self.assertNotEquals(self.result.copy(), self.result)
- self.assertNotEquals(self.empty.copy(), self.empty)
- self.assertEquals(list(self.result.copy()), list(self.empty.copy()))
- def test_config(self):
- self.result.config(distinct=True, offset=1, limit=1)
- self.empty.config(distinct=True, offset=1, limit=1)
- self.assertEquals(list(self.result), list(self.empty))
- def test_slice(self):
- self.assertEquals(list(self.result[:]), [])
- self.assertEquals(list(self.empty[:]), [])
- def test_contains(self):
- self.assertEquals(Foo() in self.empty, False)
- def test_is_empty(self):
- self.assertEquals(self.result.is_empty(), True)
- self.assertEquals(self.empty.is_empty(), True)
- def test_any(self):
- self.assertEquals(self.result.any(), None)
- self.assertEquals(self.empty.any(), None)
- def test_first_unordered(self):
- self.assertRaises(UnorderedError, self.result.first)
- self.assertRaises(UnorderedError, self.empty.first)
- def test_first_ordered(self):
- self.result.order_by(Foo.title)
- self.assertEquals(self.result.first(), None)
- self.empty.order_by(Foo.title)
- self.assertEquals(self.empty.first(), None)
- def test_last_unordered(self):
- self.assertRaises(UnorderedError, self.result.last)
- self.assertRaises(UnorderedError, self.empty.last)
- def test_last_ordered(self):
- self.result.order_by(Foo.title)
- self.assertEquals(self.result.last(), None)
- self.empty.order_by(Foo.title)
- self.assertEquals(self.empty.last(), None)
- def test_one(self):
- self.assertEquals(self.result.one(), None)
- self.assertEquals(self.empty.one(), None)
- def test_order_by(self):
- self.assertEquals(self.result.order_by(Foo.title), self.result)
- self.assertEquals(self.empty.order_by(Foo.title), self.empty)
- def test_remove(self):
- self.assertEquals(self.result.remove(), 0)
- self.assertEquals(self.empty.remove(), 0)
- def test_count(self):
- self.assertEquals(self.result.count(), 0)
- self.assertEquals(self.empty.count(), 0)
- self.assertEquals(self.empty.count(expr="abc"), 0)
- self.assertEquals(self.empty.count(distinct=True), 0)
- def test_max(self):
- self.assertEquals(self.result.max(Foo.id), None)
- self.assertEquals(self.empty.max(Foo.id), None)
- def test_min(self):
- self.assertEquals(self.result.min(Foo.id), None)
- self.assertEquals(self.empty.min(Foo.id), None)
- def test_avg(self):
- self.assertEquals(self.result.avg(Foo.id), None)
- self.assertEquals(self.empty.avg(Foo.id), None)
- def test_sum(self):
- self.assertEquals(self.result.sum(Foo.id), None)
- self.assertEquals(self.empty.sum(Foo.id), None)
- def test_get_select_expr_without_columns(self):
- """
- A L{FeatureError} is raised if L{EmptyResultSet.get_select_expr} is
- called without a list of L{Column}s.
- """
- self.assertRaises(FeatureError, self.result.get_select_expr)
- self.assertRaises(FeatureError, self.empty.get_select_expr)
- def test_get_select_expr_(self):
- """
- A L{FeatureError} is raised if L{EmptyResultSet.get_select_expr} is
- called without a list of L{Column}s.
- """
- subselect = self.result.get_select_expr(Foo.id)
- self.assertEqual((Foo.id,), subselect.columns)
- result = self.store.find(Foo, Foo.id.is_in(subselect))
- self.assertEquals(list(result), [])
- subselect = self.empty.get_select_expr(Foo.id)
- self.assertEqual((Foo.id,), subselect.columns)
- result = self.store.find(Foo, Foo.id.is_in(subselect))
- self.assertEquals(list(result), [])
- def test_values_no_columns(self):
- self.assertRaises(FeatureError, list, self.result.values())
- self.assertRaises(FeatureError, list, self.empty.values())
- def test_values(self):
- self.assertEquals(list(self.result.values(Foo.title)), [])
- self.assertEquals(list(self.empty.values(Foo.title)), [])
- def test_set_no_args(self):
- self.assertEquals(self.result.set(), None)
- self.assertEquals(self.empty.set(), None)
- def test_cached(self):
- self.assertEquals(self.result.cached(), [])
- self.assertEquals(self.empty.cached(), [])
- def test_find(self):
- self.assertEquals(list(self.result.find(Foo.title == u"foo")), [])
- self.assertEquals(list(self.empty.find(Foo.title == u"foo")), [])
- def test_union(self):
- self.assertEquals(self.empty.union(self.empty), self.empty)
- self.assertEquals(type(self.empty.union(self.result)),
- type(self.result))
- self.assertEquals(type(self.result.union(self.empty)),
- type(self.result))
- def test_difference(self):
- self.assertEquals(self.empty.difference(self.empty), self.empty)
- self.assertEquals(self.empty.difference(self.result), self.empty)
- self.assertEquals(self.result.difference(self.empty), self.result)
- def test_intersection(self):
- self.assertEquals(self.empty.intersection(self.empty), self.empty)
- self.assertEquals(self.empty.intersection(self.result), self.empty)
- self.assertEquals(self.result.intersection(self.empty), self.empty)