ds.c 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #include <malloc.h>
  2. #include <unistd.h>
  3. #include <string.h>
  4. //#include <sys/select.h>
  5. #include <sys/types.h>
  6. #include <sys/socket.h>
  7. #include <sys/stat.h>
  8. #include <fcntl.h>
  9. #include <netinet/in.h>
  10. #include <errno.h>
  11. #include <openssl/bio.h>
  12. #include <openssl/ssl.h>
  13. #include <openssl/err.h>
  14. #include <pthread.h>
  15. #include "ds.h"
  16. int openSslLoaded = 0;
  17. void *clear(void *ptr){
  18. int e = errno;
  19. if(ptr){
  20. free(ptr);
  21. }
  22. errno = e;
  23. return NULL;
  24. }
  25. pthread_mutex_t loadLock;
  26. void loadOpenSSL(const char *dh){
  27. if(openSslLoaded)
  28. return;
  29. pthread_mutex_lock(&loadLock);
  30. if(!openSslLoaded){
  31. SSL_load_error_strings();
  32. ERR_load_BIO_strings();
  33. ERR_load_crypto_strings();
  34. SSL_library_init();
  35. OpenSSL_add_all_algorithms();
  36. openSslLoaded = 1;
  37. }
  38. pthread_mutex_unlock(&loadLock);
  39. }
  40. void copy6addr(unsigned char d[16], const unsigned char s[16]){
  41. int i;
  42. for(i = 0; i < 16; i++)
  43. d[i] = s[i];
  44. }
  45. void zero6addr(unsigned char d[16]){
  46. int i;
  47. for(i = 0; i < 16; i++)
  48. d[i] = 0;
  49. }
  50. nethandler getNethandler(const int ipv6, const int port){
  51. nethandler h = (nethandler)malloc(sizeof(s_nethandler));
  52. h->ipv6 = ipv6;
  53. if(ipv6){
  54. h->fd = socket(AF_INET6, SOCK_STREAM, 0);
  55. }else{
  56. h->fd = socket(AF_INET, SOCK_STREAM, 0);
  57. }
  58. int optval = 1;
  59. setsockopt(h->fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
  60. int e, en;
  61. if(ipv6){
  62. struct sockaddr_in6 add;
  63. add.sin6_family = AF_INET6;
  64. zero6addr(add.sin6_addr.s6_addr);
  65. add.sin6_port = htons(port);
  66. e = bind(h->fd, (struct sockaddr*) &add, sizeof(add));
  67. }else{
  68. struct sockaddr_in add;
  69. add.sin_family = AF_INET;
  70. add.sin_addr.s_addr = INADDR_ANY;
  71. add.sin_port = htons(port);
  72. e = bind(h->fd, (struct sockaddr*) &add, sizeof(add));
  73. }
  74. if(e)
  75. return clear(h);
  76. e = listen(h->fd, DEFAULT_LISTENNING_QUEUE);
  77. if(e)
  78. return clear(h);
  79. return h;
  80. }
  81. nethandler getIPv4Port(const int port){
  82. return getNethandler(0, port);
  83. }
  84. nethandler getPort(const int port){
  85. return getNethandler(1, port);
  86. }
  87. ds createFromFile(int f){
  88. ds d = (ds)malloc(sizeof(s_ds));
  89. d->tp = file;
  90. d->fd = f;
  91. return d;
  92. }
  93. ds createFromFileName(const char *f){
  94. int fd = open(f, O_CREAT | O_RDWR, 0666);
  95. if(fd == -1){
  96. return NULL;
  97. }
  98. return createFromFile(fd);
  99. }
  100. ds createFromHandler(nethandler h){
  101. ds d = (ds)malloc(sizeof(s_ds));
  102. d->tp = sock;
  103. unsigned int s = sizeof(d->peer);
  104. d->fd = accept(h->fd, (struct sockaddr*)&(d->peer), &s);
  105. if(d->fd <= 0)
  106. return clear(d);
  107. d->ipv6 = d->peer.ss_family == AF_INET6;
  108. d->server = 1;
  109. return d;
  110. }
  111. ds createToHost(struct sockaddr *add, const int add_size, const int ipv6){
  112. ds d = (ds)malloc(sizeof(s_ds));
  113. d->tp = sock;
  114. if(ipv6){
  115. d->fd = socket(AF_INET6, SOCK_STREAM, 0);
  116. }else{
  117. d->fd = socket(AF_INET, SOCK_STREAM, 0);
  118. }
  119. if(connect(d->fd, add, add_size) < 0){
  120. int e = errno;
  121. free(d);
  122. errno = e;
  123. return NULL;
  124. }
  125. d->server = 0;
  126. return d;
  127. }
  128. ds createToIPv4Host(const unsigned long host, const int port){
  129. struct sockaddr_in add;
  130. add.sin_family = AF_INET;
  131. add.sin_port = htons(port);
  132. add.sin_addr.s_addr = host;
  133. return createToHost((struct sockaddr*) &add, sizeof(add), 0);
  134. }
  135. ds createToIPv6Host(const unsigned char host[16], const int port){
  136. struct sockaddr_in6 add;
  137. add.sin6_family = AF_INET6;
  138. add.sin6_port = htons(port);
  139. add.sin6_flowinfo = 0;
  140. copy6addr(add.sin6_addr.s6_addr, host);
  141. add.sin6_scope_id = 0;
  142. return createToHost((struct sockaddr*) &add, sizeof(add), 1);
  143. }
  144. int getPeer(ds d, unsigned long *ipv4peer, unsigned char ipv6peer[16], int *ipv6){
  145. int port = 0;
  146. struct sockaddr_storage peer;
  147. int peer_size = sizeof(peer);
  148. if(getpeername(d->fd, (struct sockaddr*)&peer, &peer_size)){
  149. return 0;
  150. }
  151. if(peer.ss_family == AF_INET){
  152. struct sockaddr_in *a = (struct sockaddr_in*)&(peer);
  153. zero6addr(ipv6peer);
  154. *ipv6 = -1;
  155. *ipv4peer = a->sin_addr.s_addr;
  156. port = a->sin_port;
  157. }else{
  158. struct sockaddr_in6 *a = (struct sockaddr_in6*)&(peer);
  159. *ipv4peer = 0;
  160. *ipv6 = 1;
  161. copy6addr(ipv6peer, a->sin6_addr.s6_addr);
  162. port = a->sin6_port;
  163. }
  164. return port;
  165. }
  166. int sendDs(ds d, const char *b, const int s){
  167. return write(d->fd, b, s);
  168. }
  169. int tlsDsSend(tlsDs d, const char *b, const int s){
  170. return SSL_write(d->s, b, s);
  171. }
  172. int stdDsSend(const char *b, const int s){
  173. return write(1, b, s);
  174. }
  175. int recvDs(ds d, char *b, const int s){
  176. return read(d->fd, b, s);
  177. }
  178. int tlsDsRecv(tlsDs d, char *b, const int s){
  179. return SSL_read(d->s, b, s);
  180. }
  181. int stdDsRecv(char *b, const int s){
  182. return read(0, b, s);
  183. }
  184. int prepareToClose(ds d){
  185. int fd = d->fd;
  186. free(d);
  187. return fd;
  188. }
  189. ds closeTls(tlsDs d){
  190. ds original = d->original;
  191. SSL_shutdown(d->s);
  192. //No bidirectional shutdown supported
  193. //SSL_shutdown(d->s);
  194. SSL_free(d->s);
  195. free(d);
  196. return original;
  197. }
  198. void closeHandler(nethandler h){
  199. close(h->fd);
  200. free(h);
  201. }
  202. tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
  203. loadOpenSSL(dh);
  204. SSL_CTX * ctx = NULL;
  205. if(d->server)
  206. ctx = SSL_CTX_new(TLSv1_server_method());
  207. else
  208. ctx = SSL_CTX_new(TLSv1_client_method());
  209. if(!ctx)
  210. return NULL;
  211. if(d->server){
  212. FILE *dhfile = fopen(dh, "r");
  213. DH *dhdt = PEM_read_DHparams(dhfile, NULL, NULL, NULL);
  214. fclose(dhfile);
  215. if(SSL_CTX_set_tmp_dh(ctx, dhdt) <= 0){
  216. int f = prepareToClose(d);
  217. closeFd(f);
  218. clear(dhdt);
  219. return clear(ctx);
  220. }
  221. }
  222. SSL_CTX_set_options(ctx, SSL_OP_SINGLE_DH_USE);
  223. if(cert)
  224. if(SSL_CTX_use_certificate_chain_file(ctx, cert) != 1){
  225. int f = prepareToClose(d);
  226. closeFd(f);
  227. return clear(ctx);
  228. }
  229. if(key)
  230. if(SSL_CTX_use_PrivateKey_file(ctx, key, SSL_FILETYPE_PEM) != 1){
  231. int f = prepareToClose(d);
  232. closeFd(f);
  233. return clear(ctx);
  234. }
  235. tlsDs t = (tlsDs)malloc(sizeof(s_tlsDs));
  236. t->original = d;
  237. if(!(t->s = SSL_new(ctx))){
  238. int f = prepareToClose(d);
  239. closeFd(f);
  240. clear(ctx);
  241. return clear(t);
  242. }
  243. if(!SSL_set_fd(t->s, d->fd)){
  244. closeTls(t);
  245. return NULL;
  246. }
  247. int retry = 1;
  248. int e;
  249. while(retry){
  250. retry = 0;
  251. if(d->server){
  252. SSL_set_accept_state(t->s);
  253. e = SSL_accept(t->s);
  254. }else{
  255. SSL_set_connect_state(t->s);
  256. e = SSL_connect(t->s);
  257. }
  258. if(e <= 0){
  259. unsigned long erval = SSL_get_error(t->s, e);
  260. //char ertxt[300];
  261. //ERR_error_string(erval, ertxt);
  262. //fprintf(stderr, "SSL Error: %s\n", ertxt);
  263. if((erval == SSL_ERROR_WANT_READ) || (erval == SSL_ERROR_WANT_WRITE)){
  264. //Here goes support to non-blocking IO, once it's supported
  265. //retry = 1;
  266. }else{
  267. closeTls(t);
  268. return NULL;
  269. }
  270. }
  271. }
  272. return t;
  273. }
  274. int getFd(ds d){
  275. return d->fd;
  276. }
  277. int getTlsFd(tlsDs t){
  278. ds d = t->original;
  279. return d->fd;
  280. }
  281. void closeFd(int fd){
  282. close(fd);
  283. }