{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DataKinds         #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs             #-}
{-# LANGUAGE KindSignatures    #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE PatternSynonyms   #-}
{-# LANGUAGE RankNTypes        #-}
{-# LANGUAGE TemplateHaskell   #-}
-- | Free foil implementation of the \(\lambda\Pi\)-calculus (with pairs).
--
-- Free foil provides __general__ definitions or implementations for the following:
--
-- 1. Freely generated (from a simple signature) scope-safe AST.
-- 2. Correct capture-avoiding substitution (see 'substitute').
-- 3. Correct \(\alpha\)-equivalence checks (see 'alphaEquiv' and 'alphaEquivRefreshed') as well as \(\alpha\)-normalization (see 'refreshAST').
-- 4. Conversion helpers (see 'convertToAST' and 'convertFromAST').
--
-- The following is __generated__ using Template Haskell:
--
-- 1. Convenient pattern synonyms.
-- 2. 'ZipMatch' instances (enabling general \(\alpha\)-equivalence).
-- 3. Conversion between scope-safe and raw term representation.
--
-- The following is implemented __manually__ in this module:
--
-- 1. Computation of weak head normal form (WHNF), see 'whnf'.
-- 2. Entry point, gluing everything together. See 'defaultMain'.
--
-- __Note:__ free foil does not (easily) support patterns at the moment,
-- so only wildcard patterns and variable patterns are handled in this implementation.
module Language.LambdaPi.Impl.FreeFoilTH where

import qualified Control.Monad.Foil              as Foil
import           Control.Monad.Foil.TH
import           Control.Monad.Free.Foil
import           Control.Monad.Free.Foil.TH
import           Data.Bifunctor.TH
import           Data.Map                        (Map)
import qualified Data.Map                        as Map
import           Data.String                     (IsString (..))
import qualified Language.LambdaPi.Syntax.Abs    as Raw
import qualified Language.LambdaPi.Syntax.Layout as Raw
import qualified Language.LambdaPi.Syntax.Lex    as Raw
import qualified Language.LambdaPi.Syntax.Par    as Raw
import qualified Language.LambdaPi.Syntax.Print  as Raw
import           System.Exit                     (exitFailure)

-- $setup
-- >>> :set -XOverloadedStrings
-- >>> :set -XDataKinds
-- >>> import qualified Control.Monad.Foil as Foil
-- >>> import Control.Monad.Free.Foil
-- >>> import Data.String (fromString)

-- * Generated code

-- ** Signature
mkSignature ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern'
deriveZipMatch ''Term'Sig
deriveBifunctor ''Term'Sig
deriveBifoldable ''Term'Sig
deriveBitraversable ''Term'Sig

-- ** Pattern synonyms
mkPatternSynonyms ''Term'Sig

-- ** Conversion helpers

mkConvertToFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern'
mkConvertFromFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern'

-- ** Scope-safe patterns

mkFoilPattern ''Raw.VarIdent ''Raw.Pattern'
deriveCoSinkable ''Raw.VarIdent ''Raw.Pattern'
mkToFoilPattern ''Raw.VarIdent ''Raw.Pattern'
mkFromFoilPattern ''Raw.VarIdent ''Raw.Pattern'

-- | Ignoring location information when unifying patterns.
instance Foil.UnifiableInPattern Raw.BNFC'Position where
  unifyInPattern :: BNFC'Position -> BNFC'Position -> Bool
unifyInPattern BNFC'Position
_ BNFC'Position
_  = Bool
True
deriveUnifiablePattern ''Raw.VarIdent ''Raw.Pattern'

-- * User-defined code

-- | Generic annotated scope-safe \(\lambda\Pi\)-terms with patterns.
type Term' a = AST (FoilPattern' a) (Term'Sig a)

-- | Scode-safe \(\lambda\Pi\)-terms annotated with source code position.
type Term = Term' Raw.BNFC'Position

-- | Scope-safe patterns annotated with source code position.
type FoilPattern = FoilPattern' Raw.BNFC'Position

-- ** Conversion helpers

-- | Convert 'Raw.Term'' into a scope-safe term.
-- This is a special case of 'convertToAST'.
toTerm' :: Foil.Distinct n => Foil.Scope n -> Map Raw.VarIdent (Foil.Name n) -> Raw.Term' a -> Term' a n
toTerm' :: forall (n :: S) a.
Distinct n =>
Scope n -> Map VarIdent (Name n) -> Term' a -> Term' a n
toTerm' = (Term' a
 -> Either
      VarIdent (Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a)))
-> (forall (x :: S) z.
    Distinct x =>
    Scope x
    -> Map VarIdent (Name x)
    -> Pattern' a
    -> (forall (y :: S).
        DExt x y =>
        FoilPattern' a x y -> Map VarIdent (Name y) -> z)
    -> z)
-> (ScopedTerm' a -> Term' a)
-> Scope n
-> Map VarIdent (Name n)
-> Term' a
-> AST (FoilPattern' a) (Term'Sig a) n
forall (n :: S) (sig :: * -> * -> *) rawIdent
       (binder :: S -> S -> *) rawTerm rawPattern rawScopedTerm.
(Distinct n, Bifunctor sig, Ord rawIdent, CoSinkable binder) =>
(rawTerm
 -> Either rawIdent (sig (rawPattern, rawScopedTerm) rawTerm))
-> (forall (x :: S) z.
    Distinct x =>
    Scope x
    -> Map rawIdent (Name x)
    -> rawPattern
    -> (forall (y :: S).
        DExt x y =>
        binder x y -> Map rawIdent (Name y) -> z)
    -> z)
-> (rawScopedTerm -> rawTerm)
-> Scope n
-> Map rawIdent (Name n)
-> rawTerm
-> AST binder sig n
convertToAST Term' a
-> Either
     VarIdent (Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a))
forall a.
Term' a
-> Either
     VarIdent (Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a))
convertToTerm'Sig Scope x
-> Map VarIdent (Name x)
-> Pattern' a
-> (forall (l :: S).
    DExt x l =>
    FoilPattern' a x l -> Map VarIdent (Name l) -> z)
-> z
forall (x :: S) z.
Distinct x =>
Scope x
-> Map VarIdent (Name x)
-> Pattern' a
-> (forall (y :: S).
    DExt x y =>
    FoilPattern' a x y -> Map VarIdent (Name y) -> z)
-> z
forall (n :: S) a r.
Distinct n =>
Scope n
-> Map VarIdent (Name n)
-> Pattern' a
-> (forall (l :: S).
    DExt n l =>
    FoilPattern' a n l -> Map VarIdent (Name l) -> r)
-> r
toFoilPattern' ScopedTerm' a -> Term' a
forall a. ScopedTerm' a -> Term' a
getTerm'FromScopedTerm'

-- | Convert 'Raw.Term'' into a closed scope-safe term.
-- This is a special case of 'toTerm''.
toTerm'Closed :: Raw.Term' a -> Term' a Foil.VoidS
toTerm'Closed :: forall a. Term' a -> Term' a 'VoidS
toTerm'Closed = Scope 'VoidS
-> Map VarIdent (Name 'VoidS) -> Term' a -> Term' a 'VoidS
forall (n :: S) a.
Distinct n =>
Scope n -> Map VarIdent (Name n) -> Term' a -> Term' a n
toTerm' Scope 'VoidS
Foil.emptyScope Map VarIdent (Name 'VoidS)
forall k a. Map k a
Map.empty

-- | Convert a scope-safe representation back into 'Raw.Term''.
-- This is a special case of 'convertFromAST'.
--
-- 'Raw.VarIdent' names are generated based on the raw identifiers in the underlying foil representation.
--
-- This function does not recover location information for variables, patterns, or scoped terms.
fromTerm' :: Term' a n -> Raw.Term' a
fromTerm' :: forall a (n :: S). Term' a n -> Term' a
fromTerm' = (Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a) -> Term' a)
-> (VarIdent -> Term' a)
-> (forall (x :: S) (y :: S).
    (Int -> VarIdent) -> FoilPattern' a x y -> Pattern' a)
-> (Term' a -> ScopedTerm' a)
-> (Int -> VarIdent)
-> AST (FoilPattern' a) (Term'Sig a) n
-> Term' a
forall (sig :: * -> * -> *) rawPattern rawScopedTerm rawTerm
       rawIdent (binder :: S -> S -> *) (n :: S).
Bifunctor sig =>
(sig (rawPattern, rawScopedTerm) rawTerm -> rawTerm)
-> (rawIdent -> rawTerm)
-> (forall (x :: S) (y :: S).
    (Int -> rawIdent) -> binder x y -> rawPattern)
-> (rawTerm -> rawScopedTerm)
-> (Int -> rawIdent)
-> AST binder sig n
-> rawTerm
convertFromAST
  Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a) -> Term' a
forall a.
Term'Sig a (Pattern' a, ScopedTerm' a) (Term' a) -> Term' a
convertFromTerm'Sig
  (a -> VarIdent -> Term' a
forall a. a -> VarIdent -> Term' a
Raw.Var (String -> a
forall a. HasCallStack => String -> a
error String
"location missing"))
  (Int -> VarIdent) -> FoilPattern' a x y -> Pattern' a
forall a (n :: S) (l :: S).
(Int -> VarIdent) -> FoilPattern' a n l -> Pattern' a
forall (x :: S) (y :: S).
(Int -> VarIdent) -> FoilPattern' a x y -> Pattern' a
fromFoilPattern'
  (a -> Term' a -> ScopedTerm' a
forall a. a -> Term' a -> ScopedTerm' a
Raw.AScopedTerm (String -> a
forall a. HasCallStack => String -> a
error String
"location missing"))
  (\Int
n -> String -> VarIdent
Raw.VarIdent (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n))

-- | Parse scope-safe terms via raw representation.
--
-- >>> fromString "λx.λy.λx.x" :: Term Foil.VoidS
-- λ x0 . λ x1 . λ x2 . x2
instance IsString (AST FoilPattern (Term'Sig Raw.BNFC'Position) Foil.VoidS) where
  fromString :: String -> Term 'VoidS
fromString String
input = case [Token] -> Err Term
Raw.pTerm (String -> [Token]
Raw.tokens String
input) of
    Left String
err   -> String -> Term 'VoidS
forall a. HasCallStack => String -> a
error (String
"could not parse λΠ-term: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
input String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"\n  " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
err)
    Right Term
term -> Term -> Term 'VoidS
forall a. Term' a -> Term' a 'VoidS
toTerm'Closed Term
term

-- | Pretty-print scope-safe terms via raw representation.
instance Show (AST (FoilPattern' a) (Term'Sig a) Foil.VoidS) where
  show :: AST (FoilPattern' a) (Term'Sig a) 'VoidS -> String
show = Term' a -> String
forall a. Print a => a -> String
Raw.printTree (Term' a -> String)
-> (AST (FoilPattern' a) (Term'Sig a) 'VoidS -> Term' a)
-> AST (FoilPattern' a) (Term'Sig a) 'VoidS
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AST (FoilPattern' a) (Term'Sig a) 'VoidS -> Term' a
forall a (n :: S). Term' a n -> Term' a
fromTerm'

-- ** Evaluation

-- | Match a pattern against an term.
matchPattern :: FoilPattern n l -> Term n -> Foil.Substitution Term l n
matchPattern :: forall (n :: S) (l :: S).
FoilPattern n l -> Term n -> Substitution Term l n
matchPattern FoilPattern n l
pat Term n
term = FoilPattern n l
-> Term n -> Substitution Term n n -> Substitution Term l n
forall (i :: S) (l :: S) (n :: S).
FoilPattern i l
-> Term n -> Substitution Term i n -> Substitution Term l n
go FoilPattern n l
pat Term n
term Substitution Term n n
forall (e :: S -> *) (i :: S). InjectName e => Substitution e i i
Foil.identitySubst
  where
    go :: FoilPattern i l -> Term n -> Foil.Substitution Term i n -> Foil.Substitution Term l n
    go :: forall (i :: S) (l :: S) (n :: S).
FoilPattern i l
-> Term n -> Substitution Term i n -> Substitution Term l n
go (FoilPatternWildcard BNFC'Position
_loc) Term n
_ = Substitution Term i n -> Substitution Term i n
Substitution Term i n -> Substitution Term l n
forall a. a -> a
id
    go (FoilPatternVar BNFC'Position
_loc NameBinder i l
x) Term n
e    = \Substitution Term i n
subst -> Substitution Term i n
-> NameBinder i l -> Term n -> Substitution Term l n
forall (e :: S -> *) (i :: S) (o :: S) (i' :: S).
Substitution e i o -> NameBinder i i' -> e o -> Substitution e i' o
Foil.addSubst Substitution Term i n
subst NameBinder i l
x Term n
e
    go (FoilPatternPair BNFC'Position
loc FoilPattern' BNFC'Position i n2
l FoilPattern' BNFC'Position n2 l
r) Term n
e  = FoilPattern' BNFC'Position n2 l
-> Term n -> Substitution Term n2 n -> Substitution Term l n
forall (i :: S) (l :: S) (n :: S).
FoilPattern i l
-> Term n -> Substitution Term i n -> Substitution Term l n
go FoilPattern' BNFC'Position n2 l
r (BNFC'Position -> Term n -> Term n
forall a (binder :: S -> S -> *) (n :: S).
a -> AST binder (Term'Sig a) n -> AST binder (Term'Sig a) n
Second BNFC'Position
loc Term n
e) (Substitution Term n2 n -> Substitution Term l n)
-> (Substitution Term i n -> Substitution Term n2 n)
-> Substitution Term i n
-> Substitution Term l n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoilPattern' BNFC'Position i n2
-> Term n -> Substitution Term i n -> Substitution Term n2 n
forall (i :: S) (l :: S) (n :: S).
FoilPattern i l
-> Term n -> Substitution Term i n -> Substitution Term l n
go FoilPattern' BNFC'Position i n2
l (BNFC'Position -> Term n -> Term n
forall a (binder :: S -> S -> *) (n :: S).
a -> AST binder (Term'Sig a) n -> AST binder (Term'Sig a) n
First BNFC'Position
loc Term n
e)

-- | Compute weak head normal form (WHNF) of a \(\lambda\Pi\)-term.
--
-- >>> whnf Foil.emptyScope "(λx.(λ_.x)(λy.x))(λ(y,z).z)"
-- λ (x0, x1) . x1
--
-- >>> whnf Foil.emptyScope "(λs.λz.s(s(z)))(λs.λz.s(s(z)))"
-- λ x1 . (λ x0 . λ x1 . x0 (x0 x1)) ((λ x0 . λ x1 . x0 (x0 x1)) x1)
--
-- Note that during computation bound variables can become unordered
-- in the sense that binders may easily repeat or decrease. For example,
-- in the following expression, inner binder has lower index that the outer one:
--
-- >>> whnf Foil.emptyScope "(λx.λy.x)(λx.x)"
-- λ x1 . λ x0 . x0
--
-- At the same time, without substitution, we get regular, increasing binder indices:
--
-- >>> "λx.λy.y" :: Term Foil.VoidS
-- λ x0 . λ x1 . x1
--
-- To compare terms for \(\alpha\)-equivalence, we may use 'alphaEquiv':
--
-- >>> alphaEquiv Foil.emptyScope (whnf Foil.emptyScope "(λx.λy.x)(λx.x)") "λx.λy.y"
-- True
--
-- We may also normalize binders using 'refreshAST':
--
-- >>> refreshAST Foil.emptyScope (whnf Foil.emptyScope "(λx.λy.x)(λx.x)")
-- λ x0 . λ x1 . x1
whnf :: Foil.Distinct n => Foil.Scope n -> Term n -> Term n
whnf :: forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope = \case
  App BNFC'Position
loc Term n
f Term n
x ->
    case Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope Term n
f of
      Lam BNFC'Position
_loc FoilPattern' BNFC'Position n l1
binder AST (FoilPattern' BNFC'Position) (Term'Sig BNFC'Position) l1
body ->
        let subst :: Substitution Term l1 n
subst = FoilPattern' BNFC'Position n l1 -> Term n -> Substitution Term l1 n
forall (n :: S) (l :: S).
FoilPattern n l -> Term n -> Substitution Term l n
matchPattern FoilPattern' BNFC'Position n l1
binder Term n
x
         in Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope (Scope n
-> Substitution Term l1 n
-> AST (FoilPattern' BNFC'Position) (Term'Sig BNFC'Position) l1
-> Term n
forall (sig :: * -> * -> *) (o :: S) (binder :: S -> S -> *)
       (i :: S).
(Bifunctor sig, Distinct o, CoSinkable binder) =>
Scope o
-> Substitution (AST binder sig) i o
-> AST binder sig i
-> AST binder sig o
substitute Scope n
scope Substitution Term l1 n
subst AST (FoilPattern' BNFC'Position) (Term'Sig BNFC'Position) l1
body)
      Term n
f' -> BNFC'Position -> Term n -> Term n -> Term n
forall a (binder :: S -> S -> *) (n :: S).
a
-> AST binder (Term'Sig a) n
-> AST binder (Term'Sig a) n
-> AST binder (Term'Sig a) n
App BNFC'Position
loc Term n
f' Term n
x
  First BNFC'Position
loc Term n
t ->
    case Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope Term n
t of
      Pair BNFC'Position
_loc Term n
l Term n
_r -> Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope Term n
l
      Term n
t'             -> BNFC'Position -> Term n -> Term n
forall a (binder :: S -> S -> *) (n :: S).
a -> AST binder (Term'Sig a) n -> AST binder (Term'Sig a) n
First BNFC'Position
loc Term n
t'
  Second BNFC'Position
loc Term n
t ->
    case Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope Term n
t of
      Pair BNFC'Position
_loc Term n
_l Term n
r -> Scope n -> Term n -> Term n
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope n
scope Term n
r
      Term n
t'             -> BNFC'Position -> Term n -> Term n
forall a (binder :: S -> S -> *) (n :: S).
a -> AST binder (Term'Sig a) n -> AST binder (Term'Sig a) n
Second BNFC'Position
loc Term n
t'
  Term n
t -> Term n
t

-- ** \(\lambda\Pi\)-interpreter

-- | Interpret a \(\lambda\Pi\) command.
interpretCommand :: Raw.Command -> IO ()
interpretCommand :: Command -> IO ()
interpretCommand (Raw.CommandCompute BNFC'Position
_loc Term
term Term
_type) =
      String -> IO ()
putStrLn (String
"  ↦ " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term 'VoidS -> String
forall a. Show a => a -> String
show (Scope 'VoidS -> Term 'VoidS -> Term 'VoidS
forall (n :: S). Distinct n => Scope n -> Term n -> Term n
whnf Scope 'VoidS
Foil.emptyScope (Term -> Term 'VoidS
forall a. Term' a -> Term' a 'VoidS
toTerm'Closed Term
term)))
-- #TODO: add typeCheck
interpretCommand (Raw.CommandCheck BNFC'Position
_loc Term
_term Term
_type) = String -> IO ()
putStrLn String
"check is not yet implemented"

-- | Interpret a \(\lambda\Pi\) program.
interpretProgram :: Raw.Program -> IO ()
interpretProgram :: Program -> IO ()
interpretProgram (Raw.AProgram BNFC'Position
_loc [Command]
typedTerms) = (Command -> IO ()) -> [Command] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Command -> IO ()
interpretCommand [Command]
typedTerms

-- | A \(\lambda\Pi\) interpreter implemented via the free foil.
defaultMain :: IO ()
defaultMain :: IO ()
defaultMain = do
  String
input <- IO String
getContents
  case [Token] -> Err Program
Raw.pProgram (Bool -> [Token] -> [Token]
Raw.resolveLayout Bool
True (String -> [Token]
Raw.tokens String
input)) of
    Left String
err -> do
      String -> IO ()
putStrLn String
err
      IO ()
forall a. IO a
exitFailure
    Right Program
program -> Program -> IO ()
interpretProgram Program
program