diff options
| author | 2025-05-28 00:44:21 +0800 | |
|---|---|---|
| committer | 2025-05-28 00:44:21 +0800 | |
| commit | aa366526c2a60a281aeb52daa2a214970a9d2464 (patch) | |
| tree | 2dc41a0c002fe115679bd317baad4fa393be53f3 /src/main.rs | |
| parent | e31aa3046e2e99f7afab707a505a89d5f0e342a3 (diff) | |
| download | soon-aa366526c2a60a281aeb52daa2a214970a9d2464.tar.gz soon-aa366526c2a60a281aeb52daa2a214970a9d2464.zip | |
feat: Add n-gram support and caching commands in soon CLI
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 178 |
1 files changed, 134 insertions, 44 deletions
diff --git a/src/main.rs b/src/main.rs index 52888f3..3fcd343 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,8 +3,8 @@ use colored::*; use counter::Counter; use std::collections::HashMap; use std::env; -use std::fs::File; -use std::io::{BufRead, BufReader}; +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; use std::path::PathBuf; #[derive(Parser, Debug)] @@ -17,6 +17,8 @@ struct Cli { command: Option<Commands>, #[arg(long)] shell: Option<String>, + #[arg(long, default_value_t = 3)] + ngram: usize, // 新增参数,控制n-gram长度 } #[derive(Subcommand, Debug)] @@ -33,6 +35,13 @@ enum Commands { Version, /// Update self [WIP] Update, + /// Show cached main commands + ShowCache, + /// Cache a command to soon cache (for testing) + Cache { + #[arg()] + cmd: String, + }, } fn detect_shell() -> String { @@ -129,56 +138,131 @@ fn load_history(shell: &str) -> Vec<HistoryItem> { result } -fn predict_next_command(history: &[HistoryItem], cwd: &str) -> Option<String> { - let mut dir_cmds: HashMap<String, Vec<String>> = HashMap::new(); - let mut last_dir: Option<String> = None; +// 提取主要指令 +fn main_cmd(cmd: &str) -> &str { + cmd.split_whitespace().next().unwrap_or("") +} + +// 读取 soon 缓存的最近 n 条主要指令 +fn read_soon_cache(n: usize) -> Vec<String> { + let path = dirs::home_dir().unwrap().join(".soon_cache"); + let mut cmds: Vec<String> = std::fs::read_to_string(path) + .unwrap_or_default() + .lines() + .map(|l| main_cmd(l).to_string()) + .collect(); + // 默认缓存条数为 10 + let n = if n == 0 { 10 } else { n }; + if cmds.len() > n { + cmds = cmds[cmds.len()-n..].to_vec(); + } + cmds +} - for item in history { - let cmd = item.cmd.trim(); - if let Some(rest) = cmd.strip_prefix("cd ") { - let dir = rest.trim().to_string(); - last_dir = Some(dir); - continue; +// 展示缓存的指令(应显示 history 中倒数 n 条主要指令) +fn soon_show_cache(ngram: usize) { + let shell = detect_shell(); + let history = load_history(&shell); + let history_main: Vec<String> = history.iter().map(|h| main_cmd(&h.cmd).to_string()).collect(); + let n = if ngram == 0 { 10 } else { ngram }; + let len = history_main.len(); + let start = if len > n { len - n } else { 0 }; + let cmds = &history_main[start..]; + + println!("{}", "🗂️ Cached main commands (from history):".cyan().bold()); + if cmds.is_empty() { + println!("{}", "No cached commands.".yellow()); + } else { + for (i, cmd) in cmds.iter().enumerate() { + println!("{:>2}: {}", i + 1, cmd); } - if let Some(ref dir) = last_dir { - dir_cmds - .entry(dir.clone()) - .or_default() - .push(cmd.to_string()); - last_dir = None; + } +} + +// 写入 soon 缓存 +fn cache_main_cmd(cmd: &str) { + let path = dirs::home_dir().unwrap().join(".soon_cache"); + let mut file = OpenOptions::new() + .append(true) + .create(true) + .open(path) + .unwrap(); + writeln!(file, "{}", main_cmd(cmd)).unwrap(); +} + +// n-gram 匹配预测(带相关度判定) +fn predict_next_command(history: &[HistoryItem], ngram: usize) -> Option<String> { + let cache_cmds = read_soon_cache(ngram); + if cache_cmds.is_empty() { return None; } + + let history_main: Vec<&str> = history.iter().map(|h| main_cmd(&h.cmd)).collect(); + let mut best_score = 0.0; + let mut best_idx = None; + let mut scores = Vec::new(); + + for i in 0..=history_main.len().saturating_sub(cache_cmds.len()) { + let window = &history_main[i..i+cache_cmds.len()]; + let matches = window.iter().zip(&cache_cmds).filter(|(a, b)| a == &b).count(); + let score = matches as f64 / cache_cmds.len() as f64; + scores.push((i, score)); + if score > best_score { + best_score = score; + best_idx = Some(i + cache_cmds.len()); } } - let cwd_name = std::path::Path::new(cwd) - .file_name() - .and_then(|s| s.to_str()) - .unwrap_or(""); + // 找到所有相关度大于60%的,选择最大相关度的预测 + let mut filtered: Vec<_> = scores.iter() + .filter(|(_, score)| *score >= 0.6) + .collect(); + filtered.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - if let Some(cmds) = dir_cmds.get(cwd_name) { - let mut counter = Counter::<&String, i32>::new(); - for cmd in cmds { - counter.update([cmd]); + if let Some(&(idx, score)) = filtered.first() { + let next_idx = idx + cache_cmds.len(); + if next_idx < history_main.len() { + let next = history_main[next_idx]; + if next != "soon" && !cache_cmds.contains(&next.to_string()) { + return Some(format!("{} (match: {:.0}%)", next, score * 100.0)); + } } - if let Some((cmd, _)) = counter.most_common().into_iter().next() { - return Some(cmd.clone()); + } + + // 如果都小于60%,找最大相关度且>=40% + let mut filtered_40: Vec<_> = scores.iter() + .filter(|(_, score)| *score >= 0.4) + .collect(); + filtered_40.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + if let Some(&(idx, score)) = filtered_40.first() { + let next_idx = idx + cache_cmds.len(); + if next_idx < history_main.len() { + let next = history_main[next_idx]; + if next != "soon" && !cache_cmds.contains(&next.to_string()) { + return Some(format!("{} (match: {:.0}%)", next, score * 100.0)); + } } } - let mut counter = Counter::<&String, i32>::new(); - for item in history { - let cmd = item.cmd.trim(); - if !cmd.starts_with("cd ") { - counter.update([&item.cmd]); + // 如果都小于10%,输出No suggestion + if best_score < 0.1 { + return None; + } + + // 否则输出最接近40%的 + let closest = scores.iter().min_by_key(|(_, score)| ((score - 0.4).abs() * 1000.0) as i32); + if let Some(&(idx, score)) = closest { + let next_idx = idx + cache_cmds.len(); + if next_idx < history_main.len() { + let next = history_main[next_idx]; + if next != "soon" && !cache_cmds.contains(&next.to_string()) { + return Some(format!("{} (match: {:.0}%)", next, score * 100.0)); + } } } - counter - .most_common() - .into_iter() - .next() - .map(|(cmd, _)| cmd.clone()) + + None } -fn soon_now(shell: &str) { +fn soon_now(shell: &str, ngram: usize) { let history = load_history(shell); if history.is_empty() { eprintln!( @@ -187,9 +271,7 @@ fn soon_now(shell: &str) { ); std::process::exit(1); } - let cwd = env::current_dir().unwrap_or_default(); - let cwd = cwd.to_string_lossy(); - let suggestion = predict_next_command(&history, &cwd); + let suggestion = predict_next_command(&history, ngram); println!("\n{}", "🔮 You might run next:".magenta().bold()); if let Some(cmd) = suggestion { println!("{} {}", "👉".green().bold(), cmd.green().bold()); @@ -260,6 +342,12 @@ fn soon_update() { "🔄 [soon update] feature under development...".yellow() ); } + +fn soon_cache(cmd: &str) { + cache_main_cmd(cmd); + println!("Cached main command: {}", main_cmd(cmd)); +} + fn main() { let cli = Cli::parse(); let shell = cli.shell.clone().unwrap_or_else(detect_shell); @@ -270,14 +358,16 @@ fn main() { } match cli.command { - Some(Commands::Now) => soon_now(&shell), + Some(Commands::Now) => soon_now(&shell, cli.ngram), Some(Commands::Stats) => soon_stats(&shell), Some(Commands::Learn) => soon_learn(&shell), Some(Commands::Which) => soon_which(&shell), Some(Commands::Version) => soon_version(), Some(Commands::Update) => soon_update(), + Some(Commands::ShowCache) => soon_show_cache(cli.ngram), + Some(Commands::Cache { cmd }) => soon_cache(&cmd), None => { - soon_now(&shell); + soon_now(&shell, cli.ngram); } } -} +}
\ No newline at end of file |