{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE QuasiQuotes           #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
module Control.Monad.Foil.TH.MkFoilData where

import           Language.Haskell.TH
import Language.Haskell.TH.Syntax (addModFinalizer)

import qualified Control.Monad.Foil.Internal as Foil
import Control.Monad.Foil.TH.Util

-- | Generate scope-safe variants given names of types for the raw representation.
mkFoilData
  :: Name -- ^ Type name for raw terms.
  -> Name -- ^ Type name for raw variable identifiers.
  -> Name -- ^ Type name for raw scoped terms.
  -> Name -- ^ Type name for raw patterns.
  -> Q [Dec]
mkFoilData :: Name -> Name -> Name -> Name -> Q [Dec]
mkFoilData Name
termT Name
nameT Name
scopeT Name
patternT = do
  Name
n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
  Name
l <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"l"
  TyConI (DataD Cxt
_ctx Name
_name [TyVarBndr BndrVis]
scopeTVars Maybe Kind
_kind [Con]
scopeCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
scopeT
  TyConI (DataD Cxt
_ctx Name
_name [TyVarBndr BndrVis]
termTVars Maybe Kind
_kind [Con]
termCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
termT

  let foilScopeCons :: [Con]
foilScopeCons = (Con -> Con) -> [Con] -> [Con]
forall a b. (a -> b) -> [a] -> [b]
map ([TyVarBndr BndrVis] -> Name -> Con -> Con
toScopeCon [TyVarBndr BndrVis]
scopeTVars Name
n) [Con]
scopeCons
  let foilTermCons :: [Con]
foilTermCons = (Con -> Con) -> [Con] -> [Con]
forall a b. (a -> b) -> [a] -> [b]
map ([TyVarBndr BndrVis] -> Name -> Name -> Con -> Con
toTermCon [TyVarBndr BndrVis]
termTVars Name
n Name
l) [Con]
termCons

  [Dec]
patternD <- Name -> Name -> Q [Dec]
mkFoilPattern Name
nameT Name
patternT
  Q () -> Q ()
addModFinalizer (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$ DocLoc -> String -> Q ()
putDoc (Name -> DocLoc
DeclDoc Name
foilTermT)
    (String
"/Generated/ with '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show 'mkFoilData String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'. A scope-safe version of '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
termT String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'.")
  Q () -> Q ()
addModFinalizer (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$ DocLoc -> String -> Q ()
putDoc (Name -> DocLoc
DeclDoc Name
foilScopeT)
    (String
"/Generated/ with '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show 'mkFoilData String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'. A scope-safe version of '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
scopeT String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'.")
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$
    [ Cxt
-> Name
-> [TyVarBndr BndrVis]
-> Maybe Kind
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
foilTermT ([TyVarBndr BndrVis]
termTVars [TyVarBndr BndrVis] -> [TyVarBndr BndrVis] -> [TyVarBndr BndrVis]
forall a. [a] -> [a] -> [a]
++ [Name -> BndrVis -> Kind -> TyVarBndr BndrVis
forall flag. Name -> flag -> Kind -> TyVarBndr flag
KindedTV Name
n BndrVis
BndrReq (Name -> Kind
PromotedT ''Foil.S)]) Maybe Kind
forall a. Maybe a
Nothing [Con]
foilTermCons []
    , Cxt
-> Name
-> [TyVarBndr BndrVis]
-> Maybe Kind
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
foilScopeT ([TyVarBndr BndrVis]
scopeTVars [TyVarBndr BndrVis] -> [TyVarBndr BndrVis] -> [TyVarBndr BndrVis]
forall a. [a] -> [a] -> [a]
++ [Name -> BndrVis -> Kind -> TyVarBndr BndrVis
forall flag. Name -> flag -> Kind -> TyVarBndr flag
KindedTV Name
n BndrVis
BndrReq (Name -> Kind
PromotedT ''Foil.S)]) Maybe Kind
forall a. Maybe a
Nothing [Con]
foilScopeCons []
    ] [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
patternD
  where
    foilTermT :: Name
foilTermT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
termT)
    foilScopeT :: Name
foilScopeT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
scopeT)
    foilPatternT :: Name
foilPatternT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
patternT)

    -- | Convert a constructor declaration for a raw scoped term
    -- into a constructor for the scope-safe scoped term.
    toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con
    toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con
toScopeCon [TyVarBndr BndrVis]
_tvars Name
n (NormalC Name
conName [BangType]
params) =
      Name -> [BangType] -> Con
NormalC Name
foilConName ((BangType -> BangType) -> [BangType] -> [BangType]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> BangType
toScopeParam [BangType]
params)
      where
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        toScopeParam :: BangType -> BangType
toScopeParam (Bang
_bang, PeelConT Name
tyName Cxt
tyParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
termT = (Bang
_bang, Name -> Cxt -> Kind
PeelConT Name
foilTermT (Cxt
tyParams Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
n]))
        toScopeParam BangType
_bangType = BangType
_bangType

    -- | Convert a constructor declaration for a raw term
    -- into a constructor for the scope-safe term.
    toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con
    toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con
toTermCon [TyVarBndr BndrVis]
tvars Name
n Name
l (NormalC Name
conName [BangType]
params) =
      [Name] -> [BangType] -> Kind -> Con
GadtC [Name
foilConName] ((BangType -> BangType) -> [BangType] -> [BangType]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> BangType
toTermParam [BangType]
params) (Name -> Cxt -> Kind
PeelConT Name
foilTermT ((TyVarBndr BndrVis -> Kind) -> [TyVarBndr BndrVis] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Kind
VarT (Name -> Kind)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Kind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
tvars Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
n]))
      where
        foilNames :: [Name]
foilNames = [Name
n, Name
l]
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        toTermParam :: BangType -> BangType
toTermParam (Bang
_bang, PeelConT Name
tyName Cxt
tyParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT = (Bang
_bang, Name -> Cxt -> Kind
PeelConT Name
foilPatternT (Cxt
tyParams Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ (Name -> Kind) -> [Name] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map Name -> Kind
VarT [Name]
foilNames))
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT = (Bang
_bang, Kind -> Kind -> Kind
AppT (Name -> Kind
ConT ''Foil.Name) (Name -> Kind
VarT Name
n))
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
scopeT = (Bang
_bang, Name -> Cxt -> Kind
PeelConT Name
foilScopeT (Cxt
tyParams Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
l]))
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
termT = (Bang
_bang, Name -> Cxt -> Kind
PeelConT Name
foilTermT (Cxt
tyParams Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
n]))
        toTermParam BangType
_bangType = BangType
_bangType

-- | Generate just the scope-safe patterns.
mkFoilPattern
  :: Name -- ^ Type name for raw variable identifiers.
  -> Name -- ^ Type name for raw patterns.
  -> Q [Dec]
mkFoilPattern :: Name -> Name -> Q [Dec]
mkFoilPattern Name
nameT Name
patternT = do
  Name
n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
  Name
l <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"l"
  TyConI (DataD Cxt
_ctx Name
_name [TyVarBndr BndrVis]
patternTVars Maybe Kind
_kind [Con]
patternCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
patternT

  [Con]
foilPatternCons <- (Con -> Q Con) -> [Con] -> Q [Con]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([TyVarBndr BndrVis] -> Name -> Con -> Q Con
toPatternCon [TyVarBndr BndrVis]
patternTVars Name
n) [Con]
patternCons

  Q () -> Q ()
addModFinalizer (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$ DocLoc -> String -> Q ()
putDoc (Name -> DocLoc
DeclDoc Name
foilPatternT)
    (String
"/Generated/ with '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show 'mkFoilPattern String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'. A scope-safe version of '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
patternT String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'.")
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Cxt
-> Name
-> [TyVarBndr BndrVis]
-> Maybe Kind
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
foilPatternT ([TyVarBndr BndrVis]
patternTVars [TyVarBndr BndrVis] -> [TyVarBndr BndrVis] -> [TyVarBndr BndrVis]
forall a. [a] -> [a] -> [a]
++ [Name -> BndrVis -> Kind -> TyVarBndr BndrVis
forall flag. Name -> flag -> Kind -> TyVarBndr flag
KindedTV Name
n BndrVis
BndrReq (Name -> Kind
PromotedT ''Foil.S), Name -> BndrVis -> Kind -> TyVarBndr BndrVis
forall flag. Name -> flag -> Kind -> TyVarBndr flag
KindedTV Name
l BndrVis
BndrReq (Name -> Kind
PromotedT ''Foil.S)]) Maybe Kind
forall a. Maybe a
Nothing [Con]
foilPatternCons []
    ]
  where
    foilPatternT :: Name
foilPatternT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
patternT)

    -- | Convert a constructor declaration for a raw pattern type
    -- into a constructor for the scope-safe pattern type.
    toPatternCon
      :: [TyVarBndr BndrVis]
      -> Name   -- ^ Name for the starting scope type variable.
      -> Con    -- ^ Raw pattern constructor.
      -> Q Con
    toPatternCon :: [TyVarBndr BndrVis] -> Name -> Con -> Q Con
toPatternCon [TyVarBndr BndrVis]
tvars Name
n (NormalC Name
conName [BangType]
params) = do
      (Name
lastScopeName, [BangType]
foilParams) <- Int -> Name -> [BangType] -> Q (Name, [BangType])
toPatternConParams Int
1 Name
n [BangType]
params
      let foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
      Q () -> Q ()
addModFinalizer (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$ DocLoc -> String -> Q ()
putDoc (Name -> DocLoc
DeclDoc Name
foilConName) (String
"Corresponds to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
conName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'.")
      Con -> Q Con
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Name] -> [BangType] -> Kind -> Con
GadtC [Name
foilConName] [BangType]
foilParams (Name -> Cxt -> Kind
PeelConT Name
foilPatternT ((TyVarBndr BndrVis -> Kind) -> [TyVarBndr BndrVis] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Kind
VarT (Name -> Kind)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Kind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
tvars Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
n, Name -> Kind
VarT Name
lastScopeName])))
      where
        -- | Process type parameters of a pattern,
        -- introducing (existential) type variables for the intermediate scopes,
        -- if necessary.
        toPatternConParams
          :: Int                  -- ^ Index of the component in the constructor.
          -> Name                 -- ^ Current scope (after processing any previous bindings).
          -> [BangType]           -- ^ Leftover pattern components.
          -> Q (Name, [BangType]) -- ^ Resulting extended scope and a list of corresponding scope-safe components.
        toPatternConParams :: Int -> Name -> [BangType] -> Q (Name, [BangType])
toPatternConParams Int
_ Name
p [] = (Name, [BangType]) -> Q (Name, [BangType])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
p, [])
        toPatternConParams Int
i Name
p (param :: BangType
param@(Bang
bang_, Kind
type_) : [BangType]
conParams) =
          case Kind
type_ of
            -- if the current component is a variable identifier
            -- then treat it as a single name binder (see 'Foil.NameBinder')
            PeelConT Name
tyName Cxt
_tyParams | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT -> do
              Name
l <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)
              let type' :: Kind
type' = Kind -> Kind -> Kind
AppT (Kind -> Kind -> Kind
AppT (Name -> Kind
ConT ''Foil.NameBinder) (Name -> Kind
VarT Name
p)) (Name -> Kind
VarT Name
l)
              (Name
l', [BangType]
conParams') <- Int -> Name -> [BangType] -> Q (Name, [BangType])
toPatternConParams (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Name
l [BangType]
conParams
              (Name, [BangType]) -> Q (Name, [BangType])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
l', (Bang
bang_, Kind
type') BangType -> [BangType] -> [BangType]
forall a. a -> [a] -> [a]
: [BangType]
conParams')
            -- if the current component is a raw pattern
            -- then convert it into a scope-safe pattern
            PeelConT Name
tyName Cxt
tyParams | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT -> do
              Name
l <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)
              let type' :: Kind
type' = Name -> Cxt -> Kind
PeelConT Name
foilPatternT (Cxt
tyParams Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ [Name -> Kind
VarT Name
p, Name -> Kind
VarT Name
l])
              (Name
l', [BangType]
conParams') <- Int -> Name -> [BangType] -> Q (Name, [BangType])
toPatternConParams (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Name
l [BangType]
conParams
              (Name, [BangType]) -> Q (Name, [BangType])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
l', (Bang
bang_, Kind
type') BangType -> [BangType] -> [BangType]
forall a. a -> [a] -> [a]
: [BangType]
conParams')
            -- otherwise, ignore the component as non-binding
            Kind
_ -> do
              (Name
l, [BangType]
conParams') <- Int -> Name -> [BangType] -> Q (Name, [BangType])
toPatternConParams (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Name
p [BangType]
conParams
              (Name, [BangType]) -> Q (Name, [BangType])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
l, BangType
param BangType -> [BangType] -> [BangType]
forall a. a -> [a] -> [a]
: [BangType]
conParams')