Targets.hs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. {-# LANGUAGE OverloadedStrings #-}
  2. {-# LANGUAGE ExistentialQuantification #-}
  3. {-# LANGUAGE ForeignFunctionInterface #-}
  4. {-# LANGUAGE InterruptibleFFI #-}
  5. {-# LANGUAGE EmptyDataDecls #-}
  6. module System.IO.Uniform.Targets (
  7. TlsSettings(..),
  8. UniformIO(..),
  9. SocketIO, FileIO, StdIO, TlsIO, SomeIO(..),
  10. BoundedPort, connectTo, connectToHost, bindPort, accept, closePort,
  11. openFile, getPeer,
  12. mapOverInput)
  13. where
  14. import Foreign
  15. import Foreign.C.Types
  16. import Foreign.C.String
  17. import Foreign.C.Error
  18. import qualified Data.IP as IP
  19. import Data.ByteString (ByteString)
  20. import qualified Data.ByteString as BS
  21. import qualified Data.List as L
  22. import Control.Exception
  23. import Control.Applicative ((<$>))
  24. import qualified Network.Socket as Soc
  25. import System.IO.Error
  26. import Data.Default.Class
  27. import System.Posix.Types (Fd(..))
  28. -- | Settings for starttls functions.
  29. data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String, tlsDHParametersFile :: String} deriving (Read, Show)
  30. instance Default TlsSettings where
  31. def = TlsSettings "" "" ""
  32. -- |
  33. -- Typeclass for uniform IO targets.
  34. class UniformIO a where
  35. -- | uRead fd n
  36. --
  37. -- Reads a block of at most n bytes of data from the IO target.
  38. -- Reading will block if there's no data available, but will return immediately
  39. -- if any amount of data is availble.
  40. --
  41. -- Must thow System.IO.Error.EOFError if reading beihond EOF.
  42. uRead :: a -> Int -> IO ByteString
  43. -- | uPut fd text
  44. --
  45. -- Writes all the bytes of text into the IO target. Takes care of retrying if needed.
  46. uPut :: a -> ByteString -> IO ()
  47. -- | fClose fd
  48. --
  49. -- Closes the IO target, releasing any allocated resource. Resources may leak if not called
  50. -- for every oppened fd.
  51. uClose :: a -> IO ()
  52. -- | startTLS fd
  53. --
  54. -- Starts a TLS connection over the IO target.
  55. startTls :: TlsSettings -> a -> IO TlsIO
  56. -- | isSecure fd
  57. --
  58. -- Indicates whether the data written or read from fd is secure at transport.
  59. isSecure :: a -> Bool
  60. -- | A type that wraps any type in the UniformIO class.
  61. data SomeIO = forall a. (UniformIO a) => SomeIO a
  62. instance UniformIO SomeIO where
  63. uRead (SomeIO s) n = uRead s n
  64. uPut (SomeIO s) t = uPut s t
  65. uClose (SomeIO s) = uClose s
  66. startTls set (SomeIO s) = startTls set s
  67. isSecure (SomeIO s) = isSecure s
  68. data Nethandler
  69. -- | A bounded IP port from where to accept SocketIO connections.
  70. newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)}
  71. data Ds
  72. newtype SocketIO = SocketIO {sock :: (Ptr Ds)}
  73. newtype FileIO = FileIO {fd :: (Ptr Ds)}
  74. data TlsDs
  75. newtype TlsIO = TlsIO {tls :: (Ptr TlsDs)}
  76. data StdIO
  77. -- | UniformIO IP connections.
  78. instance UniformIO SocketIO where
  79. uRead s n = do
  80. allocaArray n (
  81. \b -> do
  82. count <- c_recv (sock s) b (fromIntegral n)
  83. if count < 0
  84. then throwErrno "could not read"
  85. else BS.packCStringLen (b, fromIntegral count)
  86. )
  87. uPut s t = do
  88. BS.useAsCStringLen t (
  89. \(str, n) -> do
  90. count <- c_send (sock s) str $ fromIntegral n
  91. if count < 0
  92. then throwErrno "could not write"
  93. else return ()
  94. )
  95. uClose s = do
  96. f <- Fd <$> c_prepareToClose (sock s)
  97. closeFd f
  98. startTls st s = withCString (tlsCertificateChainFile st) (
  99. \cert -> withCString (tlsPrivateKeyFile st) (
  100. \key -> withCString (tlsDHParametersFile st) (
  101. \para -> do
  102. r <- c_startSockTls (sock s) cert key para
  103. if r == nullPtr
  104. then throwErrno "could not start TLS"
  105. else return . TlsIO $ r
  106. )
  107. )
  108. )
  109. isSecure _ = False
  110. -- | UniformIO that reads from stdin and writes to stdout.
  111. instance UniformIO StdIO where
  112. uRead _ n = do
  113. allocaArray n (
  114. \b -> do
  115. count <- c_recvStd b (fromIntegral n)
  116. if count < 0
  117. then throwErrno "could not read"
  118. else BS.packCStringLen (b, fromIntegral count)
  119. )
  120. uPut _ t = do
  121. BS.useAsCStringLen t (
  122. \(str, n) -> do
  123. count <- c_sendStd str $ fromIntegral n
  124. if count < 0
  125. then throwErrno "could not write"
  126. else return ()
  127. )
  128. uClose _ = return ()
  129. startTls _ _ = return . TlsIO $ nullPtr
  130. isSecure _ = False
  131. -- | UniformIO type for file IO.
  132. instance UniformIO FileIO where
  133. uRead s n = do
  134. allocaArray n (
  135. \b -> do
  136. count <- c_recv (fd s) b $ fromIntegral n
  137. if count < 0
  138. then throwErrno "could not read"
  139. else BS.packCStringLen (b, fromIntegral count)
  140. )
  141. uPut s t = do
  142. BS.useAsCStringLen t (
  143. \(str, n) -> do
  144. count <- c_send (fd s) str $ fromIntegral n
  145. if count < 0
  146. then throwErrno "could not write"
  147. else return ()
  148. )
  149. uClose s = do
  150. f <- Fd <$> c_prepareToClose (fd s)
  151. closeFd f
  152. -- Not implemented yet.
  153. startTls _ _ = return . TlsIO $ nullPtr
  154. isSecure _ = False
  155. -- | UniformIO wrapper that applies TLS to communication on IO target.
  156. -- This type is constructed by calling startTls on other targets.
  157. instance UniformIO TlsIO where
  158. uRead s n = do
  159. allocaArray n (
  160. \b -> do
  161. count <- c_recvTls (tls s) b $ fromIntegral n
  162. if count < 0
  163. then throwErrno "could not read"
  164. else BS.packCStringLen (b, fromIntegral count)
  165. )
  166. uPut s t = do
  167. BS.useAsCStringLen t (
  168. \(str, n) -> do
  169. count <- c_sendTls (tls s) str $ fromIntegral n
  170. if count < 0
  171. then throwErrno "could not write"
  172. else return ()
  173. )
  174. uClose s = do
  175. d <- c_closeTls (tls s)
  176. f <- Fd <$> c_prepareToClose d
  177. closeFd f
  178. startTls _ s = return s
  179. isSecure _ = True
  180. -- | connectToHost hostName port
  181. --
  182. -- Connects to the given host and port.
  183. connectToHost :: String -> Int -> IO SocketIO
  184. connectToHost host port = do
  185. ip <- getAddr
  186. connectTo ip port
  187. where
  188. getAddr :: IO IP.IP
  189. getAddr = do
  190. add <- Soc.getAddrInfo Nothing (Just host) Nothing
  191. case add of
  192. [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
  193. (a:_) -> case Soc.addrAddress a of
  194. Soc.SockAddrInet _ a' -> return . IP.IPv4 . IP.fromHostAddress $ a'
  195. Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a'
  196. _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
  197. -- | ConnecctTo ipAddress port
  198. --
  199. -- Connects to the given port of the host at the given IP address.
  200. connectTo :: IP.IP -> Int -> IO SocketIO
  201. connectTo host port = do
  202. r <- case host of
  203. IP.IPv4 host' -> fmap SocketIO $ c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port)
  204. IP.IPv6 host' -> fmap SocketIO $ withArray (ipToArray host') (
  205. \add -> c_connect6 add (fromIntegral port)
  206. )
  207. if sock r == nullPtr
  208. then throwErrno "could not connect to host"
  209. else return r
  210. where
  211. ipToArray :: IP.IPv6 -> [CUChar]
  212. ipToArray ip = let
  213. (w0, w1, w2, w3) = IP.toHostAddress6 ip
  214. in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3]
  215. wtoc :: Word32 -> [CUChar]
  216. wtoc w = let
  217. c0 = fromIntegral $ mod w 256
  218. w1 = div w 256
  219. c1 = fromIntegral $ mod w1 256
  220. w2 = div w1 256
  221. c2 = fromIntegral $ mod w2 256
  222. c3 = fromIntegral $ div w2 256
  223. in [c3, c2, c1, c0]
  224. -- | bindPort port
  225. -- Binds to the given IP port, becoming ready to accept connections on it.
  226. -- Binding to port numbers under 1024 will fail unless performed by the superuser,
  227. -- once bounded, a process can reduce its privileges and still accept clients on that port.
  228. bindPort :: Int -> IO BoundedPort
  229. bindPort port = do
  230. r <- fmap BoundedPort $ c_getPort $ fromIntegral port
  231. if lis r == nullPtr
  232. then throwErrno "could not bind to port"
  233. else return r
  234. -- | accept port
  235. --
  236. -- Accept clients on a port previously bound with bindPort.
  237. accept :: BoundedPort -> IO SocketIO
  238. accept port = do
  239. r <- fmap SocketIO $ c_accept (lis port)
  240. if sock r == nullPtr
  241. then throwErrno "could not accept connection"
  242. else return r
  243. -- | Open a file for bidirectional IO.
  244. openFile :: String -> IO FileIO
  245. openFile fileName = do
  246. r <- withCString fileName (
  247. \f -> fmap FileIO $ c_createFile f
  248. )
  249. if fd r == nullPtr
  250. then throwErrno "could not open file"
  251. else return r
  252. -- | Gets the address of the peer socket of a internet connection.
  253. getPeer :: SocketIO -> IO (IP.IP, Int)
  254. getPeer s = allocaArray 16 (
  255. \p6 -> alloca (
  256. \p4 -> alloca (
  257. \iptype -> do
  258. p <- c_getPeer (sock s) p4 p6 iptype
  259. if p == -1
  260. then throwErrno "could not get peer address"
  261. else do
  262. iptp <- peek iptype
  263. if iptp == 1
  264. then do --IPv6
  265. add <- peekArray 16 p6
  266. return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p)
  267. else do --IPv4
  268. add <- peek p4
  269. return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p)
  270. )
  271. )
  272. )
  273. closeFd :: Fd -> IO ()
  274. closeFd (Fd f) = c_closeFd f
  275. -- | Closes a BoundedPort, and releases any resource used by it.
  276. closePort :: BoundedPort -> IO ()
  277. closePort p = c_closePort (lis p)
  278. -- | mapOverInput io block_size f initial
  279. -- Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data
  280. -- where a(0) = initial and the last value after io reaches EOF is returned.
  281. --
  282. -- Notice that the length of read_data might not be equal block_size.
  283. mapOverInput :: UniformIO io => io -> Int -> (a -> ByteString -> a) -> a -> IO a
  284. mapOverInput io block f initial = do
  285. a <- tryIOError $ uRead io block
  286. case a of
  287. Left e -> if isEOFError e then return initial else throw e -- EOF
  288. Right dt -> mapOverInput io block f (f initial dt)
  289. foreign import ccall interruptible "getPort" c_getPort :: CInt -> IO (Ptr Nethandler)
  290. foreign import ccall interruptible "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr Ds)
  291. foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr Ds)
  292. foreign import ccall interruptible "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr Ds)
  293. foreign import ccall interruptible "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr Ds)
  294. foreign import ccall interruptible "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> CString -> IO (Ptr TlsDs)
  295. foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)
  296. --foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> IO CInt
  297. --foreign import ccall safe "getTlsFd" c_getTlsFd :: Ptr TlsDs -> IO CInt
  298. foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO ()
  299. foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt
  300. foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
  301. foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds)
  302. foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
  303. foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
  304. foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt
  305. foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
  306. foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
  307. foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt