{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE QuasiQuotes           #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
module Control.Monad.Foil.TH.MkInstancesFoil where

import           Language.Haskell.TH

import qualified Control.Monad.Foil         as Foil
import           Control.Monad.Foil.TH.Util
import           Data.List                  (nub)

-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances.
mkInstancesFoil
  :: Name -- ^ Type name for raw terms.
  -> Name -- ^ Type name for raw variable identifiers.
  -> Name -- ^ Type name for raw scoped terms.
  -> Name -- ^ Type name for raw patterns.
  -> Q [Dec]
mkInstancesFoil :: Name -> Name -> Name -> Name -> Q [Dec]
mkInstancesFoil Name
termT Name
nameT Name
scopeT Name
patternT = do
  TyConI (DataD [Type]
_ctx Name
_name [TyVarBndr BndrVis]
scopeTVars Maybe Type
_kind [Con]
scopeCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
scopeT
  TyConI (DataD [Type]
_ctx Name
_name [TyVarBndr BndrVis]
termTVars Maybe Type
_kind [Con]
termCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
termT

  [Dec]
coSinkablePatternD <- Name -> Name -> Q [Dec]
deriveCoSinkable Name
nameT Name
patternT

  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$
    [ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] (Type -> Type -> Type
AppT (Name -> Type
ConT ''Foil.Sinkable) (Name -> [Type] -> Type
PeelConT Name
foilScopeT ((TyVarBndr BndrVis -> Type) -> [TyVarBndr BndrVis] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
scopeTVars)))
        [ Name -> [Clause] -> Dec
FunD 'Foil.sinkabilityProof ((Con -> Clause) -> [Con] -> [Clause]
forall a b. (a -> b) -> [a] -> [b]
map Con -> Clause
clauseScopedTerm [Con]
scopeCons) ]

    , Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] (Type -> Type -> Type
AppT (Name -> Type
ConT ''Foil.Sinkable) (Name -> [Type] -> Type
PeelConT Name
foilTermT ((TyVarBndr BndrVis -> Type) -> [TyVarBndr BndrVis] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
termTVars)))
        [ Name -> [Clause] -> Dec
FunD 'Foil.sinkabilityProof ((Con -> Clause) -> [Con] -> [Clause]
forall a b. (a -> b) -> [a] -> [b]
map Con -> Clause
clauseTerm [Con]
termCons)]
    ] [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
coSinkablePatternD

  where
    foilTermT :: Name
foilTermT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
termT)
    foilScopeT :: Name
foilScopeT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
scopeT)

    clauseScopedTerm :: Con -> Clause
    clauseScopedTerm :: Con -> Clause
clauseScopedTerm = Con -> Clause
clauseTerm

    clauseTerm :: Con -> Clause
    clauseTerm :: Con -> Clause
clauseTerm RecC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record constructors (RecC) are not supported yet!"
    clauseTerm InfixC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Infix constructors (InfixC) are not supported yet!"
    clauseTerm ForallC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Existential constructors (ForallC) are not supported yet!"
    clauseTerm GadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"GADT constructors (GadtC) are not supported yet!"
    clauseTerm RecGadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record GADT constructors (RecGadtC) are not supported yet!"
    clauseTerm (NormalC Name
conName [BangType]
params) =
      [Pat] -> Body -> [Dec] -> Clause
Clause
        [Name -> Pat
VarP Name
rename, Name -> [Type] -> [Pat] -> Pat
ConP Name
foilConName [] [Pat]
conParamPatterns]
        (Exp -> Body
NormalB (Integer -> Exp -> Exp -> [BangType] -> Exp
go Integer
1 (Name -> Exp
VarE Name
rename) (Name -> Exp
ConE Name
foilConName) [BangType]
params))
        []
      where
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        rename :: Name
rename = String -> Name
mkName String
"_rename"
        conParamPatterns :: [Pat]
conParamPatterns = (BangType -> Integer -> Pat) -> [BangType] -> [Integer] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith BangType -> Integer -> Pat
forall {a} {p}. Show a => p -> a -> Pat
mkConParamPattern [BangType]
params [Integer
1..]
        mkConParamPattern :: p -> a -> Pat
mkConParamPattern p
_ a
i = Name -> Pat
VarP (String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i))

        go :: Integer -> Exp -> Exp -> [BangType] -> Exp
go Integer
_i Exp
_rename' Exp
p [] = Exp
p
        go Integer
i Exp
rename' Exp
p ((Bang
_bang, PeelConT Name
tyName [Type]
_tyParams) : [BangType]
conParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT =
              Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
rename) (Name -> Exp
VarE Name
xi))) [BangType]
conParams
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
termT =
              Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.sinkabilityProof) (Name -> Exp
VarE Name
rename)) (Name -> Exp
VarE Name
xi))) [BangType]
conParams
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
scopeT =
              Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.sinkabilityProof) Exp
rename') (Name -> Exp
VarE Name
xi))) [BangType]
conParams
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT =
              Exp -> Exp -> Exp
AppE
                (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.coSinkabilityProof) Exp
rename') (Name -> Exp
VarE Name
xi))
                ([Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
renamei, Name -> Pat
VarP Name
xi']
                  (Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) (Name -> Exp
VarE Name
renamei) (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi')) [BangType]
conParams))
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
            xi' :: Name
xi' = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'")
            renamei :: Name
renamei = String -> Name
mkName (String
"rename" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
        go Integer
i Exp
rename' Exp
p (BangType
_ : [BangType]
conPatterns) =
          Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi)) [BangType]
conPatterns
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)

-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances.
deriveCoSinkable
  :: Name -- ^ Type name for raw variable identifiers.
  -> Name -- ^ Type name for raw patterns.
  -> Q [Dec]
deriveCoSinkable :: Name -> Name -> Q [Dec]
deriveCoSinkable Name
nameT Name
patternT = do
  TyConI (DataD [Type]
_ctx Name
_name [TyVarBndr BndrVis]
patternTVars Maybe Type
_kind [Con]
patternCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
patternT

  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] (Type -> Type -> Type
AppT (Name -> Type
ConT ''Foil.CoSinkable) (Name -> [Type] -> Type
PeelConT Name
foilPatternT ((TyVarBndr BndrVis -> Type) -> [TyVarBndr BndrVis] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
patternTVars)))
        [ Name -> [Clause] -> Dec
FunD 'Foil.coSinkabilityProof ((Con -> Clause) -> [Con] -> [Clause]
forall a b. (a -> b) -> [a] -> [b]
map Con -> Clause
clausePattern [Con]
patternCons)
        , Name -> [Clause] -> Dec
FunD 'Foil.withPattern ((Con -> Clause) -> [Con] -> [Clause]
forall a b. (a -> b) -> [a] -> [b]
map Con -> Clause
clauseWithPattern [Con]
patternCons) ]
    ]

  where
    foilPatternT :: Name
foilPatternT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
patternT)

    clausePattern :: Con -> Clause
    clausePattern :: Con -> Clause
clausePattern RecC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record constructors (RecC) are not supported yet!"
    clausePattern InfixC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Infix constructors (InfixC) are not supported yet!"
    clausePattern ForallC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Existential constructors (ForallC) are not supported yet!"
    clausePattern GadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"GADT constructors (GadtC) are not supported yet!"
    clausePattern RecGadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record GADT constructors (RecGadtC) are not supported yet!"
    clausePattern (NormalC Name
conName [BangType]
params) =
      [Pat] -> Body -> [Dec] -> Clause
Clause
        [Name -> Pat
VarP Name
rename, Name -> [Type] -> [Pat] -> Pat
ConP Name
foilConName [] [Pat]
conParamPatterns, Name -> Pat
VarP Name
cont]
        (Exp -> Body
NormalB (Integer -> Exp -> Exp -> [BangType] -> Exp
go Integer
1 (Name -> Exp
VarE Name
rename) (Name -> Exp
ConE Name
foilConName) [BangType]
params))
        []
      where
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        rename :: Name
rename = String -> Name
mkName String
"rename"
        cont :: Name
cont = String -> Name
mkName String
"cont"
        conParamPatterns :: [Pat]
conParamPatterns = (BangType -> Integer -> Pat) -> [BangType] -> [Integer] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith BangType -> Integer -> Pat
forall {a} {p}. Show a => p -> a -> Pat
mkConParamPattern [BangType]
params [Integer
1..]
        mkConParamPattern :: p -> a -> Pat
mkConParamPattern p
_ a
i = Name -> Pat
VarP (String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i))

        go :: Integer -> Exp -> Exp -> [BangType] -> Exp
go Integer
_i Exp
rename' Exp
p [] = Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
cont) Exp
rename') Exp
p
        go Integer
i Exp
rename' Exp
p ((Bang
_bang, PeelConT Name
tyName [Type]
_tyParams) : [BangType]
conParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT Bool -> Bool -> Bool
|| Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT =
              Exp -> Exp -> Exp
AppE
                (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.coSinkabilityProof) Exp
rename') (Name -> Exp
VarE Name
xi))
                ([Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
renamei, Name -> Pat
VarP Name
xi']
                  (Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) (Name -> Exp
VarE Name
renamei) (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi')) [BangType]
conParams))
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
            xi' :: Name
xi' = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'")
            renamei :: Name
renamei = String -> Name
mkName (String
"rename" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
        go Integer
i Exp
rename' Exp
p (BangType
_ : [BangType]
conPatterns) =
          Integer -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi)) [BangType]
conPatterns
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)

    clauseWithPattern :: Con -> Clause
    clauseWithPattern :: Con -> Clause
clauseWithPattern RecC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record constructors (RecC) are not supported yet!"
    clauseWithPattern InfixC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Infix constructors (InfixC) are not supported yet!"
    clauseWithPattern ForallC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Existential constructors (ForallC) are not supported yet!"
    clauseWithPattern GadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"GADT constructors (GadtC) are not supported yet!"
    clauseWithPattern RecGadtC{} = String -> Clause
forall a. HasCallStack => String -> a
error String
"Record GADT constructors (RecGadtC) are not supported yet!"
    clauseWithPattern (NormalC Name
conName [BangType]
params) =
      [Pat] -> Body -> [Dec] -> Clause
Clause
        [Name -> Pat
VarP Name
withNameBinder, Name -> Pat
VarP Name
id', Name -> Pat
VarP Name
comp, Name -> Pat
VarP Name
scope, Name -> [Type] -> [Pat] -> Pat
ConP Name
foilConName [] [Pat]
conParamPatterns, Name -> Pat
VarP Name
cont]
        (Exp -> Body
NormalB (Integer -> Name -> Exp -> Exp -> [BangType] -> Exp
go Integer
1 Name
scope (Name -> Exp
VarE Name
id') (Name -> Exp
ConE Name
foilConName) [BangType]
params))
        []
      where
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        withNameBinder :: Name
withNameBinder = String -> Name
mkName String
"_withNameBinder"
        id' :: Name
id' = String -> Name
mkName String
"id'"
        comp :: Name
comp = String -> Name
mkName String
"_comp"
        scope :: Name
scope = String -> Name
mkName String
"_scope"
        cont :: Name
cont = String -> Name
mkName String
"cont"
        conParamPatterns :: [Pat]
conParamPatterns = (BangType -> Integer -> Pat) -> [BangType] -> [Integer] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith BangType -> Integer -> Pat
forall {a} {p}. Show a => p -> a -> Pat
mkConParamPattern [BangType]
params [Integer
1..]
        mkConParamPattern :: p -> a -> Pat
mkConParamPattern p
_ a
i = Name -> Pat
VarP (String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i))

        go :: Integer -> Name -> Exp -> Exp -> [BangType] -> Exp
go Integer
_i Name
_scope' Exp
rename' Exp
p [] = Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
cont) Exp
rename') Exp
p
        go Integer
i Name
scope' Exp
rename' Exp
p ((Bang
_bang, PeelConT Name
tyName [Type]
_tyParams) : [BangType]
conParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT Bool -> Bool -> Bool
|| Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT =
              Exp -> Exp -> Exp
AppE
                ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.withPattern) [Name -> Exp
VarE Name
withNameBinder, Name -> Exp
VarE Name
id', Name -> Exp
VarE Name
comp, Name -> Exp
VarE Name
scope', Name -> Exp
VarE Name
xi])
                ([Pat] -> Exp -> Exp
LamE [Name -> Pat
VarP Name
renamei, Name -> Pat
VarP Name
xi']
                  ([Dec] -> Exp -> Exp
LetE [Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
scopei) (Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.extendScopePattern) (Name -> Exp
VarE Name
xi')) (Name -> Exp
VarE Name
scope'))) []]
                    (Integer -> Name -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Name
scopei ((Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
comp) [Exp
rename', Name -> Exp
VarE Name
renamei]) (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi')) [BangType]
conParams)))
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
            xi' :: Name
xi' = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'")
            renamei :: Name
renamei = String -> Name
mkName (String
"f" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
            scopei :: Name
scopei = String -> Name
mkName (String
"_scope" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)
        go Integer
i Name
scope' Exp
rename' Exp
p (BangType
_ : [BangType]
conPatterns) =
          Integer -> Name -> Exp -> Exp -> [BangType] -> Exp
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Name
scope' Exp
rename' (Exp -> Exp -> Exp
AppE Exp
p (Name -> Exp
VarE Name
xi)) [BangType]
conPatterns
          where
            xi :: Name
xi = String -> Name
mkName (String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i)

-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances.
deriveUnifiablePattern
  :: Name -- ^ Type name for raw variable identifiers.
  -> Name -- ^ Type name for raw patterns.
  -> Q [Dec]
deriveUnifiablePattern :: Name -> Name -> Q [Dec]
deriveUnifiablePattern Name
nameT Name
patternT = do
  TyConI (DataD [Type]
_ctx Name
_name [TyVarBndr BndrVis]
patternTVars Maybe Type
_kind [Con]
patternCons [DerivClause]
_deriv) <- Name -> Q Info
reify Name
patternT

  let ([Type]
eqTypes, [Clause]
clauses) = (Con -> ([Type], Clause)) -> [Con] -> ([Type], [Clause])
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Con -> ([Type], Clause)
clauseUnifyPatterns [Con]
patternCons
      ctx :: [Type]
ctx = [Type] -> [Type]
forall a. Eq a => [a] -> [a]
nub [ Type -> Type -> Type
AppT (Name -> Type
ConT ''Foil.UnifiableInPattern) Type
type_ | Type
type_ <- [Type]
eqTypes, Type
type_ Type -> [Type] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (TyVarBndr BndrVis -> Type) -> [TyVarBndr BndrVis] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
patternTVars ]
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
ctx (Type -> Type -> Type
AppT (Name -> Type
ConT ''Foil.UnifiablePattern) (Name -> [Type] -> Type
PeelConT Name
foilPatternT ((TyVarBndr BndrVis -> Type) -> [TyVarBndr BndrVis] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr BndrVis -> Name) -> TyVarBndr BndrVis -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr BndrVis -> Name
forall a. TyVarBndr a -> Name
tvarName) [TyVarBndr BndrVis]
patternTVars)))
        [ Name -> [Clause] -> Dec
FunD 'Foil.unifyPatterns ([Clause]
clauses [Clause] -> [Clause] -> [Clause]
forall a. [a] -> [a] -> [a]
++ [Clause
notUnifiableClause]) ]
    ]

  where
    foilPatternT :: Name
foilPatternT = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
patternT)

    notUnifiableClause :: Clause
    notUnifiableClause :: Clause
notUnifiableClause = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP, Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'Foil.NotUnifiable)) []

    clauseUnifyPatterns :: Con -> ([Type], Clause)
    clauseUnifyPatterns :: Con -> ([Type], Clause)
clauseUnifyPatterns RecC{} = String -> ([Type], Clause)
forall a. HasCallStack => String -> a
error String
"Record constructors (RecC) are not supported yet!"
    clauseUnifyPatterns InfixC{} = String -> ([Type], Clause)
forall a. HasCallStack => String -> a
error String
"Infix constructors (InfixC) are not supported yet!"
    clauseUnifyPatterns ForallC{} = String -> ([Type], Clause)
forall a. HasCallStack => String -> a
error String
"Existential constructors (ForallC) are not supported yet!"
    clauseUnifyPatterns GadtC{} = String -> ([Type], Clause)
forall a. HasCallStack => String -> a
error String
"GADT constructors (GadtC) are not supported yet!"
    clauseUnifyPatterns RecGadtC{} = String -> ([Type], Clause)
forall a. HasCallStack => String -> a
error String
"Record GADT constructors (RecGadtC) are not supported yet!"
    clauseUnifyPatterns (NormalC Name
conName [BangType]
params) =
      case Integer
-> [Type] -> [(Bool, Exp, Exp)] -> [BangType] -> (Exp, [Type])
go Integer
1 [] [] [BangType]
params of
        (Exp
body, [Type]
eqTypes) ->
          ([Type]
eqTypes, [Pat] -> Body -> [Dec] -> Clause
Clause
            [Name -> [Type] -> [Pat] -> Pat
ConP Name
foilConName [] [Pat]
paramsL, Name -> [Type] -> [Pat] -> Pat
ConP Name
foilConName [] [Pat]
paramsR]
            (Exp -> Body
NormalB Exp
body)
            [])
      where
        foilConName :: Name
foilConName = String -> Name
mkName (String
"Foil" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName)
        paramsL :: [Pat]
paramsL = (BangType -> Integer -> Pat) -> [BangType] -> [Integer] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (String -> BangType -> Integer -> Pat
forall {a} {p}. Show a => String -> p -> a -> Pat
mkConParamPattern String
"l") [BangType]
params [Integer
1..]
        paramsR :: [Pat]
paramsR = (BangType -> Integer -> Pat) -> [BangType] -> [Integer] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (String -> BangType -> Integer -> Pat
forall {a} {p}. Show a => String -> p -> a -> Pat
mkConParamPattern String
"r") [BangType]
params [Integer
1..]
        mkConParamPattern :: String -> p -> a -> Pat
mkConParamPattern String
s p
_ a
i = Name -> Pat
VarP (String -> Name
mkName (String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i))

        mkUnifyAllPairsRev :: [(Bool, Exp, Exp)] -> Exp
mkUnifyAllPairsRev [] = Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Foil.SameNameBinders) (Name -> Exp
VarE 'Foil.emptyNameBinders)
        mkUnifyAllPairsRev [(Bool
isNameBinder, Exp
l, Exp
r)]
          | Bool
isNameBinder = Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.unifyNameBinders) Exp
l) Exp
r
          | Bool
otherwise    = Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.unifyPatterns) Exp
l) Exp
r
        mkUnifyAllPairsRev ((Bool
isNameBinder, Exp
l, Exp
r) : [(Bool, Exp, Exp)]
pairs) =
          Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE
            (Exp -> Maybe Exp
forall a. a -> Maybe a
Just ([(Bool, Exp, Exp)] -> Exp
mkUnifyAllPairsRev [(Bool, Exp, Exp)]
pairs))
            (Name -> Exp
VarE (if Bool
isNameBinder then 'Foil.andThenUnifyNameBinders else 'Foil.andThenUnifyPatterns))
            (Exp -> Maybe Exp
forall a. a -> Maybe a
Just ([Maybe Exp] -> Exp
TupE [Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
l, Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
r]))


        go :: Integer
-> [Type] -> [(Bool, Exp, Exp)] -> [BangType] -> (Exp, [Type])
go Integer
_i [Type]
eqTypes [(Bool, Exp, Exp)]
pairsRev [] = ([(Bool, Exp, Exp)] -> Exp
mkUnifyAllPairsRev [(Bool, Exp, Exp)]
pairsRev, [Type]
eqTypes)
        go Integer
i [Type]
eqTypes [(Bool, Exp, Exp)]
pairsRev ((Bang
_bang, PeelConT Name
tyName [Type]
_tyParams) : [BangType]
conParams)
          | Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT Bool -> Bool -> Bool
|| Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
patternT =
            case Integer
-> [Type] -> [(Bool, Exp, Exp)] -> [BangType] -> (Exp, [Type])
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [Type]
eqTypes ((Bool
isNameBinder, Exp
l, Exp
r) (Bool, Exp, Exp) -> [(Bool, Exp, Exp)] -> [(Bool, Exp, Exp)]
forall a. a -> [a] -> [a]
: [(Bool, Exp, Exp)]
pairsRev) [BangType]
conParams of
              (Exp
next, [Type]
eqTypes') ->
                (Exp -> [Match] -> Exp
CaseE ([Maybe Exp] -> Exp
TupE [Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.assertDistinct) Exp
l), Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'Foil.assertDistinct) Exp
r)])
                  [Pat -> Body -> [Dec] -> Match
Match
                    ([Pat] -> Pat
TupP [Name -> [Type] -> [Pat] -> Pat
ConP 'Foil.Distinct [] [], Name -> [Type] -> [Pat] -> Pat
ConP 'Foil.Distinct [] []])
                    (Exp -> Body
NormalB Exp
next)
                    []], [Type]
eqTypes')
          where
            isNameBinder :: Bool
isNameBinder = Name
tyName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
nameT
            l :: Exp
l = Name -> Exp
VarE (String -> Name
mkName (String
"l" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i))
            r :: Exp
r = Name -> Exp
VarE (String -> Name
mkName (String
"r" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i))
        go Integer
i [Type]
eqTypes [(Bool, Exp, Exp)]
pairsRev ((Bang
_bang, Type
type_) : [BangType]
conPatterns) =
          case Integer
-> [Type] -> [(Bool, Exp, Exp)] -> [BangType] -> (Exp, [Type])
go (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [Type]
eqTypes [(Bool, Exp, Exp)]
pairsRev [BangType]
conPatterns of
            (Exp
next, [Type]
eqTypes') ->
              (Exp -> Exp -> Exp -> Exp
CondE
                (Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
l) (Name -> Exp
VarE 'Foil.unifyInPattern) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
r))
                Exp
next
                (Name -> Exp
ConE 'Foil.NotUnifiable), Type
type_ Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
eqTypes')
          where
            l :: Exp
l = Name -> Exp
VarE (String -> Name
mkName (String
"l" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i))
            r :: Exp
r = Name -> Exp
VarE (String -> Name
mkName (String
"r" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i))