瀏覽代碼

Added SafeIO module

Marcos Dumay de Medeiros 8 年之前
父節點
當前提交
bbb6c63b3f
共有 3 個文件被更改,包括 70 次插入9 次删除
  1. 6 1
      interruptible.cabal
  2. 25 0
      src/Control/Monad/Trans/SafeIO.hs
  3. 39 8
      test/Test.hs

+ 6 - 1
interruptible.cabal

@@ -28,12 +28,16 @@ build-type:          Simple
 cabal-version:       >=1.10
 
 library
-  exposed-modules:     Control.Monad.Trans.Interruptible
+  exposed-modules:
+      Control.Monad.Trans.Interruptible
+      Control.Monad.Trans.SafeIO
   other-modules:       Control.Monad.Trans.Interruptible.Class
   other-extensions:    TypeFamilies
   build-depends:
       base >=4.7 && <4.9,
       transformers,
+      monad-control,
+      lifted-base,
       either
   hs-source-dirs:      src
   default-language:    Haskell2010
@@ -47,6 +51,7 @@ Test-suite all
     base >=4.7 && <5.0,
     Cabal >= 1.9.2,
     either,
+    transformers,
     interruptible
   ghc-options: -Wall -fno-warn-unused-do-bind -fwarn-incomplete-patterns -threaded
   default-language: Haskell2010

+ 25 - 0
src/Control/Monad/Trans/SafeIO.hs

@@ -0,0 +1,25 @@
+{-# LANGUAGE FlexibleContexts #-}
+{- |
+
+-}
+module Control.Monad.Trans.SafeIO where
+
+import System.IO.Error
+import Control.Monad.IO.Class
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Either
+import Control.Monad.Trans.Control
+import qualified Control.Exception.Lifted as Lift
+
+class IOErrorDerivation e where
+  coerceIOError :: IOError -> e
+
+safeIO :: (MonadIO m, IOErrorDerivation e) => IO a -> EitherT e m a
+safeIO io = (liftIO $ tryIOError io) >>= hoistResult
+
+safeCT :: (MonadBaseControl IO m, IOErrorDerivation e) => m a -> EitherT e m a
+safeCT f = (lift $ Lift.try f) >>= hoistResult
+
+hoistResult :: (IOErrorDerivation e, Monad m) => Either IOError a -> EitherT e m a
+hoistResult (Left e) = left . coerceIOError $ e
+hoistResult (Right v) = right v

+ 39 - 8
test/Test.hs

@@ -4,8 +4,11 @@ import Distribution.TestSuite
 import System.IO.Error
 
 import Control.Monad.Trans.Either
+import Control.Monad.Trans.State
 import Data.Either.Combinators
 import Control.Monad.Trans.Interruptible
+import Control.Monad.IO.Class
+import Control.Monad.Trans.SafeIO
 
 simpleTest :: String -> IO Progress -> Test
 simpleTest n t = 
@@ -25,13 +28,15 @@ simpleTest n t =
 
 tests :: IO [Test]
 tests = return [
-  simpleTest "resume" (tres),
-  simpleTest "resume2" (tres2),
-  simpleTest "resume3" (tres3),
-  simpleTest "resume4" (tres4),
-  simpleTest "resume5" (tres5),
-  simpleTest "intercalate1" (int1),
-  simpleTest "intercalate5" (int5)
+  simpleTest "resume" tres,
+  simpleTest "resume2" tres2,
+  simpleTest "resume3" tres3,
+  simpleTest "resume4" tres4,
+  simpleTest "resume5" tres5,
+  simpleTest "intercalate1" int1,
+  simpleTest "intercalate5" int5,
+  simpleTest "safeIO" tSafeIO,
+  simpleTest "safeCL" tSafeCT
   ]
 
 tres :: IO Progress
@@ -81,4 +86,30 @@ int5 = do
   let f = (\x y -> return $ x + y) :: Int -> Int -> EitherT () (EitherT () (EitherT () (EitherT () (EitherT () IO)))) Int
   r <- intercalateWith resume5 f [1, 2, 3] (map (Right . Right . Right . Right . Right) [10, 20])
   let v = map (fromRight 0 . fromRight (Left ()) . fromRight (Left ()) . fromRight (Left ()) . fromRight (Left ())) r
-  Finished <$> if v == [16, 26] then return Pass else return $ Fail $ "Wrong value: " ++ show v
+  Finished <$> if v == [16, 26] then return Pass else return . Fail $ "Wrong value: " ++ show v
+
+newtype Txt = Txt String
+instance IOErrorDerivation Txt where
+  coerceIOError = Txt . show
+
+tSafeIO :: IO Progress
+tSafeIO = do
+  let msg = "test"
+      err = show . userError $ msg
+  r <- runEitherT (safeIO . ioError . userError $ msg)
+  case r of
+    Left (Txt msg') -> Finished <$> if err == msg' then return Pass else return . Fail $ "Wrong error: " ++ msg'
+    Right _ -> return . Finished . Fail $ "Throwing error didn't create an error!"
+
+tSafeCT :: IO Progress
+tSafeCT = do
+  let msg = "test"
+      err = show . userError $ msg
+  r <- fst <$> runStateT (runEitherT (safeCT . stateError $ msg)) ()
+  case r of
+    Left (Txt msg') -> Finished <$> if err == msg' then return Pass else return . Fail $ "Wrong error: " ++ msg'
+    Right _ -> return . Finished . Fail $ "Throwing error didn't create an error!"
+  where
+    stateError :: String -> StateT () IO ()
+    stateError = liftIO . ioError . userError
+