diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1c1f508 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[workspace] +resolver = "2" + +members = [ + "server", + "cli-client", + "common", + "integration_tests" +] diff --git a/README.md b/README.md index 8c4d4e1..7f0800f 100644 --- a/README.md +++ b/README.md @@ -69,3 +69,17 @@ without error, and is free of clippy errors. send a message to the server from the client. Make sure that niether the server or client exit with a failure. This action should be run anytime new code is pushed to a branch or landed on the main branch. + +## Start a server +`cargo run -p server -- --ip 127.0.0.1 --port 8090` + +## Start a client +`cargo run -p cli-client -- --host 127.0.0.1 --port 8090 --username ` +or +`SIMPLE_CHAT_SERVER_HOST=127.0.0.1 SIMPLE_CHAT_SERVER_PORT=8090 cargo run -p cli-client -- --username ` + +## Run tests +`cargo test` + +## Demo video +https://github.com/user-attachments/assets/a4c163e2-d29b-40ab-97e2-037e86164486 diff --git a/cli-client/Cargo.toml b/cli-client/Cargo.toml new file mode 100644 index 0000000..08b7ea0 --- /dev/null +++ b/cli-client/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "cli-client" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1.41.0", features = ["full"] } +clap = { version = "4.5.20", features = ["derive"] } +common = { path = "../common" } \ No newline at end of file diff --git a/cli-client/src/client.rs b/cli-client/src/client.rs new file mode 100644 index 0000000..1b61662 --- /dev/null +++ b/cli-client/src/client.rs @@ -0,0 +1,145 @@ +//! This module contains types and functions to connect with the server. + +use std::io; + +use common::{extract_parts, messages}; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::TcpStream, +}; + +/// A struct to encapsulate functionalities related to connect and messaging with the server. +pub struct Client { + /// Server host + pub host: String, + /// Server port + pub port: String, + /// Username representing human that uses this client + pub username: String, +} + +/// Enum to represent the commands that can be entered by the user +#[derive(PartialEq)] +enum ConsoleCommand { + /// Command to leave the chat room + Leave, + /// Command to send a message to the chat room + Send, + InvalidCommand, +} + +impl From for ConsoleCommand { + fn from(str: String) -> Self { + match str.as_str() { + "leave" => ConsoleCommand::Leave, + "send" => ConsoleCommand::Send, + _ => ConsoleCommand::InvalidCommand, + } + } +} + +impl From<&str> for ConsoleCommand { + fn from(command_str: &str) -> Self { + match command_str.to_lowercase().as_str() { + "leave" => ConsoleCommand::Leave, + "send" => ConsoleCommand::Send, + _ => ConsoleCommand::InvalidCommand, + } + } +} + +impl Client { + pub fn new(host: String, port: String, username: String) -> Self { + Client { + host, + port, + username, + } + } + /// Starts the connection with the server and handles the communication between the client and the server. + pub async fn start(&self) -> io::Result<()> { + // Connect to the server + let mut stream = TcpStream::connect(format!("{}:{}", self.host, self.port)).await?; + + // Disable Nagle's algorithm to send data immediately + stream.set_nodelay(true).unwrap(); + let (reader, writer) = stream.split(); + + // Create a buffered writer and reader for network communication + let mut writer = BufWriter::new(writer); + let mut reader = BufReader::new(reader); + let mut line = String::new(); // Buffer to store received data + + // Join the default room using supplied username + writer + .write_all(format!("<{}> {}\n", messages::JOIN_USER, self.username).as_bytes()) + .await + .expect("ERROR: Unable to write to server"); + writer.flush().await.expect("ERROR: Unable to flush writer"); + + reader + .read_line(&mut line) + .await + .expect("ERROR: Unable to read from server"); + + let (command, _, _) = extract_parts(&line); + if command == messages::DUPLICATE_USER { + eprintln!( + "ERROR: Username already in use. Please try again with a different username.\n" + ); + return Ok(()); + } + + let mut input = String::new(); + + // Create a buffered writer and reader for stdin/stdout communication + let mut console_reader = BufReader::new(tokio::io::stdin()); + + line.clear(); + loop { + tokio::select! { + _result = console_reader.read_line(&mut input) => { + + // Handle sending here + input = input.trim().to_string(); + if input.is_empty() || input == "\n" { + continue; + } + + let user_input = input.split(" ").collect::>(); + + // Extract command from user input + let original_command = user_input[0].to_lowercase(); + let command = ConsoleCommand::from(original_command.clone()); + + if command == ConsoleCommand::Leave { + writer.write_all(format!("<{}> {}\n", messages::LEAVE_USER, self.username).as_bytes()).await.expect("Unable to write to server"); + writer.flush().await.expect("Unable to write to server"); + return Ok(()); + } else if command == ConsoleCommand::Send{ + let usr_msg = &input[original_command.len() + 1..]; + let usr_msg = format!("<{}> {} {}\n", messages::USER_MSG, self.username, usr_msg); + writer.write_all(usr_msg.as_bytes()).await.expect("Unable to write to server"); + writer.flush().await.expect("Unable to write to server"); + } + + input.clear(); + } + + result = reader.read_line(&mut line) => { + if result.expect("ERROR: Unable to read from server") == 0 { + eprintln!("Server closed the connection."); + return Ok(()); + } + let (command, username, data) = extract_parts(&line); + if command == messages::USER_MSG { + println!("{}> {}", username, data); + } else if command == messages::INVALID_CMD { + eprintln!("ERROR: Invalid command received from server"); + } + line.clear(); + } + } + } + } +} diff --git a/cli-client/src/main.rs b/cli-client/src/main.rs new file mode 100644 index 0000000..2abd21a --- /dev/null +++ b/cli-client/src/main.rs @@ -0,0 +1,48 @@ +use std::{env, io, process::exit}; + +use clap::Parser; +use client::Client; + +mod client; + +/// Struct to represent command line args +#[derive(Clone, Debug, Parser)] +#[command(version, about, long_about = None)] +struct Args { + /// A uniqie username + #[arg(short, long)] + username: String, + + /// Host or IP to listen to + #[arg(short = 'o', long)] + host: Option, + + /// Port + #[arg(short, long)] + port: Option, +} + +#[tokio::main] +async fn main() -> io::Result<()> { + // Arg parsing. Priority is given to the command line arguments if provided. + let args = Args::parse(); + + let mut host = env::var("SIMPLE_CHAT_SERVER_HOST").unwrap_or_default(); + let mut port = env::var("SIMPLE_CHAT_SERVER_PORT").unwrap_or_default(); + + host = args.host.unwrap_or(host); + port = args.port.unwrap_or(port); + let username = args.username; + + let client = Client::new(host, port, username); + + // Return the appripriate code based on error. + match client.start().await { + Ok(_) => {} + Err(e) => { + eprintln!("=>ERROR: {}", e); + exit(1); + } + }; + Ok(()) +} diff --git a/common/Cargo.toml b/common/Cargo.toml new file mode 100644 index 0000000..e020b17 --- /dev/null +++ b/common/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "common" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/common/src/lib.rs b/common/src/lib.rs new file mode 100644 index 0000000..d6d692a --- /dev/null +++ b/common/src/lib.rs @@ -0,0 +1,81 @@ +//! This module contains functionalities which are common to both the server and the client. + +pub mod messages; + +// Extract various parts from the string message +pub fn extract_parts(line: &str) -> (u16, String, String) { + let lines = line.split(" ").collect::>(); + let command = lines[0].trim().to_lowercase(); + + let mut data = String::new(); + let mut username = String::new(); + + match lines.len() { + 2 => { + data = line[command.len()..].trim().to_string(); + } + 3.. => { + username = lines[1].trim().to_lowercase(); + data = line[command.len() + username.len() + 2..] + .trim() + .to_string(); + } + _ => {} + } + + let command = command + .strip_prefix("<") + .unwrap() + .strip_suffix(">") + .unwrap(); + // We can trust that command is something we can parse to u16. So we can use unwrap() here safely. + (command.parse::().unwrap(), username, data) +} + +#[cfg(test)] +mod tests { + + use super::*; + + // Test extraction of message with more than three parts. This should return 3 parts. + #[test] + fn test_extract_more_than_three_parts() { + let input = "<107> testuser Hey this is sample message from a testuser"; + let result = extract_parts(input); + assert_eq!(result.0, 107); + assert_eq!(result.1, "testuser"); + assert_eq!(result.2, "Hey this is sample message from a testuser"); + } + + // Test extraction of message with three parts + #[test] + fn test_extract_three_parts() { + let input = "<107> testuser Hey"; + let result = extract_parts(input); + assert_eq!(result.0, 107); + assert_eq!(result.1, "testuser"); + assert_eq!(result.2, "Hey"); + } + + // Test extraction of message with two parts + #[test] + fn test_extract_two_parts() { + let input = "<101> testuser"; + let result = extract_parts(input); + println!("{:?}", &result); + assert_eq!(result.0, 101); + assert_eq!(result.1, ""); + assert_eq!(result.2, "testuser"); + } + + // Test extraction of message with two parts + #[test] + fn test_extract_one_part() { + let input = "<103>"; + let result = extract_parts(input); + println!("{:?}", &result); + assert_eq!(result.0, 103); + assert_eq!(result.1, ""); + assert_eq!(result.2, ""); + } +} diff --git a/common/src/messages.rs b/common/src/messages.rs new file mode 100644 index 0000000..d8b6502 --- /dev/null +++ b/common/src/messages.rs @@ -0,0 +1,10 @@ +//! This module conntains messages used for communication between Client and the Server. + +pub const JOIN_USER: u16 = 101; +pub const USER_JOINED: u16 = 102; +pub const LEAVE_USER: u16 = 103; +pub const USER_LEFT: u16 = 104; +pub const DUPLICATE_USER: u16 = 105; +pub const INVALID_CMD: u16 = 106; +pub const USER_MSG: u16 = 107; +pub const WELCOME_MSG: u16 = 108; diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml new file mode 100644 index 0000000..018f300 --- /dev/null +++ b/integration_tests/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "integration_tests" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/integration_tests/src/lib.rs b/integration_tests/src/lib.rs new file mode 100644 index 0000000..913adce --- /dev/null +++ b/integration_tests/src/lib.rs @@ -0,0 +1,189 @@ +//! This module contains tests for the integration of the server and the client. + +#[cfg(test)] +mod tests { + use std::{ + io::{BufRead, BufReader, Read, Write}, + net::TcpStream, + process::{Command, Stdio}, + thread::sleep, + time::Duration, + }; + + #[test] + fn test_user_exists() { + // Start the server + let mut server = Command::new("../target/release/server") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("Failed to start server"); + + // Wait until the server is ready + let mut attempts = 0; + while attempts < 10 { + if TcpStream::connect("127.0.0.1:8090").is_ok() { + println!("Server is ready."); + break; + } + attempts += 1; + sleep(Duration::from_secs(1)); + } + + // Run the client + let mut client1 = Command::new("../target/release/cli-client") + .args(["--username", "user1"]) + .args(["--host", "127.0.0.1"]) + .args(["--port", "8090"]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("Failed to start client1"); + + let mut client2 = Command::new("../target/release/cli-client") + .args(["--username", "user1"]) + .args(["--host", "127.0.0.1"]) + .args(["--port", "8090"]) + .stdout(Stdio::null()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to start client2"); + + // Capture the output of client2 + let mut client2_stderr = String::new(); + if let Some(stdout) = client2.stderr.take() { + let mut reader = BufReader::new(stdout); + reader + .read_to_string(&mut client2_stderr) + .expect("Failed to read stderr"); + } + + // Assert the output + assert_eq!( + client2_stderr.trim(), + "ERROR: Username already in use. Please try again with a different username." + ); + + server.kill().expect("Failed to kill server"); + client1.kill().expect("Failed to kill client1"); + // server.wait().expect("Failed to wait for server"); + client2.kill().expect("Failed to kill client2"); + } + + // Attempted but not working as expected. + #[ignore] + #[test] + fn test_messaging() { + // Start the server + println!("Starting server"); + let mut server = Command::new("../target/release/server") + .args(["--port", "8091"]) + .stdout(Stdio::null()) + .spawn() + .expect("Failed to start server"); + + // Wait until the server is ready + let mut attempts = 0; + while attempts < 10 { + if TcpStream::connect("127.0.0.1:8091").is_ok() { + println!("Server is ready."); + break; + } + attempts += 1; + sleep(Duration::from_secs(1)); + } + + // Run the client + let mut client1 = Command::new("../target/release/cli-client") + .args(["--username", "user1"]) + .args(["--host", "127.0.0.1"]) + .args(["--port", "8091"]) + .stdout(Stdio::piped()) + .stdin(Stdio::piped()) + .stderr(Stdio::null()) + .spawn() + .expect("Failed to start client1"); + + let mut client2 = Command::new("../target/release/cli-client") + .args(["--username", "user2"]) + .args(["--host", "127.0.0.1"]) + .args(["--port", "8091"]) + .stdout(Stdio::piped()) + .stdin(Stdio::piped()) + .stderr(Stdio::null()) + .spawn() + .expect("Failed to start client2"); + + // Write to Stdins of the clients + println!("Sending messages"); + let client1_stdin = client1 + .stdin + .as_mut() + .expect("Failed to open client1 stdin"); + client1_stdin + .write_all(b"send Hey Everyone. How are you doing?\n") + .expect("Failed to write to client1 stdin"); + client1_stdin + .flush() + .expect("Failed to flush client1 stdin"); + + let client2_stdin = client2 + .stdin + .as_mut() + .expect("Failed to open client2 stdin"); + client2_stdin + .write_all(b"send Hey User1. I am fine. Thanks\n") + .expect("Failed to write to client2 stdin"); + client2_stdin + .flush() + .expect("Failed to flush client1 stdin"); + + client1_stdin + .write_all(b"send Cool!\n") + .expect("Failed to write to client1 stdin"); + client1_stdin + .flush() + .expect("Failed to flush client1 stdin"); + + let mut client1_stdout: Vec = Vec::new(); + let mut client2_stdout: Vec = Vec::new(); + + if let Some(stdout) = client1.stdout.take() { + let reader = BufReader::new(stdout); + for (index, line) in reader.lines().enumerate() { + let line = line.expect("Failed to read line from client1 stdout"); + client1_stdout.push(line); + + // We know that there are two lines so we need to break after the second line. + if index == 1 { + break; + } + } + } + + if let Some(stdout) = client2.stdout.take() { + let reader = BufReader::new(stdout); + for (index, line) in reader.lines().enumerate() { + let line = line.expect("Failed to read line from client2 stdout"); + // println!("{}", &line); + client2_stdout.push(line); + + // We know that there are two lines so we need to break after the second line. + if index == 1 { + break; + } + } + } + + // Assert the output for client1 + assert_eq!(client1_stdout[0], "user2> Hey User1. I am fine. Thanks"); + + // Assert the output for client2 + assert_eq!(client2_stdout[0], "user1> Cool!\n"); + assert_eq!(client2_stdout[1], "user1> Hey Everyone. How are you doing?"); + + client1.kill().expect("Failed to kill client1"); + client2.kill().expect("Failed to kill client2"); + server.kill().expect("Failed to kill server"); + } +} diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..8590308 --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "server" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1.41.0", features = ["full"] } +common = { path = "../common"} +clap = { version = "4.5.20", features = ["derive"] } +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.18", features = ["env-filter"]} diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..eb6d99d --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,49 @@ +use std::process::exit; + +use clap::{command, Parser}; +use server::SimpleChatServer; +use tracing::subscriber; +use tracing_subscriber::EnvFilter; + +mod server; + +/// Struct to represent command line args +#[derive(Clone, Debug, Parser)] +#[command(version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "127.0.0.1")] + ip: Option, + + #[arg(short, long, default_value = "8090")] + port: Option, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + let server = + SimpleChatServer::new(format!("{}:{}", args.ip.unwrap(), args.port.unwrap()).to_string()); + + let subscriber = tracing_subscriber::fmt() + .with_thread_ids(true) + .compact() + .with_file(true) + .with_line_number(true) + .with_target(false) + .with_env_filter(EnvFilter::new("info")) + .finish(); + + // Sets this subscriber as the global default for the duration of the entire program. + subscriber::set_global_default(subscriber).expect("Error in setting logging mechanism"); + + match server.start().await { + Ok(_) => { + tracing::info!("Server started successfully"); + } + Err(e) => { + tracing::error!("Error starting server: {}", e); + exit(1); + } + } +} diff --git a/server/src/server.rs b/server/src/server.rs new file mode 100644 index 0000000..710e878 --- /dev/null +++ b/server/src/server.rs @@ -0,0 +1,458 @@ +//! This module contains types and functions to handle messages from the client. + +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; + +use common::{extract_parts, messages}; +use tokio::{ + io::{self, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, + net::{TcpListener, TcpStream}, + sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, + }, +}; +use tracing::{event, span, Level, Span}; + +/// The server instance +pub struct SimpleChatServer { + /// The address the server is listening on + pub address: Arc, + /// The users connected to the server + pub users: Arc>>>, +} + +impl SimpleChatServer { + /// Create a new server instance + pub fn new(address: String) -> SimpleChatServer { + let address = Arc::new(address); + + SimpleChatServer { + address, + users: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Start the server + pub async fn start(&self) -> io::Result<()> { + let address = self.address.clone(); + let listener = TcpListener::bind(address.to_string()).await?; + + tracing::info!("Starting server on {}", self.address); + + while let Ok((stream, client_address)) = listener.accept().await { + let users = self.users.clone(); + + tokio::spawn({ + async move { + let _ = Self::handle_connection(stream, client_address, users).await; + } + }); + } + + Ok(()) + } + + // Handles joining of a new user. Returns true/false based on user joining success. + async fn handle_join_command( + users: Arc>>>, + writer: &mut T, + line: &mut String, + my_username: &mut String, + data: &str, + tx: Sender, + parent_span: &Span, + ) -> bool + where + T: AsyncWrite + std::marker::Unpin, + { + let span = span!(parent: parent_span, Level::INFO, "handle_join_command"); + let _guard = span.enter(); + + if users.lock().await.contains_key(data) { + let _ = writer + .write_all(format!("<{}>\n", messages::DUPLICATE_USER).as_bytes()) + .await; + let _ = writer.flush().await; + + event!(Level::WARN, "Duplicate username received: {}", data); + line.clear(); // Clear the buffer for safety. + return false; + } + users.lock().await.insert(data.to_string(), tx.clone()); + *my_username = data.to_string(); + event!(Level::INFO, "User {} has joined the chat", data); + + // Send ACK for user joining. + let _ = writer + .write_all(format!("<{}> {}\n", messages::USER_JOINED, &data).as_bytes()) + .await; + let _ = writer.flush().await; + + line.clear(); // Clear the buffer for safety. + drop(_guard); + true + } + + // Handle leaving of the existing user + async fn handle_leave_command( + users: Arc>>>, + rx: &mut Receiver, + data: &str, + parent_span: &Span, + ) { + let span = span!(parent: parent_span, Level::INFO, "handle_leave_command"); + let _guard = span.enter(); + + users.lock().await.remove(data); + + rx.close(); + event!(Level::INFO, "User {:?} has left the chat", &data); + } + + // Handle user messages + async fn handle_user_messages( + users: Arc>>>, + my_username: &str, + writer: &mut T, + line: &mut String, + data: &str, + sender_username: &str, + parent_span: &Span, + ) where + T: AsyncWrite + std::marker::Unpin, + { + let span = span!(parent: parent_span, Level::INFO, "handle_user_messages"); + let _guard = span.enter(); + + // User should not be able to send message if not joined. + if my_username.is_empty() { + event!( + Level::WARN, + "Connection is trying to send message without joining the chat" + ); + + let _ = writer + .write_all("Please join the chat first.\n".as_bytes()) + .await; + let _ = writer.flush().await; + line.clear(); // Clear the buffer for safety. + return; + } + + event!( + Level::DEBUG, + "Sending message {:?} to all users from {}", + &line, + &sender_username + ); + + for (username, sender) in users.lock().await.iter() { + if username != my_username { + let usr_msg = format!("<{}> {} {}\n", messages::USER_MSG, sender_username, data); + let _ = sender.send(usr_msg).await; + } + } + } + + // Primary function to handle the connection. Called for each new user connection. It manages the user joining, leaving and messages. + async fn handle_connection( + mut stream: TcpStream, + client_address: SocketAddr, + users: Arc>>>, + ) -> io::Result<()> { + let handle_connection_span = + span!(Level::INFO, "handle_connection", client_address = %client_address); + let _guard = handle_connection_span.enter(); + + event!(Level::INFO, "Connected"); + stream.set_nodelay(true)?; + + let (reader, writer) = stream.split(); + let mut reader = BufReader::new(reader); + let mut writer = BufWriter::new(writer); + + // Send welcome message to the client + // let _ = writer + // .write_all(format!("{} \n", messages::WELCOME_MSG).as_bytes()) + // .await; + // let _ = writer.flush().await; + + let (tx, mut rx) = channel::(1000); + + let mut line = String::new(); + let mut my_username = String::new(); + + /* + IMP: + reader.read_line() method expects \n at the end of the message to mark it as line. Without it, it will wait indefinitely. + */ + event!(Level::TRACE, "Starting loop"); + loop { + tokio::select! { + result = reader.read_line(&mut line) => { + match result { + Ok(n) => { + event!(Level::TRACE, "Line received {:?}", &line); + if n == 0 { + event!(Level::TRACE, "Connection closed"); + break Ok(()); + } + + if line.trim() == "\n" { + continue; + } + + let (command, username, data) = extract_parts(&line); + event!(Level::TRACE, "Command: {:?}, Username: {:?}, Data: {:?}", &command, &username, &data); + + // Handle joining of the user + if command == messages::JOIN_USER { + Self::handle_join_command(users.clone(), &mut writer, &mut line, &mut my_username, &data, tx.clone(), &handle_connection_span).await; + } + // Handle leaving of the user + else if command == messages::LEAVE_USER { + Self::handle_leave_command(users.clone(), &mut rx, &data, &handle_connection_span).await; + } + // Handle user messages + else if command == messages::USER_MSG { + Self::handle_user_messages(users.clone(), &my_username, &mut writer, &mut line, &data, &username, &handle_connection_span).await; + } else { + event!(Level::WARN, "Invalid command received: {:?}", &line); + let _ = writer.write_all(format!("{}", messages::INVALID_CMD).as_bytes()).await; + let _ = writer.flush().await; + } + + line.clear(); // Clear the buffer for safety. + + } + Err(e) => { + event!(Level::ERROR, "ERROR: Failed to read from socket; error={:?}", e); + break Ok(()); + } + } + } + result = rx.recv() => { + match result { + Some(msg) => { + // Received message from other user(s). Send it to this user if this user is not the sender. + let _ = writer.write_all(msg.as_bytes()).await; + let _ = writer.flush().await; + } + None => { + event!(Level::ERROR, "Channel closed"); + break Ok(()); + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[tokio::test] + async fn test_handle_leave_command() { + /*********** Preparation **********/ + let data: &str = "testuser"; + let (tx, mut rx) = channel::(1); + let mut users_map = HashMap::new(); + users_map.insert(data.to_string(), tx); + let span = span!(Level::INFO, "test_handle_leave_command"); + let users = Arc::new(Mutex::new(users_map)); + + /*********** Call **********/ + SimpleChatServer::handle_leave_command(users.clone(), &mut rx, data, &span).await; + + /*********** Assertion **********/ + assert!(rx.is_closed()); + assert_eq!(users.lock().await.len(), 0); + } + + // Tests the scenario where a new user joins the chat. The username of a new user doesn't exists. + #[tokio::test] + async fn test_handle_join_command_user_not_exists() { + /*********** Preparation **********/ + let existing_username: &str = "testuser"; + let (tx, _rx) = channel::(1); + + let (client, mut _server) = io::duplex(64); + let mut buf_writer = BufWriter::new(client); + let mut line = String::from("test line from user"); + let mut my_username = String::new(); + let users = Arc::new(Mutex::new(HashMap::new())); + + // This is really not important as we are not using it in the test. + let span = span!(Level::INFO, "test_handle_join_command_user_not_exists"); + + /*********** Call **********/ + SimpleChatServer::handle_join_command( + users.clone(), + &mut buf_writer, + &mut line, + &mut my_username, + existing_username, + tx, + &span, + ) + .await; + + /*********** Assertion **********/ + assert!(users.lock().await.contains_key(existing_username)); + assert_eq!(my_username, existing_username); + assert!(line.is_empty()); + } + + // Tests the scenario where a new user joins the chat. The username of a new user already exists. + #[tokio::test] + async fn test_handle_join_command_user_exists() { + /*********** Preparation **********/ + let existing_username: &str = "testuser"; + let (client, server) = io::duplex(64); + + let mut buf_writer = BufWriter::new(client); + let mut buf_reader = BufReader::new(server); + + let mut line = String::from("test line from user"); + let mut my_username = String::new(); + + // Add user already + let mut users_map: HashMap> = HashMap::new(); + let (tx, _rx) = channel::(1); + users_map.insert(existing_username.to_string(), tx.clone()); + + let users = Arc::new(Mutex::new(users_map)); + + // This is really not important as we are not using it in the test. + let span = span!(Level::INFO, "test_handle_join_command_user_exists"); + + /*********** Call **********/ + SimpleChatServer::handle_join_command( + users.clone(), + &mut buf_writer, + &mut line, + &mut my_username, + existing_username, + tx, + &span, + ) + .await; + + /*********** Assertion **********/ + assert!(my_username.is_empty()); + let mut output = String::new(); + let _ = buf_reader.read_line(&mut output).await; + assert_eq!(output, format!("<{}>\n", messages::DUPLICATE_USER)); + assert!(line.is_empty()); + } + + // Tests the scenario where a connection is trying to send message without joining the chat. + #[tokio::test] + async fn test_handle_user_msgs_without_joining() { + /*********** Preparation **********/ + let (client, server) = io::duplex(64); + + let mut buf_writer = BufWriter::new(client); + let mut buf_reader = BufReader::new(server); + + let mut line = String::from("test line from user"); + + // We send blank username to simulate the scenario where user has not joined the chat. + let mut my_username = String::new(); + let mut output = String::new(); + + let users = Arc::new(Mutex::new(HashMap::new())); + + // This is really not important as we are not using it in the test. + let span = span!(Level::INFO, "test_handle_join_command_user_exists"); + + /*********** Call **********/ + SimpleChatServer::handle_user_messages( + users.clone(), + &mut my_username, + &mut buf_writer, + &mut line, + "", + "", + &span, + ) + .await; + + /*********** Assertion **********/ + let _ = buf_reader.read_line(&mut output).await; + assert_eq!(output, "Please join the chat first.\n"); + assert!(line.is_empty()); + } + + // Tests the scenario where a connection is trying to send message without joining the chat. + #[tokio::test] + async fn test_handle_user_msgs_after_joining() { + /*********** Preparation **********/ + let (client, _) = io::duplex(64); + + let mut buf_writer = BufWriter::new(client); + + let mut line = String::from("test line from user"); + + let mut my_username = String::from("user1"); + let sender_username = "user1"; + + // Insert users to simulate the scenario where users have already joined the chat. + let (tx1, mut rx1) = channel::(5); + let (tx2, mut rx2) = channel::(5); + let (tx3, mut rx3) = channel::(5); + let mut users_map: HashMap> = HashMap::new(); + users_map.insert("user1".to_string(), tx1); + users_map.insert("user2".to_string(), tx2); + users_map.insert("user3".to_string(), tx3); + + let user_msg = "Hello All"; + + let users = Arc::new(Mutex::new(users_map)); + + // This is really not important as we are not using it in the test. + let span = span!(Level::INFO, "test_handle_join_command_user_exists"); + + /*********** Call **********/ + SimpleChatServer::handle_user_messages( + users.clone(), + &mut my_username, + &mut buf_writer, + &mut line, + user_msg, + sender_username, + &span, + ) + .await; + + /*********** Assertion **********/ + // All users except the sender user1 should receive the message. + // Check for user2 + assert_eq!( + rx2.recv().await.unwrap(), + format!( + "<{}> {} {}\n", + messages::USER_MSG, + sender_username, + user_msg + ) + ); + + // Check for user3 + assert_eq!( + rx3.recv().await.unwrap(), + format!( + "<{}> {} {}\n", + messages::USER_MSG, + sender_username, + user_msg + ) + ); + + // users1 should not receive message from herself + assert!(rx1.try_recv().is_err()); + } +}