4
#include <sys/socket.h>
7
#include <netinet/tcp.h>
8
#include <netinet/in.h>
10
#include <sys/types.h>
16
#include "nbd-proxy.h"
19
* Handles the different signal during the execution
20
* sig -- the type of signal emitted
21
* SIGINT is handled, it terminates the program
23
void sighandler(int sig) {
26
printf("Ctrl+c pressed, exiting\n");
33
* This thread acts as a proxy from client to server
34
* data -- struct containing shared data between threads (thread_data)
36
void *client_to_server(void *data) {
37
struct thread_data *infos = (struct thread_data*) data;
38
// nbd_request size : 28 bytes
39
char recv_buf[sizeof(struct nbd_request)];
42
PRINT_DEBUG("[t_client] Init mainloop\n");
44
bytes_read = recv(infos->client_socket, recv_buf, sizeof(recv_buf), MSG_WAITALL);
46
PRINT_DEBUG("[t_client] Client disconnected on recv(). Dying...\n");
51
struct nbd_request *new_req = (struct nbd_request*) malloc(sizeof(struct nbd_request));
52
memcpy(new_req, recv_buf, sizeof(recv_buf));
55
// Checking if data is a valid nbd_request
56
if(new_req->magic == ntohl(NBD_REQUEST_MAGIC)) {
57
// NBD_CMD_READ from client
58
if(new_req->type == ntohl(NBD_CMD_READ)) {
59
PRINT_DEBUG("[t_client] Got nbd_request : handle(%s) of len(%u) and from(%lu)\n",
60
handle_to_string(new_req->handle), ntohl(new_req->len), ntohll(new_req->from));
62
pthread_mutex_lock(&net_lock);
63
// Adding nbd_request received to queue (thread safe)
64
add_nbd_request(new_req, infos->reqs);
65
if(send_to_server(infos, recv_buf, bytes_read, new_req) == -1) {
66
PRINT_DEBUG("[t_client] Failed sending nbd_request to server\n");
68
pthread_mutex_unlock(&net_lock);
69
// NBD_CMD_DISC from client
70
} else if(new_req->type == ntohl(NBD_CMD_DISC)) {
71
PRINT_DEBUG("[t_client] NBD_DISCONNECT from client. Cleaning\n");
72
// On thin client infrastructure, this should not happen. Quitting properly
78
PRINT_DEBUG("[t_client] WTF client_to_server outside while\n");
82
* This thread acts as a proxy from server to client
83
* data -- struct containing shared data between threads (thread_data)
85
void *server_to_client(void *data) {
86
struct thread_data *infos = (struct thread_data*) data;
88
struct nbd_request *current_nr = NULL;
89
struct nbd_request *flag_disc = NULL;
90
char recv_buf[SRV_RECV_BUF * SEND_BUF_FACTOR];
93
int discard_reply_flag = 0;
95
int size_recv_buf = sizeof(struct nbd_reply);
96
char send_buf[SRV_RECV_BUF * SEND_BUF_FACTOR];
97
ssize_t total_size = 0;
100
PRINT_DEBUG("[t_server] Init mainloop\n");
102
PRINT_DEBUG("[t_server] recv mode\n");
103
bytes_read = recv(infos->server_socket, recv_buf, size_recv_buf, MSG_WAITALL);
104
// Keep track of bytes read for used to update nbd_request len
105
r_bytes = bytes_read;
106
if(bytes_read <= 0) {
107
PRINT_DEBUG("[t_server] Server disconnected on recv() (bytes_read = %d). Reconnecting\n",
109
reconnect_server(infos);
110
// Sending last nbd_request modified
111
if(current_nr != NULL) {
112
PRINT_DEBUG("[t_server] nbd_request count : %d\n", count_nbd_request(infos->reqs));
113
send_to_server(infos, (char *) current_nr, sizeof(struct nbd_request), NULL);
114
PRINT_DEBUG("[t_server] Last known nbd_request sent\n");
115
PRINT_DEBUG("|-- nbd_request : handle(%s) of len(%u) and from(%lu)\n",
116
handle_to_string(current_nr->handle), ntohl(current_nr->len),
117
ntohll(current_nr->from));
119
flag_disc = current_nr;
120
discard_reply_flag = sizeof(struct nbd_reply);
122
PRINT_DEBUG("[t_server] Resending all nbd_request after recv error from server\n");
123
resend_all_nbd_requests(infos);
127
size_recv_buf = sizeof(struct nbd_reply);
131
//PRINT_DEBUG("[t_server] Bytes read : %ld\n", bytes_read);
134
memcpy(&nr, recv_buf, sizeof(struct nbd_reply));
135
// If the packet received contain a valid nbd_reply
136
if(nr.magic == ntohl(NBD_REPLY_MAGIC)) {
137
PRINT_DEBUG("[t_server] Got nbd_reply : handle(%s)\n", handle_to_string(nr.handle));
138
// If already data in send_buf, it means that the last nbd_request is over
141
current_nr = get_nbd_request_by_handle(nr.handle, infos->reqs);
143
if(current_nr == NULL)
144
PRINT_DEBUG("[t_server] nbd_reply handle unknown\n");
146
// Adapting recv buffer size
147
size_recv_buf = ntohl(current_nr->len);
148
// Ignoring nbd_reply size for len of nbd_request
149
r_bytes -= sizeof(struct nbd_reply);
151
} else if(current_nr == NULL) {
152
PRINT_DEBUG("[t_server] Fatal error: No nbd_reply received and no nbd_request to serve\n");
158
PRINT_DEBUG("Copy %ld bytes to send_buf at %ld pos of send_buf\n",
159
bytes_read, total_size);
160
memcpy(send_buf + total_size, recv_buf, bytes_read);
161
total_size += bytes_read;
163
PRINT_DEBUG("[t_server] nbd_request in queue : %d\n", count_nbd_request(infos->reqs));
164
// Sending to client when all nbd_request's data received
166
PRINT_DEBUG("[t_server] Sending %d bytes to client\n",(int)total_size - discard_reply_flag);
167
// Sending data to client (P -> C). On client disconnect, nbd proxy STOP!
168
send_to_client(infos, send_buf + discard_reply_flag, total_size - discard_reply_flag);
170
discard_reply_flag = 0;
174
// Updating current nbd_request.len of received bytes (r_bytes)
175
if(current_nr != NULL) {
176
current_nr->len = htonl(ntohl(current_nr->len) - r_bytes);
177
current_nr->from = htonll(ntohll(current_nr->from) + r_bytes);
179
if((current_nr->len) == 0) {
180
// Removing nbd_request from queue. Not useful anymore (atomic action)
181
rm_nbd_request(current_nr, infos->reqs);
182
if(current_nr == flag_disc) {
183
PRINT_DEBUG("[t_server] Last known nbd_request done. Resending queue (count : %d)\n",
184
count_nbd_request(infos->reqs));
185
resend_all_nbd_requests(infos);
189
size_recv_buf = sizeof(struct nbd_reply);
193
PRINT_DEBUG("[t_server] WTF server_to_client outside while\n");
197
* Connect to server with the specific negotiation protocol of nbd
198
* sock -- socket to nbd server
200
* Return nbd_init_data* containing all the information from server. This
201
* needs to be resend to the client
203
struct nbd_init_data *nbd_connect(int sock) {
204
struct nbd_init_data *nid = (struct nbd_init_data *) malloc(sizeof(struct nbd_init_data));
206
PRINT_DEBUG("[nbd_connect] Negotiation: ");
207
/* Read INIT_PASSWD */
208
if (read(sock, &(nid->init_passwd) , sizeof(nid->init_passwd)) < 0)
209
PRINT_DEBUG("[nbd_connect] Failed/1: %m\n");
210
if (strlen(nid->init_passwd)==0)
211
PRINT_DEBUG("[nbd_connect] Server closed connection\n");
212
if (strcmp(nid->init_passwd, INIT_PASSWD))
213
PRINT_DEBUG("[nbd_connect] INIT_PASSWD bad\n");
216
/* Read cliserv_magic */
217
if (read(sock, &(nid->magic), sizeof(nid->magic)) < 0)
218
PRINT_DEBUG("[nbd_connect] Failed/2: %m\n");
219
nid->magic = ntohll(nid->magic);
220
if (nid->magic != cliserv_magic)
221
PRINT_DEBUG("[nbd_connect] Not enough cliserv_magic\n");
222
nid->magic = ntohll(nid->magic);
226
if (read(sock, &(nid->size), sizeof(nid->size)) < 0)
227
PRINT_DEBUG("[nbd_connect] Failed/3: %m\n");
231
if (read(sock, &(nid->flags), sizeof(nid->flags)) < 0)
232
PRINT_DEBUG("[nbd_connect] Failed/4: %m\n");
236
if (read(sock, &(nid->zeros), sizeof(nid->zeros)) < 0)
237
PRINT_DEBUG("[nbd_connect] Failed/5: %m\n");
243
/* create_connect_sock
244
* Create a socket connected to a specific and point.
245
* port -- which port to connect to
246
* addr -- remote IP address
248
* Return socket file descriptor
250
int create_connect_sock(int port, char *addr) {
254
struct sockaddr_in struct_addr;
256
sock = socket(PF_INET, SOCK_STREAM, 0);
258
PRINT_DEBUG("Unable to create connect socket\n");
262
if((setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) == -1)
263
perror("setsockopt");
265
struct_addr.sin_family = AF_INET;
266
struct_addr.sin_port = htons(port);
267
struct_addr.sin_addr.s_addr = inet_addr(addr);
268
memset(struct_addr.sin_zero, 0, sizeof(struct_addr.sin_zero));
270
PRINT_DEBUG("[create_connect_sock] Connect socket to %s\n", addr);
271
err = connect(sock, (struct sockaddr *) &struct_addr, sizeof(struct_addr));
273
PRINT_DEBUG("Server unable to connect\n");
276
PRINT_DEBUG("[create_connect_sock] Connected and ready.\n");
280
/* create_listen_sock
281
* Create a socket listening on a port and bind
282
* port -- which port to listen on
283
* addr -- IP addr to bind on
285
* Return socket file descriptor
287
int create_listen_sock(int port, int addr) {
291
struct sockaddr_in struct_addr;
294
if((sock = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
296
PRINT_DEBUG("Unable to create listen socket\n");
300
struct_addr.sin_family = AF_INET;
301
struct_addr.sin_port = htons(port);
302
struct_addr.sin_addr.s_addr = ntohl(addr);
303
memset(struct_addr.sin_zero, 0, sizeof(struct_addr.sin_zero));
305
if((setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) == -1)
306
perror("setsockopt");
308
PRINT_DEBUG("[create_listen_sock] Binding socket to localhost\n");
309
if(bind(sock, (struct sockaddr *) &struct_addr, sizeof(struct_addr)) == -1) {
311
PRINT_DEBUG("Unable to bind socket\n");
315
PRINT_DEBUG("[create_listen_sock] Listining to localhost\n");
316
if(listen(sock,1) == -1) {
318
PRINT_DEBUG("Unable to listen\n");
322
/* Send SIGHUP to detach the parent process */
323
kill(getppid(), SIGHUP);
327
PRINT_DEBUG("[create_listen_sock] Accepting socket\n");
328
s_size = sizeof(struct_addr);
329
if((newfd = accept(sock, (struct sockaddr *) &struct_addr, &s_size)) == -1) {
331
PRINT_DEBUG("Accept() failed\n");
335
if((setsockopt(newfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) == -1)
336
perror("setsockopt");
337
PRINT_DEBUG("[create_listen_sock] Socket bound and ready. Returning\n");
342
/* resend_all_nbd_requests
343
* Resend all nbd_requests in memory to nbd server
344
* infos -- struct thread_data
345
* except_nr -- nbd_request NOT to send to server
347
void resend_all_nbd_requests(struct thread_data *infos) {
348
pthread_mutex_lock(&data_lock);
349
struct proxy_nbd_request *current_pnr = *(infos->reqs);
350
if(current_pnr == NULL) {
351
PRINT_DEBUG("[resend] No nbd_request in queue\n");
352
pthread_mutex_unlock(&data_lock);
356
// Don't send specific nbd_request
357
if(current_pnr->nr == infos->except_nr)
359
send_to_server(infos, (char*) current_pnr->nr, sizeof(struct nbd_request), NULL);
360
PRINT_DEBUG("[resend] nbd_request : handle(%s) of len(%u) and from(%lu)\n",
361
handle_to_string(current_pnr->nr->handle), ntohl(current_pnr->nr->len),
362
ntohll(current_pnr->nr->from));
363
} while((current_pnr = current_pnr->next) != NULL);
364
pthread_mutex_unlock(&data_lock);
368
* Clean data structure and reconnect to server
369
* infos -- struct thread_data
371
void reconnect_server(struct thread_data *infos) {
372
pthread_mutex_lock(&net_lock);
373
// Close current server socket
374
close(infos->server_socket);
376
infos->server_socket = create_connect_sock(infos->server_port, infos->server_addr);
377
server_connect(infos);
378
pthread_mutex_unlock(&net_lock);
379
PRINT_DEBUG("[reconnect_server] Reconnected to server\n");
383
* Establish connection to nbd server
384
* infos -- struct thread_data
386
void server_connect(struct thread_data *infos) {
387
// Saving nid information to thread_data
388
infos->nid = nbd_connect(infos->server_socket);
389
PRINT_DEBUG("Server Connected\n");
393
* Connect client to server (nbd point of view)
394
* td -- struct thread_data containing thread informations
396
void client_connect(struct thread_data *td) {
397
struct nbd_init_data *nid = td->nid;
398
send(td->client_socket, &(nid->init_passwd), sizeof(nid->init_passwd), 0);
399
send(td->client_socket, &(nid->magic), sizeof(nid->magic), 0);
400
send(td->client_socket, &(nid->size), sizeof(nid->size) + sizeof(nid->flags) + sizeof(nid->zeros) - 4, 0);
401
PRINT_DEBUG("Client Connected\n");
405
* Send data to client with safe control
406
* infos -- thread_data
407
* buf -- data to send
408
* size -- size of data to send
409
* Return -1 if error detected
412
void send_to_client(struct thread_data *infos, char *buf, size_t size) {
414
if((send(infos->client_socket, buf, size, 0) == -1)) {
415
PRINT_DEBUG("Client disconnected on send(). Dying...\n");
422
* Send data to server with safe control
423
* infos -- thread_data
424
* buf -- data to send
425
* size -- size of data to send
426
* Return -1 if error detected
429
int send_to_server(struct thread_data *infos, char *buf, size_t size, struct nbd_request *nr) {
432
if((send(infos->server_socket, buf, size, 0) == -1)) {
434
infos->except_nr = NULL;
442
int main(int argc, char *argv[]) {
443
int server_port, listen_port, client_socket, server_socket;
444
int client_th, server_th, rc;
445
char *server_address;
447
pthread_t threads[NUM_THREADS];
448
struct nbd_init_data *nid;
451
printf("Usage : nbd-proxy server_address server_port listening_port\n");
455
server_address = argv[1];
456
server_port = atoi(argv[2]);
457
listen_port = atoi(argv[3]);
459
struct thread_data *th_d1 = (struct thread_data *) malloc(sizeof(struct thread_data));
460
struct proxy_nbd_request *pnr = NULL;
462
signal(SIGINT, sighandler);
465
/* Our process ID and Session ID */
468
/* Fork off the parent process */
473
/* If we got a good PID, then
474
we can exit the parent process. */
476
/* Wait until the child process is ready to process the requests and exit */
481
/* Change the file mode mask */
484
/* Create a new SID for the child process */
487
/* Log the failure */
491
/* Change the current working directory */
492
if ((chdir("/")) < 0) {
493
/* Log the failure */
497
/* Close out the standard file descriptors */
499
close(STDOUT_FILENO);
500
close(STDERR_FILENO);
503
PRINT_DEBUG("[main] Creating sockets\n");
505
server_socket = create_connect_sock(server_port, server_address);
507
PRINT_DEBUG("[main] nbd_connect\n");
508
// Negotiate with nbd server
509
nid = nbd_connect(server_socket);
512
client_socket = create_listen_sock(listen_port, INADDR_LOOPBACK);
514
th_d1->client_socket = client_socket;
515
th_d1->server_socket = server_socket;
516
th_d1->server_port = server_port;
517
th_d1->server_addr = server_address;
518
th_d1->listen_port = listen_port;
522
// Connect client with server negotiation data
523
client_connect(th_d1);
525
PRINT_DEBUG("[main] Spawning thread server\n");
526
server_th = pthread_create(&threads[1], NULL, server_to_client, (void *) th_d1);
527
PRINT_DEBUG("[main] Spawning thread client\n");
528
client_th = pthread_create(&threads[0], NULL, client_to_server, (void *) th_d1);
530
rc = pthread_join(threads[0], &status);
532
PRINT_DEBUG("ERROR, return code from pthread_join is %d\n", rc);
534
rc = pthread_join(threads[1], &status);
536
PRINT_DEBUG("ERROR, return code from pthread_join is %d\n", rc);
539
PRINT_DEBUG("[main] main is dying\n");