diff --git a/Cargo.lock b/Cargo.lock index d09e201..00a9b0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -568,6 +568,21 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crossterm" version = "0.25.0" @@ -1544,6 +1559,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "httptest" +version = "0.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde82de3ef9bd882493c6a5edbc3363ad928925b30ccecc0f2ddeb42601b3021" +dependencies = [ + "bstr", + "bytes", + "crossbeam-channel", + "form_urlencoded", + "futures", + "http 1.3.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", +] + [[package]] name = "hyper" version = "0.14.32" @@ -1581,6 +1620,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -2026,6 +2066,7 @@ dependencies = [ "async-trait", "env_logger", "futures-util", + "httptest", "log", "octocrab", "pretty_assertions", diff --git a/k3d/Cargo.toml b/k3d/Cargo.toml index 633859d..1124d75 100644 --- a/k3d/Cargo.toml +++ b/k3d/Cargo.toml @@ -6,27 +6,17 @@ readme.workspace = true license.workspace = true [dependencies] -#serde = { version = "1.0.123", features = [ "derive" ] } log = { workspace = true } -env_logger = { workspace = true } -#russh = { workspace = true } -#russh-keys = { workspace = true } -#thiserror = "1.0" async-trait = { workspace = true } tokio = { workspace = true } octocrab = "0.44.0" regex = "1.11.1" reqwest = { version = "0.12", features = ["stream"] } -#hyper-rustls = "0.27.5" -#hyper = { version = "1", features = [ "client" ] } -#hyper = { version = "1", features = ["full"] } -#http-body-util = "0.1" -#hyper-util = { version = "0.1", features = ["full"] } url.workspace = true sha2 = "0.10.8" futures-util = "0.3.31" -#bytes = "1.10.1" -#serde_json = "1.0.133" [dev-dependencies] +env_logger = { workspace = true } +httptest = "0.16.3" pretty_assertions = "1.4.1" diff --git a/k3d/src/downloadable_asset.rs b/k3d/src/downloadable_asset.rs index 53de329..ababc77 100644 --- a/k3d/src/downloadable_asset.rs +++ b/k3d/src/downloadable_asset.rs @@ -8,6 +8,33 @@ use tokio::fs::File; use tokio::io::AsyncWriteExt; use url::Url; +const CHECKSUM_FAILED_MSG: &str = "Downloaded file failed checksum verification"; + +/// Represents an asset that can be downloaded from a URL with checksum verification. +/// +/// This struct facilitates secure downloading of files from remote URLs by +/// verifying the integrity of the downloaded content using SHA-256 checksums. +/// It handles downloading the file, saving it to disk, and verifying the checksum matches +/// the expected value. +/// +/// # Examples +/// +/// ```compile_fail +/// # use url::Url; +/// # use std::path::PathBuf; +/// +/// # async fn example() -> Result<(), String> { +/// let asset = DownloadableAsset { +/// url: Url::parse("https://example.com/file.zip").unwrap(), +/// file_name: "file.zip".to_string(), +/// checksum: "a1b2c3d4e5f6...".to_string(), +/// }; +/// +/// let download_dir = PathBuf::from("/tmp/downloads"); +/// let file_path = asset.download_to_path(download_dir).await?; +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub(crate) struct DownloadableAsset { pub(crate) url: Url, @@ -55,6 +82,30 @@ impl DownloadableAsset { calculated_hash == self.checksum } + /// Downloads the asset to the specified directory, verifying its checksum. + /// + /// This function will: + /// 1. Create the target directory if it doesn't exist + /// 2. Check if the file already exists with the correct checksum + /// 3. If not, download the file from the URL + /// 4. Verify the downloaded file's checksum matches the expected value + /// + /// # Arguments + /// + /// * `folder` - The directory path where the file should be saved + /// + /// # Returns + /// + /// * `Ok(PathBuf)` - The path to the downloaded file on success + /// * `Err(String)` - A descriptive error message if the download or verification fails + /// + /// # Errors + /// + /// This function will return an error if: + /// - The network request fails + /// - The server responds with a non-success status code + /// - Writing to disk fails + /// - The checksum verification fails pub(crate) async fn download_to_path(&self, folder: PathBuf) -> Result { if !folder.exists() { fs::create_dir_all(&folder) @@ -101,7 +152,7 @@ impl DownloadableAsset { drop(file); if !self.verify_checksum(target_file_path.clone()) { - panic!("Downloaded file failed checksum verification"); + return Err(CHECKSUM_FAILED_MSG.to_string()); } info!( @@ -115,54 +166,20 @@ impl DownloadableAsset { #[cfg(test)] mod tests { use super::*; - use std::io::Write; - use std::net::TcpListener; - use std::sync::OnceLock; - use std::thread; + use httptest::{ + matchers::{self, request}, + responders, Expectation, Server, + }; const BASE_TEST_PATH: &str = "/tmp/harmony-test-k3d-download"; - const TEST_SERVER_PORT: u16 = 18452; const TEST_CONTENT: &str = "This is a test file."; const TEST_CONTENT_HASH: &str = "f29bc64a9d3732b4b9035125fdb3285f5b6455778edca72414671e0ca3b2e0de"; - struct TestContext { - download_path: String, - domain: String, - } - - static TEST_SERVER: OnceLock<()> = OnceLock::new(); - - fn init_logs() { + fn setup_test() -> (PathBuf, Server) { let _ = env_logger::builder().try_init(); - } - - fn setup_test() -> TestContext { - init_logs(); - - TEST_SERVER.get_or_init(|| { - let listener = TcpListener::bind(format!("127.0.0.1:{}", TEST_SERVER_PORT)).unwrap(); - - thread::spawn(move || { - for stream in listener.incoming() { - thread::spawn(move || { - let mut stream = stream.expect("Stream opened correctly"); - let mut buffer = [0; 1024]; - let _ = stream.read(&mut buffer); - - let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nContent-Length: {}\r\n\r\n{}", - TEST_CONTENT.len(), - TEST_CONTENT - ); - - stream.write_all(response.as_bytes()).expect("Can write to stream"); - stream.flush().expect("Can flush stream"); - }); - } - }); - }); + // Create unique test directory let test_id = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() @@ -170,51 +187,44 @@ mod tests { let download_path = format!("{}/test_{}", BASE_TEST_PATH, test_id); std::fs::create_dir_all(&download_path).unwrap(); - assert!(wait_for_server_ready(1000), "Test server failed to start"); - - TestContext { - download_path, - domain: format!("127.0.0.1:{}", TEST_SERVER_PORT), - } - } - - fn wait_for_server_ready(timeout_ms: u64) -> bool { - let start = std::time::Instant::now(); - let timeout = std::time::Duration::from_millis(timeout_ms); - - while start.elapsed() < timeout { - if std::net::TcpStream::connect(format!("127.0.0.1:{}", TEST_SERVER_PORT)).is_ok() { - return true; - } - std::thread::sleep(std::time::Duration::from_millis(50)); - } - false + (PathBuf::from(download_path), Server::run()) } #[tokio::test] async fn test_download_to_path_success() { - let test = setup_test(); + let (folder, server) = setup_test(); + + server.expect( + Expectation::matching(request::method_path("GET", "/test.txt")) + .respond_with(responders::status_code(200).body(TEST_CONTENT)), + ); let asset = DownloadableAsset { - url: Url::parse(&format!("http://{}/test.txt", test.domain)).unwrap(), + url: Url::parse(&server.url("/test.txt").to_string()).unwrap(), file_name: "test.txt".to_string(), checksum: TEST_CONTENT_HASH.to_string(), }; - let folder = PathBuf::from(&test.download_path); - let result = asset.download_to_path(folder).await.unwrap(); - + let result = asset + .download_to_path(folder.join("success")) + .await + .unwrap(); let downloaded_content = std::fs::read_to_string(result).unwrap(); assert_eq!(downloaded_content, TEST_CONTENT); } #[tokio::test] async fn test_download_to_path_already_exists() { - let test = setup_test(); - let folder = PathBuf::from(&test.download_path); + let (folder, server) = setup_test(); + + server.expect( + Expectation::matching(matchers::any()) + .times(0) + .respond_with(responders::status_code(200).body(TEST_CONTENT)), + ); let asset = DownloadableAsset { - url: Url::parse(&format!("http://{}/test.txt", test.domain)).unwrap(), + url: Url::parse(&server.url("/test.txt").to_string()).unwrap(), file_name: "test.txt".to_string(), checksum: TEST_CONTENT_HASH.to_string(), }; @@ -228,18 +238,66 @@ mod tests { } #[tokio::test] - async fn test_download_to_path_failure() { - let test = setup_test(); + async fn test_download_to_path_server_error() { + let (folder, server) = setup_test(); + + server.expect( + Expectation::matching(matchers::any()).respond_with(responders::status_code(404)), + ); let asset = DownloadableAsset { - url: Url::parse("http://127.0.0.1:9999/test.txt").unwrap(), + url: Url::parse(&server.url("/test.txt").to_string()).unwrap(), file_name: "test.txt".to_string(), - checksum: "some_checksum".to_string(), + checksum: TEST_CONTENT_HASH.to_string(), }; - let result = asset - .download_to_path(PathBuf::from(&test.download_path)) - .await; + let result = asset.download_to_path(folder.join("error")).await; assert!(result.is_err()); + assert!(result.unwrap_err().contains("status: 404")); + } + + #[tokio::test] + async fn test_download_to_path_checksum_failure() { + let (folder, server) = setup_test(); + + let invalid_content = "This is NOT the expected content"; + server.expect( + Expectation::matching(matchers::any()) + .respond_with(responders::status_code(200).body(invalid_content)), + ); + + let asset = DownloadableAsset { + url: Url::parse(&server.url("/test.txt").to_string()).unwrap(), + file_name: "test.txt".to_string(), + checksum: TEST_CONTENT_HASH.to_string(), + }; + + let join_handle = + tokio::spawn(async move { asset.download_to_path(folder.join("failure")).await }); + + assert_eq!( + join_handle.await.unwrap().err().unwrap(), + CHECKSUM_FAILED_MSG + ); + } + + #[tokio::test] + async fn test_download_with_specific_path_matcher() { + let (folder, server) = setup_test(); + + server.expect( + Expectation::matching(matchers::request::path("/specific/path.txt")) + .respond_with(responders::status_code(200).body(TEST_CONTENT)), + ); + + let asset = DownloadableAsset { + url: Url::parse(&server.url("/specific/path.txt").to_string()).unwrap(), + file_name: "path.txt".to_string(), + checksum: TEST_CONTENT_HASH.to_string(), + }; + + let result = asset.download_to_path(folder).await.unwrap(); + let downloaded_content = std::fs::read_to_string(result).unwrap(); + assert_eq!(downloaded_content, TEST_CONTENT); } }