-- |
-- This module is an efficient implementation of the derivative algorithm for trees.
--
-- It is intended to be used for production purposes.
--
-- This means that it gives up some readability for speed.
--
-- This module provides memoization of the nullable, calls and returns functions.

module Data.Katydid.Relapse.MemDerive
  ( derive
  , Mem
  , newMem
  , validate
  )
where

import qualified Data.Map.Strict               as M
import           Control.Monad.State            ( State
                                                , runState
                                                , lift
                                                , state
                                                )
import           Control.Monad.Trans.Except     ( ExceptT(..)
                                                , runExceptT
                                                )

import           Data.Katydid.Parser.Parser

import qualified Data.Katydid.Relapse.Derive   as Derive
import           Data.Katydid.Relapse.Smart     ( Grammar
                                                , Pattern
                                                , lookupRef
                                                , nullable
                                                , lookupMain
                                                )
import           Data.Katydid.Relapse.IfExprs
import           Data.Katydid.Relapse.Expr
import           Data.Katydid.Relapse.Zip

mem :: Ord k => (k -> v) -> k -> M.Map k v -> (v, M.Map k v)
mem f k m | M.member k m = (m M.! k, m)
          | otherwise    = let res = f k in (res, M.insert k res m)

type Calls = M.Map [Pattern] IfExprs
type Returns = M.Map ([Pattern], [Bool]) [Pattern]

-- |
-- Mem is the object used to store memoized results of the nullable, calls and returns functions.
newtype Mem = Mem (Calls, Returns)

-- |
-- newMem creates a object used for memoization by the validate function.
-- Each grammar should create its own memoize object.
newMem :: Mem
newMem = Mem (M.empty, M.empty)

calls :: Grammar -> [Pattern] -> State Mem IfExprs
calls g k = state $ \(Mem (c, r)) ->
  let (v', c') = mem (Derive.calls g) k c in (v', Mem (c', r))

returns :: Grammar -> ([Pattern], [Bool]) -> State Mem [Pattern]
returns g k = state $ \(Mem (c, r)) ->
  let (v', r') = mem (Derive.returns g) k r in (v', Mem (c, r'))

mderive
  :: Tree t
  => Grammar
  -> [Pattern]
  -> [t]
  -> ExceptT String (State Mem) [Pattern]
mderive _ ps []          = return ps
mderive g ps (tree : ts) = do
  ifs                <- lift $ calls g ps
  childps            <- hoistExcept $ evalIfExprs ifs (getLabel tree)
  (zchildps, zipper) <- return $ zippy childps
  childres           <- mderive g zchildps (getChildren tree)
  let nulls   = map nullable childres
      unzipns = unzipby zipper nulls
  rs <- lift $ returns g (ps, unzipns)
  mderive g rs ts

hoistExcept :: (Monad m) => Either e a -> ExceptT e m a
hoistExcept = ExceptT . return

-- |
-- derive is the classic derivative implementation for trees.
derive :: Tree t => Grammar -> [t] -> Either String Pattern
derive g ts =
  let start    = [lookupMain g]
      (res, _) = runState (runExceptT $ mderive g start ts) newMem
  in  case res of
        (Left  l  ) -> Left l
        (Right [r]) -> return r
        (Right rs ) -> Left $ "not a single pattern: " ++ show rs

-- |
-- validate is the uses the derivative implementation for trees and
-- return whether tree is valid, given the input grammar and start pattern.
validate :: Tree t => Grammar -> Pattern -> [t] -> (State Mem) Bool
validate g start tree = do
  rs <- runExceptT (mderive g [start] tree)
  return $ case rs of
    (Right [r]) -> nullable r
    _           -> False