Difference between revisions of "Amb"

From HaskellWiki
Jump to navigation Jump to search
m
m (update)
Line 10: Line 10:
 
import Control.Monad.Identity
 
import Control.Monad.Identity
   
type Point r a m = () -> AmbT r a m a
+
type Point r s m = () -> AmbT r s m s
newtype AmbT r a m b = AmbT { unAmbT :: StateT [Point r a m] (ContT r m) b }
+
newtype AmbT r s m a = AmbT { unAmbT :: StateT [Point r s m] (ContT r m) a }
type Amb r a = AmbT r a Identity
+
type Amb r s = AmbT r s Identity
   
instance MonadTrans (AmbT r a) where
+
instance MonadTrans (AmbT r s) where
 
lift = AmbT . lift . lift
 
lift = AmbT . lift . lift
   
instance (Monad m) => Monad (AmbT r a m) where
+
instance (Monad m) => Monad (AmbT r s m) where
 
AmbT a >>= b = AmbT $ a >>= unAmbT . b
 
AmbT a >>= b = AmbT $ a >>= unAmbT . b
 
return = AmbT . return
 
return = AmbT . return
   
backtrack :: (Monad m) => AmbT r a m b
+
backtrack :: (Monad m) => AmbT r s m a
 
backtrack = do xss <- AmbT get
 
backtrack = do xss <- AmbT get
 
case xss of
 
case xss of
Line 29: Line 29:
 
return undefined
 
return undefined
   
addPoint :: (Monad m) => Point r a m -> AmbT r a m ()
+
addPoint :: (Monad m) => Point r s m -> AmbT r s m ()
 
addPoint x = AmbT $ modify (x:)
 
addPoint x = AmbT $ modify (x:)
   
amb :: (Monad m) => [b] -> AmbT r a m b
+
amb :: (Monad m) => [a] -> AmbT r s m a
 
amb [] = backtrack
 
amb [] = backtrack
 
amb (x:xs) = ambCC $ \exit -> do
 
amb (x:xs) = ambCC $ \exit -> do
Line 39: Line 39:
 
where ambCC f = AmbT $ callCC $ \k -> unAmbT $ f $ AmbT . k
 
where ambCC f = AmbT $ callCC $ \k -> unAmbT $ f $ AmbT . k
   
cut :: (Monad m) => AmbT r a m ()
+
cut :: (Monad m) => AmbT r s m ()
 
cut = AmbT $ put []
 
cut = AmbT $ put []
   
runAmbT :: (Monad m) => AmbT r a m r -> m r
+
runAmbT :: (Monad m) => AmbT r s m r -> m r
 
runAmbT (AmbT a) = runContT (evalStateT a []) return
 
runAmbT (AmbT a) = runContT (evalStateT a []) return
   
runAmb :: Amb r a r -> r
+
runAmb :: Amb r s r -> r
 
runAmb = runIdentity . runAmbT
 
runAmb = runIdentity . runAmbT
 
</haskell>
 
</haskell>

Revision as of 15:23, 28 March 2008

This is an implementation of the amb operator in Haskell. Interestingly, it is identical to the list monad: remove 'amb' and the examples below work fine (apart, of course, from the IO one).

Notably, AmbT could be considered ListT done right.

module Amb (AmbT, Amb, amb, cut, runAmbT, runAmb) where

import Control.Monad.Cont
import Control.Monad.State
import Control.Monad.Identity

type Point r s m = () -> AmbT r s m s
newtype AmbT r s m a = AmbT { unAmbT :: StateT [Point r s m] (ContT r m) a }
type Amb r s = AmbT r s Identity

instance MonadTrans (AmbT r s) where
    lift = AmbT . lift . lift

instance (Monad m) => Monad (AmbT r s m) where
    AmbT a >>= b = AmbT $ a >>= unAmbT . b
    return = AmbT . return

backtrack :: (Monad m) => AmbT r s m a
backtrack = do xss <- AmbT get
               case xss of
                 [] -> fail "amb tree exhausted"
                 (f:xs) -> do AmbT $ put xs
                              f ()
                              return undefined

addPoint :: (Monad m) => Point r s m -> AmbT r s m ()
addPoint x = AmbT $ modify (x:)

amb :: (Monad m) => [a] -> AmbT r s m a
amb []     = backtrack
amb (x:xs) = ambCC $ \exit -> do
               ambCC $ \k -> addPoint k >> exit x
               amb xs
    where ambCC f = AmbT $ callCC $ \k -> unAmbT $ f $ AmbT . k

cut :: (Monad m) => AmbT r s m ()
cut = AmbT $ put []

runAmbT :: (Monad m) => AmbT r s m r -> m r
runAmbT (AmbT a) = runContT (evalStateT a []) return

runAmb :: Amb r s r -> r
runAmb = runIdentity . runAmbT

And some examples:

example :: Amb r Integer (Integer,Integer)
example = do x <- amb [1,2,3]
             y <- amb [4,5,6]
             if x*y == 8
               then return (x,y)
               else amb []

factor :: Integer -> Amb r Integer (Integer,Integer)
factor a = do x <- amb [2..]
              y <- amb [2..x]
              if x*y == a
                then return (x,y)
                else amb []

factorIO :: Integer -> AmbT r Integer IO (Integer,Integer)
factorIO a = do lift $ putStrLn $ "Factoring " ++ show a
                x <- amb [2..]
                y <- amb [2..x]
                lift $ putStrLn $ "Trying " ++ show x ++ " and " ++ show y
                if x*y == a
                  then do lift $ putStrLn "Found it!"
                          return (x,y)
                  else do lift $ putStrLn $ "Nope (" ++ show (x*y) ++ ")"
                          amb []