11{-# LANGUAGE AllowAmbiguousTypes #-}
2+ {-# LANGUAGE CPP #-}
23{-# LANGUAGE DataKinds #-}
34{-# LANGUAGE DerivingVia #-}
45{-# LANGUAGE GADTs #-}
56{-# LANGUAGE LambdaCase #-}
67{-# LANGUAGE OverloadedStrings #-}
78{-# LANGUAGE PatternSynonyms #-}
9+ {-# LANGUAGE ScopedTypeVariables #-}
810{-# LANGUAGE ViewPatterns #-}
911
12+ #if MIN_VERSION_random(1,3,0)
13+ {-# OPTIONS_GHC -Wno-deprecations #-} -- Due to usage of `split`
14+ #endif
1015-- | Generate example CBOR given a CDDL specification
1116module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm , generateCBORTerm' ) where
1217
1318import Capability.Reader
1419import Capability.Sink (HasSink )
1520import Capability.Source (HasSource , MonadState (.. ))
16- import Capability.State (HasState , get , modify , state )
21+ import Capability.State (HasState , get , modify )
1722import Codec.CBOR.Cuddle.CDDL (
1823 Name (.. ),
1924 OccurrenceIndicator (.. ),
@@ -31,6 +36,7 @@ import Codec.CBOR.Write qualified as CBOR
3136import Control.Monad (join , replicateM , (<=<) )
3237import Control.Monad.Reader (Reader , runReader )
3338import Control.Monad.State.Strict (StateT , runStateT )
39+ import Control.Monad.State.Strict qualified as MTL
3440import Data.Bifunctor (second )
3541import Data.ByteString (ByteString )
3642import Data.ByteString.Base16 qualified as Base16
@@ -45,24 +51,25 @@ import Data.Word (Word32, Word64)
4551import GHC.Generics (Generic )
4652import System.Random.Stateful (
4753 Random ,
48- RandomGen (genShortByteString , genWord32 , genWord64 ),
49- RandomGenM ,
50- StatefulGen (.. ),
54+ RandomGen (.. ),
55+ StateGenM (.. ),
5156 UniformRange (uniformRM ),
52- applyRandomGenM ,
5357 randomM ,
5458 uniformByteStringM ,
5559 )
60+ #if MIN_VERSION_random(1,3,0)
61+ import System.Random.Stateful (
62+ SplitGen (.. )
63+ )
64+ #endif
5665
5766--------------------------------------------------------------------------------
5867-- Generator infrastructure
5968--------------------------------------------------------------------------------
6069
6170-- | Generator context, parametrised over the type of the random seed
62- data GenEnv g = GenEnv
71+ newtype GenEnv = GenEnv
6372 { cddl :: CTreeRoot' Identity MonoRef
64- , fakeSeed :: CapGenM g
65- -- ^ Access the "fake" seed, necessary to recursively call generators
6673 }
6774 deriving (Generic )
6875
@@ -77,63 +84,63 @@ data GenState g = GenState
7784 }
7885 deriving (Generic )
7986
80- newtype M g a = M { runM :: StateT (GenState g ) (Reader (GenEnv g )) a }
81- deriving (Functor , Applicative , Monad )
87+ instance RandomGen g => RandomGen (GenState g ) where
88+ genWord8 = withRandomSeed genWord8
89+ genWord16 = withRandomSeed genWord16
90+ genWord32 = withRandomSeed genWord32
91+ genWord64 = withRandomSeed genWord64
92+ split = splitGenStateWith split
93+
94+ #if MIN_VERSION_random(1,3,0)
95+ instance SplitGen g => SplitGen (GenState g ) where
96+ splitGen = splitGenStateWith splitGen
97+ #endif
98+
99+ splitGenStateWith :: (g -> (g , g )) -> GenState g -> (GenState g , GenState g )
100+ splitGenStateWith f s =
101+ case f (randomSeed s) of
102+ (gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
103+
104+ withRandomSeed :: (t -> (a , g )) -> GenState t -> (a , GenState g )
105+ withRandomSeed f s =
106+ case f (randomSeed s) of
107+ (r, gen) -> (r, s {randomSeed = gen})
108+
109+ newtype M g a = M { runM :: StateT (GenState g ) (Reader GenEnv ) a }
110+ deriving (Functor , Applicative , Monad , MTL.MonadState (GenState g))
82111 deriving
83112 (HasSource " randomSeed" g , HasSink " randomSeed" g , HasState " randomSeed" g )
84113 via Field
85114 " randomSeed"
86115 ()
87- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
116+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
88117 deriving
89118 (HasSource " depth" Int , HasSink " depth" Int , HasState " depth" Int )
90119 via Field
91120 " depth"
92121 ()
93- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
122+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
94123 deriving
95124 ( HasSource " cddl" (CTreeRoot' Identity MonoRef )
96125 , HasReader " cddl" (CTreeRoot' Identity MonoRef )
97126 )
98127 via Field
99128 " cddl"
100129 ()
101- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
102- deriving
103- (HasSource " fakeSeed" (CapGenM g ), HasReader " fakeSeed" (CapGenM g ))
104- via Field
105- " fakeSeed"
106- ()
107- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
108-
109- -- | Opaque type carrying the type of a pure PRNG inside a capability-style
110- -- state monad.
111- data CapGenM g = CapGenM
130+ (Lift (StateT (GenState g ) (MonadReader (Reader GenEnv ))))
112131
113- instance RandomGen g => StatefulGen (CapGenM g ) (M g ) where
114- uniformWord64 _ = state @ " randomSeed" genWord64
115- uniformWord32 _ = state @ " randomSeed" genWord32
116-
117- uniformShortByteString n _ = state @ " randomSeed" (genShortByteString n)
118-
119- instance RandomGen r => RandomGenM (CapGenM r ) r (M r ) where
120- applyRandomGenM f _ = state @ " randomSeed" f
121-
122- runGen :: M g a -> GenEnv g -> GenState g -> (a , GenState g )
132+ runGen :: M g a -> GenEnv -> GenState g -> (a , GenState g )
123133runGen m env st = runReader (runStateT (runM m) st) env
124134
125- evalGen :: M g a -> GenEnv g -> GenState g -> a
135+ evalGen :: M g a -> GenEnv -> GenState g -> a
126136evalGen m env = fst . runGen m env
127137
128- asksM :: forall tag r m a . HasReader tag r m => (r -> m a ) -> m a
129- asksM f = f =<< ask @ tag
130-
131138--------------------------------------------------------------------------------
132139-- Wrappers around some Random function in Gen
133140--------------------------------------------------------------------------------
134141
135142genUniformRM :: forall a g . (UniformRange a , RandomGen g ) => (a , a ) -> M g a
136- genUniformRM = asksM @ " fakeSeed " . uniformRM
143+ genUniformRM r = uniformRM r ( StateGenM @ ( GenState g ))
137144
138145-- | Generate a random number in a given range, biased increasingly towards the
139146-- lower end as the depth parameter increases.
@@ -143,9 +150,8 @@ genDepthBiasedRM ::
143150 (a , a ) ->
144151 M g a
145152genDepthBiasedRM bounds = do
146- fs <- ask @ " fakeSeed"
147153 d <- get @ " depth"
148- samples <- replicateM d (uniformRM bounds fs )
154+ samples <- replicateM d (genUniformRM bounds)
149155 pure $ minimum samples
150156
151157-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +161,10 @@ genDepthBiasedBool = do
155161 and <$> replicateM d genRandomM
156162
157163genRandomM :: forall g a . (Random a , RandomGen g ) => M g a
158- genRandomM = asksM @ " fakeSeed " randomM
164+ genRandomM = randomM ( StateGenM @ ( GenState g ))
159165
160166genBytes :: forall g . RandomGen g => Int -> M g ByteString
161- genBytes n = asksM @ " fakeSeed " $ uniformByteStringM n
167+ genBytes n = uniformByteStringM n ( StateGenM @ ( GenState g ))
162168
163169genText :: forall g . RandomGen g => Int -> M g Text
164170genText n = pure $ T. pack . take n . join $ repeat [' a' .. ' z' ]
@@ -436,12 +442,12 @@ genValueVariant (VBool b) = pure $ TBool b
436442
437443generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438444generateCBORTerm cddl n stdGen =
439- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
445+ let genEnv = GenEnv {cddl}
440446 genState = GenState {randomSeed = stdGen, depth = 1 }
441447 in evalGen (genForName n) genEnv genState
442448
443449generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term , g )
444450generateCBORTerm' cddl n stdGen =
445- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
451+ let genEnv = GenEnv {cddl}
446452 genState = GenState {randomSeed = stdGen, depth = 1 }
447453 in second randomSeed $ runGen (genForName n) genEnv genState
0 commit comments