~cjwatson/rabbitfixture/rabbitmq-server-3.8.10

« back to all changes in this revision

Viewing changes to rabbitfixture/server.py

  • Committer: Gavin Panella
  • Date: 2013-09-16 16:45:09 UTC
  • mfrom: (29.1.5 untangle-host)
  • Revision ID: gavin.panella@canonical.com-20130916164509-hbs81wfxdd3s0k37
[r=rvba][bug=1225980] Fix port reuse issues when restarting the fixture.

Previously RabbitMQ was started on all addresses, the fq_hostname
property was incorrectly using gethostname() regardless of the
hostname given, and allocate_ports() only ever checked ports on
localhost.

Show diffs side-by-side

added added

removed removed

Lines of Context:
41
41
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)
42
42
 
43
43
 
44
 
def allocate_ports(n=1):
45
 
    """Allocate `n` unused ports.
 
44
def get_port(socket):
 
45
    """Return the port to which a socket is bound."""
 
46
    addr, port = socket.getsockname()
 
47
    return port
 
48
 
 
49
 
 
50
def allocate_ports(*addrs):
 
51
    """Allocate `len(addrs)` unused ports.
 
52
 
 
53
    A port is allocated for each element in `addrs`.
46
54
 
47
55
    There is a small race condition here (between the time we allocate the
48
56
    port, and the time it actually gets used), but for the purposes for which
49
57
    this function gets used it isn't a problem in practice.
50
58
    """
51
 
    sockets = map(lambda _: socket.socket(), xrange(n))
 
59
    sockets = [socket.socket() for addr in addrs]
52
60
    try:
53
 
        for s in sockets:
54
 
            s.bind(('localhost', 0))
55
 
        return map(lambda s: s.getsockname()[1], sockets)
 
61
        for addr, sock in zip(addrs, sockets):
 
62
            sock.bind((addr, 0))
 
63
        return [get_port(sock) for sock in sockets]
56
64
    finally:
57
 
        for s in sockets:
58
 
            s.close()
 
65
        for sock in sockets:
 
66
            sock.close()
59
67
 
60
68
 
61
69
# Pattern to parse rabbitctl status output to find the nodename of a running
107
115
        if self.hostname is None:
108
116
            self.hostname = 'localhost'
109
117
        if self.port is None:
110
 
            [self.port] = allocate_ports(1)
 
118
            [self.port] = allocate_ports(self.hostname)
111
119
        if self.homedir is None:
112
120
            self.homedir = self.useFixture(TempDir()).path
113
121
        if self.mnesiadir is None:
121
129
    @property
122
130
    def fq_nodename(self):
123
131
        """The node of the RabbitMQ that is being exported."""
124
 
        # Note that socket.gethostname is recommended by the rabbitctl manpage
125
 
        # even though we're always on localhost, its what the erlang cluster
126
 
        # code wants.
127
 
        return "%s@%s" % (self.nodename, socket.gethostname())
 
132
        return "%s@%s" % (self.nodename, self.hostname)
128
133
 
129
134
 
130
135
class RabbitServerEnvironment(Fixture):
134
139
 
135
140
    - ``RABBITMQ_MNESIA_BASE``
136
141
    - ``RABBITMQ_LOG_BASE``
 
142
    - ``RABBITMQ_NODE_IP_ADDRESS``
137
143
    - ``RABBITMQ_NODE_PORT``
138
144
    - ``RABBITMQ_NODENAME``
139
145
    - ``RABBITMQ_PLUGINS_DIR``
156
162
        self.useFixture(EnvironmentVariableFixture(
157
163
            "RABBITMQ_LOG_BASE", self.config.homedir))
158
164
        self.useFixture(EnvironmentVariableFixture(
 
165
            "RABBITMQ_NODE_IP_ADDRESS",
 
166
            socket.gethostbyname(self.config.hostname)))
 
167
        self.useFixture(EnvironmentVariableFixture(
159
168
            "RABBITMQ_NODE_PORT", str(self.config.port)))
160
169
        self.useFixture(EnvironmentVariableFixture(
161
170
            "RABBITMQ_NODENAME", self.config.fq_nodename))
162
171
        self.useFixture(EnvironmentVariableFixture(
163
172
            "RABBITMQ_PLUGINS_DIR", self.config.pluginsdir))
164
173
        self._errors = []
165
 
        self.addDetail('rabbit-errors',
166
 
            Content(UTF8_TEXT, self._get_errors))
 
174
        self.addDetail('rabbit-errors', Content(
 
175
            UTF8_TEXT, self._get_errors))
167
176
 
168
177
    def _get_errors(self):
169
178
        """Yield all errors as UTF-8 encoded text."""