Skip to content

Commit dcb507d

Browse files
committed
feat: Improve OAuth error handling with custom AuthError type and better timeout management
1 parent ddeefb5 commit dcb507d

File tree

1 file changed

+53
-17
lines changed

1 file changed

+53
-17
lines changed

src/oauth.rs

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ const REDDIT_ANDROID_OAUTH_CLIENT_ID: &str = "ohXpoqrZYub1kg";
1515

1616
const AUTH_ENDPOINT: &str = "https://www.reddit.com";
1717

18+
const OAUTH_TIMEOUT: Duration = Duration::from_secs(5);
19+
1820
// Spoofed client for Android devices
1921
#[derive(Debug, Clone, Default)]
2022
pub struct Oauth {
@@ -32,24 +34,30 @@ impl Oauth {
3234
loop {
3335
let attempt = Self::new_with_timeout().await;
3436
match attempt {
35-
Ok(Some(oauth)) => {
37+
Ok(Ok(oauth)) => {
3638
info!("[✅] Successfully created OAuth client");
3739
return oauth;
3840
}
39-
Ok(None) => {
40-
error!("Failed to create OAuth client. Retrying in 5 seconds...");
41+
Ok(Err(e)) => {
42+
error!("Failed to create OAuth client: {}. Retrying in 5 seconds...", {
43+
match e {
44+
AuthError::Hyper(error) => error.to_string(),
45+
AuthError::SerdeDeserialize(error) => error.to_string(),
46+
AuthError::Field((value, error)) => format!("{error}\n{value}"),
47+
}
48+
});
4149
}
42-
Err(duration) => {
43-
error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds...");
50+
Err(_) => {
51+
error!("Failed to create OAuth client before timeout. Retrying in 5 seconds...");
4452
}
4553
}
46-
tokio::time::sleep(Duration::from_secs(5)).await;
54+
tokio::time::sleep(OAUTH_TIMEOUT).await;
4755
}
4856
}
4957

50-
async fn new_with_timeout() -> Result<Option<Self>, Elapsed> {
58+
async fn new_with_timeout() -> Result<Result<Self, AuthError>, Elapsed> {
5159
let mut oauth = Self::default();
52-
timeout(Duration::from_secs(5), oauth.login()).await.map(|result| result.map(|_| oauth))
60+
timeout(OAUTH_TIMEOUT, oauth.login()).await.map(|result: Result<(), AuthError>| result.map(|_| oauth))
5361
}
5462

5563
pub(crate) fn default() -> Self {
@@ -66,7 +74,7 @@ impl Oauth {
6674
device,
6775
}
6876
}
69-
async fn login(&mut self) -> Option<()> {
77+
async fn login(&mut self) -> Result<(), AuthError> {
7078
// Construct URL for OAuth token
7179
let url = format!("{AUTH_ENDPOINT}/auth/v2/oauth/access-token/loid");
7280
let mut builder = Request::builder().method(Method::POST).uri(&url);
@@ -95,7 +103,7 @@ impl Oauth {
95103

96104
// Send request
97105
let client: &once_cell::sync::Lazy<client::Client<_, Body>> = &CLIENT;
98-
let resp = client.request(request).await.ok()?;
106+
let resp = client.request(request).await?;
99107

100108
trace!("Received response with status {} and length {:?}", resp.status(), resp.headers().get("content-length"));
101109
trace!("OAuth headers: {:#?}", resp.headers());
@@ -106,30 +114,58 @@ impl Oauth {
106114
// Not worried about the privacy implications, since this is randomly changed
107115
// and really only as privacy-concerning as the OAuth token itself.
108116
if let Some(header) = resp.headers().get("x-reddit-loid") {
109-
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().ok()?.to_string());
117+
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().unwrap().to_string());
110118
}
111119

112120
// Same with x-reddit-session
113121
if let Some(header) = resp.headers().get("x-reddit-session") {
114-
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().ok()?.to_string());
122+
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().unwrap().to_string());
115123
}
116124

117125
trace!("Serializing response...");
118126

119127
// Serialize response
120-
let body_bytes = hyper::body::to_bytes(resp.into_body()).await.ok()?;
121-
let json: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
128+
let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
129+
let json: serde_json::Value = serde_json::from_slice(&body_bytes)?;
122130

123131
trace!("Accessing relevant fields...");
124132

125133
// Save token and expiry
126-
self.token = json.get("access_token")?.as_str()?.to_string();
127-
self.expires_in = json.get("expires_in")?.as_u64()?;
134+
self.token = json
135+
.get("access_token")
136+
.ok_or_else(|| AuthError::Field((json.clone(), "access_token")))?
137+
.as_str()
138+
.ok_or_else(|| AuthError::Field((json.clone(), "access_token: as_str")))?
139+
.to_string();
140+
self.expires_in = json
141+
.get("expires_in")
142+
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in")))?
143+
.as_u64()
144+
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in: as_u64")))?;
128145
self.headers_map.insert("Authorization".to_owned(), format!("Bearer {}", self.token));
129146

130147
info!("[✅] Success - Retrieved token \"{}...\", expires in {}", &self.token[..32], self.expires_in);
131148

132-
Some(())
149+
Ok(())
150+
}
151+
}
152+
153+
#[derive(Debug)]
154+
enum AuthError {
155+
Hyper(hyper::Error),
156+
SerdeDeserialize(serde_json::Error),
157+
Field((serde_json::Value, &'static str)),
158+
}
159+
160+
impl From<hyper::Error> for AuthError {
161+
fn from(err: hyper::Error) -> Self {
162+
AuthError::Hyper(err)
163+
}
164+
}
165+
166+
impl From<serde_json::Error> for AuthError {
167+
fn from(err: serde_json::Error) -> Self {
168+
AuthError::SerdeDeserialize(err)
133169
}
134170
}
135171

0 commit comments

Comments
 (0)