Browse Source

Changed startTls into a more extensible interface

Marcos Dumay de Medeiros 8 years ago
parent
commit
0375faee38

+ 3 - 46
src/System/IO/Uniform.hs

@@ -11,33 +11,17 @@
 module System.IO.Uniform (
   UniformIO(..),
   TlsSettings(..),
-  SomeIO(..), TlsIO,
+  SomeIO(..),
   mapOverInput
   ) where
 
-import System.IO.Uniform.External
-
-import Foreign
---import Foreign.C.Types
---import Foreign.C.String
-import Foreign.C.Error
---import qualified Data.IP as IP
 import Data.ByteString (ByteString)
-import qualified Data.ByteString as BS
---import qualified Data.ByteString.Lazy as LBS
---import qualified Data.ByteString.Builder as BSBuild
---import qualified Data.List as L
 import Control.Exception
 import Control.Applicative ((<$>))
---import Data.Monoid (mappend)
---import qualified Network.Socket as Soc
 import System.IO.Error
---import Control.Concurrent.MVar
 
 import Data.Default.Class
 
-import System.Posix.Types (Fd(..))
-
 -- |
 -- Typeclass for uniform IO targets.
 class UniformIO a where
@@ -61,7 +45,7 @@ class UniformIO a where
   -- | startTLS fd
   --
   --  Starts a TLS connection over the IO target.
-  startTls :: TlsSettings -> a -> IO TlsIO
+  startTls :: TlsSettings -> a -> IO a
   -- | isSecure fd
   --
   --  Indicates whether the data written or read from fd is secure at transport.
@@ -74,7 +58,7 @@ instance UniformIO SomeIO where
   uRead (SomeIO s) n = uRead s n
   uPut (SomeIO s) t  = uPut s t
   uClose (SomeIO s) = uClose s
-  startTls set (SomeIO s) = startTls set s
+  startTls set (SomeIO s) = SomeIO <$> startTls set s
   isSecure (SomeIO s) = isSecure s
 
 -- | Settings for starttls functions.
@@ -83,33 +67,6 @@ data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChain
 instance Default TlsSettings where
   def = TlsSettings "" "" ""
   
--- | UniformIO wrapper that applies TLS to communication on IO target.
--- This type is constructed by calling startTls on other targets.
-instance UniformIO TlsIO where
-  uRead s n = do
-    allocaArray n (
-      \b -> do
-        count <- c_recvTls (tls s) b $ fromIntegral n
-        if count < 0
-          then throwErrno "could not read"
-          else BS.packCStringLen (b, fromIntegral count)
-      )
-  uPut s t = do
-    BS.useAsCStringLen t (
-      \(str, n) -> do
-        count <- c_sendTls (tls s) str $ fromIntegral n
-        if count < 0
-          then throwErrno "could not write"
-          else return ()
-      )
-  uClose s = do
-    d <- c_closeTls (tls s)
-    f <- Fd <$> c_prepareToClose d
-    closeFd f
-  startTls _ s = return s
-  isSecure _ = True
-
-
 -- | mapOverInput io block_size f initial
 --   Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data
 --   where a(0) = initial and the last value after io reaches EOF is returned.

+ 1 - 7
src/System/IO/Uniform/ByteString.hs

@@ -10,13 +10,7 @@ module System.IO.Uniform.ByteString (
   ) where
 
 import System.IO.Uniform
-import System.IO.Uniform.External
 
-import Foreign
---import Foreign.C.Types
---import Foreign.C.String
---import Foreign.C.Error
--- import qualified Data.IP as IP
 import Data.ByteString (ByteString)
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
@@ -52,7 +46,7 @@ instance UniformIO ByteStringIO where
     let o' = mappend o $ BSBuild.byteString t
     putMVar (bsiooutput s) o'
   uClose _ = return ()
-  startTls _ _ = return . TlsIO $ nullPtr
+  startTls _ a = return a
   isSecure _ = True
 
 -- | withByteStringIO input f

+ 2 - 3
src/System/IO/Uniform/External.hs

@@ -14,10 +14,9 @@ data Nethandler
 -- | A bounded IP port from where to accept SocketIO connections.
 newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)}
 data Ds
-newtype SocketIO = SocketIO {sock :: (Ptr Ds)}
-newtype FileIO = FileIO {fd :: (Ptr Ds)}
 data TlsDs
-newtype TlsIO = TlsIO {tls :: (Ptr TlsDs)}
+data SocketIO = SocketIO {sock :: (Ptr Ds)} | TlsSocketIO {bio :: (Ptr TlsDs)}
+newtype FileIO = FileIO {fd :: (Ptr Ds)}
 data StdIO
 
 closeFd :: Fd -> IO ()

+ 2 - 3
src/System/IO/Uniform/File.hs

@@ -55,9 +55,8 @@ instance UniformIO FileIO where
   uClose s = do
     f <- Fd <$> c_prepareToClose (fd s)
     closeFd f
-  -- Not implemented yet.
-  startTls _ _ = return . TlsIO $ nullPtr
-  isSecure _ = False
+  startTls _ f = return f
+  isSecure _ = True
   
   
 -- | Open a file for bidirectional IO.

+ 32 - 10
src/System/IO/Uniform/Network.hs

@@ -41,37 +41,59 @@ import System.Posix.Types (Fd(..))
 
 -- | UniformIO IP connections.
 instance UniformIO SocketIO where
-  uRead s n = do
+  uRead (SocketIO s) n = do
     allocaArray n (
       \b -> do
-        count <- c_recv (sock s) b (fromIntegral n)
+        count <- c_recv s b (fromIntegral n)
         if count < 0
           then throwErrno "could not read"
           else BS.packCStringLen (b, fromIntegral count)
       )
-  uPut s t = do
+  uRead (TlsSocketIO s) n = do
+    allocaArray n (
+      \b -> do
+        count <- c_recvTls s b $ fromIntegral n
+        if count < 0
+          then throwErrno "could not read"
+          else BS.packCStringLen (b, fromIntegral count)
+      )
+  uPut (SocketIO s) t = do
     BS.useAsCStringLen t (
       \(str, n) -> do
-        count <- c_send (sock s) str $ fromIntegral n
+        count <- c_send s str $ fromIntegral n
         if count < 0
           then throwErrno "could not write"
           else return ()
       )
-  uClose s = do
-    f <- Fd <$> c_prepareToClose (sock s)
+  uPut (TlsSocketIO s) t = do
+    BS.useAsCStringLen t (
+      \(str, n) -> do
+        count <- c_sendTls s str $ fromIntegral n
+        if count < 0
+          then throwErrno "could not write"
+          else return ()
+      )
+  uClose (SocketIO s) = do
+    f <- Fd <$> c_prepareToClose s
+    closeFd f
+  uClose (TlsSocketIO s) = do
+    d <- c_closeTls s
+    f <- Fd <$> c_prepareToClose d
     closeFd f
-  startTls st s = withCString (tlsCertificateChainFile st) (
+  startTls st (SocketIO s) = withCString (tlsCertificateChainFile st) (
     \cert -> withCString (tlsPrivateKeyFile st) (
       \key -> withCString (tlsDHParametersFile st) (
         \para -> do
-          r <- c_startSockTls (sock s) cert key para
+          r <- c_startSockTls s cert key para
           if r == nullPtr
             then throwErrno "could not start TLS"
-            else return . TlsIO $ r
+            else return . TlsSocketIO $ r
         )
       )
     )
-  isSecure _ = False
+  startTls _ s@(TlsSocketIO _) = return s
+  isSecure (SocketIO _) = False
+  isSecure (TlsSocketIO _) = True
 
 
 -- | connectToHost hostName port

+ 2 - 2
src/System/IO/Uniform/Std.hs

@@ -50,5 +50,5 @@ instance UniformIO StdIO where
           else return ()
       )
   uClose _ = return ()
-  startTls _ _ = return . TlsIO $ nullPtr
-  isSecure _ = False
+  startTls _ a = return a
+  isSecure _ = True