diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..0906815 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1,4 @@ +# High-performance Rust clippy configuration +cognitive-complexity-threshold = 25 +type-complexity-threshold = 600 +too-many-arguments-threshold = 10 diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..a778f02 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,11 @@ +# EditorConfig is awesome: https://editorconfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 \ No newline at end of file diff --git a/.github/workflows/chat-test.yml b/.github/workflows/chat-test.yml new file mode 100644 index 0000000..c969966 --- /dev/null +++ b/.github/workflows/chat-test.yml @@ -0,0 +1,136 @@ +name: chat ci + +on: + push: + branches: [ main, master ] + pull_request: + branches: [ main, master ] + +env: + CARGO_TERM_COLOR: always + RUST_VERSION: 1.91.1 + +jobs: + test-chat: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_VERSION }} + components: rustfmt, clippy + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Check formatting + run: | + cargo fmt --all -- --check + + - name: Check for clippy + run: | + cargo clippy --workspace --all-targets --all-features -- -D warnings + + - name: Run all tests + run: | + cargo test --all + + - name: Build the project + run: | + cargo build --workspace + + - name: Verify binaries can be executed + run: | + ./target/debug/chat-server --help + ./target/debug/chat-client --help + + - name: Start chat server + run: | + nohup cargo run --bin chat-server -- --host 127.0.0.1 --port 8080 --max-connections 100 > server.log 2>&1 & + SERVER_PID=$! + + # save the server PID for cleanup in later phase + echo "SERVER_PID=$SERVER_PID" >> $GITHUB_ENV + + # Wait a moment for the server to start (technically not required as it's rust ๐Ÿ˜‰) + sleep 3 + + # Check if server is running and listening on the port + if kill -0 $SERVER_PID 2>/dev/null; then + echo "Server started successfully with PID $SERVER_PID" + if lsof -i :8080 >/dev/null 2>&1; then + echo "Server is listening on port 8080" + else + echo "Server is not listening on port 8080" + cat server.log + exit 1 + fi + else + echo "Server failed to start" + cat server.log + exit 1 + fi + + - name: Test client connection and message sending + run: | + echo "Testing first client connection..." + { + echo "send Hello from test client" + echo "leave" + } | cargo run --bin chat-client -- --host 127.0.0.1 --port 8080 --username testuser1 + + echo "Testing second client connection..." + { + echo "send Hello from second client" + echo "leave" + } | cargo run --bin chat-client -- --host 127.0.0.1 --port 8080 --username testuser2 + + # Test 3: Test username collision detection + echo "Testing username collision..." + { + echo "send Message from duplicate user" + echo "leave" + } | cargo run --bin chat-client -- --host 127.0.0.1 --port 8080 --username testuser1 & + + sleep 1 + wait + + echo "Everything is fine" + + - name: Stop chat server + run: | + if [ ! -z "$SERVER_PID" ]; then + kill $SERVER_PID + wait $SERVER_PID 2>/dev/null || true + echo "server stopped" + + echo "=== Server Logs ===" + cat server.log + echo "=== End Server Logs ===" + fi + + - name: Verify no processes left running + run: | + if command -v lsof >/dev/null 2>&1; then + if lsof -i :8080 >/dev/null 2>&1; then + echo "Warning: Port 8080 still in use after server shutdown" + lsof -i :8080 + exit 1 + else + echo "Port 8080 is free" + fi + else + echo "lsof not available, skipping port check" + fi diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..0b8dd82 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,14 @@ +# High-performance Rust formatting configuration +tab_spaces = 2 +hard_tabs = false + +# Code organization +reorder_imports = true +reorder_modules = true + +# Performance-oriented settings +force_explicit_abi = false + +# Readability +use_field_init_shorthand = true +use_try_shorthand = true diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..ce072c8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "editor.tabSize": 2 +} diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..a85b264 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[workspace] +resolver = "3" +members = ["chat-client", "chat-core", "chat-server"] +default-members = ["chat-client", "chat-core", "chat-server"] +exclude = ["developer-tools"] + +[workspace.package] +version = "0.1.0" +edition = "2024" +rust-version = "1.91.1" + +[workspace.dependencies] +tokio = { version = "1", features = ["full"] } +bincode = "2" +uuid = { version = "1", features = ["v4"] } +serde = { version = "1" } +dashmap = "6" +anyhow = "1" +thiserror = "2" +tracing = "0.1" +clap = "4" +tracing-subscriber = "0.3" +bytes = "1" + +# Testing and benchmarking +criterion = "0.7" +tokio-test = "0.4" diff --git a/README.md b/README.md index 8c4d4e1..7d52440 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,19 @@ # Simple Chat +## ๐ŸŽฅ Demo Video +[Watch Demo (vimeo)](https://vimeo.com/1141420619?fl=pl&fe=sh) + +## ๐Ÿชถ Features +- **Protocol**: Custom binary protocol over TCP for effciency and throghput avoided json over websockets +- **Least Cloning**: efficient broadcasting, optimized for least message cloning regardless of scale +- **Least Serialization effort**: O(1) effort for message serialization regardless of amount of receivers with smart caching +- **Resource efficient**: Easy on memory and cpu +- **Graceful shutdowns**: graceful shutdowns for server and client + +## ๐Ÿซฃ Out of Scope + +- Encryption in transit +- State machines implemented at protocol levels + ## Summary diff --git a/chat-client/Cargo.toml b/chat-client/Cargo.toml new file mode 100644 index 0000000..1d979d4 --- /dev/null +++ b/chat-client/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "chat-client" +version.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +chat-core = { path = "../chat-core" } +clap = { workspace = true, features = ["derive"] } +anyhow = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } diff --git a/chat-client/src/client.rs b/chat-client/src/client.rs new file mode 100644 index 0000000..509d05c --- /dev/null +++ b/chat-client/src/client.rs @@ -0,0 +1,598 @@ +use anyhow::{Result, anyhow}; +use chat_core::protocol::{ClientMessage, ServerMessage, encode_message}; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::select; +use tokio::signal; +use tracing::{error, info, warn}; + +pub struct ChatClient { + username: String, + original_username: String, + reader: BufReader, + writer: OwnedWriteHalf, + join_attempts: u8, +} + +impl ChatClient { + pub async fn new(host: &str, port: u16, username: &str) -> Result { + let (reader, writer) = TcpStream::connect(format!("{}:{}", host, port)) + .await? + .into_split(); + info!("Connected to server at {}:{}", host, port); + + let mut client = Self { + reader: BufReader::new(reader), + writer, + username: username.to_string(), + original_username: username.to_string(), + join_attempts: 0, + }; + + client.join_with_retry().await?; + + Ok(client) + } + + pub async fn run(&mut self) -> Result<()> { + let mut stdin = tokio::io::BufReader::new(tokio::io::stdin()); + let mut buffer = vec![0u8; 4096]; + let mut stdin_buffer = String::new(); + let mut shutdown_signal = Box::pin(signal::ctrl_c()); + let mut should_shutdown = false; + + println!(); + println!("Welcome to the chat! Commands: 'send ', 'leave'"); + info!("Client event loop started for user: {}", self.username); + + loop { + select! { + + _ = shutdown_signal.as_mut() => { + info!("Received shutdown signal (Ctrl+C)"); + should_shutdown = true; + break; + } + + + result = self.reader.read(&mut buffer) => { + if should_shutdown { + break; + } + match result { + Ok(0) => { + info!("Server disconnected unexpectedly"); + println!(); + println!("โŒโŒ: Server disconnected!โŒโŒ press ENTER โ†ฉ๏ธ to exit!"); + println!(); + break; + } + Ok(n) => { + if let Err(e) = crate::message_handler::handle_server_message(&buffer[..n]).await { + error!("Error handling server message: {}", e); + } + } + Err(e) => { + error!("Error reading from server: {}", e); + break; + } + } + } + + + result = stdin.read_line(&mut stdin_buffer) => { + if should_shutdown { + break; + } + match result { + Ok(0) => { + info!("Stdin closed unexpectedly"); + println!("Stdin closed"); + break; + } + Ok(_) => { + let input = stdin_buffer.trim(); + if let Err(e) = self.handle_user_input(input).await { + if e.to_string().contains("Client requested leave") { + info!("User requested graceful shutdown via 'leave' command"); + should_shutdown = true; + break; + } else { + error!("Error handling user input: {}", e); + } + } + stdin_buffer.clear(); + } + Err(e) => { + error!("Error reading from stdin: {}", e); + break; + } + } + } + } + } + + if should_shutdown { + info!("Initiating graceful shutdown for user: {}", self.username); + if let Err(e) = self.graceful_leave().await { + error!("Failed to send leave message during shutdown: {}", e); + } else { + info!( + "Successfully sent leave message for user: {}", + self.username + ); + } + } else { + warn!( + "Client exiting without graceful shutdown for user: {}", + self.username + ); + } + + Ok(()) + } + + async fn handle_user_input(&mut self, input: &str) -> Result<()> { + let parts: Vec<&str> = input.splitn(2, ' ').collect(); + + match parts[0] { + "send" if parts.len() == 1 => { + println!("No message provided. Use 'send '"); + } + "send" if parts.len() > 1 => { + info!("User {} sending message: {}", self.username, parts[1]); + let message = ClientMessage::Message { + username: self.username.clone(), + content: parts[1].to_string(), + }; + self.send_client_message(message).await?; + } + "leave" => { + info!("User {} requested leave via command", self.username); + println!("Leaving chat..."); + let message = ClientMessage::Leave { + username: self.username.clone(), + }; + self.send_client_message(message).await?; + return Err(anyhow!("Client requested leave")); + } + _ => { + println!("Unknown command. Use 'send ' or 'leave'"); + } + } + + Ok(()) + } + + pub async fn send_client_message(&mut self, message: ClientMessage) -> Result<()> { + let frame = encode_message(&message)?; + self.writer.write_all(&frame).await?; + Ok(()) + } + + pub async fn graceful_leave(&mut self) -> Result<()> { + info!("Sending leave message for user: {}", self.username); + let message = ClientMessage::Leave { + username: self.username.clone(), + }; + self.send_client_message(message).await?; + + self.writer.flush().await?; + + info!( + "Leave message sent successfully for user: {}", + self.username + ); + + match self.writer.shutdown().await { + Ok(()) => info!("Writer shutdown successful for user: {}", self.username), + Err(e) => warn!("Writer shutdown failed for user: {}: {}", self.username, e), + } + + Ok(()) + } + + /// Join the chat with retry + async fn join_with_retry(&mut self) -> Result<()> { + const MAX_JOIN_ATTEMPTS: u8 = 3; + + loop { + self.join_attempts += 1; + + info!( + "Attempting to join with username: {} (attempt {}/{})", + self.username, self.join_attempts, MAX_JOIN_ATTEMPTS + ); + + let message = ClientMessage::Join { + username: self.username.clone(), + }; + self.send_client_message(message).await?; + + let mut buffer = vec![0u8; 4096]; + + let reader = self.reader.get_mut(); + match chat_core::transport_layer::read_message_from_stream(reader, &mut buffer).await { + Ok(ServerMessage::Success { message }) => { + info!("Successfully joined chat: {}", message); + println!("โœ… {}", message); + return Ok(()); + } + Ok(ServerMessage::UserNameAlreadyTaken { username }) => { + warn!("username '{}' is already taken", username); + + if self.join_attempts >= MAX_JOIN_ATTEMPTS { + return Err(anyhow!( + "Failed to join after {} attempts. username '{}' is taken and no alternatives were accepted.", + MAX_JOIN_ATTEMPTS, + username + )); + } + + let suggested_username = crate::username_handler::generate_alternative_username( + &self.username, + self.join_attempts, + ); + match crate::username_handler::prompt_username_selection_loop( + &self.original_username, + &suggested_username, + ) + .await? + { + Some(new_username) => { + self.username = new_username; + } + None => { + return Err(anyhow!("User cancelled join process")); + } + } + } + Ok(ServerMessage::Error { reason }) => { + return Err(anyhow!("Server error during join: {}", reason)); + } + Ok(_) => { + return Err(anyhow!("Unexpected server response during join process")); + } + Err(e) => { + return Err(anyhow!("Application error: {}", e)); + } + } + } + } +} + +#[cfg(test)] +mod tests { + + use chat_core::protocol::{ClientMessage, ServerMessage, decode_message, encode_message}; + + #[test] + fn test_client_message_username_extraction() { + let join_message = ClientMessage::Join { + username: "alice".to_string(), + }; + assert_eq!(join_message.username(), Some("alice")); + + let leave_message = ClientMessage::Leave { + username: "bob".to_string(), + }; + assert_eq!(leave_message.username(), Some("bob")); + + let message_message = ClientMessage::Message { + username: "charlie".to_string(), + content: "Hello".to_string(), + }; + assert_eq!(message_message.username(), Some("charlie")); + } + + #[tokio::test] + async fn test_client_message_serialization() { + let messages = vec![ + ClientMessage::Join { + username: "alice".to_string(), + }, + ClientMessage::Leave { + username: "bob".to_string(), + }, + ClientMessage::Message { + username: "charlie".to_string(), + content: "Hello, world!".to_string(), + }, + ]; + + for message in messages { + let encoded = encode_message(&message).expect("Failed to encode client message"); + assert!( + encoded.len() > 4, + "Encoded message should have length prefix and payload" + ); + + let payload = &encoded[4..]; + + let decoded: ClientMessage = + decode_message(payload).expect("Failed to decode client message"); + + match (&message, decoded) { + ( + ClientMessage::Join { + username: orig_user, + }, + ClientMessage::Join { username: dec_user }, + ) => { + assert_eq!(orig_user, &dec_user); + } + ( + ClientMessage::Leave { + username: orig_user, + }, + ClientMessage::Leave { username: dec_user }, + ) => { + assert_eq!(orig_user, &dec_user); + } + ( + ClientMessage::Message { + username: orig_user, + content: orig_content, + }, + ClientMessage::Message { + username: dec_user, + content: dec_content, + }, + ) => { + assert_eq!(orig_user, &dec_user); + assert_eq!(orig_content, &dec_content); + } + _ => panic!("Message type mismatch after encoding/decoding"), + } + } + } + + #[tokio::test] + async fn test_server_message_handling() { + let test_cases = vec![ + ServerMessage::Message { + username: "alice".to_string(), + content: "Hello everyone!".to_string(), + }, + ServerMessage::Error { + reason: "Username already taken".to_string(), + }, + ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }, + ServerMessage::UserJoined { + username: "newuser".to_string(), + }, + ServerMessage::UserLeft { + username: "leavinguser".to_string(), + }, + ]; + + for message in test_cases { + let encoded = encode_message(&message).expect("Failed to encode server message"); + assert!( + encoded.len() > 4, + "Encoded message should have length prefix and payload" + ); + + let payload = &encoded[4..]; + + let decoded: ServerMessage = + decode_message(payload).expect("Failed to decode server message"); + + match (&message, decoded) { + (ServerMessage::Message { .. }, ServerMessage::Message { .. }) => {} + (ServerMessage::Error { .. }, ServerMessage::Error { .. }) => {} + (ServerMessage::Success { .. }, ServerMessage::Success { .. }) => {} + (ServerMessage::UserJoined { .. }, ServerMessage::UserJoined { .. }) => {} + (ServerMessage::UserLeft { .. }, ServerMessage::UserLeft { .. }) => {} + _ => panic!("Server message type mismatch after encoding/decoding"), + } + } + } + + #[test] + fn test_message_length_prefix() { + let test_message = ClientMessage::Message { + username: "test".to_string(), + content: "Hello".to_string(), + }; + + let encoded = encode_message(&test_message).expect("Failed to encode test message"); + + assert!( + encoded.len() > 4, + "Message should have length prefix and payload" + ); + + let length_bytes = &encoded[0..4]; + let length = u32::from_be_bytes([ + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + ]) as usize; + + assert_eq!( + length, + encoded.len() - 4, + "Length prefix should match payload size" + ); + } + + #[test] + fn test_large_message_serialization() { + let large_content = "x".repeat(10000); + let test_message = ClientMessage::Message { + username: "largeuser".to_string(), + content: large_content, + }; + + let encoded = encode_message(&test_message).expect("Failed to encode large message"); + assert!( + encoded.len() > 10000, + "Large message should result in substantial encoded size" + ); + + let payload = &encoded[4..]; + + let decoded: ClientMessage = decode_message(payload).expect("Failed to decode large message"); + + match decoded { + ClientMessage::Message { username, content } => { + assert_eq!(username, "largeuser"); + assert_eq!(content.len(), 10000); + } + _ => panic!("Expected large ClientMessage::Message"), + } + } + + #[test] + fn test_empty_message_serialization() { + let test_message = ClientMessage::Message { + username: "".to_string(), + content: "".to_string(), + }; + + let encoded = encode_message(&test_message).expect("Failed to encode empty message"); + assert!( + encoded.len() >= 4, + "Message should have at least length prefix" + ); + + let payload = &encoded[4..]; + + let decoded: ClientMessage = decode_message(payload).expect("Failed to decode empty message"); + + match decoded { + ClientMessage::Message { username, content } => { + assert_eq!(username, ""); + assert_eq!(content, ""); + } + _ => panic!("Expected empty ClientMessage::Message"), + } + } + + #[test] + fn test_message_equality() { + let msg1 = ClientMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }; + + let msg2 = ClientMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }; + + let encoded1 = encode_message(&msg1).expect("Failed to encode first message"); + let encoded2 = encode_message(&msg2).expect("Failed to encode second message"); + + assert_eq!( + encoded1, encoded2, + "Identical messages should produce identical encoded data" + ); + } + + #[test] + fn test_username_validation_edge_cases() { + let valid_usernames: Vec = vec![ + "alice".to_string(), + "user123".to_string(), + "test_user".to_string(), + "a".repeat(50), + ]; + + let edge_case_usernames: Vec = vec!["".to_string(), "a".repeat(1000)]; + + for username in valid_usernames { + let message = ClientMessage::Join { username }; + let encoded = encode_message(&message).expect("Should encode valid username"); + assert!(encoded.len() > 4); + } + + for username in edge_case_usernames { + let message = ClientMessage::Join { username }; + let encoded = encode_message(&message).expect("Should encode edge case username"); + assert!(encoded.len() > 4); + } + } + + #[test] + fn test_content_validation_edge_cases() { + let valid_contents: Vec = vec![ + "Hello world".to_string(), + "".to_string(), + "a".repeat(1000), + "Special chars: !@#$%^&*()".to_string(), + "Unicode: ๐Ÿš€โœจ๐ŸŽ‰".to_string(), + ]; + + for content in valid_contents { + let message = ClientMessage::Message { + username: "testuser".to_string(), + content: content.clone(), + }; + let encoded = encode_message(&message).expect("Should encode valid content"); + assert!(encoded.len() > 4); + + let payload = &encoded[4..]; + let decoded: ClientMessage = decode_message(payload).expect("Should decode message"); + + match decoded { + ClientMessage::Message { + content: decoded_content, + .. + } => { + assert_eq!(content, decoded_content); + } + _ => panic!("Expected Message variant"), + } + } + } + + #[test] + fn test_command_parsing_logic() { + let test_cases = vec![ + ("send Hello world", Some(("send", "Hello world"))), + ("send ", Some(("send", ""))), + ("leave", None), + ("unknown", None), + ("", None), + ("send", None), + ]; + + for (input, expected) in test_cases { + let parts: Vec<&str> = input.splitn(2, ' ').collect(); + + match expected { + Some((command, content)) => { + assert_eq!( + parts.len(), + 2, + "Input '{}' should split into 2 parts", + input + ); + assert_eq!( + parts[0], command, + "First part should be command for input '{}'", + input + ); + assert_eq!( + parts[1], content, + "Second part should be content for input '{}'", + input + ); + } + None => { + assert!( + !parts.is_empty(), + "Input '{}' should have at least 1 part", + input + ); + } + } + } + } +} diff --git a/chat-client/src/lib.rs b/chat-client/src/lib.rs new file mode 100644 index 0000000..0fef2b7 --- /dev/null +++ b/chat-client/src/lib.rs @@ -0,0 +1,3 @@ +pub mod client; +pub mod message_handler; +pub mod username_handler; diff --git a/chat-client/src/main.rs b/chat-client/src/main.rs new file mode 100644 index 0000000..e5f39c5 --- /dev/null +++ b/chat-client/src/main.rs @@ -0,0 +1,33 @@ +use anyhow::Result; +use chat_client::client::ChatClient; +use chat_core::utils::validate_username; +use clap::Parser; +use tracing::Level; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + #[arg(long, default_value = "127.0.0.1")] + host: String, + + #[arg(long, default_value_t = 8080)] + port: u16, + + #[arg(short, long)] + username: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::fmt() + .with_max_level(Level::ERROR) + .init(); + + let args = Args::parse(); + + validate_username(&args.username)?; + let mut client = ChatClient::new(&args.host, args.port, &args.username).await?; + client.run().await?; + + Ok(()) +} diff --git a/chat-client/src/message_handler.rs b/chat-client/src/message_handler.rs new file mode 100644 index 0000000..a7b894c --- /dev/null +++ b/chat-client/src/message_handler.rs @@ -0,0 +1,125 @@ +use anyhow::anyhow; +use chat_core::protocol::{LENGTH_PREFIX, ServerMessage, decode_message}; + +pub async fn handle_server_message(data: &[u8]) -> anyhow::Result<()> { + if data.len() < LENGTH_PREFIX { + return Err(anyhow!("Message too short")); + } + + let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize; + if data.len() < LENGTH_PREFIX + length { + return Err(anyhow!("Incomplete message")); + } + + let message = decode_message(&data[LENGTH_PREFIX..LENGTH_PREFIX + length]) + .map_err(|e| anyhow!("Failed to decode server message: {}", e))?; + + match message { + ServerMessage::Message { username, content } => { + println!("๐Ÿ—จ๏ธ: {}: {}", username, content); + } + ServerMessage::Error { reason } => { + println!("โŒ: {}", reason); + } + ServerMessage::UserNameAlreadyTaken { username } => { + println!("โŒ: username `{}` not available", username); + } + ServerMessage::Success { message } => { + println!("๐Ÿ’: {}", message); + } + ServerMessage::UserJoined { username } => { + println!("๐Ÿ“ข: `{}` joined the chat", username); + } + ServerMessage::UserLeft { username } => { + println!("๐Ÿ“ข: `{}` left the chat", username); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use chat_core::protocol::{ServerMessage, encode_message}; + + #[tokio::test] + async fn test_handle_server_message_message() { + let message = ServerMessage::Message { + username: "alice".to_string(), + content: "Hello everyone!".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_error() { + let message = ServerMessage::Error { + reason: "Test error".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_success() { + let message = ServerMessage::Success { + message: "Welcome to chat!".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_user_joined() { + let message = ServerMessage::UserJoined { + username: "bob".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_user_left() { + let message = ServerMessage::UserLeft { + username: "charlie".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_username_not_found() { + let message = ServerMessage::UserNameAlreadyTaken { + username: "takenuser".to_string(), + }; + + let encoded = encode_message(&message).unwrap(); + let result = handle_server_message(&encoded).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_server_message_invalid_data() { + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; + let result = handle_server_message(&invalid_data).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_handle_server_message_short_data() { + let short_data = vec![0x00, 0x00]; + let result = handle_server_message(&short_data).await; + assert!(result.is_err()); + } +} diff --git a/chat-client/src/username_handler.rs b/chat-client/src/username_handler.rs new file mode 100644 index 0000000..2c5a885 --- /dev/null +++ b/chat-client/src/username_handler.rs @@ -0,0 +1,97 @@ +use anyhow::anyhow; + +/// Generate alternative usernames by appending numbers +pub fn generate_alternative_username(username: &str, attempts: u8) -> String { + format!("{}_{:02}", username, attempts + 1) +} + +/// Read user input with a prompt +async fn read_user_input(prompt: &str) -> anyhow::Result { + print!("{}", prompt); + std::io::Write::flush(&mut std::io::stdout())?; + + let mut input = String::new(); + std::io::stdin() + .read_line(&mut input) + .map_err(|e| anyhow!("Failed to read input: {}", e))?; + + Ok(input.trim().to_string()) +} + +pub async fn prompt_username_selection_loop( + taken: &str, + suggested: &str, +) -> anyhow::Result> { + loop { + println!(); + println!("Username '{}' is already taken.", taken); + println!("Suggested alternative: '{}'", suggested); + println!(); + println!("Options:"); + println!(" 1. Use suggested username: '{}'", suggested); + println!(" 2. Enter a custom username"); + println!(" 3. Cancel and exit"); + print!("Please choose (1-3) or enter custom username directly: "); + std::io::Write::flush(&mut std::io::stdout())?; + + let choice = read_user_input("").await?; + + match choice.as_str() { + "1" => { + println!("Using suggested username: '{}'", suggested); + return Ok(Some(suggested.to_string())); + } + "2" => { + let custom = read_user_input("Enter your desired username: ").await?; + if !custom.is_empty() { + println!("Using custom username: '{}'", custom); + return Ok(Some(custom)); + } + println!("Username cannot be empty. Please try again."); + } + "3" => { + println!("Join cancelled by user."); + return Ok(None); + } + _ if !choice.is_empty() => { + println!("Using custom username: '{}'", choice); + return Ok(Some(choice)); + } + _ => { + println!("Invalid choice. Please try again."); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_alternative_username() { + let username = "testuser"; + let alternative = generate_alternative_username(username, 0); + assert_eq!(alternative, "testuser_01"); + + let alternative = generate_alternative_username(username, 1); + assert_eq!(alternative, "testuser_02"); + + let alternative = generate_alternative_username(username, 9); + assert_eq!(alternative, "testuser_10"); + } + + #[test] + fn test_generate_alternative_username_with_special_chars() { + let username = "user@domain"; + let alternative = generate_alternative_username(username, 0); + assert_eq!(alternative, "user@domain_01"); + } + + #[test] + fn test_generate_alternative_username_empty() { + let username = ""; + let alternative = generate_alternative_username(username, 0); + assert_eq!(alternative, "_01"); + } +} diff --git a/chat-client/tests/integration_client.rs b/chat-client/tests/integration_client.rs new file mode 100644 index 0000000..a62bf31 --- /dev/null +++ b/chat-client/tests/integration_client.rs @@ -0,0 +1,351 @@ +use anyhow::Result; +use chat_client::client::ChatClient; +use chat_core::protocol::{ClientMessage, ServerMessage, decode_message, encode_message}; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::sleep; + +#[tokio::test] +async fn test_client_connection_and_basic_communication() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 1024]; + let _n = reader.read(&mut buffer).await.unwrap(); + + let length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize; + let payload = &buffer[4..4 + length]; + let message: ClientMessage = decode_message(payload).unwrap(); + + match message { + ClientMessage::Join { username } => { + assert_eq!(username, "testuser"); + + let success_msg = ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }; + let response = encode_message(&success_msg).unwrap(); + writer.write_all(&response).await.unwrap(); + writer.flush().await.unwrap(); + } + _ => panic!("Expected join message"), + } + + loop { + match reader.read(&mut buffer).await { + Ok(0) => { + break; + } + Ok(n) => { + if n < 4 { + continue; + } + + let length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize; + if n < 4 + length { + continue; + } + + let payload = &buffer[4..4 + length]; + let message: ClientMessage = decode_message(payload).unwrap(); + + match message { + ClientMessage::Message { username, content } => { + assert_eq!(username, "testuser"); + assert_eq!(content, "Hello, server!"); + + let server_msg = ServerMessage::Message { + username: "server".to_string(), + content: "Echo: Hello, server!".to_string(), + }; + let response = encode_message(&server_msg).unwrap(); + writer.write_all(&response).await.unwrap(); + writer.flush().await.unwrap(); + } + ClientMessage::Leave { username } => { + assert_eq!(username, "testuser"); + break; + } + _ => {} + } + } + Err(_) => { + break; + } + } + } + }); + + sleep(Duration::from_millis(100)).await; + + let mut client = ChatClient::new("127.0.0.1", addr.port(), "testuser").await?; + + let test_message = ClientMessage::Message { + username: "testuser".to_string(), + content: "Hello, server!".to_string(), + }; + client.send_client_message(test_message).await?; + + client.graceful_leave().await?; + + server_handle.await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_client_message_handling_and_parsing() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 1024]; + let _n = reader.read(&mut buffer).await.unwrap(); + + let success_msg = ServerMessage::Success { + message: "Welcome!".to_string(), + }; + let response = encode_message(&success_msg).unwrap(); + writer.write_all(&response).await.unwrap(); + + let test_messages = vec![ + ServerMessage::Message { + username: "alice".to_string(), + content: "Hello everyone!".to_string(), + }, + ServerMessage::UserJoined { + username: "bob".to_string(), + }, + ServerMessage::UserLeft { + username: "alice".to_string(), + }, + ServerMessage::Error { + reason: "Invalid command".to_string(), + }, + ]; + + for msg in test_messages { + let response = encode_message(&msg).unwrap(); + writer.write_all(&response).await.unwrap(); + writer.flush().await.unwrap(); + sleep(Duration::from_millis(50)).await; + } + + let _n = reader.read(&mut buffer).await.unwrap(); + }); + + sleep(Duration::from_millis(100)).await; + + let mut client = ChatClient::new("127.0.0.1", addr.port(), "testuser2").await?; + + sleep(Duration::from_millis(300)).await; + + client.graceful_leave().await?; + + server_handle.await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_client_error_handling_and_edge_cases() -> Result<()> { + let connection_result = ChatClient::new("127.0.0.1", 9999, "testuser").await; + assert!( + connection_result.is_err(), + "Should fail to connect to non-existent server" + ); + + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + drop(stream); + }); + + sleep(Duration::from_millis(100)).await; + + let client_result = ChatClient::new("127.0.0.1", addr.port(), "testuser").await; + + if let Ok(mut client) = client_result { + let test_message = ClientMessage::Message { + username: "testuser".to_string(), + content: "Test message to closed connection".to_string(), + }; + + let send_result = client.send_client_message(test_message).await; + + match send_result { + Ok(_) => { + println!("โœ“ Message sent successfully (connection was still open when message was sent)"); + } + Err(e) => { + println!( + "Expected I/O error when sending to closed connection: {}", + e + ); + + let error_msg = e.to_string().to_lowercase(); + assert!( + error_msg.contains("connection") + || error_msg.contains("broken") + || error_msg.contains("reset") + || error_msg.contains("pipe") + || error_msg.contains("eof"), + "Expected network-related error, got: {}", + e + ); + } + } + + let leave_result = client.graceful_leave().await; + match leave_result { + Ok(_) => { + println!("โœ“ Graceful leave succeeded (connection was handled properly)"); + } + Err(e) => { + println!("โœ“ Graceful leave encountered expected error: {}", e); + } + } + } + + server_handle.await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_client_protocol_compliance_and_message_serialization() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + + let mut length_bytes = [0u8; 4]; + reader.read_exact(&mut length_bytes).await.unwrap(); + let length = u32::from_be_bytes(length_bytes) as usize; + + if buffer.len() < length { + buffer.resize(length, 0); + } + reader.read_exact(&mut buffer[..length]).await.unwrap(); + + let join_message: ClientMessage = decode_message(&buffer[..length]).unwrap(); + + match join_message { + ClientMessage::Join { username } => { + assert_eq!(username, "protocol_test"); + } + _ => panic!("Expected join message"), + } + + let success_msg = ServerMessage::Success { + message: "Join successful".to_string(), + }; + let response = encode_message(&success_msg).unwrap(); + writer.write_all(&response).await.unwrap(); + + for i in 0..4 { + let mut length_bytes = [0u8; 4]; + reader.read_exact(&mut length_bytes).await.unwrap(); + let length = u32::from_be_bytes(length_bytes) as usize; + + if buffer.len() < length { + buffer.resize(length, 0); + } + reader.read_exact(&mut buffer[..length]).await.unwrap(); + + let message: ClientMessage = decode_message(&buffer[..length]).unwrap(); + + match message { + ClientMessage::Message { username, content } => { + assert_eq!(username, "protocol_test"); + assert_eq!(content, format!("Test message {}", i + 1)); + } + ClientMessage::Leave { username } => { + assert_eq!(username, "protocol_test"); + + assert_eq!(i, 3, "Leave message should be the last message"); + } + _ => panic!("Unexpected message type"), + } + } + }); + + sleep(Duration::from_millis(100)).await; + + let mut client = ChatClient::new("127.0.0.1", addr.port(), "protocol_test").await?; + + for i in 0..3 { + let message = ClientMessage::Message { + username: "protocol_test".to_string(), + content: format!("Test message {}", i + 1), + }; + client.send_client_message(message).await?; + } + + let leave_message = ClientMessage::Leave { + username: "protocol_test".to_string(), + }; + client.send_client_message(leave_message).await?; + + server_handle.await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_client_connection_cleanup() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 1024]; + + let _n = reader.read(&mut buffer).await.unwrap(); + + let success_msg = ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }; + let response = encode_message(&success_msg).unwrap(); + writer.write_all(&response).await.unwrap(); + writer.flush().await.unwrap(); + + let _n = reader.read(&mut buffer).await.unwrap_or(0); + + let _ = writer.shutdown().await; + }); + + sleep(Duration::from_millis(100)).await; + + let mut client = ChatClient::new("127.0.0.1", addr.port(), "cleanup_test").await?; + + let test_message = ClientMessage::Message { + username: "cleanup_test".to_string(), + content: "Testing connection cleanup".to_string(), + }; + + client.send_client_message(test_message).await?; + + client.graceful_leave().await?; + + server_handle.await.unwrap(); + + Ok(()) +} diff --git a/chat-core/Cargo.toml b/chat-core/Cargo.toml new file mode 100644 index 0000000..e5fb8ba --- /dev/null +++ b/chat-core/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "chat-core" +version.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +dashmap = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +bytes = {workspace = true} +thiserror = { workspace = true } +bincode = { workspace = true} +serde = { workspace = true } +tracing = {workspace = true} diff --git a/chat-core/src/error.rs b/chat-core/src/error.rs new file mode 100644 index 0000000..03a84b0 --- /dev/null +++ b/chat-core/src/error.rs @@ -0,0 +1,216 @@ +use std::io; + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ApplicationError { + #[error("Encoding error: {0}")] + Encoding(#[from] bincode::error::EncodeError), + #[error("Decoding error: {0}")] + Decoding(#[from] bincode::error::DecodeError), + #[error("Stream I/O error: {0}")] + StreamIoError(#[from] io::Error), + #[error("Client stream closed")] + ClientReadStreamClosed, + #[error("Incomplete length prefix")] + IncompleteLengthPrefix, + #[error("Incomplete payload")] + IncompletePyaload, + #[error("username not found")] + UsernameNotFound, + #[error("Message too large: {size} bytes exceeds maximum of {max_size} bytes")] + MessageTooLarge { size: usize, max_size: usize }, + #[error("Invalid ussername: {0}")] + InvalidUsername(String), + #[error("Connection error: {0}")] + ConnectionError(String), + #[error("Broadcast error: {0}")] + BroadcastError(String), +} + +impl ApplicationError { + pub fn message_too_large(size: usize, max_size: usize) -> Self { + ApplicationError::MessageTooLarge { size, max_size } + } + pub fn invalid_username(reason: String) -> Self { + ApplicationError::InvalidUsername(reason) + } + pub fn connection_error(reason: String) -> Self { + ApplicationError::ConnectionError(reason) + } + pub fn broadcast_error(reason: String) -> Self { + ApplicationError::BroadcastError(reason) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_derivation_for_encoding() { + let test_data = vec![1, 2, 3, 4]; + let encode_result = bincode::encode_to_vec(&test_data, bincode::config::standard()); + + let encode_error = match encode_result { + Ok(_) => return, + Err(e) => e, + }; + let app_error: ApplicationError = encode_error.into(); + + assert!(matches!(app_error, ApplicationError::Encoding(_))); + assert!(app_error.to_string().contains("Error in encoding")); + } + + #[test] + fn test_error_derivation_for_decoding() { + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; + + let decode_result: Result<(String, usize), bincode::error::DecodeError> = + bincode::decode_from_slice(&invalid_data, bincode::config::standard()); + + if let Err(decode_error) = decode_result { + let app_error: ApplicationError = decode_error.into(); + + assert!(matches!(app_error, ApplicationError::Decoding(_))); + assert!(app_error.to_string().contains("Decoding error")); + } + } + + #[test] + fn test_error_derivation_for_io_error() { + let io_error = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Connection lost"); + let app_error: ApplicationError = io_error.into(); + + assert!(matches!(app_error, ApplicationError::StreamIoError(_))); + assert!(app_error.to_string().contains("Stream I/O error")); + } + + #[test] + fn test_client_read_stream_closed_error() { + let error = ApplicationError::ClientReadStreamClosed; + assert!(error.to_string().contains("Client stream closed")); + } + + #[test] + fn test_incomplete_length_prefix_error() { + let error = ApplicationError::IncompleteLengthPrefix; + assert!(error.to_string().contains("Incomplete length prefix")); + } + + #[test] + fn test_incomplete_payload_error() { + let error = ApplicationError::IncompletePyaload; + assert!(error.to_string().contains("Incomplete payload")); + } + + #[test] + fn test_username_not_found_error() { + let error = ApplicationError::UsernameNotFound; + assert!(error.to_string().contains("username not found")); + } + + #[test] + fn test_error_display_formatting() { + let io_error = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "Access denied"); + let app_error: ApplicationError = io_error.into(); + + let display_string = format!("{}", app_error); + assert!(display_string.contains("Stream I/O error")); + } + + #[test] + fn test_error_kind_discrimination() { + let errors = [ + ApplicationError::ClientReadStreamClosed, + ApplicationError::IncompleteLengthPrefix, + ApplicationError::IncompletePyaload, + ApplicationError::UsernameNotFound, + ]; + + assert!(matches!( + errors[0], + ApplicationError::ClientReadStreamClosed + )); + assert!(matches!( + errors[1], + ApplicationError::IncompleteLengthPrefix + )); + assert!(matches!(errors[2], ApplicationError::IncompletePyaload)); + assert!(matches!(errors[3], ApplicationError::UsernameNotFound)); + + assert!(!matches!( + errors[0], + ApplicationError::IncompleteLengthPrefix + )); + assert!(!matches!(errors[1], ApplicationError::UsernameNotFound)); + assert!(!matches!( + errors[2], + ApplicationError::ClientReadStreamClosed + )); + assert!(!matches!(errors[3], ApplicationError::IncompletePyaload)); + } + + #[test] + fn test_error_source_chaining() { + let io_error = std::io::Error::other("Underlying IO error"); + let app_error: ApplicationError = io_error.into(); + + use std::error::Error; + if let Some(source) = app_error.source() { + assert!(source.to_string().contains("Underlying IO error")); + } else { + panic!("Expected error source to be preserved"); + } + } + + #[test] + fn test_complex_error_hierarchy() { + let nested_io_error = std::io::Error::other("Deep nested error"); + let app_error: ApplicationError = nested_io_error.into(); + + assert!(matches!(app_error, ApplicationError::StreamIoError(_))); + + if let ApplicationError::StreamIoError(io_err) = &app_error { + assert_eq!(io_err.to_string(), "Deep nested error"); + } + } + + #[test] + fn test_error_debug_formatting() { + let error = ApplicationError::UsernameNotFound; + let debug_string = format!("{:?}", error); + assert!(debug_string.contains("UsernameNotFound")); + + let io_error = std::io::Error::other("Debug test"); + let app_error: ApplicationError = io_error.into(); + let debug_string = format!("{:?}", app_error); + assert!(debug_string.contains("StreamIoError")); + } + + #[test] + fn test_error_edge_cases() { + let empty_io_error = std::io::Error::other(""); + let app_error: ApplicationError = empty_io_error.into(); + assert!(app_error.to_string().contains("Stream I/O error")); + + let long_message = "A".repeat(1000); + let long_io_error = std::io::Error::other(long_message); + let app_error: ApplicationError = long_io_error.into(); + assert!(app_error.to_string().contains("Stream I/O error")); + } + + #[test] + fn test_bincode_error_conversion() { + let test_data = vec![1, 2, 3, 4]; + if let Err(encode_error) = bincode::encode_to_vec(&test_data, bincode::config::standard()) { + let app_error: ApplicationError = encode_error.into(); + assert!(matches!(app_error, ApplicationError::Encoding(_))); + assert!(app_error.to_string().contains("Error in encoding")); + } + + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; + if bincode::decode_from_slice::(&invalid_data, bincode::config::standard()).is_ok() { + } + } +} diff --git a/chat-core/src/lib.rs b/chat-core/src/lib.rs new file mode 100644 index 0000000..268e082 --- /dev/null +++ b/chat-core/src/lib.rs @@ -0,0 +1,5 @@ +pub mod error; +pub mod message_cahce; +pub mod protocol; +pub mod transport_layer; +pub mod utils; diff --git a/chat-core/src/message_cahce.rs b/chat-core/src/message_cahce.rs new file mode 100644 index 0000000..ecf38c8 --- /dev/null +++ b/chat-core/src/message_cahce.rs @@ -0,0 +1,133 @@ +use std::{ + collections::HashMap, + hash::{DefaultHasher, Hash, Hasher}, + sync::Arc, +}; + +use bincode::Encode; +use bytes::Bytes; +use tokio::sync::RwLock; + +use crate::protocol::BINCODE_STANDADRD_CONFIG; + +pub struct MessageCache { + cache: Arc>>, + capacity: usize, +} + +impl MessageCache { + pub fn new(capacity: usize) -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::new())), + capacity, + } + } + + pub async fn get(&self, key: u64) -> Option { + let cache = self.cache.read().await; + cache.get(&key).cloned() + } + + pub async fn put(&self, key: u64, value: Bytes) { + let mut cache = self.cache.write().await; + if cache.len() >= self.capacity + && let Some(key) = cache.keys().next().cloned() + { + cache.remove(&key); + } + cache.insert(key, value); + } + + pub fn hash_message(message: &T) -> Result { + let encoded = bincode::encode_to_vec(message, BINCODE_STANDADRD_CONFIG)?; + let mut hasher = DefaultHasher::new(); + encoded.hash(&mut hasher); + Ok(hasher.finish()) + } +} + +impl Clone for MessageCache { + fn clone(&self) -> Self { + Self { + cache: Arc::clone(&self.cache), + capacity: self.capacity, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_message_cache_new() { + let cache = MessageCache::new(10); + assert_eq!(cache.capacity, 10); + } + + #[tokio::test] + async fn test_message_cache_get_put() { + let cache = MessageCache::new(2); + let test_message = "test message".to_string(); + + let hash = MessageCache::hash_message(&test_message).unwrap(); + + cache.put(hash, Bytes::from(test_message.clone())).await; + + let result = cache.get(hash).await; + assert!(result.is_some()); + assert_eq!(result.unwrap(), Bytes::from(test_message)); + } + + #[tokio::test] + async fn test_message_cache_capacity_limits() { + let cache = MessageCache::new(2); + + let message1 = "message1".to_string(); + let message2 = "message2".to_string(); + let message3 = "message3".to_string(); + + let hash1 = MessageCache::hash_message(&message1).unwrap(); + let hash2 = MessageCache::hash_message(&message2).unwrap(); + let hash3 = MessageCache::hash_message(&message3).unwrap(); + + cache.put(hash1, Bytes::from(message1.clone())).await; + cache.put(hash2, Bytes::from(message2.clone())).await; + + assert!(cache.get(hash1).await.is_some()); + assert!(cache.get(hash2).await.is_some()); + + cache.put(hash3, Bytes::from(message3.clone())).await; + + let result1 = cache.get(hash1).await; + let result2 = cache.get(hash2).await; + let result3 = cache.get(hash3).await; + + assert!(result3.is_some()); + assert_eq!(result3.unwrap(), Bytes::from(message3)); + + assert!(result1.is_none() || result2.is_none()); + } + + #[test] + fn test_hash_message() { + let message1 = "test message".to_string(); + let message2 = "different message".to_string(); + + let hash1 = MessageCache::hash_message(&message1).unwrap(); + let hash2 = MessageCache::hash_message(&message2).unwrap(); + + assert_ne!(hash1, hash2); + } + + #[test] + fn test_hash_message_same_content() { + let message1 = "test message".to_string(); + let message2 = "test message".to_string(); + + let hash1 = MessageCache::hash_message(&message1).unwrap(); + let hash2 = MessageCache::hash_message(&message2).unwrap(); + + assert_eq!(hash1, hash2); + } +} diff --git a/chat-core/src/protocol.rs b/chat-core/src/protocol.rs new file mode 100644 index 0000000..62c4633 --- /dev/null +++ b/chat-core/src/protocol.rs @@ -0,0 +1,636 @@ +use std::{ops::Deref, sync::Arc}; + +use bincode::{self, Decode, Encode, config::Configuration}; +use bytes::{Bytes, BytesMut}; + +use crate::{error::ApplicationError, message_cahce::MessageCache}; + +pub const BINCODE_STANDADRD_CONFIG: Configuration = bincode::config::standard(); +/// max allowed message size 1 Mibibyte +pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// max allowed username length +pub const MAX_USERNAME_LENGTH: usize = 30; + +#[derive(Clone, Encode, Decode)] +pub enum ClientMessage { + Join { username: String }, + Leave { username: String }, + Message { username: String, content: String }, +} + +impl ClientMessage { + pub fn username(&self) -> Option<&str> { + match self { + ClientMessage::Join { username, .. } + | ClientMessage::Leave { username } + | ClientMessage::Message { + username, + content: _, + } => Some(username), + } + } + pub fn join(username: String) -> Self { + Self::Join { username } + } + pub fn leave(username: String) -> Self { + Self::Leave { username } + } + pub fn message(username: String, content: String) -> Self { + Self::Message { username, content } + } +} + +#[derive(Clone, Encode, Decode)] +pub enum ServerMessage { + Success { message: String }, + Error { reason: String }, + UserNameAlreadyTaken { username: String }, + Message { username: String, content: String }, + UserJoined { username: String }, + UserLeft { username: String }, +} + +impl ServerMessage { + pub fn username(&self) -> Option<&str> { + match self { + ServerMessage::Message { username, .. } + | ServerMessage::UserNameAlreadyTaken { username } + | ServerMessage::UserJoined { username } + | ServerMessage::UserLeft { username } => Some(username), + _ => None, + } + } + pub fn success(message: String) -> Self { + Self::Success { message } + } + pub fn error(reason: String) -> Self { + Self::Error { reason } + } + pub fn message(username: String, content: String) -> Self { + Self::Message { username, content } + } + pub fn user_name_already_taken(username: String) -> Self { + Self::UserNameAlreadyTaken { username } + } + pub fn user_joined(username: String) -> Self { + Self::UserJoined { username } + } + pub fn user_left(username: String) -> Self { + Self::UserLeft { username } + } +} + +#[derive(Encode, Decode)] +pub struct SharedServerMessage(pub Arc); + +impl SharedServerMessage { + pub fn new(message: ServerMessage) -> Self { + Self(Arc::new(message)) + } +} + +impl Deref for SharedServerMessage { + type Target = ServerMessage; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Clone for SharedServerMessage { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +pub const LENGTH_PREFIX: usize = 4; + +/// encode content which doesn't need to be cached +pub fn encode_message(message: &T) -> Result { + let payload = Bytes::from(bincode::encode_to_vec(message, BINCODE_STANDADRD_CONFIG)?); + let mut frame = BytesMut::with_capacity(LENGTH_PREFIX + payload.len()); + frame.extend_from_slice(&(payload.len() as u32).to_be_bytes()); + frame.extend_from_slice(&payload); + Ok(frame.freeze()) +} + +/// encode content which needs to be fetched from cache +/// this is userful specially for ServerMessages +pub async fn encode_message_with_cache( + message: &T, + cache: &MessageCache, +) -> Result { + let hash = MessageCache::hash_message(message)?; + + if let Some(cached) = cache.get(hash).await { + return Ok(cached); + } + let frame_bytes = encode_message(message)?; + cache.put(hash, frame_bytes.clone()).await; + + Ok(frame_bytes) +} + +pub fn decode_message>(buf: &[u8]) -> Result { + Ok(bincode::decode_from_slice(buf, BINCODE_STANDADRD_CONFIG).map(|(body, _)| body)?) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_message_username_extraction_join() { + let msg = ClientMessage::Join { + username: "alice".to_string(), + }; + assert_eq!(msg.username(), Some("alice")); + } + + #[test] + fn test_client_message_username_extraction_leave() { + let msg = ClientMessage::Leave { + username: "bob".to_string(), + }; + assert_eq!(msg.username(), Some("bob")); + } + + #[test] + fn test_client_message_username_extraction_message() { + let msg = ClientMessage::Message { + username: "charlie".to_string(), + content: "Hello, world!".to_string(), + }; + assert_eq!(msg.username(), Some("charlie")); + } + + #[test] + fn test_server_message_username_extraction_message() { + let msg = ServerMessage::Message { + username: "alice".to_string(), + content: "Hi everyone!".to_string(), + }; + assert_eq!(msg.username(), Some("alice")); + } + + #[test] + fn test_server_message_username_extraction_user_joined() { + let msg = ServerMessage::UserJoined { + username: "bob".to_string(), + }; + assert_eq!(msg.username(), Some("bob")); + } + + #[test] + fn test_server_message_username_extraction_user_left() { + let msg = ServerMessage::UserLeft { + username: "charlie".to_string(), + }; + assert_eq!(msg.username(), Some("charlie")); + } + + #[test] + fn test_server_message_username_extraction_success() { + let msg = ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }; + assert_eq!(msg.username(), None); + } + + #[test] + fn test_server_message_username_extraction_error() { + let msg = ServerMessage::Error { + reason: "Invalid username".to_string(), + }; + assert_eq!(msg.username(), None); + } + + #[tokio::test] + async fn test_client_message_serialization_and_deserialization() { + let join_msg = ClientMessage::Join { + username: "testuser".to_string(), + }; + let cache = MessageCache::new(10); + let encoded = encode_message_with_cache(&join_msg, &cache) + .await + .expect("Failed to encode join message"); + let decoded: ClientMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode join message"); + assert!(matches!(decoded, ClientMessage::Join { .. })); + if let ClientMessage::Join { username } = decoded { + assert_eq!(username, "testuser"); + } + + let leave_msg = ClientMessage::Leave { + username: "testuser".to_string(), + }; + let encoded = encode_message_with_cache(&leave_msg, &cache) + .await + .expect("Failed to encode leave message"); + let decoded: ClientMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode leave message"); + assert!(matches!(decoded, ClientMessage::Leave { .. })); + if let ClientMessage::Leave { username } = decoded { + assert_eq!(username, "testuser"); + } + + let message_content = "Hello, this is a test message!"; + let message_msg = ClientMessage::Message { + username: "testuser".to_string(), + content: message_content.to_string(), + }; + let encoded = encode_message_with_cache(&message_msg, &cache) + .await + .expect("Failed to encode message"); + let decoded: ClientMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode message"); + assert!(matches!(decoded, ClientMessage::Message { .. })); + if let ClientMessage::Message { username, content } = decoded { + assert_eq!(username, "testuser"); + assert_eq!(content, message_content); + } + } + + #[tokio::test] + async fn test_server_message_serialization_and_deserialization() { + let success_msg = ServerMessage::Success { + message: "Connection successful!".to_string(), + }; + let cache = MessageCache::new(10); + let encoded = encode_message_with_cache(&success_msg, &cache) + .await + .expect("Failed to encode success message"); + let decoded: ServerMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode success message"); + assert!(matches!(decoded, ServerMessage::Success { .. })); + if let ServerMessage::Success { message } = decoded { + assert_eq!(message, "Connection successful!"); + } + + let error_msg = ServerMessage::Error { + reason: "Username already taken".to_string(), + }; + let encoded = encode_message_with_cache(&error_msg, &cache) + .await + .expect("Failed to encode error message"); + let decoded: ServerMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode error message"); + assert!(matches!(decoded, ServerMessage::Error { .. })); + if let ServerMessage::Error { reason } = decoded { + assert_eq!(reason, "Username already taken"); + } + + let server_message = ServerMessage::Message { + username: "alice".to_string(), + content: "Hello from server!".to_string(), + }; + let encoded = encode_message_with_cache(&server_message, &cache) + .await + .expect("Failed to encode server message"); + let decoded: ServerMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode server message"); + assert!(matches!(decoded, ServerMessage::Message { .. })); + if let ServerMessage::Message { username, content } = decoded { + assert_eq!(username, "alice"); + assert_eq!(content, "Hello from server!"); + } + + let user_joined_msg = ServerMessage::UserJoined { + username: "newuser".to_string(), + }; + let encoded = encode_message_with_cache(&user_joined_msg, &cache) + .await + .expect("Failed to encode user joined message"); + let decoded: ServerMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode user joined message"); + assert!(matches!(decoded, ServerMessage::UserJoined { .. })); + if let ServerMessage::UserJoined { username } = decoded { + assert_eq!(username, "newuser"); + } + + let user_left_msg = ServerMessage::UserLeft { + username: "leavinguser".to_string(), + }; + let encoded = encode_message_with_cache(&user_left_msg, &cache) + .await + .expect("Failed to encode user left message"); + let decoded: ServerMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode user left message"); + assert!(matches!(decoded, ServerMessage::UserLeft { .. })); + if let ServerMessage::UserLeft { username } = decoded { + assert_eq!(username, "leavinguser"); + } + } + + #[tokio::test] + async fn test_message_framing_length_prefix() { + let msg = ClientMessage::Join { + username: "test".to_string(), + }; + let cache = MessageCache::new(10); + + let frame = encode_message_with_cache(&msg, &cache) + .await + .expect("Failed to encode message"); + + assert!( + frame.len() > LENGTH_PREFIX, + "Frame should be longer than length prefix" + ); + + let length_bytes = &frame[0..LENGTH_PREFIX]; + let length = u32::from_be_bytes([ + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + ]) as usize; + + assert_eq!(length, frame.len() - LENGTH_PREFIX); + + let payload = &frame[LENGTH_PREFIX..]; + let decoded: ClientMessage = decode_message(payload).expect("Failed to decode payload"); + assert!(matches!(decoded, ClientMessage::Join { .. })); + } + + #[tokio::test] + async fn test_empty_strings_in_messages() { + let msg = ClientMessage::Message { + username: "".to_string(), + content: "Message with empty username".to_string(), + }; + let cache = MessageCache::new(10); + + let encoded = encode_message_with_cache(&msg, &cache) + .await + .expect("Failed to encode message with empty username"); + let decoded: ClientMessage = decode_message(&encoded[LENGTH_PREFIX..]) + .expect("Failed to decode message with empty username"); + + if let ClientMessage::Message { username, content } = decoded { + assert_eq!(username, ""); + assert_eq!(content, "Message with empty username"); + } + + let msg = ClientMessage::Message { + username: "user".to_string(), + content: "".to_string(), + }; + + let encoded = encode_message_with_cache(&msg, &cache) + .await + .expect("Failed to encode message with empty content"); + let decoded: ClientMessage = decode_message(&encoded[LENGTH_PREFIX..]) + .expect("Failed to decode message with empty content"); + + if let ClientMessage::Message { username, content } = decoded { + assert_eq!(username, "user"); + assert_eq!(content, ""); + } + } + + #[tokio::test] + async fn test_long_strings_in_messages() { + let long_username = "a".repeat(1000); + let long_content = "b".repeat(10000); + + let msg = ClientMessage::Message { + username: long_username.clone(), + content: long_content.clone(), + }; + let cache = MessageCache::new(10); + + let encoded = encode_message_with_cache(&msg, &cache) + .await + .expect("Failed to encode message with long strings"); + let decoded: ClientMessage = decode_message(&encoded[LENGTH_PREFIX..]) + .expect("Failed to decode message with long strings"); + + if let ClientMessage::Message { username, content } = decoded { + assert_eq!(username, long_username); + assert_eq!(content, long_content); + } + } + + #[test] + fn test_decode_message_with_invalid_data() { + let result: Result = decode_message::(&[]); + assert!(result.is_err(), "Should fail to decode empty data"); + + let invalid_data = vec![ + 0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE, 0x12, 0x34, 0x56, 0x78, + ]; + let result: Result = decode_message(&invalid_data); + + assert!( + result.is_ok() || result.is_err(), + "Decoding should either succeed or fail gracefully" + ); + } + + #[tokio::test] + async fn test_encode_message_edge_cases() { + let small_msg = ClientMessage::Join { + username: "a".to_string(), + }; + let cache = MessageCache::new(10); + let encoded = encode_message_with_cache(&small_msg, &cache) + .await + .expect("Failed to encode small message"); + assert!(encoded.len() > LENGTH_PREFIX); + + let large_msg = ClientMessage::Message { + username: "user".to_string(), + content: "x".repeat(100000), + }; + let encoded = encode_message_with_cache(&large_msg, &cache) + .await + .expect("Failed to encode large message"); + assert!(encoded.len() > 100000); + + let decoded: ClientMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode large message"); + if let ClientMessage::Message { username, content } = decoded { + assert_eq!(username, "user"); + assert_eq!(content.len(), 100000); + } + } + + #[tokio::test] + async fn test_message_type_discrimination() { + let messages = [ + ClientMessage::Join { + username: "user1".to_string(), + }, + ClientMessage::Leave { + username: "user2".to_string(), + }, + ClientMessage::Message { + username: "user3".to_string(), + content: "test".to_string(), + }, + ]; + let cache = MessageCache::new(10); + + for (i, original_msg) in messages.iter().enumerate() { + let encoded = encode_message_with_cache(original_msg, &cache) + .await + .expect("Failed to encode message"); + let decoded: ClientMessage = + decode_message(&encoded[LENGTH_PREFIX..]).expect("Failed to decode message"); + + match (original_msg, decoded) { + ( + ClientMessage::Join { + username: orig_user, + }, + ClientMessage::Join { username: dec_user }, + ) => { + assert_eq!( + *orig_user, dec_user, + "Join message {} should have matching usernames", + i + ); + } + ( + ClientMessage::Leave { + username: orig_user, + }, + ClientMessage::Leave { username: dec_user }, + ) => { + assert_eq!( + *orig_user, dec_user, + "Leave message {} should have matching usernames", + i + ); + } + ( + ClientMessage::Message { + username: orig_user, + content: orig_content, + }, + ClientMessage::Message { + username: dec_user, + content: dec_content, + }, + ) => { + assert_eq!( + *orig_user, dec_user, + "Message {} should have matching usernames", + i + ); + assert_eq!( + *orig_content, dec_content, + "Message {} should have matching content", + i + ); + } + _ => panic!("Message {} type mismatch after encoding/decoding", i), + } + } + } + + #[tokio::test] + async fn test_length_prefix_constants() { + assert_eq!(LENGTH_PREFIX, 4, "Length prefix should be 4 bytes for u32"); + + let msg = ClientMessage::Join { + username: "test".to_string(), + }; + let cache = MessageCache::new(10); + let encoded = encode_message_with_cache(&msg, &cache) + .await + .expect("Failed to encode test message"); + + assert!( + encoded.len() > LENGTH_PREFIX, + "Encoded message should be longer than length prefix" + ); + + let length_bytes = &encoded[0..LENGTH_PREFIX]; + let length = u32::from_be_bytes([ + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + ]); + + assert_eq!(length as usize, encoded.len() - LENGTH_PREFIX); + } + + #[test] + fn test_client_message_factory_methods() { + let join_msg = ClientMessage::join("alice".to_string()); + match join_msg { + ClientMessage::Join { username } => assert_eq!(username, "alice"), + _ => panic!("Expected Join message"), + } + + let leave_msg = ClientMessage::leave("bob".to_string()); + match leave_msg { + ClientMessage::Leave { username } => assert_eq!(username, "bob"), + _ => panic!("Expected Leave message"), + } + + let message_msg = ClientMessage::message("charlie".to_string(), "Hello".to_string()); + match message_msg { + ClientMessage::Message { username, content } => { + assert_eq!(username, "charlie"); + assert_eq!(content, "Hello"); + } + _ => panic!("Expected Message message"), + } + } + + #[test] + fn test_server_message_factory_methods() { + let success_msg = ServerMessage::success("Welcome!".to_string()); + match success_msg { + ServerMessage::Success { message } => assert_eq!(message, "Welcome!"), + _ => panic!("Expected Success message"), + } + + let error_msg = ServerMessage::error("Error occurred".to_string()); + match error_msg { + ServerMessage::Error { reason } => assert_eq!(reason, "Error occurred"), + _ => panic!("Expected Error message"), + } + + let message_msg = ServerMessage::message("alice".to_string(), "Hello".to_string()); + match message_msg { + ServerMessage::Message { username, content } => { + assert_eq!(username, "alice"); + assert_eq!(content, "Hello"); + } + _ => panic!("Expected Message message"), + } + + let taken_msg = ServerMessage::user_name_already_taken("takenuser".to_string()); + match taken_msg { + ServerMessage::UserNameAlreadyTaken { username } => assert_eq!(username, "takenuser"), + _ => panic!("Expected UserNameAlreadyTaken message"), + } + + let joined_msg = ServerMessage::user_joined("newuser".to_string()); + match joined_msg { + ServerMessage::UserJoined { username } => assert_eq!(username, "newuser"), + _ => panic!("Expected UserJoined message"), + } + + let left_msg = ServerMessage::user_left("leavinguser".to_string()); + match left_msg { + ServerMessage::UserLeft { username } => assert_eq!(username, "leavinguser"), + _ => panic!("Expected UserLeft message"), + } + } + + #[test] + fn test_shared_server_message_new() { + let message = ServerMessage::success("Test message".to_string()); + let shared_message = SharedServerMessage::new(message); + + if let Some(username) = shared_message.username() { + panic!("Success message should not have username: {}", username) + } + } +} diff --git a/chat-core/src/transport_layer.rs b/chat-core/src/transport_layer.rs new file mode 100644 index 0000000..9e3cd06 --- /dev/null +++ b/chat-core/src/transport_layer.rs @@ -0,0 +1,504 @@ +use bincode::{Decode, Encode}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::OwnedReadHalf, +}; + +use crate::{ + error::ApplicationError, + message_cahce::MessageCache, + protocol::{ + LENGTH_PREFIX, MAX_MESSAGE_SIZE, decode_message, encode_message, encode_message_with_cache, + }, +}; + +pub async fn read_message_from_stream>( + reader: &mut OwnedReadHalf, + buffer: &mut Vec, +) -> Result { + let n = reader.read(&mut buffer[..LENGTH_PREFIX]).await?; + if n == 0 { + return Err(ApplicationError::ClientReadStreamClosed); + } + + if n < LENGTH_PREFIX { + return Err(ApplicationError::IncompleteLengthPrefix); + } + + let length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize; + + if length > MAX_MESSAGE_SIZE { + return Err(ApplicationError::message_too_large( + length, + MAX_MESSAGE_SIZE, + )); + } + + if buffer.len() < length { + buffer.resize(length, 0); + } + + reader.read_exact(&mut buffer[..length]).await?; + decode_message(&buffer[..length]) +} + +pub async fn write_message_to_stream_with_cache( + writer: &mut tokio::net::tcp::OwnedWriteHalf, + message: &T, + cache: &MessageCache, +) -> Result<(), ApplicationError> { + let frame = encode_message_with_cache(message, cache).await?; + writer.write_all(&frame).await?; + writer.flush().await?; + Ok(()) +} + +pub async fn write_message_to_stream( + writer: &mut tokio::net::tcp::OwnedWriteHalf, + message: &T, +) -> Result<(), ApplicationError> { + let frame = encode_message(message)?; + writer.write_all(&frame).await?; + writer.flush().await?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, + }; + + #[tokio::test] + async fn test_read_message_from_stream_success() { + let test_message = crate::protocol::ClientMessage::Join { + username: "testuser".to_string(), + }; + let cache = MessageCache::new(1); + let encoded_message = crate::protocol::encode_message_with_cache(&test_message, &cache) + .await + .expect("Failed to encode test message"); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + writer.write_all(&encoded_message).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut buffer = vec![0u8; 1024]; + let _ = reader.read(&mut buffer).await.unwrap(); + + Arc::new(test_message) + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let received_message: crate::protocol::ClientMessage = + read_message_from_stream(&mut reader, &mut buffer) + .await + .expect("Failed to read message from stream"); + + received_message + }); + + let (server_result, client_result) = tokio::join!(server_handle, client_handle); + let sent_message = server_result.unwrap(); + let received_message = client_result.unwrap(); + + match (&*sent_message, received_message) { + ( + crate::protocol::ClientMessage::Join { + username: sent_username, + }, + crate::protocol::ClientMessage::Join { + username: received_username, + }, + ) => { + assert_eq!(sent_username, &received_username); + } + _ => panic!("Message type mismatch"), + } + } + + #[tokio::test] + async fn test_read_message_from_stream_client_closed() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let result: Result = + read_message_from_stream(&mut reader, &mut buffer).await; + + result + }); + + server_handle.await.unwrap(); + let client_result = client_handle.await.unwrap(); + + assert!(matches!( + client_result, + Err(crate::error::ApplicationError::ClientReadStreamClosed) + )); + } + + #[tokio::test] + async fn test_read_message_from_stream_incomplete_length_prefix() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + writer.write_all(&[0x00, 0x00]).await.unwrap(); + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let result: Result = + read_message_from_stream(&mut reader, &mut buffer).await; + + result + }); + + server_handle.await.unwrap(); + let client_result = client_handle.await.unwrap(); + + assert!(matches!( + client_result, + Err(crate::error::ApplicationError::IncompleteLengthPrefix) + )); + } + + #[tokio::test] + async fn test_read_message_from_stream_incomplete_payload() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + writer.write_all(&[0x00, 0x00, 0x00, 0x64]).await.unwrap(); + writer.write_all(&[0xFF; 10]).await.unwrap(); + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let result: Result = + read_message_from_stream(&mut reader, &mut buffer).await; + + result + }); + + server_handle.await.unwrap(); + let client_result = client_handle.await.unwrap(); + + assert!(client_result.is_err()); + } + + #[tokio::test] + async fn test_write_message_to_stream_success() { + let test_message = crate::protocol::ServerMessage::Success { + message: "Test message".to_string(), + }; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .expect("Failed to write message to stream"); + + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let bytes_read = reader.read(&mut buffer).await.unwrap(); + buffer.truncate(bytes_read); + + buffer + }); + + server_handle.await.unwrap(); + let received_bytes = client_handle.await.unwrap(); + + let expected_message = crate::protocol::ServerMessage::Success { + message: "Test message".to_string(), + }; + let expected_bytes = + crate::protocol::encode_message(&expected_message).expect("Failed to encode test message"); + + assert_eq!(received_bytes, expected_bytes); + } + + #[tokio::test] + async fn test_write_message_to_stream_with_large_message() { + let large_content = "x".repeat(10000); + let test_message = crate::protocol::ServerMessage::Message { + username: "largeuser".to_string(), + content: large_content, + }; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .expect("Failed to write large message to stream"); + + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 20480]; + let mut total_bytes = Vec::new(); + let mut offset = 0; + + loop { + let bytes_read = reader.read(&mut buffer[offset..]).await.unwrap(); + if bytes_read == 0 { + break; + } + offset += bytes_read; + total_bytes.extend_from_slice(&buffer[..bytes_read]); + } + + total_bytes + }); + + server_handle.await.unwrap(); + let received_bytes = client_handle.await.unwrap(); + + let expected_large_content = "x".repeat(10000); + let expected_message = crate::protocol::ServerMessage::Message { + username: "largeuser".to_string(), + content: expected_large_content, + }; + let expected_bytes = crate::protocol::encode_message(&expected_message) + .expect("Failed to encode large test message"); + + assert_eq!(received_bytes, expected_bytes); + assert!(received_bytes.len() > 10000); + } + + #[tokio::test] + async fn test_write_message_to_stream_with_empty_message() { + let test_message = crate::protocol::ServerMessage::Message { + username: "".to_string(), + content: "".to_string(), + }; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .expect("Failed to write empty message to stream"); + + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let bytes_read = reader.read(&mut buffer).await.unwrap(); + buffer.truncate(bytes_read); + + buffer + }); + + server_handle.await.unwrap(); + let received_bytes = client_handle.await.unwrap(); + + let expected_message = crate::protocol::ServerMessage::Message { + username: "".to_string(), + content: "".to_string(), + }; + let expected_bytes = crate::protocol::encode_message(&expected_message) + .expect("Failed to encode empty test message"); + + assert_eq!(received_bytes, expected_bytes); + + assert!(received_bytes.len() >= 4); + } + + #[tokio::test] + async fn test_round_trip_encoding_decoding() { + let test_messages = vec![ + crate::protocol::ClientMessage::Join { + username: "user1".to_string(), + }, + crate::protocol::ClientMessage::Leave { + username: "user2".to_string(), + }, + crate::protocol::ClientMessage::Message { + username: "user3".to_string(), + content: "Hello, world!".to_string(), + }, + ]; + + for (i, original_message) in test_messages.into_iter().enumerate() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let message_clone = original_message.clone(); + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &message_clone) + .await + .expect("Failed to write message to stream"); + + writer.shutdown().await.unwrap(); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let received_message: crate::protocol::ClientMessage = + read_message_from_stream(&mut reader, &mut buffer) + .await + .expect("Failed to read message from stream"); + + received_message + }); + + server_handle.await.unwrap(); + let received_message = client_handle.await.unwrap(); + + match (original_message, received_message) { + ( + crate::protocol::ClientMessage::Join { + username: orig_user, + }, + crate::protocol::ClientMessage::Join { + username: recv_user, + }, + ) => { + assert_eq!( + *orig_user, recv_user, + "Join message {} username mismatch", + i + ); + } + ( + crate::protocol::ClientMessage::Leave { + username: orig_user, + }, + crate::protocol::ClientMessage::Leave { + username: recv_user, + }, + ) => { + assert_eq!( + *orig_user, recv_user, + "Leave message {} username mismatch", + i + ); + } + ( + crate::protocol::ClientMessage::Message { + username: orig_user, + content: orig_content, + }, + crate::protocol::ClientMessage::Message { + username: recv_user, + content: recv_content, + }, + ) => { + assert_eq!(*orig_user, recv_user, "Message {} username mismatch", i); + assert_eq!( + *orig_content, recv_content, + "Message {} content mismatch", + i + ); + } + _ => panic!("Message {} type mismatch", i), + } + } + } + + #[test] + fn test_buffer_reuse() { + let mut buffer = vec![0u8; 10]; + + buffer[0] = 0x00; + buffer[1] = 0x00; + buffer[2] = 0x01; + buffer[3] = 0x00; + + if buffer.len() < 256 { + buffer.resize(256, 0); + } + + assert!( + buffer.len() >= 256, + "Buffer should be resized to accommodate large messages" + ); + + buffer[0] = 0x42; + let old_size = buffer.len(); + buffer.resize(old_size + 100, 0); + + assert_eq!( + buffer[0], 0x42, + "Buffer resize should preserve existing data" + ); + assert_eq!( + buffer.len(), + old_size + 100, + "Buffer should be resized correctly" + ); + } +} diff --git a/chat-core/src/utils.rs b/chat-core/src/utils.rs new file mode 100644 index 0000000..0e991dc --- /dev/null +++ b/chat-core/src/utils.rs @@ -0,0 +1,87 @@ +use crate::{ + error::ApplicationError, + protocol::{MAX_MESSAGE_SIZE, MAX_USERNAME_LENGTH}, +}; + +pub fn generate_unique_id() -> String { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("invalid time") + .as_nanos(); + + format!("id_{}", ts) +} + +pub fn validate_username(username: &str) -> Result<(), ApplicationError> { + if username.is_empty() || username.len() > MAX_USERNAME_LENGTH { + return Err(ApplicationError::UsernameNotFound); + } + + if username.len() > 20 { + return Err(ApplicationError::message_too_large( + username.len(), + MAX_MESSAGE_SIZE, + )); + } + Ok(()) +} + +pub fn is_valid_message(message: &str) -> bool { + !message.is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_unique_id() { + let id1 = generate_unique_id(); + let id2 = generate_unique_id(); + + assert!(id1.starts_with("id_")); + assert!(id2.starts_with("id_")); + assert_ne!(id1, id2); + } + + #[test] + fn test_validate_username_valid() { + let valid_usernames = vec!["alice", "user123", "test_user", "a"]; + + for username in valid_usernames { + assert!(validate_username(username).is_ok()); + } + } + + #[test] + fn test_validate_username_empty() { + let result = validate_username(""); + assert!(result.is_err()); + } + + #[test] + fn test_validate_username_too_long() { + let long_username = "a".repeat(31); + let result = validate_username(&long_username); + assert!(result.is_err()); + } + + #[test] + fn test_validate_username_over_20_chars() { + let long_username = "a".repeat(21); + let result = validate_username(&long_username); + assert!(result.is_err()); + } + + #[test] + fn test_is_valid_message_valid() { + assert!(is_valid_message("Hello")); + assert!(is_valid_message("a")); + assert!(is_valid_message("This is a test message")); + } + + #[test] + fn test_is_valid_message_empty() { + assert!(!is_valid_message("")); + } +} diff --git a/chat-core/tests/integration_protocol.rs b/chat-core/tests/integration_protocol.rs new file mode 100644 index 0000000..5f0519b --- /dev/null +++ b/chat-core/tests/integration_protocol.rs @@ -0,0 +1,502 @@ +use chat_core::{ + error::ApplicationError, + protocol::{ClientMessage, LENGTH_PREFIX, ServerMessage, decode_message, encode_message}, + transport_layer::{read_message_from_stream, write_message_to_stream}, +}; +use tokio::io::AsyncWriteExt; +use tokio::net::{TcpListener, TcpStream}; + +#[tokio::test] +async fn test_protocol_message_round_trip() { + let client_test_messages = [ + ClientMessage::Join { + username: "alice".to_string(), + }, + ClientMessage::Leave { + username: "bob".to_string(), + }, + ClientMessage::Message { + username: "charlie".to_string(), + content: "Hello, world!".to_string(), + }, + ]; + + for (i, original_message) in client_test_messages.iter().enumerate() { + let encoded = encode_message(original_message) + .unwrap_or_else(|e| panic!("Failed to encode client message {}: {}", i, e)); + + assert!( + encoded.len() > LENGTH_PREFIX, + "Encoded client message {} should be longer than length prefix", + i + ); + + let payload = &encoded[LENGTH_PREFIX..]; + let decoded: ClientMessage = decode_message(payload) + .unwrap_or_else(|e| panic!("Failed to decode client message {}: {}", i, e)); + + match (original_message, decoded) { + ( + ClientMessage::Join { + username: orig_user, + }, + ClientMessage::Join { username: dec_user }, + ) => { + assert_eq!(orig_user, &dec_user, "Join message {} username mismatch", i); + } + ( + ClientMessage::Leave { + username: orig_user, + }, + ClientMessage::Leave { username: dec_user }, + ) => { + assert_eq!( + orig_user, &dec_user, + "Leave message {} username mismatch", + i + ); + } + ( + ClientMessage::Message { + username: orig_user, + content: orig_content, + }, + ClientMessage::Message { + username: dec_user, + content: dec_content, + }, + ) => { + assert_eq!(orig_user, &dec_user, "Message {} username mismatch", i); + assert_eq!(orig_content, &dec_content, "Message {} content mismatch", i); + } + _ => panic!("Client message {} type mismatch after encoding/decoding", i), + } + } + + let server_test_messages = [ + ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }, + ServerMessage::Error { + reason: "Username already taken".to_string(), + }, + ServerMessage::Message { + username: "dave".to_string(), + content: "Test message".to_string(), + }, + ServerMessage::UserJoined { + username: "newuser".to_string(), + }, + ServerMessage::UserLeft { + username: "leavinguser".to_string(), + }, + ]; + + for (i, original_message) in server_test_messages.iter().enumerate() { + let encoded = encode_message(original_message) + .unwrap_or_else(|e| panic!("Failed to encode server message {}: {}", i, e)); + + assert!( + encoded.len() > LENGTH_PREFIX, + "Encoded server message {} should be longer than length prefix", + i + ); + + let payload = &encoded[LENGTH_PREFIX..]; + let decoded: ServerMessage = decode_message(payload) + .unwrap_or_else(|e| panic!("Failed to decode server message {}: {}", i, e)); + + match (original_message, decoded) { + ( + ServerMessage::Success { message: orig_msg }, + ServerMessage::Success { message: dec_msg }, + ) => { + assert_eq!(orig_msg, &dec_msg, "Success message {} mismatch", i); + } + ( + ServerMessage::Error { + reason: orig_reason, + }, + ServerMessage::Error { reason: dec_reason }, + ) => { + assert_eq!(orig_reason, &dec_reason, "Error message {} mismatch", i); + } + ( + ServerMessage::Message { + username: orig_user, + content: orig_content, + }, + ServerMessage::Message { + username: dec_user, + content: dec_content, + }, + ) => { + assert_eq!( + orig_user, &dec_user, + "Server message {} username mismatch", + i + ); + assert_eq!( + orig_content, &dec_content, + "Server message {} content mismatch", + i + ); + } + ( + ServerMessage::UserJoined { + username: orig_user, + }, + ServerMessage::UserJoined { username: dec_user }, + ) => { + assert_eq!( + orig_user, &dec_user, + "UserJoined message {} username mismatch", + i + ); + } + ( + ServerMessage::UserLeft { + username: orig_user, + }, + ServerMessage::UserLeft { username: dec_user }, + ) => { + assert_eq!( + orig_user, &dec_user, + "UserLeft message {} username mismatch", + i + ); + } + _ => panic!("Server message {} type mismatch after encoding/decoding", i), + } + } +} + +#[tokio::test] +async fn test_transport_layer_tcp_communication() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let test_message = ClientMessage::Message { + username: "testuser".to_string(), + content: "Hello via TCP!".to_string(), + }; + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let received_message: ClientMessage = read_message_from_stream(&mut reader, &mut buffer) + .await + .unwrap(); + + let response = ServerMessage::Success { + message: format!( + "Received message from: {}", + received_message.username().unwrap() + ), + }; + write_message_to_stream(&mut writer, &response) + .await + .unwrap(); + + writer.shutdown().await.unwrap(); + + received_message + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .unwrap(); + + writer.shutdown().await.unwrap(); + + let mut buffer = vec![0u8; 4096]; + let response: ServerMessage = read_message_from_stream(&mut reader, &mut buffer) + .await + .unwrap(); + + response + }); + + let (received_message, response) = tokio::join!(server_handle, client_handle); + let received_message = received_message.unwrap(); + let response = response.unwrap(); + + match received_message { + ClientMessage::Message { username, content } => { + assert_eq!(username, "testuser"); + assert_eq!(content, "Hello via TCP!"); + } + _ => panic!("Expected ClientMessage::Message"), + } + + match response { + ServerMessage::Success { message } => { + assert!(message.contains("testuser")); + } + _ => panic!("Expected ServerMessage::Success"), + } +} + +#[tokio::test] +async fn test_transport_layer_error_handling() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let result: Result = + read_message_from_stream(&mut reader, &mut buffer).await; + + result + }); + + server_handle.await.unwrap(); + let read_result = client_handle.await.unwrap(); + + assert!(matches!( + read_result, + Err(ApplicationError::ClientReadStreamClosed) + )); +} + +#[tokio::test] +async fn test_large_message_handling() { + let large_content = "x".repeat(50_000); + let test_message = ClientMessage::Message { + username: "largeuser".to_string(), + content: large_content.clone(), + }; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 65536]; + let received_message: ClientMessage = read_message_from_stream(&mut reader, &mut buffer) + .await + .unwrap(); + + received_message + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .unwrap(); + + writer.shutdown().await.unwrap(); + }); + + let (received_message, _) = tokio::join!(server_handle, client_handle); + + match received_message.unwrap() { + ClientMessage::Message { username, content } => { + assert_eq!(username, "largeuser"); + assert_eq!(content.len(), 50_000); + assert_eq!(content, large_content); + } + _ => panic!("Expected large ClientMessage::Message"), + } +} + +#[tokio::test] +async fn test_concurrent_message_transfer() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let num_clients = 5; + let messages_per_client = 3; + + let server_handle = tokio::spawn(async move { + let mut expected_messages = Vec::new(); + + for i in 0..num_clients { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _) = stream.into_split(); + + for j in 0..messages_per_client { + let mut buffer = vec![0u8; 4096]; + let received_message: ClientMessage = read_message_from_stream(&mut reader, &mut buffer) + .await + .unwrap(); + + expected_messages.push(format!("client{}_msg{}", i, j)); + + match &received_message { + ClientMessage::Message { username, content } => { + assert_eq!(username, &format!("client{}", i)); + assert_eq!(content, &format!("message {}", j)); + } + _ => panic!("Expected ClientMessage::Message"), + } + } + } + + expected_messages + }); + + let mut client_handles = Vec::new(); + for i in 0..num_clients { + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + for j in 0..messages_per_client { + let message = ClientMessage::Message { + username: format!("client{}", i), + content: format!("message {}", j), + }; + + write_message_to_stream(&mut writer, &message) + .await + .unwrap(); + } + }); + client_handles.push(client_handle); + } + + for handle in client_handles { + handle.await.unwrap(); + } + + let received_messages = server_handle.await.unwrap(); + assert_eq!(received_messages.len(), num_clients * messages_per_client); +} + +#[tokio::test] +async fn test_message_cache_integration() { + use chat_core::message_cahce::MessageCache; + + let cache = MessageCache::new(5); + let test_messages = [ + ClientMessage::Join { + username: "user1".to_string(), + }, + ClientMessage::Message { + username: "user1".to_string(), + content: "Hello".to_string(), + }, + ClientMessage::Leave { + username: "user1".to_string(), + }, + ]; + + for message in test_messages.iter() { + let hash = MessageCache::hash_message(message).unwrap(); + + assert!(cache.get(hash).await.is_none()); + + let encoded = encode_message(message).unwrap(); + cache.put(hash, encoded.clone()).await; + + let cached = cache.get(hash).await; + assert!(cached.is_some()); + assert_eq!(cached.unwrap(), encoded); + } +} + +#[tokio::test] +async fn test_error_propagation_integration() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let result: Result = + read_message_from_stream(&mut reader, &mut buffer).await; + + match result { + Ok(_) => { + println!("Unexpectedly received valid message from malformed data"); + } + Err(ApplicationError::Decoding(_)) => {} + Err(ApplicationError::Encoding(_)) => {} + Err(ApplicationError::StreamIoError(_)) => {} + Err(e) => { + println!("Got unexpected error: {:?}", e); + } + } + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + let malformed_data = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF]; + let _ = writer.write_all(&malformed_data).await; + + let _ = writer.shutdown().await; + }); + + server_handle.await.unwrap(); + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn test_message_size_limits_integration() { + let huge_content = "x".repeat(10_000); + let test_message = ClientMessage::Message { + username: "biguser".to_string(), + content: huge_content.clone(), + }; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _) = stream.into_split(); + + let mut buffer = vec![0u8; 15_000]; + let received_message: ClientMessage = read_message_from_stream(&mut reader, &mut buffer) + .await + .unwrap(); + + received_message + }); + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + write_message_to_stream(&mut writer, &test_message) + .await + .unwrap(); + + writer.shutdown().await.unwrap(); + }); + + let (received_message, _) = tokio::join!(server_handle, client_handle); + + match received_message.unwrap() { + ClientMessage::Message { username, content } => { + assert_eq!(username, "biguser"); + assert_eq!(content.len(), 10_000); + assert_eq!(content, huge_content); + } + _ => panic!("Expected large ClientMessage::Message"), + } +} diff --git a/chat-server/Cargo.toml b/chat-server/Cargo.toml new file mode 100644 index 0000000..562e1e6 --- /dev/null +++ b/chat-server/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "chat-server" +version.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +chat-core = { path = "../chat-core" } +dashmap = { workspace = true } +anyhow = { workspace = true } +clap = { workspace = true, features = ["derive"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +tokio = { workspace = true } diff --git a/chat-server/src/broadcast.rs b/chat-server/src/broadcast.rs new file mode 100644 index 0000000..08e4040 --- /dev/null +++ b/chat-server/src/broadcast.rs @@ -0,0 +1,273 @@ +use chat_core::{ + error::ApplicationError, + message_cahce::MessageCache, + protocol::{ServerMessage, SharedServerMessage}, + transport_layer::{write_message_to_stream, write_message_to_stream_with_cache}, +}; +use tokio::{ + net::tcp::OwnedWriteHalf, + sync::broadcast::{Receiver, error::RecvError}, + task::JoinHandle, +}; +use tracing::{info, warn}; + +pub fn spawn_broadcast_dispatcher( + mut rec: Receiver, + channel_owner: String, + mut writer: OwnedWriteHalf, + cache: MessageCache, +) -> JoinHandle> { + tokio::spawn(async move { + let success_msg = ServerMessage::Success { + message: format!("Welcome to the chat ๐Ÿ™ `{}`", channel_owner), + }; + write_message_to_stream(&mut writer, &success_msg).await?; + + info!("User {} joined the chat", channel_owner); + + loop { + match rec.recv().await { + Ok(message) => { + if let Some(message_username) = message.username() + && message_username == channel_owner + { + continue; + } + + write_message_to_stream_with_cache(&mut writer, &message, &cache).await?; + } + Err(RecvError::Lagged(lagged_by)) => { + warn!("reciever lagged by {lagged_by} messages"); + } + Err(RecvError::Closed) => { + warn!("Broadcast channel has been closed, Exiting"); + break; + } + } + } + + let success_msg = + ServerMessage::success(format!("Disconnected: Good bye `{}`!", channel_owner)); + write_message_to_stream(&mut writer, &success_msg).await?; + Ok(()) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use chat_core::protocol::encode_message; + use tokio::sync::broadcast; + + #[test] + fn test_message_username_extraction() { + let messages_with_usernames = vec![ + ServerMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }, + ServerMessage::UserJoined { + username: "bob".to_string(), + }, + ServerMessage::UserLeft { + username: "charlie".to_string(), + }, + ]; + + let messages_without_usernames = vec![ + ServerMessage::Success { + message: "Welcome!".to_string(), + }, + ServerMessage::Error { + reason: "Error occurred".to_string(), + }, + ]; + + for message in &messages_with_usernames { + assert!( + message.username().is_some(), + "Message should have a username" + ); + } + + for message in &messages_without_usernames { + assert!( + message.username().is_none(), + "Message should not have a username" + ); + } + } + + #[test] + fn test_message_serialization() { + let test_messages = vec![ + ServerMessage::Message { + username: "alice".to_string(), + content: "Hello, world!".to_string(), + }, + ServerMessage::Error { + reason: "Test error".to_string(), + }, + ServerMessage::Success { + message: "Welcome to the chat!".to_string(), + }, + ServerMessage::UserJoined { + username: "newuser".to_string(), + }, + ServerMessage::UserLeft { + username: "leavinguser".to_string(), + }, + ]; + + for message in test_messages { + let encoded = encode_message(&message).expect("Failed to encode message"); + assert!( + encoded.len() > 4, + "Encoded message should have length prefix and payload" + ); + + let payload = &encoded[4..]; + + let decoded: ServerMessage = + chat_core::protocol::decode_message(payload).expect("Failed to decode message"); + + match (&message, decoded) { + ( + ServerMessage::Message { + username: orig_user, + content: orig_content, + }, + ServerMessage::Message { + username: dec_user, + content: dec_content, + }, + ) => { + assert_eq!(orig_user, &dec_user); + assert_eq!(orig_content, &dec_content); + } + ( + ServerMessage::Error { + reason: orig_reason, + }, + ServerMessage::Error { reason: dec_reason }, + ) => { + assert_eq!(orig_reason, &dec_reason); + } + ( + ServerMessage::Success { message: orig_msg }, + ServerMessage::Success { message: dec_msg }, + ) => { + assert_eq!(orig_msg, &dec_msg); + } + ( + ServerMessage::UserJoined { + username: orig_user, + }, + ServerMessage::UserJoined { username: dec_user }, + ) => { + assert_eq!(orig_user, &dec_user); + } + ( + ServerMessage::UserLeft { + username: orig_user, + }, + ServerMessage::UserLeft { username: dec_user }, + ) => { + assert_eq!(orig_user, &dec_user); + } + _ => panic!("Message type mismatch after encoding/decoding"), + } + } + } + + #[test] + fn test_large_message_serialization() { + let large_content = "x".repeat(10000); + let test_message = ServerMessage::Message { + username: "largeuser".to_string(), + content: large_content, + }; + + let encoded = encode_message(&test_message).expect("Failed to encode large message"); + assert!( + encoded.len() > 10000, + "Large message should result in substantial encoded size" + ); + + let payload = &encoded[4..]; + + let decoded: ServerMessage = + chat_core::protocol::decode_message(payload).expect("Failed to decode large message"); + + match decoded { + ServerMessage::Message { username, content } => { + assert_eq!(username, "largeuser"); + assert_eq!(content.len(), 10000); + } + _ => panic!("Expected large ServerMessage::Message"), + } + } + + #[test] + fn test_empty_message_serialization() { + let test_message = ServerMessage::Message { + username: "".to_string(), + content: "".to_string(), + }; + + let encoded = encode_message(&test_message).expect("Failed to encode empty message"); + assert!( + encoded.len() >= 4, + "Message should have at least length prefix" + ); + + let payload = &encoded[4..]; + + let decoded: ServerMessage = + chat_core::protocol::decode_message(payload).expect("Failed to decode empty message"); + + match decoded { + ServerMessage::Message { username, content } => { + assert_eq!(username, ""); + assert_eq!(content, ""); + } + _ => panic!("Expected empty ServerMessage::Message"), + } + } + + #[test] + fn test_broadcast_dispatcher_creation() { + let (tx, _rx) = broadcast::channel(10); + + let test_message = ServerMessage::Success { + message: "Test".to_string(), + }; + + let result = tx.send(test_message); + assert!( + result.is_ok(), + "Should be able to send message through broadcast channel" + ); + } + + #[test] + fn test_message_equality() { + let msg1 = ServerMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }; + + let msg2 = ServerMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }; + + let encoded1 = encode_message(&msg1).expect("Failed to encode first message"); + let encoded2 = encode_message(&msg2).expect("Failed to encode second message"); + + assert_eq!( + encoded1, encoded2, + "Identical messages should produce identical encoded data" + ); + } +} diff --git a/chat-server/src/broadcast_pool.rs b/chat-server/src/broadcast_pool.rs new file mode 100644 index 0000000..ecd9110 --- /dev/null +++ b/chat-server/src/broadcast_pool.rs @@ -0,0 +1,116 @@ +use std::{collections::HashMap, sync::Arc}; + +use chat_core::{ + error::ApplicationError, message_cahce::MessageCache, protocol::SharedServerMessage, +}; +use dashmap::DashMap; +use tokio::{net::tcp::OwnedWriteHalf, sync::broadcast::Sender, task::JoinHandle}; +use tracing::warn; + +pub struct BroadcastPool { + broadcaster: Arc>, + dispatchers: Arc>>>, +} + +impl BroadcastPool { + pub fn new(broadcaster: Arc>) -> Self { + Self { + broadcaster, + dispatchers: Arc::new(DashMap::new()), + } + } + + pub fn create_dispatcher(&self, username: String, writer: OwnedWriteHalf, cache: MessageCache) { + let receiver = self.broadcaster.subscribe(); + let dispatcher = + crate::broadcast::spawn_broadcast_dispatcher(receiver, username.clone(), writer, cache); + self.dispatchers.insert(username, dispatcher); + } + + pub fn has(&self, username: &str) -> bool { + self.dispatchers.contains_key(username) + } + + pub fn destroy_dispatcher(&self, username: &str) { + if let Some((_, dispatcher)) = self.dispatchers.remove(username) { + dispatcher.abort(); + } else { + warn!("dispatcher not found for user {username}"); + } + } + + pub async fn broadcast_message( + &self, + message: SharedServerMessage, + ) -> Result<(), ApplicationError> { + if self.broadcaster.send(message).is_err() { + warn!("no recievers to recive broadcst"); + } + Ok(()) + } + + pub async fn shutdown(&self) -> Result<(), ApplicationError> { + drop(self.broadcaster.clone()); + + let mut dispatchers = HashMap::new(); + for entry in self.dispatchers.iter() { + if let Some((username, dispatcher)) = self.dispatchers.remove(entry.key()) { + dispatchers.insert(username, dispatcher); + } + } + + for (_, dispatcher) in dispatchers { + dispatcher.abort(); + } + Ok(()) + } +} + +impl Clone for BroadcastPool { + fn clone(&self) -> Self { + Self { + broadcaster: Arc::clone(&self.broadcaster), + dispatchers: Arc::clone(&self.dispatchers), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::broadcast; + + #[test] + fn test_broadcast_pool_new() { + let (tx, _) = broadcast::channel(10); + let broadcaster = Arc::new(tx); + let pool = BroadcastPool::new(broadcaster.clone()); + + assert!(Arc::ptr_eq(&pool.broadcaster, &broadcaster)); + } + + #[test] + fn test_broadcast_pool_has() { + let (tx, _) = broadcast::channel(10); + let pool = BroadcastPool::new(Arc::new(tx)); + + assert!(!pool.has("testuser")); + } + + #[test] + fn test_broadcast_pool_destroy_dispatcher() { + let (tx, _) = broadcast::channel(10); + let pool = BroadcastPool::new(Arc::new(tx)); + + pool.destroy_dispatcher("nonexistent"); + } + + #[tokio::test] + async fn test_broadcast_pool_shutdown() { + let (tx, _) = broadcast::channel(10); + let pool = BroadcastPool::new(Arc::new(tx)); + + let result = pool.shutdown().await; + assert!(result.is_ok()); + } +} diff --git a/chat-server/src/client_handler.rs b/chat-server/src/client_handler.rs new file mode 100644 index 0000000..74b108d --- /dev/null +++ b/chat-server/src/client_handler.rs @@ -0,0 +1,92 @@ +use chat_core::{ + error::ApplicationError, + message_cahce::MessageCache, + protocol::{ClientMessage, ServerMessage, SharedServerMessage, encode_message}, + transport_layer::read_message_from_stream, + utils, +}; +use tokio::{io::AsyncWriteExt, net::TcpStream}; +use tracing::info; + +use crate::broadcast_pool::BroadcastPool; + +pub(super) async fn handle_client( + stream: TcpStream, + broadcast_pool: BroadcastPool, + cache: MessageCache, +) -> Result<(), ApplicationError> { + let (mut reader, mut writer) = stream.into_split(); + + let mut buffer = vec![0u8; 4096]; + let username: Option; + + loop { + match read_message_from_stream(&mut reader, &mut buffer).await { + Ok(ClientMessage::Join { username: un }) => { + // validate username + utils::validate_username(&un)?; + if broadcast_pool.has(&un) { + let error_msg = ServerMessage::user_name_already_taken(un.to_string()); + let frame = encode_message(&error_msg)?; + writer.write_all(&frame).await?; + continue; + } + username = Some(un); + break; + } + Ok(_) => { + let error_msg = ServerMessage::Error { + reason: "You have to join the group before doing any other operation".to_string(), + }; + let frame = encode_message(&error_msg)?; + writer.write_all(&frame).await?; + } + Err(_) => (), + } + } + + let user = username.ok_or(ApplicationError::UsernameNotFound)?; + + broadcast_pool.create_dispatcher(user.clone(), writer, cache); + + let join_notification = SharedServerMessage::new(ServerMessage::user_joined(user.clone())); + broadcast_pool.broadcast_message(join_notification).await?; + + loop { + tokio::select! { + message_result = read_message_from_stream(&mut reader, &mut buffer) => { + match message_result { + Ok(ClientMessage::Join { username: _ }) => {} + Ok(ClientMessage::Leave { username: user }) => { + broadcast_pool.destroy_dispatcher(&user); + info!("User {} left the chat", user); + + let msg = SharedServerMessage::new(ServerMessage::user_left(user.clone())); + broadcast_pool.broadcast_message(msg).await?; + break; + } + Ok(ClientMessage::Message { + username: user, + content, + }) => { + let msg = SharedServerMessage::new(ServerMessage::message(user.clone(), content.clone())); + broadcast_pool.broadcast_message(msg).await?; + info!("Message from {}: {}", user, content); + } + Err(ApplicationError::ClientReadStreamClosed) => { + info!("Client disconnected"); + break; + } + Err(_) => { + + break; + } + } + } + } + } + + broadcast_pool.destroy_dispatcher(&user); + + Ok(()) +} diff --git a/chat-server/src/lib.rs b/chat-server/src/lib.rs new file mode 100644 index 0000000..a426e18 --- /dev/null +++ b/chat-server/src/lib.rs @@ -0,0 +1,4 @@ +pub mod broadcast; +pub mod broadcast_pool; +pub mod client_handler; +pub mod server; diff --git a/chat-server/src/main.rs b/chat-server/src/main.rs new file mode 100644 index 0000000..107ad8c --- /dev/null +++ b/chat-server/src/main.rs @@ -0,0 +1,75 @@ +use anyhow::Result; +use chat_server::server::ChatServer; +use clap::Parser; +use tokio::signal; +use tokio::sync::oneshot; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + #[arg(long, default_value = "127.0.0.1")] + host: String, + + #[arg(long, default_value_t = 8080)] + port: u16, + + #[arg(long, default_value_t = 1000)] + max_connections: usize, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + + let args = Args::parse(); + + let server = ChatServer::new(args.max_connections); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let shutdown_signal = shutdown_signal(shutdown_tx); + + let server_task = + tokio::spawn(async move { server.run(&args.host, args.port, shutdown_rx).await }); + + tokio::select! { + _ = shutdown_signal => { + tracing::info!("Shutdown signal received, server will shutdown gracefully"); + } + result = server_task => { + match result { + Ok(_) => tracing::info!("Server completed normally"), + Err(e) => tracing::error!("Server error: {}", e), + } + } + } + + Ok(()) +} + +async fn shutdown_signal(shutdown_tx: oneshot::Sender<()>) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("Shutdown signal received, starting graceful shutdown..."); + + let _ = shutdown_tx.send(()); +} diff --git a/chat-server/src/server.rs b/chat-server/src/server.rs new file mode 100644 index 0000000..724202c --- /dev/null +++ b/chat-server/src/server.rs @@ -0,0 +1,194 @@ +use anyhow::Result; +use chat_core::message_cahce::MessageCache; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::{ + broadcast::{self}, + oneshot, +}; +use tracing::{error, info}; + +use crate::broadcast_pool::BroadcastPool; + +pub struct ChatServer { + max_connections: usize, + broadcast_pool: BroadcastPool, + cache: MessageCache, +} + +pub type ShutdownSignal = oneshot::Receiver<()>; + +impl ChatServer { + pub fn new(max_connections: usize) -> Self { + let (tx, _) = broadcast::channel(10_000); + Self { + broadcast_pool: BroadcastPool::new(Arc::new(tx)), + cache: MessageCache::new(10_0000), + max_connections, + } + } + + pub async fn run(&self, host: &str, port: u16, mut shutdown_rx: ShutdownSignal) -> Result<()> { + let listener = TcpListener::bind(format!("{}:{}", host, port)).await?; + info!("Chat server listening on {}:{}", host, port); + + let connection_limiter = Arc::new(tokio::sync::Semaphore::new(self.max_connections)); + + loop { + tokio::select! { + _ = &mut shutdown_rx => { + info!("Shutdown signal received, starting graceful shutdown..."); + self.shutdown().await?; + return Ok(()); + } + accept_result = listener.accept() => { + match accept_result { + Ok((stream, addr)) => { + info!("New connection from {}", addr); + let permit = connection_limiter.clone().acquire_owned().await?; + + let pool = self.broadcast_pool.clone(); + let cache = self.cache.clone(); + tokio::spawn(async move { + if let Err(e) = crate::client_handler::handle_client(stream, pool, cache).await { + error!("Client handler error: {}", e); + } + drop(permit); + }); + } + Err(e) => { + error!("Accept error: {}", e); + tracing::debug!("Accept failed, releasing permit"); + } + } + } + } + } + } + + async fn shutdown(&self) -> Result<()> { + info!("Starting graceful shutdown..."); + self.broadcast_pool.shutdown().await?; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + info!("Server shutdown complete"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chat_core::{error::ApplicationError, protocol::ServerMessage}; + + #[tokio::test] + async fn test_server_creation() { + let server = ChatServer::new(100); + assert_eq!(server.max_connections, 100); + } + + #[tokio::test] + async fn test_server_creation_with_zero_connections() { + let server = ChatServer::new(0); + assert_eq!(server.max_connections, 0); + } + + #[tokio::test] + async fn test_graceful_shutdown() { + let server = ChatServer::new(10); + + let result = server.shutdown().await; + assert!(result.is_ok()); + } + + #[test] + fn test_application_error_variants() { + let errors = vec![ + ApplicationError::ClientReadStreamClosed, + ApplicationError::IncompleteLengthPrefix, + ApplicationError::IncompletePyaload, + ApplicationError::UsernameNotFound, + ]; + + for error in errors { + let error_string = format!("{}", error); + assert!(!error_string.is_empty(), "Error should have a description"); + + let _debug_string = format!("{:?}", error); + } + } + + #[tokio::test] + async fn test_multiple_server_instances() { + let server1 = ChatServer::new(50); + let server2 = ChatServer::new(100); + + assert_eq!(server1.max_connections, 50); + assert_eq!(server2.max_connections, 100); + } + + #[test] + fn test_connection_limiter_creation() { + let server = ChatServer::new(10); + + let connection_limiter = tokio::sync::Semaphore::new(server.max_connections); + assert_eq!(connection_limiter.available_permits(), 10); + + let permit1 = connection_limiter.try_acquire().unwrap(); + assert_eq!(connection_limiter.available_permits(), 9); + + let permit2 = connection_limiter.try_acquire().unwrap(); + assert_eq!(connection_limiter.available_permits(), 8); + + drop(permit1); + drop(permit2); + + assert_eq!(connection_limiter.available_permits(), 10); + } + + #[test] + fn test_large_server_capacity() { + let server = ChatServer::new(1_000_000); + assert_eq!(server.max_connections, 1_000_000); + } + + #[test] + fn test_server_message_username_extraction_comprehensive() { + let messages_with_usernames = vec![ + ServerMessage::Message { + username: "alice".to_string(), + content: "Hello".to_string(), + }, + ServerMessage::UserJoined { + username: "bob".to_string(), + }, + ServerMessage::UserLeft { + username: "charlie".to_string(), + }, + ]; + + let messages_without_usernames = vec![ + ServerMessage::Success { + message: "Welcome!".to_string(), + }, + ServerMessage::Error { + reason: "Error occurred".to_string(), + }, + ]; + + for message in &messages_with_usernames { + assert!(message.username().is_some(), "Message should have username"); + assert!( + !message.username().unwrap().is_empty(), + "Username should not be empty" + ); + } + + for message in &messages_without_usernames { + assert!( + message.username().is_none(), + "Message should not have username" + ); + } + } +} diff --git a/chat-server/tests/integration_server_client.rs b/chat-server/tests/integration_server_client.rs new file mode 100644 index 0000000..faf2d37 --- /dev/null +++ b/chat-server/tests/integration_server_client.rs @@ -0,0 +1,420 @@ +use chat_core::{ + error::ApplicationError, + protocol::{ClientMessage, ServerMessage}, + transport_layer::{read_message_from_stream, write_message_to_stream}, +}; +use chat_server::server::ChatServer; +use tokio::{ + io::AsyncWriteExt, + net::TcpStream, + sync::oneshot, + time::{Duration, timeout}, +}; + +async fn simulate_client_connection( + host: &str, + port: u16, + _username: &str, +) -> Result> { + let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + Ok(stream) +} + +#[tokio::test] +async fn test_server_basic_functionality() { + let server = ChatServer::new(10); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12345; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let stream_result = simulate_client_connection("127.0.0.1", port, "testuser").await; + assert!(stream_result.is_ok(), "Should be able to connect to server"); + + let stream = stream_result.unwrap(); + + let join_message = ClientMessage::Join { + username: "testuser".to_string(), + }; + + let (_reader, mut writer) = stream.into_split(); + let join_result = write_message_to_stream(&mut writer, &join_message).await; + assert!(join_result.is_ok(), "Should be able to send join message"); + + let _ = writer.shutdown().await; + + let _ = shutdown_tx.send(()); + + let result = timeout(Duration::from_secs(5), server_handle).await; + match result { + Ok(handle) => { + let _ = handle; + } + Err(_) => panic!("Server shutdown timed out"), + } +} + +#[tokio::test] +async fn test_server_graceful_shutdown() { + let server = ChatServer::new(10); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12347; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let client_result = simulate_client_connection("127.0.0.1", port, "testuser").await; + assert!(client_result.is_ok(), "Client should connect successfully"); + + let _ = shutdown_tx.send(()); + + let server_result = server_handle.await; + assert!(server_result.is_ok(), "Server should shut down gracefully"); +} + +#[tokio::test] +async fn test_server_multiple_clients() { + let server = ChatServer::new(10); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12348; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client_tasks = Vec::new(); + + for i in 0..5 { + let task = tokio::spawn(async move { + simulate_client_connection("127.0.0.1", port, &format!("user{}", i)).await + }); + client_tasks.push(task); + } + + let mut connection_results = Vec::new(); + for task in client_tasks { + connection_results.push(task.await.unwrap()); + } + + for result in &connection_results { + assert!(result.is_ok(), "All client connections should succeed"); + } + + let _ = shutdown_tx.send(()); + + let result = timeout(Duration::from_secs(5), server_handle).await; + match result { + Ok(handle) => { + let _ = handle; + } + Err(_) => panic!("Server shutdown timed out"), + } +} + +#[tokio::test] +async fn test_server_message_handling() -> Result<(), ApplicationError> { + let server = ChatServer::new(5); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let port = 12349; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(format!("127.0.0.1:{}", 12349)) + .await + .unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let join_message = ClientMessage::Join { + username: "testuser".to_string(), + }; + + write_message_to_stream(&mut writer, &join_message) + .await + .unwrap(); + + let mut buffer = vec![0u8; 4096]; + let response: Result = timeout( + Duration::from_millis(500), + read_message_from_stream(&mut reader, &mut buffer), + ) + .await + .unwrap_or(Err(ApplicationError::ClientReadStreamClosed)); + + if response.is_err() { + eprintln!("Server did not respond within timeout, but this is acceptable for this test"); + } + + let test_message = ClientMessage::Message { + username: "testuser".to_string(), + content: "Hello, server!".to_string(), + }; + + let send_result = write_message_to_stream(&mut writer, &test_message).await; + + assert!( + send_result.is_ok(), + "Should be able to send message to server" + ); + + // Important: Properly shutdown the writer to avoid resource leaks + let _ = writer.shutdown().await; + + // Clean up the reader + drop(reader); + }); + + // Wait for client to complete with timeout + let client_result = timeout(Duration::from_secs(10), client_handle).await; + match client_result { + Ok(handle) => { + let _ = handle; + } + Err(_) => panic!("Client operation timed out"), + } + + let _ = shutdown_tx.send(()); + + let result = timeout(Duration::from_secs(5), server_handle).await; + + match result { + Ok(handle) => { + let _ = handle; + } + Err(_) => panic!("Server shutdown timed out"), + } + Ok(()) +} +#[tokio::test] +async fn test_server_error_recovery() -> Result<(), ApplicationError> { + let server = ChatServer::new(10); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let port = 12349; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + for i in 0..3 { + let stream_result = TcpStream::connect(format!("127.0.0.1:{}", port)).await; + assert!(stream_result.is_ok(), "Connection {} should succeed", i); + + let stream = stream_result.unwrap(); + let (_reader, mut writer) = stream.into_split(); + + let _ = writer.shutdown().await; + + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let malformed_stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + let (_reader, mut writer) = malformed_stream.into_split(); + + let random_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE]; + let _ = writer.write_all(&random_data).await; + + let _ = writer.shutdown().await; + + let _ = shutdown_tx.send(()); + let _ = server_handle.await; + Ok(()) +} + +#[tokio::test] +async fn test_server_broadcast_functionality() -> Result<(), ApplicationError> { + let server = ChatServer::new(10); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12350; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client_handles = Vec::new(); + + for i in 0..3 { + let client_handle = tokio::spawn(async move { + let stream = TcpStream::connect(format!("127.0.0.1:{}", 12350)) + .await + .unwrap(); + let (mut reader, mut writer) = stream.into_split(); + + let join_message = ClientMessage::Join { + username: format!("user{}", i), + }; + + write_message_to_stream(&mut writer, &join_message) + .await + .unwrap(); + + let mut buffer = vec![0u8; 4096]; + let response: Result = timeout( + Duration::from_secs(3), + read_message_from_stream(&mut reader, &mut buffer), + ) + .await + .unwrap_or(Err(ApplicationError::ClientReadStreamClosed)); + + let _ = writer.shutdown().await; + + if let Ok(ServerMessage::Success { message }) = response { + assert!(message.contains("Welcome")); + } + + drop(reader); + }); + + client_handles.push(client_handle); + tokio::time::sleep(Duration::from_millis(10)).await; + } + + for handle in client_handles { + let _ = handle.await; + } + + let _ = shutdown_tx.send(()); + let result = timeout(Duration::from_secs(5), server_handle).await; + + match result { + Ok(_) => {} + Err(_) => panic!("Server shutdown timed out"), + } + + Ok(()) +} + +#[tokio::test] +async fn test_server_connection_limit_enforcement() -> Result<(), ApplicationError> { + let server = ChatServer::new(2); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12351; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let client1 = TcpStream::connect(format!("127.0.0.1:{}", port)).await; + let client2 = TcpStream::connect(format!("127.0.0.1:{}", port)).await; + + assert!(client1.is_ok(), "First client should connect"); + assert!(client2.is_ok(), "Second client should connect"); + + tokio::time::sleep(Duration::from_millis(100)).await; + + if let Ok(stream) = client1 { + let (_, mut writer) = stream.into_split(); + writer.shutdown().await?; + } + + if let Ok(stream) = client2 { + let (_, mut writer) = stream.into_split(); + writer.shutdown().await?; + } + + let _ = shutdown_tx.send(()); + let result = timeout(Duration::from_secs(5), server_handle).await; + + match result { + Ok(_) => {} + Err(_) => panic!("Server shutdown timed out"), + } + + Ok(()) +} + +#[tokio::test] +async fn test_server_username_validation_integration() -> Result<(), ApplicationError> { + let server = ChatServer::new(5); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let port = 12352; + + let server_handle = tokio::spawn(async move { server.run("127.0.0.1", port, shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let client1_handle = tokio::spawn(async move { + let stream = TcpStream::connect(format!("127.0.0.1:{}", 12352)) + .await + .unwrap(); + let (mut reader1, mut writer1) = stream.into_split(); + + let valid_join = ClientMessage::Join { + username: "validuser".to_string(), + }; + + write_message_to_stream(&mut writer1, &valid_join) + .await + .unwrap(); + + let mut buffer = vec![0u8; 4096]; + let response1: Result = timeout( + Duration::from_secs(3), + read_message_from_stream(&mut reader1, &mut buffer), + ) + .await + .unwrap_or(Err(ApplicationError::ClientReadStreamClosed)); + + let _ = writer1.shutdown().await; + + assert!(response1.is_ok(), "Valid username should be accepted"); + + drop(reader1); + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let client2_handle = tokio::spawn(async move { + let stream2 = TcpStream::connect(format!("127.0.0.1:{}", 12352)) + .await + .unwrap(); + let (mut reader2, mut writer2) = stream2.into_split(); + + let invalid_join = ClientMessage::Join { + username: "".to_string(), + }; + + write_message_to_stream(&mut writer2, &invalid_join) + .await + .unwrap(); + + let mut buffer2 = vec![0u8; 4096]; + let response2: Result = timeout( + Duration::from_secs(3), + read_message_from_stream(&mut reader2, &mut buffer2), + ) + .await + .unwrap_or(Err(ApplicationError::ClientReadStreamClosed)); + + if let Ok(resp2) = response2 + && let ServerMessage::Error { reason } = resp2 + { + assert!(reason.contains("username") || reason.contains("invalid")); + } + + let _ = writer2.shutdown().await; + + drop(reader2); + }); + + let _ = client1_handle.await; + let _ = client2_handle.await; + + let _ = shutdown_tx.send(()); + let result = timeout(Duration::from_secs(5), server_handle).await; + + match result { + Ok(_) => {} + Err(_) => panic!("Server shutdown timed out"), + } + + Ok(()) +} diff --git a/developer-tools/git/hooks/commit-msg b/developer-tools/git/hooks/commit-msg new file mode 100644 index 0000000..5e1a4ec --- /dev/null +++ b/developer-tools/git/hooks/commit-msg @@ -0,0 +1,47 @@ +#!/bin/sh + +# Conventional Commit message validation hook +# This hook will validate that commit messages follow the conventional commit format + +commit_msg=$(cat "$1") + +# Check if message is empty +if [ -z "$commit_msg" ]; then + echo "โŒ Commit message cannot be empty" + exit 1 +fi + +# Extract the first line (header) +commit_header=$(echo "$commit_msg" | head -n1) + +# Check conventional commit format: type(scope): description +if ! echo "$commit_header" | grep -qE '^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)\([a-z0-9\-]+\): [a-z].*'; then + echo "โŒ Commit message must follow conventional commit format:" + echo " (): " + echo "" + echo " Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert" + echo " Example: feat(auth): add login functionality" + echo "" + echo " Current message: $commit_header" + exit 1 +fi + +# Check subject length (should be <= 50 characters) +subject_length=$(echo "$commit_header" | wc -c) +if [ "$subject_length" -gt 51 ]; then + echo "โŒ Subject line must be 50 characters or less" + echo " Current length: $((subject_length - 1)) characters" + echo " Current message: $commit_header" + exit 1 +fi + +# Check for proper capitalization (should start with lowercase) +first_char=$(echo "$commit_header" | cut -c1) +if [[ "$first_char" =~ [A-Z] ]]; then + echo "โŒ Subject line should start with lowercase letter" + echo " Current message: $commit_header" + exit 1 +fi + +echo "โœ… Commit message format is valid" +exit 0 diff --git a/developer-tools/git/hooks/pre-commit b/developer-tools/git/hooks/pre-commit new file mode 100755 index 0000000..986c49f --- /dev/null +++ b/developer-tools/git/hooks/pre-commit @@ -0,0 +1,84 @@ +#!/bin/bash + +set -e + +echo "๐Ÿ” Running pre-commit checks ..." + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Function to print colored output +print_status() { + echo -e "${GREEN}โœ“${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}โš ${NC} $1" +} + +print_error() { + echo -e "${RED}โœ—${NC} $1" +} + +# Check if rust is installed +if ! command -v cargo &> /dev/null; then + print_error "Cargo/Rust not found. Please install Rust: https://rustup.rs/" + exit 1 +fi + +# Get list of staged Rust files +RUST_FILES=$(git diff --cached --name-only --diff-filter=ACM | grep -E '\.rs$' || true) + +if [ -z "$RUST_FILES" ]; then + print_warning "No Rust files staged for commit" + exit 0 +fi + +echo "๐Ÿ“ Staged Rust files:" +echo "$RUST_FILES" | sed 's/^/ /' + +# Check formatting +echo "" +echo "๐Ÿ”ค Checking code formatting..." +if ! cargo fmt --all -- --check; then + print_error "Code formatting issues found. Run 'cargo fmt --all' to fix them." + echo "" + echo "๐Ÿ’ก To automatically format your code:" + echo " cargo fmt --all" + exit 1 +fi +print_status "Code formatting is correct" + +# Check for compilation errors +echo "" +echo "๐Ÿ”จ Checking compilation..." +if ! cargo check --all --all-features; then + print_error "Compilation failed. Please fix compilation errors." + exit 1 +fi +print_status "Code compiles successfully" + +# Run clippy for linting +echo "" +echo "๐Ÿ” Running clippy linter..." +if ! cargo clippy --workspace --all-targets --all-features -- -D warnings; then + print_error "Clippy found issues. Please fix them or use 'allow' if appropriate." + echo "" + echo "๐Ÿ’ก To see clippy suggestions:" + echo " cargo clippy --all --all-features" + exit 1 +fi +print_status "No clippy issues found" + +# Run basic tests on changed files +echo "" +echo "๐Ÿงช Running tests..." +if ! cargo test --all --all-features; then + print_error "Tests failed. Please fix failing tests." + exit 1 +fi +print_status "All tests pass" +exit 0 diff --git a/developer-tools/git/template/.gitmessage b/developer-tools/git/template/.gitmessage new file mode 100644 index 0000000..ae5de6d --- /dev/null +++ b/developer-tools/git/template/.gitmessage @@ -0,0 +1,41 @@ +# (): +# +# +# +#