Przeglądaj źródła

Targets now must have EOF signaling

Marcos Dumay de Medeiros 8 lat temu
rodzic
commit
78b2cd72ee

+ 30 - 3
cbits/ds.c

@@ -100,6 +100,7 @@ nethandler getPort(const int port){
 
 ds createFromFile(int f){
   ds d = (ds)malloc(sizeof(s_ds));
+  d->eof = 0;
   d->tp = file;
   d->fd = f;
   return d;
@@ -115,6 +116,7 @@ ds createFromFileName(const char *f){
 
 ds createFromHandler(nethandler h){
   ds d = (ds)malloc(sizeof(s_ds));
+  d->eof = 0;
   d->tp = sock;
   unsigned int s = sizeof(d->peer);
   d->fd = accept(h->fd, (struct sockaddr*)&(d->peer), &s);
@@ -127,6 +129,7 @@ ds createFromHandler(nethandler h){
 
 ds createToHost(struct sockaddr *add, const int add_size, const int ipv6){
   ds d = (ds)malloc(sizeof(s_ds));
+  d->eof = 0;
   d->tp = sock;
   if(ipv6){
     d->fd = socket(AF_INET6, SOCK_STREAM, 0);
@@ -184,6 +187,13 @@ int getPeer(ds d, unsigned long *ipv4peer, unsigned char ipv6peer[16], int *ipv6
   return port;
 }
 
+int *getStd(){
+  return (int*) malloc(sizeof(int));
+}
+void closeStd(int *d){
+  free(d);
+}
+
 int sendDs(ds d, const char *b, const int s){
   return write(d->fd, b, s);
 }
@@ -195,13 +205,17 @@ int stdDsSend(const char *b, const int s){
 }
 
 int recvDs(ds d, char *b, const int s){
-  return read(d->fd, b, s);
+  int v = read(d->fd, b, s);
+  d->eof = v == 0;
+  return v;
 }
 int tlsDsRecv(tlsDs d, char *b, const int s){
   return SSL_read(d->s, b, s);
 }
-int stdDsRecv(char *b, const int s){
-  return read(0, b, s);
+int stdDsRecv(int *d, char *b, const int s){
+  int v = read(0, b, s);
+  *d = v == 0;
+  return v;
 }
 
 
@@ -265,6 +279,7 @@ tlsDs startSockTls(ds d, const char *cert, const char *key, const char *dh){
       return clear(ctx);
   }
   tlsDs t = (tlsDs)malloc(sizeof(s_tlsDs));
+  t->eof = 0;
   t->original = d;
   if(!(t->s = SSL_new(ctx))){
     int f = prepareToClose(d);
@@ -316,3 +331,15 @@ int getTlsFd(tlsDs t){
 void closeFd(int fd){
   close(fd);
 }
+
+int isStdEof(int *d){
+  return *d;
+}
+
+int isDsEof(ds d){
+  return d-> eof;
+}
+
+int isTlsEof(tlsDs d){
+  return d-> eof;
+}

+ 8 - 1
cbits/ds.h

@@ -8,6 +8,7 @@ typedef enum {
 
 typedef struct {
   int fd;
+  int eof;
   dstype tp;
   int ipv6;
   int server;
@@ -25,6 +26,7 @@ typedef struct {
   dstype tp;
   void *original;
   SSL *s;
+  int eof;
 } *tlsDs, s_tlsDs;
 
 nethandler getIPv4Port(const int port);
@@ -50,9 +52,14 @@ int recvDs(ds, char[], const int);
 int tlsDsSend(tlsDs, const char[const], const int);
 int tlsDsRecv(tlsDs, char[], const int);
 
+int *getStd();
 int stdDsSend(const char[const], const int);
-int stdDsRecv(char[], const int);
+int stdDsRecv(int*, char[], const int);
+void closeStd(int*);
 
 int getFd(ds);
 int getTlsFd(tlsDs);
+int isStdEOF(int*);
+int isDsEOF(ds);
+int isTlsEOF(tlsDs);
 void closeFd(int);

+ 27 - 16
src/System/IO/Uniform.hs

@@ -9,17 +9,18 @@ module System.IO.Uniform (
   UniformIO(..),
   TlsSettings(..),
   SomeIO(..),
-  mapOverInput,
+  foldOverInput,
   uGetContents
   ) where
 
 import Data.ByteString (ByteString)
 import qualified Data.ByteString.Lazy as LBS
-import Control.Exception
-import System.IO.Error
+import Control.Monad.IO.Class
 
 import Data.Default.Class
 
+import Debug.Trace
+
 -- | Typeclass for uniform IO targets.
 class UniformIO a where
   {- |
@@ -49,7 +50,9 @@ class UniformIO a where
   --
   --  Indicates whether the data written or read from fd is secure at transport.
   isSecure :: a -> Bool
-  
+  -- | True when the target is at end of file
+  isEOF :: a -> IO Bool
+
 -- | A type that wraps any type in the UniformIO class.
 data SomeIO = forall a. (UniformIO a) => SomeIO a
 
@@ -59,6 +62,7 @@ instance UniformIO SomeIO where
   uClose (SomeIO s) = uClose s
   startTls set (SomeIO s) = SomeIO <$> startTls set s
   isSecure (SomeIO s) = isSecure s
+  isEOF (SomeIO s) = isEOF s
 
 -- | Settings for starttls functions.
 data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String, tlsDHParametersFile :: String} deriving (Read, Show)
@@ -67,28 +71,35 @@ instance Default TlsSettings where
   def = TlsSettings "" "" ""
   
 {- |
-mapOverInput io block_size f initial
+foldOverInput io block_size f initial
 
 Reads io untill the end of file, evaluating a(i) <- f a(i-1) read_data
 where a(0) = initial and the last value after io reaches EOF is returned.
 
 Notice that the length of read_data might not be equal block_size.
 -}
-mapOverInput :: forall a io. UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a
-mapOverInput io block f initial = do
-  a <- tryIOError $ uRead io block
-  case a of
-    Left e -> if isEOFError e then return initial else throw e -- EOF
-    Right dt -> do
-      i <- f initial dt
-      mapOverInput io block f i
+foldOverInput :: forall a io. UniformIO io => io -> Int -> (a -> ByteString -> IO a) -> a -> IO a
+foldOverInput io block f initial = do
+  eof <- liftIO $ isEOF io
+  if eof
+    then return initial
+    else do
+    dt <- uRead io block
+    traceShowM dt
+    i <- f initial dt
+    foldOverInput io block f i
 
 {- |
 Returns the entire contents recieved from this target.
 -}
 uGetContents :: UniformIO io => io -> Int -> IO LBS.ByteString
-uGetContents io block = LBS.fromChunks <$> mapOverInput io block atEnd []
+uGetContents io block = LBS.fromChunks <$> concatData
   where
-    atEnd :: [ByteString] -> ByteString -> IO [ByteString]
-    atEnd bb b = return $ bb ++ [b]
+    concatData = do
+      eof <- liftIO $ isEOF io
+      if eof
+        then return []
+        else do
+        dt <- uRead io block
+        (dt :) <$> concatData
 

+ 14 - 14
src/System/IO/Uniform/ByteString.hs

@@ -12,27 +12,26 @@ import Data.ByteString (ByteString)
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
 import qualified Data.ByteString.Builder as BSBuild
-import System.IO.Error
 import Control.Concurrent.MVar
 
 --import Data.Default.Class
 
 --import System.Posix.Types (Fd(..))
 
--- | Wrapper that does UniformIO that reads and writes on the memory.
-data ByteStringIO = ByteStringIO {bsioinput :: MVar (ByteString, Bool), bsiooutput :: MVar BSBuild.Builder}
+{- |
+Wrapper that does UniformIO that reads and writes on the memory.
+
+Input and output may be queried and replaced during the execution of
+the target, with obviously undefined results in case of concurrent
+execution.
+-}
+data ByteStringIO = ByteStringIO {bsioinput :: MVar ByteString, bsiooutput :: MVar BSBuild.Builder}
 instance UniformIO ByteStringIO where
   uRead s n = do
-    (i, eof) <- takeMVar . bsioinput $ s
-    if eof
-    then do
-      putMVar (bsioinput s) (i, eof)
-      ioError $ mkIOError eofErrorType "read past end of input" Nothing Nothing
-    else do
-      let (r, i') = BS.splitAt n i
-      let eof' = BS.null r && n > 0
-      putMVar (bsioinput s) (i', eof')
-      return r
+    i <- takeMVar . bsioinput $ s
+    let (r, i') = BS.splitAt n i
+    putMVar (bsioinput s) i'
+    return r
   uPut s t = do
     o <- takeMVar . bsiooutput $ s
     let o' = mappend o $ BSBuild.byteString t
@@ -40,13 +39,14 @@ instance UniformIO ByteStringIO where
   uClose _ = return ()
   startTls _ = return
   isSecure _ = True
+  isEOF t = withMVar (bsioinput t) $ return . BS.null
 
 -- | withByteStringIO' input f
 --   Runs f with a ByteStringIO that has the given input, returns f's output and
 --   the ByteStringIO output.
 withByteStringIO' :: ByteString -> (ByteStringIO -> IO a) -> IO (a, LBS.ByteString)
 withByteStringIO' input f = do
-  ivar <- newMVar (input, False)
+  ivar <- newMVar input
   ovar <- newMVar . BSBuild.byteString $ BS.empty
   let bsio = ByteStringIO ivar ovar
   a <- f bsio

+ 9 - 2
src/System/IO/Uniform/External.hs

@@ -19,7 +19,7 @@ data SocketIO = SocketIO {sock :: Ptr Ds} | TlsSocketIO {bio :: Ptr TlsDs}
 -- | UniformIO type for file IO.
 newtype FileIO = FileIO {fd :: Ptr Ds}
 -- | UniformIO that reads from stdin and writes to stdout.
-data StdIO = StdIO
+data StdIO = StdIO {eofMark :: Ptr CInt}
 
 closeFd :: Fd -> IO ()
 closeFd (Fd f) = c_closeFd f
@@ -46,10 +46,17 @@ foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt
 foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
 foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds)
 
+foreign import ccall safe "getStd" c_getStd :: IO (Ptr CInt)
+foreign import ccall safe "closeStd" c_closeStd :: Ptr CInt -> IO ()
+
+foreign import ccall safe "isStdEof" c_isStdEof :: Ptr CInt -> IO Bool
+foreign import ccall safe "isDsEof" c_isDsEof :: Ptr Ds -> IO Bool
+foreign import ccall safe "isTlsEof" c_isTlsEof :: Ptr TlsDs -> IO Bool
+
 foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt
 
 foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt
-foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt
+foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CInt -> Ptr CChar -> CInt -> IO CInt
 foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt

+ 1 - 0
src/System/IO/Uniform/File.hs

@@ -35,6 +35,7 @@ instance UniformIO FileIO where
     closeFd f
   startTls _ = return
   isSecure _ = True
+  isEOF s = c_isDsEof . fd $ s
   
   
 -- | Open a file for bidirectional IO.

+ 2 - 1
src/System/IO/Uniform/HandlePair.hs

@@ -7,7 +7,7 @@ module System.IO.Uniform.HandlePair (
   fromHandles
   ) where
 
-import System.IO (Handle, hClose)
+import System.IO (Handle, hClose, hIsEOF)
 import System.IO.Uniform
 import qualified Data.ByteString as BS
 
@@ -34,3 +34,4 @@ instance UniformIO HandlePair where
     hClose o
   startTls _ = return
   isSecure _ = True
+  isEOF (HandlePair i _) = hIsEOF i

+ 2 - 1
src/System/IO/Uniform/Network.hs

@@ -74,7 +74,8 @@ instance UniformIO SocketIO where
   startTls _ s@(TlsSocketIO _) = return s
   isSecure (SocketIO _) = False
   isSecure (TlsSocketIO _) = True
-
+  isEOF (SocketIO s) = c_isDsEof s
+  isEOF (TlsSocketIO t) = c_isTlsEof t
 
 -- | connectToHost hostName port
 --

+ 10 - 4
src/System/IO/Uniform/Std.hs

@@ -1,6 +1,7 @@
 -- | UniformIO over stdin and stdout
 module System.IO.Uniform.Std (
-  StdIO(StdIO)
+  StdIO,
+  getStdIO
   ) where
 
 import System.IO.Uniform
@@ -13,9 +14,9 @@ import Control.Monad
 
 -- | UniformIO that reads from stdin and writes to stdout.
 instance UniformIO StdIO where
-  uRead _ n = allocaArray n (
+  uRead s n = allocaArray n (
     \b -> do
-      count <- c_recvStd b (fromIntegral n)
+      count <- c_recvStd (eofMark s) b (fromIntegral n)
       if count < 0
         then throwErrno "could not read"
         else BS.packCStringLen (b, fromIntegral count)
@@ -25,6 +26,11 @@ instance UniformIO StdIO where
       count <- c_sendStd str $ fromIntegral n
       when (count < 0) $ throwErrno "could not write"
     )
-  uClose _ = return ()
+  uClose s = c_closeStd . eofMark $ s
   startTls _ = return
   isSecure _ = True
+  isEOF s = c_isStdEof . eofMark $ s
+
+
+getStdIO :: IO StdIO
+getStdIO = StdIO <$> c_getStd

+ 4 - 11
src/System/IO/Uniform/Streamline.hs

@@ -46,7 +46,7 @@ module System.IO.Uniform.Streamline (
 import System.IO (stdout, Handle)
 import qualified System.IO.Uniform as S
 import qualified System.IO.Uniform.Network as N
-import qualified System.IO.Uniform.Std as Std
+import qualified System.IO.Uniform.Null as Null
 import System.IO.Uniform (UniformIO, SomeIO(..), TlsSettings)
 import System.IO.Uniform.Streamline.Scanner
 import Data.Default.Class
@@ -57,7 +57,6 @@ import Control.Monad.Trans.Control
 import Control.Monad
 import Control.Monad.Base
 import Control.Monad.IO.Class
-import System.IO.Error
 import Data.ByteString (ByteString)
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
@@ -66,13 +65,10 @@ import Data.IP (IP)
 
 import qualified Data.Attoparsec.ByteString as A
 
-import Debug.Trace
-
 -- | Internal state for a Streamline monad
 data StreamlineState = StreamlineState {str :: SomeIO, buff :: ByteString, targetEOF :: Bool, echo :: Maybe Handle, inLimit :: Int}
 instance Default StreamlineState where
-  -- | Will open StdIO
-  def = StreamlineState (SomeIO Std.StdIO) BS.empty False Nothing (-1)
+  def = StreamlineState (SomeIO Null.NullIO) BS.empty False Nothing (-1)
 
 -- | Monad that emulates character stream IO over block IO.
 newtype Streamline m a = Streamline {withTarget' :: StreamlineState -> m (a, StreamlineState)}
@@ -106,7 +102,7 @@ takeBuff = do
   readF
   Streamline $ \cl -> 
     let lim = inLimit cl
-        eof = targetEOF cl
+        --eof = targetEOF cl
         b = buff cl
     in if lim < 0 then return (b, cl{buff=""})
        else let (r, b') = BS.splitAt lim b
@@ -315,7 +311,7 @@ runAttoparsecAndReturn p = do
     continueResult c d dd = case c of
       A.Fail i _ msg -> do
         pushBuff $ BS.concat (reverse dd) `BS.append` i
-        return (BS.take (BS.length d - BS.length i) d, Left msg)
+        return ("", Left msg)
       A.Done i r -> do
         pushBuff i
         return (BS.concat (reverse dd) `BS.append`
@@ -381,9 +377,6 @@ Setting to Nothing will disable echo.
 echoTo :: Monad m => Maybe Handle -> Streamline m ()
 echoTo h = Streamline $ \cl -> return ((), cl{echo=h})
 
-eofError :: MonadIO m => String -> m a
-eofError msg = liftIO . ioError $ mkIOError eofErrorType msg Nothing Nothing
-
 instance Interruptible Streamline where
   type RSt Streamline a = (a, StreamlineState)
   resume f (a, st) = withTarget' (f a) st

+ 2 - 0
src/System/IO/Uniform/Timeout.hs

@@ -42,6 +42,7 @@ instance UniformIO FixedTimeout where
       Just r -> return $ FixedTimeout t r
       Nothing -> doTimeout
   isSecure (FixedTimeout _ u) = isSecure u
+  isEOF (FixedTimeout _ u) = isEOF u
 
 {- |
 Variable timeout, set at runtime.
@@ -71,6 +72,7 @@ instance UniformIO MVarTimeout where
       Just r -> return $ MVarTimeout t' r
       Nothing -> doTimeout
   isSecure (MVarTimeout _ u) = isSecure u
+  isEOF (MVarTimeout _ u) = isEOF u
 
 doTimeout :: IO a
 doTimeout = ioError $ userError "Timeout"

+ 1 - 1
test/Targets.hs

@@ -102,7 +102,7 @@ testBS = do
   (len, echo) <- withByteStringIO dt (
     \io -> let
       count = countAndEcho io :: Int -> ByteString -> IO Int
-      in mapOverInput io 2 count 0
+      in foldOverInput io 2 count 0
     ) :: IO (Int, ByteString)
   if dt /= echo || BS.length dt /= len
     then return . Finished . Fail $ "Failure on ByteStringIO test"

+ 2 - 0
uniform-io.cabal

@@ -41,6 +41,7 @@ library
       System.IO.Uniform.Network,
       System.IO.Uniform.File,
       System.IO.Uniform.Std,
+      System.IO.Uniform.Null,
       System.IO.Uniform.ByteString,
       System.IO.Uniform.HandlePair,
       System.IO.Uniform.Timeout,
@@ -76,6 +77,7 @@ library
       data-default-class >= 0.0.1,
       monad-control,
       transformers-base,
+      conduit,          
       interruptible
   
   -- Directories containing source files.