Browse Source

Passes tests. GHC conc does not work.

Marcos Dumay de Medeiros 6 years ago
parent
commit
c7aea152ec
7 changed files with 75 additions and 71 deletions
  1. 1 0
      .gitignore
  2. 55 49
      src/System/IO/Uniform/Targets.hs
  3. 15 18
      src/System/IO/Uniform/ds.c
  4. 1 0
      src/System/IO/Uniform/ds.h
  5. 2 2
      test/Targets.hs
  6. 0 1
      test/testFile
  7. 1 1
      uniform-io.cabal

+ 1 - 0
.gitignore

@@ -1,6 +1,7 @@
 dist/
 .cabal-sandbox/
 cabal.sandbox.config
+test/testFile
 *~
 *.[ao]
 **/*~

+ 55 - 49
src/System/IO/Uniform/Targets.hs

@@ -21,7 +21,6 @@ import System.IO.Error
 
 import Data.Default.Class
 
-import GHC.Conc (closeFdWith, threadWaitRead, threadWaitWrite)
 import System.Posix.Types (Fd(..))
 
 -- | Settings for starttls functions.
@@ -78,23 +77,25 @@ newtype TlsStream = TlsStream {tls :: (Ptr TlsDs)}
 
 -- | UniformIO IP connections.
 instance UniformIO SocketIO where
-  uRead s n = allocaArray n (
-    \b -> do
-      count <- c_recv (sock s) b (fromIntegral n)
-      if count < 0
-        then throwErrno "could not read"
-        else BS.packCStringLen (b, fromIntegral count)
-    )
-  uPut s t = BS.useAsCStringLen t (
-    \(str, n) -> do
-      count <- c_send (sock s) str $ fromIntegral n
-      if count < 0
-        then throwErrno "could not write"
-        else return ()
-    )
+  uRead s n = do
+    allocaArray n (
+      \b -> do
+        count <- c_recv (sock s) b (fromIntegral n)
+        if count < 0
+          then throwErrno "could not read"
+          else BS.packCStringLen (b, fromIntegral count)
+      )
+  uPut s t = do
+    BS.useAsCStringLen t (
+      \(str, n) -> do
+        count <- c_send (sock s) str $ fromIntegral n
+        if count < 0
+          then throwErrno "could not write"
+          else return ()
+      )
   uClose s = do
     f <- Fd <$> c_prepareToClose (sock s)
-    closeFdWith closeFd f
+    closeFd f
   startTls st s = withCString (tlsCertificateChainFile st) (
     \cert -> withCString (tlsPrivateKeyFile st) (
       \key -> withCString (tlsDHParametersFile st) (
@@ -110,23 +111,25 @@ instance UniformIO SocketIO where
   
 -- | UniformIO type for file IO.
 instance UniformIO FileIO where
-  uRead s n = allocaArray n (
-    \b -> do
-      count <- c_recv (fd s) b $ fromIntegral n
-      if count < 0
-        then throwErrno "could not read"
-        else  BS.packCStringLen (b, fromIntegral count)
-    )
-  uPut s t = BS.useAsCStringLen t (
-    \(str, n) -> do
-      count <- c_send (fd s) str $ fromIntegral n
-      if count < 0
-        then throwErrno "could not write"
-        else return ()
-    )
+  uRead s n = do
+    allocaArray n (
+      \b -> do
+        count <- c_recv (fd s) b $ fromIntegral n
+        if count < 0
+          then throwErrno "could not read"
+          else  BS.packCStringLen (b, fromIntegral count)
+      )
+  uPut s t = do
+    BS.useAsCStringLen t (
+      \(str, n) -> do
+        count <- c_send (fd s) str $ fromIntegral n
+        if count < 0
+          then throwErrno "could not write"
+          else return ()
+      )
   uClose s = do
     f <- Fd <$> c_prepareToClose (fd s)
-    closeFdWith closeFd f
+    closeFd f
   -- Not implemented yet.
   startTls _ _ = return . TlsStream $ nullPtr
   isSecure _ = False
@@ -134,24 +137,26 @@ instance UniformIO FileIO where
 -- | UniformIO wrapper that applies TLS to communication on IO target.
 -- This type is constructed by calling startTls on other targets.
 instance UniformIO TlsStream where
-  uRead s n = allocaArray n (
-    \b -> do
-      count <- c_recvTls (tls s) b $ fromIntegral n
-      if count < 0
-        then throwErrno "could not read"
-        else BS.packCStringLen (b, fromIntegral count)
-    )
-  uPut s t = BS.useAsCStringLen t (
-    \(str, n) -> do
-      count <- c_sendTls (tls s) str $ fromIntegral n
-      if count < 0
-        then throwErrno "could not write"
-        else return ()
-    )
+  uRead s n = do
+    allocaArray n (
+      \b -> do
+        count <- c_recvTls (tls s) b $ fromIntegral n
+        if count < 0
+          then throwErrno "could not read"
+          else BS.packCStringLen (b, fromIntegral count)
+      )
+  uPut s t = do
+    BS.useAsCStringLen t (
+      \(str, n) -> do
+        count <- c_sendTls (tls s) str $ fromIntegral n
+        if count < 0
+          then throwErrno "could not write"
+          else return ()
+      )
   uClose s = do
     d <- c_closeTls (tls s)
     f <- Fd <$> c_prepareToClose d
-    closeFdWith closeFd f
+    closeFd f
   startTls _ s = return s
   isSecure _ = True
 
@@ -271,7 +276,8 @@ foreign import ccall interruptible "createToIPv6Host" c_connect6 :: Ptr CUChar -
 foreign import ccall interruptible "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> CString -> IO (Ptr TlsDs)
 foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)
 
-foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> CInt
+--foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> IO CInt
+--foreign import ccall safe "getTlsFd" c_getTlsFd :: Ptr TlsDs -> IO CInt
 foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO ()
 
 foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt
@@ -279,9 +285,9 @@ foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
 foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds)
 
 foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
-foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
+--foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt
 
 foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
-foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
+--foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt

+ 15 - 18
src/System/IO/Uniform/ds.c

@@ -11,6 +11,7 @@
 #include <openssl/bio.h>
 #include <openssl/ssl.h>
 #include <openssl/err.h>
+#include <pthread.h>
 
 #include "ds.h"
 
@@ -25,7 +26,12 @@ void *clear(void *ptr){
   return NULL;
 }
 
+pthread_mutex_t loadLock;
+
 void loadOpenSSL(const char *dh){
+  if(openSslLoaded)
+    return;
+  pthread_mutex_lock(&loadLock);
   if(!openSslLoaded){
     SSL_load_error_strings();
     ERR_load_BIO_strings();
@@ -34,6 +40,7 @@ void loadOpenSSL(const char *dh){
     OpenSSL_add_all_algorithms();    
     openSslLoaded = 1;    
   }
+  pthread_mutex_unlock(&loadLock);
 }
 
 void copy6addr(unsigned char d[16], const unsigned char s[16]){
@@ -217,9 +224,7 @@ void closeHandler(nethandler h){
 }
 
 tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
-  fprintf(stderr, "Starting TLS\n");
   loadOpenSSL(dh);
-  fprintf(stderr, "OpenSSL loaded\n");
   SSL_CTX * ctx = NULL;
   if(d->server)
     ctx = SSL_CTX_new(TLSv1_server_method());
@@ -227,13 +232,9 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
     ctx = SSL_CTX_new(TLSv1_client_method());
   if(!ctx)
     return NULL;
-  fprintf(stderr, "Got CTX\n");
   if(d->server){
     FILE *dhfile = fopen(dh, "r");
-    fprintf(stderr, "dh is %s\n", dh);
-    fprintf(stderr, "dhfile is %x\n", dhfile);
     DH *dhdt = PEM_read_DHparams(dhfile, NULL, NULL, NULL);
-    fprintf(stderr, "dhdt is %x\n", dhdt);
     fclose(dhfile);
     if(SSL_CTX_set_tmp_dh(ctx, dhdt) <= 0){
       int f = prepareToClose(d);
@@ -241,25 +242,20 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
       clear(dhdt);
       return clear(ctx);    
     }
-    fprintf(stderr, "Set DH parameters\n");
   }
   SSL_CTX_set_options(ctx, SSL_OP_SINGLE_DH_USE);
-  fprintf(stderr, "Set CTX options\n");
-  fprintf(stderr, "Set options\n");
   if(cert)
     if(SSL_CTX_use_certificate_chain_file(ctx, cert) != 1){
       int f = prepareToClose(d);
       closeFd(f);
       return clear(ctx);
     }
-  fprintf(stderr, "Set cert\n");
   if(key)
     if(SSL_CTX_use_PrivateKey_file(ctx, key, SSL_FILETYPE_PEM) != 1){
       int f = prepareToClose(d);
       closeFd(f);
       return clear(ctx);
     }
-  fprintf(stderr, "Set key\n");
   tlsDs t = (tlsDs)malloc(sizeof(s_tlsDs));
   t->original = d;
   if(!(t->s = SSL_new(ctx))){
@@ -268,12 +264,10 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
     clear(ctx);
     return clear(t);
   }
-  fprintf(stderr, "Got SSL\n");
   if(!SSL_set_fd(t->s, d->fd)){
     closeTls(t);
     return NULL;
   }
-  fprintf(stderr, "Set fd\n");
   int retry = 1;
   int e;
   while(retry){
@@ -287,10 +281,9 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
     }
     if(e <= 0){
       unsigned long erval = SSL_get_error(t->s, e);
-      char ertxt[300];
-      ERR_error_string(erval, ertxt);
-      fprintf(stderr, "SSL Error: %s\n", ertxt);
-      ERR_print_errors(t->s->bbio);
+      //char ertxt[300];
+      //ERR_error_string(erval, ertxt);
+      //fprintf(stderr, "SSL Error: %s\n", ertxt);
       if((erval == SSL_ERROR_WANT_READ) || (erval == SSL_ERROR_WANT_WRITE)){
 	//Here goes support to non-blocking IO, once it's supported
 	//retry = 1;
@@ -300,7 +293,6 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
       }
     }
   }
-  fprintf(stderr, "TLS started\n");
   return t;
 }
 
@@ -308,6 +300,11 @@ int getFd(ds d){
   return d->fd;
 }
 
+int getTlsFd(tlsDs t){
+  ds d = t->original;
+  return d->fd;
+}
+
 void closeFd(int fd){
   close(fd);
 }

+ 1 - 0
src/System/IO/Uniform/ds.h

@@ -54,4 +54,5 @@ int stdDsSend(const char[const], const int);
 int stdDsRecv(char[], const int);
 
 int getFd(ds);
+int getTlsFd(tlsDs);
 void closeFd(int);

+ 2 - 2
test/Targets.hs

@@ -12,10 +12,10 @@ import qualified Data.ByteString.Char8 as C8
 tests :: IO [Test]
 tests = return [
   simpleTest "network" testNetwork,
-  simpleTest "file" testFile
+  simpleTest "file" testFile,
   --Test framework fails on this test
   --actual script works as expected
-  --simpleTest "network TLS" testTls
+  simpleTest "network TLS" testTls
   ]
 
 testNetwork :: IO Progress

+ 0 - 1
test/testFile

@@ -1 +0,0 @@
-abcde

+ 1 - 1
uniform-io.cabal

@@ -111,7 +111,7 @@ library
   includes: ds.h
   install-includes: ds.h
   C-Sources: src/System/IO/Uniform/ds.c
-  extra-libraries: ssl
+  extra-libraries: ssl, pthread
 
 Test-suite targets
   type: detailed-0.9