diff --git a/.changeset/busy-cloths-search.md b/.changeset/busy-cloths-search.md new file mode 100644 index 00000000000..7aed840f79e --- /dev/null +++ b/.changeset/busy-cloths-search.md @@ -0,0 +1,56 @@ +--- +'hive-console-sdk-rs': minor +--- + +Breaking Changes to avoid future breaking changes; + +Switch to [Builder](https://rust-unofficial.github.io/patterns/patterns/creational/builder.html) pattern for `SupergraphFetcher`, `PersistedDocumentsManager` and `UsageAgent` structs. + +No more `try_new` or `try_new_async` or `try_new_sync` functions, instead use `SupergraphFetcherBuilder`, `PersistedDocumentsManagerBuilder` and `UsageAgentBuilder` structs to create instances. + +Benefits; + +- No need to provide all parameters at once when creating an instance even for default values. + +Example; +```rust +// Before +let fetcher = SupergraphFetcher::try_new_async( + "SOME_ENDPOINT", // endpoint + "SOME_KEY", + "MyUserAgent/1.0".to_string(), + Duration::from_secs(5), // connect_timeout + Duration::from_secs(10), // request_timeout + false, // accept_invalid_certs + 3, // retry_count + )?; + +// After +// No need to provide all parameters at once, can use default values +let fetcher = SupergraphFetcherBuilder::new() + .endpoint("SOME_ENDPOINT".to_string()) + .key("SOME_KEY".to_string()) + .build_async()?; +``` + +- Easier to add new configuration options in the future without breaking existing code. + +Example; + +```rust +let fetcher = SupergraphFetcher::try_new_async( + "SOME_ENDPOINT", // endpoint + "SOME_KEY", + "MyUserAgent/1.0".to_string(), + Duration::from_secs(5), // connect_timeout + Duration::from_secs(10), // request_timeout + false, // accept_invalid_certs + 3, // retry_count + circuit_breaker_config, // Breaking Change -> new parameter added + )?; + +let fetcher = SupergraphFetcherBuilder::new() + .endpoint("SOME_ENDPOINT".to_string()) + .key("SOME_KEY".to_string()) + .build_async()?; // No breaking change, circuit_breaker_config can be added later if needed +``` \ No newline at end of file diff --git a/.changeset/light-walls-vanish.md b/.changeset/light-walls-vanish.md new file mode 100644 index 00000000000..7a1f4f89105 --- /dev/null +++ b/.changeset/light-walls-vanish.md @@ -0,0 +1,20 @@ +--- +'hive-console-sdk-rs': patch +--- + +Circuit Breaker Implementation and Multiple Endpoints Support + +Implementation of Circuit Breakers in Hive Console Rust SDK, you can learn more [here](https://the-guild.dev/graphql/hive/product-updates/2025-12-04-cdn-mirror-and-circuit-breaker) + +Breaking Changes: + +Now `endpoint` configuration accepts multiple endpoints as an array for `SupergraphFetcherBuilder` and `PersistedDocumentsManager`. + +```diff +SupergraphFetcherBuilder::default() +- .endpoint(endpoint) ++ .add_endpoint(endpoint1) ++ .add_endpoint(endpoint2) +``` + +This change requires updating the configuration structure to accommodate multiple endpoints. diff --git a/.changeset/violet-waves-happen.md b/.changeset/violet-waves-happen.md new file mode 100644 index 00000000000..67aabe639d1 --- /dev/null +++ b/.changeset/violet-waves-happen.md @@ -0,0 +1,17 @@ +--- +'hive-apollo-router-plugin': major +--- + +- Multiple endpoints support for `HiveRegistry` and `PersistedOperationsPlugin` + +Breaking Changes: +- Now there is no `endpoint` field in the configuration, it has been replaced with `endpoints`, which is an array of strings. You are not affected if you use environment variables to set the endpoint. + +```diff +HiveRegistry::new( + Some( + HiveRegistryConfig { +- endpoint: String::from("CDN_ENDPOINT"), ++ endpoints: vec![String::from("CDN_ENDPOINT1"), String::from("CDN_ENDPOINT2")], + ) +) diff --git a/configs/cargo/Cargo.lock b/configs/cargo/Cargo.lock index 8f84a6aa761..40d56619324 100644 --- a/configs/cargo/Cargo.lock +++ b/configs/cargo/Cargo.lock @@ -2613,15 +2613,20 @@ dependencies = [ "anyhow", "async-trait", "axum-core 0.5.5", + "futures-util", "graphql-parser", "graphql-tools", "md5", "mockito", "moka", + "once_cell", + "recloser", + "regex-automata", "regress", "reqwest", "reqwest-middleware", "reqwest-retry 0.8.0", + "retry-policies 0.5.1", "serde", "serde_json", "sha2", @@ -4623,6 +4628,16 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "recloser" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ac0d06281c3556fea72cef9e5372d9ac172335be0d71c3b4f3db900483e0eb" +dependencies = [ + "crossbeam-epoch", + "pin-project", +] + [[package]] name = "redis-protocol" version = "6.0.0" diff --git a/packages/libraries/router/src/persisted_documents.rs b/packages/libraries/router/src/persisted_documents.rs index 5391f1a6646..1fa5daaf30c 100644 --- a/packages/libraries/router/src/persisted_documents.rs +++ b/packages/libraries/router/src/persisted_documents.rs @@ -32,7 +32,7 @@ pub static PERSISTED_DOCUMENT_HASH_KEY: &str = "hive::persisted_document_hash"; pub struct Config { pub enabled: Option, /// GraphQL Hive persisted documents CDN endpoint URL. - pub endpoint: Option, + pub endpoint: Option, /// GraphQL Hive persisted documents CDN access token. pub key: Option, /// Whether arbitrary documents should be allowed along-side persisted documents. @@ -57,6 +57,25 @@ pub struct Config { pub cache_size: Option, } +#[derive(Clone, Debug, Deserialize, JsonSchema)] +#[serde(untagged)] +pub enum EndpointConfig { + Single(String), + Multiple(Vec), +} + +impl From<&str> for EndpointConfig { + fn from(value: &str) -> Self { + EndpointConfig::Single(value.into()) + } +} + +impl From<&[&str]> for EndpointConfig { + fn from(value: &[&str]) -> Self { + EndpointConfig::Multiple(value.iter().map(|s| s.to_string()).collect()) + } +} + pub struct PersistedDocumentsPlugin { persisted_documents_manager: Option>, allow_arbitrary_documents: bool, @@ -72,11 +91,14 @@ impl PersistedDocumentsPlugin { allow_arbitrary_documents, }); } - let endpoint = match &config.endpoint { - Some(ep) => ep.clone(), + let endpoints = match &config.endpoint { + Some(ep) => match ep { + EndpointConfig::Single(url) => vec![url.clone()], + EndpointConfig::Multiple(urls) => urls.clone(), + }, None => { if let Ok(ep) = env::var("HIVE_CDN_ENDPOINT") { - ep + vec![ep] } else { return Err( "Endpoint for persisted documents CDN is not configured. Please set it via the plugin configuration or HIVE_CDN_ENDPOINT environment variable." @@ -100,17 +122,41 @@ impl PersistedDocumentsPlugin { } }; + let mut persisted_documents_manager = PersistedDocumentsManager::builder() + .key(key) + .user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)); + + for endpoint in endpoints { + persisted_documents_manager = persisted_documents_manager.add_endpoint(endpoint); + } + + if let Some(connect_timeout) = config.connect_timeout { + persisted_documents_manager = + persisted_documents_manager.connect_timeout(Duration::from_secs(connect_timeout)); + } + + if let Some(request_timeout) = config.request_timeout { + persisted_documents_manager = + persisted_documents_manager.request_timeout(Duration::from_secs(request_timeout)); + } + + if let Some(retry_count) = config.retry_count { + persisted_documents_manager = persisted_documents_manager.max_retries(retry_count); + } + + if let Some(accept_invalid_certs) = config.accept_invalid_certs { + persisted_documents_manager = + persisted_documents_manager.accept_invalid_certs(accept_invalid_certs); + } + + if let Some(cache_size) = config.cache_size { + persisted_documents_manager = persisted_documents_manager.cache_size(cache_size); + } + + let persisted_documents_manager = persisted_documents_manager.build()?; + Ok(PersistedDocumentsPlugin { - persisted_documents_manager: Some(Arc::new(PersistedDocumentsManager::new( - key, - endpoint, - config.accept_invalid_certs.unwrap_or(false), - Duration::from_secs(config.connect_timeout.unwrap_or(5)), - Duration::from_secs(config.request_timeout.unwrap_or(15)), - config.retry_count.unwrap_or(3), - config.cache_size.unwrap_or(1000), - format!("hive-apollo-router/{}", PLUGIN_VERSION), - ))), + persisted_documents_manager: Some(Arc::new(persisted_documents_manager)), allow_arbitrary_documents, }) } @@ -344,8 +390,8 @@ mod hive_persisted_documents_tests { Self { server } } - fn endpoint(&self) -> String { - self.server.url("") + fn endpoint(&self) -> EndpointConfig { + EndpointConfig::Single(self.server.url("")) } /// Registers a valid artifact URL with an actual GraphQL document diff --git a/packages/libraries/router/src/registry.rs b/packages/libraries/router/src/registry.rs index 243c160cbc7..fb48b76bc35 100644 --- a/packages/libraries/router/src/registry.rs +++ b/packages/libraries/router/src/registry.rs @@ -1,14 +1,13 @@ use crate::consts::PLUGIN_VERSION; use crate::registry_logger::Logger; use anyhow::{anyhow, Result}; +use hive_console_sdk::supergraph_fetcher::sync::SupergraphFetcherSyncState; use hive_console_sdk::supergraph_fetcher::SupergraphFetcher; -use hive_console_sdk::supergraph_fetcher::SupergraphFetcherSyncState; use sha2::Digest; use sha2::Sha256; use std::env; use std::io::Write; use std::thread; -use std::time::Duration; #[derive(Debug)] pub struct HiveRegistry { @@ -18,7 +17,7 @@ pub struct HiveRegistry { } pub struct HiveRegistryConfig { - endpoint: Option, + endpoints: Vec, key: Option, poll_interval: Option, accept_invalid_certs: Option, @@ -29,7 +28,7 @@ impl HiveRegistry { #[allow(clippy::new_ret_no_self)] pub fn new(user_config: Option) -> Result<()> { let mut config = HiveRegistryConfig { - endpoint: None, + endpoints: vec![], key: None, poll_interval: None, accept_invalid_certs: Some(true), @@ -38,7 +37,7 @@ impl HiveRegistry { // Pass values from user's config if let Some(user_config) = user_config { - config.endpoint = user_config.endpoint; + config.endpoints = user_config.endpoints; config.key = user_config.key; config.poll_interval = user_config.poll_interval; config.accept_invalid_certs = user_config.accept_invalid_certs; @@ -47,9 +46,9 @@ impl HiveRegistry { // Pass values from environment variables if they are not set in the user's config - if config.endpoint.is_none() { + if config.endpoints.is_empty() { if let Ok(endpoint) = env::var("HIVE_CDN_ENDPOINT") { - config.endpoint = Some(endpoint); + config.endpoints.push(endpoint); } } @@ -86,7 +85,7 @@ impl HiveRegistry { } // Resolve values - let endpoint = config.endpoint.unwrap_or_default(); + let endpoint = config.endpoints; let key = config.key.unwrap_or_default(); let poll_interval: u64 = config.poll_interval.unwrap_or(10); let accept_invalid_certs = config.accept_invalid_certs.unwrap_or(false); @@ -120,19 +119,23 @@ impl HiveRegistry { .to_string_lossy() .to_string(), ); - env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone()); - env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true"); - - let fetcher = SupergraphFetcher::try_new_sync( - endpoint, - &key, - format!("hive-apollo-router/{}", PLUGIN_VERSION), - Duration::from_secs(5), - Duration::from_secs(60), - accept_invalid_certs, - 3, - ) - .map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?; + unsafe { + env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone()); + env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true"); + } + + let mut fetcher = SupergraphFetcher::builder() + .key(key) + .user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)) + .accept_invalid_certs(accept_invalid_certs); + + for ep in endpoint { + fetcher = fetcher.add_endpoint(ep); + } + + let fetcher = fetcher + .build_sync() + .map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?; let registry = HiveRegistry { fetcher, diff --git a/packages/libraries/router/src/usage.rs b/packages/libraries/router/src/usage.rs index 1964a6e04e5..0d28575675e 100644 --- a/packages/libraries/router/src/usage.rs +++ b/packages/libraries/router/src/usage.rs @@ -8,8 +8,7 @@ use core::ops::Drop; use futures::StreamExt; use graphql_parser::parse_schema; use graphql_parser::schema::Document; -use hive_console_sdk::agent::UsageAgentExt; -use hive_console_sdk::agent::{ExecutionReport, UsageAgent}; +use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent, UsageAgentExt}; use http::HeaderValue; use rand::Rng; use schemars::JsonSchema; @@ -244,20 +243,27 @@ impl Plugin for UsagePlugin { .expect("Failed to parse schema") .into_static(); + let token = token.expect("token is set"); + let agent = if enabled { let flush_interval = Duration::from_secs(flush_interval); - let agent = UsageAgent::try_new( - &token.expect("token is set"), - endpoint, - target_id, - buffer_size, - Duration::from_secs(connect_timeout), - Duration::from_secs(request_timeout), - accept_invalid_certs, - flush_interval, - format!("hive-apollo-router/{}", PLUGIN_VERSION), - ) - .map_err(Box::new)?; + + let mut agent = UsageAgent::builder() + .token(token) + .endpoint(endpoint) + .buffer_size(buffer_size) + .connect_timeout(Duration::from_secs(connect_timeout)) + .request_timeout(Duration::from_secs(request_timeout)) + .accept_invalid_certs(accept_invalid_certs) + .user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)) + .flush_interval(flush_interval); + + if let Some(target_id) = target_id { + agent = agent.target_id(target_id); + } + + let agent = agent.build().map_err(Box::new)?; + start_flush_interval(agent.clone()); Some(agent) } else { diff --git a/packages/libraries/sdk-rs/Cargo.toml b/packages/libraries/sdk-rs/Cargo.toml index a5fae8037f9..0964230d0ce 100644 --- a/packages/libraries/sdk-rs/Cargo.toml +++ b/packages/libraries/sdk-rs/Cargo.toml @@ -32,6 +32,11 @@ serde_json = "1" moka = { version = "0.12.10", features = ["future", "sync"] } sha2 = { version = "0.10.8", features = ["std"] } tokio-util = "0.7.16" +regex-automata = "0.4.10" +once_cell = "1.21.3" +retry-policies = "0.5.0" +recloser = "1.3.1" +futures-util = "0.3.31" typify = "0.5.0" regress = "0.10.5" diff --git a/packages/libraries/sdk-rs/src/agent/builder.rs b/packages/libraries/sdk-rs/src/agent/builder.rs new file mode 100644 index 00000000000..c0824e1f711 --- /dev/null +++ b/packages/libraries/sdk-rs/src/agent/builder.rs @@ -0,0 +1,216 @@ +use std::{sync::Arc, time::Duration}; + +use once_cell::sync::Lazy; +use recloser::AsyncRecloser; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest_middleware::ClientBuilder; +use reqwest_retry::RetryTransientMiddleware; + +use crate::agent::usage_agent::{non_empty_string, AgentError, Buffer, UsageAgent}; +use crate::agent::utils::OperationProcessor; +use crate::circuit_breaker; +use retry_policies::policies::ExponentialBackoff; + +pub struct UsageAgentBuilder { + token: Option, + endpoint: String, + target_id: Option, + buffer_size: usize, + connect_timeout: Duration, + request_timeout: Duration, + accept_invalid_certs: bool, + flush_interval: Duration, + retry_policy: ExponentialBackoff, + user_agent: Option, + circuit_breaker: Option, +} + +pub static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage"; + +impl Default for UsageAgentBuilder { + fn default() -> Self { + Self { + endpoint: DEFAULT_HIVE_USAGE_ENDPOINT.to_string(), + token: None, + target_id: None, + buffer_size: 1000, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(15), + accept_invalid_certs: false, + flush_interval: Duration::from_secs(5), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + user_agent: None, + circuit_breaker: None, + } + } +} + +fn is_legacy_token(token: &str) -> bool { + !token.starts_with("hvo1/") && !token.starts_with("hvu1/") && !token.starts_with("hvp1/") +} + +impl UsageAgentBuilder { + /// Your [Registry Access Token](https://the-guild.dev/graphql/hive/docs/management/targets#registry-access-tokens) with write permission. + pub fn token(mut self, token: String) -> Self { + self.token = non_empty_string(Some(token)); + self + } + /// For self-hosting, you can override `/usage` endpoint (defaults to `https://app.graphql-hive.com/usage`). + pub fn endpoint(mut self, endpoint: String) -> Self { + if let Some(endpoint) = non_empty_string(Some(endpoint)) { + self.endpoint = endpoint; + } + self + } + /// A target ID, this can either be a slug following the format “$organizationSlug/$projectSlug/$targetSlug” (e.g “the-guild/graphql-hive/staging”) or an UUID (e.g. “a0f4c605-6541-4350-8cfe-b31f21a4bf80”). To be used when the token is configured with an organization access token. + pub fn target_id(mut self, target_id: String) -> Self { + self.target_id = non_empty_string(Some(target_id)); + self + } + /// A maximum number of operations to hold in a buffer before sending to Hive Console + /// Default: 1000 + pub fn buffer_size(mut self, buffer_size: usize) -> Self { + self.buffer_size = buffer_size; + self + } + /// A timeout for only the connect phase of a request to Hive Console + /// Default: 5 seconds + pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + /// A timeout for the entire request to Hive Console + /// Default: 15 seconds + pub fn request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + /// Accepts invalid SSL certificates + /// Default: false + pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self { + self.accept_invalid_certs = accept_invalid_certs; + self + } + /// Frequency of flushing the buffer to the server + /// Default: 5 seconds + pub fn flush_interval(mut self, flush_interval: Duration) -> Self { + self.flush_interval = flush_interval; + self + } + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + self.user_agent = non_empty_string(Some(user_agent)); + self + } + /// Retry policy for sending reports + /// Default: ExponentialBackoff with max 3 retries + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + /// Maximum number of retries for sending reports + /// Default: ExponentialBackoff with max 3 retries + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + pub fn build(self) -> Result, AgentError> { + let mut default_headers = HeaderMap::new(); + + default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2")); + + if let Some(token) = self.token { + let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token)) + .map_err(|_| AgentError::InvalidToken)?; + + authorization_header.set_sensitive(true); + + default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header); + + default_headers.insert( + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(default_headers); + + if let Some(user_agent) = &self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); + } + + let reqwest_agent = reqwest_agent + .build() + .map_err(AgentError::HTTPClientCreationError)?; + let client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + let mut endpoint = self.endpoint; + + match self.target_id { + Some(_) if is_legacy_token(&token) => { + return Err(AgentError::TargetIdWithLegacyToken) + } + Some(target_id) if !is_legacy_token(&token) => { + let target_id = validate_target_id(&target_id)?; + endpoint.push_str(&format!("/{}", target_id)); + } + None if !is_legacy_token(&token) => return Err(AgentError::MissingTargetId), + _ => {} + } + + let circuit_breaker = if let Some(cb) = self.circuit_breaker { + cb + } else { + circuit_breaker::CircuitBreakerBuilder::default() + .build_async() + .map_err(AgentError::CircuitBreakerCreationError)? + }; + + Ok(Arc::new(UsageAgent { + endpoint, + buffer: Buffer::new(self.buffer_size), + processor: OperationProcessor::new(), + client, + flush_interval: self.flush_interval, + circuit_breaker, + })) + } else { + Err(AgentError::MissingToken) + } + } +} + +// Target ID regexp for validation: slug format +static SLUG_REGEX: Lazy = Lazy::new(|| { + regex_automata::meta::Regex::new(r"^[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+$").unwrap() +}); +// Target ID regexp for validation: UUID format +static UUID_REGEX: Lazy = Lazy::new(|| { + regex_automata::meta::Regex::new( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$", + ) + .unwrap() +}); + +fn validate_target_id(target_id: &str) -> Result<&str, AgentError> { + let trimmed_s = target_id.trim(); + if trimmed_s.is_empty() { + Err(AgentError::InvalidTargetId("".to_string())) + } else { + if SLUG_REGEX.is_match(trimmed_s) { + return Ok(trimmed_s); + } + if UUID_REGEX.is_match(trimmed_s) { + return Ok(trimmed_s); + } + Err(AgentError::InvalidTargetId(format!( + "Invalid target_id format: '{}'. It must be either in slug format '$organizationSlug/$projectSlug/$targetSlug' or UUID format 'a0f4c605-6541-4350-8cfe-b31f21a4bf80'", + trimmed_s + ))) + } +} diff --git a/packages/libraries/sdk-rs/src/agent/mod.rs b/packages/libraries/sdk-rs/src/agent/mod.rs new file mode 100644 index 00000000000..2e8250459f7 --- /dev/null +++ b/packages/libraries/sdk-rs/src/agent/mod.rs @@ -0,0 +1,3 @@ +pub mod builder; +pub mod usage_agent; +pub mod utils; diff --git a/packages/libraries/sdk-rs/src/agent.rs b/packages/libraries/sdk-rs/src/agent/usage_agent.rs similarity index 81% rename from packages/libraries/sdk-rs/src/agent.rs rename to packages/libraries/sdk-rs/src/agent/usage_agent.rs index 9344ae18363..b22d7f56998 100644 --- a/packages/libraries/sdk-rs/src/agent.rs +++ b/packages/libraries/sdk-rs/src/agent/usage_agent.rs @@ -1,8 +1,6 @@ -use super::graphql::OperationProcessor; use graphql_parser::schema::Document; -use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use recloser::AsyncRecloser; +use reqwest_middleware::ClientWithMiddleware; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, sync::{Arc, Mutex}, @@ -11,6 +9,9 @@ use std::{ use thiserror::Error; use tokio_util::sync::CancellationToken; +use crate::agent::builder::UsageAgentBuilder; +use crate::agent::utils::OperationProcessor; + #[derive(Debug, Clone)] pub struct ExecutionReport { pub schema: Arc>, @@ -28,18 +29,26 @@ pub struct ExecutionReport { typify::import_types!(schema = "./usage-report-v2.schema.json"); #[derive(Debug, Default)] -pub struct Buffer(Mutex>); +pub struct Buffer { + queue: Mutex>, + size: usize, +} impl Buffer { - fn new() -> Self { - Self(Mutex::new(VecDeque::new())) + pub fn new(size: usize) -> Self { + Self { + queue: Mutex::new(VecDeque::new()), + size, + } } fn lock_buffer( &self, ) -> Result>, AgentError> { - let buffer: Result>, AgentError> = - self.0.lock().map_err(|e| AgentError::Lock(e.to_string())); + let buffer: Result>, AgentError> = self + .queue + .lock() + .map_err(|e| AgentError::Lock(e.to_string())); buffer } @@ -56,15 +65,15 @@ impl Buffer { } } pub struct UsageAgent { - buffer_size: usize, - endpoint: String, - buffer: Buffer, - processor: OperationProcessor, - client: ClientWithMiddleware, - flush_interval: Duration, + pub(crate) endpoint: String, + pub(crate) buffer: Buffer, + pub(crate) processor: OperationProcessor, + pub(crate) client: ClientWithMiddleware, + pub(crate) flush_interval: Duration, + pub(crate) circuit_breaker: AsyncRecloser, } -fn non_empty_string(value: Option) -> Option { +pub fn non_empty_string(value: Option) -> Option { value.filter(|str| !str.is_empty()) } @@ -78,75 +87,30 @@ pub enum AgentError { Forbidden, #[error("unable to send report: rate limited")] RateLimited, - #[error("invalid token provided: {0}")] - InvalidToken(String), + #[error("missing token")] + MissingToken, + #[error("your access token requires providing a 'target_id' option.")] + MissingTargetId, + #[error("using 'target_id' with legacy tokens is not supported")] + TargetIdWithLegacyToken, + #[error("invalid token provided")] + InvalidToken, + #[error("invalid target id provided: {0}, it should be either a slug like \"$organizationSlug/$projectSlug/$targetSlug\" or an UUID")] + InvalidTargetId(String), #[error("unable to instantiate the http client for reports sending: {0}")] HTTPClientCreationError(reqwest::Error), + #[error("unable to create circuit breaker: {0}")] + CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError), + #[error("rejected by the circuit breaker")] + CircuitBreakerRejected, #[error("unable to send report: {0}")] Unknown(String), } impl UsageAgent { - #[allow(clippy::too_many_arguments)] - pub fn try_new( - token: &str, - endpoint: String, - target_id: Option, - buffer_size: usize, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - flush_interval: Duration, - user_agent: String, - ) -> Result, AgentError> { - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); - - let mut default_headers = HeaderMap::new(); - - default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2")); - - let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token)) - .map_err(|_| AgentError::InvalidToken(token.to_string()))?; - - authorization_header.set_sensitive(true); - - default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header); - - default_headers.insert( - reqwest::header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(default_headers) - .build() - .map_err(AgentError::HTTPClientCreationError)?; - let client = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - let mut endpoint = endpoint; - - if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") { - if let Some(target_id) = target_id { - endpoint.push_str(&format!("/{}", target_id)); - } - } - - Ok(Arc::new(Self { - buffer_size, - endpoint, - buffer: Buffer::new(), - processor: OperationProcessor::new(), - client, - flush_interval, - })) + pub fn builder() -> UsageAgentBuilder { + UsageAgentBuilder::default() } - fn produce_report(&self, reports: Vec) -> Result { let mut report = Report { size: 0, @@ -233,21 +197,28 @@ impl UsageAgent { Ok(report) } - pub async fn send_report(&self, report: Report) -> Result<(), AgentError> { + async fn send_report(&self, report: Report) -> Result<(), AgentError> { if report.size == 0 { return Ok(()); } let report_body = serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?; // Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure - let resp = self + let resp_fut = self .client .post(&self.endpoint) .header(reqwest::header::CONTENT_LENGTH, report_body.len()) .body(report_body) - .send() + .send(); + + let resp = self + .circuit_breaker + .call(resp_fut) .await - .map_err(|e| AgentError::Unknown(e.to_string()))?; + .map_err(|e| match e { + recloser::Error::Inner(e) => AgentError::Unknown(e.to_string()), + recloser::Error::Rejected => AgentError::CircuitBreakerRejected, + })?; match resp.status() { reqwest::StatusCode::OK => Ok(()), @@ -307,7 +278,7 @@ pub trait UsageAgentExt { impl UsageAgentExt for Arc { fn flush_if_full(&self, size: usize) -> Result<(), AgentError> { - if size >= self.buffer_size { + if size >= self.buffer.size { let cloned_self = self.clone(); tokio::task::spawn(async move { cloned_self.flush().await; @@ -333,7 +304,7 @@ mod tests { use graphql_parser::{parse_query, parse_schema}; use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; - use crate::agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt}; + use crate::agent::usage_agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt}; const CONTENT_TYPE_VALUE: &'static str = "application/json"; const GRAPHQL_CLIENT_NAME: &'static str = "Hive Client"; @@ -489,18 +460,17 @@ mod tests { ) .expect("Failed to parse query"); - let usage_agent = UsageAgent::try_new( - token, - format!("{}/200", server_url), - None, - 10, - Duration::from_millis(500), - Duration::from_millis(500), - false, - Duration::from_millis(10), - user_agent, - ) - .expect("Failed to create UsageAgent"); + let usage_agent = UsageAgent::builder() + .token(token.to_string()) + .endpoint(format!("{}/200", server_url)) + .buffer_size(10) + .connect_timeout(Duration::from_millis(500)) + .request_timeout(Duration::from_millis(500)) + .accept_invalid_certs(false) + .flush_interval(Duration::from_millis(10)) + .user_agent(user_agent.clone()) + .build() + .expect("Failed to create UsageAgent"); usage_agent .add_report(ExecutionReport { @@ -509,7 +479,7 @@ mod tests { operation_name: Some("deleteProject".to_string()), client_name: Some(GRAPHQL_CLIENT_NAME.to_string()), client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()), - timestamp: timestamp.try_into().unwrap(), + timestamp, duration, ok: true, errors: 0, diff --git a/packages/libraries/sdk-rs/src/graphql.rs b/packages/libraries/sdk-rs/src/agent/utils.rs similarity index 100% rename from packages/libraries/sdk-rs/src/graphql.rs rename to packages/libraries/sdk-rs/src/agent/utils.rs diff --git a/packages/libraries/sdk-rs/src/circuit_breaker.rs b/packages/libraries/sdk-rs/src/circuit_breaker.rs new file mode 100644 index 00000000000..0dbd4529c2f --- /dev/null +++ b/packages/libraries/sdk-rs/src/circuit_breaker.rs @@ -0,0 +1,67 @@ +use std::time::Duration; + +use recloser::{AsyncRecloser, Recloser}; + +#[derive(Clone)] +pub struct CircuitBreakerBuilder { + error_threshold: f32, + volume_threshold: usize, + reset_timeout: Duration, +} + +impl Default for CircuitBreakerBuilder { + fn default() -> Self { + Self { + error_threshold: 0.5, + volume_threshold: 5, + reset_timeout: Duration::from_secs(30), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum CircuitBreakerError { + #[error("Invalid error threshold: {0}. It must be between 0.0 and 1.0")] + InvalidErrorThreshold(f32), +} + +impl CircuitBreakerBuilder { + /// Percentage after what the circuit breaker should kick in. + /// Default: .5 + pub fn error_threshold(mut self, percentage: f32) -> Self { + self.error_threshold = percentage; + self + } + /// Count of requests before starting evaluating. + /// Default: 5 + pub fn volume_threshold(mut self, threshold: usize) -> Self { + self.volume_threshold = threshold; + self + } + /// After what time the circuit breaker is attempting to retry sending requests in milliseconds. + /// Default: 30s + pub fn reset_timeout(mut self, timeout: Duration) -> Self { + self.reset_timeout = timeout; + self + } + + pub fn build_async(self) -> Result { + let recloser = self.build_sync()?; + Ok(AsyncRecloser::from(recloser)) + } + pub fn build_sync(self) -> Result { + let error_threshold = if self.error_threshold < 0.0 || self.error_threshold > 1.0 { + return Err(CircuitBreakerError::InvalidErrorThreshold( + self.error_threshold, + )); + } else { + self.error_threshold + }; + let recloser = Recloser::custom() + .error_rate(error_threshold) + .closed_len(self.volume_threshold) + .open_wait(self.reset_timeout) + .build(); + Ok(recloser) + } +} diff --git a/packages/libraries/sdk-rs/src/lib.rs b/packages/libraries/sdk-rs/src/lib.rs index ec6f97886e0..0201c9cc2ca 100644 --- a/packages/libraries/sdk-rs/src/lib.rs +++ b/packages/libraries/sdk-rs/src/lib.rs @@ -1,4 +1,4 @@ pub mod agent; -pub mod graphql; +pub mod circuit_breaker; pub mod persisted_documents; pub mod supergraph_fetcher; diff --git a/packages/libraries/sdk-rs/src/persisted_documents.rs b/packages/libraries/sdk-rs/src/persisted_documents.rs index 02ce66132c8..4a5aab95ccb 100644 --- a/packages/libraries/sdk-rs/src/persisted_documents.rs +++ b/packages/libraries/sdk-rs/src/persisted_documents.rs @@ -1,18 +1,22 @@ use std::time::Duration; +use crate::agent::usage_agent::non_empty_string; +use crate::circuit_breaker::CircuitBreakerBuilder; use moka::future::Cache; +use recloser::AsyncRecloser; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientWithMiddleware; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use reqwest_retry::RetryTransientMiddleware; +use retry_policies::policies::ExponentialBackoff; use tracing::{debug, info, warn}; #[derive(Debug)] pub struct PersistedDocumentsManager { - agent: ClientWithMiddleware, + client: ClientWithMiddleware, cache: Cache, - endpoint: String, + endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>, } #[derive(Debug, thiserror::Error)] @@ -31,6 +35,18 @@ pub enum PersistedDocumentsError { FailedToReadCDNResponse(reqwest::Error), #[error("No persisted document provided, or document id cannot be resolved.")] PersistedDocumentRequired, + #[error("Missing required configuration option: {0}")] + MissingConfigurationOption(String), + #[error("Invalid CDN key {0}")] + InvalidCDNKey(String), + #[error("Failed to create HTTP client: {0}")] + HTTPClientCreationError(reqwest::Error), + #[error("unable to create circuit breaker: {0}")] + CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError), + #[error("rejected by the circuit breaker")] + CircuitBreakerRejected, + #[error("unknown error")] + Unknown, } impl PersistedDocumentsError { @@ -51,47 +67,75 @@ impl PersistedDocumentsError { PersistedDocumentsError::PersistedDocumentRequired => { "PERSISTED_DOCUMENT_REQUIRED".into() } + PersistedDocumentsError::MissingConfigurationOption(_) => { + "MISSING_CONFIGURATION_OPTION".into() + } + PersistedDocumentsError::InvalidCDNKey(_) => "INVALID_CDN_KEY".into(), + PersistedDocumentsError::HTTPClientCreationError(_) => { + "HTTP_CLIENT_CREATION_ERROR".into() + } + PersistedDocumentsError::CircuitBreakerCreationError(_) => { + "CIRCUIT_BREAKER_CREATION_ERROR".into() + } + PersistedDocumentsError::CircuitBreakerRejected => "CIRCUIT_BREAKER_REJECTED".into(), + PersistedDocumentsError::Unknown => "UNKNOWN_ERROR".into(), } } } impl PersistedDocumentsManager { - #[allow(clippy::too_many_arguments)] - pub fn new( - key: String, - endpoint: String, - accept_invalid_certs: bool, - connect_timeout: Duration, - request_timeout: Duration, - retry_count: u32, - cache_size: u64, - user_agent: String, - ) -> Self { - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count); + pub fn builder() -> PersistedDocumentsManagerBuilder { + PersistedDocumentsManagerBuilder::default() + } + async fn resolve_from_endpoint( + &self, + endpoint: &str, + document_id: &str, + circuit_breaker: &AsyncRecloser, + ) -> Result { + let cdn_document_id = str::replace(document_id, "~", "/"); + let cdn_artifact_url = format!("{}/apps/{}", endpoint, cdn_document_id); + info!( + "Fetching document {} from CDN: {}", + document_id, cdn_artifact_url + ); + let response_fut = self.client.get(cdn_artifact_url).send(); - let mut default_headers = HeaderMap::new(); - default_headers.insert("X-Hive-CDN-Key", HeaderValue::from_str(&key).unwrap()); - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(default_headers) - .build() - .expect("Failed to create reqwest client"); - let agent = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); + let response = circuit_breaker + .call(response_fut) + .await + .map_err(|e| match e { + recloser::Error::Inner(e) => PersistedDocumentsError::FailedToFetchFromCDN(e), + recloser::Error::Rejected => PersistedDocumentsError::CircuitBreakerRejected, + })?; - let cache = Cache::::new(cache_size); + if response.status().is_success() { + let document = response + .text() + .await + .map_err(PersistedDocumentsError::FailedToReadCDNResponse)?; + debug!( + "Document fetched from CDN: {}, storing in local cache", + document + ); + self.cache + .insert(document_id.into(), document.clone()) + .await; - Self { - agent, - cache, - endpoint, + return Ok(document); } - } + warn!( + "Document fetch from CDN failed: HTTP {}, Body: {:?}", + response.status(), + response + .text() + .await + .unwrap_or_else(|_| "Unavailable".to_string()) + ); + + Err(PersistedDocumentsError::DocumentNotFound) + } /// Resolves the document from the cache, or from the CDN pub async fn resolve_document( &self, @@ -110,50 +154,173 @@ impl PersistedDocumentsManager { "Document {} not found in cache. Fetching from CDN", document_id ); - let cdn_document_id = str::replace(document_id, "~", "/"); - let cdn_artifact_url = format!("{}/apps/{}", &self.endpoint, cdn_document_id); - info!( - "Fetching document {} from CDN: {}", - document_id, cdn_artifact_url - ); - let cdn_response = self.agent.get(cdn_artifact_url).send().await; - - match cdn_response { - Ok(response) => { - if response.status().is_success() { - let document = response - .text() - .await - .map_err(PersistedDocumentsError::FailedToReadCDNResponse)?; - debug!( - "Document fetched from CDN: {}, storing in local cache", - document - ); - self.cache - .insert(document_id.into(), document.clone()) - .await; - - return Ok(document); + let mut last_error: Option = None; + for (endpoint, circuit_breaker) in &self.endpoints_with_circuit_breakers { + let result = self + .resolve_from_endpoint(endpoint, document_id, circuit_breaker) + .await; + match result { + Ok(document) => return Ok(document), + Err(e) => { + last_error = Some(e); } + } + } + match last_error { + Some(e) => Err(e), + None => Err(PersistedDocumentsError::Unknown), + } + } + } + } +} - warn!( - "Document fetch from CDN failed: HTTP {}, Body: {:?}", - response.status(), - response - .text() - .await - .unwrap_or_else(|_| "Unavailable".to_string()) - ); +pub struct PersistedDocumentsManagerBuilder { + key: Option, + endpoints: Vec, + accept_invalid_certs: bool, + connect_timeout: Duration, + request_timeout: Duration, + retry_policy: ExponentialBackoff, + cache_size: u64, + user_agent: Option, + circuit_breaker: CircuitBreakerBuilder, +} - Err(PersistedDocumentsError::DocumentNotFound) - } - Err(e) => { - warn!("Failed to fetch document from CDN: {:?}", e); +impl Default for PersistedDocumentsManagerBuilder { + fn default() -> Self { + Self { + key: None, + endpoints: vec![], + accept_invalid_certs: false, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(15), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + cache_size: 10_000, + user_agent: None, + circuit_breaker: CircuitBreakerBuilder::default(), + } + } +} - Err(PersistedDocumentsError::FailedToFetchFromCDN(e)) - } - } +impl PersistedDocumentsManagerBuilder { + /// The CDN Access Token with from the Hive Console target. + pub fn key(mut self, key: String) -> Self { + self.key = non_empty_string(Some(key)); + self + } + + /// The CDN endpoint from Hive Console target. + pub fn add_endpoint(mut self, endpoint: String) -> Self { + if let Some(endpoint) = non_empty_string(Some(endpoint)) { + self.endpoints.push(endpoint); + } + self + } + + /// Accept invalid SSL certificates + /// default: false + pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self { + self.accept_invalid_certs = accept_invalid_certs; + self + } + + /// Connection timeout for the Hive Console CDN requests. + /// Default: 5 seconds + pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + + /// Request timeout for the Hive Console CDN requests. + /// Default: 15 seconds + pub fn request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + + /// Retry policy for fetching persisted documents + /// Default: ExponentialBackoff with max 3 retries + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + + /// Maximum number of retries for fetching persisted documents + /// Default: ExponentialBackoff with max 3 retries + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + + /// Size of the in-memory cache for persisted documents + /// Default: 10,000 entries + pub fn cache_size(mut self, cache_size: u64) -> Self { + self.cache_size = cache_size; + self + } + + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + self.user_agent = non_empty_string(Some(user_agent)); + self + } + + pub fn build(self) -> Result { + let mut default_headers = HeaderMap::new(); + let key = match self.key { + Some(key) => key, + None => { + return Err(PersistedDocumentsError::MissingConfigurationOption( + "key".to_string(), + )); } + }; + default_headers.insert( + "X-Hive-CDN-Key", + HeaderValue::from_str(&key) + .map_err(|e| PersistedDocumentsError::InvalidCDNKey(e.to_string()))?, + ); + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(default_headers); + + if let Some(user_agent) = self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); } + + let reqwest_agent = reqwest_agent + .build() + .map_err(PersistedDocumentsError::HTTPClientCreationError)?; + let client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + let cache = Cache::::new(self.cache_size); + + if self.endpoints.is_empty() { + return Err(PersistedDocumentsError::MissingConfigurationOption( + "endpoints".to_string(), + )); + } + + Ok(PersistedDocumentsManager { + client, + cache, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(move |endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .build_async() + .map_err(PersistedDocumentsError::CircuitBreakerCreationError)?; + Ok((endpoint, circuit_breaker)) + }) + .collect::, PersistedDocumentsError>>()?, + }) } } diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher.rs deleted file mode 100644 index 98c2540eed7..00000000000 --- a/packages/libraries/sdk-rs/src/supergraph_fetcher.rs +++ /dev/null @@ -1,260 +0,0 @@ -use std::fmt::Display; -use std::sync::RwLock; -use std::time::Duration; -use std::time::SystemTime; - -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; -use reqwest::header::InvalidHeaderValue; -use reqwest::header::IF_NONE_MATCH; -use reqwest_middleware::ClientBuilder; -use reqwest_middleware::ClientWithMiddleware; -use reqwest_retry::policies::ExponentialBackoff; -use reqwest_retry::RetryDecision; -use reqwest_retry::RetryPolicy; -use reqwest_retry::RetryTransientMiddleware; - -#[derive(Debug)] -pub struct SupergraphFetcher { - client: SupergraphFetcherAsyncOrSyncClient, - endpoint: String, - etag: RwLock>, - state: std::marker::PhantomData, -} - -#[derive(Debug)] -pub struct SupergraphFetcherAsyncState; -#[derive(Debug)] -pub struct SupergraphFetcherSyncState; - -#[derive(Debug)] -enum SupergraphFetcherAsyncOrSyncClient { - Async { - reqwest_client: ClientWithMiddleware, - }, - Sync { - reqwest_client: reqwest::blocking::Client, - retry_policy: ExponentialBackoff, - }, -} - -pub enum SupergraphFetcherError { - FetcherCreationError(reqwest::Error), - NetworkError(reqwest_middleware::Error), - NetworkResponseError(reqwest::Error), - Lock(String), - InvalidKey(InvalidHeaderValue), -} - -impl Display for SupergraphFetcherError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SupergraphFetcherError::FetcherCreationError(e) => { - write!(f, "Creating fetcher failed: {}", e) - } - SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e), - SupergraphFetcherError::NetworkResponseError(e) => { - write!(f, "Network response error: {}", e) - } - SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e), - SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e), - } - } -} - -fn prepare_client_config( - mut endpoint: String, - key: &str, - retry_count: u32, -) -> Result<(String, HeaderMap, ExponentialBackoff), SupergraphFetcherError> { - if !endpoint.ends_with("/supergraph") { - if endpoint.ends_with("/") { - endpoint.push_str("supergraph"); - } else { - endpoint.push_str("/supergraph"); - } - } - - let mut headers = HeaderMap::new(); - let mut cdn_key_header = - HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?; - cdn_key_header.set_sensitive(true); - headers.insert("X-Hive-CDN-Key", cdn_key_header); - - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count); - - Ok((endpoint, headers, retry_policy)) -} - -impl SupergraphFetcher { - #[allow(clippy::too_many_arguments)] - pub fn try_new_sync( - endpoint: String, - key: &str, - user_agent: String, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - retry_count: u32, - ) -> Result { - let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?; - - Ok(Self { - client: SupergraphFetcherAsyncOrSyncClient::Sync { - reqwest_client: reqwest::blocking::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(headers) - .build() - .map_err(SupergraphFetcherError::FetcherCreationError)?, - retry_policy, - }, - endpoint, - etag: RwLock::new(None), - state: std::marker::PhantomData, - }) - } - - pub fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { - let request_start_time = SystemTime::now(); - // Implementing retry logic for sync client - let mut n_past_retries = 0; - let (reqwest_client, retry_policy) = match &self.client { - SupergraphFetcherAsyncOrSyncClient::Sync { - reqwest_client, - retry_policy, - } => (reqwest_client, retry_policy), - _ => unreachable!(), - }; - let resp = loop { - let mut req = reqwest_client.get(&self.endpoint); - let etag = self.get_latest_etag()?; - if let Some(etag) = etag { - req = req.header(IF_NONE_MATCH, etag); - } - let response = req.send(); - - match response { - Ok(resp) => break resp, - Err(e) => match retry_policy.should_retry(request_start_time, n_past_retries) { - RetryDecision::DoNotRetry => { - return Err(SupergraphFetcherError::NetworkError( - reqwest_middleware::Error::Reqwest(e), - )); - } - RetryDecision::Retry { execute_after } => { - n_past_retries += 1; - if let Ok(duration) = execute_after.elapsed() { - std::thread::sleep(duration); - } - } - }, - } - }; - - if resp.status().as_u16() == 304 { - return Ok(None); - } - - let etag = resp.headers().get("etag"); - self.update_latest_etag(etag)?; - - let text = resp - .text() - .map_err(SupergraphFetcherError::NetworkResponseError)?; - - Ok(Some(text)) - } -} - -impl SupergraphFetcher { - #[allow(clippy::too_many_arguments)] - pub fn try_new_async( - endpoint: String, - key: &str, - user_agent: String, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - retry_count: u32, - ) -> Result { - let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?; - - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .default_headers(headers) - .user_agent(user_agent) - .build() - .map_err(SupergraphFetcherError::FetcherCreationError)?; - let reqwest_client = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - Ok(Self { - client: SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client }, - endpoint, - etag: RwLock::new(None), - state: std::marker::PhantomData, - }) - } - pub async fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { - let reqwest_client = match &self.client { - SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client } => reqwest_client, - _ => unreachable!(), - }; - let mut req = reqwest_client.get(&self.endpoint); - let etag = self.get_latest_etag()?; - if let Some(etag) = etag { - req = req.header(IF_NONE_MATCH, etag); - } - - let resp = req - .send() - .await - .map_err(SupergraphFetcherError::NetworkError)?; - - if resp.status().as_u16() == 304 { - return Ok(None); - } - - let etag = resp.headers().get("etag"); - self.update_latest_etag(etag)?; - - let text = resp - .text() - .await - .map_err(SupergraphFetcherError::NetworkResponseError)?; - - Ok(Some(text)) - } -} - -impl SupergraphFetcher { - fn get_latest_etag(&self) -> Result, SupergraphFetcherError> { - let guard: std::sync::RwLockReadGuard<'_, Option> = - self.etag.try_read().map_err(|e| { - SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e)) - })?; - - Ok(guard.clone()) - } - - fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> { - let mut guard: std::sync::RwLockWriteGuard<'_, Option> = - self.etag.try_write().map_err(|e| { - SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e)) - })?; - - if let Some(etag_value) = etag { - *guard = Some(etag_value.clone()); - } else { - *guard = None; - } - - Ok(()) - } -} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs new file mode 100644 index 00000000000..55994d392fb --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs @@ -0,0 +1,149 @@ +use futures_util::TryFutureExt; +use reqwest::header::{HeaderValue, IF_NONE_MATCH}; +use reqwest_middleware::ClientBuilder; +use reqwest_retry::RetryTransientMiddleware; +use tokio::sync::RwLock; + +use crate::supergraph_fetcher::{ + builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherAsyncOrSyncClient, + SupergraphFetcherError, +}; + +#[derive(Debug)] +pub struct SupergraphFetcherAsyncState; + +impl SupergraphFetcher { + pub async fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { + let (endpoints_with_circuit_breakers, reqwest_client) = match &self.client { + SupergraphFetcherAsyncOrSyncClient::Async { + endpoints_with_circuit_breakers, + reqwest_client, + } => (endpoints_with_circuit_breakers, reqwest_client), + _ => unreachable!("Called async fetcher on sync client"), + }; + let mut last_error: Option = None; + let mut last_resp = None; + for (endpoint, circuit_breaker) in endpoints_with_circuit_breakers { + let mut req = reqwest_client.get(endpoint); + let etag = self.get_latest_etag().await; + if let Some(etag) = etag { + req = req.header(IF_NONE_MATCH, etag); + } + let resp_fut = async { + let mut resp = req + .send() + .await + .map_err(SupergraphFetcherError::NetworkError); + // Server errors (5xx) are considered errors + if let Ok(ok_res) = resp { + resp = if ok_res.status().is_server_error() { + return Err(SupergraphFetcherError::NetworkError( + reqwest_middleware::Error::Middleware(anyhow::anyhow!( + "Server error: {}", + ok_res.status() + )), + )); + } else { + Ok(ok_res) + } + } + resp + }; + let resp = circuit_breaker + .call(resp_fut) + // Map recloser errors to SupergraphFetcherError + .map_err(|e| match e { + recloser::Error::Inner(e) => e, + recloser::Error::Rejected => SupergraphFetcherError::RejectedByCircuitBreaker, + }) + .await; + match resp { + Err(err) => { + last_error = Some(err); + continue; + } + Ok(resp) => { + last_resp = Some(resp); + break; + } + } + } + + if let Some(last_resp) = last_resp { + let etag = last_resp.headers().get("etag"); + self.update_latest_etag(etag).await; + let text = last_resp + .text() + .await + .map_err(SupergraphFetcherError::NetworkResponseError)?; + Ok(Some(text)) + } else if let Some(error) = last_error { + Err(error) + } else { + Ok(None) + } + } + async fn get_latest_etag(&self) -> Option { + let guard = self.etag.read().await; + + guard.clone() + } + async fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> () { + let mut guard = self.etag.write().await; + + if let Some(etag_value) = etag { + *guard = Some(etag_value.clone()); + } else { + *guard = None; + } + } +} + +impl SupergraphFetcherBuilder { + /// Builds an asynchronous SupergraphFetcher + pub fn build_async( + self, + ) -> Result, SupergraphFetcherError> { + self.validate_endpoints()?; + + let headers = self.prepare_headers()?; + + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(headers); + + if let Some(user_agent) = self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); + } + + let reqwest_agent = reqwest_agent + .build() + .map_err(SupergraphFetcherError::FetcherCreationError)?; + let reqwest_client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + Ok(SupergraphFetcher { + client: SupergraphFetcherAsyncOrSyncClient::Async { + reqwest_client, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(|endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .unwrap_or_default() + .build_async() + .map_err(SupergraphFetcherError::CircuitBreakerCreationError); + circuit_breaker.map(|cb| (endpoint, cb)) + }) + .collect::, _>>()?, + }, + etag: RwLock::new(None), + state: std::marker::PhantomData, + }) + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs new file mode 100644 index 00000000000..adddc011232 --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs @@ -0,0 +1,135 @@ +use std::time::Duration; + +use reqwest::header::{HeaderMap, HeaderValue}; +use retry_policies::policies::ExponentialBackoff; + +use crate::{ + agent::usage_agent::non_empty_string, circuit_breaker::CircuitBreakerBuilder, + supergraph_fetcher::SupergraphFetcherError, +}; + +pub struct SupergraphFetcherBuilder { + pub(crate) endpoints: Vec, + pub(crate) key: Option, + pub(crate) user_agent: Option, + pub(crate) connect_timeout: Duration, + pub(crate) request_timeout: Duration, + pub(crate) accept_invalid_certs: bool, + pub(crate) retry_policy: ExponentialBackoff, + pub(crate) circuit_breaker: Option, +} + +impl Default for SupergraphFetcherBuilder { + fn default() -> Self { + Self { + endpoints: vec![], + key: None, + user_agent: None, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(60), + accept_invalid_certs: false, + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + circuit_breaker: None, + } + } +} + +impl SupergraphFetcherBuilder { + pub fn new() -> Self { + Self::default() + } + + /// The CDN endpoint from Hive Console target. + pub fn add_endpoint(mut self, endpoint: String) -> Self { + if let Some(mut endpoint) = non_empty_string(Some(endpoint)) { + if !endpoint.ends_with("/supergraph") { + if endpoint.ends_with("/") { + endpoint.push_str("supergraph"); + } else { + endpoint.push_str("/supergraph"); + } + } + self.endpoints.push(endpoint); + } + self + } + + /// The CDN Access Token with from the Hive Console target. + pub fn key(mut self, key: String) -> Self { + self.key = Some(key); + self + } + + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + self.user_agent = Some(user_agent); + self + } + + /// Connection timeout for the Hive Console CDN requests. + /// Default: 5 seconds + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + /// Request timeout for the Hive Console CDN requests. + /// Default: 60 seconds + pub fn request_timeout(mut self, timeout: Duration) -> Self { + self.request_timeout = timeout; + self + } + + pub fn accept_invalid_certs(mut self, accept: bool) -> Self { + self.accept_invalid_certs = accept; + self + } + + /// Policy for retrying failed requests. + /// + /// By default, an exponential backoff retry policy is used, with 10 attempts. + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + + /// Maximum number of retries for failed requests. + /// + /// By default, an exponential backoff retry policy is used, with 10 attempts. + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + + pub fn circuit_breaker(&mut self, builder: CircuitBreakerBuilder) -> &mut Self { + self.circuit_breaker = Some(builder); + self + } + + pub(crate) fn validate_endpoints(&self) -> Result<(), SupergraphFetcherError> { + if self.endpoints.is_empty() { + return Err(SupergraphFetcherError::MissingConfigurationOption( + "endpoint".to_string(), + )); + } + Ok(()) + } + + pub(crate) fn prepare_headers(&self) -> Result { + let key = match &self.key { + Some(key) => key, + None => { + return Err(SupergraphFetcherError::MissingConfigurationOption( + "key".to_string(), + )) + } + }; + let mut headers = HeaderMap::new(); + let mut cdn_key_header = + HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?; + cdn_key_header.set_sensitive(true); + headers.insert("X-Hive-CDN-Key", cdn_key_header); + + Ok(headers) + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs new file mode 100644 index 00000000000..dba4cddcdb2 --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs @@ -0,0 +1,78 @@ +use std::fmt::Display; +use tokio::sync::RwLock; + +use crate::circuit_breaker::CircuitBreakerError; +use crate::supergraph_fetcher::async_::SupergraphFetcherAsyncState; +use recloser::AsyncRecloser; +use recloser::Recloser; +use reqwest::header::HeaderValue; +use reqwest::header::InvalidHeaderValue; +use reqwest_middleware::ClientWithMiddleware; +use retry_policies::policies::ExponentialBackoff; + +pub mod async_; +pub mod builder; +pub mod sync; + +#[derive(Debug)] +pub struct SupergraphFetcher { + client: SupergraphFetcherAsyncOrSyncClient, + etag: RwLock>, + state: std::marker::PhantomData, +} + +#[derive(Debug)] +enum SupergraphFetcherAsyncOrSyncClient { + Async { + endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>, + reqwest_client: ClientWithMiddleware, + }, + Sync { + endpoints_with_circuit_breakers: Vec<(String, Recloser)>, + reqwest_client: reqwest::blocking::Client, + retry_policy: ExponentialBackoff, + }, +} + +// Doesn't matter which one we implement this for, both have the same builder +impl SupergraphFetcher { + pub fn builder() -> builder::SupergraphFetcherBuilder { + builder::SupergraphFetcherBuilder::default() + } +} + +pub enum SupergraphFetcherError { + FetcherCreationError(reqwest::Error), + NetworkError(reqwest_middleware::Error), + NetworkResponseError(reqwest::Error), + Lock(String), + InvalidKey(InvalidHeaderValue), + MissingConfigurationOption(String), + RejectedByCircuitBreaker, + CircuitBreakerCreationError(CircuitBreakerError), +} + +impl Display for SupergraphFetcherError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SupergraphFetcherError::FetcherCreationError(e) => { + write!(f, "Creating fetcher failed: {}", e) + } + SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e), + SupergraphFetcherError::NetworkResponseError(e) => { + write!(f, "Network response error: {}", e) + } + SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e), + SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e), + SupergraphFetcherError::MissingConfigurationOption(e) => { + write!(f, "Missing configuration option: {}", e) + } + SupergraphFetcherError::RejectedByCircuitBreaker => { + write!(f, "Request rejected by circuit breaker") + } + SupergraphFetcherError::CircuitBreakerCreationError(e) => { + write!(f, "Creating circuit breaker failed: {}", e) + } + } + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs new file mode 100644 index 00000000000..726b354d6d5 --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs @@ -0,0 +1,185 @@ +use std::time::SystemTime; + +use reqwest::header::{HeaderValue, IF_NONE_MATCH}; +use reqwest_retry::{RetryDecision, RetryPolicy}; +use tokio::sync::RwLock; + +use crate::supergraph_fetcher::{ + builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherAsyncOrSyncClient, + SupergraphFetcherError, +}; + +#[derive(Debug)] +pub struct SupergraphFetcherSyncState; + +impl SupergraphFetcher { + pub fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { + let (endpoints_with_circuit_breakers, reqwest_client, retry_policy) = match &self.client { + SupergraphFetcherAsyncOrSyncClient::Sync { + endpoints_with_circuit_breakers, + reqwest_client, + retry_policy, + } => ( + endpoints_with_circuit_breakers, + reqwest_client, + retry_policy, + ), + _ => unreachable!("Called sync fetcher on async client"), + }; + let mut last_error: Option = None; + let mut last_resp = None; + for (endpoint, circuit_breaker) in endpoints_with_circuit_breakers { + let resp = { + circuit_breaker + .call(|| { + let request_start_time = SystemTime::now(); + // Implementing retry logic for sync client + let mut n_past_retries = 0; + loop { + let mut req = reqwest_client.get(endpoint); + let etag = self.get_latest_etag()?; + if let Some(etag) = etag { + req = req.header(IF_NONE_MATCH, etag); + } + let mut response = req.send().map_err(|err| { + SupergraphFetcherError::NetworkError( + reqwest_middleware::Error::Reqwest(err), + ) + }); + + // Server errors (5xx) are considered retryable + if let Ok(ok_res) = response { + response = if ok_res.status().is_server_error() { + Err(SupergraphFetcherError::NetworkError( + reqwest_middleware::Error::Middleware(anyhow::anyhow!( + "Server error: {}", + ok_res.status() + )), + )) + } else { + Ok(ok_res) + } + } + + match response { + Ok(resp) => break Ok(resp), + Err(e) => { + match retry_policy + .should_retry(request_start_time, n_past_retries) + { + RetryDecision::DoNotRetry => { + return Err(e); + } + RetryDecision::Retry { execute_after } => { + n_past_retries += 1; + if let Ok(duration) = execute_after.elapsed() { + std::thread::sleep(duration); + } + } + } + } + } + } + }) + // Map recloser errors to SupergraphFetcherError + .map_err(|e| match e { + recloser::Error::Inner(e) => e, + recloser::Error::Rejected => { + SupergraphFetcherError::RejectedByCircuitBreaker + } + }) + }; + match resp { + Err(e) => { + last_error = Some(e); + continue; + } + Ok(resp) => { + last_resp = Some(resp); + break; + } + } + } + + if let Some(last_resp) = last_resp { + if last_resp.status().as_u16() == 304 { + return Ok(None); + } + self.update_latest_etag(last_resp.headers().get("etag"))?; + let text = last_resp + .text() + .map_err(SupergraphFetcherError::NetworkResponseError)?; + Ok(Some(text)) + } else if let Some(error) = last_error { + Err(error) + } else { + Ok(None) + } + } + fn get_latest_etag(&self) -> Result, SupergraphFetcherError> { + let guard = self.etag.try_read().map_err(|e| { + SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e)) + })?; + + Ok(guard.clone()) + } + fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> { + let mut guard = self.etag.try_write().map_err(|e| { + SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e)) + })?; + + if let Some(etag_value) = etag { + *guard = Some(etag_value.clone()); + } else { + *guard = None; + } + + Ok(()) + } +} + +impl SupergraphFetcherBuilder { + /// Builds a synchronous SupergraphFetcher + pub fn build_sync( + self, + ) -> Result, SupergraphFetcherError> { + self.validate_endpoints()?; + let headers = self.prepare_headers()?; + + let mut reqwest_client = reqwest::blocking::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(headers); + + if let Some(user_agent) = &self.user_agent { + reqwest_client = reqwest_client.user_agent(user_agent); + } + + let reqwest_client = reqwest_client + .build() + .map_err(SupergraphFetcherError::FetcherCreationError)?; + let fetcher: SupergraphFetcher = SupergraphFetcher { + client: SupergraphFetcherAsyncOrSyncClient::Sync { + reqwest_client, + retry_policy: self.retry_policy, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(|endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .unwrap_or_default() + .build_sync() + .map_err(SupergraphFetcherError::CircuitBreakerCreationError); + circuit_breaker.map(|cb| (endpoint, cb)) + }) + .collect::, _>>()?, + }, + etag: RwLock::new(None), + state: std::marker::PhantomData, + }; + Ok(fetcher) + } +}