1
/* Copyright (c) 2007-2012 Sam Trenholme
2
* IPv6 code contributed by Jean-Jacques Sarton in 2007
6
* Redistribution and use in source and binary forms, with or without
7
* modification, are permitted provided that the following conditions
10
* 1. Redistributions of source code must retain the above copyright
11
* notice, this list of conditions and the following disclaimer.
12
* 2. Redistributions in binary form must reproduce the above copyright
13
* notice, this list of conditions and the following disclaimer in the
14
* documentation and/or other materials provided with the distribution.
16
* This software is provided 'as is' with no guarantees of correctness or
17
* fitness for purpose.
21
#include "DwTcpSocket.h"
26
/* One parameter that may eventually become a dwood2rc parameter */
27
#define TCP_BUFFERSIZE 1024
29
/* Mararc parameters that are set in DwMararc.c */
30
extern dw_str *key_s[];
31
extern dw_str *key_d[];
32
extern int32_t key_n[];
34
/* Parameters set in DwSys.c */
35
extern int64_t the_time;
36
extern dwr_rg *rng_seed;
38
/* List of addresses we will bind to */
39
extern ip_addr_T bind_address[];
40
extern ip_addr_T upstream_address[];
42
/* Some global variables */
43
extern int max_tcp_procs;
44
extern int timeout_seconds;
45
extern int timeout_seconds_tcp;
47
extern int upstream_port;
48
extern int num_retries;
51
SOCKET tcp_b_local[DW_MAXIPS + 1]; /* Local sockets */
53
/* The following is needed because Winsock's "make this socket non-blocking"
54
* uses a pointer to a number as one of its arguments */
56
extern u_long dont_block;
57
extern void windows_socket_start();
60
/* The upstream server we will connect to (round robin rotated) */
62
/* Allocate the memory for the list of the open remote TCP connections.
63
* This memory is never freed once allocated, because this data is always
64
* used by the program */
65
void malloc_tcp_pend() {
66
tcp_pend = dw_malloc((max_tcp_procs + 1) * sizeof(tcp_pend_T));
68
dw_alog_3strings("Fatal: Could not allocate tcp_pend","","");
73
/* Initialize the values of all the open remote TCP connections */
74
void init_tcp_b_pend() {
76
for(a = 0; a < max_tcp_procs; a++) {
81
/* TCP bind to all IP addresses we are to bind to and return the number of
82
* IP addresses we got */
86
if(key_n[DWM_N_tcp_listen] != 1) {
89
for(a = 0; a < DW_MAXIPS; a++) {
90
if(bind_address[a].len != 0) {
91
tcp_b_local[a] = do_bind(&bind_address[a],SOCK_STREAM);
92
if(tcp_b_local[a] != -1) {
104
/* Find a free pending TCP connection to use; return -1 if
105
* there isn't one (we're overloaded) */
106
int32_t find_free_tcp_pend() {
108
for(a = 0; a < max_tcp_procs; a++) {
109
if(tcp_pend[a].local == INVALID_SOCKET) {
113
return -1; /* None available (we're overloaded) */
116
/* Set up a TCP server that we will use to connect to a remote host */
117
SOCKET setup_tcp_server(sockaddr_all_T *server, dw_str *query, int b) {
118
SOCKET remote = INVALID_SOCKET;
119
ip_addr_T rem_ip = {0,{0,0},0,0};
121
rem_ip = get_upstream_ip(query,b);
122
if(rem_ip.glueless != 0) {
123
dw_destroy(rem_ip.glueless);
125
if(rem_ip.len == 4) {
126
server->V4.sin_family = AF_INET;
127
server->V4.sin_port = htons(upstream_port);
128
memcpy(&(server->V4.sin_addr),rem_ip.ip,4);
129
remote = socket(AF_INET,SOCK_STREAM,0);
131
} else if(rem_ip.len == 16) {
132
server->V6.sin6_family = AF_INET6;
133
server->V6.sin6_port = htons(upstream_port);
134
memcpy(&(server->V6.sin6_addr),rem_ip.ip,16);
135
remote = socket(AF_INET6,SOCK_STREAM,0);
138
return INVALID_SOCKET;
143
/* Given a tcp socket s, accept the connection on that socket, then
144
* prepare things so we can get <len><DNS packet>, send the query in their
145
* packet upstream, then send <len><DNS reply> back to the TCP client
147
void local_tcp_accept(SOCKET s) {
148
sockaddr_all_T client;
150
socklen_t len = sizeof(struct sockaddr_in);
154
b = find_free_tcp_pend();
155
if(b == -1) { /* Out of active TCP connections */
159
len = sizeof(client);
160
local = accept(s,(struct sockaddr *)&client,&len);
161
make_socket_nonblock(local);
163
if(local == INVALID_SOCKET) { /* accept() error */
167
/* This is where we do ip-based packet rejection */
168
get_from_ip_port(&from_ip,&client);
169
if(check_ip_acl(&from_ip) != 1) {
174
/* At this point, we want to get the 2-byte
175
* length of the DNS packet, followed by getting the DNS packet;
176
* we then want to be able to send UDP queries upstream to get
177
* the information we want, then we went to send the reply back
180
tcp_pend[b].buffer = dw_malloc(3); /* To put the two bytes we want */
181
if(tcp_pend[b].buffer == 0) {
186
tcp_pend[b].local = local;
187
tcp_pend[b].wanted = 2; /* We want to get the two-byte DNS length
188
* header from the client */
189
tcp_pend[b].die = get_time() + ((int64_t)timeout_seconds_tcp << 8);
192
/* For a given pending TCP connection, see if we have all the bytes we
193
* want. If we don't, try to get the data we want */
194
void tcp_get_wanted(int b) {
198
toget = tcp_pend[b].wanted - tcp_pend[b].got;
199
if(toget > 0 && tcp_pend[b].state == 0) {
200
buffer = dw_malloc(toget + 1);
204
len = recv(tcp_pend[b].local,buffer,toget,MSG_DONTWAIT);
205
/* Add the bytes we get to the end of the buffer of wanted
207
if(len > toget || len < 0) {
211
memcpy(tcp_pend[b].buffer + tcp_pend[b].got, buffer, len);
212
tcp_pend[b].got += len;
213
tcp_pend[b].die = get_time() +
214
((int64_t)timeout_seconds_tcp << 8);
219
/* For a given TCP connection, if we have all the bytes we want, do the
221
void tcp_process_data(int b) {
223
if(tcp_pend[b].wanted != tcp_pend[b].got || tcp_pend[b].buffer == 0
224
|| tcp_pend[b].state != 0) {
227
if(tcp_pend[b].wanted == 2) { /* If we wanted the length of the DNS
229
/* Based on the length of the DNS packet wanted, we next
230
* try to get the DNS packet */
232
wanted = tcp_pend[b].buffer[0];
234
wanted |= tcp_pend[b].buffer[1];
235
free(tcp_pend[b].buffer);
236
tcp_pend[b].buffer = 0;
238
closesocket(tcp_pend[b].local);
242
tcp_pend[b].wanted = wanted;
243
tcp_pend[b].buffer = dw_malloc(wanted + 1);
244
tcp_pend[b].die = get_time() +
245
((int64_t)timeout_seconds_tcp << 8);
246
} else if(tcp_pend[b].wanted >= 12) {
252
/* Convert a TCP packet on a connection in to a reply we either get from
253
* the cache or send upstream via UDP */
254
void tcp_to_udp(int b) {
255
int32_t local_id = -1;
256
dw_str *query = 0, *orig_query = 0;
258
local_id = get_dns_qid((void *)tcp_pend[b].buffer, tcp_pend[b].wanted,
261
closesocket(tcp_pend[b].local);
266
/* See if the data is cached */
267
query = dw_get_dname_type((void *)tcp_pend[b].buffer,12,
270
closesocket(tcp_pend[b].local);
274
orig_query = dw_copy(query);
275
dwc_lower_case(query);
277
if(get_reply_from_cache(query,0,0,local_id,0,b,orig_query,0) != 1) {
278
/* If not cached, make the buffer a UDP connection upstream */
279
forward_local_udp_packet(1,local_id,0,0,
280
(void *)tcp_pend[b].buffer,tcp_pend[b].wanted,b,
282
tcp_pend[b].state = 1; /* Awaiting UDP reply */
284
/* "<< 10" instead of "<< 8" because we need more time to
285
* get a reply upstream */
286
tcp_pend[b].die = get_time() +
287
((int64_t)timeout_seconds_tcp << 10);
291
dw_destroy(orig_query);
294
/* Called from the "UDP" code, this tells Deadwood to buffer a TCP
295
* packet to send back to the client */
296
void tcp_return_reply(int b, char *packet, int len) {
297
if(tcp_pend[b].buffer != 0) {
298
free(tcp_pend[b].buffer);
299
tcp_pend[b].buffer = 0;
301
tcp_pend[b].state = 2; /* Send TCP reply back to client */
302
tcp_pend[b].buffer = dw_malloc(len + 3);
303
/* 2-byte length header */
304
tcp_pend[b].buffer[0] = ((len & 0xff00) >> 8);
305
tcp_pend[b].buffer[1] = (len & 0xff);
306
memcpy(tcp_pend[b].buffer + 2, packet, len);
307
tcp_pend[b].wanted = len + 2;
309
tcp_pend[b].die = get_time() + ((int64_t)timeout_seconds_tcp << 8);
312
/* This code sends back buffered data to the client who sent us the original
314
void tcp_send_wanted(int b) {
317
if(tcp_pend[b].state != 2) { /* Data to return to client */
320
tosend = tcp_pend[b].wanted - tcp_pend[b].got;
321
if(tosend > 0 && tcp_pend[b].state == 2) {
322
len = send(tcp_pend[b].local,tcp_pend[b].buffer +
323
tcp_pend[b].got,tosend,MSG_DONTWAIT);
324
tcp_pend[b].got += len;
325
tcp_pend[b].die = get_time() +
326
((int64_t)timeout_seconds_tcp << 8);
328
closesocket(tcp_pend[b].local);
333
/* Create a DNS query packet, given a raw DNS query, as a dw_string object */
334
dw_str *make_dns_query_packet(dw_str *query, int id, int is_upstream) {
337
/* Convert the query in to a DNS packet to send */
338
/* 0x0180: QR = 0; Opcode = 0; AA = 0; TC = 0; RD = 1; RA = 1;
339
* Z = 0; RCODE = 0 ; 0x0080: Same but RD = 0 */
340
if(is_upstream == 1) {
341
out = make_dns_header(id,0x0180,0,0,0); /* Header */
343
out = make_dns_header(id,0x0080,0,0,0); /* Header */
346
goto catch_make_dns_query_packet;
348
if(dw_append(query,out) == -1) /* Question */ {
349
goto catch_make_dns_query_packet;
351
if(dw_put_u16(out,1,-1) == -1) /* "class" (internet) */ {
352
goto catch_make_dns_query_packet;
357
catch_make_dns_query_packet:
364
/* If we get a "truncated" UDP DNS packet upstream, and have connected via
365
* TCP to make our original DNS query, connect via TCP to the upstream
366
* server to try and get the non-truncated reply */
367
void tcp_truncated_retry(int b, dw_str *query, int id, int udp_id, int is_up) {
369
sockaddr_all_T server;
370
socklen_t len = sizeof(struct sockaddr_in);
372
if(tcp_pend[b].buffer != 0) {
373
free(tcp_pend[b].buffer);
374
tcp_pend[b].buffer = 0;
377
/* Prepare packet to send */
378
tmp = make_dns_query_packet(query,id,is_up);
380
goto catch_tcp_truncated_retry;
382
tcp_pend[b].buffer = dw_malloc(tmp->len + 3);
383
if(tcp_pend[b].buffer == 0) {
384
goto catch_tcp_truncated_retry;
386
tcp_pend[b].buffer[0] = (tmp->len & 0xff00) >> 8; /* Header byte 1 */
387
tcp_pend[b].buffer[1] = tmp->len & 0xff; /* Header byte 2 */
388
memcpy(tcp_pend[b].buffer + 2, tmp->str, tmp->len); /* DNS query */
389
tcp_pend[b].state = 3; /* Send buffer upstream */
390
tcp_pend[b].got = 0; /* No bytes sent */
391
tcp_pend[b].wanted = tmp->len + 2; /* Send entire packet */
393
/* Connect to upstream server over TCP */
394
tcp_pend[b].upstream = setup_tcp_server(&server,query,udp_id);
395
if(tcp_pend[b].upstream == INVALID_SOCKET) {
396
goto catch_tcp_truncated_retry;
398
make_socket_nonblock(tcp_pend[b].upstream);
400
if (server.Family == AF_INET6)
401
len = sizeof(struct sockaddr_in6);
403
if(connect(tcp_pend[b].upstream,(struct sockaddr *)&server,len) == -1
404
&& SCKT_ERR != EINPROGRESS) {
405
closesocket(tcp_pend[b].upstream);
406
goto catch_tcp_truncated_retry;
413
catch_tcp_truncated_retry:
417
closesocket(tcp_pend[b].local);
421
/* Send data via TCP to upstream DNS server */
422
void tcp_upstream_send(int b) {
425
if(tcp_pend[b].state != 3) {
429
if(tcp_pend[b].wanted < tcp_pend[b].got) {
430
closesocket(tcp_pend[b].local);
431
closesocket(tcp_pend[b].upstream);
435
len = send(tcp_pend[b].upstream,tcp_pend[b].buffer + tcp_pend[b].got,
436
tcp_pend[b].wanted - tcp_pend[b].got,MSG_DONTWAIT);
438
if(len == -1) { /* Nothing sent, try later */
442
tcp_pend[b].got += len;
443
tcp_pend[b].die = get_time() + ((int64_t)timeout_seconds << 8);
444
if(tcp_pend[b].got >= tcp_pend[b].wanted) { /* All sent, get ready
446
free(tcp_pend[b].buffer);
447
tcp_pend[b].buffer = 0;
448
tcp_pend[b].state = 4; /* Get packet length from upstream */
452
/* Prepare things to get the length upstream */
453
void tcp_prepare_upstream_len(int b) {
454
if(tcp_pend[b].state != 4) {
457
tcp_pend[b].buffer = dw_malloc(3);
458
if(tcp_pend[b].buffer == 0) {
459
closesocket(tcp_pend[b].local);
460
closesocket(tcp_pend[b].upstream);
463
tcp_pend[b].wanted = 2;
465
tcp_pend[b].state = 5; /* Getting length from upstream */
468
/* Get the two-byte length packet from upstream and allocate memory
469
* to store the up-and-coming packet */
470
void tcp_get_upstream_len(int b) {
475
if(tcp_pend[b].state != 5) {
478
toget = tcp_pend[b].wanted - tcp_pend[b].got;
480
len = recv(tcp_pend[b].upstream,
481
tcp_pend[b].buffer + tcp_pend[b].got,
482
toget, MSG_DONTWAIT);
486
tcp_pend[b].got += len;
487
tcp_pend[b].die = get_time() +
488
((int64_t)timeout_seconds_tcp << 8);
489
} else if(toget == 0) {
490
wanted = tcp_pend[b].buffer[0] & 0xff;
492
wanted |= tcp_pend[b].buffer[1] & 0xff;
493
free(tcp_pend[b].buffer);
494
tcp_pend[b].buffer = 0;
495
tcp_pend[b].wanted = wanted;
496
tcp_pend[b].buffer = dw_malloc(wanted + 3);
497
tcp_pend[b].die = get_time() +
498
((int64_t)timeout_seconds_tcp << 8);
499
tcp_pend[b].state = 6;
500
tcp_pend[b].buffer[0] = (wanted & 0xff00) >> 8;
501
tcp_pend[b].buffer[1] = (wanted & 0xff);
506
/* Forward data from upstream DNS server locally; nearly identical to
507
* Deadwood 2.3's tcp_local2remote */
508
void tcp_downstream_forward(int b) {
512
if(tcp_pend[b].state != 6) {
516
if(tcp_pend[b].got >= tcp_pend[b].wanted ||
517
tcp_pend[b].buffer == 0 ||
518
tcp_pend[b].sent >= tcp_pend[b].wanted + 2) {
519
closesocket(tcp_pend[b].local);
520
closesocket(tcp_pend[b].upstream);
525
/* The "2" you see is the 2-byte length header */
526
len = recv(tcp_pend[b].upstream,
527
tcp_pend[b].buffer + 2 + tcp_pend[b].got,
528
tcp_pend[b].wanted - tcp_pend[b].got,MSG_DONTWAIT);
530
if(len != (tcp_pend[b].wanted - tcp_pend[b].got)) {
532
return; /* Try again later */
534
tcp_pend[b].die = get_time() +
535
((int64_t)timeout_seconds_tcp << 8);
539
tcp_pend[b].got += len;
541
/* Again, the '2' is the 2-byte length header */
542
actual = send(tcp_pend[b].local,tcp_pend[b].buffer + tcp_pend[b].sent,
543
tcp_pend[b].got - tcp_pend[b].sent + 2,
548
} else if(actual != len) { /* Partial sends not supported */
549
tcp_pend[b].die = get_time() +
550
((int64_t)timeout_seconds_tcp << 8);
551
tcp_pend[b].sent += actual;
553
} else if(actual == len && tcp_pend[b].wanted == tcp_pend[b].got) {
554
/* All data sent, success */
555
closesocket(tcp_pend[b].local);
556
closesocket(tcp_pend[b].upstream);
561
tcp_pend[b].sent += actual;
565
/* Handle all TCP connections with data pending to be sent */
566
void tcp_handle_all(int b) {
567
if(key_n[DWM_N_tcp_listen] == 1) {
571
tcp_upstream_send(b);
572
tcp_prepare_upstream_len(b);
573
tcp_get_upstream_len(b);
574
tcp_downstream_forward(b);
578
/* Disconnect idle TCP connections */
579
void kill_tcp_expired() {
581
for(a = 0; a < max_tcp_procs; a++) {
582
if(tcp_pend[a].die > 0 && tcp_pend[a].die < get_time()) {
583
closesocket(tcp_pend[a].local);
589
/* Process any pending connections which select() caught */
590
void tcp_process_results(int a, fd_set *rx_fd) {
593
/* Find the pending connection */
594
while(a > 0 && z < 10000) {
595
/* Handle new connections */
596
for(b = 0; b < DW_MAXIPS; b++) {
597
if(tcp_b_local[b] != INVALID_SOCKET &&
598
FD_ISSET(tcp_b_local[b],rx_fd)) {
599
local_tcp_accept(tcp_b_local[b]);