use futures_util::StreamExt; use log::{debug, info, warn}; use sha2::{Digest, Sha256}; use std::io::Read; use std::path::PathBuf; use tokio::fs; use tokio::fs::File; use tokio::io::AsyncWriteExt; use url::Url; #[derive(Debug)] pub(crate) struct DownloadableAsset { pub(crate) url: Url, pub(crate) file_name: String, pub(crate) checksum: String, } impl DownloadableAsset { fn verify_checksum(&self, file: PathBuf) -> bool { if !file.exists() { warn!("File does not exist: {:?}", file); return false; } let mut file = match std::fs::File::open(&file) { Ok(file) => file, Err(e) => { warn!("Failed to open file for checksum verification: {:?}", e); return false; } }; let mut hasher = Sha256::new(); let mut buffer = [0; 1024 * 1024]; // 1MB buffer loop { let bytes_read = match file.read(&mut buffer) { Ok(0) => break, Ok(n) => n, Err(e) => { warn!("Error reading file for checksum: {:?}", e); return false; } }; hasher.update(&buffer[..bytes_read]); } let result = hasher.finalize(); let calculated_hash = format!("{:x}", result); debug!("Expected checksum: {}", self.checksum); debug!("Calculated checksum: {}", calculated_hash); calculated_hash == self.checksum } pub(crate) async fn download_to_path(&self, folder: PathBuf) -> Result { if !folder.exists() { fs::create_dir_all(&folder) .await .expect("Failed to create download directory"); } let target_file_path = folder.join(&self.file_name); debug!("Downloading to path: {:?}", target_file_path); if self.verify_checksum(target_file_path.clone()) { debug!("File already exists with correct checksum, skipping download"); return Ok(target_file_path); } debug!("Downloading from URL: {}", self.url); let client = reqwest::Client::new(); let response = client .get(self.url.clone()) .send() .await .map_err(|e| format!("Failed to download file: {e}"))?; if !response.status().is_success() { return Err(format!( "Failed to download file, status: {}", response.status() )); } let mut file = File::create(&target_file_path) .await .expect("Failed to create target file"); let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = chunk_result.expect("Error while downloading file"); file.write_all(&chunk) .await .expect("Failed to write data to file"); } file.flush().await.expect("Failed to flush file"); drop(file); if !self.verify_checksum(target_file_path.clone()) { panic!("Downloaded file failed checksum verification"); } info!( "File downloaded and verified successfully: {}", target_file_path.to_string_lossy() ); Ok(target_file_path) } } #[cfg(test)] mod tests { use super::*; use std::io::Write; use std::net::TcpListener; use std::sync::OnceLock; use std::thread; const BASE_TEST_PATH: &str = "/tmp/harmony-test-k3d-download"; const TEST_SERVER_PORT: u16 = 18452; const TEST_CONTENT: &str = "This is a test file."; const TEST_CONTENT_HASH: &str = "f29bc64a9d3732b4b9035125fdb3285f5b6455778edca72414671e0ca3b2e0de"; struct TestContext { download_path: String, domain: String, } static TEST_SERVER: OnceLock<()> = OnceLock::new(); fn init_logs() { let _ = env_logger::builder().try_init(); } fn setup_test() -> TestContext { init_logs(); TEST_SERVER.get_or_init(|| { let listener = TcpListener::bind(format!("127.0.0.1:{}", TEST_SERVER_PORT)).unwrap(); thread::spawn(move || { for stream in listener.incoming() { thread::spawn(move || { let mut stream = stream.expect("Stream opened correctly"); let mut buffer = [0; 1024]; let _ = stream.read(&mut buffer); let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nContent-Length: {}\r\n\r\n{}", TEST_CONTENT.len(), TEST_CONTENT ); stream.write_all(response.as_bytes()).expect("Can write to stream"); stream.flush().expect("Can flush stream"); }); } }); }); let test_id = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_millis(); let download_path = format!("{}/test_{}", BASE_TEST_PATH, test_id); std::fs::create_dir_all(&download_path).unwrap(); assert!(wait_for_server_ready(1000), "Test server failed to start"); TestContext { download_path, domain: format!("127.0.0.1:{}", TEST_SERVER_PORT), } } fn wait_for_server_ready(timeout_ms: u64) -> bool { let start = std::time::Instant::now(); let timeout = std::time::Duration::from_millis(timeout_ms); while start.elapsed() < timeout { if std::net::TcpStream::connect(format!("127.0.0.1:{}", TEST_SERVER_PORT)).is_ok() { return true; } std::thread::sleep(std::time::Duration::from_millis(50)); } false } #[tokio::test] async fn test_download_to_path_success() { let test = setup_test(); let asset = DownloadableAsset { url: Url::parse(&format!("http://{}/test.txt", test.domain)).unwrap(), file_name: "test.txt".to_string(), checksum: TEST_CONTENT_HASH.to_string(), }; let folder = PathBuf::from(&test.download_path); let result = asset.download_to_path(folder).await.unwrap(); let downloaded_content = std::fs::read_to_string(result).unwrap(); assert_eq!(downloaded_content, TEST_CONTENT); } #[tokio::test] async fn test_download_to_path_already_exists() { let test = setup_test(); let folder = PathBuf::from(&test.download_path); let asset = DownloadableAsset { url: Url::parse(&format!("http://{}/test.txt", test.domain)).unwrap(), file_name: "test.txt".to_string(), checksum: TEST_CONTENT_HASH.to_string(), }; let target_file_path = folder.join(&asset.file_name); std::fs::write(&target_file_path, TEST_CONTENT).unwrap(); let result = asset.download_to_path(folder).await.unwrap(); let content = std::fs::read_to_string(result).unwrap(); assert_eq!(content, TEST_CONTENT); } #[tokio::test] async fn test_download_to_path_failure() { let test = setup_test(); let asset = DownloadableAsset { url: Url::parse("http://127.0.0.1:9999/test.txt").unwrap(), file_name: "test.txt".to_string(), checksum: "some_checksum".to_string(), }; let result = asset .download_to_path(PathBuf::from(&test.download_path)) .await; assert!(result.is_err()); } }