~cloud-init-dev/cloud-init/trunk

« back to all changes in this revision

Viewing changes to tests/unittests/test_datasource/test_azure.py

  • Committer: Scott Moser
  • Date: 2016-08-10 15:06:15 UTC
  • Revision ID: smoser@ubuntu.com-20160810150615-ma2fv107w3suy1ma
README: Mention move of revision control to git.

cloud-init development has moved its revision control to git.
It is available at 
  https://code.launchpad.net/cloud-init

Clone with 
  git clone https://git.launchpad.net/cloud-init
or
  git clone git+ssh://git.launchpad.net/cloud-init

For more information see
  https://git.launchpad.net/cloud-init/tree/HACKING.rst

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
from cloudinit import helpers
2
 
from cloudinit.util import b64e, decode_binary, load_file
3
 
from cloudinit.sources import DataSourceAzure
4
 
 
5
 
from ..helpers import TestCase, populate_dir, mock, ExitStack, PY26, SkipTest
6
 
 
7
 
import crypt
8
 
import os
9
 
import shutil
10
 
import stat
11
 
import tempfile
12
 
import xml.etree.ElementTree as ET
13
 
import yaml
14
 
 
15
 
 
16
 
def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):
17
 
    if data is None:
18
 
        data = {'HostName': 'FOOHOST'}
19
 
    if pubkeys is None:
20
 
        pubkeys = {}
21
 
 
22
 
    content = """<?xml version="1.0" encoding="utf-8"?>
23
 
<Environment xmlns="http://schemas.dmtf.org/ovf/environment/1"
24
 
 xmlns:oe="http://schemas.dmtf.org/ovf/environment/1"
25
 
 xmlns:wa="http://schemas.microsoft.com/windowsazure"
26
 
 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
27
 
 
28
 
 <wa:ProvisioningSection><wa:Version>1.0</wa:Version>
29
 
 <LinuxProvisioningConfigurationSet
30
 
  xmlns="http://schemas.microsoft.com/windowsazure"
31
 
  xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
32
 
  <ConfigurationSetType>LinuxProvisioningConfiguration</ConfigurationSetType>
33
 
    """
34
 
    for key, dval in data.items():
35
 
        if isinstance(dval, dict):
36
 
            val = dval.get('text')
37
 
            attrs = ' ' + ' '.join(["%s='%s'" % (k, v) for k, v in dval.items()
38
 
                                    if k != 'text'])
39
 
        else:
40
 
            val = dval
41
 
            attrs = ""
42
 
        content += "<%s%s>%s</%s>\n" % (key, attrs, val, key)
43
 
 
44
 
    if userdata:
45
 
        content += "<UserData>%s</UserData>\n" % (b64e(userdata))
46
 
 
47
 
    if pubkeys:
48
 
        content += "<SSH><PublicKeys>\n"
49
 
        for fp, path, value in pubkeys:
50
 
            content += " <PublicKey>"
51
 
            if fp and path:
52
 
                content += ("<Fingerprint>%s</Fingerprint><Path>%s</Path>" %
53
 
                            (fp, path))
54
 
            if value:
55
 
                content += "<Value>%s</Value>" % value
56
 
            content += "</PublicKey>\n"
57
 
        content += "</PublicKeys></SSH>"
58
 
    content += """
59
 
 </LinuxProvisioningConfigurationSet>
60
 
 </wa:ProvisioningSection>
61
 
 <wa:PlatformSettingsSection><wa:Version>1.0</wa:Version>
62
 
 <PlatformSettings xmlns="http://schemas.microsoft.com/windowsazure"
63
 
  xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
64
 
 <KmsServerHostname>kms.core.windows.net</KmsServerHostname>
65
 
 <ProvisionGuestAgent>false</ProvisionGuestAgent>
66
 
 <GuestAgentPackageName i:nil="true" />
67
 
 </PlatformSettings></wa:PlatformSettingsSection>
68
 
</Environment>
69
 
    """
70
 
 
71
 
    return content
72
 
 
73
 
 
74
 
class TestAzureDataSource(TestCase):
75
 
 
76
 
    def setUp(self):
77
 
        super(TestAzureDataSource, self).setUp()
78
 
        if PY26:
79
 
            raise SkipTest("Does not work on python 2.6")
80
 
        self.tmp = tempfile.mkdtemp()
81
 
        self.addCleanup(shutil.rmtree, self.tmp)
82
 
 
83
 
        # patch cloud_dir, so our 'seed_dir' is guaranteed empty
84
 
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
85
 
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
86
 
 
87
 
        self.patches = ExitStack()
88
 
        self.addCleanup(self.patches.close)
89
 
 
90
 
        super(TestAzureDataSource, self).setUp()
91
 
 
92
 
    def apply_patches(self, patches):
93
 
        for module, name, new in patches:
94
 
            self.patches.enter_context(mock.patch.object(module, name, new))
95
 
 
96
 
    def _get_ds(self, data):
97
 
 
98
 
        def dsdevs():
99
 
            return data.get('dsdevs', [])
100
 
 
101
 
        def _invoke_agent(cmd):
102
 
            data['agent_invoked'] = cmd
103
 
 
104
 
        def _wait_for_files(flist, _maxwait=None, _naplen=None):
105
 
            data['waited'] = flist
106
 
            return []
107
 
 
108
 
        def _pubkeys_from_crt_files(flist):
109
 
            data['pubkey_files'] = flist
110
 
            return ["pubkey_from: %s" % f for f in flist]
111
 
 
112
 
        if data.get('ovfcontent') is not None:
113
 
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
114
 
                         {'ovf-env.xml': data['ovfcontent']})
115
 
 
116
 
        mod = DataSourceAzure
117
 
        mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
118
 
 
119
 
        self.get_metadata_from_fabric = mock.MagicMock(return_value={
120
 
            'public-keys': [],
121
 
        })
122
 
 
123
 
        self.instance_id = 'test-instance-id'
124
 
 
125
 
        self.apply_patches([
126
 
            (mod, 'list_possible_azure_ds_devs', dsdevs),
127
 
            (mod, 'invoke_agent', _invoke_agent),
128
 
            (mod, 'wait_for_files', _wait_for_files),
129
 
            (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
130
 
            (mod, 'perform_hostname_bounce', mock.MagicMock()),
131
 
            (mod, 'get_hostname', mock.MagicMock()),
132
 
            (mod, 'set_hostname', mock.MagicMock()),
133
 
            (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric),
134
 
            (mod.util, 'read_dmi_data', mock.MagicMock(
135
 
                return_value=self.instance_id)),
136
 
        ])
137
 
 
138
 
        dsrc = mod.DataSourceAzureNet(
139
 
            data.get('sys_cfg', {}), distro=None, paths=self.paths)
140
 
 
141
 
        return dsrc
142
 
 
143
 
    def xml_equals(self, oxml, nxml):
144
 
        """Compare two sets of XML to make sure they are equal"""
145
 
 
146
 
        def create_tag_index(xml):
147
 
            et = ET.fromstring(xml)
148
 
            ret = {}
149
 
            for x in et.iter():
150
 
                ret[x.tag] = x
151
 
            return ret
152
 
 
153
 
        def tags_exists(x, y):
154
 
            for tag in x.keys():
155
 
                self.assertIn(tag, y)
156
 
            for tag in y.keys():
157
 
                self.assertIn(tag, x)
158
 
 
159
 
        def tags_equal(x, y):
160
 
            for x_tag, x_val in x.items():
161
 
                y_val = y.get(x_val.tag)
162
 
                self.assertEqual(x_val.text, y_val.text)
163
 
 
164
 
        old_cnt = create_tag_index(oxml)
165
 
        new_cnt = create_tag_index(nxml)
166
 
        tags_exists(old_cnt, new_cnt)
167
 
        tags_equal(old_cnt, new_cnt)
168
 
 
169
 
    def xml_notequals(self, oxml, nxml):
170
 
        try:
171
 
            self.xml_equals(oxml, nxml)
172
 
        except AssertionError:
173
 
            return
174
 
        raise AssertionError("XML is the same")
175
 
 
176
 
    def test_basic_seed_dir(self):
177
 
        odata = {'HostName': "myhost", 'UserName': "myuser"}
178
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
179
 
                'sys_cfg': {}}
180
 
 
181
 
        dsrc = self._get_ds(data)
182
 
        ret = dsrc.get_data()
183
 
        self.assertTrue(ret)
184
 
        self.assertEqual(dsrc.userdata_raw, "")
185
 
        self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
186
 
        self.assertTrue(os.path.isfile(
187
 
            os.path.join(self.waagent_d, 'ovf-env.xml')))
188
 
 
189
 
    def test_waagent_d_has_0700_perms(self):
190
 
        # we expect /var/lib/waagent to be created 0700
191
 
        dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
192
 
        ret = dsrc.get_data()
193
 
        self.assertTrue(ret)
194
 
        self.assertTrue(os.path.isdir(self.waagent_d))
195
 
        self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700)
196
 
 
197
 
    def test_user_cfg_set_agent_command_plain(self):
198
 
        # set dscfg in via plaintext
199
 
        # we must have friendly-to-xml formatted plaintext in yaml_cfg
200
 
        # not all plaintext is expected to work.
201
 
        yaml_cfg = "{agent_command: my_command}\n"
202
 
        cfg = yaml.safe_load(yaml_cfg)
203
 
        odata = {'HostName': "myhost", 'UserName': "myuser",
204
 
                 'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}}
205
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
206
 
 
207
 
        dsrc = self._get_ds(data)
208
 
        ret = dsrc.get_data()
209
 
        self.assertTrue(ret)
210
 
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])
211
 
 
212
 
    def test_user_cfg_set_agent_command(self):
213
 
        # set dscfg in via base64 encoded yaml
214
 
        cfg = {'agent_command': "my_command"}
215
 
        odata = {'HostName': "myhost", 'UserName': "myuser",
216
 
                 'dscfg': {'text': b64e(yaml.dump(cfg)),
217
 
                           'encoding': 'base64'}}
218
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
219
 
 
220
 
        dsrc = self._get_ds(data)
221
 
        ret = dsrc.get_data()
222
 
        self.assertTrue(ret)
223
 
        self.assertEqual(data['agent_invoked'], cfg['agent_command'])
224
 
 
225
 
    def test_sys_cfg_set_agent_command(self):
226
 
        sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
227
 
        data = {'ovfcontent': construct_valid_ovf_env(data={}),
228
 
                'sys_cfg': sys_cfg}
229
 
 
230
 
        dsrc = self._get_ds(data)
231
 
        ret = dsrc.get_data()
232
 
        self.assertTrue(ret)
233
 
        self.assertEqual(data['agent_invoked'], '_COMMAND')
234
 
 
235
 
    def test_username_used(self):
236
 
        odata = {'HostName': "myhost", 'UserName': "myuser"}
237
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
238
 
 
239
 
        dsrc = self._get_ds(data)
240
 
        ret = dsrc.get_data()
241
 
        self.assertTrue(ret)
242
 
        self.assertEqual(dsrc.cfg['system_info']['default_user']['name'],
243
 
                         "myuser")
244
 
 
245
 
    def test_password_given(self):
246
 
        odata = {'HostName': "myhost", 'UserName': "myuser",
247
 
                 'UserPassword': "mypass"}
248
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
249
 
 
250
 
        dsrc = self._get_ds(data)
251
 
        ret = dsrc.get_data()
252
 
        self.assertTrue(ret)
253
 
        self.assertTrue('default_user' in dsrc.cfg['system_info'])
254
 
        defuser = dsrc.cfg['system_info']['default_user']
255
 
 
256
 
        # default user should be updated username and should not be locked.
257
 
        self.assertEqual(defuser['name'], odata['UserName'])
258
 
        self.assertFalse(defuser['lock_passwd'])
259
 
        # passwd is crypt formated string $id$salt$encrypted
260
 
        # encrypting plaintext with salt value of everything up to final '$'
261
 
        # should equal that after the '$'
262
 
        pos = defuser['passwd'].rfind("$") + 1
263
 
        self.assertEqual(defuser['passwd'],
264
 
                         crypt.crypt(odata['UserPassword'],
265
 
                                     defuser['passwd'][0:pos]))
266
 
 
267
 
    def test_userdata_plain(self):
268
 
        mydata = "FOOBAR"
269
 
        odata = {'UserData': {'text': mydata, 'encoding': 'plain'}}
270
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
271
 
 
272
 
        dsrc = self._get_ds(data)
273
 
        ret = dsrc.get_data()
274
 
        self.assertTrue(ret)
275
 
        self.assertEqual(decode_binary(dsrc.userdata_raw), mydata)
276
 
 
277
 
    def test_userdata_found(self):
278
 
        mydata = "FOOBAR"
279
 
        odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}}
280
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
281
 
 
282
 
        dsrc = self._get_ds(data)
283
 
        ret = dsrc.get_data()
284
 
        self.assertTrue(ret)
285
 
        self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8'))
286
 
 
287
 
    def test_no_datasource_expected(self):
288
 
        # no source should be found if no seed_dir and no devs
289
 
        data = {}
290
 
        dsrc = self._get_ds({})
291
 
        ret = dsrc.get_data()
292
 
        self.assertFalse(ret)
293
 
        self.assertFalse('agent_invoked' in data)
294
 
 
295
 
    def test_cfg_has_pubkeys_fingerprint(self):
296
 
        odata = {'HostName': "myhost", 'UserName': "myuser"}
297
 
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}]
298
 
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
299
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
300
 
                                                      pubkeys=pubkeys)}
301
 
 
302
 
        dsrc = self._get_ds(data)
303
 
        ret = dsrc.get_data()
304
 
        self.assertTrue(ret)
305
 
        for mypk in mypklist:
306
 
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])
307
 
            self.assertIn('pubkey_from', dsrc.metadata['public-keys'][-1])
308
 
 
309
 
    def test_cfg_has_pubkeys_value(self):
310
 
        # make sure that provided key is used over fingerprint
311
 
        odata = {'HostName': "myhost", 'UserName': "myuser"}
312
 
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': 'value1'}]
313
 
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
314
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
315
 
                                                      pubkeys=pubkeys)}
316
 
 
317
 
        dsrc = self._get_ds(data)
318
 
        ret = dsrc.get_data()
319
 
        self.assertTrue(ret)
320
 
 
321
 
        for mypk in mypklist:
322
 
            self.assertIn(mypk, dsrc.cfg['_pubkeys'])
323
 
            self.assertIn(mypk['value'], dsrc.metadata['public-keys'])
324
 
 
325
 
    def test_cfg_has_no_fingerprint_has_value(self):
326
 
        # test value is used when fingerprint not provided
327
 
        odata = {'HostName': "myhost", 'UserName': "myuser"}
328
 
        mypklist = [{'fingerprint': None, 'path': 'path1', 'value': 'value1'}]
329
 
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
330
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata,
331
 
                                                      pubkeys=pubkeys)}
332
 
 
333
 
        dsrc = self._get_ds(data)
334
 
        ret = dsrc.get_data()
335
 
        self.assertTrue(ret)
336
 
 
337
 
        for mypk in mypklist:
338
 
            self.assertIn(mypk['value'], dsrc.metadata['public-keys'])
339
 
 
340
 
    def test_default_ephemeral(self):
341
 
        # make sure the ephemeral device works
342
 
        odata = {}
343
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata),
344
 
                'sys_cfg': {}}
345
 
 
346
 
        dsrc = self._get_ds(data)
347
 
        ret = dsrc.get_data()
348
 
        self.assertTrue(ret)
349
 
        cfg = dsrc.get_config_obj()
350
 
 
351
 
        self.assertEqual(dsrc.device_name_to_device("ephemeral0"),
352
 
                         "/dev/sdb")
353
 
        assert 'disk_setup' in cfg
354
 
        assert 'fs_setup' in cfg
355
 
        self.assertIsInstance(cfg['disk_setup'], dict)
356
 
        self.assertIsInstance(cfg['fs_setup'], list)
357
 
 
358
 
    def test_provide_disk_aliases(self):
359
 
        # Make sure that user can affect disk aliases
360
 
        dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
361
 
        odata = {'HostName': "myhost", 'UserName': "myuser",
362
 
                 'dscfg': {'text': b64e(yaml.dump(dscfg)),
363
 
                           'encoding': 'base64'}}
364
 
        usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'},
365
 
                                  'ephemeral0': False}}
366
 
        userdata = '#cloud-config' + yaml.dump(usercfg) + "\n"
367
 
 
368
 
        ovfcontent = construct_valid_ovf_env(data=odata, userdata=userdata)
369
 
        data = {'ovfcontent': ovfcontent, 'sys_cfg': {}}
370
 
 
371
 
        dsrc = self._get_ds(data)
372
 
        ret = dsrc.get_data()
373
 
        self.assertTrue(ret)
374
 
        cfg = dsrc.get_config_obj()
375
 
        self.assertTrue(cfg)
376
 
 
377
 
    def test_userdata_arrives(self):
378
 
        userdata = "This is my user-data"
379
 
        xml = construct_valid_ovf_env(data={}, userdata=userdata)
380
 
        data = {'ovfcontent': xml}
381
 
        dsrc = self._get_ds(data)
382
 
        dsrc.get_data()
383
 
 
384
 
        self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw)
385
 
 
386
 
    def test_password_redacted_in_ovf(self):
387
 
        odata = {'HostName': "myhost", 'UserName': "myuser",
388
 
                 'UserPassword': "mypass"}
389
 
        data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
390
 
        dsrc = self._get_ds(data)
391
 
        ret = dsrc.get_data()
392
 
 
393
 
        self.assertTrue(ret)
394
 
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
395
 
 
396
 
        # The XML should not be same since the user password is redacted
397
 
        on_disk_ovf = load_file(ovf_env_path)
398
 
        self.xml_notequals(data['ovfcontent'], on_disk_ovf)
399
 
 
400
 
        # Make sure that the redacted password on disk is not used by CI
401
 
        self.assertNotEqual(dsrc.cfg.get('password'),
402
 
                            DataSourceAzure.DEF_PASSWD_REDACTION)
403
 
 
404
 
        # Make sure that the password was really encrypted
405
 
        et = ET.fromstring(on_disk_ovf)
406
 
        for elem in et.iter():
407
 
            if 'UserPassword' in elem.tag:
408
 
                self.assertEqual(DataSourceAzure.DEF_PASSWD_REDACTION,
409
 
                                 elem.text)
410
 
 
411
 
    def test_ovf_env_arrives_in_waagent_dir(self):
412
 
        xml = construct_valid_ovf_env(data={}, userdata="FOODATA")
413
 
        dsrc = self._get_ds({'ovfcontent': xml})
414
 
        dsrc.get_data()
415
 
 
416
 
        # 'data_dir' is '/var/lib/waagent' (walinux-agent's state dir)
417
 
        # we expect that the ovf-env.xml file is copied there.
418
 
        ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
419
 
        self.assertTrue(os.path.exists(ovf_env_path))
420
 
        self.xml_equals(xml, load_file(ovf_env_path))
421
 
 
422
 
    def test_ovf_can_include_unicode(self):
423
 
        xml = construct_valid_ovf_env(data={})
424
 
        xml = u'\ufeff{0}'.format(xml)
425
 
        dsrc = self._get_ds({'ovfcontent': xml})
426
 
        dsrc.get_data()
427
 
 
428
 
    def test_exception_fetching_fabric_data_doesnt_propagate(self):
429
 
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
430
 
        ds.ds_cfg['agent_command'] = '__builtin__'
431
 
        self.get_metadata_from_fabric.side_effect = Exception
432
 
        self.assertFalse(ds.get_data())
433
 
 
434
 
    def test_fabric_data_included_in_metadata(self):
435
 
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
436
 
        ds.ds_cfg['agent_command'] = '__builtin__'
437
 
        self.get_metadata_from_fabric.return_value = {'test': 'value'}
438
 
        ret = ds.get_data()
439
 
        self.assertTrue(ret)
440
 
        self.assertEqual('value', ds.metadata['test'])
441
 
 
442
 
    def test_instance_id_from_dmidecode_used(self):
443
 
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
444
 
        ds.get_data()
445
 
        self.assertEqual(self.instance_id, ds.metadata['instance-id'])
446
 
 
447
 
    def test_instance_id_from_dmidecode_used_for_builtin(self):
448
 
        ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
449
 
        ds.ds_cfg['agent_command'] = '__builtin__'
450
 
        ds.get_data()
451
 
        self.assertEqual(self.instance_id, ds.metadata['instance-id'])
452
 
 
453
 
 
454
 
class TestAzureBounce(TestCase):
455
 
 
456
 
    def mock_out_azure_moving_parts(self):
457
 
        self.patches.enter_context(
458
 
            mock.patch.object(DataSourceAzure, 'invoke_agent'))
459
 
        self.patches.enter_context(
460
 
            mock.patch.object(DataSourceAzure, 'wait_for_files'))
461
 
        self.patches.enter_context(
462
 
            mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs',
463
 
                              mock.MagicMock(return_value=[])))
464
 
        self.patches.enter_context(
465
 
            mock.patch.object(DataSourceAzure,
466
 
                              'find_fabric_formatted_ephemeral_disk',
467
 
                              mock.MagicMock(return_value=None)))
468
 
        self.patches.enter_context(
469
 
            mock.patch.object(DataSourceAzure,
470
 
                              'find_fabric_formatted_ephemeral_part',
471
 
                              mock.MagicMock(return_value=None)))
472
 
        self.patches.enter_context(
473
 
            mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric',
474
 
                              mock.MagicMock(return_value={})))
475
 
        self.patches.enter_context(
476
 
            mock.patch.object(DataSourceAzure.util, 'read_dmi_data',
477
 
                              mock.MagicMock(return_value='test-instance-id')))
478
 
 
479
 
    def setUp(self):
480
 
        super(TestAzureBounce, self).setUp()
481
 
        self.tmp = tempfile.mkdtemp()
482
 
        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
483
 
        self.paths = helpers.Paths({'cloud_dir': self.tmp})
484
 
        self.addCleanup(shutil.rmtree, self.tmp)
485
 
        DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
486
 
        self.patches = ExitStack()
487
 
        self.mock_out_azure_moving_parts()
488
 
        self.get_hostname = self.patches.enter_context(
489
 
            mock.patch.object(DataSourceAzure, 'get_hostname'))
490
 
        self.set_hostname = self.patches.enter_context(
491
 
            mock.patch.object(DataSourceAzure, 'set_hostname'))
492
 
        self.subp = self.patches.enter_context(
493
 
            mock.patch('cloudinit.sources.DataSourceAzure.util.subp'))
494
 
 
495
 
    def tearDown(self):
496
 
        self.patches.close()
497
 
 
498
 
    def _get_ds(self, ovfcontent=None):
499
 
        if ovfcontent is not None:
500
 
            populate_dir(os.path.join(self.paths.seed_dir, "azure"),
501
 
                         {'ovf-env.xml': ovfcontent})
502
 
        return DataSourceAzure.DataSourceAzureNet(
503
 
            {}, distro=None, paths=self.paths)
504
 
 
505
 
    def get_ovf_env_with_dscfg(self, hostname, cfg):
506
 
        odata = {
507
 
            'HostName': hostname,
508
 
            'dscfg': {
509
 
                'text': b64e(yaml.dump(cfg)),
510
 
                'encoding': 'base64'
511
 
            }
512
 
        }
513
 
        return construct_valid_ovf_env(data=odata)
514
 
 
515
 
    def test_disabled_bounce_does_not_change_hostname(self):
516
 
        cfg = {'hostname_bounce': {'policy': 'off'}}
517
 
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
518
 
        self.assertEqual(0, self.set_hostname.call_count)
519
 
 
520
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
521
 
    def test_disabled_bounce_does_not_perform_bounce(
522
 
            self, perform_hostname_bounce):
523
 
        cfg = {'hostname_bounce': {'policy': 'off'}}
524
 
        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
525
 
        self.assertEqual(0, perform_hostname_bounce.call_count)
526
 
 
527
 
    def test_same_hostname_does_not_change_hostname(self):
528
 
        host_name = 'unchanged-host-name'
529
 
        self.get_hostname.return_value = host_name
530
 
        cfg = {'hostname_bounce': {'policy': 'yes'}}
531
 
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
532
 
        self.assertEqual(0, self.set_hostname.call_count)
533
 
 
534
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
535
 
    def test_unchanged_hostname_does_not_perform_bounce(
536
 
            self, perform_hostname_bounce):
537
 
        host_name = 'unchanged-host-name'
538
 
        self.get_hostname.return_value = host_name
539
 
        cfg = {'hostname_bounce': {'policy': 'yes'}}
540
 
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
541
 
        self.assertEqual(0, perform_hostname_bounce.call_count)
542
 
 
543
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
544
 
    def test_force_performs_bounce_regardless(self, perform_hostname_bounce):
545
 
        host_name = 'unchanged-host-name'
546
 
        self.get_hostname.return_value = host_name
547
 
        cfg = {'hostname_bounce': {'policy': 'force'}}
548
 
        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
549
 
        self.assertEqual(1, perform_hostname_bounce.call_count)
550
 
 
551
 
    def test_different_hostnames_sets_hostname(self):
552
 
        expected_hostname = 'azure-expected-host-name'
553
 
        self.get_hostname.return_value = 'default-host-name'
554
 
        self._get_ds(
555
 
            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
556
 
        self.assertEqual(expected_hostname,
557
 
                         self.set_hostname.call_args_list[0][0][0])
558
 
 
559
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
560
 
    def test_different_hostnames_performs_bounce(
561
 
            self, perform_hostname_bounce):
562
 
        expected_hostname = 'azure-expected-host-name'
563
 
        self.get_hostname.return_value = 'default-host-name'
564
 
        self._get_ds(
565
 
            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
566
 
        self.assertEqual(1, perform_hostname_bounce.call_count)
567
 
 
568
 
    def test_different_hostnames_sets_hostname_back(self):
569
 
        initial_host_name = 'default-host-name'
570
 
        self.get_hostname.return_value = initial_host_name
571
 
        self._get_ds(
572
 
            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
573
 
        self.assertEqual(initial_host_name,
574
 
                         self.set_hostname.call_args_list[-1][0][0])
575
 
 
576
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
577
 
    def test_failure_in_bounce_still_resets_host_name(
578
 
            self, perform_hostname_bounce):
579
 
        perform_hostname_bounce.side_effect = Exception
580
 
        initial_host_name = 'default-host-name'
581
 
        self.get_hostname.return_value = initial_host_name
582
 
        self._get_ds(
583
 
            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
584
 
        self.assertEqual(initial_host_name,
585
 
                         self.set_hostname.call_args_list[-1][0][0])
586
 
 
587
 
    def test_environment_correct_for_bounce_command(self):
588
 
        interface = 'int0'
589
 
        hostname = 'my-new-host'
590
 
        old_hostname = 'my-old-host'
591
 
        self.get_hostname.return_value = old_hostname
592
 
        cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}}
593
 
        data = self.get_ovf_env_with_dscfg(hostname, cfg)
594
 
        self._get_ds(data).get_data()
595
 
        self.assertEqual(1, self.subp.call_count)
596
 
        bounce_env = self.subp.call_args[1]['env']
597
 
        self.assertEqual(interface, bounce_env['interface'])
598
 
        self.assertEqual(hostname, bounce_env['hostname'])
599
 
        self.assertEqual(old_hostname, bounce_env['old_hostname'])
600
 
 
601
 
    def test_default_bounce_command_used_by_default(self):
602
 
        cmd = 'default-bounce-command'
603
 
        DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd
604
 
        cfg = {'hostname_bounce': {'policy': 'force'}}
605
 
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
606
 
        self._get_ds(data).get_data()
607
 
        self.assertEqual(1, self.subp.call_count)
608
 
        bounce_args = self.subp.call_args[1]['args']
609
 
        self.assertEqual(cmd, bounce_args)
610
 
 
611
 
    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
612
 
    def test_set_hostname_option_can_disable_bounce(
613
 
            self, perform_hostname_bounce):
614
 
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
615
 
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
616
 
        self._get_ds(data).get_data()
617
 
 
618
 
        self.assertEqual(0, perform_hostname_bounce.call_count)
619
 
 
620
 
    def test_set_hostname_option_can_disable_hostname_set(self):
621
 
        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
622
 
        data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
623
 
        self._get_ds(data).get_data()
624
 
 
625
 
        self.assertEqual(0, self.set_hostname.call_count)
626
 
 
627
 
 
628
 
class TestReadAzureOvf(TestCase):
629
 
    def test_invalid_xml_raises_non_azure_ds(self):
630
 
        invalid_xml = "<foo>" + construct_valid_ovf_env(data={})
631
 
        self.assertRaises(DataSourceAzure.BrokenAzureDataSource,
632
 
                          DataSourceAzure.read_azure_ovf, invalid_xml)
633
 
 
634
 
    def test_load_with_pubkeys(self):
635
 
        mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}]
636
 
        pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist]
637
 
        content = construct_valid_ovf_env(pubkeys=pubkeys)
638
 
        (_md, _ud, cfg) = DataSourceAzure.read_azure_ovf(content)
639
 
        for mypk in mypklist:
640
 
            self.assertIn(mypk, cfg['_pubkeys'])