~ubuntu-branches/ubuntu/wily/pymongo/wily-proposed

« back to all changes in this revision

Viewing changes to test/__init__.py

  • Committer: Package Import Robot
  • Author(s): Federico Ceratto
  • Date: 2015-04-26 22:43:13 UTC
  • mfrom: (24.1.5 sid)
  • Revision ID: package-import@ubuntu.com-20150426224313-0hga2jphvf0rrmfe
Tags: 3.0.1-1
New upstream release.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright 2010-2014 MongoDB, Inc.
 
1
# Copyright 2010-2015 MongoDB, Inc.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
 
15
 
"""Clean up databases after running `nosetests`.
 
15
"""Test suite for pymongo, bson, and gridfs.
16
16
"""
17
17
 
18
18
import os
 
19
import socket
 
20
import sys
 
21
from pymongo.common import partition_node
 
22
 
 
23
if sys.version_info[:2] == (2, 6):
 
24
    import unittest2 as unittest
 
25
    from unittest2 import SkipTest
 
26
else:
 
27
    import unittest
 
28
    from unittest import SkipTest
19
29
import warnings
20
30
 
 
31
from functools import wraps
 
32
 
21
33
import pymongo
22
 
from pymongo.errors import ConnectionFailure
23
 
 
24
 
# hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode
25
 
# type in Python 2, so ensure these hostnames are unicodes, too. It makes tests
26
 
# like `test_repr` predictable.
27
 
host = unicode(os.environ.get("DB_IP", 'localhost'))
 
34
import pymongo.errors
 
35
 
 
36
from bson.py3compat import _unicode
 
37
from pymongo import common
 
38
from test.version import Version
 
39
 
 
40
# hostnames retrieved from isMaster will be of unicode type in Python 2,
 
41
# so ensure these hostnames are unicodes, too. It makes tests like
 
42
# `test_repr` predictable.
 
43
host = _unicode(os.environ.get("DB_IP", 'localhost'))
28
44
port = int(os.environ.get("DB_PORT", 27017))
29
45
pair = '%s:%d' % (host, port)
30
46
 
31
 
host2 = unicode(os.environ.get("DB_IP2", 'localhost'))
 
47
host2 = _unicode(os.environ.get("DB_IP2", 'localhost'))
32
48
port2 = int(os.environ.get("DB_PORT2", 27018))
33
49
 
34
 
host3 = unicode(os.environ.get("DB_IP3", 'localhost'))
 
50
host3 = _unicode(os.environ.get("DB_IP3", 'localhost'))
35
51
port3 = int(os.environ.get("DB_PORT3", 27019))
36
52
 
37
 
# Make sure warnings are always raised, regardless of
38
 
# python version.
 
53
db_user = _unicode(os.environ.get("DB_USER", "user"))
 
54
db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password"))
 
55
 
 
56
 
 
57
class client_knobs(object):
 
58
    def __init__(
 
59
            self,
 
60
            heartbeat_frequency=None,
 
61
            kill_cursor_frequency=None):
 
62
        self.heartbeat_frequency = heartbeat_frequency
 
63
        self.kill_cursor_frequency = kill_cursor_frequency
 
64
 
 
65
        self.old_heartbeat_frequency = None
 
66
        self.old_kill_cursor_frequency = None
 
67
 
 
68
    def enable(self):
 
69
        self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
 
70
        self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
 
71
 
 
72
        if self.heartbeat_frequency is not None:
 
73
            common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
 
74
 
 
75
        if self.kill_cursor_frequency is not None:
 
76
            common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
 
77
 
 
78
    def __enter__(self):
 
79
        self.enable()
 
80
 
 
81
    def disable(self):
 
82
        common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
 
83
        common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
 
84
 
 
85
    def __exit__(self, exc_type, exc_val, exc_tb):
 
86
        self.disable()
 
87
 
 
88
 
 
89
class ClientContext(object):
 
90
 
 
91
    def __init__(self):
 
92
        """Create a client and grab essential information from the server."""
 
93
        self.connected = False
 
94
        self.ismaster = {}
 
95
        self.w = None
 
96
        self.nodes = set()
 
97
        self.replica_set_name = None
 
98
        self.rs_client = None
 
99
        self.cmd_line = None
 
100
        self.version = Version(-1)  # Needs to be comparable with Version
 
101
        self.auth_enabled = False
 
102
        self.test_commands_enabled = False
 
103
        self.is_mongos = False
 
104
        self.is_rs = False
 
105
        self.has_ipv6 = False
 
106
 
 
107
        try:
 
108
            client = pymongo.MongoClient(host, port,
 
109
                                         serverSelectionTimeoutMS=100)
 
110
            client.admin.command('ismaster')  # Can we connect?
 
111
            
 
112
            # If so, then reset client to defaults.
 
113
            self.client = pymongo.MongoClient(host, port)
 
114
 
 
115
        except pymongo.errors.ConnectionFailure:
 
116
            self.client = None
 
117
        else:
 
118
            self.connected = True
 
119
            self.ismaster = self.client.admin.command('ismaster')
 
120
            self.w = len(self.ismaster.get("hosts", [])) or 1
 
121
            self.nodes = set([(host, port)])
 
122
            self.replica_set_name = self.ismaster.get('setName', '')
 
123
            self.rs_client = None
 
124
            self.version = Version.from_client(self.client)
 
125
            if self.replica_set_name:
 
126
                self.is_rs = True
 
127
                self.rs_client = pymongo.MongoClient(
 
128
                    pair, replicaSet=self.replica_set_name)
 
129
 
 
130
                self.nodes = set([partition_node(node)
 
131
                                  for node in self.ismaster.get('hosts', [])])
 
132
 
 
133
            self.rs_or_standalone_client = self.rs_client or self.client
 
134
 
 
135
            try:
 
136
                self.cmd_line = self.client.admin.command('getCmdLineOpts')
 
137
            except pymongo.errors.OperationFailure as e:
 
138
                msg = e.details.get('errmsg', '')
 
139
                if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
 
140
                    # Unauthorized.
 
141
                    self.auth_enabled = True
 
142
                else:
 
143
                    raise
 
144
            else:
 
145
                self.auth_enabled = self._server_started_with_auth()
 
146
 
 
147
            if self.auth_enabled:
 
148
                # See if db_user already exists.
 
149
                self.user_provided = self._check_user_provided()
 
150
                if not self.user_provided:
 
151
                    roles = {}
 
152
                    if self.version.at_least(2, 5, 3, -1):
 
153
                        roles = {'roles': ['root']}
 
154
                    self.client.admin.add_user(db_user, db_pwd, **roles)
 
155
                    self.client.admin.authenticate(db_user, db_pwd)
 
156
 
 
157
                if self.rs_client:
 
158
                    self.rs_client.admin.authenticate(db_user, db_pwd)
 
159
 
 
160
                # May not have this if OperationFailure was raised earlier.
 
161
                self.cmd_line = self.client.admin.command('getCmdLineOpts')
 
162
 
 
163
            if 'enableTestCommands=1' in self.cmd_line['argv']:
 
164
                self.test_commands_enabled = True
 
165
            elif 'parsed' in self.cmd_line:
 
166
                params = self.cmd_line['parsed'].get('setParameter', [])
 
167
                if 'enableTestCommands=1' in params:
 
168
                    self.test_commands_enabled = True
 
169
 
 
170
            self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
 
171
            self.has_ipv6 = self._server_started_with_ipv6()
 
172
 
 
173
    def _check_user_provided(self):
 
174
        try:
 
175
            self.client.admin.authenticate(db_user, db_pwd)
 
176
            return True
 
177
        except pymongo.errors.OperationFailure as e:
 
178
            msg = e.details.get('errmsg', '')
 
179
            if e.code == 18 or 'auth fails' in msg:
 
180
                # Auth failed.
 
181
                return False
 
182
            else:
 
183
                raise
 
184
 
 
185
    def _server_started_with_auth(self):
 
186
        # MongoDB >= 2.0
 
187
        if 'parsed' in self.cmd_line:
 
188
            parsed = self.cmd_line['parsed']
 
189
            # MongoDB >= 2.6
 
190
            if 'security' in parsed:
 
191
                security = parsed['security']
 
192
                # >= rc3
 
193
                if 'authorization' in security:
 
194
                    return security['authorization'] == 'enabled'
 
195
                # < rc3
 
196
                return (security.get('auth', False) or
 
197
                        bool(security.get('keyFile')))
 
198
            return parsed.get('auth', False) or bool(parsed.get('keyFile'))
 
199
        # Legacy
 
200
        argv = self.cmd_line['argv']
 
201
        return '--auth' in argv or '--keyFile' in argv
 
202
 
 
203
    def _server_started_with_ipv6(self):
 
204
        if not socket.has_ipv6:
 
205
            return False
 
206
 
 
207
        if 'parsed' in self.cmd_line:
 
208
            if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
 
209
                return False
 
210
        else:
 
211
            if '--ipv6' not in self.cmd_line['argv']:
 
212
                return False
 
213
 
 
214
        # The server was started with --ipv6. Is there an IPv6 route to it?
 
215
        try:
 
216
            for info in socket.getaddrinfo(host, port):
 
217
                if info[0] == socket.AF_INET6:
 
218
                    return True
 
219
        except socket.error:
 
220
            pass
 
221
 
 
222
        return False
 
223
 
 
224
    def _require(self, condition, msg, func=None):
 
225
        def make_wrapper(f):
 
226
            @wraps(f)
 
227
            def wrap(*args, **kwargs):
 
228
                # Always raise SkipTest if we can't connect to MongoDB
 
229
                if not self.connected:
 
230
                    raise SkipTest("Cannot connect to MongoDB on %s" % pair)
 
231
                if condition:
 
232
                    return f(*args, **kwargs)
 
233
                raise SkipTest(msg)
 
234
            return wrap
 
235
 
 
236
        if func is None:
 
237
            def decorate(f):
 
238
                return make_wrapper(f)
 
239
            return decorate
 
240
        return make_wrapper(func)
 
241
 
 
242
    def require_connection(self, func):
 
243
        """Run a test only if we can connect to MongoDB."""
 
244
        return self._require(self.connected,
 
245
                             "Cannot connect to MongoDB on %s" % pair,
 
246
                             func=func)
 
247
 
 
248
    def require_version_min(self, *ver):
 
249
        """Run a test only if the server version is at least ``version``."""
 
250
        other_version = Version(*ver)
 
251
        return self._require(self.version >= other_version,
 
252
                             "Server version must be at least %s"
 
253
                             % str(other_version))
 
254
 
 
255
    def require_version_max(self, *ver):
 
256
        """Run a test only if the server version is at most ``version``."""
 
257
        other_version = Version(*ver)
 
258
        return self._require(self.version <= other_version,
 
259
                             "Server version must be at most %s"
 
260
                             % str(other_version))
 
261
 
 
262
    def require_auth(self, func):
 
263
        """Run a test only if the server is running with auth enabled."""
 
264
        return self.check_auth_with_sharding(
 
265
            self._require(self.auth_enabled,
 
266
                          "Authentication is not enabled on the server",
 
267
                          func=func))
 
268
 
 
269
    def require_no_auth(self, func):
 
270
        """Run a test only if the server is running without auth enabled."""
 
271
        return self._require(not self.auth_enabled,
 
272
                             "Authentication must not be enabled on the server",
 
273
                             func=func)
 
274
 
 
275
    def require_replica_set(self, func):
 
276
        """Run a test only if the client is connected to a replica set."""
 
277
        return self._require(self.is_rs,
 
278
                             "Not connected to a replica set",
 
279
                             func=func)
 
280
 
 
281
    def require_no_replica_set(self, func):
 
282
        """Run a test if the client is *not* connected to a replica set."""
 
283
        return self._require(
 
284
            not self.is_rs,
 
285
            "Connected to a replica set, not a standalone mongod",
 
286
            func=func)
 
287
 
 
288
    def require_ipv6(self, func):
 
289
        """Run a test only if the client can connect to a server via IPv6."""
 
290
        return self._require(self.has_ipv6,
 
291
                             "No IPv6",
 
292
                             func=func)
 
293
 
 
294
    def require_no_mongos(self, func):
 
295
        """Run a test only if the client is not connected to a mongos."""
 
296
        return self._require(not self.is_mongos,
 
297
                             "Must be connected to a mongod, not a mongos",
 
298
                             func=func)
 
299
 
 
300
    def require_mongos(self, func):
 
301
        """Run a test only if the client is connected to a mongos."""
 
302
        return self._require(self.is_mongos,
 
303
                             "Must be connected to a mongos",
 
304
                             func=func)
 
305
 
 
306
    def check_auth_with_sharding(self, func):
 
307
        """Skip a test when connected to mongos < 2.0 and running with auth."""
 
308
        condition = not (self.auth_enabled and
 
309
                         self.is_mongos and self.version < (2,))
 
310
        return self._require(condition,
 
311
                             "Auth with sharding requires MongoDB >= 2.0.0",
 
312
                             func=func)
 
313
 
 
314
    def require_test_commands(self, func):
 
315
        """Run a test only if the server has test commands enabled."""
 
316
        return self._require(self.test_commands_enabled,
 
317
                             "Test commands must be enabled",
 
318
                             func=func)
 
319
 
 
320
 
 
321
# Reusable client context
 
322
client_context = ClientContext()
 
323
 
 
324
 
 
325
class IntegrationTest(unittest.TestCase):
 
326
    """Base class for TestCases that need a connection to MongoDB to pass."""
 
327
 
 
328
    @classmethod
 
329
    @client_context.require_connection
 
330
    def setUpClass(cls):
 
331
        cls.client = client_context.rs_or_standalone_client
 
332
        cls.db = cls.client.pymongo_test
 
333
 
 
334
 
 
335
class MockClientTest(unittest.TestCase):
 
336
    """Base class for TestCases that use MockClient.
 
337
 
 
338
    This class is *not* an IntegrationTest: if properly written, MockClient
 
339
    tests do not require a running server.
 
340
 
 
341
    The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests.
 
342
    """
 
343
 
 
344
    def setUp(self):
 
345
        super(MockClientTest, self).setUp()
 
346
 
 
347
        self.client_knobs = client_knobs(
 
348
            heartbeat_frequency=0.001)
 
349
 
 
350
        self.client_knobs.enable()
 
351
 
 
352
    def tearDown(self):
 
353
        self.client_knobs.disable()
 
354
        super(MockClientTest, self).tearDown()
 
355
 
 
356
 
39
357
def setup():
40
358
    warnings.resetwarnings()
41
359
    warnings.simplefilter("always")
42
360
 
43
361
 
44
362
def teardown():
45
 
    try:
46
 
        c = pymongo.MongoClient(host, port)
47
 
    except ConnectionFailure:
48
 
        # Tests where ssl=True can cause connection failures here.
49
 
        # Ignore and continue.
50
 
        return
51
 
 
 
363
    c = client_context.client
52
364
    c.drop_database("pymongo-pooling-tests")
53
365
    c.drop_database("pymongo_test")
54
366
    c.drop_database("pymongo_test1")
55
367
    c.drop_database("pymongo_test2")
56
368
    c.drop_database("pymongo_test_mike")
57
369
    c.drop_database("pymongo_test_bernie")
 
370
    if client_context.auth_enabled and not client_context.user_provided:
 
371
        c.admin.remove_user(db_user)
 
372
 
 
373
 
 
374
class PymongoTestRunner(unittest.TextTestRunner):
 
375
    def run(self, test):
 
376
        setup()
 
377
        result = super(PymongoTestRunner, self).run(test)
 
378
        try:
 
379
            teardown()
 
380
        finally:
 
381
            return result
 
382
 
 
383
 
 
384
def test_cases(suite):
 
385
    """Iterator over all TestCases within a TestSuite."""
 
386
    for suite_or_case in suite._tests:
 
387
        if isinstance(suite_or_case, unittest.TestCase):
 
388
            # unittest.TestCase
 
389
            yield suite_or_case
 
390
        else:
 
391
            # unittest.TestSuite
 
392
            for case in test_cases(suite_or_case):
 
393
                yield case