-- |
-- This module is a simple implementation of the internal derivative algorithm.
--
-- It is intended to be used for explanation purposes.
--
-- This means that it gives up speed for readability.
--
-- Thus it has no type of memoization.

module Data.Katydid.Relapse.Derive
  ( derive
  , calls
  , returns
  , zipderive
    -- * Internal functions
    -- | These functions are exposed for testing purposes.
  , removeOneForEach
  )
where

import           Data.Foldable                  ( foldlM )
import           Data.List.Index                ( imap )

import           Data.Katydid.Parser.Parser

import           Data.Katydid.Relapse.Smart
import           Data.Katydid.Relapse.Simplify
import           Data.Katydid.Relapse.Zip
import           Data.Katydid.Relapse.IfExprs

-- | 
-- calls returns a compiled if expression tree.
-- Each if expression returns a child pattern, given the input value.
-- In other words calls signature is actually:
--
-- @
--   Refs -> [Pattern] -> Value -> [Pattern]
-- @
--
-- , where the resulting list of patterns are the child patterns,
-- that need to be derived given the trees child values.
calls :: Grammar -> [Pattern] -> IfExprs
calls g ps = compileIfExprs $ concatMap (\p -> deriveCall g p []) ps

deriveCall :: Grammar -> Pattern -> [IfExpr] -> [IfExpr]
deriveCall _ Empty                      res = res
deriveCall _ ZAny                       res = res
deriveCall _ Node { expr = v, pat = p } res = newIfExpr v p emptySet : res
deriveCall g Concat { left = l, right = r } res
  | nullable l = deriveCall g l (deriveCall g r res)
  | otherwise  = deriveCall g l res
deriveCall g Or { pats = ps }         res = foldr (deriveCall g) res ps
deriveCall g And { pats = ps }        res = foldr (deriveCall g) res ps
deriveCall g Interleave { pats = ps } res = foldr (deriveCall g) res ps
deriveCall g ZeroOrMore { pat = p }   res = deriveCall g p res
deriveCall g Reference { refName = name } res =
  deriveCall g (lookupRef g name) res
deriveCall g Not { pat = p }      res = deriveCall g p res
deriveCall g Contains { pat = p } res = deriveCall g p res
deriveCall g Optional { pat = p } res = deriveCall g p res

-- |
-- returns takes a list of patterns and list of bools.
-- The list of bools represent the nullability of the derived child patterns.
-- Each bool will then replace each Node pattern with either an Empty or EmptySet.
-- The lists do not to be the same length, because each Pattern can contain an arbitrary number of Node Patterns.
returns :: Grammar -> ([Pattern], [Bool]) -> [Pattern]
returns _ ([], []) = []
returns g (p : tailps, ns) =
  let (dp, tailns) = deriveReturn g p ns in dp : returns g (tailps, tailns)

mapReturn :: Grammar -> [Pattern] -> [Bool] -> ([Pattern], [Bool])
mapReturn g ps ns = foldl
  (\(dps, tailns) p ->
    let (dp, tailoftail) = deriveReturn g p tailns in (dp : dps, tailoftail)
  )
  ([], ns)
  ps

deriveReturn :: Grammar -> Pattern -> [Bool] -> (Pattern, [Bool])
deriveReturn _ Empty ns = (emptySet, ns)
deriveReturn _ ZAny  ns = (zanyPat, ns)
deriveReturn _ Node{} ns | head ns   = (emptyPat, tail ns)
                         | otherwise = (emptySet, tail ns)
deriveReturn g Concat { left = l, right = r } ns
  | nullable l
  = let (dl, ltail) = deriveReturn g l ns
        (dr, rtail) = deriveReturn g r ltail
    in  (orPat (concatPat dl r) dr, rtail)
  | otherwise
  = let (dl, ltail) = deriveReturn g l ns in (concatPat dl r, ltail)
deriveReturn g Or { pats = ps } ns =
  let (dps, tailns) = mapReturn g ps ns in (foldl1 orPat dps, tailns)
deriveReturn g And { pats = ps } ns =
  let (dps, tailns) = mapReturn g ps ns in (foldl1 andPat dps, tailns)
deriveReturn g Interleave { pats = ps } ns =
  let (dps, tailns) = mapReturn g ps ns
      pps           = reverse $ removeOneForEach ps
      ips           = zipWith (:) dps pps
      ors           = map (foldl1 interleavePat) ips
  in  (foldl1 orPat ors, tailns)
deriveReturn g z@ZeroOrMore { pat = p } ns =
  let (dp, tailns) = deriveReturn g p ns in (concatPat dp z, tailns)
deriveReturn g Reference { refName = name } ns =
  deriveReturn g (lookupRef g name) ns
deriveReturn g Not { pat = p } ns =
  let (dp, tailns) = deriveReturn g p ns in (notPat dp, tailns)
deriveReturn g c@Contains { pat = p } ns =
  let (dp, tailns) = deriveReturn g p ns in (orPat c (containsPat dp), tailns)
deriveReturn g Optional { pat = p } ns = deriveReturn g p ns

-- | For internal testing.
-- removeOneForEach creates N copies of the list removing the n'th element from each.
removeOneForEach :: [a] -> [[a]]
removeOneForEach xs = imap
  (\index list -> let (start, end) = splitAt index list in start ++ tail end)
  (replicate (length xs) xs)

-- |
-- derive is the classic derivative implementation for trees.
derive :: Tree t => Grammar -> [t] -> Either String Pattern
derive g ts = do
  ps <- foldlM (deriv g) [lookupMain g] ts
  if length ps == 1
    then return $ head ps
    else Left $ "Number of patterns is not one, but " ++ show ps

deriv :: Tree t => Grammar -> [Pattern] -> t -> Either String [Pattern]
deriv g ps tree = if all unescapable ps
  then return ps
  else
    let ifs   = calls g ps
        d     = deriv g
        nulls = map nullable
    in  do
          childps  <- evalIfExprs ifs (getLabel tree)
          childres <- foldlM d childps (getChildren tree)
          return $ returns g (ps, nulls childres)

-- |
-- zipderive is a slighty optimized version of derivs.
-- It zips its intermediate pattern lists to reduce the state space.
zipderive :: Tree t => Grammar -> [t] -> Either String Pattern
zipderive g ts = do
  ps <- foldlM (zipderiv g) [lookupMain g] ts
  if length ps == 1
    then return $ head ps
    else Left $ "Number of patterns is not one, but " ++ show ps

zipderiv :: Tree t => Grammar -> [Pattern] -> t -> Either String [Pattern]
zipderiv g ps tree = if all unescapable ps
  then return ps
  else
    let ifs   = calls g ps
        d     = zipderiv g
        nulls = map nullable
    in  do
          childps            <- evalIfExprs ifs (getLabel tree)
          (zchildps, zipper) <- return $ zippy childps
          childres           <- foldlM d zchildps (getChildren tree)
          let unzipns = unzipby zipper (nulls childres)
          return $ returns g (ps, unzipns)