Browse Source

Simplified the C code, and changing Target to use GHC green threads.

Marcos Dumay de Medeiros 8 years ago
parent
commit
d6b7db6bd3
3 changed files with 114 additions and 108 deletions
  1. 39 24
      src/System/IO/Uniform/Targets.hs
  2. 54 58
      src/System/IO/Uniform/ds.c
  3. 21 26
      src/System/IO/Uniform/ds.h

+ 39 - 24
src/System/IO/Uniform/Targets.hs

@@ -14,11 +14,15 @@ import Data.ByteString (ByteString)
 import qualified Data.ByteString as BS
 import qualified Data.List as L
 import Control.Exception
+import Control.Applicative ((<$>))
 import qualified Network.Socket as Soc
 import System.IO.Error
 
 import Data.Default.Class
 
+import GHC.Conc (closeFdWith, threadWaitRead, threadWaitWrite)
+import System.Posix.Types (Fd(..))
+
 -- | Settings for starttls functions.
 data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String} deriving (Read, Show)
 
@@ -65,10 +69,9 @@ instance UniformIO SomeIO where
 data Nethandler
 -- | A bounded IP port from where to accept SocketIO connections.
 newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)}
-data SockDs
-newtype SocketIO = SocketIO {sock :: (Ptr SockDs)}
-data FileDs
-newtype FileIO = FileIO {fd :: (Ptr FileDs)}
+data Ds
+newtype SocketIO = SocketIO {sock :: (Ptr Ds)}
+newtype FileIO = FileIO {fd :: (Ptr Ds)}
 data TlsDs
 newtype TlsStream = TlsStream {tls :: (Ptr TlsDs)}
 
@@ -76,19 +79,21 @@ newtype TlsStream = TlsStream {tls :: (Ptr TlsDs)}
 instance UniformIO SocketIO where
   uRead s n = allocaArray n (
     \b -> do
-      count <- c_recvSock (sock s) b (fromIntegral n)
+      count <- c_recv (sock s) b (fromIntegral n)
       if count < 0
         then throwErrno "could not read"
         else BS.packCStringLen (b, fromIntegral count)
     )
   uPut s t = BS.useAsCStringLen t (
     \(str, n) -> do
-      count <- c_sendSock (sock s) str $ fromIntegral n
+      count <- c_send (sock s) str $ fromIntegral n
       if count < 0
         then throwErrno "could not write"
         else return ()
     )
-  uClose s = c_closeSock (sock s)
+  uClose s = do
+    f <- Fd <$> c_prepareToClose (sock s)
+    closeFdWith closeFd f
   startTls st s = withCString (tlsCertificateChainFile st) (
     \cert -> withCString (tlsPrivateKeyFile st) (
       \key -> do
@@ -104,19 +109,21 @@ instance UniformIO SocketIO where
 instance UniformIO FileIO where
   uRead s n = allocaArray n (
     \b -> do
-      count <- c_recvFile (fd s) b $ fromIntegral n
+      count <- c_recv (fd s) b $ fromIntegral n
       if count < 0
         then throwErrno "could not read"
         else  BS.packCStringLen (b, fromIntegral count)
     )
   uPut s t = BS.useAsCStringLen t (
     \(str, n) -> do
-      count <- c_sendFile (fd s) str $ fromIntegral n
+      count <- c_send (fd s) str $ fromIntegral n
       if count < 0
         then throwErrno "could not write"
         else return ()
     )
-  uClose s = c_closeFile (fd s)
+  uClose s = do
+    f <- Fd <$> c_prepareToClose (fd s)
+    closeFdWith closeFd f
   -- Not implemented yet.
   startTls _ _ = return . TlsStream $ nullPtr
   isSecure _ = False
@@ -138,7 +145,10 @@ instance UniformIO TlsStream where
         then throwErrno "could not write"
         else return ()
     )
-  uClose s = c_closeTls (tls s)
+  uClose s = do
+    d <- c_closeTls (tls s)
+    f <- Fd <$> c_prepareToClose d
+    closeFdWith closeFd f
   startTls _ s = return s
   isSecure _ = True
 
@@ -242,28 +252,33 @@ getPeer s = allocaArray 16 (
     )
   )
     
+closeFd :: Fd -> IO ()
+closeFd (Fd f) = c_closeFd f
+            
 -- | Closes a BoundedPort, and releases any resource used by it.
 closePort :: BoundedPort -> IO ()
 closePort p = c_closePort (lis p)
 
 foreign import ccall safe "getPort" c_getPort :: CInt -> IO (Ptr Nethandler)
-foreign import ccall safe "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr SockDs)
-foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr FileDs)
-foreign import ccall safe "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr SockDs)
-foreign import ccall safe "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr SockDs)
+foreign import ccall safe "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr Ds)
+foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr Ds)
+foreign import ccall safe "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr Ds)
+foreign import ccall safe "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr Ds)
+
+foreign import ccall safe "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> IO (Ptr TlsDs)
+foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)
 
-foreign import ccall safe "startSockTls" c_startSockTls :: Ptr SockDs -> CString -> CString -> IO (Ptr TlsDs)
-foreign import ccall safe "getPeer" c_getPeer :: Ptr SockDs -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)
+foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> CInt
+foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO ()
 
-foreign import ccall safe "closeSockDs" c_closeSock :: Ptr SockDs -> IO ()
-foreign import ccall safe "closeFileDs" c_closeFile :: Ptr FileDs -> IO ()
+foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt
 foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
-foreign import ccall safe "closeTlsDs" c_closeTls :: Ptr TlsDs -> IO ()
+foreign import ccall safe "closeTlsDs" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds)
 
-foreign import ccall interruptible "fileDsSend" c_sendFile :: Ptr FileDs -> Ptr CChar -> CInt -> IO CInt
-foreign import ccall interruptible "sockDsSend" c_sendSock :: Ptr SockDs -> Ptr CChar -> CInt -> IO CInt
+foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
+foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt
 
-foreign import ccall interruptible "fileDsRecv" c_recvFile :: Ptr FileDs -> Ptr CChar -> CInt -> IO CInt
-foreign import ccall interruptible "sockDsRecv" c_recvSock :: Ptr SockDs -> Ptr CChar -> CInt -> IO CInt
+foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
+foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt

+ 54 - 58
src/System/IO/Uniform/ds.c

@@ -52,29 +52,29 @@ nethandler getNethandler(const int ipv6, const int port){
   nethandler h = (nethandler)malloc(sizeof(s_nethandler));
   h->ipv6 = ipv6;
   if(ipv6){
-    h->s = socket(AF_INET6, SOCK_STREAM, 0);
+    h->fd = socket(AF_INET6, SOCK_STREAM, 0);
   }else{
-    h->s = socket(AF_INET, SOCK_STREAM, 0);
+    h->fd = socket(AF_INET, SOCK_STREAM, 0);
   }
   int optval = 1;
-  setsockopt(h->s, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
+  setsockopt(h->fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
   int e, en;
   if(ipv6){
     struct sockaddr_in6 add;
     add.sin6_family = AF_INET6;
     zero6addr(add.sin6_addr.s6_addr);
     add.sin6_port = htons(port);
-    e = bind(h->s, (struct sockaddr*) &add, sizeof(add));
+    e = bind(h->fd, (struct sockaddr*) &add, sizeof(add));
   }else{
     struct sockaddr_in add;
     add.sin_family = AF_INET;
     add.sin_addr.s_addr = INADDR_ANY;
     add.sin_port = htons(port);
-    e = bind(h->s, (struct sockaddr*) &add, sizeof(add));
+    e = bind(h->fd, (struct sockaddr*) &add, sizeof(add));
   }
   if(e)
     return clear(h);
-  e = listen(h->s, DEFAULT_LISTENNING_QUEUE);
+  e = listen(h->fd, DEFAULT_LISTENNING_QUEUE);
   if(e)
     return clear(h);
   return h;
@@ -88,13 +88,14 @@ nethandler getPort(const int port){
   return getNethandler(1, port);
 }
 
-fileDs createFromFile(int f){
-  fileDs d = (fileDs)malloc(sizeof(s_fileDs));
-  d->f = f;
+ds createFromFile(int f){
+  ds d = (ds)malloc(sizeof(s_ds));
+  d->tp = file;
+  d->fd = f;
   return d;
 }
 
-fileDs createFromFileName(const char *f){
+ds createFromFileName(const char *f){
   int fd = open(f, O_CREAT | O_RDWR);
   if(fd == -1){
     return NULL;
@@ -102,25 +103,27 @@ fileDs createFromFileName(const char *f){
   return createFromFile(fd);
 }
 
-sockDs createFromHandler(nethandler h){
-  sockDs d = (sockDs)malloc(sizeof(s_sockDs));
+ds createFromHandler(nethandler h){
+  ds d = (ds)malloc(sizeof(s_ds));
+  d->tp = sock;
   unsigned int s = sizeof(d->peer);
-  d->s = accept(h->s, (struct sockaddr*)&(d->peer), &s);
-  if(d->s <= 0)
+  d->fd = accept(h->fd, (struct sockaddr*)&(d->peer), &s);
+  if(d->fd <= 0)
     return clear(d);
   d->ipv6 = d->peer.ss_family == AF_INET6;
   d->server = 1;
   return d;
 }
 
-sockDs createToHost(struct sockaddr *add, const int add_size, const int ipv6){
-  sockDs d = (sockDs)malloc(sizeof(s_sockDs));
+ds createToHost(struct sockaddr *add, const int add_size, const int ipv6){
+  ds d = (ds)malloc(sizeof(s_ds));
+  d->tp = sock;
   if(ipv6){
-    d->s = socket(AF_INET6, SOCK_STREAM, 0);
+    d->fd = socket(AF_INET6, SOCK_STREAM, 0);
   }else{
-    d->s = socket(AF_INET, SOCK_STREAM, 0);
+    d->fd = socket(AF_INET, SOCK_STREAM, 0);
   }
-  if(connect(d->s, add, add_size) < 0){
+  if(connect(d->fd, add, add_size) < 0){
     int e = errno;
     free(d);
     errno = e;
@@ -130,7 +133,7 @@ sockDs createToHost(struct sockaddr *add, const int add_size, const int ipv6){
   return d;
 }
 
-sockDs createToIPv4Host(const unsigned long host, const int port){
+ds createToIPv4Host(const unsigned long host, const int port){
   struct sockaddr_in add;
   add.sin_family = AF_INET;
   add.sin_port = htons(port);
@@ -138,7 +141,7 @@ sockDs createToIPv4Host(const unsigned long host, const int port){
   return createToHost((struct sockaddr*) &add, sizeof(add), 0);
 }
 
-sockDs createToIPv6Host(const unsigned char host[16], const int port){
+ds createToIPv6Host(const unsigned char host[16], const int port){
   struct sockaddr_in6 add;
   add.sin6_family = AF_INET6;
   add.sin6_port = htons(port);
@@ -148,11 +151,11 @@ sockDs createToIPv6Host(const unsigned char host[16], const int port){
   return createToHost((struct sockaddr*) &add, sizeof(add), 1);
 }
 
-int getPeer(sockDs d, unsigned long *ipv4peer, unsigned char ipv6peer[16], int *ipv6){
+int getPeer(ds d, unsigned long *ipv4peer, unsigned char ipv6peer[16], int *ipv6){
   int port = 0;
   struct sockaddr_storage peer;
   int peer_size = sizeof(peer);
-  if(getpeername(d->s, (struct sockaddr*)&peer, &peer_size)){
+  if(getpeername(d->fd, (struct sockaddr*)&peer, &peer_size)){
     return 0;
   }
   if(peer.ss_family == AF_INET){
@@ -171,11 +174,8 @@ int getPeer(sockDs d, unsigned long *ipv4peer, unsigned char ipv6peer[16], int *
   return port;
 }
 
-int fileDsSend(fileDs d, const char *b, const int s){
-  return write(d->f, b, s);
-}
-int sockDsSend(sockDs d, const char *b, const int s){
-  return write(d->s, b, s);
+int sendDs(ds d, const char *b, const int s){
+  return write(d->fd, b, s);
 }
 int tlsDsSend(tlsDs d, const char *b, const int s){
   return SSL_write(d->s, b, s);
@@ -184,11 +184,8 @@ int stdDsSend(const char *b, const int s){
   return write(1, b, s);
 }
 
-int fileDsRecv(fileDs d, char *b, const int s){
-  return read(d->f, b, s);
-}
-int sockDsRecv(sockDs d, char *b, const int s){
-  return read(d->s, b, s);
+int recvDs(ds d, char *b, const int s){
+  return read(d->fd, b, s);
 }
 int tlsDsRecv(tlsDs d, char *b, const int s){
   return SSL_read(d->s, b, s);
@@ -198,36 +195,27 @@ int stdDsRecv(char *b, const int s){
 }
 
 
-void closeFileDs(fileDs d){
-  close(d->f);
-  free(d);
-}
-void closeSockDs(sockDs d){
-  close(d->s);
+int prepareToClose(ds d){
+  int fd = d->fd;
   free(d);
+  return fd;
 }
 
-void closeTlsDs(tlsDs d){
+ds closeTlsDs(tlsDs d){
+  ds original = d->original;
   SSL_shutdown(d->s);
   SSL_shutdown(d->s);
   SSL_free(d->s);
-  switch(d->tp){
-  case file:
-    closeFileDs(d->original);
-    break;
-  case sock:
-    closeSockDs(d->original);
-    break;
-  }
   free(d);
+  return original;
 }
 
 void closeHandler(nethandler h){
-  close(h->s);
+  close(h->fd);
   free(h);
 }
 
-tlsDs startSockTls(sockDs d, const char *cert, const char *key){
+tlsDs startSockTls(ds d, const char *cert, const char *key){
   loadOpenSSL();
   SSL_CTX * ctx = NULL;
   if(d->server)
@@ -239,26 +227,28 @@ tlsDs startSockTls(sockDs d, const char *cert, const char *key){
   SSL_CTX_set_options(ctx, SSL_OP_SINGLE_DH_USE);
   if(cert)
     if(SSL_CTX_use_certificate_chain_file(ctx, cert) != 1){
-      closeSockDs(d);
+      int f = prepareToClose(d);
+      closeFd(f);
       return clear(ctx);
     }
   if(key)
     if(SSL_CTX_use_PrivateKey_file(ctx, key, SSL_FILETYPE_PEM) != 1){
-      closeSockDs(d);
+      int f = prepareToClose(d);
+      closeFd(f);
       return clear(ctx);
     }
   tlsDs t = (tlsDs)malloc(sizeof(s_tlsDs));
   t->original = d;
   if(!(t->s = SSL_new(ctx))){
-    closeSockDs(d);
+    int f = prepareToClose(d);
+    closeFd(f);
     clear(ctx);
     return clear(t);
   }
-  if(!SSL_set_fd(t->s, d->s)){
+  if(!SSL_set_fd(t->s, d->fd)){
     closeTlsDs(t);
     return NULL;
   }
-  printf("Starting handshake\n");
   int retry = 1;
   int e;
   while(retry){
@@ -273,13 +263,19 @@ tlsDs startSockTls(sockDs d, const char *cert, const char *key){
       if((erval == SSL_ERROR_WANT_READ) || (erval == SSL_ERROR_WANT_WRITE)){
 	
       }else{
-	printf("Error\n");
-	ERR_print_errors(t->s->bbio);
+	//ERR_print_errors(t->s->bbio);
 	closeTlsDs(t);
 	return NULL;
       }
     }
   }
-  printf("Success\n");
   return t;
 }
+
+int getFd(ds d){
+  return d->fd;
+}
+
+void closeFd(int fd){
+  close(fd);
+}

+ 21 - 26
src/System/IO/Uniform/ds.h

@@ -2,28 +2,25 @@
 #include <netinet/in.h>
 #include <openssl/ssl.h>
 
+typedef enum {
+  file, std, sock
+} dstype;
+
 typedef struct {
-  int s;
+  int fd;
+  dstype tp;
   int ipv6;
   int server;
   struct sockaddr_storage peer;
-} *sockDs, s_sockDs;
-
-typedef struct {
-  int f;
-} *fileDs, s_fileDs;
+} *ds, s_ds;
 
 #define DEFAULT_LISTENNING_QUEUE 5
 
 typedef struct{
-  int s;
+  int fd;
   int ipv6;
 } *nethandler, s_nethandler;
 
-typedef enum {
-  file, sock
-} dstype;
-
 typedef struct {
   dstype tp;
   void *original;
@@ -33,26 +30,22 @@ typedef struct {
 nethandler getIPv4Port(const int port);
 nethandler getPort(const int port);
 
-fileDs createFromFile(int);
-fileDs createFromFileName(const char*);
-sockDs createFromHandler(nethandler);
-sockDs createToIPv4Host(const unsigned long, const int);
-sockDs createToIPv6Host(const unsigned char[16], const int);
+ds createFromFile(int);
+ds createFromFileName(const char*);
+ds createFromHandler(nethandler);
+ds createToIPv4Host(const unsigned long, const int);
+ds createToIPv6Host(const unsigned char[16], const int);
 
-tlsDs startSockTls(sockDs, const char*, const char*);
+tlsDs startSockTls(ds, const char*, const char*);
 
-int getPeer(sockDs, unsigned long*, unsigned char[16], int*);
+int getPeer(ds, unsigned long*, unsigned char[16], int*);
 
-void closeSockDs(sockDs);
-void closeFileDs(fileDs);
+int closeDs(ds);
 void closeHandler(nethandler);
-void closeTlsDs(tlsDs);
-
-int fileDsSend(fileDs, const char[const], const int);
-int fileDsRecv(fileDs, char[], const int);
+ds closeTls(tlsDs);
 
-int sockDsSend(sockDs, const char[const], const int);
-int sockDsRecv(sockDs, char[], const int);
+int sendDs(ds, const char[const], const int);
+int recvDs(ds, char[], const int);
 
 int tlsDsSend(tlsDs, const char[const], const int);
 int tlsDsRecv(tlsDs, char[], const int);
@@ -60,3 +53,5 @@ int tlsDsRecv(tlsDs, char[], const int);
 int stdDsSend(const char[const], const int);
 int stdDsRecv(char[], const int);
 
+int getFd(ds);
+void closeFd(int);