~zulcss/nova/nova-precise-g3

« back to all changes in this revision

Viewing changes to .pc/upstream/0002-Stop-libvirt-test-from-deleting-instances-dir.patch/nova/tests/test_libvirt.py

  • Committer: Package Import Robot
  • Author(s): Chuck Short, Adam Gandelman, Chuck Short
  • Date: 2012-04-12 14:14:29 UTC
  • Revision ID: package-import@ubuntu.com-20120412141429-dt69y6cd5e0uqbmk
Tags: 2012.1-0ubuntu2
[ Adam Gandelman ]
* debian/rules: Properly create empty doc/build/man dir for builds that
  skip doc building
* debian/control: Set 'Conflicts: nova-compute-hypervisor' for the various
  nova-compute-$type packages. (LP: #975616)
* debian/control: Set 'Breaks: nova-api' for the various nova-api-$service
  sub-packages. (LP: #966115)

[ Chuck Short ]
* Resynchronize with stable/essex:
  - b1d11b8 Use project_id in ec2.cloud._format_image()
  - 6e988ed Fixes image publication using deprecated auth. (LP: #977765)
  - 6e988ed Populate image properties with project_id again
  - 3b14c74 Fixed bug 962840, added a test case.
  - d4e96fe Allow unprivileged RADOS users to access rbd volumes.
  - 4acfab6 Stop libvirt test from deleting instances dir
  - 155c7b2 fix bug where nova ignores glance host in imageref
* debian/nova.conf: Enabled ec2_private_dns_show_ip so that juju can
  connect to openstack instances.
* debian/patches/fix-docs-build-without-network.patch: Fix docs build
  when there is no network access.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# vim: tabstop=4 shiftwidth=4 softtabstop=4
 
2
#
 
3
#    Copyright 2010 OpenStack LLC
 
4
#
 
5
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 
6
#    not use this file except in compliance with the License. You may obtain
 
7
#    a copy of the License at
 
8
#
 
9
#         http://www.apache.org/licenses/LICENSE-2.0
 
10
#
 
11
#    Unless required by applicable law or agreed to in writing, software
 
12
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 
13
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 
14
#    License for the specific language governing permissions and limitations
 
15
#    under the License.
 
16
 
 
17
import copy
 
18
import errno
 
19
import eventlet
 
20
import mox
 
21
import os
 
22
import re
 
23
import shutil
 
24
import sys
 
25
import tempfile
 
26
 
 
27
from xml.etree import ElementTree
 
28
from xml.dom import minidom
 
29
 
 
30
from nova import context
 
31
from nova import db
 
32
from nova import exception
 
33
from nova import flags
 
34
from nova import log as logging
 
35
from nova import test
 
36
from nova import utils
 
37
from nova.api.ec2 import cloud
 
38
from nova.compute import instance_types
 
39
from nova.compute import power_state
 
40
from nova.compute import utils as compute_utils
 
41
from nova.compute import vm_states
 
42
from nova.virt import images
 
43
from nova.virt import driver
 
44
from nova.virt import firewall as base_firewall
 
45
from nova.virt.libvirt import connection
 
46
from nova.virt.libvirt import firewall
 
47
from nova.virt.libvirt import volume
 
48
from nova.volume import driver as volume_driver
 
49
from nova.virt.libvirt import utils as libvirt_utils
 
50
from nova.tests import fake_network
 
51
from nova.tests import fake_libvirt_utils
 
52
 
 
53
 
 
54
try:
 
55
    import libvirt
 
56
    connection.libvirt = libvirt
 
57
except ImportError:
 
58
    libvirt = None
 
59
 
 
60
 
 
61
FLAGS = flags.FLAGS
 
62
LOG = logging.getLogger(__name__)
 
63
 
 
64
_fake_network_info = fake_network.fake_get_instance_nw_info
 
65
_fake_stub_out_get_nw_info = fake_network.stub_out_nw_api_get_instance_nw_info
 
66
_ipv4_like = fake_network.ipv4_like
 
67
 
 
68
 
 
69
def _concurrency(wait, done, target):
 
70
    wait.wait()
 
71
    done.send()
 
72
 
 
73
 
 
74
class FakeVirDomainSnapshot(object):
 
75
 
 
76
    def __init__(self, dom=None):
 
77
        self.dom = dom
 
78
 
 
79
    def delete(self, flags):
 
80
        pass
 
81
 
 
82
 
 
83
class FakeVirtDomain(object):
 
84
 
 
85
    def __init__(self, fake_xml=None):
 
86
        if fake_xml:
 
87
            self._fake_dom_xml = fake_xml
 
88
        else:
 
89
            self._fake_dom_xml = """
 
90
                <domain type='kvm'>
 
91
                    <devices>
 
92
                        <disk type='file'>
 
93
                            <source file='filename'/>
 
94
                        </disk>
 
95
                    </devices>
 
96
                </domain>
 
97
            """
 
98
 
 
99
    def name(self):
 
100
        return "fake-domain %s" % self
 
101
 
 
102
    def info(self):
 
103
        return [power_state.RUNNING, None, None, None, None]
 
104
 
 
105
    def create(self):
 
106
        pass
 
107
 
 
108
    def managedSave(self, *args):
 
109
        pass
 
110
 
 
111
    def createWithFlags(self, launch_flags):
 
112
        pass
 
113
 
 
114
    def XMLDesc(self, *args):
 
115
        return self._fake_dom_xml
 
116
 
 
117
 
 
118
class LibvirtVolumeTestCase(test.TestCase):
 
119
 
 
120
    def setUp(self):
 
121
        super(LibvirtVolumeTestCase, self).setUp()
 
122
        self.executes = []
 
123
 
 
124
        def fake_execute(*cmd, **kwargs):
 
125
            self.executes.append(cmd)
 
126
            return None, None
 
127
 
 
128
        self.stubs.Set(utils, 'execute', fake_execute)
 
129
 
 
130
        class FakeLibvirtConnection(object):
 
131
            def __init__(self, hyperv="QEMU"):
 
132
                self.hyperv = hyperv
 
133
 
 
134
            def get_hypervisor_type(self):
 
135
                return self.hyperv
 
136
 
 
137
            def get_all_block_devices(self):
 
138
                return []
 
139
 
 
140
        self.fake_conn = FakeLibvirtConnection()
 
141
        self.connr = {
 
142
            'ip': '127.0.0.1',
 
143
            'initiator': 'fake_initiator'
 
144
        }
 
145
 
 
146
    def test_libvirt_iscsi_driver(self):
 
147
        # NOTE(vish) exists is to make driver assume connecting worked
 
148
        self.stubs.Set(os.path, 'exists', lambda x: True)
 
149
        vol_driver = volume_driver.ISCSIDriver()
 
150
        libvirt_driver = volume.LibvirtISCSIVolumeDriver(self.fake_conn)
 
151
        location = '10.0.2.15:3260'
 
152
        name = 'volume-00000001'
 
153
        iqn = 'iqn.2010-10.org.openstack:%s' % name
 
154
        vol = {'id': 1,
 
155
               'name': name,
 
156
               'provider_auth': None,
 
157
               'provider_location': '%s,fake %s' % (location, iqn)}
 
158
        connection_info = vol_driver.initialize_connection(vol, self.connr)
 
159
        mount_device = "vde"
 
160
        xml = libvirt_driver.connect_volume(connection_info, mount_device)
 
161
        tree = ElementTree.fromstring(xml)
 
162
        dev_str = '/dev/disk/by-path/ip-%s-iscsi-%s-lun-0' % (location, iqn)
 
163
        self.assertEqual(tree.get('type'), 'block')
 
164
        self.assertEqual(tree.find('./source').get('dev'), dev_str)
 
165
        libvirt_driver.disconnect_volume(connection_info, mount_device)
 
166
        connection_info = vol_driver.terminate_connection(vol, self.connr)
 
167
        expected_commands = [('iscsiadm', '-m', 'node', '-T', iqn,
 
168
                              '-p', location),
 
169
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
170
                              '-p', location, '--login'),
 
171
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
172
                              '-p', location, '--op', 'update',
 
173
                              '-n', 'node.startup', '-v', 'automatic'),
 
174
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
175
                              '-p', location, '--op', 'update',
 
176
                              '-n', 'node.startup', '-v', 'manual'),
 
177
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
178
                              '-p', location, '--logout'),
 
179
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
180
                              '-p', location, '--op', 'delete')]
 
181
        self.assertEqual(self.executes, expected_commands)
 
182
 
 
183
    def test_libvirt_iscsi_driver_still_in_use(self):
 
184
        # NOTE(vish) exists is to make driver assume connecting worked
 
185
        self.stubs.Set(os.path, 'exists', lambda x: True)
 
186
        vol_driver = volume_driver.ISCSIDriver()
 
187
        libvirt_driver = volume.LibvirtISCSIVolumeDriver(self.fake_conn)
 
188
        location = '10.0.2.15:3260'
 
189
        name = 'volume-00000001'
 
190
        iqn = 'iqn.2010-10.org.openstack:%s' % name
 
191
        devs = ['/dev/disk/by-path/ip-%s-iscsi-%s-lun-1' % (location, iqn)]
 
192
        self.stubs.Set(self.fake_conn, 'get_all_block_devices', lambda: devs)
 
193
        vol = {'id': 1,
 
194
               'name': name,
 
195
               'provider_auth': None,
 
196
               'provider_location': '%s,fake %s' % (location, iqn)}
 
197
        connection_info = vol_driver.initialize_connection(vol, self.connr)
 
198
        mount_device = "vde"
 
199
        xml = libvirt_driver.connect_volume(connection_info, mount_device)
 
200
        tree = ElementTree.fromstring(xml)
 
201
        dev_str = '/dev/disk/by-path/ip-%s-iscsi-%s-lun-0' % (location, iqn)
 
202
        self.assertEqual(tree.get('type'), 'block')
 
203
        self.assertEqual(tree.find('./source').get('dev'), dev_str)
 
204
        libvirt_driver.disconnect_volume(connection_info, mount_device)
 
205
        connection_info = vol_driver.terminate_connection(vol, self.connr)
 
206
        expected_commands = [('iscsiadm', '-m', 'node', '-T', iqn,
 
207
                              '-p', location),
 
208
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
209
                              '-p', location, '--login'),
 
210
                             ('iscsiadm', '-m', 'node', '-T', iqn,
 
211
                              '-p', location, '--op', 'update',
 
212
                              '-n', 'node.startup', '-v', 'automatic')]
 
213
        self.assertEqual(self.executes, expected_commands)
 
214
 
 
215
    def test_libvirt_sheepdog_driver(self):
 
216
        vol_driver = volume_driver.SheepdogDriver()
 
217
        libvirt_driver = volume.LibvirtNetVolumeDriver(self.fake_conn)
 
218
        name = 'volume-00000001'
 
219
        vol = {'id': 1, 'name': name}
 
220
        connection_info = vol_driver.initialize_connection(vol, self.connr)
 
221
        mount_device = "vde"
 
222
        xml = libvirt_driver.connect_volume(connection_info, mount_device)
 
223
        tree = ElementTree.fromstring(xml)
 
224
        self.assertEqual(tree.get('type'), 'network')
 
225
        self.assertEqual(tree.find('./source').get('protocol'), 'sheepdog')
 
226
        self.assertEqual(tree.find('./source').get('name'), name)
 
227
        libvirt_driver.disconnect_volume(connection_info, mount_device)
 
228
        connection_info = vol_driver.terminate_connection(vol, self.connr)
 
229
 
 
230
    def test_libvirt_rbd_driver(self):
 
231
        vol_driver = volume_driver.RBDDriver()
 
232
        libvirt_driver = volume.LibvirtNetVolumeDriver(self.fake_conn)
 
233
        name = 'volume-00000001'
 
234
        vol = {'id': 1, 'name': name}
 
235
        connection_info = vol_driver.initialize_connection(vol, self.connr)
 
236
        mount_device = "vde"
 
237
        xml = libvirt_driver.connect_volume(connection_info, mount_device)
 
238
        tree = ElementTree.fromstring(xml)
 
239
        self.assertEqual(tree.get('type'), 'network')
 
240
        self.assertEqual(tree.find('./source').get('protocol'), 'rbd')
 
241
        rbd_name = '%s/%s' % (FLAGS.rbd_pool, name)
 
242
        self.assertEqual(tree.find('./source').get('name'), rbd_name)
 
243
        libvirt_driver.disconnect_volume(connection_info, mount_device)
 
244
        connection_info = vol_driver.terminate_connection(vol, self.connr)
 
245
 
 
246
    def test_libvirt_lxc_volume(self):
 
247
        self.stubs.Set(os.path, 'exists', lambda x: True)
 
248
        vol_driver = volume_driver.ISCSIDriver()
 
249
        libvirt_driver = volume.LibvirtISCSIVolumeDriver(self.fake_conn)
 
250
        location = '10.0.2.15:3260'
 
251
        name = 'volume-00000001'
 
252
        iqn = 'iqn.2010-10.org.openstack:%s' % name
 
253
        vol = {'id': 1,
 
254
               'name': name,
 
255
               'provider_auth': None,
 
256
               'provider_location': '%s,fake %s' % (location, iqn)}
 
257
        connection_info = vol_driver.initialize_connection(vol, self.connr)
 
258
        mount_device = "vde"
 
259
        xml = libvirt_driver.connect_volume(connection_info, mount_device)
 
260
        tree = ElementTree.fromstring(xml)
 
261
        dev_str = '/dev/disk/by-path/ip-%s-iscsi-%s-lun-0' % (location, iqn)
 
262
        self.assertEqual(tree.get('type'), 'block')
 
263
        self.assertEqual(tree.find('./source').get('dev'), dev_str)
 
264
        libvirt_driver.disconnect_volume(connection_info, mount_device)
 
265
        connection_info = vol_driver.terminate_connection(vol, self.connr)
 
266
 
 
267
 
 
268
class CacheConcurrencyTestCase(test.TestCase):
 
269
    def setUp(self):
 
270
        super(CacheConcurrencyTestCase, self).setUp()
 
271
        self.flags(instances_path='nova.compute.manager')
 
272
 
 
273
        def fake_exists(fname):
 
274
            basedir = os.path.join(FLAGS.instances_path, '_base')
 
275
            if fname == basedir:
 
276
                return True
 
277
            return False
 
278
 
 
279
        def fake_execute(*args, **kwargs):
 
280
            pass
 
281
 
 
282
        def fake_extend(image, size):
 
283
            pass
 
284
 
 
285
        self.stubs.Set(os.path, 'exists', fake_exists)
 
286
        self.stubs.Set(utils, 'execute', fake_execute)
 
287
        self.stubs.Set(connection.disk, 'extend', fake_extend)
 
288
        connection.libvirt_utils = fake_libvirt_utils
 
289
 
 
290
    def tearDown(self):
 
291
        connection.libvirt_utils = libvirt_utils
 
292
        super(CacheConcurrencyTestCase, self).tearDown()
 
293
 
 
294
    def test_same_fname_concurrency(self):
 
295
        """Ensures that the same fname cache runs at a sequentially"""
 
296
        conn = connection.LibvirtConnection
 
297
        wait1 = eventlet.event.Event()
 
298
        done1 = eventlet.event.Event()
 
299
        eventlet.spawn(conn._cache_image, _concurrency,
 
300
                       'target', 'fname', False, None, wait1, done1)
 
301
        wait2 = eventlet.event.Event()
 
302
        done2 = eventlet.event.Event()
 
303
        eventlet.spawn(conn._cache_image, _concurrency,
 
304
                       'target', 'fname', False, None, wait2, done2)
 
305
        wait2.send()
 
306
        eventlet.sleep(0)
 
307
        try:
 
308
            self.assertFalse(done2.ready())
 
309
        finally:
 
310
            wait1.send()
 
311
        done1.wait()
 
312
        eventlet.sleep(0)
 
313
        self.assertTrue(done2.ready())
 
314
 
 
315
    def test_different_fname_concurrency(self):
 
316
        """Ensures that two different fname caches are concurrent"""
 
317
        conn = connection.LibvirtConnection
 
318
        wait1 = eventlet.event.Event()
 
319
        done1 = eventlet.event.Event()
 
320
        eventlet.spawn(conn._cache_image, _concurrency,
 
321
                       'target', 'fname2', False, None, wait1, done1)
 
322
        wait2 = eventlet.event.Event()
 
323
        done2 = eventlet.event.Event()
 
324
        eventlet.spawn(conn._cache_image, _concurrency,
 
325
                       'target', 'fname1', False, None, wait2, done2)
 
326
        wait2.send()
 
327
        eventlet.sleep(0)
 
328
        try:
 
329
            self.assertTrue(done2.ready())
 
330
        finally:
 
331
            wait1.send()
 
332
            eventlet.sleep(0)
 
333
 
 
334
 
 
335
class FakeVolumeDriver(object):
 
336
    def __init__(self, *args, **kwargs):
 
337
        pass
 
338
 
 
339
    def attach_volume(self, *args):
 
340
        pass
 
341
 
 
342
    def detach_volume(self, *args):
 
343
        pass
 
344
 
 
345
    def get_xml(self, *args):
 
346
        return ""
 
347
 
 
348
 
 
349
def missing_libvirt():
 
350
    return libvirt is None
 
351
 
 
352
 
 
353
class LibvirtConnTestCase(test.TestCase):
 
354
 
 
355
    def setUp(self):
 
356
        super(LibvirtConnTestCase, self).setUp()
 
357
        connection._late_load_cheetah()
 
358
        self.flags(fake_call=True)
 
359
        self.user_id = 'fake'
 
360
        self.project_id = 'fake'
 
361
        self.context = context.get_admin_context()
 
362
        self.flags(instances_path='')
 
363
        self.call_libvirt_dependant_setup = False
 
364
        connection.libvirt_utils = fake_libvirt_utils
 
365
 
 
366
        def fake_extend(image, size):
 
367
            pass
 
368
 
 
369
        self.stubs.Set(connection.disk, 'extend', fake_extend)
 
370
 
 
371
    def tearDown(self):
 
372
        connection.libvirt_utils = libvirt_utils
 
373
        super(LibvirtConnTestCase, self).tearDown()
 
374
 
 
375
    test_instance = {'memory_kb': '1024000',
 
376
                     'basepath': '/some/path',
 
377
                     'bridge_name': 'br100',
 
378
                     'vcpus': 2,
 
379
                     'project_id': 'fake',
 
380
                     'bridge': 'br101',
 
381
                     'image_ref': '155d900f-4e14-4e4c-a73d-069cbf4541e6',
 
382
                     'root_gb': 10,
 
383
                     'ephemeral_gb': 20,
 
384
                     'instance_type_id': '5'}  # m1.small
 
385
 
 
386
    def create_fake_libvirt_mock(self, **kwargs):
 
387
        """Defining mocks for LibvirtConnection(libvirt is not used)."""
 
388
 
 
389
        # A fake libvirt.virConnect
 
390
        class FakeLibvirtConnection(object):
 
391
            def defineXML(self, xml):
 
392
                return FakeVirtDomain()
 
393
 
 
394
        # Creating mocks
 
395
        volume_driver = 'iscsi=nova.tests.test_libvirt.FakeVolumeDriver'
 
396
        self.flags(libvirt_volume_drivers=[volume_driver])
 
397
        fake = FakeLibvirtConnection()
 
398
        # Customizing above fake if necessary
 
399
        for key, val in kwargs.items():
 
400
            fake.__setattr__(key, val)
 
401
 
 
402
        self.flags(image_service='nova.image.fake.FakeImageService')
 
403
        self.flags(libvirt_vif_driver="nova.tests.fake_network.FakeVIFDriver")
 
404
 
 
405
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
406
        connection.LibvirtConnection._conn = fake
 
407
 
 
408
    def fake_lookup(self, instance_name):
 
409
        return FakeVirtDomain()
 
410
 
 
411
    def fake_execute(self, *args):
 
412
        open(args[-1], "a").close()
 
413
 
 
414
    def create_service(self, **kwargs):
 
415
        service_ref = {'host': kwargs.get('host', 'dummy'),
 
416
                       'binary': 'nova-compute',
 
417
                       'topic': 'compute',
 
418
                       'report_count': 0,
 
419
                       'availability_zone': 'zone'}
 
420
 
 
421
        return db.service_create(context.get_admin_context(), service_ref)
 
422
 
 
423
    def test_get_connector(self):
 
424
        initiator = 'fake.initiator.iqn'
 
425
        ip = 'fakeip'
 
426
        self.flags(my_ip=ip)
 
427
 
 
428
        conn = connection.LibvirtConnection(True)
 
429
        expected = {
 
430
            'ip': ip,
 
431
            'initiator': initiator
 
432
        }
 
433
        volume = {
 
434
            'id': 'fake'
 
435
        }
 
436
        result = conn.get_volume_connector(volume)
 
437
        self.assertDictMatch(expected, result)
 
438
 
 
439
    def test_preparing_xml_info(self):
 
440
        conn = connection.LibvirtConnection(True)
 
441
        instance_ref = db.instance_create(self.context, self.test_instance)
 
442
 
 
443
        result = conn._prepare_xml_info(instance_ref,
 
444
                                        _fake_network_info(self.stubs, 1),
 
445
                                        None, False)
 
446
        self.assertTrue(len(result['nics']) == 1)
 
447
 
 
448
        result = conn._prepare_xml_info(instance_ref,
 
449
                                        _fake_network_info(self.stubs, 2),
 
450
                                        None, False)
 
451
        self.assertTrue(len(result['nics']) == 2)
 
452
 
 
453
    def test_xml_and_uri_no_ramdisk_no_kernel(self):
 
454
        instance_data = dict(self.test_instance)
 
455
        self._check_xml_and_uri(instance_data,
 
456
                                expect_kernel=False, expect_ramdisk=False)
 
457
 
 
458
    def test_xml_and_uri_no_ramdisk(self):
 
459
        instance_data = dict(self.test_instance)
 
460
        instance_data['kernel_id'] = 'aki-deadbeef'
 
461
        self._check_xml_and_uri(instance_data,
 
462
                                expect_kernel=True, expect_ramdisk=False)
 
463
 
 
464
    def test_xml_and_uri_no_kernel(self):
 
465
        instance_data = dict(self.test_instance)
 
466
        instance_data['ramdisk_id'] = 'ari-deadbeef'
 
467
        self._check_xml_and_uri(instance_data,
 
468
                                expect_kernel=False, expect_ramdisk=False)
 
469
 
 
470
    def test_xml_and_uri(self):
 
471
        instance_data = dict(self.test_instance)
 
472
        instance_data['ramdisk_id'] = 'ari-deadbeef'
 
473
        instance_data['kernel_id'] = 'aki-deadbeef'
 
474
        self._check_xml_and_uri(instance_data,
 
475
                                expect_kernel=True, expect_ramdisk=True)
 
476
 
 
477
    def test_xml_and_uri_rescue(self):
 
478
        instance_data = dict(self.test_instance)
 
479
        instance_data['ramdisk_id'] = 'ari-deadbeef'
 
480
        instance_data['kernel_id'] = 'aki-deadbeef'
 
481
        self._check_xml_and_uri(instance_data, expect_kernel=True,
 
482
                                expect_ramdisk=True, rescue=True)
 
483
 
 
484
    def test_xml_uuid(self):
 
485
        instance_data = dict(self.test_instance)
 
486
        self._check_xml_and_uuid(instance_data)
 
487
 
 
488
    def test_lxc_container_and_uri(self):
 
489
        instance_data = dict(self.test_instance)
 
490
        self._check_xml_and_container(instance_data)
 
491
 
 
492
    def test_xml_disk_prefix(self):
 
493
        instance_data = dict(self.test_instance)
 
494
        self._check_xml_and_disk_prefix(instance_data)
 
495
 
 
496
    def test_xml_disk_driver(self):
 
497
        instance_data = dict(self.test_instance)
 
498
        self._check_xml_and_disk_driver(instance_data)
 
499
 
 
500
    def test_xml_disk_bus_virtio(self):
 
501
        self._check_xml_and_disk_bus({"disk_format": "raw"},
 
502
                                     "disk", "virtio")
 
503
 
 
504
    def test_xml_disk_bus_ide(self):
 
505
        self._check_xml_and_disk_bus({"disk_format": "iso"},
 
506
                                     "cdrom", "ide")
 
507
 
 
508
    def test_list_instances(self):
 
509
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
510
        connection.LibvirtConnection._conn.lookupByID = self.fake_lookup
 
511
        connection.LibvirtConnection._conn.listDomainsID = lambda: [0, 1]
 
512
 
 
513
        self.mox.ReplayAll()
 
514
        conn = connection.LibvirtConnection(False)
 
515
        instances = conn.list_instances()
 
516
        # Only one should be listed, since domain with ID 0 must be skiped
 
517
        self.assertEquals(len(instances), 1)
 
518
 
 
519
    def test_get_all_block_devices(self):
 
520
        xml = [
 
521
            # NOTE(vish): id 0 is skipped
 
522
            None,
 
523
            """
 
524
                <domain type='kvm'>
 
525
                    <devices>
 
526
                        <disk type='file'>
 
527
                            <source file='filename'/>
 
528
                        </disk>
 
529
                        <disk type='block'>
 
530
                            <source dev='/path/to/dev/1'/>
 
531
                        </disk>
 
532
                    </devices>
 
533
                </domain>
 
534
            """,
 
535
            """
 
536
                <domain type='kvm'>
 
537
                    <devices>
 
538
                        <disk type='file'>
 
539
                            <source file='filename'/>
 
540
                        </disk>
 
541
                    </devices>
 
542
                </domain>
 
543
            """,
 
544
            """
 
545
                <domain type='kvm'>
 
546
                    <devices>
 
547
                        <disk type='file'>
 
548
                            <source file='filename'/>
 
549
                        </disk>
 
550
                        <disk type='block'>
 
551
                            <source dev='/path/to/dev/3'/>
 
552
                        </disk>
 
553
                    </devices>
 
554
                </domain>
 
555
            """,
 
556
        ]
 
557
 
 
558
        def fake_lookup(id):
 
559
            return FakeVirtDomain(xml[id])
 
560
 
 
561
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
562
        connection.LibvirtConnection._conn.listDomainsID = lambda: range(4)
 
563
        connection.LibvirtConnection._conn.lookupByID = fake_lookup
 
564
 
 
565
        self.mox.ReplayAll()
 
566
        conn = connection.LibvirtConnection(False)
 
567
        devices = conn.get_all_block_devices()
 
568
        self.assertEqual(devices, ['/path/to/dev/1', '/path/to/dev/3'])
 
569
 
 
570
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
571
    def test_snapshot_in_ami_format(self):
 
572
        self.flags(image_service='nova.image.fake.FakeImageService')
 
573
 
 
574
        # Start test
 
575
        image_service = utils.import_object(FLAGS.image_service)
 
576
 
 
577
        # Assign different image_ref from nova/images/fakes for testing ami
 
578
        test_instance = copy.deepcopy(self.test_instance)
 
579
        test_instance["image_ref"] = 'c905cedb-7281-47e4-8a62-f26bc5fc4c77'
 
580
 
 
581
        # Assuming that base image already exists in image_service
 
582
        instance_ref = db.instance_create(self.context, test_instance)
 
583
        properties = {'instance_id': instance_ref['id'],
 
584
                      'user_id': str(self.context.user_id)}
 
585
        snapshot_name = 'test-snap'
 
586
        sent_meta = {'name': snapshot_name, 'is_public': False,
 
587
                     'status': 'creating', 'properties': properties}
 
588
        # Create new image. It will be updated in snapshot method
 
589
        # To work with it from snapshot, the single image_service is needed
 
590
        recv_meta = image_service.create(context, sent_meta)
 
591
 
 
592
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
593
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
594
        self.mox.StubOutWithMock(connection.utils, 'execute')
 
595
        connection.utils.execute = self.fake_execute
 
596
 
 
597
        self.mox.ReplayAll()
 
598
 
 
599
        conn = connection.LibvirtConnection(False)
 
600
        conn.snapshot(self.context, instance_ref, recv_meta['id'])
 
601
 
 
602
        snapshot = image_service.show(context, recv_meta['id'])
 
603
        self.assertEquals(snapshot['properties']['image_state'], 'available')
 
604
        self.assertEquals(snapshot['status'], 'active')
 
605
        self.assertEquals(snapshot['disk_format'], 'ami')
 
606
        self.assertEquals(snapshot['name'], snapshot_name)
 
607
 
 
608
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
609
    def test_snapshot_in_raw_format(self):
 
610
        self.flags(image_service='nova.image.fake.FakeImageService')
 
611
 
 
612
        # Start test
 
613
        image_service = utils.import_object(FLAGS.image_service)
 
614
 
 
615
        # Assuming that base image already exists in image_service
 
616
        instance_ref = db.instance_create(self.context, self.test_instance)
 
617
        properties = {'instance_id': instance_ref['id'],
 
618
                      'user_id': str(self.context.user_id)}
 
619
        snapshot_name = 'test-snap'
 
620
        sent_meta = {'name': snapshot_name, 'is_public': False,
 
621
                     'status': 'creating', 'properties': properties}
 
622
        # Create new image. It will be updated in snapshot method
 
623
        # To work with it from snapshot, the single image_service is needed
 
624
        recv_meta = image_service.create(context, sent_meta)
 
625
 
 
626
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
627
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
628
        self.mox.StubOutWithMock(connection.utils, 'execute')
 
629
        connection.utils.execute = self.fake_execute
 
630
 
 
631
        self.mox.ReplayAll()
 
632
 
 
633
        conn = connection.LibvirtConnection(False)
 
634
        conn.snapshot(self.context, instance_ref, recv_meta['id'])
 
635
 
 
636
        snapshot = image_service.show(context, recv_meta['id'])
 
637
        self.assertEquals(snapshot['properties']['image_state'], 'available')
 
638
        self.assertEquals(snapshot['status'], 'active')
 
639
        self.assertEquals(snapshot['disk_format'], 'raw')
 
640
        self.assertEquals(snapshot['name'], snapshot_name)
 
641
 
 
642
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
643
    def test_snapshot_in_qcow2_format(self):
 
644
        self.flags(image_service='nova.image.fake.FakeImageService')
 
645
        self.flags(snapshot_image_format='qcow2')
 
646
 
 
647
        # Start test
 
648
        image_service = utils.import_object(FLAGS.image_service)
 
649
 
 
650
        # Assuming that base image already exists in image_service
 
651
        instance_ref = db.instance_create(self.context, self.test_instance)
 
652
        properties = {'instance_id': instance_ref['id'],
 
653
                      'user_id': str(self.context.user_id)}
 
654
        snapshot_name = 'test-snap'
 
655
        sent_meta = {'name': snapshot_name, 'is_public': False,
 
656
                     'status': 'creating', 'properties': properties}
 
657
        # Create new image. It will be updated in snapshot method
 
658
        # To work with it from snapshot, the single image_service is needed
 
659
        recv_meta = image_service.create(context, sent_meta)
 
660
 
 
661
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
662
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
663
        self.mox.StubOutWithMock(connection.utils, 'execute')
 
664
        connection.utils.execute = self.fake_execute
 
665
 
 
666
        self.mox.ReplayAll()
 
667
 
 
668
        conn = connection.LibvirtConnection(False)
 
669
        conn.snapshot(self.context, instance_ref, recv_meta['id'])
 
670
 
 
671
        snapshot = image_service.show(context, recv_meta['id'])
 
672
        self.assertEquals(snapshot['properties']['image_state'], 'available')
 
673
        self.assertEquals(snapshot['status'], 'active')
 
674
        self.assertEquals(snapshot['disk_format'], 'qcow2')
 
675
        self.assertEquals(snapshot['name'], snapshot_name)
 
676
 
 
677
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
678
    def test_snapshot_no_image_architecture(self):
 
679
        self.flags(image_service='nova.image.fake.FakeImageService')
 
680
 
 
681
        # Start test
 
682
        image_service = utils.import_object(FLAGS.image_service)
 
683
 
 
684
        # Assign different image_ref from nova/images/fakes for
 
685
        # testing different base image
 
686
        test_instance = copy.deepcopy(self.test_instance)
 
687
        test_instance["image_ref"] = '76fa36fc-c930-4bf3-8c8a-ea2a2420deb6'
 
688
 
 
689
        # Assuming that base image already exists in image_service
 
690
        instance_ref = db.instance_create(self.context, test_instance)
 
691
        properties = {'instance_id': instance_ref['id'],
 
692
                      'user_id': str(self.context.user_id)}
 
693
        snapshot_name = 'test-snap'
 
694
        sent_meta = {'name': snapshot_name, 'is_public': False,
 
695
                     'status': 'creating', 'properties': properties}
 
696
        # Create new image. It will be updated in snapshot method
 
697
        # To work with it from snapshot, the single image_service is needed
 
698
        recv_meta = image_service.create(context, sent_meta)
 
699
 
 
700
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
701
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
702
        self.mox.StubOutWithMock(connection.utils, 'execute')
 
703
        connection.utils.execute = self.fake_execute
 
704
 
 
705
        self.mox.ReplayAll()
 
706
 
 
707
        conn = connection.LibvirtConnection(False)
 
708
        conn.snapshot(self.context, instance_ref, recv_meta['id'])
 
709
 
 
710
        snapshot = image_service.show(context, recv_meta['id'])
 
711
        self.assertEquals(snapshot['properties']['image_state'], 'available')
 
712
        self.assertEquals(snapshot['status'], 'active')
 
713
        self.assertEquals(snapshot['name'], snapshot_name)
 
714
 
 
715
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
716
    def test_snapshot_no_original_image(self):
 
717
        self.flags(image_service='nova.image.fake.FakeImageService')
 
718
 
 
719
        # Start test
 
720
        image_service = utils.import_object(FLAGS.image_service)
 
721
 
 
722
        # Assign a non-existent image
 
723
        test_instance = copy.deepcopy(self.test_instance)
 
724
        test_instance["image_ref"] = '661122aa-1234-dede-fefe-babababababa'
 
725
 
 
726
        instance_ref = db.instance_create(self.context, test_instance)
 
727
        properties = {'instance_id': instance_ref['id'],
 
728
                      'user_id': str(self.context.user_id)}
 
729
        snapshot_name = 'test-snap'
 
730
        sent_meta = {'name': snapshot_name, 'is_public': False,
 
731
                     'status': 'creating', 'properties': properties}
 
732
        recv_meta = image_service.create(context, sent_meta)
 
733
 
 
734
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
735
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
736
        self.mox.StubOutWithMock(connection.utils, 'execute')
 
737
        connection.utils.execute = self.fake_execute
 
738
 
 
739
        self.mox.ReplayAll()
 
740
 
 
741
        conn = connection.LibvirtConnection(False)
 
742
        conn.snapshot(self.context, instance_ref, recv_meta['id'])
 
743
 
 
744
        snapshot = image_service.show(context, recv_meta['id'])
 
745
        self.assertEquals(snapshot['properties']['image_state'], 'available')
 
746
        self.assertEquals(snapshot['status'], 'active')
 
747
        self.assertEquals(snapshot['name'], snapshot_name)
 
748
 
 
749
    def test_attach_invalid_volume_type(self):
 
750
        self.create_fake_libvirt_mock()
 
751
        connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
 
752
        self.mox.ReplayAll()
 
753
        conn = connection.LibvirtConnection(False)
 
754
        self.assertRaises(exception.VolumeDriverNotFound,
 
755
                          conn.attach_volume,
 
756
                          {"driver_volume_type": "badtype"},
 
757
                           "fake",
 
758
                           "/dev/fake")
 
759
 
 
760
    def test_multi_nic(self):
 
761
        instance_data = dict(self.test_instance)
 
762
        network_info = _fake_network_info(self.stubs, 2)
 
763
        conn = connection.LibvirtConnection(True)
 
764
        instance_ref = db.instance_create(self.context, instance_data)
 
765
        xml = conn.to_xml(instance_ref, network_info, None, False)
 
766
        tree = ElementTree.fromstring(xml)
 
767
        interfaces = tree.findall("./devices/interface")
 
768
        self.assertEquals(len(interfaces), 2)
 
769
        parameters = interfaces[0].findall('./filterref/parameter')
 
770
        self.assertEquals(interfaces[0].get('type'), 'bridge')
 
771
        self.assertEquals(parameters[0].get('name'), 'IP')
 
772
        self.assertTrue(_ipv4_like(parameters[0].get('value'), '192.168'))
 
773
        self.assertEquals(parameters[1].get('name'), 'DHCPSERVER')
 
774
        self.assertTrue(_ipv4_like(parameters[1].get('value'), '192.168.*.1'))
 
775
 
 
776
    def _check_xml_and_container(self, instance):
 
777
        user_context = context.RequestContext(self.user_id,
 
778
                                              self.project_id)
 
779
        instance_ref = db.instance_create(user_context, instance)
 
780
 
 
781
        self.flags(libvirt_type='lxc')
 
782
        conn = connection.LibvirtConnection(True)
 
783
 
 
784
        self.assertEquals(conn.uri, 'lxc:///')
 
785
 
 
786
        network_info = _fake_network_info(self.stubs, 1)
 
787
        xml = conn.to_xml(instance_ref, network_info)
 
788
        tree = ElementTree.fromstring(xml)
 
789
 
 
790
        check = [
 
791
        (lambda t: t.find('.').get('type'), 'lxc'),
 
792
        (lambda t: t.find('./os/type').text, 'exe'),
 
793
        (lambda t: t.find('./devices/filesystem/target').get('dir'), '/')]
 
794
 
 
795
        for i, (check, expected_result) in enumerate(check):
 
796
            self.assertEqual(check(tree),
 
797
                             expected_result,
 
798
                             '%s failed common check %d' % (xml, i))
 
799
 
 
800
        target = tree.find('./devices/filesystem/source').get('dir')
 
801
        self.assertTrue(len(target) > 0)
 
802
 
 
803
    def _check_xml_and_disk_prefix(self, instance):
 
804
        user_context = context.RequestContext(self.user_id,
 
805
                                              self.project_id)
 
806
        instance_ref = db.instance_create(user_context, instance)
 
807
 
 
808
        type_disk_map = {
 
809
            'qemu': [
 
810
               (lambda t: t.find('.').get('type'), 'qemu'),
 
811
               (lambda t: t.find('./devices/disk/target').get('dev'), 'vda')],
 
812
            'xen': [
 
813
               (lambda t: t.find('.').get('type'), 'xen'),
 
814
               (lambda t: t.find('./devices/disk/target').get('dev'), 'sda')],
 
815
            'kvm': [
 
816
               (lambda t: t.find('.').get('type'), 'kvm'),
 
817
               (lambda t: t.find('./devices/disk/target').get('dev'), 'vda')],
 
818
            'uml': [
 
819
               (lambda t: t.find('.').get('type'), 'uml'),
 
820
               (lambda t: t.find('./devices/disk/target').get('dev'), 'ubda')]
 
821
            }
 
822
 
 
823
        for (libvirt_type, checks) in type_disk_map.iteritems():
 
824
            self.flags(libvirt_type=libvirt_type)
 
825
            conn = connection.LibvirtConnection(True)
 
826
 
 
827
            network_info = _fake_network_info(self.stubs, 1)
 
828
            xml = conn.to_xml(instance_ref, network_info)
 
829
            tree = ElementTree.fromstring(xml)
 
830
 
 
831
            for i, (check, expected_result) in enumerate(checks):
 
832
                self.assertEqual(check(tree),
 
833
                                 expected_result,
 
834
                                 '%s != %s failed check %d' %
 
835
                                 (check(tree), expected_result, i))
 
836
 
 
837
    def _check_xml_and_disk_driver(self, image_meta):
 
838
        os_open = os.open
 
839
        directio_supported = True
 
840
 
 
841
        def os_open_stub(path, flags, *args, **kwargs):
 
842
            if flags & os.O_DIRECT:
 
843
                if not directio_supported:
 
844
                    raise OSError(errno.EINVAL,
 
845
                                  '%s: %s' % (os.strerror(errno.EINVAL), path))
 
846
                flags &= ~os.O_DIRECT
 
847
            return os_open(path, flags, *args, **kwargs)
 
848
 
 
849
        self.stubs.Set(os, 'open', os_open_stub)
 
850
 
 
851
        user_context = context.RequestContext(self.user_id, self.project_id)
 
852
        instance_ref = db.instance_create(user_context, self.test_instance)
 
853
        network_info = _fake_network_info(self.stubs, 1)
 
854
 
 
855
        xml = connection.LibvirtConnection(True).to_xml(instance_ref,
 
856
                                                        network_info,
 
857
                                                        image_meta)
 
858
        tree = ElementTree.fromstring(xml)
 
859
        disks = tree.findall('./devices/disk/driver')
 
860
        for disk in disks:
 
861
            self.assertEqual(disk.get("cache"), "none")
 
862
 
 
863
        directio_supported = False
 
864
 
 
865
        # The O_DIRECT availability is cached on first use in
 
866
        # LibvirtConnection, hence we re-create it here
 
867
        xml = connection.LibvirtConnection(True).to_xml(instance_ref,
 
868
                                                        network_info,
 
869
                                                        image_meta)
 
870
        tree = ElementTree.fromstring(xml)
 
871
        disks = tree.findall('./devices/disk/driver')
 
872
        for disk in disks:
 
873
            self.assertEqual(disk.get("cache"), "writethrough")
 
874
 
 
875
    def _check_xml_and_disk_bus(self, image_meta, device_type, bus):
 
876
        user_context = context.RequestContext(self.user_id, self.project_id)
 
877
        instance_ref = db.instance_create(user_context, self.test_instance)
 
878
        network_info = _fake_network_info(self.stubs, 1)
 
879
 
 
880
        xml = connection.LibvirtConnection(True).to_xml(instance_ref,
 
881
                                                        network_info,
 
882
                                                        image_meta)
 
883
        tree = ElementTree.fromstring(xml)
 
884
        self.assertEqual(tree.find('./devices/disk').get('device'),
 
885
                         device_type)
 
886
        self.assertEqual(tree.find('./devices/disk/target').get('bus'), bus)
 
887
 
 
888
    def _check_xml_and_uuid(self, image_meta):
 
889
        user_context = context.RequestContext(self.user_id, self.project_id)
 
890
        instance_ref = db.instance_create(user_context, self.test_instance)
 
891
        network_info = _fake_network_info(self.stubs, 1)
 
892
 
 
893
        xml = connection.LibvirtConnection(True).to_xml(instance_ref,
 
894
                                                        network_info,
 
895
                                                        image_meta)
 
896
        tree = ElementTree.fromstring(xml)
 
897
        self.assertEqual(tree.find('./uuid').text,
 
898
                         instance_ref['uuid'])
 
899
 
 
900
    def _check_xml_and_uri(self, instance, expect_ramdisk, expect_kernel,
 
901
                           rescue=False):
 
902
        user_context = context.RequestContext(self.user_id, self.project_id)
 
903
        instance_ref = db.instance_create(user_context, instance)
 
904
        network_ref = db.project_get_networks(context.get_admin_context(),
 
905
                                             self.project_id)[0]
 
906
 
 
907
        type_uri_map = {'qemu': ('qemu:///system',
 
908
                             [(lambda t: t.find('.').get('type'), 'qemu'),
 
909
                              (lambda t: t.find('./os/type').text, 'hvm'),
 
910
                              (lambda t: t.find('./devices/emulator'), None)]),
 
911
                        'kvm': ('qemu:///system',
 
912
                             [(lambda t: t.find('.').get('type'), 'kvm'),
 
913
                              (lambda t: t.find('./os/type').text, 'hvm'),
 
914
                              (lambda t: t.find('./devices/emulator'), None)]),
 
915
                        'uml': ('uml:///system',
 
916
                             [(lambda t: t.find('.').get('type'), 'uml'),
 
917
                              (lambda t: t.find('./os/type').text, 'uml')]),
 
918
                        'xen': ('xen:///',
 
919
                             [(lambda t: t.find('.').get('type'), 'xen'),
 
920
                              (lambda t: t.find('./os/type').text, 'linux')]),
 
921
                              }
 
922
 
 
923
        for hypervisor_type in ['qemu', 'kvm', 'xen']:
 
924
            check_list = type_uri_map[hypervisor_type][1]
 
925
 
 
926
            if rescue:
 
927
                check = (lambda t: t.find('./os/kernel').text.split('/')[1],
 
928
                         'kernel.rescue')
 
929
                check_list.append(check)
 
930
                check = (lambda t: t.find('./os/initrd').text.split('/')[1],
 
931
                         'ramdisk.rescue')
 
932
                check_list.append(check)
 
933
            else:
 
934
                if expect_kernel:
 
935
                    check = (lambda t: t.find('./os/kernel').text.split(
 
936
                        '/')[1], 'kernel')
 
937
                else:
 
938
                    check = (lambda t: t.find('./os/kernel'), None)
 
939
                check_list.append(check)
 
940
 
 
941
                if expect_ramdisk:
 
942
                    check = (lambda t: t.find('./os/initrd').text.split(
 
943
                        '/')[1], 'ramdisk')
 
944
                else:
 
945
                    check = (lambda t: t.find('./os/initrd'), None)
 
946
                check_list.append(check)
 
947
 
 
948
            if hypervisor_type in ['qemu', 'kvm']:
 
949
                check = (lambda t: t.findall('./devices/serial')[0].get(
 
950
                        'type'), 'file')
 
951
                check_list.append(check)
 
952
                check = (lambda t: t.findall('./devices/serial')[1].get(
 
953
                        'type'), 'pty')
 
954
                check_list.append(check)
 
955
                check = (lambda t: t.findall('./devices/serial/source')[0].get(
 
956
                        'path').split('/')[1], 'console.log')
 
957
                check_list.append(check)
 
958
            else:
 
959
                check = (lambda t: t.find('./devices/console').get(
 
960
                        'type'), 'pty')
 
961
                check_list.append(check)
 
962
 
 
963
        parameter = './devices/interface/filterref/parameter'
 
964
        common_checks = [
 
965
            (lambda t: t.find('.').tag, 'domain'),
 
966
            (lambda t: t.find(parameter).get('name'), 'IP'),
 
967
            (lambda t: _ipv4_like(t.find(parameter).get('value'), '192.168'),
 
968
             True),
 
969
            (lambda t: t.findall(parameter)[1].get('name'), 'DHCPSERVER'),
 
970
            (lambda t: _ipv4_like(t.findall(parameter)[1].get('value'),
 
971
                                  '192.168.*.1'), True),
 
972
            (lambda t: t.find('./memory').text, '2097152')]
 
973
        if rescue:
 
974
            common_checks += [
 
975
                (lambda t: t.findall('./devices/disk/source')[0].get(
 
976
                    'file').split('/')[1], 'disk.rescue'),
 
977
                (lambda t: t.findall('./devices/disk/source')[1].get(
 
978
                    'file').split('/')[1], 'disk')]
 
979
        else:
 
980
            common_checks += [(lambda t: t.findall(
 
981
                './devices/disk/source')[0].get('file').split('/')[1],
 
982
                               'disk')]
 
983
            common_checks += [(lambda t: t.findall(
 
984
                './devices/disk/source')[1].get('file').split('/')[1],
 
985
                               'disk.local')]
 
986
 
 
987
        for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
 
988
            self.flags(libvirt_type=libvirt_type)
 
989
            conn = connection.LibvirtConnection(True)
 
990
 
 
991
            self.assertEquals(conn.uri, expected_uri)
 
992
 
 
993
            network_info = _fake_network_info(self.stubs, 1)
 
994
            xml = conn.to_xml(instance_ref, network_info, None, rescue)
 
995
            tree = ElementTree.fromstring(xml)
 
996
            for i, (check, expected_result) in enumerate(checks):
 
997
                self.assertEqual(check(tree),
 
998
                                 expected_result,
 
999
                                 '%s != %s failed check %d' %
 
1000
                                 (check(tree), expected_result, i))
 
1001
 
 
1002
            for i, (check, expected_result) in enumerate(common_checks):
 
1003
                self.assertEqual(check(tree),
 
1004
                                 expected_result,
 
1005
                                 '%s != %s failed common check %d' %
 
1006
                                 (check(tree), expected_result, i))
 
1007
 
 
1008
        # This test is supposed to make sure we don't
 
1009
        # override a specifically set uri
 
1010
        #
 
1011
        # Deliberately not just assigning this string to FLAGS.libvirt_uri and
 
1012
        # checking against that later on. This way we make sure the
 
1013
        # implementation doesn't fiddle around with the FLAGS.
 
1014
        testuri = 'something completely different'
 
1015
        self.flags(libvirt_uri=testuri)
 
1016
        for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
 
1017
            self.flags(libvirt_type=libvirt_type)
 
1018
            conn = connection.LibvirtConnection(True)
 
1019
            self.assertEquals(conn.uri, testuri)
 
1020
        db.instance_destroy(user_context, instance_ref['id'])
 
1021
 
 
1022
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1023
    def test_ensure_filtering_rules_for_instance_timeout(self):
 
1024
        """ensure_filtering_fules_for_instance() finishes with timeout."""
 
1025
        # Preparing mocks
 
1026
        def fake_none(self, *args):
 
1027
            return
 
1028
 
 
1029
        def fake_raise(self):
 
1030
            raise libvirt.libvirtError('ERR')
 
1031
 
 
1032
        class FakeTime(object):
 
1033
            def __init__(self):
 
1034
                self.counter = 0
 
1035
 
 
1036
            def sleep(self, t):
 
1037
                self.counter += t
 
1038
 
 
1039
        fake_timer = FakeTime()
 
1040
 
 
1041
        # _fake_network_info must be called before create_fake_libvirt_mock(),
 
1042
        # as _fake_network_info calls utils.import_class() and
 
1043
        # create_fake_libvirt_mock() mocks utils.import_class().
 
1044
        network_info = _fake_network_info(self.stubs, 1)
 
1045
        self.create_fake_libvirt_mock()
 
1046
        instance_ref = db.instance_create(self.context, self.test_instance)
 
1047
 
 
1048
        # Start test
 
1049
        self.mox.ReplayAll()
 
1050
        try:
 
1051
            conn = connection.LibvirtConnection(False)
 
1052
            self.stubs.Set(conn.firewall_driver,
 
1053
                           'setup_basic_filtering',
 
1054
                           fake_none)
 
1055
            self.stubs.Set(conn.firewall_driver,
 
1056
                           'prepare_instance_filter',
 
1057
                           fake_none)
 
1058
            self.stubs.Set(conn.firewall_driver,
 
1059
                           'instance_filter_exists',
 
1060
                           fake_none)
 
1061
            conn.ensure_filtering_rules_for_instance(instance_ref,
 
1062
                                                     network_info,
 
1063
                                                     time=fake_timer)
 
1064
        except exception.Error, e:
 
1065
            c1 = (0 <= e.message.find('Timeout migrating for'))
 
1066
        self.assertTrue(c1)
 
1067
 
 
1068
        self.assertEqual(29, fake_timer.counter, "Didn't wait the expected "
 
1069
                                                 "amount of time")
 
1070
 
 
1071
        db.instance_destroy(self.context, instance_ref['id'])
 
1072
 
 
1073
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1074
    def test_live_migration_raises_exception(self):
 
1075
        """Confirms recover method is called when exceptions are raised."""
 
1076
        # Preparing data
 
1077
        self.compute = utils.import_object(FLAGS.compute_manager)
 
1078
        instance_dict = {'host': 'fake',
 
1079
                         'power_state': power_state.RUNNING,
 
1080
                         'vm_state': vm_states.ACTIVE}
 
1081
        instance_ref = db.instance_create(self.context, self.test_instance)
 
1082
        instance_ref = db.instance_update(self.context, instance_ref['id'],
 
1083
                                          instance_dict)
 
1084
        vol_dict = {'status': 'migrating', 'size': 1}
 
1085
        volume_ref = db.volume_create(self.context, vol_dict)
 
1086
        db.volume_attached(self.context, volume_ref['id'], instance_ref['id'],
 
1087
                           '/dev/fake')
 
1088
 
 
1089
        # Preparing mocks
 
1090
        vdmock = self.mox.CreateMock(libvirt.virDomain)
 
1091
        self.mox.StubOutWithMock(vdmock, "migrateToURI")
 
1092
        _bandwidth = FLAGS.live_migration_bandwidth
 
1093
        vdmock.migrateToURI(FLAGS.live_migration_uri % 'dest',
 
1094
                            mox.IgnoreArg(),
 
1095
                            None,
 
1096
                            _bandwidth).AndRaise(libvirt.libvirtError('ERR'))
 
1097
 
 
1098
        def fake_lookup(instance_name):
 
1099
            if instance_name == instance_ref.name:
 
1100
                return vdmock
 
1101
 
 
1102
        self.create_fake_libvirt_mock(lookupByName=fake_lookup)
 
1103
        self.mox.StubOutWithMock(self.compute, "rollback_live_migration")
 
1104
        self.compute.rollback_live_migration(self.context, instance_ref,
 
1105
                                            'dest', False)
 
1106
 
 
1107
        #start test
 
1108
        self.mox.ReplayAll()
 
1109
        conn = connection.LibvirtConnection(False)
 
1110
        self.assertRaises(libvirt.libvirtError,
 
1111
                      conn._live_migration,
 
1112
                      self.context, instance_ref, 'dest', False,
 
1113
                      self.compute.rollback_live_migration)
 
1114
 
 
1115
        instance_ref = db.instance_get(self.context, instance_ref['id'])
 
1116
        self.assertTrue(instance_ref['vm_state'] == vm_states.ACTIVE)
 
1117
        self.assertTrue(instance_ref['power_state'] == power_state.RUNNING)
 
1118
        volume_ref = db.volume_get(self.context, volume_ref['id'])
 
1119
        self.assertTrue(volume_ref['status'] == 'in-use')
 
1120
 
 
1121
        db.volume_destroy(self.context, volume_ref['id'])
 
1122
        db.instance_destroy(self.context, instance_ref['id'])
 
1123
 
 
1124
    def test_pre_live_migration_works_correctly(self):
 
1125
        """Confirms pre_block_migration works correctly."""
 
1126
        # Creating testdata
 
1127
        vol = {'block_device_mapping': [
 
1128
                  {'connection_info': 'dummy', 'mount_device': '/dev/sda'},
 
1129
                  {'connection_info': 'dummy', 'mount_device': '/dev/sdb'}]}
 
1130
        conn = connection.LibvirtConnection(False)
 
1131
 
 
1132
        # Creating mocks
 
1133
        self.mox.StubOutWithMock(driver, "block_device_info_get_mapping")
 
1134
        driver.block_device_info_get_mapping(vol
 
1135
            ).AndReturn(vol['block_device_mapping'])
 
1136
        self.mox.StubOutWithMock(conn, "volume_driver_method")
 
1137
        for v in vol['block_device_mapping']:
 
1138
            conn.volume_driver_method('connect_volume',
 
1139
                                     v['connection_info'], v['mount_device'])
 
1140
 
 
1141
        # Starting test
 
1142
        self.mox.ReplayAll()
 
1143
        self.assertEqual(conn.pre_live_migration(vol), None)
 
1144
 
 
1145
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1146
    def test_pre_block_migration_works_correctly(self):
 
1147
        """Confirms pre_block_migration works correctly."""
 
1148
        # Replace instances_path since this testcase creates tmpfile
 
1149
        with utils.tempdir() as tmpdir:
 
1150
            self.flags(instances_path=tmpdir)
 
1151
 
 
1152
            # Test data
 
1153
            instance_ref = db.instance_create(self.context, self.test_instance)
 
1154
            dummyjson = ('[{"path": "%s/disk", "disk_size": "10737418240",'
 
1155
                         ' "type": "raw", "backing_file": ""}]')
 
1156
 
 
1157
            # Preparing mocks
 
1158
            # qemu-img should be mockd since test environment might not have
 
1159
            # large disk space.
 
1160
            self.mox.ReplayAll()
 
1161
            conn = connection.LibvirtConnection(False)
 
1162
            conn.pre_block_migration(self.context, instance_ref,
 
1163
                                     dummyjson % tmpdir)
 
1164
 
 
1165
            self.assertTrue(os.path.exists('%s/%s/' %
 
1166
                                           (tmpdir, instance_ref.name)))
 
1167
 
 
1168
        db.instance_destroy(self.context, instance_ref['id'])
 
1169
 
 
1170
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1171
    def test_get_instance_disk_info_works_correctly(self):
 
1172
        """Confirms pre_block_migration works correctly."""
 
1173
        # Test data
 
1174
        instance_ref = db.instance_create(self.context, self.test_instance)
 
1175
        dummyxml = ("<domain type='kvm'><name>instance-0000000a</name>"
 
1176
                    "<devices>"
 
1177
                    "<disk type='file'><driver name='qemu' type='raw'/>"
 
1178
                    "<source file='/test/disk'/>"
 
1179
                    "<target dev='vda' bus='virtio'/></disk>"
 
1180
                    "<disk type='file'><driver name='qemu' type='qcow2'/>"
 
1181
                    "<source file='/test/disk.local'/>"
 
1182
                    "<target dev='vdb' bus='virtio'/></disk>"
 
1183
                    "</devices></domain>")
 
1184
 
 
1185
        ret = ("image: /test/disk\n"
 
1186
               "file format: raw\n"
 
1187
               "virtual size: 20G (21474836480 bytes)\n"
 
1188
               "disk size: 3.1G\n"
 
1189
               "cluster_size: 2097152\n"
 
1190
               "backing file: /test/dummy (actual path: /backing/file)\n")
 
1191
 
 
1192
        # Preparing mocks
 
1193
        vdmock = self.mox.CreateMock(libvirt.virDomain)
 
1194
        self.mox.StubOutWithMock(vdmock, "XMLDesc")
 
1195
        vdmock.XMLDesc(0).AndReturn(dummyxml)
 
1196
 
 
1197
        def fake_lookup(instance_name):
 
1198
            if instance_name == instance_ref.name:
 
1199
                return vdmock
 
1200
        self.create_fake_libvirt_mock(lookupByName=fake_lookup)
 
1201
 
 
1202
        GB = 1024 * 1024 * 1024
 
1203
        fake_libvirt_utils.disk_sizes['/test/disk'] = 10 * GB
 
1204
        fake_libvirt_utils.disk_sizes['/test/disk.local'] = 20 * GB
 
1205
        fake_libvirt_utils.disk_backing_files['/test/disk.local'] = 'file'
 
1206
 
 
1207
        self.mox.StubOutWithMock(os.path, "getsize")
 
1208
        os.path.getsize('/test/disk').AndReturn((10737418240))
 
1209
 
 
1210
        self.mox.StubOutWithMock(utils, "execute")
 
1211
        utils.execute('qemu-img', 'info',
 
1212
                      '/test/disk.local').AndReturn((ret, ''))
 
1213
 
 
1214
        os.path.getsize('/test/disk.local').AndReturn((21474836480))
 
1215
 
 
1216
        self.mox.ReplayAll()
 
1217
        conn = connection.LibvirtConnection(False)
 
1218
        info = conn.get_instance_disk_info(instance_ref.name)
 
1219
        info = utils.loads(info)
 
1220
        self.assertEquals(info[0]['type'], 'raw')
 
1221
        self.assertEquals(info[0]['path'], '/test/disk')
 
1222
        self.assertEquals(info[0]['disk_size'], 10737418240)
 
1223
        self.assertEquals(info[0]['backing_file'], "")
 
1224
        self.assertEquals(info[1]['type'], 'qcow2')
 
1225
        self.assertEquals(info[1]['path'], '/test/disk.local')
 
1226
        self.assertEquals(info[1]['virt_disk_size'], 21474836480)
 
1227
        self.assertEquals(info[1]['backing_file'], "file")
 
1228
 
 
1229
        db.instance_destroy(self.context, instance_ref['id'])
 
1230
 
 
1231
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1232
    def test_spawn_with_network_info(self):
 
1233
        # Preparing mocks
 
1234
        def fake_none(self, instance):
 
1235
            return
 
1236
 
 
1237
        # _fake_network_info must be called before create_fake_libvirt_mock(),
 
1238
        # as _fake_network_info calls utils.import_class() and
 
1239
        # create_fake_libvirt_mock() mocks utils.import_class().
 
1240
        network_info = _fake_network_info(self.stubs, 1)
 
1241
        self.create_fake_libvirt_mock()
 
1242
 
 
1243
        instance_ref = self.test_instance
 
1244
        instance_ref['image_ref'] = 123456  # we send an int to test sha1 call
 
1245
        instance = db.instance_create(self.context, instance_ref)
 
1246
 
 
1247
        # Start test
 
1248
        self.mox.ReplayAll()
 
1249
        conn = connection.LibvirtConnection(False)
 
1250
        self.stubs.Set(conn.firewall_driver,
 
1251
                       'setup_basic_filtering',
 
1252
                       fake_none)
 
1253
        self.stubs.Set(conn.firewall_driver,
 
1254
                       'prepare_instance_filter',
 
1255
                       fake_none)
 
1256
 
 
1257
        try:
 
1258
            conn.spawn(self.context, instance, None, network_info)
 
1259
        except Exception, e:
 
1260
            # assert that no exception is raised due to sha1 receiving an int
 
1261
            self.assertEqual(-1, str(e.message).find('must be string or buffer'
 
1262
                                                     ', not int'))
 
1263
            count = (0 <= str(e.message).find('Unexpected method call'))
 
1264
 
 
1265
        path = os.path.join(FLAGS.instances_path, instance.name)
 
1266
        if os.path.isdir(path):
 
1267
            shutil.rmtree(path)
 
1268
 
 
1269
        path = os.path.join(FLAGS.instances_path, '_base')
 
1270
        if os.path.isdir(path):
 
1271
            shutil.rmtree(os.path.join(FLAGS.instances_path, '_base'))
 
1272
 
 
1273
    def test_get_host_ip_addr(self):
 
1274
        conn = connection.LibvirtConnection(False)
 
1275
        ip = conn.get_host_ip_addr()
 
1276
        self.assertEquals(ip, FLAGS.my_ip)
 
1277
 
 
1278
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1279
    def test_broken_connection(self):
 
1280
        for (error, domain) in (
 
1281
                (libvirt.VIR_ERR_SYSTEM_ERROR, libvirt.VIR_FROM_REMOTE),
 
1282
                (libvirt.VIR_ERR_SYSTEM_ERROR, libvirt.VIR_FROM_RPC)):
 
1283
 
 
1284
            conn = connection.LibvirtConnection(False)
 
1285
 
 
1286
            self.mox.StubOutWithMock(conn, "_wrapped_conn")
 
1287
            self.mox.StubOutWithMock(conn._wrapped_conn, "getCapabilities")
 
1288
            self.mox.StubOutWithMock(libvirt.libvirtError, "get_error_code")
 
1289
            self.mox.StubOutWithMock(libvirt.libvirtError, "get_error_domain")
 
1290
 
 
1291
            conn._wrapped_conn.getCapabilities().AndRaise(
 
1292
                    libvirt.libvirtError("fake failure"))
 
1293
 
 
1294
            libvirt.libvirtError.get_error_code().AndReturn(error)
 
1295
            libvirt.libvirtError.get_error_domain().AndReturn(domain)
 
1296
 
 
1297
            self.mox.ReplayAll()
 
1298
 
 
1299
            self.assertFalse(conn._test_connection())
 
1300
 
 
1301
            self.mox.UnsetStubs()
 
1302
 
 
1303
    def test_volume_in_mapping(self):
 
1304
        conn = connection.LibvirtConnection(False)
 
1305
        swap = {'device_name': '/dev/sdb',
 
1306
                'swap_size': 1}
 
1307
        ephemerals = [{'num': 0,
 
1308
                       'virtual_name': 'ephemeral0',
 
1309
                       'device_name': '/dev/sdc1',
 
1310
                       'size': 1},
 
1311
                      {'num': 2,
 
1312
                       'virtual_name': 'ephemeral2',
 
1313
                       'device_name': '/dev/sdd',
 
1314
                       'size': 1}]
 
1315
        block_device_mapping = [{'mount_device': '/dev/sde',
 
1316
                                 'device_path': 'fake_device'},
 
1317
                                {'mount_device': '/dev/sdf',
 
1318
                                 'device_path': 'fake_device'}]
 
1319
        block_device_info = {
 
1320
                'root_device_name': '/dev/sda',
 
1321
                'swap': swap,
 
1322
                'ephemerals': ephemerals,
 
1323
                'block_device_mapping': block_device_mapping}
 
1324
 
 
1325
        def _assert_volume_in_mapping(device_name, true_or_false):
 
1326
            self.assertEquals(conn._volume_in_mapping(device_name,
 
1327
                                                      block_device_info),
 
1328
                              true_or_false)
 
1329
 
 
1330
        _assert_volume_in_mapping('sda', False)
 
1331
        _assert_volume_in_mapping('sdb', True)
 
1332
        _assert_volume_in_mapping('sdc1', True)
 
1333
        _assert_volume_in_mapping('sdd', True)
 
1334
        _assert_volume_in_mapping('sde', True)
 
1335
        _assert_volume_in_mapping('sdf', True)
 
1336
        _assert_volume_in_mapping('sdg', False)
 
1337
        _assert_volume_in_mapping('sdh1', False)
 
1338
 
 
1339
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1340
    def test_immediate_delete(self):
 
1341
        conn = connection.LibvirtConnection(False)
 
1342
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 
1343
        connection.LibvirtConnection._conn.lookupByName = lambda x: None
 
1344
 
 
1345
        instance = db.instance_create(self.context, self.test_instance)
 
1346
        conn.destroy(instance, {})
 
1347
 
 
1348
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1349
    def test_destroy_saved(self):
 
1350
        """Ensure destroy calls managedSaveRemove for saved instance"""
 
1351
        mock = self.mox.CreateMock(libvirt.virDomain)
 
1352
        mock.destroy()
 
1353
        mock.hasManagedSaveImage(0).AndReturn(1)
 
1354
        mock.managedSaveRemove(0)
 
1355
        mock.undefine()
 
1356
 
 
1357
        self.mox.ReplayAll()
 
1358
 
 
1359
        def fake_lookup_by_name(instance_name):
 
1360
            return mock
 
1361
 
 
1362
        conn = connection.LibvirtConnection(False)
 
1363
        self.stubs.Set(conn, '_lookup_by_name', fake_lookup_by_name)
 
1364
        instance = {"name": "instancename", "id": "instanceid",
 
1365
                    "uuid": "875a8070-d0b9-4949-8b31-104d125c9a64"}
 
1366
        conn.destroy(instance, [])
 
1367
 
 
1368
    def test_available_least_handles_missing(self):
 
1369
        """Ensure destroy calls managedSaveRemove for saved instance"""
 
1370
        conn = connection.LibvirtConnection(False)
 
1371
 
 
1372
        def list_instances():
 
1373
            return ['fake']
 
1374
        self.stubs.Set(conn, 'list_instances', list_instances)
 
1375
 
 
1376
        def get_info(instance_name):
 
1377
            raise exception.InstanceNotFound()
 
1378
        self.stubs.Set(conn, 'get_instance_disk_info', get_info)
 
1379
 
 
1380
        result = conn.get_disk_available_least()
 
1381
        space = fake_libvirt_utils.get_fs_info(FLAGS.instances_path)['free']
 
1382
        self.assertEqual(result, space / 1024 ** 3)
 
1383
 
 
1384
 
 
1385
class HostStateTestCase(test.TestCase):
 
1386
 
 
1387
    cpu_info = ('{"vendor": "Intel", "model": "pentium", "arch": "i686", '
 
1388
                 '"features": ["ssse3", "monitor", "pni", "sse2", "sse", '
 
1389
                 '"fxsr", "clflush", "pse36", "pat", "cmov", "mca", "pge", '
 
1390
                 '"mtrr", "sep", "apic"], '
 
1391
                 '"topology": {"cores": "1", "threads": "1", "sockets": "1"}}')
 
1392
 
 
1393
    class FakeConnection(object):
 
1394
        """Fake connection object"""
 
1395
 
 
1396
        def get_vcpu_total(self):
 
1397
            return 1
 
1398
 
 
1399
        def get_vcpu_used(self):
 
1400
            return 0
 
1401
 
 
1402
        def get_cpu_info(self):
 
1403
            return HostStateTestCase.cpu_info
 
1404
 
 
1405
        def get_local_gb_total(self):
 
1406
            return 100
 
1407
 
 
1408
        def get_local_gb_used(self):
 
1409
            return 20
 
1410
 
 
1411
        def get_memory_mb_total(self):
 
1412
            return 497
 
1413
 
 
1414
        def get_memory_mb_used(self):
 
1415
            return 88
 
1416
 
 
1417
        def get_hypervisor_type(self):
 
1418
            return 'QEMU'
 
1419
 
 
1420
        def get_hypervisor_version(self):
 
1421
            return 13091
 
1422
 
 
1423
        def get_disk_available_least(self):
 
1424
            return 13091
 
1425
 
 
1426
    def test_update_status(self):
 
1427
        self.mox.StubOutWithMock(connection, 'get_connection')
 
1428
        connection.get_connection(True).AndReturn(self.FakeConnection())
 
1429
 
 
1430
        self.mox.ReplayAll()
 
1431
        hs = connection.HostState(True)
 
1432
        stats = hs._stats
 
1433
        self.assertEquals(stats["vcpus"], 1)
 
1434
        self.assertEquals(stats["vcpus_used"], 0)
 
1435
        self.assertEquals(stats["cpu_info"],
 
1436
                {"vendor": "Intel", "model": "pentium", "arch": "i686",
 
1437
                 "features": ["ssse3", "monitor", "pni", "sse2", "sse",
 
1438
                              "fxsr", "clflush", "pse36", "pat", "cmov",
 
1439
                              "mca", "pge", "mtrr", "sep", "apic"],
 
1440
                 "topology": {"cores": "1", "threads": "1", "sockets": "1"}
 
1441
                })
 
1442
        self.assertEquals(stats["disk_total"], 100)
 
1443
        self.assertEquals(stats["disk_used"], 20)
 
1444
        self.assertEquals(stats["disk_available"], 80)
 
1445
        self.assertEquals(stats["host_memory_total"], 497)
 
1446
        self.assertEquals(stats["host_memory_free"], 409)
 
1447
        self.assertEquals(stats["hypervisor_type"], 'QEMU')
 
1448
        self.assertEquals(stats["hypervisor_version"], 13091)
 
1449
 
 
1450
 
 
1451
class NWFilterFakes:
 
1452
    def __init__(self):
 
1453
        self.filters = {}
 
1454
 
 
1455
    def nwfilterLookupByName(self, name):
 
1456
        if name in self.filters:
 
1457
            return self.filters[name]
 
1458
        raise libvirt.libvirtError('Filter Not Found')
 
1459
 
 
1460
    def filterDefineXMLMock(self, xml):
 
1461
        class FakeNWFilterInternal:
 
1462
            def __init__(self, parent, name):
 
1463
                self.name = name
 
1464
                self.parent = parent
 
1465
 
 
1466
            def undefine(self):
 
1467
                del self.parent.filters[self.name]
 
1468
                pass
 
1469
        tree = ElementTree.fromstring(xml)
 
1470
        name = tree.get('name')
 
1471
        if name not in self.filters:
 
1472
            self.filters[name] = FakeNWFilterInternal(self, name)
 
1473
        return True
 
1474
 
 
1475
 
 
1476
class IptablesFirewallTestCase(test.TestCase):
 
1477
    def setUp(self):
 
1478
        super(IptablesFirewallTestCase, self).setUp()
 
1479
 
 
1480
        self.user_id = 'fake'
 
1481
        self.project_id = 'fake'
 
1482
        self.context = context.RequestContext(self.user_id, self.project_id)
 
1483
 
 
1484
        class FakeLibvirtConnection(object):
 
1485
            def nwfilterDefineXML(*args, **kwargs):
 
1486
                """setup_basic_rules in nwfilter calls this."""
 
1487
                pass
 
1488
        self.fake_libvirt_connection = FakeLibvirtConnection()
 
1489
        self.fw = firewall.IptablesFirewallDriver(
 
1490
                      get_connection=lambda: self.fake_libvirt_connection)
 
1491
 
 
1492
    in_nat_rules = [
 
1493
      '# Generated by iptables-save v1.4.10 on Sat Feb 19 00:03:19 2011',
 
1494
      '*nat',
 
1495
      ':PREROUTING ACCEPT [1170:189210]',
 
1496
      ':INPUT ACCEPT [844:71028]',
 
1497
      ':OUTPUT ACCEPT [5149:405186]',
 
1498
      ':POSTROUTING ACCEPT [5063:386098]',
 
1499
    ]
 
1500
 
 
1501
    in_filter_rules = [
 
1502
      '# Generated by iptables-save v1.4.4 on Mon Dec  6 11:54:13 2010',
 
1503
      '*filter',
 
1504
      ':INPUT ACCEPT [969615:281627771]',
 
1505
      ':FORWARD ACCEPT [0:0]',
 
1506
      ':OUTPUT ACCEPT [915599:63811649]',
 
1507
      ':nova-block-ipv4 - [0:0]',
 
1508
      '-A INPUT -i virbr0 -p tcp -m tcp --dport 67 -j ACCEPT ',
 
1509
      '-A FORWARD -d 192.168.122.0/24 -o virbr0 -m state --state RELATED'
 
1510
      ',ESTABLISHED -j ACCEPT ',
 
1511
      '-A FORWARD -s 192.168.122.0/24 -i virbr0 -j ACCEPT ',
 
1512
      '-A FORWARD -i virbr0 -o virbr0 -j ACCEPT ',
 
1513
      '-A FORWARD -o virbr0 -j REJECT --reject-with icmp-port-unreachable ',
 
1514
      '-A FORWARD -i virbr0 -j REJECT --reject-with icmp-port-unreachable ',
 
1515
      'COMMIT',
 
1516
      '# Completed on Mon Dec  6 11:54:13 2010',
 
1517
    ]
 
1518
 
 
1519
    in6_filter_rules = [
 
1520
      '# Generated by ip6tables-save v1.4.4 on Tue Jan 18 23:47:56 2011',
 
1521
      '*filter',
 
1522
      ':INPUT ACCEPT [349155:75810423]',
 
1523
      ':FORWARD ACCEPT [0:0]',
 
1524
      ':OUTPUT ACCEPT [349256:75777230]',
 
1525
      'COMMIT',
 
1526
      '# Completed on Tue Jan 18 23:47:56 2011',
 
1527
    ]
 
1528
 
 
1529
    def _create_instance_ref(self):
 
1530
        return db.instance_create(self.context,
 
1531
                                  {'user_id': 'fake',
 
1532
                                   'project_id': 'fake',
 
1533
                                   'instance_type_id': 1})
 
1534
 
 
1535
    def test_static_filters(self):
 
1536
        instance_ref = self._create_instance_ref()
 
1537
        src_instance_ref = self._create_instance_ref()
 
1538
 
 
1539
        admin_ctxt = context.get_admin_context()
 
1540
        secgroup = db.security_group_create(admin_ctxt,
 
1541
                                            {'user_id': 'fake',
 
1542
                                             'project_id': 'fake',
 
1543
                                             'name': 'testgroup',
 
1544
                                             'description': 'test group'})
 
1545
 
 
1546
        src_secgroup = db.security_group_create(admin_ctxt,
 
1547
                                                {'user_id': 'fake',
 
1548
                                                 'project_id': 'fake',
 
1549
                                                 'name': 'testsourcegroup',
 
1550
                                                 'description': 'src group'})
 
1551
 
 
1552
        db.security_group_rule_create(admin_ctxt,
 
1553
                                      {'parent_group_id': secgroup['id'],
 
1554
                                       'protocol': 'icmp',
 
1555
                                       'from_port': -1,
 
1556
                                       'to_port': -1,
 
1557
                                       'cidr': '192.168.11.0/24'})
 
1558
 
 
1559
        db.security_group_rule_create(admin_ctxt,
 
1560
                                      {'parent_group_id': secgroup['id'],
 
1561
                                       'protocol': 'icmp',
 
1562
                                       'from_port': 8,
 
1563
                                       'to_port': -1,
 
1564
                                       'cidr': '192.168.11.0/24'})
 
1565
 
 
1566
        db.security_group_rule_create(admin_ctxt,
 
1567
                                      {'parent_group_id': secgroup['id'],
 
1568
                                       'protocol': 'tcp',
 
1569
                                       'from_port': 80,
 
1570
                                       'to_port': 81,
 
1571
                                       'cidr': '192.168.10.0/24'})
 
1572
 
 
1573
        db.security_group_rule_create(admin_ctxt,
 
1574
                                      {'parent_group_id': secgroup['id'],
 
1575
                                       'protocol': 'tcp',
 
1576
                                       'from_port': 80,
 
1577
                                       'to_port': 81,
 
1578
                                       'group_id': src_secgroup['id']})
 
1579
 
 
1580
        db.instance_add_security_group(admin_ctxt, instance_ref['uuid'],
 
1581
                                       secgroup['id'])
 
1582
        db.instance_add_security_group(admin_ctxt, src_instance_ref['uuid'],
 
1583
                                       src_secgroup['id'])
 
1584
        instance_ref = db.instance_get(admin_ctxt, instance_ref['id'])
 
1585
        src_instance_ref = db.instance_get(admin_ctxt, src_instance_ref['id'])
 
1586
 
 
1587
#        self.fw.add_instance(instance_ref)
 
1588
        def fake_iptables_execute(*cmd, **kwargs):
 
1589
            process_input = kwargs.get('process_input', None)
 
1590
            if cmd == ('ip6tables-save', '-t', 'filter'):
 
1591
                return '\n'.join(self.in6_filter_rules), None
 
1592
            if cmd == ('iptables-save', '-t', 'filter'):
 
1593
                return '\n'.join(self.in_filter_rules), None
 
1594
            if cmd == ('iptables-save', '-t', 'nat'):
 
1595
                return '\n'.join(self.in_nat_rules), None
 
1596
            if cmd == ('iptables-restore',):
 
1597
                lines = process_input.split('\n')
 
1598
                if '*filter' in lines:
 
1599
                    self.out_rules = lines
 
1600
                return '', ''
 
1601
            if cmd == ('ip6tables-restore',):
 
1602
                lines = process_input.split('\n')
 
1603
                if '*filter' in lines:
 
1604
                    self.out6_rules = lines
 
1605
                return '', ''
 
1606
            print cmd, kwargs
 
1607
 
 
1608
        network_model = _fake_network_info(self.stubs, 1, spectacular=True)
 
1609
 
 
1610
        from nova.network import linux_net
 
1611
        linux_net.iptables_manager.execute = fake_iptables_execute
 
1612
 
 
1613
        _fake_stub_out_get_nw_info(self.stubs, lambda *a, **kw: network_model)
 
1614
 
 
1615
        network_info = compute_utils.legacy_network_info(network_model)
 
1616
        self.fw.prepare_instance_filter(instance_ref, network_info)
 
1617
        self.fw.apply_instance_filter(instance_ref, network_info)
 
1618
 
 
1619
        in_rules = filter(lambda l: not l.startswith('#'),
 
1620
                          self.in_filter_rules)
 
1621
        for rule in in_rules:
 
1622
            if not 'nova' in rule:
 
1623
                self.assertTrue(rule in self.out_rules,
 
1624
                                'Rule went missing: %s' % rule)
 
1625
 
 
1626
        instance_chain = None
 
1627
        for rule in self.out_rules:
 
1628
            # This is pretty crude, but it'll do for now
 
1629
            # last two octets change
 
1630
            if re.search('-d 192.168.[0-9]{1,3}.[0-9]{1,3} -j', rule):
 
1631
                instance_chain = rule.split(' ')[-1]
 
1632
                break
 
1633
        self.assertTrue(instance_chain, "The instance chain wasn't added")
 
1634
 
 
1635
        security_group_chain = None
 
1636
        for rule in self.out_rules:
 
1637
            # This is pretty crude, but it'll do for now
 
1638
            if '-A %s -j' % instance_chain in rule:
 
1639
                security_group_chain = rule.split(' ')[-1]
 
1640
                break
 
1641
        self.assertTrue(security_group_chain,
 
1642
                        "The security group chain wasn't added")
 
1643
 
 
1644
        regex = re.compile('-A .* -j ACCEPT -p icmp -s 192.168.11.0/24')
 
1645
        self.assertTrue(len(filter(regex.match, self.out_rules)) > 0,
 
1646
                        "ICMP acceptance rule wasn't added")
 
1647
 
 
1648
        regex = re.compile('-A .* -j ACCEPT -p icmp -m icmp --icmp-type 8'
 
1649
                           ' -s 192.168.11.0/24')
 
1650
        self.assertTrue(len(filter(regex.match, self.out_rules)) > 0,
 
1651
                        "ICMP Echo Request acceptance rule wasn't added")
 
1652
 
 
1653
        for ip in network_model.fixed_ips():
 
1654
            if ip['version'] != 4:
 
1655
                continue
 
1656
            regex = re.compile('-A .* -j ACCEPT -p tcp -m multiport '
 
1657
                               '--dports 80:81 -s %s' % ip['address'])
 
1658
            self.assertTrue(len(filter(regex.match, self.out_rules)) > 0,
 
1659
                            "TCP port 80/81 acceptance rule wasn't added")
 
1660
 
 
1661
        regex = re.compile('-A .* -j ACCEPT -p tcp '
 
1662
                           '-m multiport --dports 80:81 -s 192.168.10.0/24')
 
1663
        self.assertTrue(len(filter(regex.match, self.out_rules)) > 0,
 
1664
                        "TCP port 80/81 acceptance rule wasn't added")
 
1665
        db.instance_destroy(admin_ctxt, instance_ref['id'])
 
1666
 
 
1667
    def test_filters_for_instance_with_ip_v6(self):
 
1668
        self.flags(use_ipv6=True)
 
1669
        network_info = _fake_network_info(self.stubs, 1)
 
1670
        rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
 
1671
        self.assertEquals(len(rulesv4), 2)
 
1672
        self.assertEquals(len(rulesv6), 1)
 
1673
 
 
1674
    def test_filters_for_instance_without_ip_v6(self):
 
1675
        self.flags(use_ipv6=False)
 
1676
        network_info = _fake_network_info(self.stubs, 1)
 
1677
        rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
 
1678
        self.assertEquals(len(rulesv4), 2)
 
1679
        self.assertEquals(len(rulesv6), 0)
 
1680
 
 
1681
    def test_multinic_iptables(self):
 
1682
        ipv4_rules_per_addr = 1
 
1683
        ipv4_addr_per_network = 2
 
1684
        ipv6_rules_per_addr = 1
 
1685
        ipv6_addr_per_network = 1
 
1686
        networks_count = 5
 
1687
        instance_ref = self._create_instance_ref()
 
1688
        network_info = _fake_network_info(self.stubs, networks_count,
 
1689
                                                      ipv4_addr_per_network)
 
1690
        ipv4_len = len(self.fw.iptables.ipv4['filter'].rules)
 
1691
        ipv6_len = len(self.fw.iptables.ipv6['filter'].rules)
 
1692
        inst_ipv4, inst_ipv6 = self.fw.instance_rules(instance_ref,
 
1693
                                                      network_info)
 
1694
        self.fw.prepare_instance_filter(instance_ref, network_info)
 
1695
        ipv4 = self.fw.iptables.ipv4['filter'].rules
 
1696
        ipv6 = self.fw.iptables.ipv6['filter'].rules
 
1697
        ipv4_network_rules = len(ipv4) - len(inst_ipv4) - ipv4_len
 
1698
        ipv6_network_rules = len(ipv6) - len(inst_ipv6) - ipv6_len
 
1699
        self.assertEquals(ipv4_network_rules,
 
1700
                  ipv4_rules_per_addr * ipv4_addr_per_network * networks_count)
 
1701
        self.assertEquals(ipv6_network_rules,
 
1702
                  ipv6_rules_per_addr * ipv6_addr_per_network * networks_count)
 
1703
 
 
1704
    def test_do_refresh_security_group_rules(self):
 
1705
        instance_ref = self._create_instance_ref()
 
1706
        self.mox.StubOutWithMock(self.fw,
 
1707
                                 'add_filters_for_instance',
 
1708
                                 use_mock_anything=True)
 
1709
        self.fw.prepare_instance_filter(instance_ref, mox.IgnoreArg())
 
1710
        self.fw.instances[instance_ref['id']] = instance_ref
 
1711
        self.mox.ReplayAll()
 
1712
        self.fw.do_refresh_security_group_rules("fake")
 
1713
 
 
1714
    @test.skip_if(missing_libvirt(), "Test requires libvirt")
 
1715
    def test_unfilter_instance_undefines_nwfilter(self):
 
1716
        admin_ctxt = context.get_admin_context()
 
1717
 
 
1718
        fakefilter = NWFilterFakes()
 
1719
        _xml_mock = fakefilter.filterDefineXMLMock
 
1720
        self.fw.nwfilter._conn.nwfilterDefineXML = _xml_mock
 
1721
        _lookup_name = fakefilter.nwfilterLookupByName
 
1722
        self.fw.nwfilter._conn.nwfilterLookupByName = _lookup_name
 
1723
        instance_ref = self._create_instance_ref()
 
1724
 
 
1725
        network_info = _fake_network_info(self.stubs, 1)
 
1726
        self.fw.setup_basic_filtering(instance_ref, network_info)
 
1727
        self.fw.prepare_instance_filter(instance_ref, network_info)
 
1728
        self.fw.apply_instance_filter(instance_ref, network_info)
 
1729
        original_filter_count = len(fakefilter.filters)
 
1730
        self.fw.unfilter_instance(instance_ref, network_info)
 
1731
 
 
1732
        # should undefine just the instance filter
 
1733
        self.assertEqual(original_filter_count - len(fakefilter.filters), 1)
 
1734
 
 
1735
        db.instance_destroy(admin_ctxt, instance_ref['id'])
 
1736
 
 
1737
    def test_provider_firewall_rules(self):
 
1738
        # setup basic instance data
 
1739
        instance_ref = self._create_instance_ref()
 
1740
        # FRAGILE: peeks at how the firewall names chains
 
1741
        chain_name = 'inst-%s' % instance_ref['id']
 
1742
 
 
1743
        # create a firewall via setup_basic_filtering like libvirt_conn.spawn
 
1744
        # should have a chain with 0 rules
 
1745
        network_info = _fake_network_info(self.stubs, 1)
 
1746
        self.fw.setup_basic_filtering(instance_ref, network_info)
 
1747
        self.assertTrue('provider' in self.fw.iptables.ipv4['filter'].chains)
 
1748
        rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
 
1749
                      if rule.chain == 'provider']
 
1750
        self.assertEqual(0, len(rules))
 
1751
 
 
1752
        admin_ctxt = context.get_admin_context()
 
1753
        # add a rule and send the update message, check for 1 rule
 
1754
        provider_fw0 = db.provider_fw_rule_create(admin_ctxt,
 
1755
                                                  {'protocol': 'tcp',
 
1756
                                                   'cidr': '10.99.99.99/32',
 
1757
                                                   'from_port': 1,
 
1758
                                                   'to_port': 65535})
 
1759
        self.fw.refresh_provider_fw_rules()
 
1760
        rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
 
1761
                      if rule.chain == 'provider']
 
1762
        self.assertEqual(1, len(rules))
 
1763
 
 
1764
        # Add another, refresh, and make sure number of rules goes to two
 
1765
        provider_fw1 = db.provider_fw_rule_create(admin_ctxt,
 
1766
                                                  {'protocol': 'udp',
 
1767
                                                   'cidr': '10.99.99.99/32',
 
1768
                                                   'from_port': 1,
 
1769
                                                   'to_port': 65535})
 
1770
        self.fw.refresh_provider_fw_rules()
 
1771
        rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
 
1772
                      if rule.chain == 'provider']
 
1773
        self.assertEqual(2, len(rules))
 
1774
 
 
1775
        # create the instance filter and make sure it has a jump rule
 
1776
        self.fw.prepare_instance_filter(instance_ref, network_info)
 
1777
        self.fw.apply_instance_filter(instance_ref, network_info)
 
1778
        inst_rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
 
1779
                           if rule.chain == chain_name]
 
1780
        jump_rules = [rule for rule in inst_rules if '-j' in rule.rule]
 
1781
        provjump_rules = []
 
1782
        # IptablesTable doesn't make rules unique internally
 
1783
        for rule in jump_rules:
 
1784
            if 'provider' in rule.rule and rule not in provjump_rules:
 
1785
                provjump_rules.append(rule)
 
1786
        self.assertEqual(1, len(provjump_rules))
 
1787
 
 
1788
        # remove a rule from the db, cast to compute to refresh rule
 
1789
        db.provider_fw_rule_destroy(admin_ctxt, provider_fw1['id'])
 
1790
        self.fw.refresh_provider_fw_rules()
 
1791
        rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
 
1792
                      if rule.chain == 'provider']
 
1793
        self.assertEqual(1, len(rules))
 
1794
 
 
1795
 
 
1796
class NWFilterTestCase(test.TestCase):
 
1797
    def setUp(self):
 
1798
        super(NWFilterTestCase, self).setUp()
 
1799
 
 
1800
        class Mock(object):
 
1801
            pass
 
1802
 
 
1803
        self.user_id = 'fake'
 
1804
        self.project_id = 'fake'
 
1805
        self.context = context.RequestContext(self.user_id, self.project_id)
 
1806
 
 
1807
        self.fake_libvirt_connection = Mock()
 
1808
 
 
1809
        self.fw = firewall.NWFilterFirewall(
 
1810
                                         lambda: self.fake_libvirt_connection)
 
1811
 
 
1812
    def test_cidr_rule_nwfilter_xml(self):
 
1813
        cloud_controller = cloud.CloudController()
 
1814
        cloud_controller.create_security_group(self.context,
 
1815
                                               'testgroup',
 
1816
                                               'test group description')
 
1817
        cloud_controller.authorize_security_group_ingress(self.context,
 
1818
                                                          'testgroup',
 
1819
                                                          from_port='80',
 
1820
                                                          to_port='81',
 
1821
                                                          ip_protocol='tcp',
 
1822
                                                          cidr_ip='0.0.0.0/0')
 
1823
 
 
1824
        security_group = db.security_group_get_by_name(self.context,
 
1825
                                                       'fake',
 
1826
                                                       'testgroup')
 
1827
        self.teardown_security_group()
 
1828
 
 
1829
    def teardown_security_group(self):
 
1830
        cloud_controller = cloud.CloudController()
 
1831
        cloud_controller.delete_security_group(self.context, 'testgroup')
 
1832
 
 
1833
    def setup_and_return_security_group(self):
 
1834
        cloud_controller = cloud.CloudController()
 
1835
        cloud_controller.create_security_group(self.context,
 
1836
                                               'testgroup',
 
1837
                                               'test group description')
 
1838
        cloud_controller.authorize_security_group_ingress(self.context,
 
1839
                                                          'testgroup',
 
1840
                                                          from_port='80',
 
1841
                                                          to_port='81',
 
1842
                                                          ip_protocol='tcp',
 
1843
                                                          cidr_ip='0.0.0.0/0')
 
1844
 
 
1845
        return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
 
1846
 
 
1847
    def _create_instance(self):
 
1848
        return db.instance_create(self.context,
 
1849
                                  {'user_id': 'fake',
 
1850
                                   'project_id': 'fake',
 
1851
                                   'instance_type_id': 1})
 
1852
 
 
1853
    def _create_instance_type(self, params=None):
 
1854
        """Create a test instance"""
 
1855
        if not params:
 
1856
            params = {}
 
1857
 
 
1858
        context = self.context.elevated()
 
1859
        inst = {}
 
1860
        inst['name'] = 'm1.small'
 
1861
        inst['memory_mb'] = '1024'
 
1862
        inst['vcpus'] = '1'
 
1863
        inst['root_gb'] = '10'
 
1864
        inst['ephemeral_gb'] = '20'
 
1865
        inst['flavorid'] = '1'
 
1866
        inst['swap'] = '2048'
 
1867
        inst['rxtx_factor'] = 1
 
1868
        inst.update(params)
 
1869
        return db.instance_type_create(context, inst)['id']
 
1870
 
 
1871
    def test_creates_base_rule_first(self):
 
1872
        # These come pre-defined by libvirt
 
1873
        self.defined_filters = ['no-mac-spoofing',
 
1874
                                'no-ip-spoofing',
 
1875
                                'no-arp-spoofing',
 
1876
                                'allow-dhcp-server']
 
1877
 
 
1878
        self.recursive_depends = {}
 
1879
        for f in self.defined_filters:
 
1880
            self.recursive_depends[f] = []
 
1881
 
 
1882
        def _filterDefineXMLMock(xml):
 
1883
            dom = minidom.parseString(xml)
 
1884
            name = dom.firstChild.getAttribute('name')
 
1885
            self.recursive_depends[name] = []
 
1886
            for f in dom.getElementsByTagName('filterref'):
 
1887
                ref = f.getAttribute('filter')
 
1888
                self.assertTrue(ref in self.defined_filters,
 
1889
                                ('%s referenced filter that does ' +
 
1890
                                'not yet exist: %s') % (name, ref))
 
1891
                dependencies = [ref] + self.recursive_depends[ref]
 
1892
                self.recursive_depends[name] += dependencies
 
1893
 
 
1894
            self.defined_filters.append(name)
 
1895
            return True
 
1896
 
 
1897
        self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
 
1898
 
 
1899
        instance_ref = self._create_instance()
 
1900
        inst_id = instance_ref['id']
 
1901
        inst_uuid = instance_ref['uuid']
 
1902
 
 
1903
        def _ensure_all_called(mac):
 
1904
            instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'],
 
1905
                                                   mac.translate(None, ':'))
 
1906
            for required in ['allow-dhcp-server',
 
1907
                             'no-arp-spoofing', 'no-ip-spoofing',
 
1908
                             'no-mac-spoofing']:
 
1909
                self.assertTrue(required in
 
1910
                                self.recursive_depends[instance_filter],
 
1911
                                "Instance's filter does not include %s" %
 
1912
                                required)
 
1913
 
 
1914
        self.security_group = self.setup_and_return_security_group()
 
1915
 
 
1916
        db.instance_add_security_group(self.context, inst_uuid,
 
1917
                                       self.security_group.id)
 
1918
        instance = db.instance_get(self.context, inst_id)
 
1919
 
 
1920
        network_info = _fake_network_info(self.stubs, 1)
 
1921
        # since there is one (network_info) there is one vif
 
1922
        # pass this vif's mac to _ensure_all_called()
 
1923
        # to set the instance_filter properly
 
1924
        mac = network_info[0][1]['mac']
 
1925
 
 
1926
        self.fw.setup_basic_filtering(instance, network_info)
 
1927
        _ensure_all_called(mac)
 
1928
        db.instance_remove_security_group(self.context, inst_uuid,
 
1929
                                          self.security_group.id)
 
1930
        self.teardown_security_group()
 
1931
        db.instance_destroy(context.get_admin_context(), instance_ref['id'])
 
1932
 
 
1933
    def test_unfilter_instance_undefines_nwfilters(self):
 
1934
        admin_ctxt = context.get_admin_context()
 
1935
 
 
1936
        fakefilter = NWFilterFakes()
 
1937
        self.fw._conn.nwfilterDefineXML = fakefilter.filterDefineXMLMock
 
1938
        self.fw._conn.nwfilterLookupByName = fakefilter.nwfilterLookupByName
 
1939
 
 
1940
        instance_ref = self._create_instance()
 
1941
        inst_id = instance_ref['id']
 
1942
        inst_uuid = instance_ref['uuid']
 
1943
 
 
1944
        self.security_group = self.setup_and_return_security_group()
 
1945
 
 
1946
        db.instance_add_security_group(self.context, inst_uuid,
 
1947
                                       self.security_group.id)
 
1948
 
 
1949
        instance = db.instance_get(self.context, inst_id)
 
1950
 
 
1951
        network_info = _fake_network_info(self.stubs, 1)
 
1952
        self.fw.setup_basic_filtering(instance, network_info)
 
1953
        original_filter_count = len(fakefilter.filters)
 
1954
        self.fw.unfilter_instance(instance, network_info)
 
1955
        self.assertEqual(original_filter_count - len(fakefilter.filters), 1)
 
1956
 
 
1957
        db.instance_destroy(admin_ctxt, instance_ref['id'])
 
1958
 
 
1959
 
 
1960
class LibvirtUtilsTestCase(test.TestCase):
 
1961
    def test_get_iscsi_initiator(self):
 
1962
        self.mox.StubOutWithMock(utils, 'execute')
 
1963
        initiator = 'fake.initiator.iqn'
 
1964
        rval = ("junk\nInitiatorName=%s\njunk\n" % initiator, None)
 
1965
        utils.execute('cat', '/etc/iscsi/initiatorname.iscsi',
 
1966
                      run_as_root=True).AndReturn(rval)
 
1967
        # Start test
 
1968
        self.mox.ReplayAll()
 
1969
        result = libvirt_utils.get_iscsi_initiator()
 
1970
        self.assertEqual(initiator, result)
 
1971
 
 
1972
    def test_create_image(self):
 
1973
        self.mox.StubOutWithMock(utils, 'execute')
 
1974
        utils.execute('qemu-img', 'create', '-f', 'raw',
 
1975
                      '/some/path', '10G')
 
1976
        utils.execute('qemu-img', 'create', '-f', 'qcow2',
 
1977
                      '/some/stuff', '1234567891234')
 
1978
        # Start test
 
1979
        self.mox.ReplayAll()
 
1980
        libvirt_utils.create_image('raw', '/some/path', '10G')
 
1981
        libvirt_utils.create_image('qcow2', '/some/stuff', '1234567891234')
 
1982
 
 
1983
    def test_create_cow_image(self):
 
1984
        self.mox.StubOutWithMock(utils, 'execute')
 
1985
        utils.execute('qemu-img', 'create', '-f', 'qcow2',
 
1986
                      '-o', 'cluster_size=2M,backing_file=/some/path',
 
1987
                      '/the/new/cow')
 
1988
        # Start test
 
1989
        self.mox.ReplayAll()
 
1990
        libvirt_utils.create_cow_image('/some/path', '/the/new/cow')
 
1991
 
 
1992
    def test_get_disk_size(self):
 
1993
        self.mox.StubOutWithMock(utils, 'execute')
 
1994
        utils.execute('qemu-img',
 
1995
                      'info',
 
1996
                      '/some/path').AndReturn(('''image: 00000001
 
1997
file format: raw
 
1998
virtual size: 4.4M (4592640 bytes)
 
1999
disk size: 4.4M''', ''))
 
2000
 
 
2001
        # Start test
 
2002
        self.mox.ReplayAll()
 
2003
        self.assertEquals(libvirt_utils.get_disk_size('/some/path'), 4592640)
 
2004
 
 
2005
    def test_copy_image(self):
 
2006
        dst_fd, dst_path = tempfile.mkstemp()
 
2007
        try:
 
2008
            os.close(dst_fd)
 
2009
 
 
2010
            src_fd, src_path = tempfile.mkstemp()
 
2011
            try:
 
2012
                with os.fdopen(src_fd, 'w') as fp:
 
2013
                    fp.write('canary')
 
2014
 
 
2015
                libvirt_utils.copy_image(src_path, dst_path)
 
2016
                with open(dst_path, 'r') as fp:
 
2017
                    self.assertEquals(fp.read(), 'canary')
 
2018
            finally:
 
2019
                os.unlink(src_path)
 
2020
        finally:
 
2021
            os.unlink(dst_path)
 
2022
 
 
2023
    def test_mkfs(self):
 
2024
        self.mox.StubOutWithMock(utils, 'execute')
 
2025
        utils.execute('mkfs', '-t', 'ext4', '/my/block/dev')
 
2026
        utils.execute('mkswap', '/my/swap/block/dev')
 
2027
        self.mox.ReplayAll()
 
2028
 
 
2029
        libvirt_utils.mkfs('ext4', '/my/block/dev')
 
2030
        libvirt_utils.mkfs('swap', '/my/swap/block/dev')
 
2031
 
 
2032
    def test_ensure_tree(self):
 
2033
        with utils.tempdir() as tmpdir:
 
2034
            testdir = '%s/foo/bar/baz' % (tmpdir,)
 
2035
            libvirt_utils.ensure_tree(testdir)
 
2036
            self.assertTrue(os.path.isdir(testdir))
 
2037
 
 
2038
    def test_write_to_file(self):
 
2039
        dst_fd, dst_path = tempfile.mkstemp()
 
2040
        try:
 
2041
            os.close(dst_fd)
 
2042
 
 
2043
            libvirt_utils.write_to_file(dst_path, 'hello')
 
2044
            with open(dst_path, 'r') as fp:
 
2045
                self.assertEquals(fp.read(), 'hello')
 
2046
        finally:
 
2047
            os.unlink(dst_path)
 
2048
 
 
2049
    def test_write_to_file_with_umask(self):
 
2050
        dst_fd, dst_path = tempfile.mkstemp()
 
2051
        try:
 
2052
            os.close(dst_fd)
 
2053
            os.unlink(dst_path)
 
2054
 
 
2055
            libvirt_utils.write_to_file(dst_path, 'hello', umask=0277)
 
2056
            with open(dst_path, 'r') as fp:
 
2057
                self.assertEquals(fp.read(), 'hello')
 
2058
            mode = os.stat(dst_path).st_mode
 
2059
            self.assertEquals(mode & 0277, 0)
 
2060
        finally:
 
2061
            os.unlink(dst_path)
 
2062
 
 
2063
    def test_chown(self):
 
2064
        self.mox.StubOutWithMock(utils, 'execute')
 
2065
        utils.execute('chown', 'soren', '/some/path', run_as_root=True)
 
2066
        self.mox.ReplayAll()
 
2067
        libvirt_utils.chown('/some/path', 'soren')
 
2068
 
 
2069
    def test_extract_snapshot(self):
 
2070
        self.mox.StubOutWithMock(utils, 'execute')
 
2071
        utils.execute('qemu-img', 'convert', '-f', 'qcow2', '-O', 'raw',
 
2072
                      '-s', 'snap1', '/path/to/disk/image', '/extracted/snap')
 
2073
 
 
2074
        # Start test
 
2075
        self.mox.ReplayAll()
 
2076
        libvirt_utils.extract_snapshot('/path/to/disk/image', 'qcow2',
 
2077
                                       'snap1', '/extracted/snap', 'raw')
 
2078
 
 
2079
    def test_load_file(self):
 
2080
        dst_fd, dst_path = tempfile.mkstemp()
 
2081
        try:
 
2082
            os.close(dst_fd)
 
2083
 
 
2084
            # We have a test for write_to_file. If that is sound, this suffices
 
2085
            libvirt_utils.write_to_file(dst_path, 'hello')
 
2086
            self.assertEquals(libvirt_utils.load_file(dst_path), 'hello')
 
2087
        finally:
 
2088
            os.unlink(dst_path)
 
2089
 
 
2090
    def test_file_open(self):
 
2091
        dst_fd, dst_path = tempfile.mkstemp()
 
2092
        try:
 
2093
            os.close(dst_fd)
 
2094
 
 
2095
            # We have a test for write_to_file. If that is sound, this suffices
 
2096
            libvirt_utils.write_to_file(dst_path, 'hello')
 
2097
            with libvirt_utils.file_open(dst_path, 'r') as fp:
 
2098
                self.assertEquals(fp.read(), 'hello')
 
2099
        finally:
 
2100
            os.unlink(dst_path)
 
2101
 
 
2102
    def test_get_fs_info(self):
 
2103
 
 
2104
        class FakeStatResult(object):
 
2105
 
 
2106
            def __init__(self):
 
2107
                self.f_bsize = 4096
 
2108
                self.f_frsize = 4096
 
2109
                self.f_blocks = 2000
 
2110
                self.f_bfree = 1000
 
2111
                self.f_bavail = 900
 
2112
                self.f_files = 2000
 
2113
                self.f_ffree = 1000
 
2114
                self.f_favail = 900
 
2115
                self.f_flag = 4096
 
2116
                self.f_namemax = 255
 
2117
 
 
2118
        self.path = None
 
2119
 
 
2120
        def fake_statvfs(path):
 
2121
            self.path = path
 
2122
            return FakeStatResult()
 
2123
 
 
2124
        self.stubs.Set(os, 'statvfs', fake_statvfs)
 
2125
 
 
2126
        fs_info = libvirt_utils.get_fs_info('/some/file/path')
 
2127
        self.assertEquals('/some/file/path', self.path)
 
2128
        self.assertEquals(8192000, fs_info['total'])
 
2129
        self.assertEquals(3686400, fs_info['free'])
 
2130
        self.assertEquals(4096000, fs_info['used'])
 
2131
 
 
2132
    def test_fetch_image(self):
 
2133
        self.mox.StubOutWithMock(images, 'fetch_to_raw')
 
2134
 
 
2135
        context = 'opaque context'
 
2136
        target = '/tmp/targetfile'
 
2137
        image_id = '4'
 
2138
        user_id = 'fake'
 
2139
        project_id = 'fake'
 
2140
        images.fetch_to_raw(context, image_id, target, user_id, project_id)
 
2141
 
 
2142
        self.mox.ReplayAll()
 
2143
        libvirt_utils.fetch_image(context, target, image_id,
 
2144
                                  user_id, project_id)
 
2145
 
 
2146
 
 
2147
class LibvirtConnectionTestCase(test.TestCase):
 
2148
    """Test for nova.virt.libvirt.connection.LibvirtConnection."""
 
2149
    def setUp(self):
 
2150
        super(LibvirtConnectionTestCase, self).setUp()
 
2151
 
 
2152
        self.libvirtconnection = connection.LibvirtConnection(read_only=True)
 
2153
 
 
2154
        self.temp_path = os.path.join(flags.FLAGS.instances_path,
 
2155
                                      'instance-00000001/', '')
 
2156
        try:
 
2157
            os.makedirs(self.temp_path)
 
2158
        except Exception:
 
2159
            print 'testcase init error'
 
2160
            pass
 
2161
 
 
2162
    def tearDown(self):
 
2163
        super(LibvirtConnectionTestCase, self).tearDown()
 
2164
 
 
2165
        try:
 
2166
            shutil.rmtree(flags.FLAGS.instances_path)
 
2167
        except Exception:
 
2168
            pass
 
2169
 
 
2170
    def _create_instance(self, params=None):
 
2171
        """Create a test instance"""
 
2172
        if not params:
 
2173
            params = {}
 
2174
 
 
2175
        inst = {}
 
2176
        inst['image_ref'] = '1'
 
2177
        inst['reservation_id'] = 'r-fakeres'
 
2178
        inst['launch_time'] = '10'
 
2179
        inst['user_id'] = 'fake'
 
2180
        inst['project_id'] = 'fake'
 
2181
        type_id = instance_types.get_instance_type_by_name('m1.tiny')['id']
 
2182
        inst['instance_type_id'] = type_id
 
2183
        inst['ami_launch_index'] = 0
 
2184
        inst['host'] = 'host1'
 
2185
        inst['root_gb'] = 10
 
2186
        inst['ephemeral_gb'] = 20
 
2187
        inst['config_drive'] = 1
 
2188
        inst['kernel_id'] = 2
 
2189
        inst['ramdisk_id'] = 3
 
2190
        inst['config_drive_id'] = 1
 
2191
        inst['key_data'] = 'ABCDEFG'
 
2192
 
 
2193
        inst.update(params)
 
2194
        return db.instance_create(context.get_admin_context(), inst)
 
2195
 
 
2196
    def test_migrate_disk_and_power_off_exception(self):
 
2197
        """Test for nova.virt.libvirt.connection.LivirtConnection
 
2198
        .migrate_disk_and_power_off. """
 
2199
 
 
2200
        self.counter = 0
 
2201
 
 
2202
        def fake_get_instance_disk_info(instance):
 
2203
            return '[]'
 
2204
 
 
2205
        def fake_destroy(instance, network_info, cleanup=True):
 
2206
            pass
 
2207
 
 
2208
        def fake_get_host_ip_addr():
 
2209
            return '10.0.0.1'
 
2210
 
 
2211
        def fake_execute(*args, **kwargs):
 
2212
            self.counter += 1
 
2213
            if self.counter == 1:
 
2214
                assert False, "intentional failure"
 
2215
 
 
2216
        def fake_os_path_exists(path):
 
2217
            return True
 
2218
 
 
2219
        self.stubs.Set(self.libvirtconnection, 'get_instance_disk_info',
 
2220
                       fake_get_instance_disk_info)
 
2221
        self.stubs.Set(self.libvirtconnection, '_destroy', fake_destroy)
 
2222
        self.stubs.Set(self.libvirtconnection, 'get_host_ip_addr',
 
2223
                       fake_get_host_ip_addr)
 
2224
        self.stubs.Set(utils, 'execute', fake_execute)
 
2225
        self.stubs.Set(os.path, 'exists', fake_os_path_exists)
 
2226
 
 
2227
        ins_ref = self._create_instance()
 
2228
 
 
2229
        self.assertRaises(AssertionError,
 
2230
                          self.libvirtconnection.migrate_disk_and_power_off,
 
2231
                          None, ins_ref, '10.0.0.2', None, None)
 
2232
 
 
2233
    def test_migrate_disk_and_power_off(self):
 
2234
        """Test for nova.virt.libvirt.connection.LivirtConnection
 
2235
        .migrate_disk_and_power_off. """
 
2236
 
 
2237
        disk_info = [{'type': 'qcow2', 'path': '/test/disk',
 
2238
                      'virt_disk_size': '10737418240',
 
2239
                      'backing_file': '/base/disk',
 
2240
                      'disk_size':'83886080'},
 
2241
                     {'type': 'raw', 'path': '/test/disk.local',
 
2242
                      'virt_disk_size': '10737418240',
 
2243
                      'backing_file': '/base/disk.local',
 
2244
                      'disk_size':'83886080'}]
 
2245
        disk_info_text = utils.dumps(disk_info)
 
2246
 
 
2247
        def fake_get_instance_disk_info(instance):
 
2248
            return disk_info_text
 
2249
 
 
2250
        def fake_destroy(instance, network_info, cleanup=True):
 
2251
            pass
 
2252
 
 
2253
        def fake_get_host_ip_addr():
 
2254
            return '10.0.0.1'
 
2255
 
 
2256
        def fake_execute(*args, **kwargs):
 
2257
            pass
 
2258
 
 
2259
        self.stubs.Set(self.libvirtconnection, 'get_instance_disk_info',
 
2260
                       fake_get_instance_disk_info)
 
2261
        self.stubs.Set(self.libvirtconnection, '_destroy', fake_destroy)
 
2262
        self.stubs.Set(self.libvirtconnection, 'get_host_ip_addr',
 
2263
                       fake_get_host_ip_addr)
 
2264
        self.stubs.Set(utils, 'execute', fake_execute)
 
2265
 
 
2266
        ins_ref = self._create_instance()
 
2267
        """ dest is different host case """
 
2268
        out = self.libvirtconnection.migrate_disk_and_power_off(
 
2269
               None, ins_ref, '10.0.0.2', None, None)
 
2270
        self.assertEquals(out, disk_info_text)
 
2271
 
 
2272
        """ dest is same host case """
 
2273
        out = self.libvirtconnection.migrate_disk_and_power_off(
 
2274
               None, ins_ref, '10.0.0.1', None, None)
 
2275
        self.assertEquals(out, disk_info_text)
 
2276
 
 
2277
    def test_wait_for_running(self):
 
2278
        """Test for nova.virt.libvirt.connection.LivirtConnection
 
2279
        ._wait_for_running. """
 
2280
 
 
2281
        def fake_get_info(instance):
 
2282
            if instance['name'] == "not_found":
 
2283
                raise exception.NotFound
 
2284
            elif instance['name'] == "running":
 
2285
                return {'state': power_state.RUNNING}
 
2286
            else:
 
2287
                return {'state': power_state.SHUTOFF}
 
2288
 
 
2289
        self.stubs.Set(self.libvirtconnection, 'get_info',
 
2290
                       fake_get_info)
 
2291
 
 
2292
        """ instance not found case """
 
2293
        self.assertRaises(utils.LoopingCallDone,
 
2294
                self.libvirtconnection._wait_for_running,
 
2295
                    {'name': 'not_found',
 
2296
                     'uuid': 'not_found_uuid'})
 
2297
 
 
2298
        """ instance is running case """
 
2299
        self.assertRaises(utils.LoopingCallDone,
 
2300
                self.libvirtconnection._wait_for_running,
 
2301
                    {'name': 'running',
 
2302
                     'uuid': 'running_uuid'})
 
2303
 
 
2304
        """ else case """
 
2305
        self.libvirtconnection._wait_for_running({'name': 'else',
 
2306
                                                  'uuid': 'other_uuid'})
 
2307
 
 
2308
    def test_finish_migration(self):
 
2309
        """Test for nova.virt.libvirt.connection.LivirtConnection
 
2310
        .finish_migration. """
 
2311
 
 
2312
        disk_info = [{'type': 'qcow2', 'path': '/test/disk',
 
2313
                      'local_gb': 10, 'backing_file': '/base/disk'},
 
2314
                     {'type': 'raw', 'path': '/test/disk.local',
 
2315
                      'local_gb': 10, 'backing_file': '/base/disk.local'}]
 
2316
        disk_info_text = utils.dumps(disk_info)
 
2317
 
 
2318
        def fake_extend(path, size):
 
2319
            pass
 
2320
 
 
2321
        def fake_to_xml(instance, network_info):
 
2322
            return ""
 
2323
 
 
2324
        def fake_plug_vifs(instance, network_info):
 
2325
            pass
 
2326
 
 
2327
        def fake_create_image(context, inst, libvirt_xml, suffix='',
 
2328
                      disk_images=None, network_info=None,
 
2329
                      block_device_info=None):
 
2330
            pass
 
2331
 
 
2332
        def fake_create_new_domain(xml):
 
2333
            return None
 
2334
 
 
2335
        def fake_execute(*args, **kwargs):
 
2336
            pass
 
2337
 
 
2338
        self.flags(use_cow_images=True)
 
2339
        self.stubs.Set(connection.disk, 'extend', fake_extend)
 
2340
        self.stubs.Set(self.libvirtconnection, 'to_xml', fake_to_xml)
 
2341
        self.stubs.Set(self.libvirtconnection, 'plug_vifs', fake_plug_vifs)
 
2342
        self.stubs.Set(self.libvirtconnection, '_create_image',
 
2343
                       fake_create_image)
 
2344
        self.stubs.Set(self.libvirtconnection, '_create_new_domain',
 
2345
                       fake_create_new_domain)
 
2346
        self.stubs.Set(utils, 'execute', fake_execute)
 
2347
        fw = base_firewall.NoopFirewallDriver()
 
2348
        self.stubs.Set(self.libvirtconnection, 'firewall_driver', fw)
 
2349
 
 
2350
        ins_ref = self._create_instance()
 
2351
 
 
2352
        ref = self.libvirtconnection.finish_migration(
 
2353
                      context.get_admin_context(), None, ins_ref,
 
2354
                      disk_info_text, None, None, None)
 
2355
        self.assertTrue(isinstance(ref, eventlet.event.Event))
 
2356
 
 
2357
    def test_finish_revert_migration(self):
 
2358
        """Test for nova.virt.libvirt.connection.LivirtConnection
 
2359
        .finish_revert_migration. """
 
2360
 
 
2361
        def fake_execute(*args, **kwargs):
 
2362
            pass
 
2363
 
 
2364
        def fake_plug_vifs(instance, network_info):
 
2365
            pass
 
2366
 
 
2367
        def fake_create_new_domain(xml):
 
2368
            return None
 
2369
 
 
2370
        self.stubs.Set(self.libvirtconnection, 'plug_vifs', fake_plug_vifs)
 
2371
        self.stubs.Set(utils, 'execute', fake_execute)
 
2372
        fw = base_firewall.NoopFirewallDriver()
 
2373
        self.stubs.Set(self.libvirtconnection, 'firewall_driver', fw)
 
2374
        self.stubs.Set(self.libvirtconnection, '_create_new_domain',
 
2375
                       fake_create_new_domain)
 
2376
 
 
2377
        ins_ref = self._create_instance()
 
2378
        libvirt_xml_path = os.path.join(flags.FLAGS.instances_path,
 
2379
                                        ins_ref['name'], 'libvirt.xml')
 
2380
        f = open(libvirt_xml_path, 'w')
 
2381
        f.close()
 
2382
 
 
2383
        ref = self.libvirtconnection.finish_revert_migration(ins_ref, None)
 
2384
        self.assertTrue(isinstance(ref, eventlet.event.Event))