{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE InterruptibleFFI #-} {-# LANGUAGE EmptyDataDecls #-} module System.IO.Uniform.Targets ( TlsSettings(..), UniformIO(..), SocketIO, FileIO, StdIO, TlsIO, SomeIO(..), BoundedPort, connectTo, connectToHost, bindPort, accept, closePort, openFile, getPeer, mapOverInput) where import Foreign import Foreign.C.Types import Foreign.C.String import Foreign.C.Error import qualified Data.IP as IP import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.List as L import Control.Exception import Control.Applicative ((<$>)) import qualified Network.Socket as Soc import System.IO.Error import Data.Default.Class import System.Posix.Types (Fd(..)) -- | Settings for starttls functions. data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String, tlsDHParametersFile :: String} deriving (Read, Show) instance Default TlsSettings where def = TlsSettings "" "" "" -- | -- Typeclass for uniform IO targets. class UniformIO a where -- | uRead fd n -- -- Reads a block of at most n bytes of data from the IO target. -- Reading will block if there's no data available, but will return immediately -- if any amount of data is availble. -- -- Must thow System.IO.Error.EOFError if reading beihond EOF. uRead :: a -> Int -> IO ByteString -- | uPut fd text -- -- Writes all the bytes of text into the IO target. Takes care of retrying if needed. uPut :: a -> ByteString -> IO () -- | fClose fd -- -- Closes the IO target, releasing any allocated resource. Resources may leak if not called -- for every oppened fd. uClose :: a -> IO () -- | startTLS fd -- -- Starts a TLS connection over the IO target. startTls :: TlsSettings -> a -> IO TlsIO -- | isSecure fd -- -- Indicates whether the data written or read from fd is secure at transport. isSecure :: a -> Bool -- | A type that wraps any type in the UniformIO class. data SomeIO = forall a. (UniformIO a) => SomeIO a instance UniformIO SomeIO where uRead (SomeIO s) n = uRead s n uPut (SomeIO s) t = uPut s t uClose (SomeIO s) = uClose s startTls set (SomeIO s) = startTls set s isSecure (SomeIO s) = isSecure s data Nethandler -- | A bounded IP port from where to accept SocketIO connections. newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)} data Ds newtype SocketIO = SocketIO {sock :: (Ptr Ds)} newtype FileIO = FileIO {fd :: (Ptr Ds)} data TlsDs newtype TlsIO = TlsIO {tls :: (Ptr TlsDs)} data StdIO -- | UniformIO IP connections. instance UniformIO SocketIO where uRead s n = do allocaArray n ( \b -> do count <- c_recv (sock s) b (fromIntegral n) if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_send (sock s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do f <- Fd <$> c_prepareToClose (sock s) closeFd f startTls st s = withCString (tlsCertificateChainFile st) ( \cert -> withCString (tlsPrivateKeyFile st) ( \key -> withCString (tlsDHParametersFile st) ( \para -> do r <- c_startSockTls (sock s) cert key para if r == nullPtr then throwErrno "could not start TLS" else return . TlsIO $ r ) ) ) isSecure _ = False -- | UniformIO that reads from stdin and writes to stdout. instance UniformIO StdIO where uRead _ n = do allocaArray n ( \b -> do count <- c_recvStd b (fromIntegral n) if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut _ t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_sendStd str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose _ = return () startTls _ _ = return . TlsIO $ nullPtr isSecure _ = False -- | UniformIO type for file IO. instance UniformIO FileIO where uRead s n = do allocaArray n ( \b -> do count <- c_recv (fd s) b $ fromIntegral n if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_send (fd s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do f <- Fd <$> c_prepareToClose (fd s) closeFd f -- Not implemented yet. startTls _ _ = return . TlsIO $ nullPtr isSecure _ = False -- | UniformIO wrapper that applies TLS to communication on IO target. -- This type is constructed by calling startTls on other targets. instance UniformIO TlsIO where uRead s n = do allocaArray n ( \b -> do count <- c_recvTls (tls s) b $ fromIntegral n if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_sendTls (tls s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do d <- c_closeTls (tls s) f <- Fd <$> c_prepareToClose d closeFd f startTls _ s = return s isSecure _ = True -- | connectToHost hostName port -- -- Connects to the given host and port. connectToHost :: String -> Int -> IO SocketIO connectToHost host port = do ip <- getAddr connectTo ip port where getAddr :: IO IP.IP getAddr = do add <- Soc.getAddrInfo Nothing (Just host) Nothing case add of [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing (a:_) -> case Soc.addrAddress a of Soc.SockAddrInet _ a' -> return . IP.IPv4 . IP.fromHostAddress $ a' Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a' _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing -- | ConnecctTo ipAddress port -- -- Connects to the given port of the host at the given IP address. connectTo :: IP.IP -> Int -> IO SocketIO connectTo host port = do r <- case host of IP.IPv4 host' -> fmap SocketIO $ c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port) IP.IPv6 host' -> fmap SocketIO $ withArray (ipToArray host') ( \add -> c_connect6 add (fromIntegral port) ) if sock r == nullPtr then throwErrno "could not connect to host" else return r where ipToArray :: IP.IPv6 -> [CUChar] ipToArray ip = let (w0, w1, w2, w3) = IP.toHostAddress6 ip in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3] wtoc :: Word32 -> [CUChar] wtoc w = let c0 = fromIntegral $ mod w 256 w1 = div w 256 c1 = fromIntegral $ mod w1 256 w2 = div w1 256 c2 = fromIntegral $ mod w2 256 c3 = fromIntegral $ div w2 256 in [c3, c2, c1, c0] -- | bindPort port -- Binds to the given IP port, becoming ready to accept connections on it. -- Binding to port numbers under 1024 will fail unless performed by the superuser, -- once bounded, a process can reduce its privileges and still accept clients on that port. bindPort :: Int -> IO BoundedPort bindPort port = do r <- fmap BoundedPort $ c_getPort $ fromIntegral port if lis r == nullPtr then throwErrno "could not bind to port" else return r -- | accept port -- -- Accept clients on a port previously bound with bindPort. accept :: BoundedPort -> IO SocketIO accept port = do r <- fmap SocketIO $ c_accept (lis port) if sock r == nullPtr then throwErrno "could not accept connection" else return r -- | Open a file for bidirectional IO. openFile :: String -> IO FileIO openFile fileName = do r <- withCString fileName ( \f -> fmap FileIO $ c_createFile f ) if fd r == nullPtr then throwErrno "could not open file" else return r -- | Gets the address of the peer socket of a internet connection. getPeer :: SocketIO -> IO (IP.IP, Int) getPeer s = allocaArray 16 ( \p6 -> alloca ( \p4 -> alloca ( \iptype -> do p <- c_getPeer (sock s) p4 p6 iptype if p == -1 then throwErrno "could not get peer address" else do iptp <- peek iptype if iptp == 1 then do --IPv6 add <- peekArray 16 p6 return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p) else do --IPv4 add <- peek p4 return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p) ) ) ) closeFd :: Fd -> IO () closeFd (Fd f) = c_closeFd f -- | Closes a BoundedPort, and releases any resource used by it. closePort :: BoundedPort -> IO () closePort p = c_closePort (lis p) -- | mapOverInput io block_size f initial -- Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data -- where a(0) = initial and the last value after io reaches EOF is returned. -- -- Notice that the length of read_data might not be equal block_size. mapOverInput :: UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a mapOverInput io block f initial = do a <- tryIOError $ uRead io block case a of Left e -> if isEOFError e then return initial else throw e -- EOF Right dt -> mapOverInput io block f (f initial dt) foreign import ccall interruptible "getPort" c_getPort :: CInt -> IO (Ptr Nethandler) foreign import ccall interruptible "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr Ds) foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr Ds) foreign import ccall interruptible "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr Ds) foreign import ccall interruptible "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr Ds) foreign import ccall interruptible "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> CString -> IO (Ptr TlsDs) foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt) --foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> IO CInt --foreign import ccall safe "getTlsFd" c_getTlsFd :: Ptr TlsDs -> IO CInt foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO () foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO () foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds) foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt