3
# Copyright 2008 Google Inc.
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
17
# This file is used for testing. The original is at:
18
# http://code.google.com/p/pymox/
20
"""Mox, an object-mocking framework for Python.
22
Mox works in the record-replay-verify paradigm. When you first create
23
a mock object, it is in record mode. You then programmatically set
24
the expected behavior of the mock object (what methods are to be
25
called on it, with what parameters, what they should return, and in
28
Once you have set up the expected mock behavior, you put it in replay
29
mode. Now the mock responds to method calls just as you told it to.
30
If an unexpected method (or an expected method with unexpected
31
parameters) is called, then an exception will be raised.
33
Once you are done interacting with the mock, you need to verify that
34
all the expected interactions occured. (Maybe your code exited
35
prematurely without calling some cleanup method!) The verify phase
36
ensures that every expected method was called; otherwise, an exception
39
Suggested usage / workflow:
44
# Create a mock data access object
45
mock_dao = my_mox.CreateMock(DAOClass)
47
# Set up expected behavior
48
mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49
mock_dao.DeletePerson(person)
51
# Put mocks in replay mode
54
# Inject mock object and run test
55
controller.SetDao(mock_dao)
56
controller.DeletePersonById('1')
58
# Verify all methods were called as expected
62
from collections import deque
69
class Error(AssertionError):
70
"""Base exception for this module."""
75
class ExpectedMethodCallsError(Error):
76
"""Raised when Verify() is called before all expected methods have been called
79
def __init__(self, expected_methods):
83
# expected_methods: A sequence of MockMethod objects that should have been
85
expected_methods: [MockMethod]
88
ValueError: if expected_methods contains no methods.
91
if not expected_methods:
92
raise ValueError("There must be at least one expected method")
94
self._expected_methods = expected_methods
97
calls = "\n".join(["%3d. %s" % (i, m)
98
for i, m in enumerate(self._expected_methods)])
99
return "Verify: Expected methods never called:\n%s" % (calls,)
102
class UnexpectedMethodCallError(Error):
103
"""Raised when an unexpected method is called.
105
This can occur if a method is called with incorrect parameters, or out of the
109
def __init__(self, unexpected_method, expected):
113
# unexpected_method: MockMethod that was called but was not at the head of
114
# the expected_method queue.
115
# expected: MockMethod or UnorderedGroup the method should have
117
unexpected_method: MockMethod
118
expected: MockMethod or UnorderedGroup
122
self._unexpected_method = unexpected_method
123
self._expected = expected
126
return "Unexpected method call: %s. Expecting: %s" % \
127
(self._unexpected_method, self._expected)
130
class UnknownMethodCallError(Error):
131
"""Raised if an unknown method is requested of the mock object."""
133
def __init__(self, unknown_method_name):
137
# unknown_method_name: Method call that is not part of the mocked class's
139
unknown_method_name: str
143
self._unknown_method_name = unknown_method_name
146
return "Method called is not a member of the object: %s" % \
147
self._unknown_method_name
151
"""Mox: a factory for creating mock objects."""
153
# A list of types that should be stubbed out with MockObjects (as
154
# opposed to MockAnythings).
155
_USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156
types.ObjectType, types.TypeType]
159
"""Initialize a new Mox."""
161
self._mock_objects = []
162
self.stubs = stubout.StubOutForTesting()
164
def CreateMock(self, class_to_mock):
165
"""Create a new mock object.
168
# class_to_mock: the class to be mocked
172
MockObject that can be used as the class_to_mock would be.
175
new_mock = MockObject(class_to_mock)
176
self._mock_objects.append(new_mock)
179
def CreateMockAnything(self):
180
"""Create a mock that will accept any method calls.
182
This does not enforce an interface.
185
new_mock = MockAnything()
186
self._mock_objects.append(new_mock)
190
"""Set all mock objects to replay mode."""
192
for mock_obj in self._mock_objects:
197
"""Call verify on all mock objects created."""
199
for mock_obj in self._mock_objects:
203
"""Call reset on all mock objects. This does not unset stubs."""
205
for mock_obj in self._mock_objects:
208
def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209
"""Replace a method, attribute, etc. with a Mock.
211
This will replace a class or module with a MockObject, and everything else
212
(method, function, etc) with a MockAnything. This can be overridden to
213
always use a MockAnything by setting use_mock_anything to True.
216
obj: A Python object (class, module, instance, callable).
217
attr_name: str. The name of the attribute to replace with a mock.
218
use_mock_anything: bool. True if a MockAnything should be used regardless
219
of the type of attribute.
222
attr_to_replace = getattr(obj, attr_name)
223
if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224
stub = self.CreateMock(attr_to_replace)
226
stub = self.CreateMockAnything()
228
self.stubs.Set(obj, attr_name, stub)
230
def UnsetStubs(self):
231
"""Restore stubs to their original state."""
233
self.stubs.UnsetAll()
236
"""Put mocks into Replay mode.
239
# args is any number of mocks to put into replay mode.
250
# args is any number of mocks to be verified.
261
# args is any number of mocks to be reset.
269
"""A mock that can be used to mock anything.
271
This is helpful for mocking classes that do not provide a public interface.
278
def __getattr__(self, method_name):
279
"""Intercept method calls on this object.
281
A new MockMethod is returned that is aware of the MockAnything's
282
state (record or replay). The call will be recorded or replayed
283
by the MockMethod's __call__.
286
# method name: the name of the method being called.
290
A new MockMethod aware of MockAnything's state (record or replay).
293
return self._CreateMockMethod(method_name)
295
def _CreateMockMethod(self, method_name):
296
"""Create a new mock method call and return it.
299
# method name: the name of the method being called.
303
A new MockMethod aware of MockAnything's state (record or replay).
306
return MockMethod(method_name, self._expected_calls_queue,
309
def __nonzero__(self):
310
"""Return 1 for nonzero so the mock can be used as a conditional."""
314
def __eq__(self, rhs):
315
"""Provide custom logic to compare objects."""
317
return (isinstance(rhs, MockAnything) and
318
self._replay_mode == rhs._replay_mode and
319
self._expected_calls_queue == rhs._expected_calls_queue)
321
def __ne__(self, rhs):
322
"""Provide custom logic to compare objects."""
324
return not self == rhs
327
"""Start replaying expected method calls."""
329
self._replay_mode = True
332
"""Verify that all of the expected calls have been made.
335
ExpectedMethodCallsError: if there are still more method calls in the
339
# If the list of expected calls is not empty, raise an exception
340
if self._expected_calls_queue:
341
# The last MultipleTimesGroup is not popped from the queue.
342
if (len(self._expected_calls_queue) == 1 and
343
isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344
self._expected_calls_queue[0].IsSatisfied()):
347
raise ExpectedMethodCallsError(self._expected_calls_queue)
350
"""Reset the state of this mock to record mode with an empty queue."""
352
# Maintain a list of method calls we are expecting
353
self._expected_calls_queue = deque()
355
# Make sure we are in setup mode, not replay mode
356
self._replay_mode = False
359
class MockObject(MockAnything, object):
360
"""A mock object that simulates the public/protected interface of a class."""
362
def __init__(self, class_to_mock):
363
"""Initialize a mock object.
365
This determines the methods and properties of the class and stores them.
368
# class_to_mock: class to be mocked
372
# This is used to hack around the mixin/inheritance of MockAnything, which
373
# is not a proper object (it can be anything. :-)
374
MockAnything.__dict__['__init__'](self)
376
# Get a list of all the public and special methods we should mock.
377
self._known_methods = set()
378
self._known_vars = set()
379
self._class_to_mock = class_to_mock
380
for method in dir(class_to_mock):
381
if callable(getattr(class_to_mock, method)):
382
self._known_methods.add(method)
384
self._known_vars.add(method)
386
def __getattr__(self, name):
387
"""Intercept attribute request on this object.
389
If the attribute is a public class variable, it will be returned and not
392
If the attribute is not a variable, it is handled like a method
393
call. The method name is checked against the set of mockable
394
methods, and a new MockMethod is returned that is aware of the
395
MockObject's state (record or replay). The call will be recorded
396
or replayed by the MockMethod's __call__.
399
# name: the name of the attribute being requested.
403
Either a class variable or a new MockMethod that is aware of the state
404
of the mock (record or replay).
407
UnknownMethodCallError if the MockObject does not mock the requested
411
if name in self._known_vars:
412
return getattr(self._class_to_mock, name)
414
if name in self._known_methods:
415
return self._CreateMockMethod(name)
417
raise UnknownMethodCallError(name)
419
def __eq__(self, rhs):
420
"""Provide custom logic to compare objects."""
422
return (isinstance(rhs, MockObject) and
423
self._class_to_mock == rhs._class_to_mock and
424
self._replay_mode == rhs._replay_mode and
425
self._expected_calls_queue == rhs._expected_calls_queue)
427
def __setitem__(self, key, value):
428
"""Provide custom logic for mocking classes that support item assignment.
431
key: Key to set the value for.
435
Expected return value in replay mode. A MockMethod object for the
436
__setitem__ method that has already been called if not in replay mode.
439
TypeError if the underlying class does not support item assignment.
440
UnexpectedMethodCallError if the object does not expect the call to
444
setitem = self._class_to_mock.__dict__.get('__setitem__', None)
446
# Verify the class supports item assignment.
448
raise TypeError('object does not support item assignment')
450
# If we are in replay mode then simply call the mock __setitem__ method.
451
if self._replay_mode:
452
return MockMethod('__setitem__', self._expected_calls_queue,
453
self._replay_mode)(key, value)
456
# Otherwise, create a mock method __setitem__.
457
return self._CreateMockMethod('__setitem__')(key, value)
459
def __getitem__(self, key):
460
"""Provide custom logic for mocking classes that are subscriptable.
463
key: Key to return the value for.
466
Expected return value in replay mode. A MockMethod object for the
467
__getitem__ method that has already been called if not in replay mode.
470
TypeError if the underlying class is not subscriptable.
471
UnexpectedMethodCallError if the object does not expect the call to
475
getitem = self._class_to_mock.__dict__.get('__getitem__', None)
477
# Verify the class supports item assignment.
479
raise TypeError('unsubscriptable object')
481
# If we are in replay mode then simply call the mock __getitem__ method.
482
if self._replay_mode:
483
return MockMethod('__getitem__', self._expected_calls_queue,
484
self._replay_mode)(key)
487
# Otherwise, create a mock method __getitem__.
488
return self._CreateMockMethod('__getitem__')(key)
490
def __call__(self, *params, **named_params):
491
"""Provide custom logic for mocking classes that are callable."""
493
# Verify the class we are mocking is callable
494
callable = self._class_to_mock.__dict__.get('__call__', None)
496
raise TypeError('Not callable')
498
# Because the call is happening directly on this object instead of a method,
499
# the call on the mock method is made right here
500
mock_method = self._CreateMockMethod('__call__')
501
return mock_method(*params, **named_params)
505
"""Return the class that is being mocked."""
507
return self._class_to_mock
510
class MockMethod(object):
511
"""Callable mock method.
513
A MockMethod should act exactly like the method it mocks, accepting parameters
514
and returning a value, or throwing an exception (as specified). When this
515
method is called, it can optionally verify whether the called method (name and
516
signature) matches the expected method.
519
def __init__(self, method_name, call_queue, replay_mode):
520
"""Construct a new mock method.
523
# method_name: the name of the method
524
# call_queue: deque of calls, verify this call against the head, or add
525
# this call to the queue.
526
# replay_mode: False if we are recording, True if we are verifying calls
527
# against the call queue.
529
call_queue: list or deque
533
self._name = method_name
534
self._call_queue = call_queue
535
if not isinstance(call_queue, deque):
536
self._call_queue = deque(self._call_queue)
537
self._replay_mode = replay_mode
540
self._named_params = None
541
self._return_value = None
542
self._exception = None
543
self._side_effects = None
545
def __call__(self, *params, **named_params):
546
"""Log parameters and return the specified return value.
548
If the Mock(Anything/Object) associated with this call is in record mode,
549
this MockMethod will be pushed onto the expected call queue. If the mock
550
is in replay mode, this will pop a MockMethod off the top of the queue and
551
verify this call is equal to the expected call.
554
UnexpectedMethodCall if this call is supposed to match an expected method
555
call and it does not.
558
self._params = params
559
self._named_params = named_params
561
if not self._replay_mode:
562
self._call_queue.append(self)
565
expected_method = self._VerifyMethodCall()
567
if expected_method._side_effects:
568
expected_method._side_effects(*params, **named_params)
570
if expected_method._exception:
571
raise expected_method._exception
573
return expected_method._return_value
575
def __getattr__(self, name):
576
"""Raise an AttributeError with a helpful message."""
578
raise AttributeError('MockMethod has no attribute "%s". '
579
'Did you remember to put your mocks in replay mode?' % name)
581
def _PopNextMethod(self):
582
"""Pop the next method from our call queue."""
584
return self._call_queue.popleft()
586
raise UnexpectedMethodCallError(self, None)
588
def _VerifyMethodCall(self):
589
"""Verify the called method is expected.
591
This can be an ordered method, or part of an unordered set.
594
The expected mock method.
597
UnexpectedMethodCall if the method called was not expected.
600
expected = self._PopNextMethod()
602
# Loop here, because we might have a MethodGroup followed by another
604
while isinstance(expected, MethodGroup):
605
expected, method = expected.MethodCalled(self)
606
if method is not None:
609
# This is a mock method, so just check equality.
611
raise UnexpectedMethodCallError(self, expected)
617
[repr(p) for p in self._params or []] +
618
['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619
desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
622
def __eq__(self, rhs):
623
"""Test whether this MockMethod is equivalent to another MockMethod.
626
# rhs: the right hand side of the test
630
return (isinstance(rhs, MockMethod) and
631
self._name == rhs._name and
632
self._params == rhs._params and
633
self._named_params == rhs._named_params)
635
def __ne__(self, rhs):
636
"""Test whether this MockMethod is not equivalent to another MockMethod.
639
# rhs: the right hand side of the test
643
return not self == rhs
645
def GetPossibleGroup(self):
646
"""Returns a possible group from the end of the call queue or None if no
647
other methods are on the stack.
650
# Remove this method from the tail of the queue so we can add it to a group.
651
this_method = self._call_queue.pop()
652
assert this_method == self
654
# Determine if the tail of the queue is a group, or just a regular ordered
658
group = self._call_queue[-1]
664
def _CheckAndCreateNewGroup(self, group_name, group_class):
665
"""Checks if the last method (a possible group) is an instance of our
666
group_class. Adds the current method to this group or creates a new one.
670
group_name: the name of the group.
671
group_class: the class used to create instance of this new group
673
group = self.GetPossibleGroup()
675
# If this is a group, and it is the correct group, add the method.
676
if isinstance(group, group_class) and group.group_name() == group_name:
677
group.AddMethod(self)
680
# Create a new group and add the method.
681
new_group = group_class(group_name)
682
new_group.AddMethod(self)
683
self._call_queue.append(new_group)
686
def InAnyOrder(self, group_name="default"):
687
"""Move this method into a group of unordered calls.
689
A group of unordered calls must be defined together, and must be executed
690
in full before the next expected method can be called. There can be
691
multiple groups that are expected serially, if they are given
692
different group names. The same group name can be reused if there is a
693
standard method call, or a group with a different name, spliced between
697
group_name: the name of the unordered group.
702
return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
704
def MultipleTimes(self, group_name="default"):
705
"""Move this method into group of calls which may be called multiple times.
707
A group of repeating calls must be defined together, and must be executed in
708
full before the next expected mehtod can be called.
711
group_name: the name of the unordered group.
716
return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
718
def AndReturn(self, return_value):
719
"""Set the value to return when this method is called.
722
# return_value can be anything.
725
self._return_value = return_value
728
def AndRaise(self, exception):
729
"""Set the exception to raise when this method is called.
732
# exception: the exception to raise when this method is called.
736
self._exception = exception
738
def WithSideEffects(self, side_effects):
739
"""Set the side effects that are simulated when this method is called.
742
side_effects: A callable which modifies the parameters or other relevant
743
state which a given test case depends on.
746
Self for chaining with AndReturn and AndRaise.
748
self._side_effects = side_effects
752
"""Base class for all Mox comparators.
754
A Comparator can be used as a parameter to a mocked method when the exact
755
value is not known. For example, the code you are testing might build up a
756
long SQL string that is passed to your mock DAO. You're only interested that
757
the IN clause contains the proper primary keys, so you can set your mock
760
mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
762
Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
764
A Comparator may replace one or more parameters, for example:
765
# return at most 10 rows
766
mock_dao.RunQuery(StrContains('SELECT'), 10)
770
# Return some non-deterministic number of rows
771
mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
774
def equals(self, rhs):
775
"""Special equals method that all comparators must implement.
778
rhs: any python object
781
raise NotImplementedError, 'method must be implemented by a subclass.'
783
def __eq__(self, rhs):
784
return self.equals(rhs)
786
def __ne__(self, rhs):
787
return not self.equals(rhs)
790
class IsA(Comparator):
791
"""This class wraps a basic Python type or class. It is used to verify
792
that a parameter is of the given type or class.
795
mock_dao.Connect(IsA(DbConnectInfo))
798
def __init__(self, class_name):
802
class_name: basic python type or a class
805
self._class_name = class_name
807
def equals(self, rhs):
808
"""Check to see if the RHS is an instance of class_name.
811
# rhs: the right hand side of the test
819
return isinstance(rhs, self._class_name)
821
# Check raw types if there was a type error. This is helpful for
822
# things like cStringIO.StringIO.
823
return type(rhs) == type(self._class_name)
826
return str(self._class_name)
828
class IsAlmost(Comparator):
829
"""Comparison class used to check whether a parameter is nearly equal
830
to a given value. Generally useful for floating point numbers.
832
Example mock_dao.SetTimeout((IsAlmost(3.9)))
835
def __init__(self, float_value, places=7):
836
"""Initialize IsAlmost.
839
float_value: The value for making the comparison.
840
places: The number of decimal places to round to.
843
self._float_value = float_value
844
self._places = places
846
def equals(self, rhs):
847
"""Check to see if RHS is almost equal to float_value
850
rhs: the value to compare to float_value
857
return round(rhs-self._float_value, self._places) == 0
859
# This is probably because either float_value or rhs is not a number.
863
return str(self._float_value)
865
class StrContains(Comparator):
866
"""Comparison class used to check whether a substring exists in a
867
string parameter. This can be useful in mocking a database with SQL
868
passed in as a string parameter, for example.
871
mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
874
def __init__(self, search_string):
878
# search_string: the string you are searching for
882
self._search_string = search_string
884
def equals(self, rhs):
885
"""Check to see if the search_string is contained in the rhs string.
888
# rhs: the right hand side of the test
896
return rhs.find(self._search_string) > -1
901
return '<str containing \'%s\'>' % self._search_string
904
class Regex(Comparator):
905
"""Checks if a string matches a regular expression.
907
This uses a given regular expression to determine equality.
910
def __init__(self, pattern, flags=0):
914
# pattern is the regular expression to search for
916
# flags passed to re.compile function as the second argument
920
self.regex = re.compile(pattern, flags=flags)
922
def equals(self, rhs):
923
"""Check to see if rhs matches regular expression pattern.
929
return self.regex.search(rhs) is not None
932
s = '<regular expression \'%s\'' % self.regex.pattern
934
s += ', flags=%d' % self.regex.flags
939
class In(Comparator):
940
"""Checks whether an item (or key) is in a list (or dict) parameter.
943
mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
946
def __init__(self, key):
950
# key is any thing that could be in a list or a key in a dict
955
def equals(self, rhs):
956
"""Check to see whether key is in rhs.
965
return self._key in rhs
968
return '<sequence or map containing \'%s\'>' % self._key
971
class ContainsKeyValue(Comparator):
972
"""Checks whether a key/value pair is in a dict parameter.
975
mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
978
def __init__(self, key, value):
982
# key: a key in a dict
983
# value: the corresponding value
989
def equals(self, rhs):
990
"""Check whether the given key/value pair is in the rhs dict.
997
return rhs[self._key] == self._value
1002
return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1005
class SameElementsAs(Comparator):
1006
"""Checks whether iterables contain the same elements (ignoring order).
1009
mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1012
def __init__(self, expected_seq):
1016
expected_seq: a sequence
1019
self._expected_seq = expected_seq
1021
def equals(self, actual_seq):
1022
"""Check to see whether actual_seq has same elements as expected_seq.
1025
actual_seq: sequence
1032
expected = dict([(element, None) for element in self._expected_seq])
1033
actual = dict([(element, None) for element in actual_seq])
1035
# Fall back to slower list-compare if any of the objects are unhashable.
1036
expected = list(self._expected_seq)
1037
actual = list(actual_seq)
1040
return expected == actual
1043
return '<sequence with same elements as \'%s\'>' % self._expected_seq
1046
class And(Comparator):
1047
"""Evaluates one or more Comparators on RHS and returns an AND of the results.
1050
def __init__(self, *args):
1054
*args: One or more Comparator
1057
self._comparators = args
1059
def equals(self, rhs):
1060
"""Checks whether all Comparators are equal to rhs.
1063
# rhs: can be anything
1069
for comparator in self._comparators:
1070
if not comparator.equals(rhs):
1076
return '<AND %s>' % str(self._comparators)
1079
class Or(Comparator):
1080
"""Evaluates one or more Comparators on RHS and returns an OR of the results.
1083
def __init__(self, *args):
1087
*args: One or more Mox comparators
1090
self._comparators = args
1092
def equals(self, rhs):
1093
"""Checks whether any Comparator is equal to rhs.
1096
# rhs: can be anything
1102
for comparator in self._comparators:
1103
if comparator.equals(rhs):
1109
return '<OR %s>' % str(self._comparators)
1112
class Func(Comparator):
1113
"""Call a function that should verify the parameter passed in is correct.
1115
You may need the ability to perform more advanced operations on the parameter
1116
in order to validate it. You can use this to have a callable validate any
1117
parameter. The callable should return either True or False.
1122
def myParamValidator(param):
1123
# Advanced logic here
1126
mock_dao.DoSomething(Func(myParamValidator), true)
1129
def __init__(self, func):
1133
func: callable that takes one parameter and returns a bool
1138
def equals(self, rhs):
1139
"""Test whether rhs passes the function test.
1141
rhs is passed into func.
1144
rhs: any python object
1147
the result of func(rhs)
1150
return self._func(rhs)
1153
return str(self._func)
1156
class IgnoreArg(Comparator):
1157
"""Ignore an argument.
1159
This can be used when we don't care about an argument of a method call.
1162
# Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163
mymock.CastMagic(3, IgnoreArg(), 'disappear')
1166
def equals(self, unused_rhs):
1167
"""Ignores arguments and returns True.
1170
unused_rhs: any python object
1179
return '<IgnoreArg>'
1182
class MethodGroup(object):
1183
"""Base class containing common behaviour for MethodGroups."""
1185
def __init__(self, group_name):
1186
self._group_name = group_name
1188
def group_name(self):
1189
return self._group_name
1192
return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1194
def AddMethod(self, mock_method):
1195
raise NotImplementedError
1197
def MethodCalled(self, mock_method):
1198
raise NotImplementedError
1200
def IsSatisfied(self):
1201
raise NotImplementedError
1203
class UnorderedGroup(MethodGroup):
1204
"""UnorderedGroup holds a set of method calls that may occur in any order.
1206
This construct is helpful for non-deterministic events, such as iterating
1207
over the keys of a dict.
1210
def __init__(self, group_name):
1211
super(UnorderedGroup, self).__init__(group_name)
1214
def AddMethod(self, mock_method):
1215
"""Add a method to this group.
1218
mock_method: A mock method to be added to this group.
1221
self._methods.append(mock_method)
1223
def MethodCalled(self, mock_method):
1224
"""Remove a method call from the group.
1226
If the method is not in the set, an UnexpectedMethodCallError will be
1230
mock_method: a mock method that should be equal to a method in the group.
1233
The mock method from the group
1236
UnexpectedMethodCallError if the mock_method was not in the group.
1239
# Check to see if this method exists, and if so, remove it from the set
1241
for method in self._methods:
1242
if method == mock_method:
1243
# Remove the called mock_method instead of the method in the group.
1244
# The called method will match any comparators when equality is checked
1245
# during removal. The method in the group could pass a comparator to
1246
# another comparator during the equality check.
1247
self._methods.remove(mock_method)
1249
# If this group is not empty, put it back at the head of the queue.
1250
if not self.IsSatisfied():
1251
mock_method._call_queue.appendleft(self)
1255
raise UnexpectedMethodCallError(mock_method, self)
1257
def IsSatisfied(self):
1258
"""Return True if there are not any methods in this group."""
1260
return len(self._methods) == 0
1263
class MultipleTimesGroup(MethodGroup):
1264
"""MultipleTimesGroup holds methods that may be called any number of times.
1266
Note: Each method must be called at least once.
1268
This is helpful, if you don't know or care how many times a method is called.
1271
def __init__(self, group_name):
1272
super(MultipleTimesGroup, self).__init__(group_name)
1273
self._methods = set()
1274
self._methods_called = set()
1276
def AddMethod(self, mock_method):
1277
"""Add a method to this group.
1280
mock_method: A mock method to be added to this group.
1283
self._methods.add(mock_method)
1285
def MethodCalled(self, mock_method):
1286
"""Remove a method call from the group.
1288
If the method is not in the set, an UnexpectedMethodCallError will be
1292
mock_method: a mock method that should be equal to a method in the group.
1295
The mock method from the group
1298
UnexpectedMethodCallError if the mock_method was not in the group.
1301
# Check to see if this method exists, and if so add it to the set of
1304
for method in self._methods:
1305
if method == mock_method:
1306
self._methods_called.add(mock_method)
1307
# Always put this group back on top of the queue, because we don't know
1309
mock_method._call_queue.appendleft(self)
1312
if self.IsSatisfied():
1313
next_method = mock_method._PopNextMethod();
1314
return next_method, None
1316
raise UnexpectedMethodCallError(mock_method, self)
1318
def IsSatisfied(self):
1319
"""Return True if all methods in this group are called at least once."""
1320
# NOTE(psycho): We can't use the simple set difference here because we want
1321
# to match different parameters which are considered the same e.g. IsA(str)
1322
# and some string. This solution is O(n^2) but n should be small.
1323
tmp = self._methods.copy()
1324
for called in self._methods_called:
1325
for expected in tmp:
1326
if called == expected:
1327
tmp.remove(expected)
1334
class MoxMetaTestBase(type):
1335
"""Metaclass to add mox cleanup and verification to every test.
1337
As the mox unit testing class is being constructed (MoxTestBase or a
1338
subclass), this metaclass will modify all test functions to call the
1339
CleanUpMox method of the test class after they finish. This means that
1340
unstubbing and verifying will happen for every test with no additional code,
1341
and any failures will result in test failures as opposed to errors.
1344
def __init__(cls, name, bases, d):
1345
type.__init__(cls, name, bases, d)
1347
# also get all the attributes from the base classes to account
1348
# for a case when test class is not the immediate child of MoxTestBase
1350
for attr_name in dir(base):
1351
d[attr_name] = getattr(base, attr_name)
1353
for func_name, func in d.items():
1354
if func_name.startswith('test') and callable(func):
1355
setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1358
def CleanUpTest(cls, func):
1359
"""Adds Mox cleanup code to any MoxTestBase method.
1361
Always unsets stubs after a test. Will verify all mocks for tests that
1365
cls: MoxTestBase or subclass; the class whose test method we are altering.
1366
func: method; the method of the MoxTestBase test class we wish to alter.
1369
The modified method.
1371
def new_method(self, *args, **kwargs):
1372
mox_obj = getattr(self, 'mox', None)
1374
if mox_obj and isinstance(mox_obj, Mox):
1377
func(self, *args, **kwargs)
1380
mox_obj.UnsetStubs()
1383
new_method.__name__ = func.__name__
1384
new_method.__doc__ = func.__doc__
1385
new_method.__module__ = func.__module__
1389
class MoxTestBase(unittest.TestCase):
1390
"""Convenience test class to make stubbing easier.
1392
Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393
want this. Also automatically unsets any stubs and verifies that all mock
1394
methods have been called at the end of each test, eliminating boilerplate
1398
__metaclass__ = MoxMetaTestBase