diff --git a/src/Language/SystemF/Expression.hs b/src/Language/SystemF/Expression.hs index 0bfa1675d821268730be95772b481e69767106ef..940f22eecfa14525c11765e7a71bea354d5fd02f 100644 --- a/src/Language/SystemF/Expression.hs +++ b/src/Language/SystemF/Expression.hs @@ -18,6 +18,7 @@ data SystemFExpr name ty data Ty name = TyVar name -- Type variable (T) | TyArrow (Ty name) (Ty name) -- Type arrow (T -> U) + | TyForAll name (Ty name) -- Universal type (forall T. X) deriving (Eq, Show) -- Pretty printing @@ -93,6 +94,7 @@ pprTy :: PrettyPrint n -> PDoc String pprTy pdoc space (TyVar n) = prettyPrint n `add` pdoc pprTy pdoc space (TyArrow a b) = pprTyArrow pdoc space a b +pprTy pdoc _ (TyForAll n t) = pprTyForAll pdoc n t pprTyArrow :: PrettyPrint n => PDoc String @@ -113,6 +115,14 @@ pprTyArrow' space a b = a <> arrow <> b where arrow | space = " -> " `add` empty | otherwise = "->" `add` empty +pprTyForAll :: PrettyPrint n + => PDoc String + -> n + -> Ty n + -> PDoc String +pprTyForAll pdoc n t = prefix <> prettyPrint t `add` pdoc + where prefix = between (prettyPrint n `add` empty) "forall " ". " empty + -- Pretty print a type abstraction pprTAbs :: (PrettyPrint n, PrettyPrint t) => PDoc String diff --git a/test/Language/SystemF/ExpressionSpec.hs b/test/Language/SystemF/ExpressionSpec.hs index 21e23d49e5638dceb82c82075b8fc0cdb7241c3f..24726a736c925983f4d32966daedcd355ef8125d 100644 --- a/test/Language/SystemF/ExpressionSpec.hs +++ b/test/Language/SystemF/ExpressionSpec.hs @@ -58,6 +58,9 @@ spec = describe "prettyPrint" $ do it "print simple arrow types" $ prettyPrint (TyArrow (TyVar "A") (TyVar "B")) `shouldBe` "A -> B" + it "prints simple forall types" $ + prettyPrint (TyForAll "X" (TyVar "X")) `shouldBe` "forall X. X" + it "prints chained arrow types" $ prettyPrint (TyArrow (TyVar "X") (TyArrow (TyVar "Y") (TyVar "Z"))) `shouldBe` "X -> Y -> Z" @@ -65,3 +68,8 @@ spec = describe "prettyPrint" $ do it "prints nested arrow types" $ prettyPrint (TyArrow (TyArrow (TyVar "T") (TyVar "U")) (TyVar "V")) `shouldBe` "(T -> U) -> V" + + it "prints complex forall types" $ + prettyPrint (TyForAll "A" (TyArrow (TyVar "A") (TyVar "A"))) + `shouldBe` "forall A. A -> A" +