Skip to content

State Monad

Numbering the nodes of a tree with a counter

Section titled “Numbering the nodes of a tree with a counter”

We have a tree data type like this:

data Tree a = Tree a [Tree a] deriving Show

And we wish to write a function that assigns a number to each node of the tree, from an incrementing counter:

tag :: Tree a -> Tree (a, Int)

First we’ll do it the long way around, since it illustrates the State monad’s low-level mechanics quite nicely.

import Control.Monad.State
-- Function that numbers the nodes of a `Tree`.
tag :: Tree a -> Tree (a, Int)
tag tree =
-- tagStep is where the action happens. This just gets the ball
-- rolling, with `0` as the initial counter value.
evalState (tagStep tree) 0
-- This is one monadic "step" of the calculation. It assumes that
-- it has access to the current counter value implicitly.
tagStep :: Tree a -> State Int (Tree (a, Int))
tagStep (Tree a subtrees) = do
-- The `get :: State s s` action accesses the implicit state
-- parameter of the State monad. Here we bind that value to
-- the variable `counter`.
counter <- get
-- The `put :: s -> State s ()` sets the implicit state parameter
-- of the `State` monad. The next `get` that we execute will see
-- the value of `counter + 1` (assuming no other puts in between).
put (counter + 1)
-- Recurse into the subtrees. `mapM` is a utility function
-- for executing a monadic actions (like `tagStep`) on a list of
-- elements, and producing the list of results. Each execution of
-- `tagStep` will be executed with the counter value that resulted
-- from the previous list element's execution.
subtrees' <- mapM tagStep subtrees
return $ Tree (a, counter) subtrees'

Split out the counter into a postIncrement action

Section titled “Split out the counter into a postIncrement action”

The bit where we are getting the current counter and then putting counter + 1 can be split off into a postIncrement action, similar to what many C-style languages provide:

postIncrement :: Enum s => State s s
postIncrement = do
result <- get
modify succ
return result

Split out the tree walk into a higher-order function

Section titled “Split out the tree walk into a higher-order function”

The tree walk logic can be split out into its own function, like this:

mapTreeM :: Monad m => (a -> m b) -> Tree a -> m (Tree b)
mapTreeM action (Tree a subtrees) = do
a' <- action a
subtrees' <- mapM (mapTreeM action) subtrees
return $ Tree a' subtrees'

With this and the postIncrement function we can rewrite tagStep:

tagStep :: Tree a -> State Int (Tree (a, Int))
tagStep = mapTreeM step
where step :: a -> State Int (a, Int)
step a = do
counter <- postIncrement
return (a, counter)

The mapTreeM solution above can be easily rewritten into an instance of the Traversable class:

instance Traversable Tree where
traverse action (Tree a subtrees) =
Tree <$> action a <*> traverse action subtrees

Note that this required us to use Applicative (the <*> operator) instead of Monad.

With that, now we can write tag like a pro:

tag :: Traversable t => t a -> t (a, Int)
tag init t = evalState (traverse step t) 0
where step a = do tag <- postIncrement
return (a, tag)

Note that this works for any Traversable type, not just our Tree type!

Getting rid of the Traversable boilerplate

Section titled “Getting rid of the Traversable boilerplate”

GHC has a DeriveTraversable extension that eliminates the need for writing the instance above:

{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
data Tree a = Tree a [Tree a]
deriving (Show, Functor, Foldable, Traversable)

Newcomers to Haskell often shy away from the State monad and treat it like a taboo—like the claimed benefit of functional programming is the avoidance of state, so don’t you lose that when you use State? A more nuanced view is that:

  • State can be useful in small, controlled doses;
  • The State type provides the ability to control the dose very precisely.

The reasons being that if you have action :: State s a, this tells you that:

  • action is special because it depends on a state;
  • The state has type s, so action cannot be influenced by any old value in your program—only an s or some value reachable from some s;
  • The runState :: State s a -> s -> (a, s) puts a “barrier” around the stateful action, so that its effectfulness cannot be observed from outside that barrier.

So this is a good set of criteria for whether to use State in particular scenario. You want to see that your code is minimizing the scope of the state, both by choosing a narrow type for s and by putting runState as close to “the bottom” as possible, (so that your actions can be influenced by as few thing as possible.