Browse Source

Added ByteStringIO

Marcos Dumay de Medeiros 8 years ago
parent
commit
0c43866b17
2 changed files with 79 additions and 12 deletions
  1. 50 3
      src/System/IO/Uniform/Targets.hs
  2. 29 9
      test/Targets.hs

+ 50 - 3
src/System/IO/Uniform/Targets.hs

@@ -7,9 +7,10 @@
 module System.IO.Uniform.Targets (
   TlsSettings(..),
   UniformIO(..),
-  SocketIO, FileIO, StdIO, TlsIO, SomeIO(..),
+  SocketIO, FileIO, StdIO, TlsIO, SomeIO(..), ByteStringIO,
   BoundedPort, connectTo, connectToHost, bindPort, accept, closePort,
   openFile, getPeer,
+  withByteStringIO, withByteStringIO',
   mapOverInput)
        where
 
@@ -20,11 +21,15 @@ import Foreign.C.Error
 import qualified Data.IP as IP
 import Data.ByteString (ByteString)
 import qualified Data.ByteString as BS
+import qualified Data.ByteString.Lazy as LBS
+import qualified Data.ByteString.Builder as BSBuild
 import qualified Data.List as L
 import Control.Exception
 import Control.Applicative ((<$>))
+import Data.Monoid (mappend)
 import qualified Network.Socket as Soc
 import System.IO.Error
+import Control.Concurrent.MVar
 
 import Data.Default.Class
 
@@ -192,6 +197,28 @@ instance UniformIO TlsIO where
   startTls _ s = return s
   isSecure _ = True
 
+-- | Wrapper that does UniformIO that reads and writes on the memory.
+data ByteStringIO = ByteStringIO {bsioinput :: MVar (ByteString, Bool), bsiooutput :: MVar BSBuild.Builder}
+instance UniformIO ByteStringIO where
+  uRead s n = do
+    (i, eof) <- takeMVar . bsioinput $ s
+    if eof
+    then do
+      putMVar (bsioinput s) (i, eof)
+      ioError $ mkIOError eofErrorType "read past end of input" Nothing Nothing
+    else do
+      let (r, i') = BS.splitAt n i
+      let eof' = (BS.null r && n > 0)
+      putMVar (bsioinput s) (i', eof')
+      return r
+  uPut s t = do
+    o <- takeMVar . bsiooutput $ s
+    let o' = mappend o $ BSBuild.byteString t
+    putMVar (bsiooutput s) o'
+  uClose _ = return ()
+  startTls _ _ = return . TlsIO $ nullPtr
+  isSecure _ = True
+
 -- | connectToHost hostName port
 --
 --  Connects to the given host and port.
@@ -299,17 +326,37 @@ closeFd (Fd f) = c_closeFd f
 closePort :: BoundedPort -> IO ()
 closePort p = c_closePort (lis p)
 
+-- | withByteStringIO input f
+--   Runs f with a ByteStringIO that has the given input, returns f's output and
+--   the ByteStringIO output.
+withByteStringIO :: ByteString -> (ByteStringIO -> IO a) -> IO (a, LBS.ByteString)
+withByteStringIO input f = do
+  ivar <- newMVar (input, False)
+  ovar <- newMVar . BSBuild.byteString $ BS.empty
+  let bsio = ByteStringIO ivar ovar
+  a <- f bsio
+  out <- takeMVar . bsiooutput $ bsio
+  return (a, BSBuild.toLazyByteString out)
+
+-- | The same as withByteStringIO, but returns an strict ByteString
+withByteStringIO' :: ByteString -> (ByteStringIO -> IO a) -> IO (a, ByteString)
+withByteStringIO' input f = do
+  (a, t) <- withByteStringIO input f
+  return (a, LBS.toStrict t)
+
 -- | mapOverInput io block_size f initial
 --   Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data
 --   where a(0) = initial and the last value after io reaches EOF is returned.
 --
 --   Notice that the length of read_data might not be equal block_size.
-mapOverInput :: UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a
+mapOverInput :: forall a io. UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a
 mapOverInput io block f initial = do
   a <- tryIOError $ uRead io block
   case a of
     Left e -> if isEOFError e then return initial else throw e -- EOF
-    Right dt -> mapOverInput io block f (f initial dt)
+    Right dt -> do
+      i <- f initial dt
+      mapOverInput io block f i
 
 
 foreign import ccall interruptible "getPort" c_getPort :: CInt -> IO (Ptr Nethandler)

+ 29 - 9
test/Targets.hs

@@ -8,12 +8,15 @@ import Control.Concurrent(forkIO)
 import qualified System.IO.Uniform as U
 import System.Timeout (timeout)
 import qualified Data.ByteString.Char8 as C8
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as BS
 
 tests :: IO [Test]
 tests = return [
   simpleTest "network" testNetwork,
   simpleTest "file" testFile,
-  simpleTest "network TLS" testTls
+  simpleTest "network TLS" testTls,
+  simpleTest "byte string" testBS
   ]
 
 testNetwork :: IO Progress
@@ -26,14 +29,14 @@ testNetwork = do
     U.uClose s
     return ()
   r' <- timeout 1000000 $ do
-        s <- U.connectToHost "127.0.0.1" 8888
-        let l = "abcdef\n"
-        U.uPut s l
-        l' <- U.uRead s 100
-        U.uClose s
-        if l == l'
-          then return . Finished $ Pass
-          else return . Finished . Fail . C8.unpack $ l'
+    s <- U.connectToHost "127.0.0.1" 8888
+    let l = "abcdef\n"
+    U.uPut s l
+    l' <- U.uRead s 100
+    U.uClose s
+    if l == l'
+      then return . Finished $ Pass
+      else return . Finished . Fail . C8.unpack $ l'
   U.closePort recv
   case r' of
     Just r -> return r
@@ -78,3 +81,20 @@ testTls = do
   case r' of
     Just r -> return r
     Nothing -> return . Finished . Fail $ "Execution blocked"
+
+testBS :: IO Progress
+testBS = do
+  let dt = "Some data to test ByteString"
+  (len, echo) <- U.withByteStringIO' dt (
+    \io -> let
+      count = countAndEcho io :: Int -> ByteString -> IO Int
+      in U.mapOverInput io 2 count 0
+    ) :: IO (Int, ByteString)
+  if dt /= echo || BS.length dt /= len
+    then return . Finished . Fail $ "Failure on ByteStringIO test"
+    else return . Finished $ Pass
+  where
+    countAndEcho :: U.UniformIO io => io -> Int -> ByteString -> IO Int
+    countAndEcho io initial dt = do
+      U.uPut io dt
+      return $ initial + BS.length dt