1
# Copyright 2014 Canonical Ltd. This software is licensed under the
2
# GNU Affero General Public License version 3 (see the file LICENSE).
4
"""Helper functions for writing Juju charms in Python."""
23
'wait_for_page_contents',
28
from collections import namedtuple
31
from shelltoolbox import (
40
from subprocess import CalledProcessError
44
Env = namedtuple('Env', 'uid gid home')
45
# We create a juju_status Command here because it makes testing much,
47
juju_status = lambda: command('juju')('status')
50
def log(message, juju_log=command('juju-log')):
51
return juju_log('--', message)
55
log("--> Entering {}".format(script_name()))
59
log("<-- Exiting {}".format(script_name()))
63
_config_get = command('config-get', '--format=json')
64
return json.loads(_config_get())
67
def relation_get(attribute=None, unit=None, rid=None):
68
cmd = command('relation-get')
69
if attribute is None and unit is None and rid is None:
75
if attribute is not None:
76
_args.append(attribute)
79
return cmd(*_args).strip()
82
def relation_set(**kwargs):
83
cmd = command('relation-set')
84
args = ['{}={}'.format(k, v) for k, v in kwargs.items()]
88
def relation_ids(relation_name):
89
cmd = command('relation-ids')
90
args = [relation_name]
91
return cmd(*args).split()
94
def relation_list(rid=None):
95
cmd = command('relation-list')
100
return cmd(*args).split()
103
def config_get(attribute):
104
cmd = command('config-get')
106
return cmd(*args).strip()
109
def unit_get(attribute):
110
cmd = command('unit-get')
112
return cmd(*args).strip()
115
def open_port(port, protocol="TCP"):
116
cmd = command('open-port')
117
args = ['{}/{}'.format(port, protocol)]
121
def close_port(port, protocol="TCP"):
122
cmd = command('close-port')
123
args = ['{}/{}'.format(port, protocol)]
132
def service_control(service_name, action):
133
cmd = command('service')
134
args = [service_name, action]
136
if action == RESTART:
139
except CalledProcessError:
140
service_control(service_name, START)
143
except CalledProcessError:
144
log("Failed to perform {} on service {}".format(action, service_name))
147
def configure_source(update=False):
148
source = config_get('source')
149
if (source.startswith('ppa:') or
150
source.startswith('cloud:') or
151
source.startswith('http:')):
152
run('add-apt-repository', source)
153
if source.startswith("http:"):
154
run('apt-key', 'import', config_get('key'))
156
run('apt-get', 'update')
159
def make_charm_config_file(charm_config):
160
charm_config_file = tempfile.NamedTemporaryFile()
161
charm_config_file.write(yaml.dump(charm_config))
162
charm_config_file.flush()
163
# The NamedTemporaryFile instance is returned instead of just the name
164
# because we want to take advantage of garbage collection-triggered
165
# deletion of the temp file when it goes out of scope in the caller.
166
return charm_config_file
169
def unit_info(service_name, item_name, data=None, unit=None):
171
data = yaml.safe_load(juju_status())
172
service = data['services'].get(service_name)
174
# XXX 2012-02-08 gmb:
175
# This allows us to cope with the race condition that we
176
# have between deploying a service and having it come up in
177
# `juju status`. We could probably do with cleaning it up so
178
# that it fails a bit more noisily after a while.
180
units = service['units']
182
item = units[unit][item_name]
184
# It might seem odd to sort the units here, but we do it to
185
# ensure that when no unit is specified, the first unit for the
186
# service (or at least the one with the lowest number) is the
187
# one whose data gets returned.
188
sorted_unit_names = sorted(units.keys())
189
item = units[sorted_unit_names[0]][item_name]
193
def get_machine_data():
194
return yaml.safe_load(juju_status())['machines']
197
def wait_for_machine(num_machines=1, timeout=300):
198
"""Wait `timeout` seconds for `num_machines` machines to come up.
200
This wait_for... function can be called by other wait_for functions
201
whose timeouts might be too short in situations where only a bare
202
Juju setup has been bootstrapped.
204
:return: A tuple of (num_machines, time_taken). This is used for
207
# You may think this is a hack, and you'd be right. The easiest way
208
# to tell what environment we're working in (LXC vs EC2) is to check
209
# the dns-name of the first machine. If it's localhost we're in LXC
210
# and we can just return here.
211
if get_machine_data()[0]['dns-name'] == 'localhost':
213
start_time = time.time()
215
# Drop the first machine, since it's the Zookeeper and that's
216
# not a machine that we need to wait for. This will only work
217
# for EC2 environments, which is why we return early above if
219
machine_data = get_machine_data()
220
non_zookeeper_machines = [
221
machine_data[key] for key in machine_data.keys()[1:]]
222
if len(non_zookeeper_machines) >= num_machines:
223
all_machines_running = True
224
for machine in non_zookeeper_machines:
225
if machine.get('instance-state') != 'running':
226
all_machines_running = False
228
if all_machines_running:
230
if time.time() - start_time >= timeout:
231
raise RuntimeError('timeout waiting for service to start')
232
time.sleep(SLEEP_AMOUNT)
233
return num_machines, time.time() - start_time
236
def wait_for_unit(service_name, timeout=480):
237
"""Wait `timeout` seconds for a given service name to come up."""
238
wait_for_machine(num_machines=1)
239
start_time = time.time()
241
state = unit_info(service_name, 'agent-state')
242
if 'error' in state or state == 'started':
244
if time.time() - start_time >= timeout:
245
raise RuntimeError('timeout waiting for service to start')
246
time.sleep(SLEEP_AMOUNT)
247
if state != 'started':
248
raise RuntimeError('unit did not start, agent-state: ' + state)
251
def wait_for_relation(service_name, relation_name, timeout=120):
252
"""Wait `timeout` seconds for a given relation to come up."""
253
start_time = time.time()
255
relation = unit_info(service_name, 'relations').get(relation_name)
256
if relation is not None and relation['state'] == 'up':
258
if time.time() - start_time >= timeout:
259
raise RuntimeError('timeout waiting for relation to be up')
260
time.sleep(SLEEP_AMOUNT)
263
def wait_for_page_contents(url, contents, timeout=120, validate=None):
265
validate = operator.contains
266
start_time = time.time()
269
stream = urllib2.urlopen(url)
270
except (urllib2.HTTPError, urllib2.URLError):
274
if validate(page, contents):
276
if time.time() - start_time >= timeout:
277
raise RuntimeError('timeout waiting for contents of ' + url)
278
time.sleep(SLEEP_AMOUNT)