Targets.hs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. {-# LANGUAGE OverloadedStrings #-}
  2. {-# LANGUAGE ExistentialQuantification #-}
  3. {-# LANGUAGE ForeignFunctionInterface #-}
  4. {-# LANGUAGE InterruptibleFFI #-}
  5. {-# LANGUAGE EmptyDataDecls #-}
  6. module System.IO.Uniform.Targets (
  7. TlsSettings(..),
  8. UniformIO(..),
  9. SocketIO, FileIO, StdIO, TlsIO, SomeIO(..), ByteStringIO,
  10. BoundedPort, connectTo, connectToHost, bindPort, accept, closePort,
  11. openFile, getPeer,
  12. withByteStringIO, withByteStringIO',
  13. mapOverInput)
  14. where
  15. import Foreign
  16. import Foreign.C.Types
  17. import Foreign.C.String
  18. import Foreign.C.Error
  19. import qualified Data.IP as IP
  20. import Data.ByteString (ByteString)
  21. import qualified Data.ByteString as BS
  22. import qualified Data.ByteString.Lazy as LBS
  23. import qualified Data.ByteString.Builder as BSBuild
  24. import qualified Data.List as L
  25. import Control.Exception
  26. import Control.Applicative ((<$>))
  27. import Data.Monoid (mappend)
  28. import qualified Network.Socket as Soc
  29. import System.IO.Error
  30. import Control.Concurrent.MVar
  31. import Data.Default.Class
  32. import System.Posix.Types (Fd(..))
  33. -- | Settings for starttls functions.
  34. data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String, tlsDHParametersFile :: String} deriving (Read, Show)
  35. instance Default TlsSettings where
  36. def = TlsSettings "" "" ""
  37. -- |
  38. -- Typeclass for uniform IO targets.
  39. class UniformIO a where
  40. -- | uRead fd n
  41. --
  42. -- Reads a block of at most n bytes of data from the IO target.
  43. -- Reading will block if there's no data available, but will return immediately
  44. -- if any amount of data is availble.
  45. --
  46. -- Must thow System.IO.Error.EOFError if reading beihond EOF.
  47. uRead :: a -> Int -> IO ByteString
  48. -- | uPut fd text
  49. --
  50. -- Writes all the bytes of text into the IO target. Takes care of retrying if needed.
  51. uPut :: a -> ByteString -> IO ()
  52. -- | fClose fd
  53. --
  54. -- Closes the IO target, releasing any allocated resource. Resources may leak if not called
  55. -- for every oppened fd.
  56. uClose :: a -> IO ()
  57. -- | startTLS fd
  58. --
  59. -- Starts a TLS connection over the IO target.
  60. startTls :: TlsSettings -> a -> IO TlsIO
  61. -- | isSecure fd
  62. --
  63. -- Indicates whether the data written or read from fd is secure at transport.
  64. isSecure :: a -> Bool
  65. -- | A type that wraps any type in the UniformIO class.
  66. data SomeIO = forall a. (UniformIO a) => SomeIO a
  67. instance UniformIO SomeIO where
  68. uRead (SomeIO s) n = uRead s n
  69. uPut (SomeIO s) t = uPut s t
  70. uClose (SomeIO s) = uClose s
  71. startTls set (SomeIO s) = startTls set s
  72. isSecure (SomeIO s) = isSecure s
  73. data Nethandler
  74. -- | A bounded IP port from where to accept SocketIO connections.
  75. newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)}
  76. data Ds
  77. newtype SocketIO = SocketIO {sock :: (Ptr Ds)}
  78. newtype FileIO = FileIO {fd :: (Ptr Ds)}
  79. data TlsDs
  80. newtype TlsIO = TlsIO {tls :: (Ptr TlsDs)}
  81. data StdIO
  82. -- | UniformIO IP connections.
  83. instance UniformIO SocketIO where
  84. uRead s n = do
  85. allocaArray n (
  86. \b -> do
  87. count <- c_recv (sock s) b (fromIntegral n)
  88. if count < 0
  89. then throwErrno "could not read"
  90. else BS.packCStringLen (b, fromIntegral count)
  91. )
  92. uPut s t = do
  93. BS.useAsCStringLen t (
  94. \(str, n) -> do
  95. count <- c_send (sock s) str $ fromIntegral n
  96. if count < 0
  97. then throwErrno "could not write"
  98. else return ()
  99. )
  100. uClose s = do
  101. f <- Fd <$> c_prepareToClose (sock s)
  102. closeFd f
  103. startTls st s = withCString (tlsCertificateChainFile st) (
  104. \cert -> withCString (tlsPrivateKeyFile st) (
  105. \key -> withCString (tlsDHParametersFile st) (
  106. \para -> do
  107. r <- c_startSockTls (sock s) cert key para
  108. if r == nullPtr
  109. then throwErrno "could not start TLS"
  110. else return . TlsIO $ r
  111. )
  112. )
  113. )
  114. isSecure _ = False
  115. -- | UniformIO that reads from stdin and writes to stdout.
  116. instance UniformIO StdIO where
  117. uRead _ n = do
  118. allocaArray n (
  119. \b -> do
  120. count <- c_recvStd b (fromIntegral n)
  121. if count < 0
  122. then throwErrno "could not read"
  123. else BS.packCStringLen (b, fromIntegral count)
  124. )
  125. uPut _ t = do
  126. BS.useAsCStringLen t (
  127. \(str, n) -> do
  128. count <- c_sendStd str $ fromIntegral n
  129. if count < 0
  130. then throwErrno "could not write"
  131. else return ()
  132. )
  133. uClose _ = return ()
  134. startTls _ _ = return . TlsIO $ nullPtr
  135. isSecure _ = False
  136. -- | UniformIO type for file IO.
  137. instance UniformIO FileIO where
  138. uRead s n = do
  139. allocaArray n (
  140. \b -> do
  141. count <- c_recv (fd s) b $ fromIntegral n
  142. if count < 0
  143. then throwErrno "could not read"
  144. else BS.packCStringLen (b, fromIntegral count)
  145. )
  146. uPut s t = do
  147. BS.useAsCStringLen t (
  148. \(str, n) -> do
  149. count <- c_send (fd s) str $ fromIntegral n
  150. if count < 0
  151. then throwErrno "could not write"
  152. else return ()
  153. )
  154. uClose s = do
  155. f <- Fd <$> c_prepareToClose (fd s)
  156. closeFd f
  157. -- Not implemented yet.
  158. startTls _ _ = return . TlsIO $ nullPtr
  159. isSecure _ = False
  160. -- | UniformIO wrapper that applies TLS to communication on IO target.
  161. -- This type is constructed by calling startTls on other targets.
  162. instance UniformIO TlsIO where
  163. uRead s n = do
  164. allocaArray n (
  165. \b -> do
  166. count <- c_recvTls (tls s) b $ fromIntegral n
  167. if count < 0
  168. then throwErrno "could not read"
  169. else BS.packCStringLen (b, fromIntegral count)
  170. )
  171. uPut s t = do
  172. BS.useAsCStringLen t (
  173. \(str, n) -> do
  174. count <- c_sendTls (tls s) str $ fromIntegral n
  175. if count < 0
  176. then throwErrno "could not write"
  177. else return ()
  178. )
  179. uClose s = do
  180. d <- c_closeTls (tls s)
  181. f <- Fd <$> c_prepareToClose d
  182. closeFd f
  183. startTls _ s = return s
  184. isSecure _ = True
  185. -- | Wrapper that does UniformIO that reads and writes on the memory.
  186. data ByteStringIO = ByteStringIO {bsioinput :: MVar (ByteString, Bool), bsiooutput :: MVar BSBuild.Builder}
  187. instance UniformIO ByteStringIO where
  188. uRead s n = do
  189. (i, eof) <- takeMVar . bsioinput $ s
  190. if eof
  191. then do
  192. putMVar (bsioinput s) (i, eof)
  193. ioError $ mkIOError eofErrorType "read past end of input" Nothing Nothing
  194. else do
  195. let (r, i') = BS.splitAt n i
  196. let eof' = (BS.null r && n > 0)
  197. putMVar (bsioinput s) (i', eof')
  198. return r
  199. uPut s t = do
  200. o <- takeMVar . bsiooutput $ s
  201. let o' = mappend o $ BSBuild.byteString t
  202. putMVar (bsiooutput s) o'
  203. uClose _ = return ()
  204. startTls _ _ = return . TlsIO $ nullPtr
  205. isSecure _ = True
  206. -- | connectToHost hostName port
  207. --
  208. -- Connects to the given host and port.
  209. connectToHost :: String -> Int -> IO SocketIO
  210. connectToHost host port = do
  211. ip <- getAddr
  212. connectTo ip port
  213. where
  214. getAddr :: IO IP.IP
  215. getAddr = do
  216. add <- Soc.getAddrInfo Nothing (Just host) Nothing
  217. case add of
  218. [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
  219. (a:_) -> case Soc.addrAddress a of
  220. Soc.SockAddrInet _ a' -> return . IP.IPv4 . IP.fromHostAddress $ a'
  221. Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a'
  222. _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
  223. -- | ConnecctTo ipAddress port
  224. --
  225. -- Connects to the given port of the host at the given IP address.
  226. connectTo :: IP.IP -> Int -> IO SocketIO
  227. connectTo host port = do
  228. r <- case host of
  229. IP.IPv4 host' -> fmap SocketIO $ c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port)
  230. IP.IPv6 host' -> fmap SocketIO $ withArray (ipToArray host') (
  231. \add -> c_connect6 add (fromIntegral port)
  232. )
  233. if sock r == nullPtr
  234. then throwErrno "could not connect to host"
  235. else return r
  236. where
  237. ipToArray :: IP.IPv6 -> [CUChar]
  238. ipToArray ip = let
  239. (w0, w1, w2, w3) = IP.toHostAddress6 ip
  240. in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3]
  241. wtoc :: Word32 -> [CUChar]
  242. wtoc w = let
  243. c0 = fromIntegral $ mod w 256
  244. w1 = div w 256
  245. c1 = fromIntegral $ mod w1 256
  246. w2 = div w1 256
  247. c2 = fromIntegral $ mod w2 256
  248. c3 = fromIntegral $ div w2 256
  249. in [c3, c2, c1, c0]
  250. -- | bindPort port
  251. -- Binds to the given IP port, becoming ready to accept connections on it.
  252. -- Binding to port numbers under 1024 will fail unless performed by the superuser,
  253. -- once bounded, a process can reduce its privileges and still accept clients on that port.
  254. bindPort :: Int -> IO BoundedPort
  255. bindPort port = do
  256. r <- fmap BoundedPort $ c_getPort $ fromIntegral port
  257. if lis r == nullPtr
  258. then throwErrno "could not bind to port"
  259. else return r
  260. -- | accept port
  261. --
  262. -- Accept clients on a port previously bound with bindPort.
  263. accept :: BoundedPort -> IO SocketIO
  264. accept port = do
  265. r <- fmap SocketIO $ c_accept (lis port)
  266. if sock r == nullPtr
  267. then throwErrno "could not accept connection"
  268. else return r
  269. -- | Open a file for bidirectional IO.
  270. openFile :: String -> IO FileIO
  271. openFile fileName = do
  272. r <- withCString fileName (
  273. \f -> fmap FileIO $ c_createFile f
  274. )
  275. if fd r == nullPtr
  276. then throwErrno "could not open file"
  277. else return r
  278. -- | Gets the address of the peer socket of a internet connection.
  279. getPeer :: SocketIO -> IO (IP.IP, Int)
  280. getPeer s = allocaArray 16 (
  281. \p6 -> alloca (
  282. \p4 -> alloca (
  283. \iptype -> do
  284. p <- c_getPeer (sock s) p4 p6 iptype
  285. if p == -1
  286. then throwErrno "could not get peer address"
  287. else do
  288. iptp <- peek iptype
  289. if iptp == 1
  290. then do --IPv6
  291. add <- peekArray 16 p6
  292. return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p)
  293. else do --IPv4
  294. add <- peek p4
  295. return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p)
  296. )
  297. )
  298. )
  299. closeFd :: Fd -> IO ()
  300. closeFd (Fd f) = c_closeFd f
  301. -- | Closes a BoundedPort, and releases any resource used by it.
  302. closePort :: BoundedPort -> IO ()
  303. closePort p = c_closePort (lis p)
  304. -- | withByteStringIO input f
  305. -- Runs f with a ByteStringIO that has the given input, returns f's output and
  306. -- the ByteStringIO output.
  307. withByteStringIO :: ByteString -> (ByteStringIO -> IO a) -> IO (a, LBS.ByteString)
  308. withByteStringIO input f = do
  309. ivar <- newMVar (input, False)
  310. ovar <- newMVar . BSBuild.byteString $ BS.empty
  311. let bsio = ByteStringIO ivar ovar
  312. a <- f bsio
  313. out <- takeMVar . bsiooutput $ bsio
  314. return (a, BSBuild.toLazyByteString out)
  315. -- | The same as withByteStringIO, but returns an strict ByteString
  316. withByteStringIO' :: ByteString -> (ByteStringIO -> IO a) -> IO (a, ByteString)
  317. withByteStringIO' input f = do
  318. (a, t) <- withByteStringIO input f
  319. return (a, LBS.toStrict t)
  320. -- | mapOverInput io block_size f initial
  321. -- Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data
  322. -- where a(0) = initial and the last value after io reaches EOF is returned.
  323. --
  324. -- Notice that the length of read_data might not be equal block_size.
  325. mapOverInput :: forall a io. UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a
  326. mapOverInput io block f initial = do
  327. a <- tryIOError $ uRead io block
  328. case a of
  329. Left e -> if isEOFError e then return initial else throw e -- EOF
  330. Right dt -> do
  331. i <- f initial dt
  332. mapOverInput io block f i
  333. foreign import ccall interruptible "getPort" c_getPort :: CInt -> IO (Ptr Nethandler)
  334. foreign import ccall interruptible "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr Ds)
  335. foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr Ds)
  336. foreign import ccall interruptible "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr Ds)
  337. foreign import ccall interruptible "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr Ds)
  338. foreign import ccall interruptible "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> CString -> IO (Ptr TlsDs)
  339. foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)
  340. --foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> IO CInt
  341. --foreign import ccall safe "getTlsFd" c_getTlsFd :: Ptr TlsDs -> IO CInt
  342. foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO ()
  343. foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt
  344. foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
  345. foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds)
  346. foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
  347. foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
  348. foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt
  349. foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
  350. foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
  351. foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt