module Main where

import Data.Iteratee as I
import Criterion.Main
import Control.Monad.Identity


runner
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner etee xs iter =
  runIdentity $ enumPureNChunk xs 5 (joinI $ etee iter) >>= I.run

-- test fusion of enumeratee/iteratee composition
runner2
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner2 etee xs iter =
  runIdentity $ enumPureNChunk xs 5 (etee =$ iter) >>= I.run

-- test fusion of enumerator/enumeratee composition
runner3
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner3 etee xs iter =
  runIdentity $ (enumPureNChunk xs 5 $= etee) iter >>= I.run

m2 :: Enumeratee [Int] [Int] Identity a
m2 = mapChunks id ><> mapChunks (map (+1))
{-# INLINE m2 #-}

m3 :: Enumeratee [Int] [Int] Identity a
m3 = mapChunks id ><> mapChunks (map (+1)) ><> I.filter (even)
{-# INLINE m3 #-}

m4 :: Enumeratee [Int] [Int] Identity a
m4 = m2 ><> m2

m10 :: Enumeratee [Int] [Int] Identity a
m10 = m3 ><> m2 ><> m3 ><> m2

m' :: Enumeratee [a] [a] Identity x
m' = convStream (getChunk)
{-# INLINE m' #-}

m'2 :: Num a => Enumeratee [a] [a] Identity x
m'2 = convStream (liftM (map (+1)) getChunk)
{-# INLINE m'2 #-}

fusedMap :: Iteratee [Int] Identity a -> a
fusedMap = runner m2 [1..100]

fusedMap' :: Iteratee [Int] Identity a -> a
fusedMap' = runner (m2 ><> m') [1..100]

fusedMap'2 :: Iteratee [Int] Identity a -> a
fusedMap'2 = runner (m2 ><> m' ><> m') [1..100]

fusedMap3 :: Iteratee [Int] Identity a -> a
fusedMap3 = runner m3 [1..100]

fusedMap4 :: Iteratee [Int] Identity a -> a
fusedMap4 = runner m4 [1..100]

fusedMap10 :: Iteratee [Int] Identity a -> a
fusedMap10 = runner m10 [1..100]

fusionBenches :: [Benchmark]
fusionBenches =
  [ bench "mapChunks/mapChunks fusion"   $ whnf fusedMap I.sum
  , bench "mapChunks/filter fusion"      $ whnf fusedMap3 I.sum
  , bench "nested mapChunks/mapChunks fusion"   $ whnf fusedMap4 I.sum
  , bench "highly nested fusion"   $ whnf fusedMap10 I.sum
  , bench "mapChunks/mapChunks/convStream"   $ whnf fusedMap' I.sum
  , bench "mapChunks/mapChunks/convStream2"   $ whnf fusedMap'2 I.sum
  ]

main :: IO ()
main = do
    print $ "fusedMap"
    print $ fusedMap I.sum
    print "fusedMap/filter"
    print $ fusedMap3 I.sum
    print "fusedMap4"
    print $ fusedMap4 I.sum
    print "fusedMap10"
    print $ fusedMap10 I.sum

    defaultMain fusionBenches
