dramforever

coding, thoughts and stuff irl

表达式归一化与 Free Monad

2016-05-22

(注:正文部分代码里的下划线是链接,指向 Haddocks)

各种定义

我们的目标是支持以下操作:

如通过 fresh 获得变量,构造两个表达式 f[#A, u[#B], #C]#D[#E, #F, #G[v]]record,之后 unify 两个 record 得到的变量,再 report 其中一个,应该得到 f[#E,u[#B],#G[v]]

表达式的代码实现

我昨天晚上想到,其实三个表达式可以用相同的一个结构实现

data ExprF a
  = Atom Identifier
  | Cons a [a]

type Expression = Fix ExprF
type PartialExpr = Free ExprF Var
type UnionFindExpr = ExprF Var

第一个并没有再次出现了,我们不管它。大家可以体会一下 PartialExpr。以下 sugar 可能有帮助:

atom :: Identifier -> PartialExpr
atom = Free . Atom

cons :: PartialExpr -> [PartialExpr] -> PartialExpr
cons u v = Free (Cons u v)

instance IsString PartialExpr where
  fromString = atom . fromString

信息维护 Monad 的实现

data UnionFindState
type MonadUFS m = MonadState UnionFindState m

fresh :: MonadUFS m => m Var

find :: MonadUFS m => Var -> m (Var, Maybe UnionFindExpr)

data UnificationError
  = AtomMismatch Identifier Identifier
  | AtomNotCons Identifier Var [Var]
  | ConsLengthMismatch UnionFindExpr UnionFindExpr

unify :: (MonadError UnificationError m, MonadUFS m)
      => Var -> Var -> m ()

具体内容略去,可以参见结尾完整代码。这里多了一个内部使用的函数 find,其作用是找到与给定变量相等的一个变量,而这个变量具有性质:它不是被指定等于另一个变量,也就是顺着 Parent 走到顶端。另外,它还返回这个变量被指定等于的 UnionFindExpr(如果有的话)。

recordreport 的实现

record :: MonadUFS m => PartialExpr -> m Var
record = iterA go where
  go f = do
    u <- sequence f
    v <- fresh
    ufsMap . at v .= Just (Linked u)
    pure v
其中:
iterA :: (Applicative p, Functor f)
      => (f (p a) -> p a) -> Free f a -> p a
在这里使用的是
iterA :: MonadUFS m
      => (ExprF (m Var) -> m Var)
      -> PartialExpr
      -> m Var

这里 ufsMap . at v .= Just (...) 是 lens 用法,意义是设定维护信息中变量 v 的指定。

我们注意到,只要给 ExprF 定义时 deriving (Traversable)

sequence :: MonadUFS m => ExprF (m Var) -> m UnionFindExpr

也就是说,将传入参数 sequence 后得到的正好是我们需要创建的变量所指定相等的值!所以,我们成功地在不依赖于 ExprF 具体定义的情况下实现了 record

report 也非常简单:

report :: MonadUFS m => Var -> m PartialExpr
report = unfoldM go where
  go v = find v >>= \case
    (u, Nothing) -> pure (Left u)
    (_, Just t) -> pure (Right t)

至此,一个简易的表达式归一化程序就完成了。

总结

在这个归一化程序中,我们很明显地发现,Free monad 的使用简化了表达式的定义和一些琐碎的功能的实现,有效避免了乏味而易错的代码重复。

附:完整实现代码

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Play where

import qualified Data.Text as T
import Control.Monad.Free
import qualified Data.Map as M
import Control.Lens hiding (cons)
import Control.Monad.State
import Control.Monad.Except
import Data.List (intercalate)
import Data.String
import Control.Applicative
import Data.Functor.Foldable (Fix(..))

newtype Identifier
  = Identifier { getIdentifier :: T.Text }
  deriving (IsString)

instance Show Identifier where
  show (Identifier u) = T.unpack u

data ExprF a
  = Atom Identifier
  | Cons a [a]
  deriving (Functor, Foldable, Traversable)

type Expression = Fix ExprF
type PartialExpr = Free ExprF Var
type UnionFindExpr = ExprF Var

printPartial :: PartialExpr -> String
printPartial = iter go . fmap show where
  go (Atom d) = show d
  go (Cons x xs) = x ++ "[" ++ intercalate "," xs ++ "]"

instance Show u => Show (ExprF u) where
  show (Atom (Identifier u)) = T.unpack u
  show (Cons x xs) = show x ++ "[" ++ intercalate "," (show <$> xs) ++ "]"

newtype Var
  = Var { getVar :: T.Text }
  deriving (Eq, Ord)

instance Show Var where
  show (Var v) = "#" ++ T.unpack v

data UnionFindPointer
  = Parent Var
  | Linked UnionFindExpr

$(makePrisms ''UnionFindPointer)

data UnionFindState
  = UnionFindState
    { _ufsSupply :: [Var]
    , _ufsMap :: M.Map Var UnionFindPointer
    }

$(makeLenses ''UnionFindState)

type MonadUFS m = MonadState UnionFindState m

fresh :: MonadUFS m => m Var
fresh = do
  ufsSupply %= tail
  use (ufsSupply . to head)

find :: MonadUFS m => Var -> m (Var, Maybe UnionFindExpr)
find u = locateRoot u >>= \(x, e) -> (x, e) <$ compressPath u x where
  locateRoot t =
    use (ufsMap . at t) >>= \case
      Just (Parent v) -> locateRoot v
      Just (Linked ex) -> pure (t, Just ex)
      Nothing -> pure (t, Nothing)
  compressPath t x = go t where
    go m
      | m == x = pure ()
      | otherwise = use (ufsMap . at m) >>= \case
          Just (Parent par) -> do
            ufsMap . at m .= Just (Parent x)
            go par
          _ -> error "Internal error: find: Can't happen!"

data UnificationError
  = AtomMismatch Identifier Identifier
  | AtomNotCons Identifier Var [Var]
  | ConsLengthMismatch UnionFindExpr UnionFindExpr
  deriving (Show)

unify :: (MonadError UnificationError m, MonadUFS m) => Var -> Var -> m ()
unify u1 v1 = do
  (xu, eu) <- find u1
  (xv, ev) <- find v1
  when (xu /= xv) $ case (eu, ev) of
    (Nothing, _) -> ufsMap . at xu .= Just (Parent xv)
    (Just _, Nothing) -> ufsMap . at xv .= Just (Parent xu)
    (Just p, Just q) -> go p q where
      go (Atom x) (Atom y)
        | getIdentifier x == getIdentifier y = pure ()
        | otherwise = throwError (AtomMismatch x y)
      go (Atom x) (Cons y ys) = throwError (AtomNotCons x y ys)
      go m@Cons{} n@Atom{} = go n m
      go m@(Cons x xs) n@(Cons y ys)
        | length xs == length ys = do
          unify x y
          zipWithM_ unify xs ys
          ufsMap . at xu .= Just (Parent xv)
        | otherwise = throwError (ConsLengthMismatch m n)

record :: MonadUFS m => PartialExpr -> m Var
record = iterA go where
  go f = do
    u <- sequence f
    v <- fresh
    ufsMap . at v .= Just (Linked u)
    pure v

report :: MonadUFS m => Var -> m PartialExpr
report = unfoldM go where
  go v = find v >>= \case
    (u, Nothing) -> pure (Left u)
    (_, Just t) -> pure (Right t)

run :: StateT UnionFindState (Except UnificationError) a
    -> Either UnificationError a
run s = runExcept (evalStateT s (UnionFindState vars M.empty)) where
  vars =
    let u = "" : liftA2 (flip (:)) u ['A'..'Z']
    in map (Var . T.pack) u

atom :: Identifier -> PartialExpr
atom = Free . Atom

cons :: PartialExpr -> [PartialExpr] -> PartialExpr
cons u v = Free (Cons u v)

instance IsString PartialExpr where
  fromString = atom . fromString

以下是一个简单的测试

a1, a2 :: MonadUFS m => m PartialExpr

-- f[#A, u[#B], #C]
a1 = do
  a <- fresh
  b <- fresh
  c <- fresh
  pure $ cons "f" [pure a, cons "u" [pure b], pure c]

-- #D[#E, #F, #G[v]]
a2 = do
  d <- fresh
  e <- fresh
  f <- fresh
  g <- fresh
  pure $ cons (pure d) [pure e, pure f, cons (pure g) ["v"]]

test :: (MonadError UnificationError m, MonadUFS m) => m PartialExpr
test = do
  e1 <- a1
  e2 <- a2
  x <- record e1
  y <- record e2
  unify x y
  report x

-- ghci> let Right u = run (printPartial <$> test) in putStrLn u
-- f[#E,u[#B],#G[v]]
--
-- Note: The variable names are generated automatically. As you can see
-- the definitions of a1 and a2 do not specify names.