Skip to content

Get rid of fmapCoerce by ensuring our functors are representational #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: functor-stt-stm
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 30 additions & 36 deletions src/Control/Monad/Trans/Control.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
, FunctionalDependencies
, FlexibleInstances
, UndecidableInstances
, MultiParamTypeClasses #-}
, MultiParamTypeClasses
, QuantifiedConstraints #-}

-- We would be Safe if Data.Coerce were.
{-# LANGUAGE Trustworthy #-}
Expand Down Expand Up @@ -142,6 +143,8 @@ import qualified Control.Monad.Trans.Writer.Strict as Strict ( WriterT(WriterT),
-- from transformers-base:
import Control.Monad.Base ( MonadBase )

class (Monad m, forall a b. (Coercible a b) => Coercible (m a) (m b)) => RepresentationalMonad m where
instance (Monad m, forall a b. (Coercible a b) => Coercible (m a) (m b)) => RepresentationalMonad m where

--------------------------------------------------------------------------------
-- MonadTransControl type class
Expand Down Expand Up @@ -247,7 +250,7 @@ class MonadTrans t => MonadTransControl t where
-- liftWith :: 'Monad' m => (('Monad' n => 'MaybeT' n b -> n ('Maybe' b)) -> m a) -> 'MaybeT' m a
-- liftWith f = 'MaybeT' ('fmap' 'Just' (f 'runMaybeT'))
-- @
liftWith :: Monad m => (Run t -> m a) -> t m a
liftWith :: RepresentationalMonad m => (Run t -> m a) -> t m a

-- | Construct a @t@ computation from the monadic state of @t@ that is
-- returned from a 'Run' function.
Expand Down Expand Up @@ -286,7 +289,7 @@ class MonadTrans t => MonadTransControl t where
-- restoreT :: ('Monad' m, 'Monoid' w) => m (a, w) -> 'WriterT' w m a
-- restoreT :: ('Monad' m, 'Monoid' w) => m (a, s, w) -> 'RWST' r w s m a
-- @
restoreT :: Monad m => m (StT t a) -> t m a
restoreT :: RepresentationalMonad m => m (StT t a) -> t m a

-- | A function that runs a transformed monad @t n@ on the monadic state that
-- was captured by 'liftWith'
Expand Down Expand Up @@ -316,16 +319,7 @@ class MonadTrans t => MonadTransControl t where
-- 'flip' 'runStateT' :: s -> Run ('StateT' s)
-- 'runMaybeT' :: Run 'MaybeT'
-- @
type Run t = forall n b. Monad n => t n b -> n (StT t b)

-------------------------------------------------------------------------------
-- fmapCoerce
-------------------------------------------------------------------------------

-- In some future this might be runtime zero-cost, but not yet.
fmapCoerce :: (Functor f, Coercible a b) => f a -> f b
fmapCoerce = fmap coerce
{-# INLINE fmapCoerce #-}
type Run t = forall n b. RepresentationalMonad n => t n b -> n (StT t b)

--------------------------------------------------------------------------------
-- Defaults for MonadTransControl
Expand Down Expand Up @@ -353,10 +347,10 @@ fmapCoerce = fmap coerce

-- | A function like 'Run' that runs a monad transformer @t@ which wraps the
-- monad transformer @t'@. This is used in 'defaultLiftWith'.
type RunDefault t t' = forall n b. Monad n => t n b -> n (StT t' b)
type RunDefault t t' = forall n b. RepresentationalMonad n => t n b -> n (StT t' b)

-- | Default definition for the 'liftWith' method.
defaultLiftWith :: (Monad m, MonadTransControl n)
defaultLiftWith :: (RepresentationalMonad m, MonadTransControl n)
=> (forall b. n m b -> t m b) -- ^ Monad constructor
-> (forall o b. t o b -> n o b) -- ^ Monad deconstructor
-> (RunDefault t n -> m a)
Expand All @@ -365,7 +359,7 @@ defaultLiftWith t unT = \f -> t $ liftWith $ \run -> f $ run . unT
{-# INLINABLE defaultLiftWith #-}

-- | Default definition for the 'restoreT' method.
defaultRestoreT :: (Monad m, MonadTransControl n)
defaultRestoreT :: (RepresentationalMonad m, MonadTransControl n)
=> (n m a -> t m a) -- ^ Monad constructor
-> m (StT n a)
-> t m a
Expand Down Expand Up @@ -395,10 +389,10 @@ defaultRestoreT t = t . restoreT

-- | A function like 'Run' that runs a monad transformer @t@ which wraps the
-- monad transformers @n@ and @n'@. This is used in 'defaultLiftWith2'.
type RunDefault2 t n n' = forall m b. (Monad m, Monad (n' m)) => t m b -> m (StT n' (StT n b))
type RunDefault2 t n n' = forall m b. (RepresentationalMonad m, RepresentationalMonad (n' m)) => t m b -> m (StT n' (StT n b))

-- | Default definition for the 'liftWith' method.
defaultLiftWith2 :: (Monad m, Monad (n' m), MonadTransControl n, MonadTransControl n')
defaultLiftWith2 :: (RepresentationalMonad m, RepresentationalMonad (n' m), MonadTransControl n, MonadTransControl n')
=> (forall b. n (n' m) b -> t m b) -- ^ Monad constructor
-> (forall o b. t o b -> n (n' o) b) -- ^ Monad deconstructor
-> (RunDefault2 t n n' -> m a)
Expand All @@ -407,7 +401,7 @@ defaultLiftWith2 t unT = \f -> t $ liftWith $ \run -> liftWith $ \run' -> f $ ru
{-# INLINABLE defaultLiftWith2 #-}

-- | Default definition for the 'restoreT' method for double 'MonadTransControl'.
defaultRestoreT2 :: (Monad m, Monad (n' m), MonadTransControl n, MonadTransControl n')
defaultRestoreT2 :: (RepresentationalMonad m, RepresentationalMonad (n' m), MonadTransControl n, MonadTransControl n')
=> (n (n' m) a -> t m a) -- ^ Monad constructor
-> m (StT n' (StT n a))
-> t m a
Expand All @@ -420,8 +414,8 @@ defaultRestoreT2 t = t . restoreT . restoreT

instance MonadTransControl IdentityT where
type StT IdentityT = Identity
liftWith f = IdentityT $ f $ fmapCoerce . runIdentityT
restoreT = IdentityT . fmapCoerce
liftWith f = IdentityT $ f $ coerce . runIdentityT
restoreT = IdentityT . coerce
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

Expand Down Expand Up @@ -455,8 +449,8 @@ instance MonadTransControl ListT where

instance MonadTransControl (ReaderT r) where
type StT (ReaderT r) = Identity
liftWith f = ReaderT $ \r -> f $ \t -> fmapCoerce $ runReaderT t r
restoreT = ReaderT . const . fmapCoerce
liftWith f = ReaderT $ \r -> f $ \t -> coerce $ runReaderT t r
restoreT = ReaderT . const . coerce
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

Expand Down Expand Up @@ -484,8 +478,8 @@ instance MonadTransControl (Strict.StateT s) where
instance Monoid w => MonadTransControl (WriterT w) where
type StT (WriterT w) = WriterStT w
liftWith f = WriterT $ fmap (\x -> (x, mempty))
(f $ fmapCoerce . runWriterT)
restoreT = WriterT . fmapCoerce
(f $ coerce . runWriterT)
restoreT = WriterT . coerce
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

Expand All @@ -495,16 +489,16 @@ newtype WriterStT w a = WriterStT { getWriterStT :: (a, w) }
instance Monoid w => MonadTransControl (Strict.WriterT w) where
type StT (Strict.WriterT w) = WriterStT w
liftWith f = Strict.WriterT $ fmap (\x -> (x, mempty))
(f $ fmapCoerce . Strict.runWriterT)
restoreT = Strict.WriterT . fmapCoerce
(f $ coerce . Strict.runWriterT)
restoreT = Strict.WriterT . coerce
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

instance Monoid w => MonadTransControl (RWST r w s) where
type StT (RWST r w s) = RwsStT w s
liftWith f = RWST $ \r s -> fmap (\x -> (x, s, mempty))
(f $ \t -> fmapCoerce $ runRWST t r s)
restoreT mSt = RWST $ \_ _ -> fmapCoerce $ mSt
(f $ \t -> coerce $ runRWST t r s)
restoreT mSt = RWST $ \_ _ -> coerce $ mSt
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

Expand All @@ -514,8 +508,8 @@ instance Monoid w => MonadTransControl (Strict.RWST r w s) where
type StT (Strict.RWST r w s) = RwsStT w s
liftWith f =
Strict.RWST $ \r s -> fmap (\x -> (x, s, mempty))
(f $ \t -> fmapCoerce $ Strict.runRWST t r s)
restoreT mSt = Strict.RWST $ \_ _ -> fmapCoerce $ mSt
(f $ \t -> coerce $ Strict.runRWST t r s)
restoreT mSt = Strict.RWST $ \_ _ -> coerce $ mSt
{-# INLINABLE liftWith #-}
{-# INLINABLE restoreT #-}

Expand All @@ -531,7 +525,7 @@ instance Monoid w => MonadTransControl (Strict.RWST r w s) where
-- for the base monad, and @MonadTransControl T@ instances for every transformer
-- @T@. Instances for @'MonadBaseControl'@ are then simply implemented using
-- @'ComposeSt'@, @'defaultLiftBaseWith'@, @'defaultRestoreM'@.
class MonadBase b m => MonadBaseControl b m | m -> b where
class (MonadBase b m, RepresentationalMonad m, RepresentationalMonad b) => MonadBaseControl b m | m -> b where
-- | Monadic state that @m@ adds to the base monad @b@.
--
-- For all base (non-transformed) monads, @StM m a ~ a@:
Expand Down Expand Up @@ -637,7 +631,7 @@ type RunInBase m b = forall a. m a -> b (StM m a)
#define BASE(M) \
instance MonadBaseControl (M) (M) where { \
type StM (M) = Identity; \
liftBaseWith f = f fmapCoerce; \
liftBaseWith f = f coerce; \
restoreM = return . coerce; \
{-# INLINABLE liftBaseWith #-}; \
{-# INLINABLE restoreM #-}}
Expand Down Expand Up @@ -772,7 +766,7 @@ control f = liftBaseWith f >>= restoreM

-- | Lift a computation and restore the monadic state immediately:
-- @controlT f = 'liftWith' f >>= 'restoreT' . return@.
controlT :: (MonadTransControl t, Monad (t m), Monad m)
controlT :: (MonadTransControl t, RepresentationalMonad (t m), RepresentationalMonad m)
=> (Run t -> m (StT t a)) -> t m a
controlT f = liftWith f >>= restoreT . return
{-# INLINABLE controlT #-}
Expand All @@ -790,7 +784,7 @@ embed_ f = liftBaseWith $ \runInBase -> return (void . runInBase . f)
{-# INLINABLE embed_ #-}

-- | Capture the current state of a transformer
captureT :: (MonadTransControl t, Monad (t m), Monad m) => t m (StT t ())
captureT :: (MonadTransControl t, RepresentationalMonad (t m), RepresentationalMonad m) => t m (StT t ())
captureT = liftWith $ \runInM -> runInM (return ())
{-# INLINABLE captureT #-}

Expand Down Expand Up @@ -879,7 +873,7 @@ liftBaseOpDiscard f g = liftBaseWith $ \runInBase -> f $ void . runInBase . g

-- | Transform an action in @t m@ using a transformer that operates on the underlying monad @m@
liftThrough
:: (MonadTransControl t, Monad (t m), Monad m)
:: (MonadTransControl t, RepresentationalMonad (t m), RepresentationalMonad m)
=> (m (StT t a) -> m (StT t b)) -- ^
-> t m a -> t m b
liftThrough f t = do
Expand Down