1
# Copyright 2012, 2013 Canonical Ltd. This software is licensed under the
2
# GNU Affero General Public License version 3 (see the file LICENSE).
4
"""Test custom commissioning scripts."""
6
from __future__ import (
15
from inspect import getsource
16
from io import BytesIO
22
from random import randint
24
from subprocess import (
30
from textwrap import dedent
33
from maasserver.fields import MAC
34
from maasserver.testing import reload_object
35
from maasserver.testing.factory import factory
36
from maasserver.testing.testcase import MAASServerTestCase
37
from maastesting.matchers import ContainsAll
38
from maastesting.utils import sample_binary_data
39
from metadataserver.fields import Bin
40
from metadataserver.models import (
42
commissioningscript as cs_module,
44
from metadataserver.models.commissioningscript import (
46
extract_router_mac_addresses,
47
make_function_call_script,
51
from testtools.content import text_content
54
def open_tarfile(content):
55
"""Open tar file from raw binary data."""
56
return tarfile.open(fileobj=BytesIO(content))
59
def make_script_name(base_name=None, number=None):
60
"""Make up a name for a commissioning script."""
64
number = randint(0, 99)
65
return factory.make_name(
66
'%0.2d-%s' % (number, factory.make_name(base_name)))
69
class TestCommissioningScriptManager(MAASServerTestCase):
71
def test_get_archive_wraps_scripts_in_tar(self):
72
script = factory.make_commissioning_script()
73
path = os.path.join(ARCHIVE_PREFIX, script.name)
74
archive = open_tarfile(CommissioningScript.objects.get_archive())
75
self.assertTrue(archive.getmember(path).isfile())
76
self.assertEqual(script.content, archive.extractfile(path).read())
78
def test_get_archive_wraps_all_scripts(self):
79
scripts = {factory.make_commissioning_script() for counter in range(3)}
80
archive = open_tarfile(CommissioningScript.objects.get_archive())
84
os.path.join(ARCHIVE_PREFIX, script.name)
88
def test_get_archive_supports_binary_scripts(self):
89
script = factory.make_commissioning_script(content=sample_binary_data)
90
path = os.path.join(ARCHIVE_PREFIX, script.name)
91
archive = open_tarfile(CommissioningScript.objects.get_archive())
92
self.assertEqual(script.content, archive.extractfile(path).read())
94
def test_get_archive_includes_builtin_scripts(self):
95
name = factory.make_name('00-maas')
96
path = os.path.join(ARCHIVE_PREFIX, name)
97
content = factory.getRandomString().encode('ascii')
98
data = dict(name=name, content=content, hook='hook')
99
self.patch(cs_module, 'BUILTIN_COMMISSIONING_SCRIPTS', {name: data})
100
archive = open_tarfile(CommissioningScript.objects.get_archive())
101
self.assertIn(path, archive.getnames())
102
self.assertEqual(content, archive.extractfile(path).read())
104
def test_get_archive_sets_sensible_mode(self):
105
for counter in range(3):
106
factory.make_commissioning_script()
107
archive = open_tarfile(CommissioningScript.objects.get_archive())
108
self.assertEqual({0755}, {info.mode for info in archive.getmembers()})
110
def test_get_archive_initializes_file_timestamps(self):
111
# The mtime on a file inside the tarball is reasonable.
112
# It would otherwise default to the Epoch, and GNU tar warns
113
# annoyingly about improbably old files.
114
start_time = floor(time.time())
115
script = factory.make_commissioning_script()
116
path = os.path.join(ARCHIVE_PREFIX, script.name)
117
archive = open_tarfile(CommissioningScript.objects.get_archive())
118
timestamp = archive.getmember(path).mtime
119
end_time = ceil(time.time())
120
self.assertGreaterEqual(timestamp, start_time)
121
self.assertLessEqual(timestamp, end_time)
124
class TestCommissioningScript(MAASServerTestCase):
126
def test_scripts_may_be_binary(self):
127
name = make_script_name()
128
CommissioningScript.objects.create(
129
name=name, content=Bin(sample_binary_data))
130
stored_script = CommissioningScript.objects.get(name=name)
131
self.assertEqual(sample_binary_data, stored_script.content)
134
class TestMakeFunctionCallScript(MAASServerTestCase):
136
def run_script(self, script):
137
script_filename = self.make_file("test.py", script)
138
os.chmod(script_filename, 0700)
140
return check_output((script_filename,), stderr=STDOUT)
141
except CalledProcessError as error:
142
self.addDetail("output", text_content(error.output))
145
def test_basic(self):
146
def example_function():
147
print("Hello, World!", end="")
148
script = make_function_call_script(example_function)
149
self.assertEqual(b"Hello, World!", self.run_script(script))
151
def test_positional_args_get_passed_through(self):
152
def example_function(a, b):
153
print("a=%s, b=%d" % (a, b), end="")
154
script = make_function_call_script(example_function, "foo", 12345)
155
self.assertEqual(b"a=foo, b=12345", self.run_script(script))
157
def test_keyword_args_get_passed_through(self):
158
def example_function(a, b):
159
print("a=%s, b=%d" % (a, b), end="")
160
script = make_function_call_script(example_function, a="foo", b=12345)
161
self.assertEqual(b"a=foo, b=12345", self.run_script(script))
163
def test_positional_and_keyword_args_get_passed_through(self):
164
def example_function(a, b):
165
print("a=%s, b=%d" % (a, b), end="")
166
script = make_function_call_script(example_function, "foo", b=12345)
167
self.assertEqual(b"a=foo, b=12345", self.run_script(script))
169
def test_non_ascii_positional_args_are_passed_without_corruption(self):
170
def example_function(text):
171
print(repr(text), end="")
172
script = make_function_call_script(example_function, "abc\u1234")
173
self.assertEqual(b"u'abc\\u1234'", self.run_script(script))
175
def test_non_ascii_keyword_args_are_passed_without_corruption(self):
176
def example_function(text):
177
print(repr(text), end="")
178
script = make_function_call_script(example_function, text="abc\u1234")
179
self.assertEqual(b"u'abc\\u1234'", self.run_script(script))
181
def test_structured_arguments_are_passed_though_too(self):
182
# Anything that can be JSON serialized can be passed.
183
def example_function(arg):
184
if arg == {"123": "foo", "bar": [4, 5, 6]}:
187
print("Unequal, got %s" % repr(arg))
188
script = make_function_call_script(
189
example_function, {"123": "foo", "bar": [4, 5, 6]})
190
self.assertEqual(b"Equal\n", self.run_script(script))
193
def isolate_function(function):
194
"""Recompile the given function in an empty namespace."""
195
source = dedent(getsource(function))
196
modcode = compile(source, "lldpd.py", "exec")
198
exec(modcode, namespace)
199
return namespace[function.__name__]
202
class TestLLDPScripts(MAASServerTestCase):
204
def test_install_script_installs_configures_and_restarts(self):
205
config_file = self.make_file("config", "# ...")
206
check_call = self.patch(subprocess, "check_call")
207
lldpd_install = isolate_function(cs_module.lldpd_install)
208
lldpd_install(config_file)
209
# lldpd is installed and restarted.
211
check_call.call_args_list,
212
[call(("apt-get", "install", "--yes", "lldpd")),
213
call(("service", "lldpd", "restart"))])
214
# lldpd's config was updated to include an updated DAEMON_ARGS
215
# setting. Note that the new comment is on a new line, and
216
# does not interfere with existing config.
217
config_expected = dedent("""\
219
# Configured by MAAS:
220
DAEMON_ARGS="-c -f -s -e -r"
222
with open(config_file, "rb") as fd:
223
config_observed = fd.read()
224
self.assertEqual(config_expected, config_observed)
226
def test_wait_script_waits_for_lldpd(self):
227
self.patch(os.path, "getmtime").return_value = 10.65
228
self.patch(time, "time").return_value = 14.12
229
self.patch(time, "sleep")
230
reference_file = self.make_file("reference")
231
time_delay = 8.98 # seconds
232
lldpd_wait = isolate_function(cs_module.lldpd_wait)
233
lldpd_wait(reference_file, time_delay)
234
# lldpd_wait checks the mtime of the reference file,
235
os.path.getmtime.assert_called_once_with(reference_file)
236
# and gets the current time,
237
time.time.assert_called_once_with()
238
# then sleeps until time_delay seconds has passed since the
239
# mtime of the reference file.
240
time.sleep.assert_called_once_with(
241
os.path.getmtime.return_value + time_delay -
242
time.time.return_value)
244
def test_capture_calls_lldpdctl(self):
245
check_call = self.patch(subprocess, "check_call")
246
lldpd_capture = isolate_function(cs_module.lldpd_capture)
249
check_call.call_args_list,
250
[call(("lldpctl", "-f", "xml"))])
253
lldp_output_template = """
254
<?xml version="1.0" encoding="UTF-8"?>
255
<lldp label="LLDP neighbors">
260
lldp_output_interface_template = """
261
<interface label="Interface" name="eth1" via="LLDP">
262
<chassis label="Chassis">
263
<id label="ChassisID" type="mac">%s</id>
264
<name label="SysName">switch-name</name>
265
<descr label="SysDescr">HDFD5BG7J</descr>
266
<mgmt-ip label="MgmtIP">192.168.9.9</mgmt-ip>
267
<capability label="Capability" type="Bridge" enabled="on"/>
268
<capability label="Capability" type="Router" enabled="off"/>
274
def make_lldp_output(macs):
275
"""Return an example raw lldp output containing the given MACs."""
276
interfaces = '\n'.join(
277
lldp_output_interface_template % mac
280
script = (lldp_output_template % interfaces).encode('utf8')
284
class TestExtractRouters(MAASServerTestCase):
286
def test_extract_router_mac_addresses_returns_None_when_empty_input(self):
287
self.assertIsNone(extract_router_mac_addresses(''))
289
def test_extract_router_mac_addresses_returns_empty_list(self):
290
lldp_output = make_lldp_output([])
291
self.assertItemsEqual([], extract_router_mac_addresses(lldp_output))
293
def test_extract_router_mac_addresses_returns_routers_list(self):
294
macs = ["11:22:33:44:55:66", "aa:bb:cc:dd:ee:ff"]
295
lldp_output = make_lldp_output(macs)
296
routers = extract_router_mac_addresses(lldp_output)
297
self.assertItemsEqual(macs, routers)
300
class TestSetNodeRouters(MAASServerTestCase):
302
def test_set_node_routers_updates_node(self):
303
node = factory.make_node(routers=None)
304
macs = ["11:22:33:44:55:66", "aa:bb:cc:dd:ee:ff"]
305
lldp_output = make_lldp_output(macs)
306
set_node_routers(node, lldp_output)
307
self.assertItemsEqual(
308
[MAC(mac) for mac in macs], reload_object(node).routers)
310
def test_set_node_routers_updates_node_if_no_routers(self):
311
node = factory.make_node()
312
lldp_output = make_lldp_output([])
313
set_node_routers(node, lldp_output)
314
self.assertItemsEqual([], reload_object(node).routers)