PageRenderTime 49ms CodeModel.GetById 9ms app.highlight 30ms RepoModel.GetById 1ms app.codeStats 1ms

/thirdparty/breakpad/third_party/protobuf/protobuf/python/mox.py

http://github.com/tomahawk-player/tomahawk
Python | 1401 lines | 1347 code | 2 blank | 52 comment | 1 complexity | 9a0e09fca3ac9adb52ac737505f50404 MD5 | raw file
   1#!/usr/bin/python2.4
   2#
   3# Copyright 2008 Google Inc.
   4#
   5# Licensed under the Apache License, Version 2.0 (the "License");
   6# you may not use this file except in compliance with the License.
   7# You may obtain a copy of the License at
   8#
   9#      http://www.apache.org/licenses/LICENSE-2.0
  10#
  11# Unless required by applicable law or agreed to in writing, software
  12# distributed under the License is distributed on an "AS IS" BASIS,
  13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14# See the License for the specific language governing permissions and
  15# limitations under the License.
  16
  17# This file is used for testing.  The original is at:
  18#   http://code.google.com/p/pymox/
  19
  20"""Mox, an object-mocking framework for Python.
  21
  22Mox works in the record-replay-verify paradigm.  When you first create
  23a mock object, it is in record mode.  You then programmatically set
  24the expected behavior of the mock object (what methods are to be
  25called on it, with what parameters, what they should return, and in
  26what order).
  27
  28Once you have set up the expected mock behavior, you put it in replay
  29mode.  Now the mock responds to method calls just as you told it to.
  30If an unexpected method (or an expected method with unexpected
  31parameters) is called, then an exception will be raised.
  32
  33Once you are done interacting with the mock, you need to verify that
  34all the expected interactions occured.  (Maybe your code exited
  35prematurely without calling some cleanup method!)  The verify phase
  36ensures that every expected method was called; otherwise, an exception
  37will be raised.
  38
  39Suggested usage / workflow:
  40
  41  # Create Mox factory
  42  my_mox = Mox()
  43
  44  # Create a mock data access object
  45  mock_dao = my_mox.CreateMock(DAOClass)
  46
  47  # Set up expected behavior
  48  mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
  49  mock_dao.DeletePerson(person)
  50
  51  # Put mocks in replay mode
  52  my_mox.ReplayAll()
  53
  54  # Inject mock object and run test
  55  controller.SetDao(mock_dao)
  56  controller.DeletePersonById('1')
  57
  58  # Verify all methods were called as expected
  59  my_mox.VerifyAll()
  60"""
  61
  62from collections import deque
  63import re
  64import types
  65import unittest
  66
  67import stubout
  68
  69class Error(AssertionError):
  70  """Base exception for this module."""
  71
  72  pass
  73
  74
  75class ExpectedMethodCallsError(Error):
  76  """Raised when Verify() is called before all expected methods have been called
  77  """
  78
  79  def __init__(self, expected_methods):
  80    """Init exception.
  81
  82    Args:
  83      # expected_methods: A sequence of MockMethod objects that should have been
  84      #   called.
  85      expected_methods: [MockMethod]
  86
  87    Raises:
  88      ValueError: if expected_methods contains no methods.
  89    """
  90
  91    if not expected_methods:
  92      raise ValueError("There must be at least one expected method")
  93    Error.__init__(self)
  94    self._expected_methods = expected_methods
  95
  96  def __str__(self):
  97    calls = "\n".join(["%3d.  %s" % (i, m)
  98                       for i, m in enumerate(self._expected_methods)])
  99    return "Verify: Expected methods never called:\n%s" % (calls,)
 100
 101
 102class UnexpectedMethodCallError(Error):
 103  """Raised when an unexpected method is called.
 104
 105  This can occur if a method is called with incorrect parameters, or out of the
 106  specified order.
 107  """
 108
 109  def __init__(self, unexpected_method, expected):
 110    """Init exception.
 111
 112    Args:
 113      # unexpected_method: MockMethod that was called but was not at the head of
 114      #   the expected_method queue.
 115      # expected: MockMethod or UnorderedGroup the method should have
 116      #   been in.
 117      unexpected_method: MockMethod
 118      expected: MockMethod or UnorderedGroup
 119    """
 120
 121    Error.__init__(self)
 122    self._unexpected_method = unexpected_method
 123    self._expected = expected
 124
 125  def __str__(self):
 126    return "Unexpected method call: %s.  Expecting: %s" % \
 127      (self._unexpected_method, self._expected)
 128
 129
 130class UnknownMethodCallError(Error):
 131  """Raised if an unknown method is requested of the mock object."""
 132
 133  def __init__(self, unknown_method_name):
 134    """Init exception.
 135
 136    Args:
 137      # unknown_method_name: Method call that is not part of the mocked class's
 138      #   public interface.
 139      unknown_method_name: str
 140    """
 141
 142    Error.__init__(self)
 143    self._unknown_method_name = unknown_method_name
 144
 145  def __str__(self):
 146    return "Method called is not a member of the object: %s" % \
 147      self._unknown_method_name
 148
 149
 150class Mox(object):
 151  """Mox: a factory for creating mock objects."""
 152
 153  # A list of types that should be stubbed out with MockObjects (as
 154  # opposed to MockAnythings).
 155  _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
 156                      types.ObjectType, types.TypeType]
 157
 158  def __init__(self):
 159    """Initialize a new Mox."""
 160
 161    self._mock_objects = []
 162    self.stubs = stubout.StubOutForTesting()
 163
 164  def CreateMock(self, class_to_mock):
 165    """Create a new mock object.
 166
 167    Args:
 168      # class_to_mock: the class to be mocked
 169      class_to_mock: class
 170
 171    Returns:
 172      MockObject that can be used as the class_to_mock would be.
 173    """
 174
 175    new_mock = MockObject(class_to_mock)
 176    self._mock_objects.append(new_mock)
 177    return new_mock
 178
 179  def CreateMockAnything(self):
 180    """Create a mock that will accept any method calls.
 181
 182    This does not enforce an interface.
 183    """
 184
 185    new_mock = MockAnything()
 186    self._mock_objects.append(new_mock)
 187    return new_mock
 188
 189  def ReplayAll(self):
 190    """Set all mock objects to replay mode."""
 191
 192    for mock_obj in self._mock_objects:
 193      mock_obj._Replay()
 194
 195
 196  def VerifyAll(self):
 197    """Call verify on all mock objects created."""
 198
 199    for mock_obj in self._mock_objects:
 200      mock_obj._Verify()
 201
 202  def ResetAll(self):
 203    """Call reset on all mock objects.  This does not unset stubs."""
 204
 205    for mock_obj in self._mock_objects:
 206      mock_obj._Reset()
 207
 208  def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
 209    """Replace a method, attribute, etc. with a Mock.
 210
 211    This will replace a class or module with a MockObject, and everything else
 212    (method, function, etc) with a MockAnything.  This can be overridden to
 213    always use a MockAnything by setting use_mock_anything to True.
 214
 215    Args:
 216      obj: A Python object (class, module, instance, callable).
 217      attr_name: str.  The name of the attribute to replace with a mock.
 218      use_mock_anything: bool. True if a MockAnything should be used regardless
 219        of the type of attribute.
 220    """
 221
 222    attr_to_replace = getattr(obj, attr_name)
 223    if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
 224      stub = self.CreateMock(attr_to_replace)
 225    else:
 226      stub = self.CreateMockAnything()
 227
 228    self.stubs.Set(obj, attr_name, stub)
 229
 230  def UnsetStubs(self):
 231    """Restore stubs to their original state."""
 232
 233    self.stubs.UnsetAll()
 234
 235def Replay(*args):
 236  """Put mocks into Replay mode.
 237
 238  Args:
 239    # args is any number of mocks to put into replay mode.
 240  """
 241
 242  for mock in args:
 243    mock._Replay()
 244
 245
 246def Verify(*args):
 247  """Verify mocks.
 248
 249  Args:
 250    # args is any number of mocks to be verified.
 251  """
 252
 253  for mock in args:
 254    mock._Verify()
 255
 256
 257def Reset(*args):
 258  """Reset mocks.
 259
 260  Args:
 261    # args is any number of mocks to be reset.
 262  """
 263
 264  for mock in args:
 265    mock._Reset()
 266
 267
 268class MockAnything:
 269  """A mock that can be used to mock anything.
 270
 271  This is helpful for mocking classes that do not provide a public interface.
 272  """
 273
 274  def __init__(self):
 275    """ """
 276    self._Reset()
 277
 278  def __getattr__(self, method_name):
 279    """Intercept method calls on this object.
 280
 281     A new MockMethod is returned that is aware of the MockAnything's
 282     state (record or replay).  The call will be recorded or replayed
 283     by the MockMethod's __call__.
 284
 285    Args:
 286      # method name: the name of the method being called.
 287      method_name: str
 288
 289    Returns:
 290      A new MockMethod aware of MockAnything's state (record or replay).
 291    """
 292
 293    return self._CreateMockMethod(method_name)
 294
 295  def _CreateMockMethod(self, method_name):
 296    """Create a new mock method call and return it.
 297
 298    Args:
 299      # method name: the name of the method being called.
 300      method_name: str
 301
 302    Returns:
 303      A new MockMethod aware of MockAnything's state (record or replay).
 304    """
 305
 306    return MockMethod(method_name, self._expected_calls_queue,
 307                      self._replay_mode)
 308
 309  def __nonzero__(self):
 310    """Return 1 for nonzero so the mock can be used as a conditional."""
 311
 312    return 1
 313
 314  def __eq__(self, rhs):
 315    """Provide custom logic to compare objects."""
 316
 317    return (isinstance(rhs, MockAnything) and
 318            self._replay_mode == rhs._replay_mode and
 319            self._expected_calls_queue == rhs._expected_calls_queue)
 320
 321  def __ne__(self, rhs):
 322    """Provide custom logic to compare objects."""
 323
 324    return not self == rhs
 325
 326  def _Replay(self):
 327    """Start replaying expected method calls."""
 328
 329    self._replay_mode = True
 330
 331  def _Verify(self):
 332    """Verify that all of the expected calls have been made.
 333
 334    Raises:
 335      ExpectedMethodCallsError: if there are still more method calls in the
 336        expected queue.
 337    """
 338
 339    # If the list of expected calls is not empty, raise an exception
 340    if self._expected_calls_queue:
 341      # The last MultipleTimesGroup is not popped from the queue.
 342      if (len(self._expected_calls_queue) == 1 and
 343          isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
 344          self._expected_calls_queue[0].IsSatisfied()):
 345        pass
 346      else:
 347        raise ExpectedMethodCallsError(self._expected_calls_queue)
 348
 349  def _Reset(self):
 350    """Reset the state of this mock to record mode with an empty queue."""
 351
 352    # Maintain a list of method calls we are expecting
 353    self._expected_calls_queue = deque()
 354
 355    # Make sure we are in setup mode, not replay mode
 356    self._replay_mode = False
 357
 358
 359class MockObject(MockAnything, object):
 360  """A mock object that simulates the public/protected interface of a class."""
 361
 362  def __init__(self, class_to_mock):
 363    """Initialize a mock object.
 364
 365    This determines the methods and properties of the class and stores them.
 366
 367    Args:
 368      # class_to_mock: class to be mocked
 369      class_to_mock: class
 370    """
 371
 372    # This is used to hack around the mixin/inheritance of MockAnything, which
 373    # is not a proper object (it can be anything. :-)
 374    MockAnything.__dict__['__init__'](self)
 375
 376    # Get a list of all the public and special methods we should mock.
 377    self._known_methods = set()
 378    self._known_vars = set()
 379    self._class_to_mock = class_to_mock
 380    for method in dir(class_to_mock):
 381      if callable(getattr(class_to_mock, method)):
 382        self._known_methods.add(method)
 383      else:
 384        self._known_vars.add(method)
 385
 386  def __getattr__(self, name):
 387    """Intercept attribute request on this object.
 388
 389    If the attribute is a public class variable, it will be returned and not
 390    recorded as a call.
 391
 392    If the attribute is not a variable, it is handled like a method
 393    call. The method name is checked against the set of mockable
 394    methods, and a new MockMethod is returned that is aware of the
 395    MockObject's state (record or replay).  The call will be recorded
 396    or replayed by the MockMethod's __call__.
 397
 398    Args:
 399      # name: the name of the attribute being requested.
 400      name: str
 401
 402    Returns:
 403      Either a class variable or a new MockMethod that is aware of the state
 404      of the mock (record or replay).
 405
 406    Raises:
 407      UnknownMethodCallError if the MockObject does not mock the requested
 408          method.
 409    """
 410
 411    if name in self._known_vars:
 412      return getattr(self._class_to_mock, name)
 413
 414    if name in self._known_methods:
 415      return self._CreateMockMethod(name)
 416
 417    raise UnknownMethodCallError(name)
 418
 419  def __eq__(self, rhs):
 420    """Provide custom logic to compare objects."""
 421
 422    return (isinstance(rhs, MockObject) and
 423            self._class_to_mock == rhs._class_to_mock and
 424            self._replay_mode == rhs._replay_mode and
 425            self._expected_calls_queue == rhs._expected_calls_queue)
 426
 427  def __setitem__(self, key, value):
 428    """Provide custom logic for mocking classes that support item assignment.
 429
 430    Args:
 431      key: Key to set the value for.
 432      value: Value to set.
 433
 434    Returns:
 435      Expected return value in replay mode.  A MockMethod object for the
 436      __setitem__ method that has already been called if not in replay mode.
 437
 438    Raises:
 439      TypeError if the underlying class does not support item assignment.
 440      UnexpectedMethodCallError if the object does not expect the call to
 441        __setitem__.
 442
 443    """
 444    setitem = self._class_to_mock.__dict__.get('__setitem__', None)
 445
 446    # Verify the class supports item assignment.
 447    if setitem is None:
 448      raise TypeError('object does not support item assignment')
 449
 450    # If we are in replay mode then simply call the mock __setitem__ method.
 451    if self._replay_mode:
 452      return MockMethod('__setitem__', self._expected_calls_queue,
 453                        self._replay_mode)(key, value)
 454
 455
 456    # Otherwise, create a mock method __setitem__.
 457    return self._CreateMockMethod('__setitem__')(key, value)
 458
 459  def __getitem__(self, key):
 460    """Provide custom logic for mocking classes that are subscriptable.
 461
 462    Args:
 463      key: Key to return the value for.
 464
 465    Returns:
 466      Expected return value in replay mode.  A MockMethod object for the
 467      __getitem__ method that has already been called if not in replay mode.
 468
 469    Raises:
 470      TypeError if the underlying class is not subscriptable.
 471      UnexpectedMethodCallError if the object does not expect the call to
 472        __setitem__.
 473
 474    """
 475    getitem = self._class_to_mock.__dict__.get('__getitem__', None)
 476
 477    # Verify the class supports item assignment.
 478    if getitem is None:
 479      raise TypeError('unsubscriptable object')
 480
 481    # If we are in replay mode then simply call the mock __getitem__ method.
 482    if self._replay_mode:
 483      return MockMethod('__getitem__', self._expected_calls_queue,
 484                        self._replay_mode)(key)
 485
 486
 487    # Otherwise, create a mock method __getitem__.
 488    return self._CreateMockMethod('__getitem__')(key)
 489
 490  def __call__(self, *params, **named_params):
 491    """Provide custom logic for mocking classes that are callable."""
 492
 493    # Verify the class we are mocking is callable
 494    callable = self._class_to_mock.__dict__.get('__call__', None)
 495    if callable is None:
 496      raise TypeError('Not callable')
 497
 498    # Because the call is happening directly on this object instead of a method,
 499    # the call on the mock method is made right here
 500    mock_method = self._CreateMockMethod('__call__')
 501    return mock_method(*params, **named_params)
 502
 503  @property
 504  def __class__(self):
 505    """Return the class that is being mocked."""
 506
 507    return self._class_to_mock
 508
 509
 510class MockMethod(object):
 511  """Callable mock method.
 512
 513  A MockMethod should act exactly like the method it mocks, accepting parameters
 514  and returning a value, or throwing an exception (as specified).  When this
 515  method is called, it can optionally verify whether the called method (name and
 516  signature) matches the expected method.
 517  """
 518
 519  def __init__(self, method_name, call_queue, replay_mode):
 520    """Construct a new mock method.
 521
 522    Args:
 523      # method_name: the name of the method
 524      # call_queue: deque of calls, verify this call against the head, or add
 525      #     this call to the queue.
 526      # replay_mode: False if we are recording, True if we are verifying calls
 527      #     against the call queue.
 528      method_name: str
 529      call_queue: list or deque
 530      replay_mode: bool
 531    """
 532
 533    self._name = method_name
 534    self._call_queue = call_queue
 535    if not isinstance(call_queue, deque):
 536      self._call_queue = deque(self._call_queue)
 537    self._replay_mode = replay_mode
 538
 539    self._params = None
 540    self._named_params = None
 541    self._return_value = None
 542    self._exception = None
 543    self._side_effects = None
 544
 545  def __call__(self, *params, **named_params):
 546    """Log parameters and return the specified return value.
 547
 548    If the Mock(Anything/Object) associated with this call is in record mode,
 549    this MockMethod will be pushed onto the expected call queue.  If the mock
 550    is in replay mode, this will pop a MockMethod off the top of the queue and
 551    verify this call is equal to the expected call.
 552
 553    Raises:
 554      UnexpectedMethodCall if this call is supposed to match an expected method
 555        call and it does not.
 556    """
 557
 558    self._params = params
 559    self._named_params = named_params
 560
 561    if not self._replay_mode:
 562      self._call_queue.append(self)
 563      return self
 564
 565    expected_method = self._VerifyMethodCall()
 566
 567    if expected_method._side_effects:
 568      expected_method._side_effects(*params, **named_params)
 569
 570    if expected_method._exception:
 571      raise expected_method._exception
 572
 573    return expected_method._return_value
 574
 575  def __getattr__(self, name):
 576    """Raise an AttributeError with a helpful message."""
 577
 578    raise AttributeError('MockMethod has no attribute "%s". '
 579        'Did you remember to put your mocks in replay mode?' % name)
 580
 581  def _PopNextMethod(self):
 582    """Pop the next method from our call queue."""
 583    try:
 584      return self._call_queue.popleft()
 585    except IndexError:
 586      raise UnexpectedMethodCallError(self, None)
 587
 588  def _VerifyMethodCall(self):
 589    """Verify the called method is expected.
 590
 591    This can be an ordered method, or part of an unordered set.
 592
 593    Returns:
 594      The expected mock method.
 595
 596    Raises:
 597      UnexpectedMethodCall if the method called was not expected.
 598    """
 599
 600    expected = self._PopNextMethod()
 601
 602    # Loop here, because we might have a MethodGroup followed by another
 603    # group.
 604    while isinstance(expected, MethodGroup):
 605      expected, method = expected.MethodCalled(self)
 606      if method is not None:
 607        return method
 608
 609    # This is a mock method, so just check equality.
 610    if expected != self:
 611      raise UnexpectedMethodCallError(self, expected)
 612
 613    return expected
 614
 615  def __str__(self):
 616    params = ', '.join(
 617        [repr(p) for p in self._params or []] +
 618        ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
 619    desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
 620    return desc
 621
 622  def __eq__(self, rhs):
 623    """Test whether this MockMethod is equivalent to another MockMethod.
 624
 625    Args:
 626      # rhs: the right hand side of the test
 627      rhs: MockMethod
 628    """
 629
 630    return (isinstance(rhs, MockMethod) and
 631            self._name == rhs._name and
 632            self._params == rhs._params and
 633            self._named_params == rhs._named_params)
 634
 635  def __ne__(self, rhs):
 636    """Test whether this MockMethod is not equivalent to another MockMethod.
 637
 638    Args:
 639      # rhs: the right hand side of the test
 640      rhs: MockMethod
 641    """
 642
 643    return not self == rhs
 644
 645  def GetPossibleGroup(self):
 646    """Returns a possible group from the end of the call queue or None if no
 647    other methods are on the stack.
 648    """
 649
 650    # Remove this method from the tail of the queue so we can add it to a group.
 651    this_method = self._call_queue.pop()
 652    assert this_method == self
 653
 654    # Determine if the tail of the queue is a group, or just a regular ordered
 655    # mock method.
 656    group = None
 657    try:
 658      group = self._call_queue[-1]
 659    except IndexError:
 660      pass
 661
 662    return group
 663
 664  def _CheckAndCreateNewGroup(self, group_name, group_class):
 665    """Checks if the last method (a possible group) is an instance of our
 666    group_class. Adds the current method to this group or creates a new one.
 667
 668    Args:
 669
 670      group_name: the name of the group.
 671      group_class: the class used to create instance of this new group
 672    """
 673    group = self.GetPossibleGroup()
 674
 675    # If this is a group, and it is the correct group, add the method.
 676    if isinstance(group, group_class) and group.group_name() == group_name:
 677      group.AddMethod(self)
 678      return self
 679
 680    # Create a new group and add the method.
 681    new_group = group_class(group_name)
 682    new_group.AddMethod(self)
 683    self._call_queue.append(new_group)
 684    return self
 685
 686  def InAnyOrder(self, group_name="default"):
 687    """Move this method into a group of unordered calls.
 688
 689    A group of unordered calls must be defined together, and must be executed
 690    in full before the next expected method can be called.  There can be
 691    multiple groups that are expected serially, if they are given
 692    different group names.  The same group name can be reused if there is a
 693    standard method call, or a group with a different name, spliced between
 694    usages.
 695
 696    Args:
 697      group_name: the name of the unordered group.
 698
 699    Returns:
 700      self
 701    """
 702    return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
 703
 704  def MultipleTimes(self, group_name="default"):
 705    """Move this method into group of calls which may be called multiple times.
 706
 707    A group of repeating calls must be defined together, and must be executed in
 708    full before the next expected mehtod can be called.
 709
 710    Args:
 711      group_name: the name of the unordered group.
 712
 713    Returns:
 714      self
 715    """
 716    return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
 717
 718  def AndReturn(self, return_value):
 719    """Set the value to return when this method is called.
 720
 721    Args:
 722      # return_value can be anything.
 723    """
 724
 725    self._return_value = return_value
 726    return return_value
 727
 728  def AndRaise(self, exception):
 729    """Set the exception to raise when this method is called.
 730
 731    Args:
 732      # exception: the exception to raise when this method is called.
 733      exception: Exception
 734    """
 735
 736    self._exception = exception
 737
 738  def WithSideEffects(self, side_effects):
 739    """Set the side effects that are simulated when this method is called.
 740
 741    Args:
 742      side_effects: A callable which modifies the parameters or other relevant
 743        state which a given test case depends on.
 744
 745    Returns:
 746      Self for chaining with AndReturn and AndRaise.
 747    """
 748    self._side_effects = side_effects
 749    return self
 750
 751class Comparator:
 752  """Base class for all Mox comparators.
 753
 754  A Comparator can be used as a parameter to a mocked method when the exact
 755  value is not known.  For example, the code you are testing might build up a
 756  long SQL string that is passed to your mock DAO. You're only interested that
 757  the IN clause contains the proper primary keys, so you can set your mock
 758  up as follows:
 759
 760  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
 761
 762  Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
 763
 764  A Comparator may replace one or more parameters, for example:
 765  # return at most 10 rows
 766  mock_dao.RunQuery(StrContains('SELECT'), 10)
 767
 768  or
 769
 770  # Return some non-deterministic number of rows
 771  mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
 772  """
 773
 774  def equals(self, rhs):
 775    """Special equals method that all comparators must implement.
 776
 777    Args:
 778      rhs: any python object
 779    """
 780
 781    raise NotImplementedError, 'method must be implemented by a subclass.'
 782
 783  def __eq__(self, rhs):
 784    return self.equals(rhs)
 785
 786  def __ne__(self, rhs):
 787    return not self.equals(rhs)
 788
 789
 790class IsA(Comparator):
 791  """This class wraps a basic Python type or class.  It is used to verify
 792  that a parameter is of the given type or class.
 793
 794  Example:
 795  mock_dao.Connect(IsA(DbConnectInfo))
 796  """
 797
 798  def __init__(self, class_name):
 799    """Initialize IsA
 800
 801    Args:
 802      class_name: basic python type or a class
 803    """
 804
 805    self._class_name = class_name
 806
 807  def equals(self, rhs):
 808    """Check to see if the RHS is an instance of class_name.
 809
 810    Args:
 811      # rhs: the right hand side of the test
 812      rhs: object
 813
 814    Returns:
 815      bool
 816    """
 817
 818    try:
 819      return isinstance(rhs, self._class_name)
 820    except TypeError:
 821      # Check raw types if there was a type error.  This is helpful for
 822      # things like cStringIO.StringIO.
 823      return type(rhs) == type(self._class_name)
 824
 825  def __repr__(self):
 826    return str(self._class_name)
 827
 828class IsAlmost(Comparator):
 829  """Comparison class used to check whether a parameter is nearly equal
 830  to a given value.  Generally useful for floating point numbers.
 831
 832  Example mock_dao.SetTimeout((IsAlmost(3.9)))
 833  """
 834
 835  def __init__(self, float_value, places=7):
 836    """Initialize IsAlmost.
 837
 838    Args:
 839      float_value: The value for making the comparison.
 840      places: The number of decimal places to round to.
 841    """
 842
 843    self._float_value = float_value
 844    self._places = places
 845
 846  def equals(self, rhs):
 847    """Check to see if RHS is almost equal to float_value
 848
 849    Args:
 850      rhs: the value to compare to float_value
 851
 852    Returns:
 853      bool
 854    """
 855
 856    try:
 857      return round(rhs-self._float_value, self._places) == 0
 858    except TypeError:
 859      # This is probably because either float_value or rhs is not a number.
 860      return False
 861
 862  def __repr__(self):
 863    return str(self._float_value)
 864
 865class StrContains(Comparator):
 866  """Comparison class used to check whether a substring exists in a
 867  string parameter.  This can be useful in mocking a database with SQL
 868  passed in as a string parameter, for example.
 869
 870  Example:
 871  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
 872  """
 873
 874  def __init__(self, search_string):
 875    """Initialize.
 876
 877    Args:
 878      # search_string: the string you are searching for
 879      search_string: str
 880    """
 881
 882    self._search_string = search_string
 883
 884  def equals(self, rhs):
 885    """Check to see if the search_string is contained in the rhs string.
 886
 887    Args:
 888      # rhs: the right hand side of the test
 889      rhs: object
 890
 891    Returns:
 892      bool
 893    """
 894
 895    try:
 896      return rhs.find(self._search_string) > -1
 897    except Exception:
 898      return False
 899
 900  def __repr__(self):
 901    return '<str containing \'%s\'>' % self._search_string
 902
 903
 904class Regex(Comparator):
 905  """Checks if a string matches a regular expression.
 906
 907  This uses a given regular expression to determine equality.
 908  """
 909
 910  def __init__(self, pattern, flags=0):
 911    """Initialize.
 912
 913    Args:
 914      # pattern is the regular expression to search for
 915      pattern: str
 916      # flags passed to re.compile function as the second argument
 917      flags: int
 918    """
 919
 920    self.regex = re.compile(pattern, flags=flags)
 921
 922  def equals(self, rhs):
 923    """Check to see if rhs matches regular expression pattern.
 924
 925    Returns:
 926      bool
 927    """
 928
 929    return self.regex.search(rhs) is not None
 930
 931  def __repr__(self):
 932    s = '<regular expression \'%s\'' % self.regex.pattern
 933    if self.regex.flags:
 934      s += ', flags=%d' % self.regex.flags
 935    s += '>'
 936    return s
 937
 938
 939class In(Comparator):
 940  """Checks whether an item (or key) is in a list (or dict) parameter.
 941
 942  Example:
 943  mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
 944  """
 945
 946  def __init__(self, key):
 947    """Initialize.
 948
 949    Args:
 950      # key is any thing that could be in a list or a key in a dict
 951    """
 952
 953    self._key = key
 954
 955  def equals(self, rhs):
 956    """Check to see whether key is in rhs.
 957
 958    Args:
 959      rhs: dict
 960
 961    Returns:
 962      bool
 963    """
 964
 965    return self._key in rhs
 966
 967  def __repr__(self):
 968    return '<sequence or map containing \'%s\'>' % self._key
 969
 970
 971class ContainsKeyValue(Comparator):
 972  """Checks whether a key/value pair is in a dict parameter.
 973
 974  Example:
 975  mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
 976  """
 977
 978  def __init__(self, key, value):
 979    """Initialize.
 980
 981    Args:
 982      # key: a key in a dict
 983      # value: the corresponding value
 984    """
 985
 986    self._key = key
 987    self._value = value
 988
 989  def equals(self, rhs):
 990    """Check whether the given key/value pair is in the rhs dict.
 991
 992    Returns:
 993      bool
 994    """
 995
 996    try:
 997      return rhs[self._key] == self._value
 998    except Exception:
 999      return False
1000
1001  def __repr__(self):
1002    return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1003
1004
1005class SameElementsAs(Comparator):
1006  """Checks whether iterables contain the same elements (ignoring order).
1007
1008  Example:
1009  mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1010  """
1011
1012  def __init__(self, expected_seq):
1013    """Initialize.
1014
1015    Args:
1016      expected_seq: a sequence
1017    """
1018
1019    self._expected_seq = expected_seq
1020
1021  def equals(self, actual_seq):
1022    """Check to see whether actual_seq has same elements as expected_seq.
1023
1024    Args:
1025      actual_seq: sequence
1026
1027    Returns:
1028      bool
1029    """
1030
1031    try:
1032      expected = dict([(element, None) for element in self._expected_seq])
1033      actual = dict([(element, None) for element in actual_seq])
1034    except TypeError:
1035      # Fall back to slower list-compare if any of the objects are unhashable.
1036      expected = list(self._expected_seq)
1037      actual = list(actual_seq)
1038      expected.sort()
1039      actual.sort()
1040    return expected == actual
1041
1042  def __repr__(self):
1043    return '<sequence with same elements as \'%s\'>' % self._expected_seq
1044
1045
1046class And(Comparator):
1047  """Evaluates one or more Comparators on RHS and returns an AND of the results.
1048  """
1049
1050  def __init__(self, *args):
1051    """Initialize.
1052
1053    Args:
1054      *args: One or more Comparator
1055    """
1056
1057    self._comparators = args
1058
1059  def equals(self, rhs):
1060    """Checks whether all Comparators are equal to rhs.
1061
1062    Args:
1063      # rhs: can be anything
1064
1065    Returns:
1066      bool
1067    """
1068
1069    for comparator in self._comparators:
1070      if not comparator.equals(rhs):
1071        return False
1072
1073    return True
1074
1075  def __repr__(self):
1076    return '<AND %s>' % str(self._comparators)
1077
1078
1079class Or(Comparator):
1080  """Evaluates one or more Comparators on RHS and returns an OR of the results.
1081  """
1082
1083  def __init__(self, *args):
1084    """Initialize.
1085
1086    Args:
1087      *args: One or more Mox comparators
1088    """
1089
1090    self._comparators = args
1091
1092  def equals(self, rhs):
1093    """Checks whether any Comparator is equal to rhs.
1094
1095    Args:
1096      # rhs: can be anything
1097
1098    Returns:
1099      bool
1100    """
1101
1102    for comparator in self._comparators:
1103      if comparator.equals(rhs):
1104        return True
1105
1106    return False
1107
1108  def __repr__(self):
1109    return '<OR %s>' % str(self._comparators)
1110
1111
1112class Func(Comparator):
1113  """Call a function that should verify the parameter passed in is correct.
1114
1115  You may need the ability to perform more advanced operations on the parameter
1116  in order to validate it.  You can use this to have a callable validate any
1117  parameter. The callable should return either True or False.
1118
1119
1120  Example:
1121
1122  def myParamValidator(param):
1123    # Advanced logic here
1124    return True
1125
1126  mock_dao.DoSomething(Func(myParamValidator), true)
1127  """
1128
1129  def __init__(self, func):
1130    """Initialize.
1131
1132    Args:
1133      func: callable that takes one parameter and returns a bool
1134    """
1135
1136    self._func = func
1137
1138  def equals(self, rhs):
1139    """Test whether rhs passes the function test.
1140
1141    rhs is passed into func.
1142
1143    Args:
1144      rhs: any python object
1145
1146    Returns:
1147      the result of func(rhs)
1148    """
1149
1150    return self._func(rhs)
1151
1152  def __repr__(self):
1153    return str(self._func)
1154
1155
1156class IgnoreArg(Comparator):
1157  """Ignore an argument.
1158
1159  This can be used when we don't care about an argument of a method call.
1160
1161  Example:
1162  # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163  mymock.CastMagic(3, IgnoreArg(), 'disappear')
1164  """
1165
1166  def equals(self, unused_rhs):
1167    """Ignores arguments and returns True.
1168
1169    Args:
1170      unused_rhs: any python object
1171
1172    Returns:
1173      always returns True
1174    """
1175
1176    return True
1177
1178  def __repr__(self):
1179    return '<IgnoreArg>'
1180
1181
1182class MethodGroup(object):
1183  """Base class containing common behaviour for MethodGroups."""
1184
1185  def __init__(self, group_name):
1186    self._group_name = group_name
1187
1188  def group_name(self):
1189    return self._group_name
1190
1191  def __str__(self):
1192    return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1193
1194  def AddMethod(self, mock_method):
1195    raise NotImplementedError
1196
1197  def MethodCalled(self, mock_method):
1198    raise NotImplementedError
1199
1200  def IsSatisfied(self):
1201    raise NotImplementedError
1202
1203class UnorderedGroup(MethodGroup):
1204  """UnorderedGroup holds a set of method calls that may occur in any order.
1205
1206  This construct is helpful for non-deterministic events, such as iterating
1207  over the keys of a dict.
1208  """
1209
1210  def __init__(self, group_name):
1211    super(UnorderedGroup, self).__init__(group_name)
1212    self._methods = []
1213
1214  def AddMethod(self, mock_method):
1215    """Add a method to this group.
1216
1217    Args:
1218      mock_method: A mock method to be added to this group.
1219    """
1220
1221    self._methods.append(mock_method)
1222
1223  def MethodCalled(self, mock_method):
1224    """Remove a method call from the group.
1225
1226    If the method is not in the set, an UnexpectedMethodCallError will be
1227    raised.
1228
1229    Args:
1230      mock_method: a mock method that should be equal to a method in the group.
1231
1232    Returns:
1233      The mock method from the group
1234
1235    Raises:
1236      UnexpectedMethodCallError if the mock_method was not in the group.
1237    """
1238
1239    # Check to see if this method exists, and if so, remove it from the set
1240    # and return it.
1241    for method in self._methods:
1242      if method == mock_method:
1243        # Remove the called mock_method instead of the method in the group.
1244        # The called method will match any comparators when equality is checked
1245        # during removal.  The method in the group could pass a comparator to
1246        # another comparator during the equality check.
1247        self._methods.remove(mock_method)
1248
1249        # If this group is not empty, put it back at the head of the queue.
1250        if not self.IsSatisfied():
1251          mock_method._call_queue.appendleft(self)
1252
1253        return self, method
1254
1255    raise UnexpectedMethodCallError(mock_method, self)
1256
1257  def IsSatisfied(self):
1258    """Return True if there are not any methods in this group."""
1259
1260    return len(self._methods) == 0
1261
1262
1263class MultipleTimesGroup(MethodGroup):
1264  """MultipleTimesGroup holds methods that may be called any number of times.
1265
1266  Note: Each method must be called at least once.
1267
1268  This is helpful, if you don't know or care how many times a method is called.
1269  """
1270
1271  def __init__(self, group_name):
1272    super(MultipleTimesGroup, self).__init__(group_name)
1273    self._methods = set()
1274    self._methods_called = set()
1275
1276  def AddMethod(self, mock_method):
1277    """Add a method to this group.
1278
1279    Args:
1280      mock_method: A mock method to be added to this group.
1281    """
1282
1283    self._methods.add(mock_method)
1284
1285  def MethodCalled(self, mock_method):
1286    """Remove a method call from the group.
1287
1288    If the method is not in the set, an UnexpectedMethodCallError will be
1289    raised.
1290
1291    Args:
1292      mock_method: a mock method that should be equal to a method in the group.
1293
1294    Returns:
1295      The mock method from the group
1296
1297    Raises:
1298      UnexpectedMethodCallError if the mock_method was not in the group.
1299    """
1300
1301    # Check to see if this method exists, and if so add it to the set of
1302    # called methods.
1303
1304    for method in self._methods:
1305      if method == mock_method:
1306        self._methods_called.add(mock_method)
1307        # Always put this group back on top of the queue, because we don't know
1308        # when we are done.
1309        mock_method._call_queue.appendleft(self)
1310        return self, method
1311
1312    if self.IsSatisfied():
1313      next_method = mock_method._PopNextMethod();
1314      return next_method, None
1315    else:
1316      raise UnexpectedMethodCallError(mock_method, self)
1317
1318  def IsSatisfied(self):
1319    """Return True if all methods in this group are called at least once."""
1320    # NOTE(psycho): We can't use the simple set difference here because we want
1321    # to match different parameters which are considered the same e.g. IsA(str)
1322    # and some string. This solution is O(n^2) but n should be small.
1323    tmp = self._methods.copy()
1324    for called in self._methods_called:
1325      for expected in tmp:
1326        if called == expected:
1327          tmp.remove(expected)
1328          if not tmp:
1329            return True
1330          break
1331    return False
1332
1333
1334class MoxMetaTestBase(type):
1335  """Metaclass to add mox cleanup and verification to every test.
1336
1337  As the mox unit testing class is being constructed (MoxTestBase or a
1338  subclass), this metaclass will modify all test functions to call the
1339  CleanUpMox method of the test class after they finish. This means that
1340  unstubbing and verifying will happen for every test with no additional code,
1341  and any failures will result in test failures as opposed to errors.
1342  """
1343
1344  def __init__(cls, name, bases, d):
1345    type.__init__(cls, name, bases, d)
1346
1347    # also get all the attributes from the base classes to account
1348    # for a case when test class is not the immediate child of MoxTestBase
1349    for base in bases:
1350      for attr_name in dir(base):
1351        d[attr_name] = getattr(base, attr_name)
1352
1353    for func_name, func in d.items():
1354      if func_name.startswith('test') and callable(func):
1355        setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1356
1357  @staticmethod
1358  def CleanUpTest(cls, func):
1359    """Adds Mox cleanup code to any MoxTestBase method.
1360
1361    Always unsets stubs after a test. Will verify all mocks for tests that
1362    otherwise pass.
1363
1364    Args:
1365      cls: MoxTestBase or subclass; the class whose test method we are altering.
1366      func: method; the method of the MoxTestBase test class we wish to alter.
1367
1368    Returns:
1369      The modified method.
1370    """
1371    def new_method(self, *args, **kwargs):
1372      mox_obj = getattr(self, 'mox', None)
1373      cleanup_mox = False
1374      if mox_obj and isinstance(mox_obj, Mox):
1375        cleanup_mox = True
1376      try:
1377        func(self, *args, **kwargs)
1378      finally:
1379        if cleanup_mox:
1380          mox_obj.UnsetStubs()
1381      if cleanup_mox:
1382        mox_obj.VerifyAll()
1383    new_method.__name__ = func.__name__
1384    new_method.__doc__ = func.__doc__
1385    new_method.__module__ = func.__module__
1386    return new_method
1387
1388
1389class MoxTestBase(unittest.TestCase):
1390  """Convenience test class to make stubbing easier.
1391
1392  Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393  want this. Also automatically unsets any stubs and verifies that all mock
1394  methods have been called at the end of each test, eliminating boilerplate
1395  code.
1396  """
1397
1398  __metaclass__ = MoxMetaTestBase
1399
1400  def setUp(self):
1401    self.mox = Mox()