From 94008907d4d7830250cf58cc3bf9c82c1021f83f Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Thu, 22 Jan 2026 09:44:50 -0800 Subject: [PATCH] Use interior mutability for the Client's PRNG; add constructor. PiperOrigin-RevId: 859651327 --- willow/benches/shell_benchmarks.rs | 8 +++---- willow/src/api/client.rs | 4 +--- willow/src/traits/client.rs | 2 +- willow/src/willow_v1/client.rs | 36 ++++++++++++++++++------------ willow/src/willow_v1/server.rs | 4 +--- willow/src/willow_v1/verifier.rs | 4 +--- willow/tests/willow_v1_shell.rs | 28 ++++++----------------- 7 files changed, 36 insertions(+), 50 deletions(-) diff --git a/willow/benches/shell_benchmarks.rs b/willow/benches/shell_benchmarks.rs index 1933297..e2da727 100644 --- a/willow/benches/shell_benchmarks.rs +++ b/willow/benches/shell_benchmarks.rs @@ -131,9 +131,7 @@ fn setup_base(args: &Args) -> BaseInputs { // Create client. let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(); let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor. let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(); @@ -218,7 +216,7 @@ struct VerifierInputs { } fn setup_verifier_verify_client_message(args: &Args) -> VerifierInputs { - let mut inputs = setup_base(args); + let inputs = setup_base(args); let mut decryption_request_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. @@ -257,7 +255,7 @@ fn run_verifier_verify_client_message(inputs: &mut VerifierInputs) { } fn setup_server_handle_client_message(args: &Args) -> ServerInputs { - let mut inputs = setup_base(args); + let inputs = setup_base(args); let mut ciphertext_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index 5c0835d..5667921 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -86,9 +86,7 @@ impl WillowShellClient { let context_bytes = aggregation_config.compute_context_bytes()?; let kahe = ShellKahe::new(kahe_config, &context_bytes)?; let vahe = ShellVahe::new(ahe_config, &context_bytes)?; - let client_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client_seed)?; - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; Ok(WillowShellClient(client)) } diff --git a/willow/src/traits/client.rs b/willow/src/traits/client.rs index e4fca66..28476db 100644 --- a/willow/src/traits/client.rs +++ b/willow/src/traits/client.rs @@ -26,7 +26,7 @@ pub trait SecureAggregationClient: HasKahe + HasVahe { /// Creates a client message to be sent to the Server. /// nonce is used for the VAHE encryption, has to be unique. fn create_client_message( - &mut self, + &self, plaintext: &Self::PlaintextSlice<'_>, signed_public_key: &DecryptorPublicKey<::Vahe>, nonce: &[u8], diff --git a/willow/src/willow_v1/client.rs b/willow/src/willow_v1/client.rs index baf5fb8..0bdecc5 100644 --- a/willow/src/willow_v1/client.rs +++ b/willow/src/willow_v1/client.rs @@ -15,13 +15,15 @@ use client_traits::SecureAggregationClient; use kahe_traits::{HasKahe, KaheBase, KaheEncrypt, KaheKeygen, TrySecretKeyInto}; use messages::{ClientMessage, DecryptorPublicKey}; +use prng_traits::SecurePrng; +use std::cell::RefCell; use vahe_traits::{HasVahe, VaheBase, VerifiableEncrypt}; /// Lightweight client directly exposing KAHE/VAHE types. pub struct WillowV1Client { pub kahe: Kahe, pub vahe: Vahe, - pub prng: Kahe::Rng, // Using a single PRNG for both VAHE and KAHE. + pub prng: RefCell, // Using a single PRNG for both VAHE and KAHE. } impl HasKahe for WillowV1Client { @@ -38,6 +40,17 @@ impl HasVahe for WillowV1Client { } } +impl WillowV1Client { + pub fn new_with_randomly_generated_seed( + kahe: Kahe, + vahe: Vahe, + ) -> Result { + let seed = Kahe::Rng::generate_seed()?; + let prng = RefCell::new(Kahe::Rng::create(&seed)?); + Ok(Self { kahe, vahe, prng }) + } +} + /// Implementation of the `SecureAggregationClient` trait for the generic /// KAHE/VAHE client, using WillowCommon as the common types (e.g. protocol /// messages are directly the AHE public key and ciphertexts). @@ -51,16 +64,17 @@ where type PlaintextSlice<'a> = ::PlaintextSlice<'a>; fn create_client_message( - &mut self, + &self, plaintext: &Self::PlaintextSlice<'_>, signed_public_key: &DecryptorPublicKey, nonce: &[u8], ) -> Result, status::StatusError> { // Generate a new KAHE key. - let kahe_secret_key = self.kahe.key_gen(&mut self.prng)?; + let kahe_secret_key = self.kahe.key_gen(&mut self.prng.borrow_mut())?; // Encrypt long plaintext with KAHE. - let kahe_ciphertext = self.kahe.encrypt(plaintext, &kahe_secret_key, &mut self.prng)?; + let kahe_ciphertext = + self.kahe.encrypt(plaintext, &kahe_secret_key, &mut self.prng.borrow_mut())?; // Convert KAHE secret key into short AHE plaintext. let ahe_plaintext: Vahe::Plaintext = self.kahe.try_secret_key_into(kahe_secret_key)?; @@ -70,7 +84,7 @@ where &ahe_plaintext, signed_public_key, nonce, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; // Keep a copy of the nonce so the message can be forwarded as-is. @@ -112,9 +126,7 @@ mod test { let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client_seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Generate AHE keys. let mut testing_decryptor = @@ -153,17 +165,13 @@ mod test { let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client1_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client1_seed)?; - let mut client1 = WillowV1Client { kahe, vahe, prng }; + let client1 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create a second client. let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client2_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client2_seed)?; - let mut client2 = WillowV1Client { kahe, vahe, prng }; + let client2 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Generate AHE keys. let mut testing_decryptor = diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index ff3ec8b..a9d07b4 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -392,9 +392,7 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create decryptor. let vahe = ShellVahe::new( diff --git a/willow/src/willow_v1/verifier.rs b/willow/src/willow_v1/verifier.rs index f9006fa..aaecf34 100644 --- a/willow/src/willow_v1/verifier.rs +++ b/willow/src/willow_v1/verifier.rs @@ -305,9 +305,7 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 9e497fd..c0e2d6d 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -57,9 +57,7 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. @@ -147,9 +145,7 @@ fn encrypt_decrypt_one_serialized() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. @@ -287,9 +283,7 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); clients.push(client); } @@ -420,9 +414,7 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); good_clients.push(client); } @@ -437,9 +429,7 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); bad_clients.push(client); } @@ -652,9 +642,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); let client_input_values = generate_random_unsigned_vector(INPUT_LENGTH as usize, INPUT_DOMAIN as u64); @@ -730,9 +718,7 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); clients.push(client); }