From 255da7b209afba4623b63d681585eb0d587af90f Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 12:56:09 -0500 Subject: [PATCH 1/7] ore/retry: create a separate RetryStream type Having a separate `RetryStream` cleanly separates the specification of a retry policy from the state required to execute that policy. A forthcoming commit will add a synchronous retry API which makes the separation of policy from implementation more important. This commit makes the new `RetryStream` an internal implementation detail rather than part of the public API, as it doesn't seem useful outside of `Retry::retry` and `RetryReader`. If we *do* discover a use for it, it's easy to slap a `pub` on the new `RetryStream` type down the road. --- src/ore/src/retry.rs | 91 ++++++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/src/ore/src/retry.rs b/src/ore/src/retry.rs index f046dfa47db25..5f45826c81be4 100644 --- a/src/ore/src/retry.rs +++ b/src/ore/src/retry.rs @@ -60,12 +60,11 @@ use std::cmp; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::Instant; use futures::{ready, Stream, StreamExt}; use pin_project::pin_project; use tokio::io::{AsyncRead, ReadBuf}; -use tokio::time::Duration; +use tokio::time::{self, Duration, Instant, Sleep}; // TODO(benesch): remove this if the `duration_constants` feature stabilizes. // See: https://github.com/rust-lang/rust/issues/57391 @@ -93,12 +92,6 @@ pub struct Retry { factor: f64, clamp_backoff: Duration, limit: RetryLimit, - - i: usize, - next_backoff: Option, - #[pin] - sleep: tokio::time::Sleep, - start: Instant, } impl Retry { @@ -162,16 +155,6 @@ impl Retry { self } - /// Resets the start time and try counters of this Retry instance to their initial values to - /// allow re-using it for another retryable operation - pub fn reset(self: Pin<&mut Self>) { - let mut this = self.project(); - *this.i = 0; - *this.next_backoff = None; - *this.start = Instant::now(); - this.sleep.set(tokio::time::sleep(Duration::default())); - } - /// Retries the asynchronous, fallible operation `f` according to the /// configured policy. /// @@ -196,10 +179,10 @@ impl Retry { F: FnMut(RetryState) -> U, U: Future>, { - let this = self; - tokio::pin!(this); + let stream = self.into_retry_stream(); + tokio::pin!(stream); let mut err = None; - while let Some(state) = this.next().await { + while let Some(state) = stream.next().await { match f(state).await { Ok(v) => return Ok(v), Err(e) => err = Some(e), @@ -207,6 +190,16 @@ impl Retry { } Err(err.expect("retry produces at least one element")) } + + fn into_retry_stream(self) -> RetryStream { + RetryStream { + retry: self, + start: Instant::now(), + i: 0, + next_backoff: None, + sleep: time::sleep(Duration::default()), + } + } } impl Default for Retry { @@ -218,35 +211,49 @@ impl Default for Retry { factor: 2.0, clamp_backoff: MAX_DURATION, limit: RetryLimit::Duration(Duration::from_secs(30)), - - i: 0, - next_backoff: None, - start: Instant::now(), - sleep: tokio::time::sleep(Duration::default()), } } } -impl Stream for Retry { +#[pin_project] +#[derive(Debug)] +struct RetryStream { + retry: Retry, + start: Instant, + i: usize, + next_backoff: Option, + #[pin] + sleep: Sleep, +} + +impl RetryStream { + fn reset(self: Pin<&mut Self>) { + let this = self.project(); + *this.start = Instant::now(); + *this.i = 0; + *this.next_backoff = None; + } +} + +impl Stream for RetryStream { type Item = RetryState; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - if *this.i == 0 { - *this.start = Instant::now(); - *this.next_backoff = Some(cmp::min(*this.initial_backoff, *this.clamp_backoff)); - } else { - match this.next_backoff { - None => return Poll::Ready(None), - Some(next_backoff) => { - ready!(Pin::new(&mut this.sleep).poll(cx)); - *next_backoff = - cmp::min(next_backoff.mul_f64(*this.factor), *this.clamp_backoff); - } + let retry = this.retry; + + match this.next_backoff { + None if *this.i == 0 => { + *this.next_backoff = Some(cmp::min(retry.initial_backoff, retry.clamp_backoff)); + } + None => return Poll::Ready(None), + Some(next_backoff) => { + ready!(this.sleep.as_mut().poll(cx)); + *next_backoff = cmp::min(next_backoff.mul_f64(retry.factor), retry.clamp_backoff); } } - match *this.limit { + match retry.limit { RetryLimit::Tries(max_tries) if *this.i + 1 >= max_tries => *this.next_backoff = None, RetryLimit::Duration(max_duration) => { let elapsed = this.start.elapsed(); @@ -264,7 +271,7 @@ impl Stream for Retry { next_backoff: *this.next_backoff, }; if let Some(d) = *this.next_backoff { - this.sleep.reset(tokio::time::Instant::now() + d); + this.sleep.reset(Instant::now() + d); } *this.i += 1; Poll::Ready(Some(state)) @@ -280,7 +287,7 @@ pub struct RetryReader { offset: usize, error: Option, #[pin] - retry: Retry, + retry: RetryStream, #[pin] state: RetryReaderState, } @@ -316,7 +323,7 @@ where factory, offset: 0, error: None, - retry, + retry: retry.into_retry_stream(), state: RetryReaderState::Waiting, } } From 6561cb0a3d73024d6d1f8056db3180371f59e8a9 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 13:03:09 -0500 Subject: [PATCH 2/7] ore/retry: rename Retry::retry to Retry::retry_async To make space for a synchronous Retry::retry method. --- demo/billing/src/mz.rs | 2 +- src/coord/src/coord.rs | 4 +++- src/dataflow/src/source/s3.rs | 2 +- src/interchange/src/avro/schema.rs | 2 +- src/kafka-util/src/admin.rs | 2 +- src/materialized/src/telemetry.rs | 2 +- src/ore/src/retry.rs | 22 +++++++++---------- src/testdrive/src/action.rs | 4 ++-- src/testdrive/src/action/avro_ocf.rs | 2 +- .../src/action/kafka/add_partitions.rs | 2 +- .../src/action/kinesis/create_stream.rs | 2 +- src/testdrive/src/action/kinesis/ingest.rs | 2 +- .../src/action/kinesis/update_shards.rs | 2 +- .../src/action/postgres/verify_slot.rs | 2 +- src/testdrive/src/action/schema_registry.rs | 2 +- src/testdrive/src/action/sql.rs | 4 ++-- .../src/action/verify_timestamp_compaction.rs | 2 +- test/metabase/smoketest/src/main.rs | 6 ++--- test/perf-kinesis/src/kinesis.rs | 2 +- 19 files changed, 35 insertions(+), 33 deletions(-) diff --git a/demo/billing/src/mz.rs b/demo/billing/src/mz.rs index a8368e3f03986..e7665a5d585ff 100644 --- a/demo/billing/src/mz.rs +++ b/demo/billing/src/mz.rs @@ -155,7 +155,7 @@ pub async fn validate_sink( let count_input_view_query = format!("SELECT count(*) from {}", input_view); Retry::default() - .retry(|_| async { + .retry_async(|_| async { let count_check_sink: i64 = mz_client .query_one(&*count_check_sink_query, &[]) .await? diff --git a/src/coord/src/coord.rs b/src/coord/src/coord.rs index 2fa1596b2b47c..ff5602c06b9be 100644 --- a/src/coord/src/coord.rs +++ b/src/coord/src/coord.rs @@ -3913,7 +3913,9 @@ where for (conn, slot_names) in replication_slots_to_drop { // Try to drop the replication slots, but give up after a while. let _ = Retry::default() - .retry(|_state| postgres_util::drop_replication_slots(&conn, &slot_names)) + .retry_async(|_state| { + postgres_util::drop_replication_slots(&conn, &slot_names) + }) .await; } }); diff --git a/src/dataflow/src/source/s3.rs b/src/dataflow/src/source/s3.rs index 0b67497e894df..297e1a3f433af 100644 --- a/src/dataflow/src/source/s3.rs +++ b/src/dataflow/src/source/s3.rs @@ -312,7 +312,7 @@ async fn scan_bucket_task( let mut continuation_token = None; loop { let response = Retry::default() - .retry(|_| { + .retry_async(|_| { client .list_objects_v2() .bucket(&bucket) diff --git a/src/interchange/src/avro/schema.rs b/src/interchange/src/avro/schema.rs index 76928d03a7f87..3cf05ec1ddc91 100644 --- a/src/interchange/src/avro/schema.rs +++ b/src/interchange/src/avro/schema.rs @@ -330,7 +330,7 @@ impl SchemaCache { let ccsr_client = &self.ccsr_client; let response = Retry::default() .max_duration(Duration::from_secs(30)) - .retry(|state| async move { + .retry_async(|state| async move { let res = ccsr_client.get_schema_by_id(id).await; match res { Err(e) => { diff --git a/src/kafka-util/src/admin.rs b/src/kafka-util/src/admin.rs index 40bcc9ec6ad1d..9ff3abd3f4a8d 100644 --- a/src/kafka-util/src/admin.rs +++ b/src/kafka-util/src/admin.rs @@ -60,7 +60,7 @@ where // created with the default number partitions, and not the number of // partitions requested in `new_topic`. Retry::default() - .retry(|_| async { + .retry_async(|_| async { let metadata = client .inner() // N.B. It is extremely important not to ask specifically diff --git a/src/materialized/src/telemetry.rs b/src/materialized/src/telemetry.rs index 2908771bfaca0..d6e4196e969a1 100644 --- a/src/materialized/src/telemetry.rs +++ b/src/materialized/src/telemetry.rs @@ -112,7 +112,7 @@ async fn report_one(config: &Config) -> Result { let response: V1VersionResponse = Retry::default() .initial_backoff(Duration::from_secs(1)) .max_duration(config.interval) - .retry(|_state| async { + .retry_async(|_state| async { let query_result = config .coord_client .system_execute_one(&make_telemetry_query(config)) diff --git a/src/ore/src/retry.rs b/src/ore/src/retry.rs index 5f45826c81be4..50e25c2618bb7 100644 --- a/src/ore/src/retry.rs +++ b/src/ore/src/retry.rs @@ -28,7 +28,7 @@ //! use std::time::Duration; //! use ore::retry::Retry; //! -//! let res = Retry::default().retry(|state| async move { +//! let res = Retry::default().retry_async(|state| async move { //! if state.i == 3 { //! Ok(()) //! } else { @@ -46,7 +46,7 @@ //! use std::time::Duration; //! use ore::retry::Retry; //! -//! let res = Retry::default().max_tries(2).retry(|state| async move { +//! let res = Retry::default().max_tries(2).retry_async(|state| async move { //! if state.i == 3 { //! Ok(()) //! } else { @@ -174,7 +174,7 @@ impl Retry { /// The operation does not attempt to forcibly time out `f`, even if there /// is a maximum duration. If there is the possibility of `f` blocking /// forever, consider adding a timeout internally. - pub async fn retry(self, mut f: F) -> Result + pub async fn retry_async(self, mut f: F) -> Result where F: FnMut(RetryState) -> U, U: Future>, @@ -397,11 +397,11 @@ mod tests { use super::*; #[tokio::test] - async fn test_retry_success() { + async fn test_retry_async_success() { let mut states = vec![]; let res = Retry::default() .initial_backoff(Duration::from_millis(1)) - .retry(|state| { + .retry_async(|state| { states.push(state); async move { if state.i == 2 { @@ -433,12 +433,12 @@ mod tests { } #[tokio::test] - async fn test_retry_fail_max_tries() { + async fn test_retry_async_fail_max_tries() { let mut states = vec![]; let res = Retry::default() .initial_backoff(Duration::from_millis(1)) .max_tries(3) - .retry(|state| { + .retry_async(|state| { states.push(state); async { Err::<(), _>("injected") } }) @@ -464,12 +464,12 @@ mod tests { } #[tokio::test] - async fn test_retry_fail_max_duration() { + async fn test_retry_async_fail_max_duration() { let mut states = vec![]; let res = Retry::default() .initial_backoff(Duration::from_millis(5)) .max_duration(Duration::from_millis(10)) - .retry(|state| { + .retry_async(|state| { states.push(state); async { Err::<(), _>("injected") } }) @@ -504,13 +504,13 @@ mod tests { } #[tokio::test] - async fn test_retry_fail_clamp_backoff() { + async fn test_retry_async_fail_clamp_backoff() { let mut states = vec![]; let res = Retry::default() .initial_backoff(Duration::from_millis(1)) .clamp_backoff(Duration::from_millis(1)) .max_tries(4) - .retry(|state| { + .retry_async(|state| { states.push(state); async { Err::<(), _>("injected") } }) diff --git a/src/testdrive/src/action.rs b/src/testdrive/src/action.rs index 103aa67ee8ee8..e14e48feeb006 100644 --- a/src/testdrive/src/action.rs +++ b/src/testdrive/src/action.rs @@ -247,7 +247,7 @@ impl State { async fn delete_bucket_objects(&self, bucket: String) -> Result<(), Error> { Retry::default() .max_duration(self.default_timeout) - .retry(|_| async { + .retry_async(|_| async { // loop until error or response has no continuation token let mut continuation_token = None; loop { @@ -286,7 +286,7 @@ impl State { pub async fn reset_sqs(&self) -> Result<(), Error> { Retry::default() .max_duration(self.default_timeout) - .retry(|_| async { + .retry_async(|_| async { for queue_url in &self.sqs_queues_created { self.sqs_client .delete_queue() diff --git a/src/testdrive/src/action/avro_ocf.rs b/src/testdrive/src/action/avro_ocf.rs index 64519ede446ce..9508bdec12fb8 100644 --- a/src/testdrive/src/action/avro_ocf.rs +++ b/src/testdrive/src/action/avro_ocf.rs @@ -154,7 +154,7 @@ impl Action for VerifyAction { async fn redo(&self, state: &mut State) -> Result<(), String> { let path = Retry::default() .max_duration(state.default_timeout) - .retry(|_| async { + .retry_async(|_| async { let row = state .pgclient .query_one( diff --git a/src/testdrive/src/action/kafka/add_partitions.rs b/src/testdrive/src/action/kafka/add_partitions.rs index 221ec1a699e93..910f6d92f9c2a 100644 --- a/src/testdrive/src/action/kafka/add_partitions.rs +++ b/src/testdrive/src/action/kafka/add_partitions.rs @@ -85,7 +85,7 @@ impl Action for AddPartitionsAction { Retry::default() .max_duration(state.default_timeout) - .retry(|_| async { + .retry_async(|_| async { let metadata = state .kafka_producer .client() diff --git a/src/testdrive/src/action/kinesis/create_stream.rs b/src/testdrive/src/action/kinesis/create_stream.rs index e2e7854d7cebd..1f7ea524b2fb4 100644 --- a/src/testdrive/src/action/kinesis/create_stream.rs +++ b/src/testdrive/src/action/kinesis/create_stream.rs @@ -56,7 +56,7 @@ impl Action for CreateStreamAction { Retry::default() .max_duration(cmp::max(state.default_timeout, Duration::from_secs(60))) - .retry(|_| async { + .retry_async(|_| async { let description = state .kinesis_client .describe_stream() diff --git a/src/testdrive/src/action/kinesis/ingest.rs b/src/testdrive/src/action/kinesis/ingest.rs index 547a4e1d14eaf..4a951c562d743 100644 --- a/src/testdrive/src/action/kinesis/ingest.rs +++ b/src/testdrive/src/action/kinesis/ingest.rs @@ -59,7 +59,7 @@ impl Action for IngestAction { // be prepared to back off. Retry::default() .max_duration(state.default_timeout) - .retry(|_| async { + .retry_async(|_| async { match state .kinesis_client .put_record() diff --git a/src/testdrive/src/action/kinesis/update_shards.rs b/src/testdrive/src/action/kinesis/update_shards.rs index 7df4b8572b27b..f76dc4d19f191 100644 --- a/src/testdrive/src/action/kinesis/update_shards.rs +++ b/src/testdrive/src/action/kinesis/update_shards.rs @@ -60,7 +60,7 @@ impl Action for UpdateShardCountAction { // Verify the current shard count. Retry::default() .max_duration(cmp::max(state.default_timeout, Duration::from_secs(60))) - .retry(|_| async { + .retry_async(|_| async { // Wait for shards to stop updating. let description = state .kinesis_client diff --git a/src/testdrive/src/action/postgres/verify_slot.rs b/src/testdrive/src/action/postgres/verify_slot.rs index 24457781e56e1..192041dc24e3d 100644 --- a/src/testdrive/src/action/postgres/verify_slot.rs +++ b/src/testdrive/src/action/postgres/verify_slot.rs @@ -54,7 +54,7 @@ impl Action for VerifySlotAction { Retry::default() .initial_backoff(Duration::from_millis(50)) .max_duration(cmp::max(state.default_timeout, Duration::from_secs(3))) - .retry(|_| async { + .retry_async(|_| async { println!(">> checking for postgres replication slot {}", &self.slot); let rows = client .query( diff --git a/src/testdrive/src/action/schema_registry.rs b/src/testdrive/src/action/schema_registry.rs index 01c3ff614e534..e9469c31598d9 100644 --- a/src/testdrive/src/action/schema_registry.rs +++ b/src/testdrive/src/action/schema_registry.rs @@ -95,7 +95,7 @@ impl Action for WaitSchemaAction { .initial_backoff(Duration::from_millis(50)) .factor(1.5) .max_duration(self.context.timeout) - .retry(|_| async { + .retry_async(|_| async { state .ccsr_client .get_schema_by_subject(&self.schema) diff --git a/src/testdrive/src/action/sql.rs b/src/testdrive/src/action/sql.rs index 4ed34bbc007af..31de16f445e11 100644 --- a/src/testdrive/src/action/sql.rs +++ b/src/testdrive/src/action/sql.rs @@ -133,7 +133,7 @@ impl Action for SqlAction { .max_duration(self.context.timeout), false => Retry::default().max_tries(1), } - .retry(|retry_state| async move { + .retry_async(|retry_state| async move { match self.try_redo(pgclient, &query).await { Ok(()) => { if retry_state.i != 0 { @@ -366,7 +366,7 @@ impl Action for FailSqlAction { .factor(self.context.backoff_factor) .max_duration(self.context.timeout), false => Retry::default().max_tries(1), - }.retry(|retry_state| async move { + }.retry_async(|retry_state| async move { match self.try_redo(pgclient, &query).await { Ok(()) => { if retry_state.i != 0 { diff --git a/src/testdrive/src/action/verify_timestamp_compaction.rs b/src/testdrive/src/action/verify_timestamp_compaction.rs index a511b9c679d12..33bf15617d259 100644 --- a/src/testdrive/src/action/verify_timestamp_compaction.rs +++ b/src/testdrive/src/action/verify_timestamp_compaction.rs @@ -65,7 +65,7 @@ impl Action for VerifyTimestampsAction { Retry::default() .initial_backoff(Duration::from_secs(1)) .max_duration(Duration::from_secs(10)) - .retry(|retry_state| { + .retry_async(|retry_state| { let initial_highest = initial_highest_base.clone(); async move { let mut catalog = Catalog::open_debug(path, NOW_ZERO.clone()) diff --git a/test/metabase/smoketest/src/main.rs b/test/metabase/smoketest/src/main.rs index 2d08f6129d73f..42be0d7f04d00 100644 --- a/test/metabase/smoketest/src/main.rs +++ b/test/metabase/smoketest/src/main.rs @@ -25,7 +25,7 @@ const DUMMY_PASSWORD: &str = "dummydummy1"; async fn connect_materialized() -> Result { Retry::default() - .retry(|_| async { + .retry_async(|_| async { let res = TcpStream::connect("materialized:6875").await; if let Err(e) = &res { log::debug!("error connecting to materialized: {}", e); @@ -52,7 +52,7 @@ async fn connect_metabase() -> Result { metabase::Client::new("http://metabase:3000").context("failed creating metabase client")?; let setup_token = Retry::default() .max_duration(Duration::from_secs(30)) - .retry(|_| async { + .retry_async(|_| async { let res = client.session_properties().await; if let Err(e) = &res { log::debug!("error connecting to metabase: {}", e); @@ -157,7 +157,7 @@ async fn main() -> Result<(), anyhow::Error> { // expose when it is complete, so just retry a few times waiting for the // metadata we expect. Retry::default() - .retry(|_| async { + .retry_async(|_| async { let mut metadata = metabase_client.database_metadata(mzdb.id).await?; metadata.tables.retain(|t| t.schema == "public"); metadata.tables.sort_by(|a, b| a.name.cmp(&b.name)); diff --git a/test/perf-kinesis/src/kinesis.rs b/test/perf-kinesis/src/kinesis.rs index 46d2867297581..c7d0b08566a49 100644 --- a/test/perf-kinesis/src/kinesis.rs +++ b/test/perf-kinesis/src/kinesis.rs @@ -39,7 +39,7 @@ pub async fn create_stream( let stream_arn = Retry::default() .max_duration(Duration::from_secs(120)) - .retry(|_| async { + .retry_async(|_| async { let description = &kinesis_client .describe_stream() .stream_name(stream_name) From d6aab5e1f047ed82e3470984606b6df32052ca9f Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 13:11:23 -0500 Subject: [PATCH 3/7] ore/retry: add a synchronous Retry::retry operation This operates identically to `Retry::retry_async` but uses `std::thread::sleep` to wait rather than Tokio timers. --- src/ore/src/retry.rs | 191 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 181 insertions(+), 10 deletions(-) diff --git a/src/ore/src/retry.rs b/src/ore/src/retry.rs index 50e25c2618bb7..a8d34ca13423f 100644 --- a/src/ore/src/retry.rs +++ b/src/ore/src/retry.rs @@ -24,42 +24,40 @@ //! Retry a contrived fallible operation until it succeeds: //! //! ``` -//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! use std::time::Duration; //! use ore::retry::Retry; //! -//! let res = Retry::default().retry_async(|state| async move { +//! let res = Retry::default().retry(|state| { //! if state.i == 3 { //! Ok(()) //! } else { //! Err("contrived failure") //! } -//! }).await; +//! }); //! assert_eq!(res, Ok(())); -//! # }); //! ``` //! //! Limit the number of retries such that success is never observed: //! //! ``` -//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! use std::time::Duration; //! use ore::retry::Retry; //! -//! let res = Retry::default().max_tries(2).retry_async(|state| async move { +//! let res = Retry::default().max_tries(2).retry(|state| { //! if state.i == 3 { //! Ok(()) //! } else { //! Err("contrived failure") //! } -//! }).await; +//! }); //! assert_eq!(res, Err("contrived failure")); -//! # }); +//! ``` use std::cmp; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::thread; use futures::{ready, Stream, StreamExt}; use pin_project::pin_project; @@ -155,8 +153,7 @@ impl Retry { self } - /// Retries the asynchronous, fallible operation `f` according to the - /// configured policy. + /// Retries the fallible operation `f` according to the configured policy. /// /// The `retry` method invokes `f` repeatedly until it succeeds or until the /// maximum duration or tries have been reached, as configured via @@ -174,6 +171,43 @@ impl Retry { /// The operation does not attempt to forcibly time out `f`, even if there /// is a maximum duration. If there is the possibility of `f` blocking /// forever, consider adding a timeout internally. + pub fn retry(self, mut f: F) -> Result + where + F: FnMut(RetryState) -> Result, + { + let start = Instant::now(); + let mut i = 0; + let mut next_backoff = Some(cmp::min(self.initial_backoff, self.clamp_backoff)); + loop { + match self.limit { + RetryLimit::Tries(max_tries) if i + 1 >= max_tries => next_backoff = None, + RetryLimit::Duration(max_duration) => { + let elapsed = start.elapsed(); + if elapsed > max_duration { + next_backoff = None; + } else if elapsed + next_backoff.unwrap() > max_duration { + next_backoff = Some(max_duration - elapsed); + } + } + _ => (), + } + let state = RetryState { i, next_backoff }; + match f(state) { + Ok(t) => return Ok(t), + Err(e) => match &mut next_backoff { + None => return Err(e), + Some(next_backoff) => { + thread::sleep(*next_backoff); + *next_backoff = + cmp::min(next_backoff.mul_f64(self.factor), self.clamp_backoff); + } + }, + } + i += 1; + } + } + + /// Like [`Retry::retry`] but for asynchronous operations. pub async fn retry_async(self, mut f: F) -> Result where F: FnMut(RetryState) -> U, @@ -396,6 +430,39 @@ enum RetryLimit { mod tests { use super::*; + #[test] + fn test_retry_success() { + let mut states = vec![]; + let res = Retry::default() + .initial_backoff(Duration::from_millis(1)) + .retry(|state| { + states.push(state); + if state.i == 2 { + Ok(()) + } else { + Err::<(), _>("injected") + } + }); + assert_eq!(res, Ok(())); + assert_eq!( + states, + &[ + RetryState { + i: 0, + next_backoff: Some(Duration::from_millis(1)) + }, + RetryState { + i: 1, + next_backoff: Some(Duration::from_millis(2)) + }, + RetryState { + i: 2, + next_backoff: Some(Duration::from_millis(4)) + }, + ] + ); + } + #[tokio::test] async fn test_retry_async_success() { let mut states = vec![]; @@ -432,6 +499,36 @@ mod tests { ); } + #[tokio::test] + async fn test_retry_fail_max_tries() { + let mut states = vec![]; + let res = Retry::default() + .initial_backoff(Duration::from_millis(1)) + .max_tries(3) + .retry(|state| { + states.push(state); + Err::<(), _>("injected") + }); + assert_eq!(res, Err("injected")); + assert_eq!( + states, + &[ + RetryState { + i: 0, + next_backoff: Some(Duration::from_millis(1)) + }, + RetryState { + i: 1, + next_backoff: Some(Duration::from_millis(2)) + }, + RetryState { + i: 2, + next_backoff: None + }, + ] + ); + } + #[tokio::test] async fn test_retry_async_fail_max_tries() { let mut states = vec![]; @@ -463,6 +560,45 @@ mod tests { ); } + #[test] + fn test_retry_fail_max_duration() { + let mut states = vec![]; + let res = Retry::default() + .initial_backoff(Duration::from_millis(5)) + .max_duration(Duration::from_millis(10)) + .retry(|state| { + states.push(state); + Err::<(), _>("injected") + }); + assert_eq!(res, Err("injected")); + + // The first try should indicate a next backoff of exactly 5ms. + assert_eq!( + states[0], + RetryState { + i: 0, + next_backoff: Some(Duration::from_millis(5)) + }, + ); + + // The next try should indicate a next backoff of between 0 and 5ms. The + // exact value depends on how long it took for the first try itself to + // execute. + assert_eq!(states[1].i, 1); + let backoff = states[1].next_backoff.unwrap(); + assert!(backoff > Duration::from_millis(0) && backoff < Duration::from_millis(5)); + + // The final try should indicate that the operation is complete with + // a next backoff of None. + assert_eq!( + states[2], + RetryState { + i: 2, + next_backoff: None, + }, + ); + } + #[tokio::test] async fn test_retry_async_fail_max_duration() { let mut states = vec![]; @@ -503,6 +639,41 @@ mod tests { ); } + #[test] + fn test_retry_fail_clamp_backoff() { + let mut states = vec![]; + let res = Retry::default() + .initial_backoff(Duration::from_millis(1)) + .clamp_backoff(Duration::from_millis(1)) + .max_tries(4) + .retry(|state| { + states.push(state); + Err::<(), _>("injected") + }); + assert_eq!(res, Err("injected")); + assert_eq!( + states, + &[ + RetryState { + i: 0, + next_backoff: Some(Duration::from_millis(1)) + }, + RetryState { + i: 1, + next_backoff: Some(Duration::from_millis(1)) + }, + RetryState { + i: 2, + next_backoff: Some(Duration::from_millis(1)) + }, + RetryState { + i: 3, + next_backoff: None + }, + ] + ); + } + #[tokio::test] async fn test_retry_async_fail_clamp_backoff() { let mut states = vec![]; From d65340a431ce3c03abc45c072fc09efff5d69617 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 16:31:06 -0500 Subject: [PATCH 4/7] ore/retry: use newly-stabilized Duration::MAX constant The duration_constants feature actually isn't stable yet, but Duration::MAX got stabilized separately. h/t @umanwizard --- src/ore/src/retry.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/ore/src/retry.rs b/src/ore/src/retry.rs index a8d34ca13423f..da01e07fdb668 100644 --- a/src/ore/src/retry.rs +++ b/src/ore/src/retry.rs @@ -64,10 +64,6 @@ use pin_project::pin_project; use tokio::io::{AsyncRead, ReadBuf}; use tokio::time::{self, Duration, Instant, Sleep}; -// TODO(benesch): remove this if the `duration_constants` feature stabilizes. -// See: https://github.com/rust-lang/rust/issues/57391 -const MAX_DURATION: Duration = Duration::from_secs(u64::MAX); - /// The state of a retry operation constructed with [`Retry`]. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct RetryState { @@ -243,7 +239,7 @@ impl Default for Retry { Retry { initial_backoff: Duration::from_millis(125), factor: 2.0, - clamp_backoff: MAX_DURATION, + clamp_backoff: Duration::MAX, limit: RetryLimit::Duration(Duration::from_secs(30)), } } From f6a6c66e3689fd80d3180028623c3505b859eeff Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 13:17:38 -0500 Subject: [PATCH 5/7] kgen: use exponential backoff to retry when queue is full This should improve throughput when the rkdkafka producer queue fills up. --- src/kafka-util/src/bin/kgen.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/kafka-util/src/bin/kgen.rs b/src/kafka-util/src/bin/kgen.rs index 637ef82240059..fdfc27674ba20 100644 --- a/src/kafka-util/src/bin/kgen.rs +++ b/src/kafka-util/src/bin/kgen.rs @@ -13,7 +13,6 @@ use std::convert::{TryFrom, TryInto}; use std::iter; use std::ops::Add; use std::rc::Rc; -use std::thread; use std::time::Duration; use anyhow::bail; @@ -37,6 +36,7 @@ use mz_avro::schema::{SchemaNode, SchemaPiece, SchemaPieceOrNamed}; use mz_avro::types::{DecimalValue, Value}; use mz_avro::Schema; use ore::cast::CastFrom; +use ore::retry::Retry; struct RandomAvroGenerator<'a> { // generators @@ -706,20 +706,20 @@ async fn main() -> anyhow::Result<()> { if args.partitions_round_robin != 0 { rec = rec.partition((i % args.partitions_round_robin) as i32); } + let mut rec = Some(rec); - loop { - match producer.send(rec) { - Ok(()) => break, - Err((KafkaError::MessageProduction(RDKafkaErrorCode::QueueFull), rec2)) => { - rec = rec2; - thread::sleep(Duration::from_secs(1)); + Retry::default() + .clamp_backoff(Duration::from_secs(1)) + .retry(|_| match producer.send(rec.take().unwrap()) { + Ok(()) => Ok(()), + Err((e @ KafkaError::MessageProduction(RDKafkaErrorCode::QueueFull), r)) => { + rec = Some(r); + Err(e) } - Err((e, _)) => { - return Err(e.into()); - } - } - } + Err((e, _)) => Err(e.into()), + })?; } + producer.flush(Timeout::Never); Ok(()) } From 1e444ab14c288c3b3db5c597136e3aa0ec78ffbc Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 1 Jan 2022 14:22:26 -0500 Subject: [PATCH 6/7] kgen: parallelize Teach kgen to optionally spawn multiple threads, defaulting to the number of physical CPUs available on the machine. Thread safety made this surprisingly irritating. This commit refactors the Avro generator so that the ThreadRng is only ever passed as a parameter, never stored, as otherwise the Avro generator does not implement `Send`. It also introduces a rather goofy `Generator` trait whose only purpose is to make it possible to clone the generator closures. --- Cargo.lock | 37 +++- src/kafka-util/Cargo.toml | 2 + src/kafka-util/src/bin/kgen.rs | 316 ++++++++++++++++++--------------- 3 files changed, 201 insertions(+), 154 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c001265a3c1e..71e10242f8d4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -873,12 +873,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "const_fn" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b9d6de7f49e22cf97ad17fc4036ece69300032f45f78f30b4a4482cdc3f4a6" - [[package]] name = "coord" version = "0.0.0" @@ -1038,6 +1032,20 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae5588f6b3c3cb05239e90bd110f257254aecd01e4635400391aeae07497845" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -1061,18 +1069,27 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.1" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1aaa739f95311c2c7887a76863f500026092fb1dce0161dab577e559ef3569d" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" dependencies = [ "cfg-if", - "const_fn", "crossbeam-utils", "lazy_static", "memoffset", "scopeguard", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b10ddc024425c88c2ad148c1b0fd53f4c6d38db9697c9f1588381212fa657c9" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.5" @@ -2417,8 +2434,10 @@ dependencies = [ "ccsr", "chrono", "clap", + "crossbeam", "futures", "mz-avro", + "num_cpus", "ore", "rand", "rdkafka", diff --git a/src/kafka-util/Cargo.toml b/src/kafka-util/Cargo.toml index e00787c208968..c4effd03b0841 100644 --- a/src/kafka-util/Cargo.toml +++ b/src/kafka-util/Cargo.toml @@ -10,8 +10,10 @@ anyhow = "1.0.52" ccsr = { path = "../ccsr" } chrono = { version = "0.4.0", default-features = false, features = ["std"] } clap = "2.34.0" +crossbeam = "0.8.1" futures = "0.3.19" mz-avro = { path = "../avro" } +num_cpus = "1.13.1" ore = { path = "../ore", features = ["network"] } rand = "0.8.4" rdkafka = { git = "https://github.com/fede1024/rust-rdkafka.git", features = ["cmake-build", "libz-static"] } diff --git a/src/kafka-util/src/bin/kgen.rs b/src/kafka-util/src/bin/kgen.rs index fdfc27674ba20..07b26e451856f 100644 --- a/src/kafka-util/src/bin/kgen.rs +++ b/src/kafka-util/src/bin/kgen.rs @@ -7,17 +7,17 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use std::cell::RefCell; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::iter; use std::ops::Add; -use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use anyhow::bail; use chrono::{NaiveDate, NaiveDateTime}; use clap::arg_enum; +use crossbeam::thread; use rand::distributions::{ uniform::SampleUniform, Alphanumeric, Bernoulli, Uniform, WeightedIndex, }; @@ -38,56 +38,80 @@ use mz_avro::Schema; use ore::cast::CastFrom; use ore::retry::Retry; +trait Generator: FnMut(&mut ThreadRng) -> R + Send + Sync { + fn clone_box(&self) -> Box>; +} + +impl Generator for F +where + F: FnMut(&mut ThreadRng) -> R + Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} + +impl Clone for Box> +where + R: 'static, +{ + fn clone(&self) -> Box> { + (**self).clone_box() + } +} + +#[derive(Clone)] struct RandomAvroGenerator<'a> { - // generators - ints: HashMap<*const SchemaPiece, Box i32>>, - longs: HashMap<*const SchemaPiece, Box i64>>, - strings: HashMap<*const SchemaPiece, Box)>>, - bytes: HashMap<*const SchemaPiece, Box)>>, - unions: HashMap<*const SchemaPiece, Box usize>>, - enums: HashMap<*const SchemaPiece, Box usize>>, - bools: HashMap<*const SchemaPiece, Box bool>>, - floats: HashMap<*const SchemaPiece, Box f32>>, - doubles: HashMap<*const SchemaPiece, Box f64>>, - decimals: HashMap<*const SchemaPiece, Box Vec>>, - array_lens: HashMap<*const SchemaPiece, Box usize>>, - _map_keys: HashMap<*const SchemaPiece, Box)>>, + // Generator functions for each piece of the schema. These map keys are + // morally `*const SchemaPiece`s, but represented as `usize`s so that they + // implement `Send`. + ints: HashMap>>, + longs: HashMap>>, + strings: HashMap>>>, + bytes: HashMap>>>, + unions: HashMap>>, + enums: HashMap>>, + bools: HashMap>>, + floats: HashMap>>, + doubles: HashMap>>, + decimals: HashMap>>>, + array_lens: HashMap>>, schema: SchemaNode<'a>, } impl<'a> RandomAvroGenerator<'a> { - fn gen_inner(&mut self, node: SchemaNode) -> Value { - let p: *const _ = &*node.inner; + fn gen_inner(&mut self, node: SchemaNode, rng: &mut ThreadRng) -> Value { + let p = &*node.inner as *const _ as usize; match node.inner { SchemaPiece::Null => Value::Null, SchemaPiece::Boolean => { - let val = self.bools.get_mut(&p).unwrap()(); + let val = self.bools.get_mut(&p).unwrap()(rng); Value::Boolean(val) } SchemaPiece::Int => { - let val = self.ints.get_mut(&p).unwrap()(); + let val = self.ints.get_mut(&p).unwrap()(rng); Value::Int(val) } SchemaPiece::Long => { - let val = self.longs.get_mut(&p).unwrap()(); + let val = self.longs.get_mut(&p).unwrap()(rng); Value::Long(val) } SchemaPiece::Float => { - let val = self.floats.get_mut(&p).unwrap()(); + let val = self.floats.get_mut(&p).unwrap()(rng); Value::Float(val) } SchemaPiece::Double => { - let val = self.doubles.get_mut(&p).unwrap()(); + let val = self.doubles.get_mut(&p).unwrap()(rng); Value::Double(val) } SchemaPiece::Date => { - let days = self.ints.get_mut(&p).unwrap()(); + let days = self.ints.get_mut(&p).unwrap()(rng); let val = NaiveDate::from_ymd(1970, 1, 1).add(chrono::Duration::days(days as i64)); Value::Date(val) } SchemaPiece::TimestampMilli => { - let millis = self.longs.get_mut(&p).unwrap()(); + let millis = self.longs.get_mut(&p).unwrap()(rng); let seconds = millis / 1000; let fraction = (millis % 1000) as u32; @@ -95,7 +119,7 @@ impl<'a> RandomAvroGenerator<'a> { Value::Timestamp(val) } SchemaPiece::TimestampMicro => { - let micros = self.longs.get_mut(&p).unwrap()(); + let micros = self.longs.get_mut(&p).unwrap()(rng); let seconds = micros / 1_000_000; let fraction = (micros % 1_000_000) as u32; @@ -107,8 +131,7 @@ impl<'a> RandomAvroGenerator<'a> { scale, fixed_size: _, } => { - let f = self.decimals.get_mut(&p).unwrap(); - let unscaled = f(); + let unscaled = self.decimals.get_mut(&p).unwrap()(rng); Value::Decimal(DecimalValue { unscaled, precision: *precision, @@ -116,24 +139,20 @@ impl<'a> RandomAvroGenerator<'a> { }) } SchemaPiece::Bytes => { - let f = self.bytes.get_mut(&p).unwrap(); - let mut val = vec![]; - f(&mut val); + let val = self.bytes.get_mut(&p).unwrap()(rng); Value::Bytes(val) } SchemaPiece::String => { - let f = self.strings.get_mut(&p).unwrap(); - let mut buf = vec![]; - f(&mut buf); + let buf = self.strings.get_mut(&p).unwrap()(rng); let val = String::from_utf8(buf).unwrap(); Value::String(val) } SchemaPiece::Json => unreachable!(), SchemaPiece::Uuid => unreachable!(), SchemaPiece::Array(inner) => { - let len = self.array_lens.get_mut(&p).unwrap()(); + let len = self.array_lens.get_mut(&p).unwrap()(rng); let next = node.step(&**inner); - let inner_vals = (0..len).map(move |_| self.gen_inner(next)).collect(); + let inner_vals = (0..len).map(move |_| self.gen_inner(next, rng)).collect(); Value::Array(inner_vals) } SchemaPiece::Map(_inner) => { @@ -153,13 +172,13 @@ impl<'a> RandomAvroGenerator<'a> { unreachable!() } SchemaPiece::Union(us) => { - let index = self.unions.get_mut(&p).unwrap()(); + let index = self.unions.get_mut(&p).unwrap()(rng); let next = node.step(&us.variants()[index]); let null_variant = us .variants() .iter() .position(|v| v == &SchemaPieceOrNamed::Piece(SchemaPiece::Null)); - let inner = Box::new(self.gen_inner(next)); + let inner = Box::new(self.gen_inner(next, rng)); Value::Union { index, inner, @@ -189,21 +208,21 @@ impl<'a> RandomAvroGenerator<'a> { .map(|f| { let k = f.name.clone(); let next = node.step(&f.schema); - let v = self.gen_inner(next); + let v = self.gen_inner(next, rng); (k, v) }) .collect(); Value::Record(fields) } SchemaPiece::Enum { symbols, .. } => { - let i = self.enums.get_mut(&p).unwrap()(); + let i = self.enums.get_mut(&p).unwrap()(rng); Value::Enum(i, symbols[i].clone()) } SchemaPiece::Fixed { size: _ } => unreachable!(), } } - pub fn gen(&mut self) -> Value { - self.gen_inner(self.schema) + pub fn gen(&mut self, rng: &mut ThreadRng) -> Value { + self.gen_inner(self.schema, rng) } fn new_inner( &mut self, @@ -211,21 +230,15 @@ impl<'a> RandomAvroGenerator<'a> { annotations: &Map, field_name: Option<&str>, ) { - let rng = Rc::new(RefCell::new(thread_rng())); - fn bool_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut() -> bool { + fn bool_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> bool + Clone { let x = json.as_f64().unwrap(); let dist = Bernoulli::new(x).unwrap(); - move || dist.sample(&mut *rng.borrow_mut()) + move |rng| dist.sample(rng) } - fn integral_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut() -> T + fn integral_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> T + Clone where - T: SampleUniform + TryFrom, + T: SampleUniform + TryFrom + Clone, + T::Sampler: Clone, >::Error: std::fmt::Debug, { let x = json.as_array().unwrap(); @@ -234,57 +247,42 @@ impl<'a> RandomAvroGenerator<'a> { x[1].as_i64().unwrap().try_into().unwrap(), ); let dist = Uniform::new_inclusive(min, max); - move || dist.sample(&mut *rng.borrow_mut()) + move |rng| dist.sample(rng) } - fn float_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut() -> f32 { + fn float_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> f32 + Clone { let x = json.as_array().unwrap(); let (min, max) = (x[0].as_f64().unwrap() as f32, x[1].as_f64().unwrap() as f32); let dist = Uniform::new_inclusive(min, max); - move || dist.sample(&mut *rng.borrow_mut()) + move |rng| dist.sample(rng) } - fn double_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut() -> f64 { + fn double_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> f64 + Clone { let x = json.as_array().unwrap(); let (min, max) = (x[0].as_f64().unwrap(), x[1].as_f64().unwrap()); let dist = Uniform::new_inclusive(min, max); - move || dist.sample(&mut *rng.borrow_mut()) + move |rng| dist.sample(rng) } - fn string_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut(&mut Vec) { - let mut len = integral_dist::(json, rng.clone()); - move |v| { - let len = len(); + fn string_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> Vec + Clone { + let mut len = integral_dist::(json); + move |rng| { + let len = len(rng); let cd = Alphanumeric; - let sample = || cd.sample(&mut *rng.borrow_mut()) as u8; - v.clear(); - v.extend(iter::repeat_with(sample).take(len)); + iter::repeat_with(|| cd.sample(rng) as u8) + .take(len) + .collect() } } - fn bytes_dist( - json: &serde_json::Value, - rng: Rc>, - ) -> impl FnMut(&mut Vec) { - let mut len = integral_dist::(json, rng.clone()); - move |v| { - let len = len(); + fn bytes_dist(json: &serde_json::Value) -> impl FnMut(&mut ThreadRng) -> Vec + Clone { + let mut len = integral_dist::(json); + move |rng| { + let len = len(rng); let bd = Uniform::new_inclusive(0, 255); - let sample = || bd.sample(&mut *rng.borrow_mut()); - v.clear(); - v.extend(iter::repeat_with(sample).take(len)); + iter::repeat_with(|| bd.sample(rng)).take(len).collect() } } fn decimal_dist( json: &serde_json::Value, - rng: Rc>, precision: usize, - ) -> impl FnMut() -> Vec { + ) -> impl FnMut(&mut ThreadRng) -> Vec + Clone { let x = json.as_array().unwrap(); let (min, max): (i64, i64) = (x[0].as_i64().unwrap(), x[1].as_i64().unwrap()); // Ensure values fit within precision bounds. @@ -304,9 +302,9 @@ impl<'a> RandomAvroGenerator<'a> { precision ); let dist = Uniform::::new_inclusive(min, max); - move || dist.sample(&mut *rng.borrow_mut()).to_be_bytes().to_vec() + move |rng| dist.sample(rng).to_be_bytes().to_vec() } - let p: *const _ = &*node.inner; + let p = &*node.inner as *const _ as usize; let dist_json = field_name.and_then(|fn_| annotations.get(fn_)); let err = format!( @@ -316,23 +314,23 @@ impl<'a> RandomAvroGenerator<'a> { match node.inner { SchemaPiece::Null => {} SchemaPiece::Boolean => { - let dist = bool_dist(dist_json.expect(&err), rng); + let dist = bool_dist(dist_json.expect(&err)); self.bools.insert(p, Box::new(dist)); } SchemaPiece::Int => { - let dist = integral_dist(dist_json.expect(&err), rng); + let dist = integral_dist(dist_json.expect(&err)); self.ints.insert(p, Box::new(dist)); } SchemaPiece::Long => { - let dist = integral_dist(dist_json.expect(&err), rng); + let dist = integral_dist(dist_json.expect(&err)); self.longs.insert(p, Box::new(dist)); } SchemaPiece::Float => { - let dist = float_dist(dist_json.expect(&err), rng); + let dist = float_dist(dist_json.expect(&err)); self.floats.insert(p, Box::new(dist)); } SchemaPiece::Double => { - let dist = double_dist(dist_json.expect(&err), rng); + let dist = double_dist(dist_json.expect(&err)); self.doubles.insert(p, Box::new(dist)); } SchemaPiece::Date => {} @@ -343,21 +341,21 @@ impl<'a> RandomAvroGenerator<'a> { scale: _, fixed_size: _, } => { - let dist = decimal_dist(dist_json.expect(&err), rng, *precision); + let dist = decimal_dist(dist_json.expect(&err), *precision); self.decimals.insert(p, Box::new(dist)); } SchemaPiece::Bytes => { let len_dist_json = annotations .get(&format!("{}.len", field_name.unwrap())) .unwrap(); - let dist = bytes_dist(len_dist_json, rng); + let dist = bytes_dist(len_dist_json); self.bytes.insert(p, Box::new(dist)); } SchemaPiece::String => { let len_dist_json = annotations .get(&format!("{}.len", field_name.unwrap())) .unwrap(); - let dist = string_dist(len_dist_json, rng); + let dist = string_dist(len_dist_json); self.strings.insert(p, Box::new(dist)); } SchemaPiece::Json => unimplemented!(), @@ -365,7 +363,7 @@ impl<'a> RandomAvroGenerator<'a> { SchemaPiece::Array(inner) => { let fn_ = field_name.unwrap(); let len_dist_json = annotations.get(&format!("{}.len", fn_)).unwrap(); - let len = integral_dist::(len_dist_json, rng); + let len = integral_dist::(len_dist_json); self.array_lens.insert(p, Box::new(len)); let item_fn = format!("{}[]", fn_); self.new_inner(node.step(&**inner), annotations, Some(&item_fn)) @@ -376,8 +374,7 @@ impl<'a> RandomAvroGenerator<'a> { assert!(variant_jsons.len() == us.variants().len()); let probabilities = variant_jsons.iter().map(|v| v.as_f64().unwrap()); let dist = WeightedIndex::new(probabilities).unwrap(); - let rng = rng; - let f = move || dist.sample(&mut *rng.borrow_mut()); + let f = move |rng: &mut ThreadRng| dist.sample(rng); self.unions.insert(p, Box::new(f)); let fn_ = field_name.unwrap(); for (i, v) in us.variants().iter().enumerate() { @@ -431,7 +428,6 @@ impl<'a> RandomAvroGenerator<'a> { doubles: Default::default(), decimals: Default::default(), array_lens: Default::default(), - _map_keys: Default::default(), schema: schema.top_node(), }; self_.new_inner(schema.top_node(), annotations.as_object().unwrap(), None); @@ -439,11 +435,11 @@ impl<'a> RandomAvroGenerator<'a> { } } +#[derive(Clone)] enum ValueGenerator<'a> { UniformBytes { len: Uniform, bytes: Uniform, - rng: ThreadRng, }, RandomAvro { inner: RandomAvroGenerator<'a>, @@ -453,9 +449,9 @@ enum ValueGenerator<'a> { } impl<'a> ValueGenerator<'a> { - pub fn next_value(&mut self, out: &mut Vec) { + pub fn next_value(&mut self, out: &mut Vec, rng: &mut ThreadRng) { match self { - ValueGenerator::UniformBytes { len, bytes, rng } => { + ValueGenerator::UniformBytes { len, bytes } => { let len = len.sample(rng); let sample = || bytes.sample(rng); out.clear(); @@ -466,7 +462,7 @@ impl<'a> ValueGenerator<'a> { schema, schema_id, } => { - let value = inner.gen(); + let value = inner.gen(rng); out.clear(); out.push(0); for b in schema_id.to_be_bytes().iter() { @@ -520,6 +516,11 @@ struct Args { /// instead. #[structopt(long, default_value = "0")] partitions_round_robin: usize, + /// The number of threads to use. + /// + /// If zero, uses the number of physical CPUs on the machine. + #[structopt(long, default_value = "0")] + threads: usize, // == Key arguments. == /// Format in which to generate keys. @@ -583,7 +584,7 @@ struct Args { async fn main() -> anyhow::Result<()> { let args: Args = ore::cli::parse_args(); - let mut value_gen = match args.value_format { + let value_gen = match args.value_format { ValueFormat::Bytes => { // Clap may one day be able to do this validation automatically. // See: https://github.com/clap-rs/clap/discussions/2039 @@ -596,9 +597,8 @@ async fn main() -> anyhow::Result<()> { let len = Uniform::new_inclusive(args.min_value_size.unwrap(), args.max_value_size.unwrap()); let bytes = Uniform::new_inclusive(0, 255); - let rng = thread_rng(); - ValueGenerator::UniformBytes { len, bytes, rng } + ValueGenerator::UniformBytes { len, bytes } } ValueFormat::Avro => { // Clap may one day be able to do this validation automatically. @@ -629,7 +629,7 @@ async fn main() -> anyhow::Result<()> { } }; - let mut key_gen = match args.key_format { + let key_gen = match args.key_format { KeyFormat::Avro => { // Clap may one day be able to do this validation automatically. // See: https://github.com/clap-rs/clap/discussions/2039 @@ -678,48 +678,74 @@ async fn main() -> anyhow::Result<()> { None }; - let producer: ThreadedProducer = ClientConfig::new() - .set("bootstrap.servers", args.bootstrap_server.to_string()) - .create()?; - let mut key_buf = vec![]; - let mut value_buf = vec![]; - let mut rng = thread_rng(); - for i in 0..args.num_records { - if !args.quiet && i % 10000 == 0 { - eprintln!("Generating message {}", i); - } + let threads = if args.threads == 0 { + num_cpus::get_physical() + } else { + args.threads + }; + println!("Using {} threads...", threads); - value_gen.next_value(&mut value_buf); - if let Some(key_gen) = key_gen.as_mut() { - key_gen.next_value(&mut key_buf); - } else if let Some(key_dist) = key_dist.as_ref() { - key_buf.clear(); - key_buf.extend(key_dist.sample(&mut rng).to_be_bytes().iter()) - } else { - key_buf.clear(); - key_buf.extend(u64::cast_from(i).to_be_bytes().iter()) - }; + let counter = AtomicUsize::new(0); + thread::scope(|scope| { + for thread in 0..threads { + let counter = &counter; + let topic = &args.topic; + let mut key_gen = key_gen.clone(); + let mut value_gen = value_gen.clone(); + let producer: ThreadedProducer = ClientConfig::new() + .set("bootstrap.servers", args.bootstrap_server.to_string()) + .create() + .unwrap(); + let mut key_buf = vec![]; + let mut value_buf = vec![]; + let mut n = args.num_records / threads; + if thread < args.num_records % threads { + n += 1; + } + scope.spawn(move |_| { + let mut rng = thread_rng(); + for _ in 0..n { + let i = counter.fetch_add(1, Ordering::Relaxed); + if !args.quiet && i % 100_000 == 0 { + eprintln!("Generating message {}", i); + } + value_gen.next_value(&mut value_buf, &mut rng); + if let Some(key_gen) = key_gen.as_mut() { + key_gen.next_value(&mut key_buf, &mut rng); + } else if let Some(key_dist) = key_dist.as_ref() { + key_buf.clear(); + key_buf.extend(key_dist.sample(&mut rng).to_be_bytes().iter()) + } else { + key_buf.clear(); + key_buf.extend(u64::cast_from(i).to_be_bytes().iter()) + }; - let mut rec = BaseRecord::to(&args.topic) - .key(&key_buf) - .payload(&value_buf); - if args.partitions_round_robin != 0 { - rec = rec.partition((i % args.partitions_round_robin) as i32); - } - let mut rec = Some(rec); + let mut rec = BaseRecord::to(&topic).key(&key_buf).payload(&value_buf); + if args.partitions_round_robin != 0 { + rec = rec.partition((i % args.partitions_round_robin) as i32); + } + let mut rec = Some(rec); - Retry::default() - .clamp_backoff(Duration::from_secs(1)) - .retry(|_| match producer.send(rec.take().unwrap()) { - Ok(()) => Ok(()), - Err((e @ KafkaError::MessageProduction(RDKafkaErrorCode::QueueFull), r)) => { - rec = Some(r); - Err(e) + Retry::default() + .clamp_backoff(Duration::from_secs(1)) + .retry(|_| match producer.send(rec.take().unwrap()) { + Ok(()) => Ok(()), + Err(( + e @ KafkaError::MessageProduction(RDKafkaErrorCode::QueueFull), + r, + )) => { + rec = Some(r); + Err(e) + } + Err((e, _)) => panic!("unexpected Kafka error: {}", e), + }) + .expect("unable to produce to Kafka"); } - Err((e, _)) => Err(e.into()), - })?; - } + producer.flush(Timeout::Never); + }); + } + }) + .unwrap(); - producer.flush(Timeout::Never); Ok(()) } From fcb0f7db6a8740b41feacbbc8a5a552375e4cd32 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Fri, 31 Dec 2021 21:34:46 -0500 Subject: [PATCH 7/7] benches/avro_ingest: use kgen directly rather than kafka-avro-generator kafka-avro-generator is going away soon. Use kgen directly instead. --- .../python/materialize/benches/avro_ingest.py | 139 +++++++++++++++--- 1 file changed, 120 insertions(+), 19 deletions(-) diff --git a/misc/python/materialize/benches/avro_ingest.py b/misc/python/materialize/benches/avro_ingest.py index fac1eef3fca92..2c9edda3e2854 100644 --- a/misc/python/materialize/benches/avro_ingest.py +++ b/misc/python/materialize/benches/avro_ingest.py @@ -10,6 +10,7 @@ """Ingest some Avro records, and report how long it takes""" import argparse +import json import os import time from typing import IO, NamedTuple @@ -87,13 +88,6 @@ def main() -> None: type=int, help="Number of Avro records to generate", ) - parser.add_argument( - "-d", - "--distribution", - default="benchmark", - type=str, - help="Distribution to use in kafka-avro-generator", - ) args = parser.parse_args() os.chdir(ROOT) @@ -101,7 +95,7 @@ def main() -> None: wait_for_confluent(args.confluent_host) - images = ["kafka-avro-generator", "materialized"] + images = ["kgen", "materialized"] deps = repo.resolve_dependencies([repo.images[name] for name in images]) deps.acquire() @@ -114,18 +108,18 @@ def main() -> None: ) docker_client.containers.run( - deps["kafka-avro-generator"].spec(), + deps["kgen"].spec(), [ - "-n", - str(args.records), - "-b", - f"{args.confluent_host}:9092", - "-r", - f"http://{args.confluent_host}:8081", - "-t", - "bench_data", - "-d", - args.distribution, + f"--num-records={args.records}", + f"--bootstrap-server={args.confluent_host}:9092", + f"--schema-registry-url=http://{args.confluent_host}:8081", + "--topic=bench_data", + "--keys=avro", + "--values=avro", + f"--avro-schema={VALUE_SCHEMA}", + f"--avro-distribution={VALUE_DISTRIBUTION}", + f"--avro-key-schema={KEY_SCHEMA}", + f"--avro-key-distribution={KEY_DISTRIBUTION}", ], network_mode="host", ) @@ -158,5 +152,112 @@ def main() -> None: prev = print_stats(mz_container, prev, results_file) +KEY_SCHEMA = json.dumps( + { + "name": "testrecordkey", + "type": "record", + "namespace": "com.acme.avro", + "fields": [{"name": "Key1", "type": "long"}, {"name": "Key2", "type": "long"}], + } +) + +KEY_DISTRIBUTION = json.dumps( + { + "com.acme.avro.testrecordkey::Key1": [0, 100], + "com.acme.avro.testrecordkey::Key2": [0, 250000], + } +) + +VALUE_SCHEMA = json.dumps( + { + "name": "testrecord", + "type": "record", + "namespace": "com.acme.avro", + "fields": [ + {"name": "Key1Unused", "type": "long"}, + {"name": "Key2Unused", "type": "long"}, + { + "name": "OuterRecord", + "type": { + "name": "OuterRecord", + "type": "record", + "fields": [ + { + "name": "Record1", + "type": { + "name": "Record1", + "type": "record", + "fields": [ + { + "name": "InnerRecord1", + "type": { + "name": "InnerRecord1", + "type": "record", + "fields": [ + {"name": "Point", "type": "long"} + ], + }, + }, + { + "name": "InnerRecord2", + "type": { + "name": "InnerRecord2", + "type": "record", + "fields": [ + {"name": "Point", "type": "long"} + ], + }, + }, + ], + }, + }, + { + "name": "Record2", + "type": { + "name": "Record2", + "type": "record", + "fields": [ + { + "name": "InnerRecord3", + "type": { + "name": "InnerRecord3", + "type": "record", + "fields": [ + {"name": "Point", "type": "long"} + ], + }, + }, + { + "name": "InnerRecord4", + "type": { + "name": "InnerRecord4", + "type": "record", + "fields": [ + {"name": "Point", "type": "long"} + ], + }, + }, + ], + }, + }, + ], + }, + }, + ], + } +) + +VALUE_DISTRIBUTION = json.dumps( + { + "com.acme.avro.testrecord::Key1Unused": [0, 100], + "com.acme.avro.testrecord::Key2Unused": [0, 250000], + "com.acme.avro.InnerRecord1::Point": [10000, 1000000000], + "com.acme.avro.InnerRecord2::Point": [10000, 1000000000], + "com.acme.avro.InnerRecord3::Point": [10000, 1000000000], + "com.acme.avro.InnerRecord4::Point": [10000, 10000000000], + } +) + + if __name__ == "__main__": main()