diff options
Diffstat (limited to 'src-tauri/src/core/downloader.rs')
| -rw-r--r-- | src-tauri/src/core/downloader.rs | 164 |
1 files changed, 126 insertions, 38 deletions
diff --git a/src-tauri/src/core/downloader.rs b/src-tauri/src/core/downloader.rs index 09101c9..d33c44d 100644 --- a/src-tauri/src/core/downloader.rs +++ b/src-tauri/src/core/downloader.rs @@ -2,6 +2,7 @@ use futures::StreamExt; use serde::{Deserialize, Serialize}; use sha1::Digest as Sha1Digest; use std::path::PathBuf; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use tauri::{Emitter, Window}; use tokio::io::AsyncWriteExt; @@ -23,6 +24,9 @@ pub struct ProgressEvent { pub downloaded: u64, pub total: u64, pub status: String, // "Downloading", "Verifying", "Finished", "Error" + pub completed_files: usize, + pub total_files: usize, + pub total_downloaded_bytes: u64, } /// calculate SHA256 hash of data @@ -51,9 +55,96 @@ pub fn verify_checksum(data: &[u8], sha256: Option<&str>, sha1: Option<&str>) -> true } -pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result<(), String> { +/// Snapshot of global progress state +struct ProgressSnapshot { + completed_files: usize, + total_files: usize, + total_downloaded_bytes: u64, +} + +/// Centralized progress tracking with atomic counters +struct GlobalProgress { + completed_files: AtomicUsize, + total_downloaded_bytes: AtomicU64, + total_files: usize, +} + +impl GlobalProgress { + fn new(total_files: usize) -> Self { + Self { + completed_files: AtomicUsize::new(0), + total_downloaded_bytes: AtomicU64::new(0), + total_files, + } + } + + /// Get current progress snapshot without modification + fn snapshot(&self) -> ProgressSnapshot { + ProgressSnapshot { + completed_files: self.completed_files.load(Ordering::Relaxed), + total_files: self.total_files, + total_downloaded_bytes: self.total_downloaded_bytes.load(Ordering::Relaxed), + } + } + + /// Increment completed files counter and return updated snapshot + fn inc_completed(&self) -> ProgressSnapshot { + let completed = self.completed_files.fetch_add(1, Ordering::Relaxed) + 1; + ProgressSnapshot { + completed_files: completed, + total_files: self.total_files, + total_downloaded_bytes: self.total_downloaded_bytes.load(Ordering::Relaxed), + } + } + + /// Add downloaded bytes and return updated snapshot + fn add_bytes(&self, delta: u64) -> ProgressSnapshot { + let total_bytes = self + .total_downloaded_bytes + .fetch_add(delta, Ordering::Relaxed) + + delta; + ProgressSnapshot { + completed_files: self.completed_files.load(Ordering::Relaxed), + total_files: self.total_files, + total_downloaded_bytes: total_bytes, + } + } +} + +/// Emit a progress event to the frontend +fn emit_progress( + window: &Window, + file_name: &str, + status: &str, + downloaded: u64, + total: u64, + snapshot: &ProgressSnapshot, +) { + let _ = window.emit( + "download-progress", + ProgressEvent { + file: file_name.to_string(), + downloaded, + total, + status: status.into(), + completed_files: snapshot.completed_files, + total_files: snapshot.total_files, + total_downloaded_bytes: snapshot.total_downloaded_bytes, + }, + ); +} + +pub async fn download_files( + window: Window, + tasks: Vec<DownloadTask>, + max_concurrent: usize, +) -> Result<(), String> { + // Clamp max_concurrent to a valid range (1-128) to prevent edge cases + let max_concurrent = max_concurrent.clamp(1, 128); + let client = reqwest::Client::new(); - let semaphore = Arc::new(Semaphore::new(10)); // Max 10 concurrent downloads + let semaphore = Arc::new(Semaphore::new(max_concurrent)); + let progress = Arc::new(GlobalProgress::new(tasks.len())); // Notify start (total files) let _ = window.emit("download-start", tasks.len()); @@ -62,6 +153,7 @@ pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result< let client = client.clone(); let window = window.clone(); let semaphore = semaphore.clone(); + let progress = progress.clone(); async move { let _permit = semaphore.acquire().await.unwrap(); @@ -69,15 +161,7 @@ pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result< // 1. Check if file exists and verify checksum if task.path.exists() { - let _ = window.emit( - "download-progress", - ProgressEvent { - file: file_name.clone(), - downloaded: 0, - total: 0, - status: "Verifying".into(), - }, - ); + emit_progress(&window, &file_name, "Verifying", 0, 0, &progress.snapshot()); if task.sha256.is_some() || task.sha1.is_some() { if let Ok(data) = tokio::fs::read(&task.path).await { @@ -86,15 +170,21 @@ pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result< task.sha256.as_deref(), task.sha1.as_deref(), ) { - // Already valid - let _ = window.emit( - "download-progress", - ProgressEvent { - file: file_name.clone(), - downloaded: 0, - total: 0, - status: "Skipped".into(), - }, + // Already valid, skip download + let skipped_size = tokio::fs::metadata(&task.path) + .await + .map(|m| m.len()) + .unwrap_or(0); + if skipped_size > 0 { + let _ = progress.add_bytes(skipped_size); + } + emit_progress( + &window, + &file_name, + "Skipped", + 0, + 0, + &progress.inc_completed(), ); return Ok(()); } @@ -123,14 +213,14 @@ pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result< return Err(format!("Write error: {}", e)); } downloaded += chunk.len() as u64; - let _ = window.emit( - "download-progress", - ProgressEvent { - file: file_name.clone(), - downloaded, - total: total_size, - status: "Downloading".into(), - }, + let snapshot = progress.add_bytes(chunk.len() as u64); + emit_progress( + &window, + &file_name, + "Downloading", + downloaded, + total_size, + &snapshot, ); } Ok(None) => break, @@ -141,23 +231,21 @@ pub async fn download_files(window: Window, tasks: Vec<DownloadTask>) -> Result< Err(e) => return Err(format!("Request error: {}", e)), } - let _ = window.emit( - "download-progress", - ProgressEvent { - file: file_name.clone(), - downloaded: 0, - total: 0, - status: "Finished".into(), - }, + emit_progress( + &window, + &file_name, + "Finished", + 0, + 0, + &progress.inc_completed(), ); - Ok(()) } }); // Buffer unordered to run concurrently tasks_stream - .buffer_unordered(10) + .buffer_unordered(max_concurrent) .collect::<Vec<Result<(), String>>>() .await; |