New monads/MonadExit

From HaskellWiki
Jump to navigation Jump to search

The Exit monad provides short-circuiting for complex program flow logic.

If you are using CPS, MonadCont, or LogicT only for this purpose, the Exit monad will likely simplify your program considerably.

Note: Now that a restriction on the Left type has been removed, the standard Either type can be used for this purpose. No separate Exit monad is needed anymore. For a monad transformer, use the version of EitherT defined in the either package.

The code

{-# OPTIONS_GHC -fglasgow-exts #-}

-- A monad that provides short-circuiting for complex program flow logic.

module Control.Monad.Exit (
  MonadExit(exitWith),
  Exit,
  runExit,
  runExitMaybe,
  ExitT,
  runExitT,
  runExitTMaybe,
  module Control.Monad,
  module Control.Monad.Trans
) where

import Control.Monad
import Control.Monad.Trans
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.Error
import Control.Monad.State
import Control.Monad.List
import qualified System.Exit as Sys (exitWith, ExitCode)

-- The MonadExit class

class Monad m => MonadExit e m | m -> e where
  exitWith :: e -> m a

instance MonadExit Sys.ExitCode IO where
  exitWith = Sys.exitWith

-- The Exit monad

data Exit e a = Continue a | Exit e

runExit :: Exit e a -> e
runExit (Exit x) = x
runExit _        = error "Exit monad did not exit."

runExitMaybe :: Exit e b -> Maybe e
runExitMaybe (Exit x) = Just x
runExitMaybe _        = Nothing

instance Functor (Exit e) where
  fmap f (Continue x) = Continue $ f x
  fmap _ (Exit     x) = Exit x

instance Monad (Exit e) where
  return = Continue
  (Continue x) >>= f = f x
  (Exit     x) >>= _ = Exit x

instance MonadExit e (Exit e) where
  exitWith = Exit

-- The ExitT monad

newtype ExitT e m a = ExitT (m (Exit e a))

runExitT :: Monad m => ExitT e m a -> m e
runExitT (ExitT x) = do
  y <- x
  case y of
    Exit z -> return z
    _      -> error "ExitT monad did not exit."

runExitTMaybe :: Monad m => ExitT e m a -> m (Maybe e)
runExitTMaybe (ExitT x) = liftM runExitMaybe x

instance Monad m => Functor (ExitT e m) where
  fmap f (ExitT x) = ExitT $ do
    y <- x
    case y of
      Continue z -> return $ Continue $ f z
      Exit     z -> return $ Exit z

instance Monad m => Monad (ExitT e m) where
  return = ExitT . return . Continue
  (ExitT x) >>= f = ExitT $ do
    y <- x
    case y of
      Continue z -> let ExitT w = f z in w
      Exit     z -> return $ Exit z

instance Monad m => MonadExit e (ExitT e m) where
  exitWith = ExitT . return . Exit

instance MonadTrans (ExitT e) where
  lift = ExitT . liftM Continue

-- Lifted instances of other monad classes from inside ExitT

-- TODO: Put a MonadFix instance here.

instance MonadIO m => MonadIO (ExitT e m) where
  liftIO = lift . liftIO

instance MonadPlus m => MonadPlus (ExitT e m) where
  mzero = lift mzero
  (ExitT x) `mplus` (ExitT y) = ExitT (x `mplus` y)

instance MonadState s (ExitT e (State s)) where
  get = lift get
  put = lift . put

instance Monad m => MonadState s (ExitT e (StateT s m)) where
  get = lift get
  put = lift . put

instance Error err => MonadError err (ExitT e (Either err)) where
  throwError = lift . throwError
  catchError (ExitT x) f = ExitT $ catchError x (\e -> let ExitT y = f e in y)

instance (Error err, Monad m) => MonadError err (ExitT e (ErrorT err m)) where
  throwError = lift . throwError
  catchError (ExitT x) f = ExitT $ catchError x (\e -> let ExitT y = f e in y)

-- MonadExit instances for other monad transformers

instance MonadExit e (StateT s (Exit e)) where
  exitWith = lift . exitWith

instance Monad m => MonadExit e (StateT s (ExitT e m)) where
  exitWith = lift . exitWith

instance MonadExit e (ListT (Exit e)) where
  exitWith = lift . exitWith

instance Monad m => MonadExit e (ListT (ExitT e m)) where
  exitWith = lift . exitWith

instance MonadExit e (ReaderT r (Exit e)) where
  exitWith = lift . exitWith

instance Monad m => MonadExit e (ReaderT r (ExitT e m)) where
  exitWith = lift . exitWith

instance Monoid w => MonadExit e (WriterT w (Exit e)) where
  exitWith = lift . exitWith

instance (Monoid w, Monad m) => MonadExit e (WriterT w (ExitT e m)) where
  exitWith = lift . exitWith

instance Error err => MonadExit e (ErrorT err (Exit e)) where
  exitWith = lift . exitWith

instance (Error err, Monad m) => MonadExit e (ErrorT err (ExitT e m)) where
  exitWith = lift . exitWith