From 39efbe66cd2b35947b5554e4467f0ae28faa6d72 Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Tue, 20 Jan 2026 14:53:23 -0800 Subject: [PATCH] Use interior mutability for Decryptor PRNG too; add constructor; clean up imports. PiperOrigin-RevId: 858764839 --- shell_wrapper/status_macros.h | 5 ++ willow/benches/shell_benchmarks.rs | 14 ++--- willow/src/api/client.rs | 6 +- willow/src/shell/BUILD | 42 +++++++++++++ willow/src/shell/parameters_utils.cc | 48 +++++++++++++++ willow/src/shell/parameters_utils.h | 36 +++++++++++ willow/src/shell/parameters_utils.rs | 56 ++++++++++++++++- willow/src/shell/parameters_utils_test.cc | 61 +++++++++++++++++++ .../testing_utils/shell_testing_decryptor.rs | 24 +++++--- willow/src/traits/client.rs | 2 +- willow/src/traits/decryptor.rs | 4 +- willow/src/willow_v1/BUILD | 9 +-- willow/src/willow_v1/client.rs | 38 +++++++----- willow/src/willow_v1/decryptor.rs | 29 +++++---- willow/src/willow_v1/server.rs | 10 +-- willow/src/willow_v1/verifier.rs | 10 +-- willow/tests/willow_v1_shell.rs | 54 ++++------------ 17 files changed, 324 insertions(+), 124 deletions(-) create mode 100644 willow/src/shell/parameters_utils.cc create mode 100644 willow/src/shell/parameters_utils.h create mode 100644 willow/src/shell/parameters_utils_test.cc diff --git a/shell_wrapper/status_macros.h b/shell_wrapper/status_macros.h index a3ec64a..fb63d66 100644 --- a/shell_wrapper/status_macros.h +++ b/shell_wrapper/status_macros.h @@ -49,4 +49,9 @@ return status; \ } +// Internal helper to handle results from Rust FFI calls, which return a +// secure_aggregation::FfiStatus instead of a absl::Status. +#define SECAGG_RETURN_IF_FFI_ERROR(expr) \ + SECAGG_RETURN_IF_ERROR(secure_aggregation::UnwrapFfiStatus(expr)) + #endif // SECURE_AGGREGATION_SHELL_WRAPPER_STATUS_MACROS_H_ diff --git a/willow/benches/shell_benchmarks.rs b/willow/benches/shell_benchmarks.rs index 1933297..5a31783 100644 --- a/willow/benches/shell_benchmarks.rs +++ b/willow/benches/shell_benchmarks.rs @@ -28,9 +28,7 @@ use messages::{ PartialDecryptionRequest, }; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; -use prng_traits::SecurePrng; use server_traits::SecureAggregationServer; -use single_thread_hkdf::SingleThreadHkdfPrng; use testing_utils::{generate_random_nonce, generate_random_unsigned_vector}; use vahe_shell::ShellVahe; use verifier_traits::SecureAggregationVerifier; @@ -131,16 +129,12 @@ fn setup_base(args: &Args) -> BaseInputs { // Create client. let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(); let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor. let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(); @@ -218,7 +212,7 @@ struct VerifierInputs { } fn setup_verifier_verify_client_message(args: &Args) -> VerifierInputs { - let mut inputs = setup_base(args); + let inputs = setup_base(args); let mut decryption_request_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. @@ -257,7 +251,7 @@ fn run_verifier_verify_client_message(inputs: &mut VerifierInputs) { } fn setup_server_handle_client_message(args: &Args) -> ServerInputs { - let mut inputs = setup_base(args); + let inputs = setup_base(args); let mut ciphertext_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index 5c0835d..1c9811c 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -22,11 +22,9 @@ use client_traits::SecureAggregationClient; use kahe_shell::ShellKahe; use kahe_traits::KaheBase; use parameters_shell::create_shell_configs; -use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use shell_ciphertexts_rust_proto::ShellAhePublicKey; -use single_thread_hkdf::SingleThreadHkdfPrng; use status::ffi::FfiStatus; use status::StatusError; use std::collections::HashMap; @@ -86,9 +84,7 @@ impl WillowShellClient { let context_bytes = aggregation_config.compute_context_bytes()?; let kahe = ShellKahe::new(kahe_config, &context_bytes)?; let vahe = ShellVahe::new(ahe_config, &context_bytes)?; - let client_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client_seed)?; - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; Ok(WillowShellClient(client)) } diff --git a/willow/src/shell/BUILD b/willow/src/shell/BUILD index 8efd481..2806681 100644 --- a/willow/src/shell/BUILD +++ b/willow/src/shell/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test") package( @@ -113,10 +115,50 @@ rust_library( crate_root = "parameters_utils.rs", deps = [ ":kahe_shell", + ":parameters_shell", "@protobuf//rust:protobuf", + "@cxx.rs//:cxx", "//shell_wrapper:kahe", "//shell_wrapper:status", "//willow/proto/shell:shell_parameters_rust_proto", + "//willow/proto/willow:aggregation_config_rust_proto", + "//willow/src/api:aggregation_config", + "//willow/src/traits:proto_serialization_traits", + ], +) + +rust_cxx_bridge( + name = "shell_parameters_utils_cxx", + src = "parameters_utils.rs", + deps = [ + ":shell_parameters_utils", + "//shell_wrapper:status_cxx", + ], +) + +cc_library( + name = "shell_parameters_utils_cc", + srcs = ["parameters_utils.cc"], + hdrs = ["parameters_utils.h"], + deps = [ + ":shell_parameters_utils_cxx", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@cxx.rs//:core", + "//shell_wrapper:status_cc", + "//shell_wrapper:status_macros", + "//willow/proto/willow:aggregation_config_cc_proto", + ], +) + +cc_test( + name = "parameters_utils_test", + srcs = ["parameters_utils_test.cc"], + deps = [ + ":shell_parameters_utils_cc", + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "//willow/proto/willow:aggregation_config_cc_proto", ], ) diff --git a/willow/src/shell/parameters_utils.cc b/willow/src/shell/parameters_utils.cc new file mode 100644 index 0000000..f91b911 --- /dev/null +++ b/willow/src/shell/parameters_utils.cc @@ -0,0 +1,48 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "willow/src/shell/parameters_utils.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "include/cxx.h" +#include "shell_wrapper/status_macros.h" +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/src/shell/parameters_utils.rs.h" + +namespace secure_aggregation { +namespace willow { + +absl::StatusOr CreateHumanReadableShellConfig( + const AggregationConfigProto& config) { + std::string serialized_config = config.SerializeAsString(); + rust::Vec result; + + SECAGG_RETURN_IF_FFI_ERROR( + secure_aggregation::create_human_readable_shell_config( + std::make_unique(std::move(serialized_config)), + &result)); + + return std::string(reinterpret_cast(result.data()), + result.size()); +} + +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/src/shell/parameters_utils.h b/willow/src/shell/parameters_utils.h new file mode 100644 index 0000000..020791d --- /dev/null +++ b/willow/src/shell/parameters_utils.h @@ -0,0 +1,36 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_SHELL_PARAMETER_UTILS_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_SHELL_PARAMETER_UTILS_H_ + +#include + +#include "absl/status/statusor.h" +#include "willow/proto/willow/aggregation_config.pb.h" + +namespace secure_aggregation { +namespace willow { + +// Returns the ShellKaheConfig and ShellAheConfig as a human-readable string, +// for the given AggregationConfigProto. +absl::StatusOr CreateHumanReadableShellConfig( + const AggregationConfigProto& config); + +} // namespace willow +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_SRC_SHELL_PARAMETER_UTILS_H_ diff --git a/willow/src/shell/parameters_utils.rs b/willow/src/shell/parameters_utils.rs index cf3bd11..76a9d95 100644 --- a/willow/src/shell/parameters_utils.rs +++ b/willow/src/shell/parameters_utils.rs @@ -12,17 +12,67 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! This file contains some utility functions for working with Willow parameters: +//! - Conversions between Rust structs and their corresponding protos. +//! - FFI bridge to generate and print Shell parameters from C++. + +use aggregation_config::AggregationConfig; +use aggregation_config_rust_proto::AggregationConfigProto; +use cxx::{CxxString, UniquePtr}; use kahe::PackedVectorConfig; use kahe_shell::ShellKaheConfig; -use protobuf::{proto, ProtoStr}; +use parameters_shell::create_shell_configs; +use proto_serialization_traits::FromProto; +use protobuf::{proto, Parse, ProtoStr}; use shell_parameters_rust_proto::{ PackedVectorConfigProto, PackedVectorConfigProtoView, ShellKaheConfigProto, ShellKaheConfigProtoView, }; use std::collections::BTreeMap; -/// This file contains some utility functions for working with Willow parameters: -/// - Conversions between Rust structs and their corresponding protos. +#[cxx::bridge] +pub mod ffi { + // Re-define FfiStatus since CXX requires shared structs to be defined in the same module + // (https://github.com/dtolnay/cxx/issues/297#issuecomment-727042059) + unsafe extern "C++" { + include!("shell_wrapper/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + + #[namespace = "secure_aggregation"] + extern "Rust" { + unsafe fn create_human_readable_shell_config( + aggregation_config_proto: UniquePtr, + out: *mut Vec, + ) -> FfiStatus; + } +} + +fn create_human_readable_shell_config_impl( + aggregation_config_proto: UniquePtr, +) -> Result, status::StatusError> { + let config_proto = + AggregationConfigProto::parse(aggregation_config_proto.as_bytes()).map_err(|e| { + status::invalid_argument(format!("Failed to parse AggregationConfigProto: {}", e)) + })?; + let config = AggregationConfig::from_proto(config_proto, ())?; + let (kahe_config, ahe_config) = create_shell_configs(&config)?; + let kahe_config_string = format!("{:#?}", kahe_config); + let ahe_config_string = format!("{:#?}", ahe_config); + let result = + format!("ShellKaheConfig: {}\nShellAheConfig: {}", kahe_config_string, ahe_config_string); + Ok(result.into_bytes()) +} + +/// SAFETY: `out` must not be null. +unsafe fn create_human_readable_shell_config( + aggregation_config_proto: UniquePtr, + out: *mut Vec, +) -> ffi::FfiStatus { + create_human_readable_shell_config_impl(aggregation_config_proto) + .map(|result| *out = result) + .into() +} /// Convert a rust struct `PackedVectorConfig` to the corresponding proto. pub fn packed_vector_config_to_proto(config: &PackedVectorConfig) -> PackedVectorConfigProto { diff --git a/willow/src/shell/parameters_utils_test.cc b/willow/src/shell/parameters_utils_test.cc new file mode 100644 index 0000000..3366e68 --- /dev/null +++ b/willow/src/shell/parameters_utils_test.cc @@ -0,0 +1,61 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "willow/src/shell/parameters_utils.h" + +#include "absl/status/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/aggregation_config.pb.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +TEST(ParameterUtilsTest, CreateHumanReadableShellConfigTest) { + AggregationConfigProto config; + VectorConfig vector_config; + vector_config.set_length(10); + vector_config.set_bound(100); + (*config.mutable_vector_configs())["test_vector"] = vector_config; + config.set_max_number_of_decryptors(1); + config.set_max_number_of_clients(10); + config.set_session_id("test_session"); + + auto result = CreateHumanReadableShellConfig(config); + + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, ::testing::HasSubstr("ShellKaheConfig")); + EXPECT_THAT(*result, ::testing::HasSubstr("ShellAheConfig")); +} + +TEST(ParameterUtilsTest, CreateHumanReadableShellConfigInvalidConfigTest) { + AggregationConfigProto config; + config.set_max_number_of_decryptors(1); + config.set_max_number_of_clients(10); + config.set_session_id("test_session"); + + auto result = CreateHumanReadableShellConfig(config); + + EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + result.status().message(), + ::testing::HasSubstr("empty vector configs in aggregation config")); +} + +} // namespace +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index 29314cf..2f8aaa9 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -31,6 +31,7 @@ use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; +use std::cell::RefCell; use vahe_shell::ShellVahe; use vahe_traits::Recover; use vahe_traits::{HasVahe, VaheBase}; @@ -41,7 +42,7 @@ use vahe_traits::{HasVahe, VaheBase}; pub struct ShellTestingDecryptor { kahe: ShellKahe, vahe: ShellVahe, - prng: SingleThreadHkdfPrng, + prng: RefCell, secret_key: Option<::SecretKeyShare>, } @@ -64,14 +65,14 @@ impl ShellTestingDecryptor { let vahe = ShellVahe::new(ahe_config, context_string)?; let seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&seed)?; - Ok(ShellTestingDecryptor { kahe, vahe, prng, secret_key: None }) + Ok(ShellTestingDecryptor { kahe, vahe, prng: RefCell::new(prng), secret_key: None }) } /// Generates a new AHE public key, and stores the corresponding secret key. pub fn generate_public_key( &mut self, ) -> Result<::PublicKey, StatusError> { - let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?; + let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?; self.secret_key = Some(sk_share); let public_key = self.vahe.aggregate_public_key_shares(&[pk_share])?; Ok(public_key) @@ -81,7 +82,7 @@ impl ShellTestingDecryptor { /// the AHE ciphertext and then decrypting the KAHE ciphertext. Does not verify the client proof /// contained in the message. pub fn decrypt( - &mut self, + &self, client_message: &ClientMessage, ) -> Result<::Plaintext, StatusError> { let partial_dec_ciphertext = @@ -94,8 +95,11 @@ impl ShellTestingDecryptor { "No secret key available", )), Some(sk_share) => { - let partial_decryption = - self.vahe.partial_decrypt(&partial_dec_ciphertext, sk_share, &mut self.prng)?; + let partial_decryption = self.vahe.partial_decrypt( + &partial_dec_ciphertext, + sk_share, + &mut self.prng.borrow_mut(), + )?; let decrypted_kahe_key = self.vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?; let decrypted_kahe_key = self.kahe.try_secret_key_from(decrypted_kahe_key)?; @@ -134,7 +138,7 @@ impl ShellTestingDecryptor { } fn decrypt_serialized( - &mut self, + &self, contribution: &[u8], ) -> Result, StatusError> { let client_message_proto = ClientMessageProto::parse(contribution) @@ -161,7 +165,7 @@ impl ShellTestingDecryptor { /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. unsafe fn decrypt_ffi( - &mut self, + &self, contribution: &[u8], out: *mut Vec, out_status_message: *mut cxx::UniquePtr, @@ -192,7 +196,7 @@ impl ShellTestingDecryptor { let partial_decryption = self.vahe.partial_decrypt( &request.partial_dec_ciphertext, sk_share, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; Ok(PartialDecryptionResponse { partial_decryption }) } @@ -267,7 +271,7 @@ pub mod ffi { #[rust_name = "decrypt_ffi"] unsafe fn decrypt( - self: &mut ShellTestingDecryptor, + self: &ShellTestingDecryptor, contribution: &[u8], out: *mut Vec, out_status_message: *mut UniquePtr, diff --git a/willow/src/traits/client.rs b/willow/src/traits/client.rs index e4fca66..28476db 100644 --- a/willow/src/traits/client.rs +++ b/willow/src/traits/client.rs @@ -26,7 +26,7 @@ pub trait SecureAggregationClient: HasKahe + HasVahe { /// Creates a client message to be sent to the Server. /// nonce is used for the VAHE encryption, has to be unique. fn create_client_message( - &mut self, + &self, plaintext: &Self::PlaintextSlice<'_>, signed_public_key: &DecryptorPublicKey<::Vahe>, nonce: &[u8], diff --git a/willow/src/traits/decryptor.rs b/willow/src/traits/decryptor.rs index 31f8934..b0d7dae 100644 --- a/willow/src/traits/decryptor.rs +++ b/willow/src/traits/decryptor.rs @@ -24,14 +24,14 @@ pub trait SecureAggregationDecryptor: HasVahe { /// Creates a public key share to be sent to the Server, updating the /// decryptor state. fn create_public_key_share( - &mut self, + &self, decryptor_state: &mut Self::DecryptorState, ) -> Result::Vahe>, StatusError>; /// Handles a partial decryption request received from the Server. Returns a /// partial decryption to the Server. fn handle_partial_decryption_request( - &mut self, + &self, partial_decryption_request: PartialDecryptionRequest<::Vahe>, decryptor_state: &Self::DecryptorState, ) -> Result::Vahe>, StatusError>; diff --git a/willow/src/willow_v1/BUILD b/willow/src/willow_v1/BUILD index edd7f8a..be432b5 100644 --- a/willow/src/willow_v1/BUILD +++ b/willow/src/willow_v1/BUILD @@ -44,12 +44,10 @@ rust_test( "//willow/src/api:aggregation_config", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/testing_utils:shell_testing_decryptor", "//willow/src/testing_utils:shell_testing_parameters", - "//willow/src/traits:prng_traits", ], ) @@ -59,11 +57,9 @@ rust_test( deps = [ "@crate_index//:googletest", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/traits:ahe_traits", "//willow/src/traits:decryptor_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", ], ) @@ -81,6 +77,7 @@ rust_library( "//willow/src/traits:ahe_traits", "//willow/src/traits:decryptor_traits", "//willow/src/traits:messages", + "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:vahe_traits", ], @@ -96,13 +93,11 @@ rust_test( "@crate_index//:googletest", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/traits:ahe_traits", "//willow/src/traits:client_traits", "//willow/src/traits:decryptor_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:verifier_traits", @@ -158,7 +153,6 @@ rust_test( "//shell_wrapper:status_matchers_rs", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/testing_utils:shell_testing_parameters", @@ -166,7 +160,6 @@ rust_test( "//willow/src/traits:client_traits", "//willow/src/traits:decryptor_traits", "//willow/src/traits:kahe_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:vahe_traits", diff --git a/willow/src/willow_v1/client.rs b/willow/src/willow_v1/client.rs index baf5fb8..6e9d12f 100644 --- a/willow/src/willow_v1/client.rs +++ b/willow/src/willow_v1/client.rs @@ -15,13 +15,15 @@ use client_traits::SecureAggregationClient; use kahe_traits::{HasKahe, KaheBase, KaheEncrypt, KaheKeygen, TrySecretKeyInto}; use messages::{ClientMessage, DecryptorPublicKey}; +use prng_traits::SecurePrng; +use std::cell::RefCell; use vahe_traits::{HasVahe, VaheBase, VerifiableEncrypt}; /// Lightweight client directly exposing KAHE/VAHE types. pub struct WillowV1Client { pub kahe: Kahe, pub vahe: Vahe, - pub prng: Kahe::Rng, // Using a single PRNG for both VAHE and KAHE. + pub prng: RefCell, // Using a single PRNG for both VAHE and KAHE. } impl HasKahe for WillowV1Client { @@ -38,6 +40,17 @@ impl HasVahe for WillowV1Client { } } +impl WillowV1Client { + pub fn new_with_randomly_generated_seed( + kahe: Kahe, + vahe: Vahe, + ) -> Result { + let seed = Kahe::Rng::generate_seed()?; + let prng = RefCell::new(Kahe::Rng::create(&seed)?); + Ok(Self { kahe, vahe, prng }) + } +} + /// Implementation of the `SecureAggregationClient` trait for the generic /// KAHE/VAHE client, using WillowCommon as the common types (e.g. protocol /// messages are directly the AHE public key and ciphertexts). @@ -51,16 +64,17 @@ where type PlaintextSlice<'a> = ::PlaintextSlice<'a>; fn create_client_message( - &mut self, + &self, plaintext: &Self::PlaintextSlice<'_>, signed_public_key: &DecryptorPublicKey, nonce: &[u8], ) -> Result, status::StatusError> { // Generate a new KAHE key. - let kahe_secret_key = self.kahe.key_gen(&mut self.prng)?; + let kahe_secret_key = self.kahe.key_gen(&mut self.prng.borrow_mut())?; // Encrypt long plaintext with KAHE. - let kahe_ciphertext = self.kahe.encrypt(plaintext, &kahe_secret_key, &mut self.prng)?; + let kahe_ciphertext = + self.kahe.encrypt(plaintext, &kahe_secret_key, &mut self.prng.borrow_mut())?; // Convert KAHE secret key into short AHE plaintext. let ahe_plaintext: Vahe::Plaintext = self.kahe.try_secret_key_into(kahe_secret_key)?; @@ -70,7 +84,7 @@ where &ahe_plaintext, signed_public_key, nonce, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; // Keep a copy of the nonce so the message can be forwarded as-is. @@ -88,9 +102,7 @@ mod test { use googletest::{gtest, verify_eq, verify_that}; use kahe_shell::ShellKahe; use parameters_shell::create_shell_configs; - use prng_traits::SecurePrng; use shell_testing_decryptor::ShellTestingDecryptor; - use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; use testing_utils::generate_random_nonce; use vahe_shell::ShellVahe; @@ -112,9 +124,7 @@ mod test { let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client_seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Generate AHE keys. let mut testing_decryptor = @@ -153,17 +163,13 @@ mod test { let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client1_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client1_seed)?; - let mut client1 = WillowV1Client { kahe, vahe, prng }; + let client1 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create a second client. let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?; - let client2_seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&client2_seed)?; - let mut client2 = WillowV1Client { kahe, vahe, prng }; + let client2 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Generate AHE keys. let mut testing_decryptor = diff --git a/willow/src/willow_v1/decryptor.rs b/willow/src/willow_v1/decryptor.rs index 9301df0..81b023e 100644 --- a/willow/src/willow_v1/decryptor.rs +++ b/willow/src/willow_v1/decryptor.rs @@ -16,17 +16,19 @@ use ahe_traits::{AheKeygen, PartialDec}; use decryptor_traits::SecureAggregationDecryptor; use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse}; use messages_rust_proto::DecryptorStateProto; +use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; -use protobuf::{proto, AsView}; +use protobuf::AsView; use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare; use status::StatusError; +use std::cell::RefCell; use vahe_traits::{EncryptVerify, HasVahe, VaheBase}; /// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs, /// does not provide verifiable partial decryptions. pub struct WillowV1Decryptor { pub vahe: Vahe, - pub prng: Vahe::Rng, + pub prng: RefCell, } impl HasVahe for WillowV1Decryptor { @@ -36,6 +38,14 @@ impl HasVahe for WillowV1Decryptor { } } +impl WillowV1Decryptor { + pub fn new_with_randomly_generated_seed(vahe: Vahe) -> Result { + let seed = Vahe::Rng::generate_seed()?; + let prng = RefCell::new(Vahe::Rng::create(&seed)?); + Ok(Self { vahe, prng }) + } +} + pub struct DecryptorState { sk_share: Option, } @@ -97,10 +107,10 @@ where /// Creates a public key share to be sent to the Server, updating the /// decryptor state. fn create_public_key_share( - &mut self, + &self, decryptor_state: &mut Self::DecryptorState, ) -> Result, status::StatusError> { - let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?; + let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?; decryptor_state.sk_share = Some(sk_share); Ok(pk_share) } @@ -108,7 +118,7 @@ where /// Handles a partial decryption request received from the Server. Returns a /// partial decryption to the Server. fn handle_partial_decryption_request( - &mut self, + &self, partial_decryption_request: PartialDecryptionRequest, decryptor_state: &Self::DecryptorState, ) -> Result, status::StatusError> { @@ -121,7 +131,7 @@ where let pd = self.vahe.partial_decrypt( &partial_decryption_request.partial_dec_ciphertext, sk_share, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; Ok(PartialDecryptionResponse { partial_decryption: pd }) } @@ -129,15 +139,12 @@ where #[cfg(test)] mod tests { - use super::*; use crate::{DecryptorState, WillowV1Decryptor}; use ahe_traits::AheBase; use decryptor_traits::SecureAggregationDecryptor; use googletest::{gtest, verify_true}; use parameters_shell::create_shell_ahe_config; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; - use single_thread_hkdf::SingleThreadHkdfPrng; use vahe_shell::ShellVahe; const CONTEXT_STRING: &[u8] = b"testing_context_string"; @@ -145,9 +152,7 @@ mod tests { #[gtest] fn decryptor_state_serialization_roundtrip() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; let mut decryptor_state = DecryptorState::default(); // Check empty state serialization. diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index ff3ec8b..d2110aa 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -362,10 +362,8 @@ mod tests { use googletest::{gtest, verify_true}; use kahe_shell::ShellKahe; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; - use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; use testing_utils::{generate_aggregation_config, generate_random_nonce}; use vahe_shell::ShellVahe; @@ -392,9 +390,7 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create decryptor. let vahe = ShellVahe::new( @@ -402,10 +398,8 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; // Create server. let kahe = diff --git a/willow/src/willow_v1/verifier.rs b/willow/src/willow_v1/verifier.rs index f9006fa..c887bc0 100644 --- a/willow/src/willow_v1/verifier.rs +++ b/willow/src/willow_v1/verifier.rs @@ -271,10 +271,8 @@ mod tests { use kahe_shell::ShellKahe; use kahe_traits::KaheBase; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; - use single_thread_hkdf::SingleThreadHkdfPrng; use status_matchers_rs::status_is; use std::collections::HashMap; use testing_utils::{generate_aggregation_config, generate_random_nonce}; @@ -305,9 +303,7 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?; // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. @@ -316,10 +312,8 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; // Create server. let kahe = diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 9e497fd..56ff8c9 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -24,10 +24,8 @@ use messages::{ PartialDecryptionRequest, PartialDecryptionResponse, }; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; -use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; -use single_thread_hkdf::SingleThreadHkdfPrng; use status::StatusErrorCode; use status_matchers_rs::status_is; use std::collections::HashMap; @@ -57,19 +55,15 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -147,19 +141,15 @@ fn encrypt_decrypt_one_serialized() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let kahe = @@ -287,9 +277,7 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); clients.push(client); } @@ -298,10 +286,8 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -420,9 +406,7 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); good_clients.push(client); } @@ -437,9 +421,7 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); bad_clients.push(client); } @@ -448,10 +430,8 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -616,10 +596,8 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Decryptor generates public key share. let public_key_share = decryptor.create_public_key_share(&mut decryptor_state).unwrap(); @@ -652,9 +630,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let mut client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); let client_input_values = generate_random_unsigned_vector(INPUT_LENGTH as usize, INPUT_DOMAIN as u64); @@ -730,9 +706,7 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { let kahe = ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); - let client = WillowV1Client { kahe, vahe, prng }; + let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap(); clients.push(client); } @@ -746,10 +720,8 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let decryptor_state = DecryptorState::default(); - let decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); decryptor_states.push(decryptor_state); decryptors.push(decryptor); }