MonadThrow/MonadCatch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
{-# LANGUAGE
    ConstraintKinds
  , DefaultSignatures
  , FlexibleInstances
  , FunctionalDependencies
  , LambdaCase
  , MultiParamTypeClasses
  , UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
import Control.Monad.Error hiding (MonadError)
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State

class Monad m => MonadThrow e m | m -> e where
  throw :: e -> m a
  default throw :: (MonadThrow e m, MonadTrans t) => e -> t m a
  throw = lift . throw

class ( Monad m
      , Monad m'
      , MonadTrans t
      , MonadTrans t'
      ) => MonadCatch e m m' t t' | t m -> e, t' m' e -> t m where
  catch :: t m a -> (e -> t' m' a) -> t' m' a

instance (Error e, Monad m) => MonadThrow e (ErrorT e m) where
  throw = throwError
instance ( Error e
         , Error e'
         , Monad m
         ) => MonadCatch e m m (ErrorT e) (ErrorT e') where
  m `catch` h = flip mapErrorT m $ \ m' ->
    m' >>= \ case
      Left e -> runErrorT $ h e
      Right a -> return $ Right a

instance MonadThrow e m => MonadThrow e (ReaderT r m)
instance ( Monad (t m)
         , Monad (t' m')
         , MonadCatch e m m' t t'
         ) => MonadCatch e (t m) (t' m') (ReaderT r) (ReaderT r) where
  m `catch` h = ReaderT $ \ r -> runReaderT m r `catch` \ e -> runReaderT (h e) r

instance MonadThrow e m => MonadThrow e (StateT s m)
instance ( Monad (t m)
         , Monad (t' m')
         , MonadCatch e m m' t t'
         ) => MonadCatch e (t m) (t' m') (StateT s) (StateT s) where
  m `catch` h = StateT $ \ s -> runStateT m s `catch` \ e -> runStateT (h e) s