PageRenderTime 58ms CodeModel.GetById 7ms app.highlight 44ms RepoModel.GetById 1ms app.codeStats 0ms

/Lib/test/test_iter.py

http://unladen-swallow.googlecode.com/
Python | 886 lines | 775 code | 72 blank | 39 comment | 61 complexity | ab9d4cf7288196b2759758c34c729b36 MD5 | raw file
  1# Test iterators.
  2
  3import unittest
  4from test.test_support import run_unittest, TESTFN, unlink, have_unicode
  5
  6# Test result of triple loop (too big to inline)
  7TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
  8            (0, 1, 0), (0, 1, 1), (0, 1, 2),
  9            (0, 2, 0), (0, 2, 1), (0, 2, 2),
 10
 11            (1, 0, 0), (1, 0, 1), (1, 0, 2),
 12            (1, 1, 0), (1, 1, 1), (1, 1, 2),
 13            (1, 2, 0), (1, 2, 1), (1, 2, 2),
 14
 15            (2, 0, 0), (2, 0, 1), (2, 0, 2),
 16            (2, 1, 0), (2, 1, 1), (2, 1, 2),
 17            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
 18
 19# Helper classes
 20
 21class BasicIterClass:
 22    def __init__(self, n):
 23        self.n = n
 24        self.i = 0
 25    def next(self):
 26        res = self.i
 27        if res >= self.n:
 28            raise StopIteration
 29        self.i = res + 1
 30        return res
 31
 32class IteratingSequenceClass:
 33    def __init__(self, n):
 34        self.n = n
 35    def __iter__(self):
 36        return BasicIterClass(self.n)
 37
 38class SequenceClass:
 39    def __init__(self, n):
 40        self.n = n
 41    def __getitem__(self, i):
 42        if 0 <= i < self.n:
 43            return i
 44        else:
 45            raise IndexError
 46
 47# Main test suite
 48
 49class TestCase(unittest.TestCase):
 50
 51    # Helper to check that an iterator returns a given sequence
 52    def check_iterator(self, it, seq):
 53        res = []
 54        while 1:
 55            try:
 56                val = it.next()
 57            except StopIteration:
 58                break
 59            res.append(val)
 60        self.assertEqual(res, seq)
 61
 62    # Helper to check that a for loop generates a given sequence
 63    def check_for_loop(self, expr, seq):
 64        res = []
 65        for val in expr:
 66            res.append(val)
 67        self.assertEqual(res, seq)
 68
 69    # Test basic use of iter() function
 70    def test_iter_basic(self):
 71        self.check_iterator(iter(range(10)), range(10))
 72
 73    # Test that iter(iter(x)) is the same as iter(x)
 74    def test_iter_idempotency(self):
 75        seq = range(10)
 76        it = iter(seq)
 77        it2 = iter(it)
 78        self.assert_(it is it2)
 79
 80    # Test that for loops over iterators work
 81    def test_iter_for_loop(self):
 82        self.check_for_loop(iter(range(10)), range(10))
 83
 84    # Test several independent iterators over the same list
 85    def test_iter_independence(self):
 86        seq = range(3)
 87        res = []
 88        for i in iter(seq):
 89            for j in iter(seq):
 90                for k in iter(seq):
 91                    res.append((i, j, k))
 92        self.assertEqual(res, TRIPLETS)
 93
 94    # Test triple list comprehension using iterators
 95    def test_nested_comprehensions_iter(self):
 96        seq = range(3)
 97        res = [(i, j, k)
 98               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
 99        self.assertEqual(res, TRIPLETS)
100
101    # Test triple list comprehension without iterators
102    def test_nested_comprehensions_for(self):
103        seq = range(3)
104        res = [(i, j, k) for i in seq for j in seq for k in seq]
105        self.assertEqual(res, TRIPLETS)
106
107    # Test a class with __iter__ in a for loop
108    def test_iter_class_for(self):
109        self.check_for_loop(IteratingSequenceClass(10), range(10))
110
111    # Test a class with __iter__ with explicit iter()
112    def test_iter_class_iter(self):
113        self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
114
115    # Test for loop on a sequence class without __iter__
116    def test_seq_class_for(self):
117        self.check_for_loop(SequenceClass(10), range(10))
118
119    # Test iter() on a sequence class without __iter__
120    def test_seq_class_iter(self):
121        self.check_iterator(iter(SequenceClass(10)), range(10))
122
123    # Test two-argument iter() with callable instance
124    def test_iter_callable(self):
125        class C:
126            def __init__(self):
127                self.i = 0
128            def __call__(self):
129                i = self.i
130                self.i = i + 1
131                if i > 100:
132                    raise IndexError # Emergency stop
133                return i
134        self.check_iterator(iter(C(), 10), range(10))
135
136    # Test two-argument iter() with function
137    def test_iter_function(self):
138        def spam(state=[0]):
139            i = state[0]
140            state[0] = i+1
141            return i
142        self.check_iterator(iter(spam, 10), range(10))
143
144    # Test two-argument iter() with function that raises StopIteration
145    def test_iter_function_stop(self):
146        def spam(state=[0]):
147            i = state[0]
148            if i == 10:
149                raise StopIteration
150            state[0] = i+1
151            return i
152        self.check_iterator(iter(spam, 20), range(10))
153
154    # Test exception propagation through function iterator
155    def test_exception_function(self):
156        def spam(state=[0]):
157            i = state[0]
158            state[0] = i+1
159            if i == 10:
160                raise RuntimeError
161            return i
162        res = []
163        try:
164            for x in iter(spam, 20):
165                res.append(x)
166        except RuntimeError:
167            self.assertEqual(res, range(10))
168        else:
169            self.fail("should have raised RuntimeError")
170
171    # Test exception propagation through sequence iterator
172    def test_exception_sequence(self):
173        class MySequenceClass(SequenceClass):
174            def __getitem__(self, i):
175                if i == 10:
176                    raise RuntimeError
177                return SequenceClass.__getitem__(self, i)
178        res = []
179        try:
180            for x in MySequenceClass(20):
181                res.append(x)
182        except RuntimeError:
183            self.assertEqual(res, range(10))
184        else:
185            self.fail("should have raised RuntimeError")
186
187    # Test for StopIteration from __getitem__
188    def test_stop_sequence(self):
189        class MySequenceClass(SequenceClass):
190            def __getitem__(self, i):
191                if i == 10:
192                    raise StopIteration
193                return SequenceClass.__getitem__(self, i)
194        self.check_for_loop(MySequenceClass(20), range(10))
195
196    # Test a big range
197    def test_iter_big_range(self):
198        self.check_for_loop(iter(range(10000)), range(10000))
199
200    # Test an empty list
201    def test_iter_empty(self):
202        self.check_for_loop(iter([]), [])
203
204    # Test a tuple
205    def test_iter_tuple(self):
206        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
207
208    # Test an xrange
209    def test_iter_xrange(self):
210        self.check_for_loop(iter(xrange(10)), range(10))
211
212    # Test a string
213    def test_iter_string(self):
214        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
215
216    # Test a Unicode string
217    if have_unicode:
218        def test_iter_unicode(self):
219            self.check_for_loop(iter(unicode("abcde")),
220                                [unicode("a"), unicode("b"), unicode("c"),
221                                 unicode("d"), unicode("e")])
222
223    # Test a directory
224    def test_iter_dict(self):
225        dict = {}
226        for i in range(10):
227            dict[i] = None
228        self.check_for_loop(dict, dict.keys())
229
230    # Test a file
231    def test_iter_file(self):
232        f = open(TESTFN, "w")
233        try:
234            for i in range(5):
235                f.write("%d\n" % i)
236        finally:
237            f.close()
238        f = open(TESTFN, "r")
239        try:
240            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
241            self.check_for_loop(f, [])
242        finally:
243            f.close()
244            try:
245                unlink(TESTFN)
246            except OSError:
247                pass
248
249    # Test list()'s use of iterators.
250    def test_builtin_list(self):
251        self.assertEqual(list(SequenceClass(5)), range(5))
252        self.assertEqual(list(SequenceClass(0)), [])
253        self.assertEqual(list(()), [])
254        self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
255
256        d = {"one": 1, "two": 2, "three": 3}
257        self.assertEqual(list(d), d.keys())
258
259        self.assertRaises(TypeError, list, list)
260        self.assertRaises(TypeError, list, 42)
261
262        f = open(TESTFN, "w")
263        try:
264            for i in range(5):
265                f.write("%d\n" % i)
266        finally:
267            f.close()
268        f = open(TESTFN, "r")
269        try:
270            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
271            f.seek(0, 0)
272            self.assertEqual(list(f),
273                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
274        finally:
275            f.close()
276            try:
277                unlink(TESTFN)
278            except OSError:
279                pass
280
281    # Test tuples()'s use of iterators.
282    def test_builtin_tuple(self):
283        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
284        self.assertEqual(tuple(SequenceClass(0)), ())
285        self.assertEqual(tuple([]), ())
286        self.assertEqual(tuple(()), ())
287        self.assertEqual(tuple("abc"), ("a", "b", "c"))
288
289        d = {"one": 1, "two": 2, "three": 3}
290        self.assertEqual(tuple(d), tuple(d.keys()))
291
292        self.assertRaises(TypeError, tuple, list)
293        self.assertRaises(TypeError, tuple, 42)
294
295        f = open(TESTFN, "w")
296        try:
297            for i in range(5):
298                f.write("%d\n" % i)
299        finally:
300            f.close()
301        f = open(TESTFN, "r")
302        try:
303            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
304            f.seek(0, 0)
305            self.assertEqual(tuple(f),
306                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
307        finally:
308            f.close()
309            try:
310                unlink(TESTFN)
311            except OSError:
312                pass
313
314    # Test filter()'s use of iterators.
315    def test_builtin_filter(self):
316        self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
317        self.assertEqual(filter(None, SequenceClass(0)), [])
318        self.assertEqual(filter(None, ()), ())
319        self.assertEqual(filter(None, "abc"), "abc")
320
321        d = {"one": 1, "two": 2, "three": 3}
322        self.assertEqual(filter(None, d), d.keys())
323
324        self.assertRaises(TypeError, filter, None, list)
325        self.assertRaises(TypeError, filter, None, 42)
326
327        class Boolean:
328            def __init__(self, truth):
329                self.truth = truth
330            def __nonzero__(self):
331                return self.truth
332        bTrue = Boolean(1)
333        bFalse = Boolean(0)
334
335        class Seq:
336            def __init__(self, *args):
337                self.vals = args
338            def __iter__(self):
339                class SeqIter:
340                    def __init__(self, vals):
341                        self.vals = vals
342                        self.i = 0
343                    def __iter__(self):
344                        return self
345                    def next(self):
346                        i = self.i
347                        self.i = i + 1
348                        if i < len(self.vals):
349                            return self.vals[i]
350                        else:
351                            raise StopIteration
352                return SeqIter(self.vals)
353
354        seq = Seq(*([bTrue, bFalse] * 25))
355        self.assertEqual(filter(lambda x: not x, seq), [bFalse]*25)
356        self.assertEqual(filter(lambda x: not x, iter(seq)), [bFalse]*25)
357
358    # Test max() and min()'s use of iterators.
359    def test_builtin_max_min(self):
360        self.assertEqual(max(SequenceClass(5)), 4)
361        self.assertEqual(min(SequenceClass(5)), 0)
362        self.assertEqual(max(8, -1), 8)
363        self.assertEqual(min(8, -1), -1)
364
365        d = {"one": 1, "two": 2, "three": 3}
366        self.assertEqual(max(d), "two")
367        self.assertEqual(min(d), "one")
368        self.assertEqual(max(d.itervalues()), 3)
369        self.assertEqual(min(iter(d.itervalues())), 1)
370
371        f = open(TESTFN, "w")
372        try:
373            f.write("medium line\n")
374            f.write("xtra large line\n")
375            f.write("itty-bitty line\n")
376        finally:
377            f.close()
378        f = open(TESTFN, "r")
379        try:
380            self.assertEqual(min(f), "itty-bitty line\n")
381            f.seek(0, 0)
382            self.assertEqual(max(f), "xtra large line\n")
383        finally:
384            f.close()
385            try:
386                unlink(TESTFN)
387            except OSError:
388                pass
389
390    # Test map()'s use of iterators.
391    def test_builtin_map(self):
392        self.assertEqual(map(None, SequenceClass(5)), range(5))
393        self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
394
395        d = {"one": 1, "two": 2, "three": 3}
396        self.assertEqual(map(None, d), d.keys())
397        self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
398        dkeys = d.keys()
399        expected = [(i < len(d) and dkeys[i] or None,
400                     i,
401                     i < len(d) and dkeys[i] or None)
402                    for i in range(5)]
403        self.assertEqual(map(None, d,
404                                   SequenceClass(5),
405                                   iter(d.iterkeys())),
406                         expected)
407
408        f = open(TESTFN, "w")
409        try:
410            for i in range(10):
411                f.write("xy" * i + "\n") # line i has len 2*i+1
412        finally:
413            f.close()
414        f = open(TESTFN, "r")
415        try:
416            self.assertEqual(map(len, f), range(1, 21, 2))
417        finally:
418            f.close()
419            try:
420                unlink(TESTFN)
421            except OSError:
422                pass
423
424    # Test zip()'s use of iterators.
425    def test_builtin_zip(self):
426        self.assertEqual(zip(), [])
427        self.assertEqual(zip(*[]), [])
428        self.assertEqual(zip(*[(1, 2), 'ab']), [(1, 'a'), (2, 'b')])
429
430        self.assertRaises(TypeError, zip, None)
431        self.assertRaises(TypeError, zip, range(10), 42)
432        self.assertRaises(TypeError, zip, range(10), zip)
433
434        self.assertEqual(zip(IteratingSequenceClass(3)),
435                         [(0,), (1,), (2,)])
436        self.assertEqual(zip(SequenceClass(3)),
437                         [(0,), (1,), (2,)])
438
439        d = {"one": 1, "two": 2, "three": 3}
440        self.assertEqual(d.items(), zip(d, d.itervalues()))
441
442        # Generate all ints starting at constructor arg.
443        class IntsFrom:
444            def __init__(self, start):
445                self.i = start
446
447            def __iter__(self):
448                return self
449
450            def next(self):
451                i = self.i
452                self.i = i+1
453                return i
454
455        f = open(TESTFN, "w")
456        try:
457            f.write("a\n" "bbb\n" "cc\n")
458        finally:
459            f.close()
460        f = open(TESTFN, "r")
461        try:
462            self.assertEqual(zip(IntsFrom(0), f, IntsFrom(-100)),
463                             [(0, "a\n", -100),
464                              (1, "bbb\n", -99),
465                              (2, "cc\n", -98)])
466        finally:
467            f.close()
468            try:
469                unlink(TESTFN)
470            except OSError:
471                pass
472
473        self.assertEqual(zip(xrange(5)), [(i,) for i in range(5)])
474
475        # Classes that lie about their lengths.
476        class NoGuessLen5:
477            def __getitem__(self, i):
478                if i >= 5:
479                    raise IndexError
480                return i
481
482        class Guess3Len5(NoGuessLen5):
483            def __len__(self):
484                return 3
485
486        class Guess30Len5(NoGuessLen5):
487            def __len__(self):
488                return 30
489
490        self.assertEqual(len(Guess3Len5()), 3)
491        self.assertEqual(len(Guess30Len5()), 30)
492        self.assertEqual(zip(NoGuessLen5()), zip(range(5)))
493        self.assertEqual(zip(Guess3Len5()), zip(range(5)))
494        self.assertEqual(zip(Guess30Len5()), zip(range(5)))
495
496        expected = [(i, i) for i in range(5)]
497        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
498            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
499                self.assertEqual(zip(x, y), expected)
500
501    # Test reduces()'s use of iterators.
502    def test_builtin_reduce(self):
503        from operator import add
504        self.assertEqual(reduce(add, SequenceClass(5)), 10)
505        self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
506        self.assertRaises(TypeError, reduce, add, SequenceClass(0))
507        self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
508        self.assertEqual(reduce(add, SequenceClass(1)), 0)
509        self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
510
511        d = {"one": 1, "two": 2, "three": 3}
512        self.assertEqual(reduce(add, d), "".join(d.keys()))
513
514    # This test case will be removed if we don't have Unicode
515    def test_unicode_join_endcase(self):
516
517        # This class inserts a Unicode object into its argument's natural
518        # iteration, in the 3rd position.
519        class OhPhooey:
520            def __init__(self, seq):
521                self.it = iter(seq)
522                self.i = 0
523
524            def __iter__(self):
525                return self
526
527            def next(self):
528                i = self.i
529                self.i = i+1
530                if i == 2:
531                    return unicode("fooled you!")
532                return self.it.next()
533
534        f = open(TESTFN, "w")
535        try:
536            f.write("a\n" + "b\n" + "c\n")
537        finally:
538            f.close()
539
540        f = open(TESTFN, "r")
541        # Nasty:  string.join(s) can't know whether unicode.join() is needed
542        # until it's seen all of s's elements.  But in this case, f's
543        # iterator cannot be restarted.  So what we're testing here is
544        # whether string.join() can manage to remember everything it's seen
545        # and pass that on to unicode.join().
546        try:
547            got = " - ".join(OhPhooey(f))
548            self.assertEqual(got, unicode("a\n - b\n - fooled you! - c\n"))
549        finally:
550            f.close()
551            try:
552                unlink(TESTFN)
553            except OSError:
554                pass
555    if not have_unicode:
556        def test_unicode_join_endcase(self): pass
557
558    # Test iterators with 'x in y' and 'x not in y'.
559    def test_in_and_not_in(self):
560        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
561            for i in range(5):
562                self.assert_(i in sc5)
563            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
564                self.assert_(i not in sc5)
565
566        self.assertRaises(TypeError, lambda: 3 in 12)
567        self.assertRaises(TypeError, lambda: 3 not in map)
568
569        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
570        for k in d:
571            self.assert_(k in d)
572            self.assert_(k not in d.itervalues())
573        for v in d.values():
574            self.assert_(v in d.itervalues())
575            self.assert_(v not in d)
576        for k, v in d.iteritems():
577            self.assert_((k, v) in d.iteritems())
578            self.assert_((v, k) not in d.iteritems())
579
580        f = open(TESTFN, "w")
581        try:
582            f.write("a\n" "b\n" "c\n")
583        finally:
584            f.close()
585        f = open(TESTFN, "r")
586        try:
587            for chunk in "abc":
588                f.seek(0, 0)
589                self.assert_(chunk not in f)
590                f.seek(0, 0)
591                self.assert_((chunk + "\n") in f)
592        finally:
593            f.close()
594            try:
595                unlink(TESTFN)
596            except OSError:
597                pass
598
599    # Test iterators with operator.countOf (PySequence_Count).
600    def test_countOf(self):
601        from operator import countOf
602        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
603        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
604        self.assertEqual(countOf("122325", "2"), 3)
605        self.assertEqual(countOf("122325", "6"), 0)
606
607        self.assertRaises(TypeError, countOf, 42, 1)
608        self.assertRaises(TypeError, countOf, countOf, countOf)
609
610        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
611        for k in d:
612            self.assertEqual(countOf(d, k), 1)
613        self.assertEqual(countOf(d.itervalues(), 3), 3)
614        self.assertEqual(countOf(d.itervalues(), 2j), 1)
615        self.assertEqual(countOf(d.itervalues(), 1j), 0)
616
617        f = open(TESTFN, "w")
618        try:
619            f.write("a\n" "b\n" "c\n" "b\n")
620        finally:
621            f.close()
622        f = open(TESTFN, "r")
623        try:
624            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
625                f.seek(0, 0)
626                self.assertEqual(countOf(f, letter + "\n"), count)
627        finally:
628            f.close()
629            try:
630                unlink(TESTFN)
631            except OSError:
632                pass
633
634    # Test iterators with operator.indexOf (PySequence_Index).
635    def test_indexOf(self):
636        from operator import indexOf
637        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
638        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
639        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
640        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
641        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
642        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
643
644        self.assertEqual(indexOf("122325", "2"), 1)
645        self.assertEqual(indexOf("122325", "5"), 5)
646        self.assertRaises(ValueError, indexOf, "122325", "6")
647
648        self.assertRaises(TypeError, indexOf, 42, 1)
649        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
650
651        f = open(TESTFN, "w")
652        try:
653            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
654        finally:
655            f.close()
656        f = open(TESTFN, "r")
657        try:
658            fiter = iter(f)
659            self.assertEqual(indexOf(fiter, "b\n"), 1)
660            self.assertEqual(indexOf(fiter, "d\n"), 1)
661            self.assertEqual(indexOf(fiter, "e\n"), 0)
662            self.assertRaises(ValueError, indexOf, fiter, "a\n")
663        finally:
664            f.close()
665            try:
666                unlink(TESTFN)
667            except OSError:
668                pass
669
670        iclass = IteratingSequenceClass(3)
671        for i in range(3):
672            self.assertEqual(indexOf(iclass, i), i)
673        self.assertRaises(ValueError, indexOf, iclass, -1)
674
675    # Test iterators with file.writelines().
676    def test_writelines(self):
677        f = file(TESTFN, "w")
678
679        try:
680            self.assertRaises(TypeError, f.writelines, None)
681            self.assertRaises(TypeError, f.writelines, 42)
682
683            f.writelines(["1\n", "2\n"])
684            f.writelines(("3\n", "4\n"))
685            f.writelines({'5\n': None})
686            f.writelines({})
687
688            # Try a big chunk too.
689            class Iterator:
690                def __init__(self, start, finish):
691                    self.start = start
692                    self.finish = finish
693                    self.i = self.start
694
695                def next(self):
696                    if self.i >= self.finish:
697                        raise StopIteration
698                    result = str(self.i) + '\n'
699                    self.i += 1
700                    return result
701
702                def __iter__(self):
703                    return self
704
705            class Whatever:
706                def __init__(self, start, finish):
707                    self.start = start
708                    self.finish = finish
709
710                def __iter__(self):
711                    return Iterator(self.start, self.finish)
712
713            f.writelines(Whatever(6, 6+2000))
714            f.close()
715
716            f = file(TESTFN)
717            expected = [str(i) + "\n" for i in range(1, 2006)]
718            self.assertEqual(list(f), expected)
719
720        finally:
721            f.close()
722            try:
723                unlink(TESTFN)
724            except OSError:
725                pass
726
727
728    # Test iterators on RHS of unpacking assignments.
729    def test_unpack_iter(self):
730        a, b = 1, 2
731        self.assertEqual((a, b), (1, 2))
732
733        a, b, c = IteratingSequenceClass(3)
734        self.assertEqual((a, b, c), (0, 1, 2))
735
736        try:    # too many values
737            a, b = IteratingSequenceClass(3)
738        except ValueError:
739            pass
740        else:
741            self.fail("should have raised ValueError")
742
743        try:    # not enough values
744            a, b, c = IteratingSequenceClass(2)
745        except ValueError:
746            pass
747        else:
748            self.fail("should have raised ValueError")
749
750        try:    # not iterable
751            a, b, c = len
752        except TypeError:
753            pass
754        else:
755            self.fail("should have raised TypeError")
756
757        a, b, c = {1: 42, 2: 42, 3: 42}.itervalues()
758        self.assertEqual((a, b, c), (42, 42, 42))
759
760        f = open(TESTFN, "w")
761        lines = ("a\n", "bb\n", "ccc\n")
762        try:
763            for line in lines:
764                f.write(line)
765        finally:
766            f.close()
767        f = open(TESTFN, "r")
768        try:
769            a, b, c = f
770            self.assertEqual((a, b, c), lines)
771        finally:
772            f.close()
773            try:
774                unlink(TESTFN)
775            except OSError:
776                pass
777
778        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
779        self.assertEqual((a, b, c), (0, 1, 42))
780
781        # Test reference count behavior
782
783        class C(object):
784            count = 0
785            def __new__(cls):
786                cls.count += 1
787                return object.__new__(cls)
788            def __del__(self):
789                cls = self.__class__
790                assert cls.count > 0
791                cls.count -= 1
792        x = C()
793        self.assertEqual(C.count, 1)
794        del x
795        self.assertEqual(C.count, 0)
796        l = [C(), C(), C()]
797        self.assertEqual(C.count, 3)
798        try:
799            a, b = iter(l)
800        except ValueError:
801            pass
802        del l
803        self.assertEqual(C.count, 0)
804
805
806    # Make sure StopIteration is a "sink state".
807    # This tests various things that weren't sink states in Python 2.2.1,
808    # plus various things that always were fine.
809
810    def test_sinkstate_list(self):
811        # This used to fail
812        a = range(5)
813        b = iter(a)
814        self.assertEqual(list(b), range(5))
815        a.extend(range(5, 10))
816        self.assertEqual(list(b), [])
817
818    def test_sinkstate_tuple(self):
819        a = (0, 1, 2, 3, 4)
820        b = iter(a)
821        self.assertEqual(list(b), range(5))
822        self.assertEqual(list(b), [])
823
824    def test_sinkstate_string(self):
825        a = "abcde"
826        b = iter(a)
827        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
828        self.assertEqual(list(b), [])
829
830    def test_sinkstate_sequence(self):
831        # This used to fail
832        a = SequenceClass(5)
833        b = iter(a)
834        self.assertEqual(list(b), range(5))
835        a.n = 10
836        self.assertEqual(list(b), [])
837
838    def test_sinkstate_callable(self):
839        # This used to fail
840        def spam(state=[0]):
841            i = state[0]
842            state[0] = i+1
843            if i == 10:
844                raise AssertionError, "shouldn't have gotten this far"
845            return i
846        b = iter(spam, 5)
847        self.assertEqual(list(b), range(5))
848        self.assertEqual(list(b), [])
849
850    def test_sinkstate_dict(self):
851        # XXX For a more thorough test, see towards the end of:
852        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
853        a = {1:1, 2:2, 0:0, 4:4, 3:3}
854        for b in iter(a), a.iterkeys(), a.iteritems(), a.itervalues():
855            b = iter(a)
856            self.assertEqual(len(list(b)), 5)
857            self.assertEqual(list(b), [])
858
859    def test_sinkstate_yield(self):
860        def gen():
861            for i in range(5):
862                yield i
863        b = gen()
864        self.assertEqual(list(b), range(5))
865        self.assertEqual(list(b), [])
866
867    def test_sinkstate_range(self):
868        a = xrange(5)
869        b = iter(a)
870        self.assertEqual(list(b), range(5))
871        self.assertEqual(list(b), [])
872
873    def test_sinkstate_enumerate(self):
874        a = range(5)
875        e = enumerate(a)
876        b = iter(e)
877        self.assertEqual(list(b), zip(range(5), range(5)))
878        self.assertEqual(list(b), [])
879
880
881def test_main():
882    run_unittest(TestCase)
883
884
885if __name__ == "__main__":
886    test_main()