11{-# LANGUAGE AllowAmbiguousTypes #-}
2+ {-# LANGUAGE CPP #-}
23{-# LANGUAGE DataKinds #-}
34{-# LANGUAGE DerivingVia #-}
45{-# LANGUAGE GADTs #-}
56{-# LANGUAGE LambdaCase #-}
67{-# LANGUAGE OverloadedStrings #-}
78{-# LANGUAGE PatternSynonyms #-}
9+ {-# LANGUAGE ScopedTypeVariables #-}
810{-# LANGUAGE ViewPatterns #-}
911
1012-- | Generate example CBOR given a CDDL specification
1113module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm , generateCBORTerm' ) where
1214
15+ import qualified Control.Monad.State.Strict as MTL
1316import Capability.Reader
1417import Capability.Sink (HasSink )
1518import Capability.Source (HasSource , MonadState (.. ))
16- import Capability.State (HasState , get , modify , state )
19+ import Capability.State (HasState , get , modify )
1720import Codec.CBOR.Cuddle.CDDL (
1821 Name (.. ),
1922 OccurrenceIndicator (.. ),
@@ -45,11 +48,9 @@ import Data.Word (Word32, Word64)
4548import GHC.Generics (Generic )
4649import System.Random.Stateful (
4750 Random ,
48- RandomGen (genShortByteString , genWord32 , genWord64 ),
49- RandomGenM ,
50- StatefulGen (.. ),
51+ RandomGen (.. ),
52+ StateGenM (.. ),
5153 UniformRange (uniformRM ),
52- applyRandomGenM ,
5354 randomM ,
5455 uniformByteStringM ,
5556 )
@@ -59,10 +60,8 @@ import System.Random.Stateful (
5960--------------------------------------------------------------------------------
6061
6162-- | Generator context, parametrised over the type of the random seed
62- data GenEnv g = GenEnv
63+ newtype GenEnv = GenEnv
6364 { cddl :: CTreeRoot' Identity MonoRef
64- , fakeSeed :: CapGenM g
65- -- ^ Access the "fake" seed, necessary to recursively call generators
6665 }
6766 deriving (Generic )
6867
@@ -77,63 +76,55 @@ data GenState g = GenState
7776 }
7877 deriving (Generic )
7978
80- newtype M g a = M { runM :: StateT (GenState g ) (Reader (GenEnv g )) a }
81- deriving (Functor , Applicative , Monad )
79+ instance RandomGen g => RandomGen (GenState g ) where
80+ genWord8 = withRandomSeed genWord8
81+ genWord16 = withRandomSeed genWord16
82+ genWord32 = withRandomSeed genWord32
83+ genWord64 = withRandomSeed genWord64
84+ split s =
85+ case split (randomSeed s) of
86+ (gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
87+
88+ withRandomSeed :: (t -> (a , g )) -> GenState t -> (a , GenState g )
89+ withRandomSeed f s =
90+ case f (randomSeed s) of
91+ (r, gen) -> (r, s {randomSeed = gen})
92+
93+ newtype M g a = M { runM :: StateT (GenState g ) (Reader GenEnv ) a }
94+ deriving (Functor , Applicative , Monad , MTL.MonadState (GenState g))
8295 deriving
8396 (HasSource " randomSeed" g , HasSink " randomSeed" g , HasState " randomSeed" g )
8497 via Field
8598 " randomSeed"
8699 ()
87- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
100+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
88101 deriving
89102 (HasSource " depth" Int , HasSink " depth" Int , HasState " depth" Int )
90103 via Field
91104 " depth"
92105 ()
93- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
106+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
94107 deriving
95108 ( HasSource " cddl" (CTreeRoot' Identity MonoRef )
96109 , HasReader " cddl" (CTreeRoot' Identity MonoRef )
97110 )
98111 via Field
99112 " cddl"
100113 ()
101- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
102- deriving
103- (HasSource " fakeSeed" (CapGenM g ), HasReader " fakeSeed" (CapGenM g ))
104- via Field
105- " fakeSeed"
106- ()
107- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
108-
109- -- | Opaque type carrying the type of a pure PRNG inside a capability-style
110- -- state monad.
111- data CapGenM g = CapGenM
114+ (Lift (StateT (GenState g ) (MonadReader (Reader GenEnv ))))
112115
113- instance RandomGen g => StatefulGen (CapGenM g ) (M g ) where
114- uniformWord64 _ = state @ " randomSeed" genWord64
115- uniformWord32 _ = state @ " randomSeed" genWord32
116-
117- uniformShortByteString n _ = state @ " randomSeed" (genShortByteString n)
118-
119- instance RandomGen r => RandomGenM (CapGenM r ) r (M r ) where
120- applyRandomGenM f _ = state @ " randomSeed" f
121-
122- runGen :: M g a -> GenEnv g -> GenState g -> (a , GenState g )
116+ runGen :: M g a -> GenEnv -> GenState g -> (a , GenState g )
123117runGen m env st = runReader (runStateT (runM m) st) env
124118
125- evalGen :: M g a -> GenEnv g -> GenState g -> a
119+ evalGen :: M g a -> GenEnv -> GenState g -> a
126120evalGen m env = fst . runGen m env
127121
128- asksM :: forall tag r m a . HasReader tag r m => (r -> m a ) -> m a
129- asksM f = f =<< ask @ tag
130-
131122--------------------------------------------------------------------------------
132123-- Wrappers around some Random function in Gen
133124--------------------------------------------------------------------------------
134125
135126genUniformRM :: forall a g . (UniformRange a , RandomGen g ) => (a , a ) -> M g a
136- genUniformRM = asksM @ " fakeSeed " . uniformRM
127+ genUniformRM r = uniformRM r ( StateGenM @ ( GenState g ))
137128
138129-- | Generate a random number in a given range, biased increasingly towards the
139130-- lower end as the depth parameter increases.
@@ -143,9 +134,8 @@ genDepthBiasedRM ::
143134 (a , a ) ->
144135 M g a
145136genDepthBiasedRM bounds = do
146- fs <- ask @ " fakeSeed"
147137 d <- get @ " depth"
148- samples <- replicateM d (uniformRM bounds fs )
138+ samples <- replicateM d (genUniformRM bounds)
149139 pure $ minimum samples
150140
151141-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +145,10 @@ genDepthBiasedBool = do
155145 and <$> replicateM d genRandomM
156146
157147genRandomM :: forall g a . (Random a , RandomGen g ) => M g a
158- genRandomM = asksM @ " fakeSeed " randomM
148+ genRandomM = randomM ( StateGenM @ ( GenState g ))
159149
160150genBytes :: forall g . RandomGen g => Int -> M g ByteString
161- genBytes n = asksM @ " fakeSeed " $ uniformByteStringM n
151+ genBytes n = uniformByteStringM n ( StateGenM @ ( GenState g ))
162152
163153genText :: forall g . RandomGen g => Int -> M g Text
164154genText n = pure $ T. pack . take n . join $ repeat [' a' .. ' z' ]
@@ -436,12 +426,12 @@ genValueVariant (VBool b) = pure $ TBool b
436426
437427generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438428generateCBORTerm cddl n stdGen =
439- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
429+ let genEnv = GenEnv {cddl}
440430 genState = GenState {randomSeed = stdGen, depth = 1 }
441431 in evalGen (genForName n) genEnv genState
442432
443433generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term , g )
444434generateCBORTerm' cddl n stdGen =
445- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
435+ let genEnv = GenEnv {cddl}
446436 genState = GenState {randomSeed = stdGen, depth = 1 }
447437 in second randomSeed $ runGen (genForName n) genEnv genState
0 commit comments