1
"""Deals with the socket communication between the PIMD and driver code.
3
Copyright (C) 2013, Joshua More and Michele Ceriotti
5
This program is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
10
This program is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http.//www.gnu.org/licenses/>.
19
Deals with creating the socket, transmitting and receiving data, accepting and
20
removing different driver routines and the parallelization of the force
24
Status: Simple class to keep track of the status, uses bitwise or to give
25
combinations of different status options.
26
DriverSocket: Class to deal with communication between a client and
28
InterfaceSocket: Host server class. Deals with distribution of all the jobs
29
between the different client servers.
32
Message: Sends a header string through the socket.
35
Disconnected: Raised if client has been disconnected.
36
InvalidStatus: Raised if client has the wrong status. Shouldn't have to be
37
used if the structure of the program is correct.
40
__all__ = ['InterfaceSocket']
44
import socket, select, threading, signal, string, time
45
from ipi.utils.depend import depstrip
46
from ipi.utils.messages import verbosity, warning, info
47
from ipi.utils.softexit import softexit
53
SERVERTIMEOUT = 2.0*TIMEOUT
57
"""Returns a header of standard length HDRLEN."""
59
return string.ljust(string.upper(mystr), HDRLEN)
62
class Disconnected(Exception):
63
"""Disconnected: Raised if client has been disconnected."""
67
class InvalidSize(Exception):
68
"""Disconnected: Raised if client returns forces with inconsistent number of atoms."""
72
class InvalidStatus(Exception):
73
"""InvalidStatus: Raised if client has the wrong status.
75
Shouldn't have to be used if the structure of the program is correct.
81
"""Simple class used to keep track of the status of the client.
83
Uses bitwise or to give combinations of different status options.
84
i.e. Status.Up | Status.Ready would be understood to mean that the client
85
was connected and ready to receive the position and cell data.
88
Disconnected: Flag for if the client has disconnected.
89
Up: Flag for if the client is running.
90
Ready: Flag for if the client has ready to receive position and cell data.
91
NeedsInit: Flag for if the client is ready to receive forcefield
93
HasData: Flag for if the client is ready to send force data.
94
Busy: Flag for if the client is busy.
95
Timeout: Flag for if the connection has timed out.
107
class DriverSocket(socket.socket):
108
"""Deals with communication between the client and driver code.
110
Deals with sending and receiving the data from the driver code. Keeps track
111
of the status of the driver. Initialises the driver forcefield, sends the
112
position and cell data, and receives the force data.
115
_buf: A string buffer to hold the reply from the driver.
116
status: Keeps track of the status of the driver.
117
lastreq: The ID of the last request processed by the client.
118
locked: Flag to mark if the client has been working consistently on one image.
121
def __init__(self, socket):
122
"""Initialises DriverSocket.
125
socket: A socket through which the communication should be done.
128
super(DriverSocket,self).__init__(_sock=socket)
129
self._buf = np.zeros(0,np.byte)
130
self.peername = self.getpeername()
131
self.status = Status.Up
132
self.waitstatus = False
136
def shutdown(self, how=socket.SHUT_RDWR):
138
self.sendall(Message("exit"))
139
self.status = Status.Disconnected
140
super(DriverSocket,self).shutdown(how)
143
"""Waits for driver status."""
145
self.status = Status.Disconnected # sets disconnected as failsafe status, in case _getstatus fails and exceptions are ignored upstream
146
self.status = self._getstatus()
148
def _getstatus(self):
149
"""Gets driver status.
152
An integer labelling the status via bitwise or of the relevant members
156
if not self.waitstatus:
158
readable, writable, errored = select.select([], [self], [])
160
self.sendall(Message("status"))
161
self.waitstatus = True
163
return Status.Disconnected
166
reply = self.recv(HDRLEN)
167
self.waitstatus = False # got status reply
168
except socket.timeout:
169
warning(" @SOCKET: Timeout in status recv!", verbosity.debug )
170
return Status.Up | Status.Busy | Status.Timeout
172
return Status.Disconnected
174
if not len(reply) == HDRLEN:
175
return Status.Disconnected
176
elif reply == Message("ready"):
177
return Status.Up | Status.Ready
178
elif reply == Message("needinit"):
179
return Status.Up | Status.NeedsInit
180
elif reply == Message("havedata"):
181
return Status.Up | Status.HasData
183
warning(" @SOCKET: Unrecognized reply: " + str(reply), verbosity.low )
186
def recvall(self, dest):
187
"""Gets the potential energy, force and virial from the driver.
190
dest: Object to be read into.
193
Disconnected: Raised if client is disconnected.
196
The data read from the socket to be read into dest.
199
blen = dest.itemsize*dest.size
200
if (blen > len(self._buf)):
201
self._buf.resize(blen)
211
bpart = self.recv(blen - bpos)
212
if len(bpart) == 0: raise socket.timeout # There is a problem if this returns no data
213
self._buf[bpos:bpos + len(bpart)] = np.fromstring(bpart, np.byte)
214
except socket.timeout:
215
warning(" @SOCKET: Timeout in status recvall, trying again!", verbosity.low)
218
if ntimeout > NTIMEOUT:
219
warning(" @SOCKET: Couldn't receive within %5d attempts. Time to give up!" % (NTIMEOUT), verbosity.low)
222
if (not timeout and bpart == 0):
226
# post-2.5 version: slightly more compact for modern python versions
229
# bpart = self.recv_into(self._buf[bpos:], blen-bpos)
230
# except socket.timeout:
231
# print " @SOCKET: Timeout in status recvall, trying again!"
234
# if (not timeout and bpart == 0):
235
# raise Disconnected()
237
#TODO this Disconnected() exception currently just causes the program to hang.
238
#This should do something more graceful
240
if np.isscalar(dest):
241
return np.fromstring(self._buf[0:blen], dest.dtype)[0]
243
return np.fromstring(self._buf[0:blen], dest.dtype).reshape(dest.shape)
245
def initialize(self, rid, pars):
246
"""Sends the initialisation string to the driver.
249
rid: The index of the request, i.e. the replica that
250
the force calculation is for.
251
pars: The parameter string to be sent to the driver.
254
InvalidStatus: Raised if the status is not NeedsInit.
257
if self.status & Status.NeedsInit:
259
self.sendall(Message("init"))
260
self.sendall(np.int32(rid))
261
self.sendall(np.int32(len(pars)))
267
raise InvalidStatus("Status in init was " + self.status)
269
def sendpos(self, pos, cell):
270
"""Sends the position and cell data to the driver.
273
pos: An array containing the atom positions.
274
cell: A cell object giving the system box.
277
InvalidStatus: Raised if the status is not Ready.
280
if (self.status & Status.Ready):
282
self.sendall(Message("posdata"))
283
self.sendall(cell.h, 9*8)
284
self.sendall(cell.ih, 9*8)
285
self.sendall(np.int32(len(pos)/3))
286
self.sendall(pos, len(pos)*8)
291
raise InvalidStatus("Status in sendpos was " + self.status)
294
"""Gets the potential energy, force and virial from the driver.
297
InvalidStatus: Raised if the status is not HasData.
298
Disconnected: Raised if the driver has disconnected.
301
A list of the form [potential, force, virial, extra].
304
if (self.status & Status.HasData):
305
self.sendall(Message("getforce"));
309
reply = self.recv(HDRLEN)
310
except socket.timeout:
311
warning(" @SOCKET: Timeout in getforce, trying again!", verbosity.low)
313
if reply == Message("forceready"):
316
warning(" @SOCKET: Unexpected getforce reply: %s" % (reply), verbosity.low)
320
raise InvalidStatus("Status in getforce was " + self.status)
323
mu = self.recvall(mu)
326
mlen = self.recvall(mlen)
327
mf = np.zeros(3*mlen,np.float64)
328
mf = self.recvall(mf)
330
mvir = np.zeros((3,3),np.float64)
331
mvir = self.recvall(mvir)
333
#! Machinery to return a string as an "extra" field. Comment if you are using a old patched driver that does not return anything!
335
mlen = self.recvall(mlen)
337
mxtra = np.zeros(mlen,np.character)
338
mxtra = self.recvall(mxtra)
339
mxtra = "".join(mxtra)
343
#!TODO must set up a machinery to intercept the "extra" return field
344
return [mu, mf, mvir, mxtra]
347
class InterfaceSocket(object):
348
"""Host server class.
350
Deals with distribution of all the jobs between the different client servers
351
and both initially and as clients either finish or are disconnected.
352
Deals with cleaning up after all calculations are done. Also deals with the
353
threading mechanism, and cleaning up if the interface is killed.
356
address: A string giving the name of the host network.
357
port: An integer giving the port the socket will be using.
358
slots: An integer giving the maximum allowed backlog of queued clients.
359
mode: A string giving the type of socket used.
360
latency: A float giving the number of seconds the interface will wait
361
before updating the client list.
362
timeout: A float giving a timeout limit for considering a calculation dead
363
and dropping the connection.
364
dopbc: A boolean which decides whether or not to fold the bead positions
365
back into the unit cell before passing them to the client code.
366
server: The socket used for data transmition.
367
clients: A list of the driver clients connected to the server.
368
requests: A list of all the jobs required in the current PIMD step.
369
jobs: A list of all the jobs currently running.
370
_poll_thread: The thread the poll loop is running on.
371
_prev_kill: Holds the signals to be sent to clean up the main thread
372
when a kill signal is sent.
373
_poll_true: A boolean giving whether the thread is alive.
374
_poll_iter: An integer used to decide whether or not to check for
375
client connections. It is used as a counter, once it becomes higher
376
than the pre-defined number of steps between checks the socket will
377
update the list of clients and then be reset to zero.
380
def __init__(self, address="localhost", port=31415, slots=4, mode="unix", latency=1e-3, timeout=1.0, dopbc=True):
381
"""Initialises interface.
384
address: An optional string giving the name of the host server.
385
Defaults to 'localhost'.
386
port: An optional integer giving the port number. Defaults to 31415.
387
slots: An optional integer giving the maximum allowed backlog of
388
queueing clients. Defaults to 4.
389
mode: An optional string giving the type of socket. Defaults to 'unix'.
390
latency: An optional float giving the time in seconds the socket will
391
wait before updating the client list. Defaults to 1e-3.
392
timeout: Length of time waiting for data from a client before we assume
393
the connection is dead and disconnect the client.
394
dopbc: A boolean which decides whether or not to fold the bead positions
395
back into the unit cell before passing them to the client code.
398
NameError: Raised if mode is not 'unix' or 'inet'.
401
self.address = address
405
self.latency = latency
406
self.timeout = timeout
408
self._poll_thread = None
410
self._poll_true = False
414
"""Creates a new socket.
416
Used so that we can create a interface object without having to also
417
create the associated socket object.
420
if self.mode == "unix":
421
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
423
self.server.bind("/tmp/ipi_" + self.address)
424
info("Created unix socket with address " + self.address, verbosity.medium)
426
raise ValueError("Error opening unix socket. Check if a file " + ("/tmp/ipi_" + self.address) + " exists, and remove it if unused.")
428
elif self.mode == "inet":
429
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
430
self.server.bind((self.address,self.port))
431
info("Created inet socket with address " + self.address + " and port number " + str(self.port), verbosity.medium)
433
raise NameError("InterfaceSocket mode " + self.mode + " is not implemented (should be unix/inet)")
435
self.server.listen(self.slots)
436
self.server.settimeout(SERVERTIMEOUT)
442
"""Closes down the socket."""
444
info(" @SOCKET: Shutting down the driver interface.", verbosity.low )
446
for c in self.clients[:]:
447
if (c.status & Status.Up):
448
c.shutdown(socket.SHUT_RDWR)
450
self.server.shutdown(socket.SHUT_RDWR)
452
if self.mode == "unix":
453
os.unlink("/tmp/ipi_" + self.address)
455
def queue(self, atoms, cell, pars=None, reqid=0):
458
Note that the pars dictionary need to be sent as a string of a
459
standard format so that the initialisation of the driver can be done.
462
atoms: An Atoms object giving the atom positions.
463
cell: A Cell object giving the system box.
464
pars: An optional dictionary giving the parameters to be sent to the
465
driver for initialisation. Defaults to {}.
466
reqid: An optional integer that identifies requests of the same type,
470
A list giving the status of the request of the form {'atoms': Atoms
471
object giving the atom positions, 'cell': Cell object giving the
472
system box, 'pars': parameter string, 'result': holds the result as a
473
list once the computation is done, 'status': a string labelling the
474
status, 'id': the id of the request, usually the bead number, 'start':
475
the starting time for the calculation, used to check for timeouts.}.
481
for k,v in pars.items():
482
par_str += k + " : " + str(v) + " , "
486
# APPLY PBC -- this is useful for codes such as LAMMPS that don't do full PBC when computing distances
487
pbcpos = depstrip(atoms.q).copy()
489
cell.array_pbc(pbcpos)
491
newreq = {"pos": pbcpos, "cell": cell, "pars": par_str,
492
"result": None, "status": "Queued", "id": reqid,
495
self.requests.append(newreq)
498
def release(self, request):
499
"""Empties the list of requests once finished.
502
request: A list of requests that are done.
505
if request in self.requests:
506
self.requests.remove(request)
508
def pool_update(self):
509
"""Deals with keeping the pool of client drivers up-to-date during a
510
force calculation step.
512
Deals with maintaining the client list. Clients that have
513
disconnected are removed and their jobs removed from the list of
514
running jobs and new clients are connected to the server.
517
for c in self.clients[:]:
518
if not (c.status & Status.Up):
520
warning(" @SOCKET: Client " + str(c.peername) +" died or got unresponsive(C). Removing from the list.", verbosity.low)
521
c.shutdown(socket.SHUT_RDWR)
525
c.status = Status.Disconnected
526
self.clients.remove(c)
527
for [k,j] in self.jobs[:]:
529
self.jobs = [ w for w in self.jobs if not ( w[0] is k and w[1] is j ) ] # removes pair in a robust way
530
#self.jobs.remove([k,j])
531
k["status"] = "Queued"
536
readable, writable, errored = select.select([self.server], [], [], 0.0)
537
if self.server in readable:
538
client, address = self.server.accept()
539
client.settimeout(TIMEOUT)
540
driver = DriverSocket(client)
541
info(" @SOCKET: Client asked for connection from "+ str( address ) +". Now hand-shaking.", verbosity.low)
543
if (driver.status | Status.Up):
544
self.clients.append(driver)
545
info(" @SOCKET: Handshaking was successful. Added to the client list.", verbosity.low)
547
warning(" @SOCKET: Handshaking failed. Dropping connection.", verbosity.low)
548
client.shutdown(socket.SHUT_RDWR)
553
def pool_distribute(self):
554
"""Deals with keeping the list of jobs up-to-date during a force
557
Deals with maintaining the jobs list. Gets data from drivers that have
558
finished their calculation and removes that job from the list of running
559
jobs, adds jobs to free clients and initialises the forcefields of new
563
for c in self.clients:
564
if c.status == Status.Disconnected : # client disconnected. force a pool_update
565
self._poll_iter = UPDATEFREQ
567
if not c.status & ( Status.Ready | Status.NeedsInit ):
570
for [r,c] in self.jobs[:]:
571
if c.status & Status.HasData:
573
r["result"] = c.getforce()
574
if len(r["result"][1]) != len(r["pos"]):
577
c.status = Status.Disconnected
580
warning(" @SOCKET: Client returned an inconsistent number of forces. Will mark as disconnected and try to carry on.", verbosity.low)
581
c.status = Status.Disconnected
584
warning(" @SOCKET: Client got in a awkward state during getforce. Will mark as disconnected and try to carry on.", verbosity.low)
585
c.status = Status.Disconnected
588
while c.status & Status.Busy: # waits, but check if we got stuck.
589
if self.timeout > 0 and r["start"] > 0 and time.time() - r["start"] > self.timeout:
590
warning(" @SOCKET: Timeout! HASDATA for bead " + str(r["id"]) + " has been running for " + str(time.time() - r["start"]) + " sec.", verbosity.low)
591
warning(" @SOCKET: Client " + str(c.peername) + " died or got unresponsive(A). Disconnecting.", verbosity.low)
593
c.shutdown(socket.SHUT_RDWR)
597
c.status = Status.Disconnected
600
if not (c.status & Status.Up):
601
warning(" @SOCKET: Client died a horrible death while getting forces. Will try to cleanup.", verbosity.low)
604
c.lastreq = r["id"] # saves the ID of the request that the client has just processed
605
self.jobs = [ w for w in self.jobs if not ( w[0] is r and w[1] is c ) ] # removes pair in a robust way
607
if self.timeout > 0 and c.status != Status.Disconnected and r["start"] > 0 and time.time() - r["start"] > self.timeout:
608
warning(" @SOCKET: Timeout! Request for bead " + str( r["id"]) + " has been running for " + str(time.time() - r["start"]) + " sec.", verbosity.low)
609
warning(" @SOCKET: Client " + str(c.peername) + " died or got unresponsive(B). Disconnecting.",verbosity.low)
611
c.shutdown(socket.SHUT_RDWR)
614
warning(" @SOCKET: could not shut down cleanly the socket. %s: %s in file '%s' on line %d" % (e[0].__name__, e[1], os.path.basename(e[2].tb_frame.f_code.co_filename), e[2].tb_lineno), verbosity.low )
617
c.status = Status.Disconnected
619
freec = self.clients[:]
620
for [r2, c] in self.jobs:
623
pendr = self.requests[:]
624
pendr = [ r for r in self.requests if r["status"] == "Queued" ]
628
# first, makes sure that the client is REALLY free
629
if not (fc.status & Status.Up):
630
self.clients.remove(fc) # if fc is in freec it can't be associated with a job (we just checked for that above)
632
if fc.status & Status.HasData:
634
if not (fc.status & (Status.Ready | Status.NeedsInit | Status.Busy) ):
635
warning(" @SOCKET: Client " + str(fc.peername) + " is in an unexpected status " + str(fc.status) + " at (1). Will try to keep calm and carry on.", verbosity.low)
637
for match_ids in ( "match", "none", "free", "any" ):
639
if match_ids == "match" and not fc.lastreq is r["id"]:
641
elif match_ids == "none" and not fc.lastreq is None:
643
elif match_ids == "free" and fc.locked:
646
info(" @SOCKET: Assigning [%5s] request id %4s to client with last-id %4s (% 3d/% 3d : %s)" % (match_ids, str(r["id"]), str(fc.lastreq), self.clients.index(fc), len(self.clients), str(fc.peername) ), verbosity.high )
648
while fc.status & Status.Busy:
650
if fc.status & Status.NeedsInit:
651
fc.initialize(r["id"], r["pars"])
653
while fc.status & Status.Busy: # waits for initialization to finish. hopefully this is fast
655
if fc.status & Status.Ready:
656
fc.sendpos(r["pos"], r["cell"])
657
r["status"] = "Running"
658
r["start"] = time.time() # sets start time for the request
660
self.jobs.append([r,fc])
661
fc.locked = (fc.lastreq is r["id"])
663
# removes r from the list of pending jobs
664
pendr = [nr for nr in pendr if (not nr is r)]
667
warning(" @SOCKET: Client " + str(fc.peername) + " is in an unexpected status " + str(fc.status) + " at (2). Will try to keep calm and carry on.", verbosity.low)
669
break # doesn't do a second (or third) round if it managed
672
def _kill_handler(self, signal, frame):
673
"""Deals with handling a kill call gracefully.
675
Prevents any of the threads becoming zombies, by intercepting a
676
kill signal using the standard python function signal.signal() and
677
then closing the socket and the spawned threads before closing the main
678
thread. Called when signals SIG_INT and SIG_TERM are received.
681
signal: An integer giving the signal number of the signal received
683
frame: Current stack frame.
686
warning(" @SOCKET: Kill signal. Trying to make a clean exit.", verbosity.low)
689
softexit.trigger(" @SOCKET: Kill signal received")
695
if signal in self._prev_kill:
696
self._prev_kill[signal](signal, frame)
698
def _poll_loop(self):
699
"""The main thread loop.
701
Runs until either the program finishes or a kill call is sent. Updates
702
the pool of clients every UPDATEFREQ loops and loops every latency
703
seconds until _poll_true becomes false.
706
info(" @SOCKET: Starting the polling thread main loop.", verbosity.low)
707
self._poll_iter = UPDATEFREQ
708
while self._poll_true:
709
time.sleep(self.latency)
710
# makes sure to remove the last dead client as soon as possible -- and to get clients if we are dry
711
if self._poll_iter >= UPDATEFREQ or len(self.clients)==0 or (len(self.clients) > 0 and not(self.clients[0].status & Status.Up)):
715
self.pool_distribute()
717
if os.path.exists("EXIT"): # softexit
718
info(" @SOCKET: Soft exit request from file EXIT. Flushing job queue.", verbosity.low)
719
# releases all pending requests
720
for r in self.requests:
722
for c in self.clients:
724
c.shutdown(socket.SHUT_RDWR)
728
# flush it all down the drain
731
self._poll_thread = None
734
"""Returns a boolean specifying whether the thread has started yet."""
736
return (not self._poll_thread is None)
738
def start_thread(self):
739
"""Spawns a new thread.
741
Splits the main program into two threads, one that runs the polling loop
742
which updates the client list, and one which gets the data. Also sets up
743
the machinery to deal with a kill call, in the case of a Ctrl-C or
744
similar signal the signal is intercepted by the _kill_handler function,
745
which cleans up the spawned thread before closing the main thread.
748
NameError: Raised if the polling thread already exists.
752
if not self._poll_thread is None:
753
raise NameError("Polling thread already started")
754
self._poll_thread = threading.Thread(target=self._poll_loop, name="poll_" + self.address)
755
self._poll_thread.daemon = True
756
self._prev_kill[signal.SIGINT] = signal.signal(signal.SIGINT, self._kill_handler)
757
self._prev_kill[signal.SIGTERM] = signal.signal(signal.SIGTERM, self._kill_handler)
758
self._poll_true = True
759
self._poll_thread.start()
761
def end_thread(self):
762
"""Closes the spawned thread.
764
Deals with cleaning up the spawned thread cleanly. First sets
765
_poll_true to false to indicate that the poll_loop should be exited, then
766
closes the spawned thread and removes it.
769
self._poll_true = False
770
if not self._poll_thread is None:
771
self._poll_thread.join()
772
self._poll_thread = None