-- |
-- This module contains the Relapse compare expressions: 
-- equal, not equal, greater than, greater than or equal, less than and less than or equal.
module Data.Katydid.Relapse.Exprs.Compare
  ( mkEqExpr
  , eqExpr
  , mkNeExpr
  , neExpr
  , mkGeExpr
  , geExpr
  , mkLeExpr
  , leExpr
  , mkGtExpr
  , gtExpr
  , mkLtExpr
  , ltExpr
  )
where

import           Data.Katydid.Relapse.Expr

-- |
-- mkEqExpr dynamically creates an eq (equal) expression, if the two input types are the same.
mkEqExpr :: [AnyExpr] -> Either String AnyExpr
mkEqExpr es = do
  (e1, e2) <- assertArgs2 "eq" es
  case e1 of
    (AnyExpr _ (BoolFunc _)) -> mkEqExpr' <$> assertBool e1 <*> assertBool e2
    (AnyExpr _ (IntFunc  _)) -> mkEqExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkEqExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkEqExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (StringFunc _)) ->
      mkEqExpr' <$> assertString e1 <*> assertString e2
    (AnyExpr _ (BytesFunc _)) ->
      mkEqExpr' <$> assertBytes e1 <*> assertBytes e2

mkEqExpr' :: (Eq a) => Expr a -> Expr a -> AnyExpr
mkEqExpr' e f = mkBoolExpr $ eqExpr e f

-- |
-- eqExpr creates an eq (equal) expression that returns true if the two evaluated input expressions are equal
-- and both don't evaluate to an error.
eqExpr :: (Eq a) => Expr a -> Expr a -> Expr Bool
eqExpr a b = trimBool Expr
  { desc = mkDesc "eq" [desc a, desc b]
  , eval = \v -> eq (eval a v) (eval b v)
  }

eq :: (Eq a) => Either String a -> Either String a -> Either String Bool
eq (Right v1) (Right v2) = return $ v1 == v2
eq (Left  _ ) _          = return False
eq _          (Left _)   = return False

-- |
-- mkNeExpr dynamically creates a ne (not equal) expression, if the two input types are the same.
mkNeExpr :: [AnyExpr] -> Either String AnyExpr
mkNeExpr es = do
  (e1, e2) <- assertArgs2 "ne" es
  case e1 of
    (AnyExpr _ (BoolFunc _)) -> mkNeExpr' <$> assertBool e1 <*> assertBool e2
    (AnyExpr _ (IntFunc  _)) -> mkNeExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkNeExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkNeExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (StringFunc _)) ->
      mkNeExpr' <$> assertString e1 <*> assertString e2
    (AnyExpr _ (BytesFunc _)) ->
      mkNeExpr' <$> assertBytes e1 <*> assertBytes e2

mkNeExpr' :: (Eq a) => Expr a -> Expr a -> AnyExpr
mkNeExpr' e f = mkBoolExpr $ neExpr e f

-- |
-- neExpr creates a ne (not equal) expression that returns true if the two evaluated input expressions are not equal
-- and both don't evaluate to an error.
neExpr :: (Eq a) => Expr a -> Expr a -> Expr Bool
neExpr a b = trimBool Expr
  { desc = mkDesc "ne" [desc a, desc b]
  , eval = \v -> ne (eval a v) (eval b v)
  }

ne :: (Eq a) => Either String a -> Either String a -> Either String Bool
ne (Right v1) (Right v2) = return $ v1 /= v2
ne (Left  _ ) _          = return False
ne _          (Left _)   = return False

-- |
-- mkGeExpr dynamically creates a ge (greater than or equal) expression, if the two input types are the same.
mkGeExpr :: [AnyExpr] -> Either String AnyExpr
mkGeExpr es = do
  (e1, e2) <- assertArgs2 "ge" es
  case e1 of
    (AnyExpr _ (IntFunc  _)) -> mkGeExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkGeExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkGeExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (BytesFunc _)) ->
      mkGeExpr' <$> assertBytes e1 <*> assertBytes e2

mkGeExpr' :: (Ord a) => Expr a -> Expr a -> AnyExpr
mkGeExpr' e f = mkBoolExpr $ geExpr e f

-- |
-- geExpr creates a ge (greater than or equal) expression that returns true if the first evaluated expression is greater than or equal to the second
-- and both don't evaluate to an error.
geExpr :: (Ord a) => Expr a -> Expr a -> Expr Bool
geExpr a b = trimBool Expr
  { desc = mkDesc "ge" [desc a, desc b]
  , eval = \v -> ge (eval a v) (eval b v)
  }

ge :: (Ord a) => Either String a -> Either String a -> Either String Bool
ge (Right v1) (Right v2) = return $ v1 >= v2
ge (Left  _ ) _          = return False
ge _          (Left _)   = return False

-- |
-- mkGtExpr dynamically creates a gt (greater than) expression, if the two input types are the same.
mkGtExpr :: [AnyExpr] -> Either String AnyExpr
mkGtExpr es = do
  (e1, e2) <- assertArgs2 "gt" es
  case e1 of
    (AnyExpr _ (IntFunc  _)) -> mkGtExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkGtExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkGtExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (BytesFunc _)) ->
      mkGtExpr' <$> assertBytes e1 <*> assertBytes e2

mkGtExpr' :: (Ord a) => Expr a -> Expr a -> AnyExpr
mkGtExpr' e f = mkBoolExpr $ gtExpr e f

-- |
-- gtExpr creates a gt (greater than) expression that returns true if the first evaluated expression is greater than the second
-- and both don't evaluate to an error.
gtExpr :: (Ord a) => Expr a -> Expr a -> Expr Bool
gtExpr a b = trimBool Expr
  { desc = mkDesc "gt" [desc a, desc b]
  , eval = \v -> gt (eval a v) (eval b v)
  }

gt :: (Ord a) => Either String a -> Either String a -> Either String Bool
gt (Right v1) (Right v2) = return $ v1 > v2
gt (Left  _ ) _          = return False
gt _          (Left _)   = return False

-- |
-- mkLeExpr dynamically creates a le (less than or equal) expression, if the two input types are the same.
mkLeExpr :: [AnyExpr] -> Either String AnyExpr
mkLeExpr es = do
  (e1, e2) <- assertArgs2 "le" es
  case e1 of
    (AnyExpr _ (IntFunc  _)) -> mkLeExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkLeExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkLeExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (BytesFunc _)) ->
      mkLeExpr' <$> assertBytes e1 <*> assertBytes e2

mkLeExpr' :: (Ord a) => Expr a -> Expr a -> AnyExpr
mkLeExpr' e f = mkBoolExpr $ leExpr e f

-- |
-- leExpr creates a le (less than or equal) expression that returns true if the first evaluated expression is less than or equal to the second
-- and both don't evaluate to an error.
leExpr :: (Ord a) => Expr a -> Expr a -> Expr Bool
leExpr a b = trimBool Expr
  { desc = mkDesc "le" [desc a, desc b]
  , eval = \v -> le (eval a v) (eval b v)
  }

le :: (Ord a) => Either String a -> Either String a -> Either String Bool
le (Right v1) (Right v2) = return $ v1 <= v2
le (Left  _ ) _          = return False
le _          (Left _)   = return False

-- |
-- mkLtExpr dynamically creates a lt (less than) expression, if the two input types are the same.
mkLtExpr :: [AnyExpr] -> Either String AnyExpr
mkLtExpr es = do
  (e1, e2) <- assertArgs2 "lt" es
  case e1 of
    (AnyExpr _ (IntFunc  _)) -> mkLtExpr' <$> assertInt e1 <*> assertInt e2
    (AnyExpr _ (UintFunc _)) -> mkLtExpr' <$> assertUint e1 <*> assertUint e2
    (AnyExpr _ (DoubleFunc _)) ->
      mkLtExpr' <$> assertDouble e1 <*> assertDouble e2
    (AnyExpr _ (BytesFunc _)) ->
      mkLtExpr' <$> assertBytes e1 <*> assertBytes e2

mkLtExpr' :: (Ord a) => Expr a -> Expr a -> AnyExpr
mkLtExpr' e f = mkBoolExpr $ ltExpr e f

-- |
-- ltExpr creates a lt (less than) expression that returns true if the first evaluated expression is less than the second
-- and both don't evaluate to an error.
ltExpr :: (Ord a) => Expr a -> Expr a -> Expr Bool
ltExpr a b = trimBool Expr
  { desc = mkDesc "lt" [desc a, desc b]
  , eval = \v -> lt (eval a v) (eval b v)
  }

lt :: (Ord a) => Either String a -> Either String a -> Either String Bool
lt (Right v1) (Right v2) = return $ v1 < v2
lt (Left  _ ) _          = return False
lt _          (Left _)   = return False