12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
15
"""Clean up databases after running `nosetests`.
15
"""Test suite for pymongo, bson, and gridfs.
21
from pymongo.common import partition_node
23
if sys.version_info[:2] == (2, 6):
24
import unittest2 as unittest
25
from unittest2 import SkipTest
28
from unittest import SkipTest
31
from functools import wraps
22
from pymongo.errors import ConnectionFailure
24
# hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode
25
# type in Python 2, so ensure these hostnames are unicodes, too. It makes tests
26
# like `test_repr` predictable.
27
host = unicode(os.environ.get("DB_IP", 'localhost'))
36
from bson.py3compat import _unicode
37
from pymongo import common
38
from test.version import Version
40
# hostnames retrieved from isMaster will be of unicode type in Python 2,
41
# so ensure these hostnames are unicodes, too. It makes tests like
42
# `test_repr` predictable.
43
host = _unicode(os.environ.get("DB_IP", 'localhost'))
28
44
port = int(os.environ.get("DB_PORT", 27017))
29
45
pair = '%s:%d' % (host, port)
31
host2 = unicode(os.environ.get("DB_IP2", 'localhost'))
47
host2 = _unicode(os.environ.get("DB_IP2", 'localhost'))
32
48
port2 = int(os.environ.get("DB_PORT2", 27018))
34
host3 = unicode(os.environ.get("DB_IP3", 'localhost'))
50
host3 = _unicode(os.environ.get("DB_IP3", 'localhost'))
35
51
port3 = int(os.environ.get("DB_PORT3", 27019))
37
# Make sure warnings are always raised, regardless of
53
db_user = _unicode(os.environ.get("DB_USER", "user"))
54
db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password"))
57
class client_knobs(object):
60
heartbeat_frequency=None,
61
kill_cursor_frequency=None):
62
self.heartbeat_frequency = heartbeat_frequency
63
self.kill_cursor_frequency = kill_cursor_frequency
65
self.old_heartbeat_frequency = None
66
self.old_kill_cursor_frequency = None
69
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
70
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
72
if self.heartbeat_frequency is not None:
73
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
75
if self.kill_cursor_frequency is not None:
76
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
82
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
83
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
85
def __exit__(self, exc_type, exc_val, exc_tb):
89
class ClientContext(object):
92
"""Create a client and grab essential information from the server."""
93
self.connected = False
97
self.replica_set_name = None
100
self.version = Version(-1) # Needs to be comparable with Version
101
self.auth_enabled = False
102
self.test_commands_enabled = False
103
self.is_mongos = False
105
self.has_ipv6 = False
108
client = pymongo.MongoClient(host, port,
109
serverSelectionTimeoutMS=100)
110
client.admin.command('ismaster') # Can we connect?
112
# If so, then reset client to defaults.
113
self.client = pymongo.MongoClient(host, port)
115
except pymongo.errors.ConnectionFailure:
118
self.connected = True
119
self.ismaster = self.client.admin.command('ismaster')
120
self.w = len(self.ismaster.get("hosts", [])) or 1
121
self.nodes = set([(host, port)])
122
self.replica_set_name = self.ismaster.get('setName', '')
123
self.rs_client = None
124
self.version = Version.from_client(self.client)
125
if self.replica_set_name:
127
self.rs_client = pymongo.MongoClient(
128
pair, replicaSet=self.replica_set_name)
130
self.nodes = set([partition_node(node)
131
for node in self.ismaster.get('hosts', [])])
133
self.rs_or_standalone_client = self.rs_client or self.client
136
self.cmd_line = self.client.admin.command('getCmdLineOpts')
137
except pymongo.errors.OperationFailure as e:
138
msg = e.details.get('errmsg', '')
139
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
141
self.auth_enabled = True
145
self.auth_enabled = self._server_started_with_auth()
147
if self.auth_enabled:
148
# See if db_user already exists.
149
self.user_provided = self._check_user_provided()
150
if not self.user_provided:
152
if self.version.at_least(2, 5, 3, -1):
153
roles = {'roles': ['root']}
154
self.client.admin.add_user(db_user, db_pwd, **roles)
155
self.client.admin.authenticate(db_user, db_pwd)
158
self.rs_client.admin.authenticate(db_user, db_pwd)
160
# May not have this if OperationFailure was raised earlier.
161
self.cmd_line = self.client.admin.command('getCmdLineOpts')
163
if 'enableTestCommands=1' in self.cmd_line['argv']:
164
self.test_commands_enabled = True
165
elif 'parsed' in self.cmd_line:
166
params = self.cmd_line['parsed'].get('setParameter', [])
167
if 'enableTestCommands=1' in params:
168
self.test_commands_enabled = True
170
self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
171
self.has_ipv6 = self._server_started_with_ipv6()
173
def _check_user_provided(self):
175
self.client.admin.authenticate(db_user, db_pwd)
177
except pymongo.errors.OperationFailure as e:
178
msg = e.details.get('errmsg', '')
179
if e.code == 18 or 'auth fails' in msg:
185
def _server_started_with_auth(self):
187
if 'parsed' in self.cmd_line:
188
parsed = self.cmd_line['parsed']
190
if 'security' in parsed:
191
security = parsed['security']
193
if 'authorization' in security:
194
return security['authorization'] == 'enabled'
196
return (security.get('auth', False) or
197
bool(security.get('keyFile')))
198
return parsed.get('auth', False) or bool(parsed.get('keyFile'))
200
argv = self.cmd_line['argv']
201
return '--auth' in argv or '--keyFile' in argv
203
def _server_started_with_ipv6(self):
204
if not socket.has_ipv6:
207
if 'parsed' in self.cmd_line:
208
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
211
if '--ipv6' not in self.cmd_line['argv']:
214
# The server was started with --ipv6. Is there an IPv6 route to it?
216
for info in socket.getaddrinfo(host, port):
217
if info[0] == socket.AF_INET6:
224
def _require(self, condition, msg, func=None):
227
def wrap(*args, **kwargs):
228
# Always raise SkipTest if we can't connect to MongoDB
229
if not self.connected:
230
raise SkipTest("Cannot connect to MongoDB on %s" % pair)
232
return f(*args, **kwargs)
238
return make_wrapper(f)
240
return make_wrapper(func)
242
def require_connection(self, func):
243
"""Run a test only if we can connect to MongoDB."""
244
return self._require(self.connected,
245
"Cannot connect to MongoDB on %s" % pair,
248
def require_version_min(self, *ver):
249
"""Run a test only if the server version is at least ``version``."""
250
other_version = Version(*ver)
251
return self._require(self.version >= other_version,
252
"Server version must be at least %s"
253
% str(other_version))
255
def require_version_max(self, *ver):
256
"""Run a test only if the server version is at most ``version``."""
257
other_version = Version(*ver)
258
return self._require(self.version <= other_version,
259
"Server version must be at most %s"
260
% str(other_version))
262
def require_auth(self, func):
263
"""Run a test only if the server is running with auth enabled."""
264
return self.check_auth_with_sharding(
265
self._require(self.auth_enabled,
266
"Authentication is not enabled on the server",
269
def require_no_auth(self, func):
270
"""Run a test only if the server is running without auth enabled."""
271
return self._require(not self.auth_enabled,
272
"Authentication must not be enabled on the server",
275
def require_replica_set(self, func):
276
"""Run a test only if the client is connected to a replica set."""
277
return self._require(self.is_rs,
278
"Not connected to a replica set",
281
def require_no_replica_set(self, func):
282
"""Run a test if the client is *not* connected to a replica set."""
283
return self._require(
285
"Connected to a replica set, not a standalone mongod",
288
def require_ipv6(self, func):
289
"""Run a test only if the client can connect to a server via IPv6."""
290
return self._require(self.has_ipv6,
294
def require_no_mongos(self, func):
295
"""Run a test only if the client is not connected to a mongos."""
296
return self._require(not self.is_mongos,
297
"Must be connected to a mongod, not a mongos",
300
def require_mongos(self, func):
301
"""Run a test only if the client is connected to a mongos."""
302
return self._require(self.is_mongos,
303
"Must be connected to a mongos",
306
def check_auth_with_sharding(self, func):
307
"""Skip a test when connected to mongos < 2.0 and running with auth."""
308
condition = not (self.auth_enabled and
309
self.is_mongos and self.version < (2,))
310
return self._require(condition,
311
"Auth with sharding requires MongoDB >= 2.0.0",
314
def require_test_commands(self, func):
315
"""Run a test only if the server has test commands enabled."""
316
return self._require(self.test_commands_enabled,
317
"Test commands must be enabled",
321
# Reusable client context
322
client_context = ClientContext()
325
class IntegrationTest(unittest.TestCase):
326
"""Base class for TestCases that need a connection to MongoDB to pass."""
329
@client_context.require_connection
331
cls.client = client_context.rs_or_standalone_client
332
cls.db = cls.client.pymongo_test
335
class MockClientTest(unittest.TestCase):
336
"""Base class for TestCases that use MockClient.
338
This class is *not* an IntegrationTest: if properly written, MockClient
339
tests do not require a running server.
341
The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests.
345
super(MockClientTest, self).setUp()
347
self.client_knobs = client_knobs(
348
heartbeat_frequency=0.001)
350
self.client_knobs.enable()
353
self.client_knobs.disable()
354
super(MockClientTest, self).tearDown()
40
358
warnings.resetwarnings()
41
359
warnings.simplefilter("always")
46
c = pymongo.MongoClient(host, port)
47
except ConnectionFailure:
48
# Tests where ssl=True can cause connection failures here.
49
# Ignore and continue.
363
c = client_context.client
52
364
c.drop_database("pymongo-pooling-tests")
53
365
c.drop_database("pymongo_test")
54
366
c.drop_database("pymongo_test1")
55
367
c.drop_database("pymongo_test2")
56
368
c.drop_database("pymongo_test_mike")
57
369
c.drop_database("pymongo_test_bernie")
370
if client_context.auth_enabled and not client_context.user_provided:
371
c.admin.remove_user(db_user)
374
class PymongoTestRunner(unittest.TextTestRunner):
377
result = super(PymongoTestRunner, self).run(test)
384
def test_cases(suite):
385
"""Iterator over all TestCases within a TestSuite."""
386
for suite_or_case in suite._tests:
387
if isinstance(suite_or_case, unittest.TestCase):
392
for case in test_cases(suite_or_case):