10
from twisted.test.proto_helpers import FakeDatagramTransport
11
from twisted.internet.defer import succeed, fail
12
from twisted.internet.error import DNSLookupError
14
from landscape.log import format_object
17
class InvalidID(Exception):
18
"""Raised when an invalid ID is used with reactor.cancel_call()."""
21
class CallHookError(Exception):
22
"""Raised when hooking on a reactor incorrectly."""
25
class EventID(object):
27
def __init__(self, event_type, pair):
28
self._event_type = event_type
32
class EventHandlingReactorMixin(object):
35
super(EventHandlingReactorMixin, self).__init__()
36
self._event_handlers = {}
38
def call_on(self, event_type, handler, priority=0):
39
pair = (handler, priority)
41
handlers = self._event_handlers.setdefault(event_type, [])
43
handlers.sort(key=lambda pair: pair[1])
45
return EventID(event_type, pair)
47
def fire(self, event_type, *args, **kwargs):
48
logging.debug("Started firing %s.", event_type)
50
for handler, priority in self._event_handlers.get(event_type, ()):
52
logging.debug("Calling %s for %s with priority %d.",
53
format_object(handler), event_type, priority)
54
results.append(handler(*args, **kwargs))
55
except KeyboardInterrupt:
56
logging.exception("Keyboard interrupt while running event "
57
"handler %s for event type %r with "
58
"args %r %r.", format_object(handler),
59
event_type, args, kwargs)
63
logging.exception("Error running event handler %s for "
64
"event type %r with args %r %r.",
65
format_object(handler), event_type,
67
logging.debug("Finished firing %s.", event_type)
70
def cancel_call(self, id):
71
if type(id) is EventID:
72
self._event_handlers[id._event_type].remove(id._pair)
74
raise InvalidID("EventID instance expected, received %r" % id)
77
class ThreadedCallsReactorMixin(object):
80
super(ThreadedCallsReactorMixin, self).__init__()
81
self._threaded_callbacks = []
83
def call_in_main(self, f, *args, **kwargs):
84
self._threaded_callbacks.append(lambda: f(*args, **kwargs))
86
def call_in_thread(self, callback, errback, f, *args, **kwargs):
87
thread.start_new_thread(self._in_thread,
88
(callback, errback, f, args, kwargs))
90
def _in_thread(self, callback, errback, f, args, kwargs):
92
result = f(*args, **kwargs)
94
exc_info = sys.exc_info()
96
self.call_in_main(logging.error, e, exc_info=exc_info)
98
self.call_in_main(errback, *exc_info)
101
self.call_in_main(callback, result)
103
def _run_threaded_callbacks(self):
104
while self._threaded_callbacks:
106
self._threaded_callbacks.pop(0)()
110
def _hook_threaded_callbacks(self):
111
id = self.call_every(0.5, self._run_threaded_callbacks)
112
self._run_threaded_callbacks_id = id
114
def _unhook_threaded_callbacks(self):
115
self.cancel_call(self._run_threaded_callbacks_id)
118
class ReactorID(object):
120
def __init__(self, timeout):
121
self._timeout = timeout
124
class Reactor(EventHandlingReactorMixin,
125
ThreadedCallsReactorMixin):
128
super(Reactor, self).__init__()
129
self._context = gobject.MainContext()
130
self._mainloop = gobject.MainLoop(context=self._context)
132
def call_later(self, timeout, function, *args, **kwargs):
134
function(*args, **kwargs)
136
timeout = gobject.Timeout(int(timeout*1000))
137
timeout.set_callback(fake_function)
138
timeout.attach(self._context)
139
return ReactorID(timeout)
141
def cancel_call(self, id):
142
if type(id) is ReactorID:
143
id._timeout.destroy()
145
super(Reactor, self).cancel_call(id)
147
def call_every(self, timeout, function, *args, **kwargs):
149
function(*args, **kwargs)
151
timeout = gobject.Timeout(int(timeout*1000))
152
timeout.set_callback(fake_function)
153
timeout.attach(self._context)
154
return ReactorID(timeout)
158
self._hook_threaded_callbacks()
160
self._unhook_threaded_callbacks()
164
self._mainloop.quit()
167
class FakeReactorID(object):
169
def __init__(self, data):
174
class FakeReactor(EventHandlingReactorMixin,
175
ThreadedCallsReactorMixin):
177
@ivar udp_transports: dict of {port: (protocol, transport)}
178
@ivar hosts: Dict of {hostname: ip}. Users should populate this
179
and L{resolve} will use it.
182
super(FakeReactor, self).__init__()
183
self._current_time = 0
185
self.udp_transports = {}
189
return float(self._current_time)
191
def call_later(self, seconds, f, *args, **kwargs):
192
scheduled_time = self._current_time + seconds
193
call = (scheduled_time, f, args, kwargs)
194
bisect.insort_left(self._calls, call)
195
return FakeReactorID(call)
197
def cancel_call(self, id):
198
if type(id) is FakeReactorID:
199
if id._data in self._calls:
200
self._calls.remove(id._data)
203
super(FakeReactor, self).cancel_call(id)
205
def call_every(self, seconds, f, *args, **kwargs):
207
# update the call so that cancellation will continue
208
# working with the same ID. And do it *before* the call
209
# because the call might cancel it!
210
call._data = self.call_later(seconds, fake)._data
215
self.cancel_call(call)
217
call = self.call_later(seconds, fake)
220
def call_in_thread(self, callback, errback, f, *args, **kwargs):
221
self._in_thread(callback, errback, f, args, kwargs)
223
# Running threaded callbacks here doesn't reflect reality, since
224
# they're usually run while the main reactor loop is active.
225
# At the same time, this is convenient as it means we don't need
226
# to run the the reactor with all registered handlers to test for
227
# actions performed on completion of specific events (e.g. firing
228
# exchange will fire exchange-done when ready). IOW, it's easier
229
# to test things synchronously.
230
self._run_threaded_callbacks()
232
def advance(self, seconds):
233
"""Advance this reactor C{seconds} into the future.
235
This is the preferred method for advancing time in your unit tests.
237
while (self._calls and self._calls[0][0]
238
<= self._current_time + seconds):
239
call = self._calls.pop(0)
240
# If we find a call within the time we're advancing,
241
# before calling it, let's advance the time *just* to
242
# when that call is expecting to be run, so that if it
243
# schedules any calls itself they will be relative to
245
seconds -= call[0] - self._current_time
246
self._current_time = call[0]
248
call[1](*call[2], **call[3])
251
self._current_time += seconds
254
"""Continuously advance this reactor until reactor.stop() is called."""
258
self.advance(self._calls[0][0])
262
self._running = False
264
def listen_udp(self, port, protocol):
266
Connect the given protocol with a fake transport, and keep the
267
transport in C{self.udp_transports}.
269
transport = FakeDatagramTransport()
270
self.udp_transports[port] = (protocol, transport)
271
protocol.makeConnection(transport)
274
def resolve(self, hostname):
275
"""Look up the hostname in C{self.hosts}.
277
@return: A Deferred resulting in the IP address.
280
# is it an IP address?
281
socket.inet_aton(hostname)
282
except socket.error: # no
283
if hostname in self.hosts:
284
return succeed(self.hosts[hostname])
286
return fail(DNSLookupError(hostname))
288
return succeed(hostname)
292
class TwistedReactor(EventHandlingReactorMixin,
293
ThreadedCallsReactorMixin):
296
from twisted.internet import reactor
297
from twisted.internet.task import LoopingCall
298
self._LoopingCall = LoopingCall
299
self._reactor = reactor
302
super(TwistedReactor, self).__init__()
305
# Since the reactor is global, we should clean it up when we
306
# initialize one of our wrappers.
307
for call in self._reactor.getDelayedCalls():
311
def call_later(self, *args, **kwargs):
312
return self._reactor.callLater(*args, **kwargs)
314
def call_every(self, seconds, f, *args, **kwargs):
315
lc = self._LoopingCall(f, *args, **kwargs)
316
lc.start(seconds, now=False)
319
def cancel_call(self, id):
320
if isinstance(id, EventID):
321
return EventHandlingReactorMixin.cancel_call(self, id)
322
if isinstance(id, self._LoopingCall):
327
def call_in_main(self, f, *args, **kwargs):
328
self._reactor.callFromThread(f, *args, **kwargs)
336
self._reactor.crash()
343
def listen_udp(self, port, protocol):
344
"""Connect the given protocol with a UDP transport.
346
See L{twisted.internet.interfaces.IReactorUDP.listenUDP}.
348
return self._reactor.listenUDP(port, protocol)
350
def resolve(self, host):
351
"""Look up the IP of the given host.
353
See L{twisted.internet.interfaces.IReactorCore.resolve}.
355
@return: A Deferred resulting in the hostname.
357
return self._reactor.resolve(host)