PageRenderTime 541ms CodeModel.GetById 35ms app.highlight 470ms RepoModel.GetById 1ms app.codeStats 1ms

/SQLAlchemy-0.7.8/test/ext/test_associationproxy.py

#
Python | 1369 lines | 1062 code | 290 blank | 17 comment | 51 complexity | bedaf670fede93d244fb797b82068e92 MD5 | raw file
   1from test.lib.testing import eq_, assert_raises
   2import copy
   3import pickle
   4
   5from sqlalchemy import *
   6from sqlalchemy.orm import *
   7from sqlalchemy.orm.collections import collection, attribute_mapped_collection
   8from sqlalchemy.ext.associationproxy import *
   9from sqlalchemy.ext.associationproxy import _AssociationList
  10from test.lib import *
  11from test.lib.testing import assert_raises_message
  12from test.lib.util import gc_collect
  13from sqlalchemy.sql import not_
  14from test.lib import fixtures
  15
  16
  17class DictCollection(dict):
  18    @collection.appender
  19    def append(self, obj):
  20        self[obj.foo] = obj
  21    @collection.remover
  22    def remove(self, obj):
  23        del self[obj.foo]
  24
  25
  26class SetCollection(set):
  27    pass
  28
  29
  30class ListCollection(list):
  31    pass
  32
  33
  34class ObjectCollection(object):
  35    def __init__(self):
  36        self.values = list()
  37    @collection.appender
  38    def append(self, obj):
  39        self.values.append(obj)
  40    @collection.remover
  41    def remove(self, obj):
  42        self.values.remove(obj)
  43    def __iter__(self):
  44        return iter(self.values)
  45
  46
  47class _CollectionOperations(fixtures.TestBase):
  48    def setup(self):
  49        collection_class = self.collection_class
  50
  51        metadata = MetaData(testing.db)
  52
  53        parents_table = Table('Parent', metadata,
  54                              Column('id', Integer, primary_key=True,
  55                                     test_needs_autoincrement=True),
  56                              Column('name', String(128)))
  57        children_table = Table('Children', metadata,
  58                               Column('id', Integer, primary_key=True,
  59                                      test_needs_autoincrement=True),
  60                               Column('parent_id', Integer,
  61                                      ForeignKey('Parent.id')),
  62                               Column('foo', String(128)),
  63                               Column('name', String(128)))
  64
  65        class Parent(object):
  66            children = association_proxy('_children', 'name')
  67
  68            def __init__(self, name):
  69                self.name = name
  70
  71        class Child(object):
  72            if collection_class and issubclass(collection_class, dict):
  73                def __init__(self, foo, name):
  74                    self.foo = foo
  75                    self.name = name
  76            else:
  77                def __init__(self, name):
  78                    self.name = name
  79
  80        mapper(Parent, parents_table, properties={
  81            '_children': relationship(Child, lazy='joined',
  82                                  collection_class=collection_class)})
  83        mapper(Child, children_table)
  84
  85        metadata.create_all()
  86
  87        self.metadata = metadata
  88        self.session = create_session()
  89        self.Parent, self.Child = Parent, Child
  90
  91    def teardown(self):
  92        self.metadata.drop_all()
  93
  94    def roundtrip(self, obj):
  95        if obj not in self.session:
  96            self.session.add(obj)
  97        self.session.flush()
  98        id, type_ = obj.id, type(obj)
  99        self.session.expunge_all()
 100        return self.session.query(type_).get(id)
 101
 102    def _test_sequence_ops(self):
 103        Parent, Child = self.Parent, self.Child
 104
 105        p1 = Parent('P1')
 106
 107        self.assert_(not p1._children)
 108        self.assert_(not p1.children)
 109
 110        ch = Child('regular')
 111        p1._children.append(ch)
 112
 113        self.assert_(ch in p1._children)
 114        self.assert_(len(p1._children) == 1)
 115
 116        self.assert_(p1.children)
 117        self.assert_(len(p1.children) == 1)
 118        self.assert_(ch not in p1.children)
 119        self.assert_('regular' in p1.children)
 120
 121        p1.children.append('proxied')
 122
 123        self.assert_('proxied' in p1.children)
 124        self.assert_('proxied' not in p1._children)
 125        self.assert_(len(p1.children) == 2)
 126        self.assert_(len(p1._children) == 2)
 127
 128        self.assert_(p1._children[0].name == 'regular')
 129        self.assert_(p1._children[1].name == 'proxied')
 130
 131        del p1._children[1]
 132
 133        self.assert_(len(p1._children) == 1)
 134        self.assert_(len(p1.children) == 1)
 135        self.assert_(p1._children[0] == ch)
 136
 137        del p1.children[0]
 138
 139        self.assert_(len(p1._children) == 0)
 140        self.assert_(len(p1.children) == 0)
 141
 142        p1.children = ['a','b','c']
 143        self.assert_(len(p1._children) == 3)
 144        self.assert_(len(p1.children) == 3)
 145
 146        del ch
 147        p1 = self.roundtrip(p1)
 148
 149        self.assert_(len(p1._children) == 3)
 150        self.assert_(len(p1.children) == 3)
 151
 152        popped = p1.children.pop()
 153        self.assert_(len(p1.children) == 2)
 154        self.assert_(popped not in p1.children)
 155        p1 = self.roundtrip(p1)
 156        self.assert_(len(p1.children) == 2)
 157        self.assert_(popped not in p1.children)
 158
 159        p1.children[1] = 'changed-in-place'
 160        self.assert_(p1.children[1] == 'changed-in-place')
 161        inplace_id = p1._children[1].id
 162        p1 = self.roundtrip(p1)
 163        self.assert_(p1.children[1] == 'changed-in-place')
 164        assert p1._children[1].id == inplace_id
 165
 166        p1.children.append('changed-in-place')
 167        self.assert_(p1.children.count('changed-in-place') == 2)
 168
 169        p1.children.remove('changed-in-place')
 170        self.assert_(p1.children.count('changed-in-place') == 1)
 171
 172        p1 = self.roundtrip(p1)
 173        self.assert_(p1.children.count('changed-in-place') == 1)
 174
 175        p1._children = []
 176        self.assert_(len(p1.children) == 0)
 177
 178        after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
 179        p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
 180        self.assert_(len(p1.children) == 10)
 181        self.assert_([c.name for c in p1._children] == after)
 182
 183        p1.children[2:6] = ['x'] * 4
 184        after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j']
 185        self.assert_(p1.children == after)
 186        self.assert_([c.name for c in p1._children] == after)
 187
 188        p1.children[2:6] = ['y']
 189        after = ['a', 'b', 'y', 'g', 'h', 'i', 'j']
 190        self.assert_(p1.children == after)
 191        self.assert_([c.name for c in p1._children] == after)
 192
 193        p1.children[2:3] = ['z'] * 4
 194        after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j']
 195        self.assert_(p1.children == after)
 196        self.assert_([c.name for c in p1._children] == after)
 197
 198        p1.children[2::2] = ['O'] * 4
 199        after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j']
 200        self.assert_(p1.children == after)
 201        self.assert_([c.name for c in p1._children] == after)
 202
 203        assert_raises(TypeError, set, [p1.children])
 204
 205        p1.children *= 0
 206        after = []
 207        self.assert_(p1.children == after)
 208        self.assert_([c.name for c in p1._children] == after)
 209
 210        p1.children += ['a', 'b']
 211        after = ['a', 'b']
 212        self.assert_(p1.children == after)
 213        self.assert_([c.name for c in p1._children] == after)
 214
 215        p1.children += ['c']
 216        after = ['a', 'b', 'c']
 217        self.assert_(p1.children == after)
 218        self.assert_([c.name for c in p1._children] == after)
 219
 220        p1.children *= 1
 221        after = ['a', 'b', 'c']
 222        self.assert_(p1.children == after)
 223        self.assert_([c.name for c in p1._children] == after)
 224
 225        p1.children *= 2
 226        after = ['a', 'b', 'c', 'a', 'b', 'c']
 227        self.assert_(p1.children == after)
 228        self.assert_([c.name for c in p1._children] == after)
 229
 230        p1.children = ['a']
 231        after = ['a']
 232        self.assert_(p1.children == after)
 233        self.assert_([c.name for c in p1._children] == after)
 234
 235        self.assert_((p1.children * 2) == ['a', 'a'])
 236        self.assert_((2 * p1.children) == ['a', 'a'])
 237        self.assert_((p1.children * 0) == [])
 238        self.assert_((0 * p1.children) == [])
 239
 240        self.assert_((p1.children + ['b']) == ['a', 'b'])
 241        self.assert_((['b'] + p1.children) == ['b', 'a'])
 242
 243        try:
 244            p1.children + 123
 245            assert False
 246        except TypeError:
 247            assert True
 248
 249class DefaultTest(_CollectionOperations):
 250    def __init__(self, *args, **kw):
 251        super(DefaultTest, self).__init__(*args, **kw)
 252        self.collection_class = None
 253
 254    def test_sequence_ops(self):
 255        self._test_sequence_ops()
 256
 257
 258class ListTest(_CollectionOperations):
 259    def __init__(self, *args, **kw):
 260        super(ListTest, self).__init__(*args, **kw)
 261        self.collection_class = list
 262
 263    def test_sequence_ops(self):
 264        self._test_sequence_ops()
 265
 266class CustomListTest(ListTest):
 267    def __init__(self, *args, **kw):
 268        super(CustomListTest, self).__init__(*args, **kw)
 269        self.collection_class = list
 270
 271# No-can-do until ticket #213
 272class DictTest(_CollectionOperations):
 273    pass
 274
 275class CustomDictTest(DictTest):
 276    def __init__(self, *args, **kw):
 277        super(DictTest, self).__init__(*args, **kw)
 278        self.collection_class = DictCollection
 279
 280    def test_mapping_ops(self):
 281        Parent, Child = self.Parent, self.Child
 282
 283        p1 = Parent('P1')
 284
 285        self.assert_(not p1._children)
 286        self.assert_(not p1.children)
 287
 288        ch = Child('a', 'regular')
 289        p1._children.append(ch)
 290
 291        self.assert_(ch in p1._children.values())
 292        self.assert_(len(p1._children) == 1)
 293
 294        self.assert_(p1.children)
 295        self.assert_(len(p1.children) == 1)
 296        self.assert_(ch not in p1.children)
 297        self.assert_('a' in p1.children)
 298        self.assert_(p1.children['a'] == 'regular')
 299        self.assert_(p1._children['a'] == ch)
 300
 301        p1.children['b'] = 'proxied'
 302
 303        self.assert_('proxied' in p1.children.values())
 304        self.assert_('b' in p1.children)
 305        self.assert_('proxied' not in p1._children)
 306        self.assert_(len(p1.children) == 2)
 307        self.assert_(len(p1._children) == 2)
 308
 309        self.assert_(p1._children['a'].name == 'regular')
 310        self.assert_(p1._children['b'].name == 'proxied')
 311
 312        del p1._children['b']
 313
 314        self.assert_(len(p1._children) == 1)
 315        self.assert_(len(p1.children) == 1)
 316        self.assert_(p1._children['a'] == ch)
 317
 318        del p1.children['a']
 319
 320        self.assert_(len(p1._children) == 0)
 321        self.assert_(len(p1.children) == 0)
 322
 323        p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'}
 324        self.assert_(len(p1._children) == 3)
 325        self.assert_(len(p1.children) == 3)
 326
 327        self.assert_(set(p1.children) == set(['d','e','f']))
 328
 329        del ch
 330        p1 = self.roundtrip(p1)
 331        self.assert_(len(p1._children) == 3)
 332        self.assert_(len(p1.children) == 3)
 333
 334        p1.children['e'] = 'changed-in-place'
 335        self.assert_(p1.children['e'] == 'changed-in-place')
 336        inplace_id = p1._children['e'].id
 337        p1 = self.roundtrip(p1)
 338        self.assert_(p1.children['e'] == 'changed-in-place')
 339        self.assert_(p1._children['e'].id == inplace_id)
 340
 341        p1._children = {}
 342        self.assert_(len(p1.children) == 0)
 343
 344        try:
 345            p1._children = []
 346            self.assert_(False)
 347        except TypeError:
 348            self.assert_(True)
 349
 350        try:
 351            p1._children = None
 352            self.assert_(False)
 353        except TypeError:
 354            self.assert_(True)
 355
 356        assert_raises(TypeError, set, [p1.children])
 357
 358
 359class SetTest(_CollectionOperations):
 360    def __init__(self, *args, **kw):
 361        super(SetTest, self).__init__(*args, **kw)
 362        self.collection_class = set
 363
 364    def test_set_operations(self):
 365        Parent, Child = self.Parent, self.Child
 366
 367        p1 = Parent('P1')
 368
 369        self.assert_(not p1._children)
 370        self.assert_(not p1.children)
 371
 372        ch1 = Child('regular')
 373        p1._children.add(ch1)
 374
 375        self.assert_(ch1 in p1._children)
 376        self.assert_(len(p1._children) == 1)
 377
 378        self.assert_(p1.children)
 379        self.assert_(len(p1.children) == 1)
 380        self.assert_(ch1 not in p1.children)
 381        self.assert_('regular' in p1.children)
 382
 383        p1.children.add('proxied')
 384
 385        self.assert_('proxied' in p1.children)
 386        self.assert_('proxied' not in p1._children)
 387        self.assert_(len(p1.children) == 2)
 388        self.assert_(len(p1._children) == 2)
 389
 390        self.assert_(set([o.name for o in p1._children]) ==
 391                     set(['regular', 'proxied']))
 392
 393        ch2 = None
 394        for o in p1._children:
 395            if o.name == 'proxied':
 396                ch2 = o
 397                break
 398
 399        p1._children.remove(ch2)
 400
 401        self.assert_(len(p1._children) == 1)
 402        self.assert_(len(p1.children) == 1)
 403        self.assert_(p1._children == set([ch1]))
 404
 405        p1.children.remove('regular')
 406
 407        self.assert_(len(p1._children) == 0)
 408        self.assert_(len(p1.children) == 0)
 409
 410        p1.children = ['a','b','c']
 411        self.assert_(len(p1._children) == 3)
 412        self.assert_(len(p1.children) == 3)
 413
 414        del ch1
 415        p1 = self.roundtrip(p1)
 416
 417        self.assert_(len(p1._children) == 3)
 418        self.assert_(len(p1.children) == 3)
 419
 420        self.assert_('a' in p1.children)
 421        self.assert_('b' in p1.children)
 422        self.assert_('d' not in p1.children)
 423
 424        self.assert_(p1.children == set(['a','b','c']))
 425
 426        try:
 427            p1.children.remove('d')
 428            self.fail()
 429        except KeyError:
 430            pass
 431
 432        self.assert_(len(p1.children) == 3)
 433        p1.children.discard('d')
 434        self.assert_(len(p1.children) == 3)
 435        p1 = self.roundtrip(p1)
 436        self.assert_(len(p1.children) == 3)
 437
 438        popped = p1.children.pop()
 439        self.assert_(len(p1.children) == 2)
 440        self.assert_(popped not in p1.children)
 441        p1 = self.roundtrip(p1)
 442        self.assert_(len(p1.children) == 2)
 443        self.assert_(popped not in p1.children)
 444
 445        p1.children = ['a','b','c']
 446        p1 = self.roundtrip(p1)
 447        self.assert_(p1.children == set(['a','b','c']))
 448
 449        p1.children.discard('b')
 450        p1 = self.roundtrip(p1)
 451        self.assert_(p1.children == set(['a', 'c']))
 452
 453        p1.children.remove('a')
 454        p1 = self.roundtrip(p1)
 455        self.assert_(p1.children == set(['c']))
 456
 457        p1._children = set()
 458        self.assert_(len(p1.children) == 0)
 459
 460        try:
 461            p1._children = []
 462            self.assert_(False)
 463        except TypeError:
 464            self.assert_(True)
 465
 466        try:
 467            p1._children = None
 468            self.assert_(False)
 469        except TypeError:
 470            self.assert_(True)
 471
 472        assert_raises(TypeError, set, [p1.children])
 473
 474
 475    def test_set_comparisons(self):
 476        Parent, Child = self.Parent, self.Child
 477
 478        p1 = Parent('P1')
 479        p1.children = ['a','b','c']
 480        control = set(['a','b','c'])
 481
 482        for other in (set(['a','b','c']), set(['a','b','c','d']),
 483                      set(['a']), set(['a','b']),
 484                      set(['c','d']), set(['e', 'f', 'g']),
 485                      set()):
 486
 487            eq_(p1.children.union(other),
 488                             control.union(other))
 489            eq_(p1.children.difference(other),
 490                             control.difference(other))
 491            eq_((p1.children - other),
 492                             (control - other))
 493            eq_(p1.children.intersection(other),
 494                             control.intersection(other))
 495            eq_(p1.children.symmetric_difference(other),
 496                             control.symmetric_difference(other))
 497            eq_(p1.children.issubset(other),
 498                             control.issubset(other))
 499            eq_(p1.children.issuperset(other),
 500                             control.issuperset(other))
 501
 502            self.assert_((p1.children == other)  ==  (control == other))
 503            self.assert_((p1.children != other)  ==  (control != other))
 504            self.assert_((p1.children < other)   ==  (control < other))
 505            self.assert_((p1.children <= other)  ==  (control <= other))
 506            self.assert_((p1.children > other)   ==  (control > other))
 507            self.assert_((p1.children >= other)  ==  (control >= other))
 508
 509    def test_set_mutation(self):
 510        Parent, Child = self.Parent, self.Child
 511
 512        # mutations
 513        for op in ('update', 'intersection_update',
 514                   'difference_update', 'symmetric_difference_update'):
 515            for base in (['a', 'b', 'c'], []):
 516                for other in (set(['a','b','c']), set(['a','b','c','d']),
 517                              set(['a']), set(['a','b']),
 518                              set(['c','d']), set(['e', 'f', 'g']),
 519                              set()):
 520                    p = Parent('p')
 521                    p.children = base[:]
 522                    control = set(base[:])
 523
 524                    getattr(p.children, op)(other)
 525                    getattr(control, op)(other)
 526                    try:
 527                        self.assert_(p.children == control)
 528                    except:
 529                        print 'Test %s.%s(%s):' % (set(base), op, other)
 530                        print 'want', repr(control)
 531                        print 'got', repr(p.children)
 532                        raise
 533
 534                    p = self.roundtrip(p)
 535
 536                    try:
 537                        self.assert_(p.children == control)
 538                    except:
 539                        print 'Test %s.%s(%s):' % (base, op, other)
 540                        print 'want', repr(control)
 541                        print 'got', repr(p.children)
 542                        raise
 543
 544        # in-place mutations
 545        for op in ('|=', '-=', '&=', '^='):
 546            for base in (['a', 'b', 'c'], []):
 547                for other in (set(['a','b','c']), set(['a','b','c','d']),
 548                              set(['a']), set(['a','b']),
 549                              set(['c','d']), set(['e', 'f', 'g']),
 550                              frozenset(['e', 'f', 'g']),
 551                              set()):
 552                    p = Parent('p')
 553                    p.children = base[:]
 554                    control = set(base[:])
 555
 556                    exec "p.children %s other" % op
 557                    exec "control %s other" % op
 558
 559                    try:
 560                        self.assert_(p.children == control)
 561                    except:
 562                        print 'Test %s %s %s:' % (set(base), op, other)
 563                        print 'want', repr(control)
 564                        print 'got', repr(p.children)
 565                        raise
 566
 567                    p = self.roundtrip(p)
 568
 569                    try:
 570                        self.assert_(p.children == control)
 571                    except:
 572                        print 'Test %s %s %s:' % (base, op, other)
 573                        print 'want', repr(control)
 574                        print 'got', repr(p.children)
 575                        raise
 576
 577
 578class CustomSetTest(SetTest):
 579    def __init__(self, *args, **kw):
 580        super(CustomSetTest, self).__init__(*args, **kw)
 581        self.collection_class = SetCollection
 582
 583class CustomObjectTest(_CollectionOperations):
 584    def __init__(self, *args, **kw):
 585        super(CustomObjectTest, self).__init__(*args, **kw)
 586        self.collection_class = ObjectCollection
 587
 588    def test_basic(self):
 589        Parent, Child = self.Parent, self.Child
 590
 591        p = Parent('p1')
 592        self.assert_(len(list(p.children)) == 0)
 593
 594        p.children.append('child')
 595        self.assert_(len(list(p.children)) == 1)
 596
 597        p = self.roundtrip(p)
 598        self.assert_(len(list(p.children)) == 1)
 599
 600        # We didn't provide an alternate _AssociationList implementation
 601        # for our ObjectCollection, so indexing will fail.
 602
 603        try:
 604            v = p.children[1]
 605            self.fail()
 606        except TypeError:
 607            pass
 608
 609class ProxyFactoryTest(ListTest):
 610    def setup(self):
 611        metadata = MetaData(testing.db)
 612
 613        parents_table = Table('Parent', metadata,
 614                              Column('id', Integer, primary_key=True,
 615                                     test_needs_autoincrement=True),
 616                              Column('name', String(128)))
 617        children_table = Table('Children', metadata,
 618                               Column('id', Integer, primary_key=True,
 619                                      test_needs_autoincrement=True),
 620                               Column('parent_id', Integer,
 621                                      ForeignKey('Parent.id')),
 622                               Column('foo', String(128)),
 623                               Column('name', String(128)))
 624
 625        class CustomProxy(_AssociationList):
 626            def __init__(
 627                self,
 628                lazy_collection,
 629                creator,
 630                value_attr,
 631                parent,
 632                ):
 633                getter, setter = parent._default_getset(lazy_collection)
 634                _AssociationList.__init__(
 635                    self,
 636                    lazy_collection,
 637                    creator,
 638                    getter,
 639                    setter,
 640                    parent,
 641                    )
 642
 643        class Parent(object):
 644            children = association_proxy('_children', 'name', 
 645                        proxy_factory=CustomProxy, 
 646                        proxy_bulk_set=CustomProxy.extend
 647                    )
 648
 649            def __init__(self, name):
 650                self.name = name
 651
 652        class Child(object):
 653            def __init__(self, name):
 654                self.name = name
 655
 656        mapper(Parent, parents_table, properties={
 657            '_children': relationship(Child, lazy='joined',
 658                                  collection_class=list)})
 659        mapper(Child, children_table)
 660
 661        metadata.create_all()
 662
 663        self.metadata = metadata
 664        self.session = create_session()
 665        self.Parent, self.Child = Parent, Child
 666
 667    def test_sequence_ops(self):
 668        self._test_sequence_ops()
 669
 670
 671class ScalarTest(fixtures.TestBase):
 672    def test_scalar_proxy(self):
 673        metadata = MetaData(testing.db)
 674
 675        parents_table = Table('Parent', metadata,
 676                              Column('id', Integer, primary_key=True,
 677                                     test_needs_autoincrement=True),
 678                              Column('name', String(128)))
 679        children_table = Table('Children', metadata,
 680                               Column('id', Integer, primary_key=True,
 681                                      test_needs_autoincrement=True),
 682                               Column('parent_id', Integer,
 683                                      ForeignKey('Parent.id')),
 684                               Column('foo', String(128)),
 685                               Column('bar', String(128)),
 686                               Column('baz', String(128)))
 687
 688        class Parent(object):
 689            foo = association_proxy('child', 'foo')
 690            bar = association_proxy('child', 'bar',
 691                                    creator=lambda v: Child(bar=v))
 692            baz = association_proxy('child', 'baz',
 693                                    creator=lambda v: Child(baz=v))
 694
 695            def __init__(self, name):
 696                self.name = name
 697
 698        class Child(object):
 699            def __init__(self, **kw):
 700                for attr in kw:
 701                    setattr(self, attr, kw[attr])
 702
 703        mapper(Parent, parents_table, properties={
 704            'child': relationship(Child, lazy='joined',
 705                              backref='parent', uselist=False)})
 706        mapper(Child, children_table)
 707
 708        metadata.create_all()
 709        session = create_session()
 710
 711        def roundtrip(obj):
 712            if obj not in session:
 713                session.add(obj)
 714            session.flush()
 715            id, type_ = obj.id, type(obj)
 716            session.expunge_all()
 717            return session.query(type_).get(id)
 718
 719        p = Parent('p')
 720
 721        # No child
 722        try:
 723            v = p.foo
 724            self.fail()
 725        except:
 726            pass
 727
 728        p.child = Child(foo='a', bar='b', baz='c')
 729
 730        self.assert_(p.foo == 'a')
 731        self.assert_(p.bar == 'b')
 732        self.assert_(p.baz == 'c')
 733
 734        p.bar = 'x'
 735        self.assert_(p.foo == 'a')
 736        self.assert_(p.bar == 'x')
 737        self.assert_(p.baz == 'c')
 738
 739        p = roundtrip(p)
 740
 741        self.assert_(p.foo == 'a')
 742        self.assert_(p.bar == 'x')
 743        self.assert_(p.baz == 'c')
 744
 745        p.child = None
 746
 747        # No child again
 748        try:
 749            v = p.foo
 750            self.fail()
 751        except:
 752            pass
 753
 754        # Bogus creator for this scalar type
 755        try:
 756            p.foo = 'zzz'
 757            self.fail()
 758        except TypeError:
 759            pass
 760
 761        p.bar = 'yyy'
 762
 763        self.assert_(p.foo is None)
 764        self.assert_(p.bar == 'yyy')
 765        self.assert_(p.baz is None)
 766
 767        del p.child
 768
 769        p = roundtrip(p)
 770
 771        self.assert_(p.child is None)
 772
 773        p.baz = 'xxx'
 774
 775        self.assert_(p.foo is None)
 776        self.assert_(p.bar is None)
 777        self.assert_(p.baz == 'xxx')
 778
 779        p = roundtrip(p)
 780
 781        self.assert_(p.foo is None)
 782        self.assert_(p.bar is None)
 783        self.assert_(p.baz == 'xxx')
 784
 785        # Ensure an immediate __set__ works.
 786        p2 = Parent('p2')
 787        p2.bar = 'quux'
 788
 789
 790class LazyLoadTest(fixtures.TestBase):
 791    def setup(self):
 792        metadata = MetaData(testing.db)
 793
 794        parents_table = Table('Parent', metadata,
 795                              Column('id', Integer, primary_key=True,
 796                                     test_needs_autoincrement=True),
 797                              Column('name', String(128)))
 798        children_table = Table('Children', metadata,
 799                               Column('id', Integer, primary_key=True,
 800                                      test_needs_autoincrement=True),
 801                               Column('parent_id', Integer,
 802                                      ForeignKey('Parent.id')),
 803                               Column('foo', String(128)),
 804                               Column('name', String(128)))
 805
 806        class Parent(object):
 807            children = association_proxy('_children', 'name')
 808
 809            def __init__(self, name):
 810                self.name = name
 811
 812        class Child(object):
 813            def __init__(self, name):
 814                self.name = name
 815
 816
 817        mapper(Child, children_table)
 818        metadata.create_all()
 819
 820        self.metadata = metadata
 821        self.session = create_session()
 822        self.Parent, self.Child = Parent, Child
 823        self.table = parents_table
 824
 825    def teardown(self):
 826        self.metadata.drop_all()
 827
 828    def roundtrip(self, obj):
 829        self.session.add(obj)
 830        self.session.flush()
 831        id, type_ = obj.id, type(obj)
 832        self.session.expunge_all()
 833        return self.session.query(type_).get(id)
 834
 835    def test_lazy_list(self):
 836        Parent, Child = self.Parent, self.Child
 837
 838        mapper(Parent, self.table, properties={
 839            '_children': relationship(Child, lazy='select',
 840                                  collection_class=list)})
 841
 842        p = Parent('p')
 843        p.children = ['a','b','c']
 844
 845        p = self.roundtrip(p)
 846
 847        # Is there a better way to ensure that the association_proxy
 848        # didn't convert a lazy load to an eager load?  This does work though.
 849        self.assert_('_children' not in p.__dict__)
 850        self.assert_(len(p._children) == 3)
 851        self.assert_('_children' in p.__dict__)
 852
 853    def test_eager_list(self):
 854        Parent, Child = self.Parent, self.Child
 855
 856        mapper(Parent, self.table, properties={
 857            '_children': relationship(Child, lazy='joined',
 858                                  collection_class=list)})
 859
 860        p = Parent('p')
 861        p.children = ['a','b','c']
 862
 863        p = self.roundtrip(p)
 864
 865        self.assert_('_children' in p.__dict__)
 866        self.assert_(len(p._children) == 3)
 867
 868    def test_lazy_scalar(self):
 869        Parent, Child = self.Parent, self.Child
 870
 871        mapper(Parent, self.table, properties={
 872            '_children': relationship(Child, lazy='select', uselist=False)})
 873
 874
 875        p = Parent('p')
 876        p.children = 'value'
 877
 878        p = self.roundtrip(p)
 879
 880        self.assert_('_children' not in p.__dict__)
 881        self.assert_(p._children is not None)
 882
 883    def test_eager_scalar(self):
 884        Parent, Child = self.Parent, self.Child
 885
 886        mapper(Parent, self.table, properties={
 887            '_children': relationship(Child, lazy='joined', uselist=False)})
 888
 889
 890        p = Parent('p')
 891        p.children = 'value'
 892
 893        p = self.roundtrip(p)
 894
 895        self.assert_('_children' in p.__dict__)
 896        self.assert_(p._children is not None)
 897
 898
 899class Parent(object):
 900    def __init__(self, name):
 901        self.name = name
 902
 903class Child(object):
 904    def __init__(self, name):
 905        self.name = name
 906
 907class KVChild(object):
 908    def __init__(self, name, value):
 909        self.name = name
 910        self.value = value
 911
 912class ReconstitutionTest(fixtures.TestBase):
 913
 914    def setup(self):
 915        metadata = MetaData(testing.db)
 916        parents = Table('parents', metadata, Column('id', Integer,
 917                        primary_key=True,
 918                        test_needs_autoincrement=True), Column('name',
 919                        String(30)))
 920        children = Table('children', metadata, Column('id', Integer,
 921                         primary_key=True,
 922                         test_needs_autoincrement=True),
 923                         Column('parent_id', Integer,
 924                         ForeignKey('parents.id')), Column('name',
 925                         String(30)))
 926        metadata.create_all()
 927        parents.insert().execute(name='p1')
 928        self.metadata = metadata
 929        self.parents = parents
 930        self.children = children
 931        Parent.kids = association_proxy('children', 'name')
 932
 933    def teardown(self):
 934        self.metadata.drop_all()
 935        clear_mappers()
 936
 937    def test_weak_identity_map(self):
 938        mapper(Parent, self.parents,
 939               properties=dict(children=relationship(Child)))
 940        mapper(Child, self.children)
 941        session = create_session(weak_identity_map=True)
 942
 943        def add_child(parent_name, child_name):
 944            parent = \
 945                session.query(Parent).filter_by(name=parent_name).one()
 946            parent.kids.append(child_name)
 947
 948        add_child('p1', 'c1')
 949        gc_collect()
 950        add_child('p1', 'c2')
 951        session.flush()
 952        p = session.query(Parent).filter_by(name='p1').one()
 953        assert set(p.kids) == set(['c1', 'c2']), p.kids
 954
 955    def test_copy(self):
 956        mapper(Parent, self.parents,
 957               properties=dict(children=relationship(Child)))
 958        mapper(Child, self.children)
 959        p = Parent('p1')
 960        p.kids.extend(['c1', 'c2'])
 961        p_copy = copy.copy(p)
 962        del p
 963        gc_collect()
 964        assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
 965
 966    def test_pickle_list(self):
 967        mapper(Parent, self.parents,
 968               properties=dict(children=relationship(Child)))
 969        mapper(Child, self.children)
 970        p = Parent('p1')
 971        p.kids.extend(['c1', 'c2'])
 972        r1 = pickle.loads(pickle.dumps(p))
 973        assert r1.kids == ['c1', 'c2']
 974        r2 = pickle.loads(pickle.dumps(p.kids))
 975        assert r2 == ['c1', 'c2']
 976
 977    def test_pickle_set(self):
 978        mapper(Parent, self.parents,
 979               properties=dict(children=relationship(Child,
 980               collection_class=set)))
 981        mapper(Child, self.children)
 982        p = Parent('p1')
 983        p.kids.update(['c1', 'c2'])
 984        r1 = pickle.loads(pickle.dumps(p))
 985        assert r1.kids == set(['c1', 'c2'])
 986        r2 = pickle.loads(pickle.dumps(p.kids))
 987        assert r2 == set(['c1', 'c2'])
 988
 989    def test_pickle_dict(self):
 990        mapper(Parent, self.parents,
 991               properties=dict(children=relationship(KVChild,
 992               collection_class=
 993                    collections.mapped_collection(PickleKeyFunc('name')))))
 994        mapper(KVChild, self.children)
 995        p = Parent('p1')
 996        p.kids.update({'c1': 'v1', 'c2': 'v2'})
 997        assert p.kids == {'c1': 'c1', 'c2': 'c2'}
 998        r1 = pickle.loads(pickle.dumps(p))
 999        assert r1.kids == {'c1': 'c1', 'c2': 'c2'}
1000        r2 = pickle.loads(pickle.dumps(p.kids))
1001        assert r2 == {'c1': 'c1', 'c2': 'c2'}
1002
1003class PickleKeyFunc(object):
1004    def __init__(self, name):
1005        self.name = name
1006
1007    def __call__(self, obj):
1008        return getattr(obj, self.name)
1009
1010class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
1011    __dialect__ = 'default'
1012
1013    run_inserts = 'once'
1014    run_deletes = None
1015    run_setup_mappers = 'once'
1016    run_setup_classes = 'once'
1017
1018    @classmethod
1019    def define_tables(cls, metadata):
1020        Table('userkeywords', metadata, 
1021          Column('keyword_id', Integer,ForeignKey('keywords.id'), primary_key=True),
1022          Column('user_id', Integer, ForeignKey('users.id'))
1023        )
1024        Table('users', metadata, 
1025            Column('id', Integer,
1026              primary_key=True, test_needs_autoincrement=True),
1027            Column('name', String(64)),
1028            Column('singular_id', Integer, ForeignKey('singular.id'))
1029        )
1030        Table('keywords', metadata, 
1031            Column('id', Integer,
1032              primary_key=True, test_needs_autoincrement=True),
1033            Column('keyword', String(64)),
1034            Column('singular_id', Integer, ForeignKey('singular.id'))
1035        )
1036        Table('singular', metadata,
1037            Column('id', Integer,
1038              primary_key=True, test_needs_autoincrement=True),
1039        )
1040
1041    @classmethod
1042    def setup_classes(cls):
1043        class User(cls.Comparable):
1044            def __init__(self, name):
1045                self.name = name
1046
1047            # o2m -> m2o
1048            # uselist -> nonuselist
1049            keywords = association_proxy('user_keywords', 'keyword',
1050                    creator=lambda k: UserKeyword(keyword=k))
1051
1052            # m2o -> o2m
1053            # nonuselist -> uselist
1054            singular_keywords = association_proxy('singular', 'keywords')
1055
1056        class Keyword(cls.Comparable):
1057            def __init__(self, keyword):
1058                self.keyword = keyword
1059
1060            # o2o -> m2o
1061            # nonuselist -> nonuselist
1062            user = association_proxy('user_keyword', 'user')
1063
1064        class UserKeyword(cls.Comparable):
1065            def __init__(self, user=None, keyword=None):
1066                self.user = user
1067                self.keyword = keyword
1068
1069        class Singular(cls.Comparable):
1070            def __init__(self, value=None):
1071                self.value = value
1072
1073    @classmethod
1074    def setup_mappers(cls):
1075        users, Keyword, UserKeyword, singular, \
1076            userkeywords, User, keywords, Singular = (cls.tables.users,
1077                                cls.classes.Keyword,
1078                                cls.classes.UserKeyword,
1079                                cls.tables.singular,
1080                                cls.tables.userkeywords,
1081                                cls.classes.User,
1082                                cls.tables.keywords,
1083                                cls.classes.Singular)
1084
1085        mapper(User, users, properties={
1086            'singular':relationship(Singular)
1087        })
1088        mapper(Keyword, keywords, properties={
1089            'user_keyword':relationship(UserKeyword, uselist=False)
1090        })
1091
1092        mapper(UserKeyword, userkeywords, properties={
1093            'user' : relationship(User, backref='user_keywords'), 
1094            'keyword' : relationship(Keyword)
1095        })
1096        mapper(Singular, singular, properties={
1097            'keywords': relationship(Keyword)
1098        })
1099
1100    @classmethod
1101    def insert_data(cls):
1102        UserKeyword, User, Keyword, Singular = (cls.classes.UserKeyword,
1103                                cls.classes.User,
1104                                cls.classes.Keyword,
1105                                cls.classes.Singular)
1106
1107        session = sessionmaker()()
1108        words = (
1109            'quick', 'brown',
1110            'fox', 'jumped', 'over',
1111            'the', 'lazy',
1112            )
1113        for ii in range(4):
1114            user = User('user%d' % ii)
1115            user.singular = Singular()
1116            session.add(user)
1117            for jj in words[ii:ii + 3]:
1118                k = Keyword(jj)
1119                user.keywords.append(k)
1120                user.singular.keywords.append(k)
1121        orphan = Keyword('orphan')
1122        orphan.user_keyword = UserKeyword(keyword=orphan, user=None)
1123        session.add(orphan)
1124        session.commit()
1125        cls.u = user
1126        cls.kw = user.keywords[0]
1127        cls.session = session
1128
1129    def _equivalent(self, q_proxy, q_direct):
1130        eq_(q_proxy.all(), q_direct.all())
1131
1132    def test_filter_any_kwarg_ul_nul(self):
1133        UserKeyword, User = self.classes.UserKeyword, self.classes.User
1134
1135        self._equivalent(self.session.query(User).
1136                    filter(User.keywords.any(keyword='jumped'
1137                         )),
1138                         self.session.query(User).filter(
1139                                User.user_keywords.any(
1140                            UserKeyword.keyword.has(keyword='jumped'
1141                         ))))
1142
1143    def test_filter_has_kwarg_nul_nul(self):
1144        UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
1145
1146        self._equivalent(self.session.query(Keyword).
1147                    filter(Keyword.user.has(name='user2'
1148                         )),
1149                         self.session.query(Keyword).
1150                            filter(Keyword.user_keyword.has(
1151                            UserKeyword.user.has(name='user2'
1152                         ))))
1153
1154    def test_filter_has_kwarg_nul_ul(self):
1155        User, Singular = self.classes.User, self.classes.Singular
1156
1157        self._equivalent(
1158            self.session.query(User).\
1159                        filter(User.singular_keywords.any(keyword='jumped')),
1160            self.session.query(User).\
1161                        filter(
1162                            User.singular.has(
1163                                Singular.keywords.any(keyword='jumped')
1164                            )
1165                        )
1166        )
1167
1168    def test_filter_any_criterion_ul_nul(self):
1169        UserKeyword, User, Keyword = (self.classes.UserKeyword,
1170                                self.classes.User,
1171                                self.classes.Keyword)
1172
1173        self._equivalent(self.session.query(User).
1174                    filter(User.keywords.any(Keyword.keyword
1175                         == 'jumped')),
1176                         self.session.query(User).
1177                            filter(User.user_keywords.any(
1178                            UserKeyword.keyword.has(Keyword.keyword
1179                         == 'jumped'))))
1180
1181    def test_filter_has_criterion_nul_nul(self):
1182        UserKeyword, User, Keyword = (self.classes.UserKeyword,
1183                                self.classes.User,
1184                                self.classes.Keyword)
1185
1186        self._equivalent(self.session.query(Keyword).
1187                filter(Keyword.user.has(User.name
1188                         == 'user2')),
1189                         self.session.query(Keyword).
1190                            filter(Keyword.user_keyword.has(
1191                                UserKeyword.user.has(User.name
1192                         == 'user2'))))
1193
1194    def test_filter_any_criterion_nul_ul(self):
1195        User, Keyword, Singular = (self.classes.User,
1196                                self.classes.Keyword,
1197                                self.classes.Singular)
1198
1199        self._equivalent(
1200            self.session.query(User).\
1201                        filter(User.singular_keywords.any(Keyword.keyword=='jumped')),
1202            self.session.query(User).\
1203                        filter(
1204                            User.singular.has(
1205                                Singular.keywords.any(Keyword.keyword=='jumped')
1206                            )
1207                        )
1208        )
1209
1210    def test_filter_contains_ul_nul(self):
1211        User = self.classes.User
1212
1213        self._equivalent(self.session.query(User).
1214        filter(User.keywords.contains(self.kw)),
1215                         self.session.query(User).
1216                         filter(User.user_keywords.any(keyword=self.kw)))
1217
1218    def test_filter_contains_nul_ul(self):
1219        User, Singular = self.classes.User, self.classes.Singular
1220
1221        self._equivalent(
1222            self.session.query(User).filter(
1223                            User.singular_keywords.contains(self.kw)
1224            ),
1225            self.session.query(User).filter(
1226                            User.singular.has(
1227                                Singular.keywords.contains(self.kw)
1228                            )
1229            ),
1230        )
1231
1232    def test_filter_eq_nul_nul(self):
1233        Keyword = self.classes.Keyword
1234
1235        self._equivalent(self.session.query(Keyword).filter(Keyword.user
1236                         == self.u),
1237                         self.session.query(Keyword).
1238                         filter(Keyword.user_keyword.has(user=self.u)))
1239
1240    def test_filter_ne_nul_nul(self):
1241        Keyword = self.classes.Keyword
1242
1243        self._equivalent(self.session.query(Keyword).filter(Keyword.user
1244                         != self.u),
1245                         self.session.query(Keyword).
1246                         filter(not_(Keyword.user_keyword.has(user=self.u))))
1247
1248    def test_filter_eq_null_nul_nul(self):
1249        UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
1250
1251        self._equivalent(self.session.query(Keyword).filter(Keyword.user
1252                         == None),
1253                         self.session.query(Keyword).
1254                            filter(Keyword.user_keyword.has(UserKeyword.user
1255                         == None)))
1256
1257    def test_filter_scalar_contains_fails_nul_nul(self):
1258        Keyword = self.classes.Keyword
1259
1260        assert_raises(exceptions.InvalidRequestError, lambda : \
1261                      Keyword.user.contains(self.u))
1262
1263    def test_filter_scalar_any_fails_nul_nul(self):
1264        Keyword = self.classes.Keyword
1265
1266        assert_raises(exceptions.InvalidRequestError, lambda : \
1267                      Keyword.user.any(name='user2'))
1268
1269    def test_filter_collection_has_fails_ul_nul(self):
1270        User = self.classes.User
1271
1272        assert_raises(exceptions.InvalidRequestError, lambda : \
1273                      User.keywords.has(keyword='quick'))
1274
1275    def test_filter_collection_eq_fails_ul_nul(self):
1276        User = self.classes.User
1277
1278        assert_raises(exceptions.InvalidRequestError, lambda : \
1279                      User.keywords == self.kw)
1280
1281    def test_filter_collection_ne_fails_ul_nul(self):
1282        User = self.classes.User
1283
1284        assert_raises(exceptions.InvalidRequestError, lambda : \
1285                      User.keywords != self.kw)
1286
1287    def test_join_separate_attr(self):
1288        User = self.classes.User
1289        self.assert_compile(
1290            self.session.query(User).join(
1291                        User.keywords.local_attr, 
1292                        User.keywords.remote_attr),
1293            "SELECT users.id AS users_id, users.name AS users_name, "
1294            "users.singular_id AS users_singular_id "
1295            "FROM users JOIN userkeywords ON users.id = "
1296            "userkeywords.user_id JOIN keywords ON keywords.id = "
1297            "userkeywords.keyword_id"
1298        )
1299
1300    def test_join_single_attr(self):
1301        User = self.classes.User
1302        self.assert_compile(
1303            self.session.query(User).join(
1304                        *User.keywords.attr),
1305            "SELECT users.id AS users_id, users.name AS users_name, "
1306            "users.singular_id AS users_singular_id "
1307            "FROM users JOIN userkeywords ON users.id = "
1308            "userkeywords.user_id JOIN keywords ON keywords.id = "
1309            "userkeywords.keyword_id"
1310        )
1311
1312class DictOfTupleUpdateTest(fixtures.TestBase):
1313    def setup(self):
1314        class B(object):
1315            def __init__(self, key, elem):
1316                self.key = key
1317                self.elem = elem
1318
1319        class A(object):
1320            elements = association_proxy("orig", "elem", creator=B)
1321
1322        m = MetaData()
1323        a = Table('a', m, Column('id', Integer, primary_key=True))
1324        b = Table('b', m, Column('id', Integer, primary_key=True), 
1325                    Column('aid', Integer, ForeignKey('a.id')))
1326        mapper(A, a, properties={
1327            'orig':relationship(B, collection_class=attribute_mapped_collection('key'))
1328        })
1329        mapper(B, b)
1330        self.A = A
1331        self.B = B
1332
1333    def test_update_one_elem_dict(self):
1334        a1 = self.A()
1335        a1.elements.update({("B", 3): 'elem2'})
1336        eq_(a1.elements, {("B",3):'elem2'})
1337
1338    def test_update_multi_elem_dict(self):
1339        a1 = self.A()
1340        a1.elements.update({("B", 3): 'elem2', ("C", 4): "elem3"})
1341        eq_(a1.elements, {("B",3):'elem2', ("C", 4): "elem3"})
1342
1343    def test_update_one_elem_list(self):
1344        a1 = self.A()
1345        a1.elements.update([(("B", 3), 'elem2')])
1346        eq_(a1.elements, {("B",3):'elem2'})
1347
1348    def test_update_multi_elem_list(self):
1349        a1 = self.A()
1350        a1.elements.update([(("B", 3), 'elem2'), (("C", 4), "elem3")])
1351        eq_(a1.elements, {("B",3):'elem2', ("C", 4): "elem3"})
1352
1353    def test_update_one_elem_varg(self):
1354        a1 = self.A()
1355        assert_raises_message(
1356            ValueError,
1357            "dictionary update sequence requires "
1358            "2-element tuples",
1359            a1.elements.update, (("B", 3), 'elem2')
1360        )
1361
1362    def test_update_multi_elem_varg(self):
1363        a1 = self.A()
1364        assert_raises_message(
1365            TypeError,
1366            "update expected at most 1 arguments, got 2",
1367            a1.elements.update,
1368            (("B", 3), 'elem2'), (("C", 4), "elem3")
1369        )