summaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-05-28 00:44:21 +0800
committerHsiangNianian <i@jyunko.cn>2025-05-28 00:44:21 +0800
commitaa366526c2a60a281aeb52daa2a214970a9d2464 (patch)
tree2dc41a0c002fe115679bd317baad4fa393be53f3 /src
parente31aa3046e2e99f7afab707a505a89d5f0e342a3 (diff)
downloadsoon-aa366526c2a60a281aeb52daa2a214970a9d2464.tar.gz
soon-aa366526c2a60a281aeb52daa2a214970a9d2464.zip
feat: Add n-gram support and caching commands in soon CLI
Diffstat (limited to 'src')
-rw-r--r--src/main.rs178
1 files changed, 134 insertions, 44 deletions
diff --git a/src/main.rs b/src/main.rs
index 52888f3..3fcd343 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -3,8 +3,8 @@ use colored::*;
use counter::Counter;
use std::collections::HashMap;
use std::env;
-use std::fs::File;
-use std::io::{BufRead, BufReader};
+use std::fs::{File, OpenOptions};
+use std::io::{BufRead, BufReader, Write};
use std::path::PathBuf;
#[derive(Parser, Debug)]
@@ -17,6 +17,8 @@ struct Cli {
command: Option<Commands>,
#[arg(long)]
shell: Option<String>,
+ #[arg(long, default_value_t = 3)]
+ ngram: usize, // 新增参数,控制n-gram长度
}
#[derive(Subcommand, Debug)]
@@ -33,6 +35,13 @@ enum Commands {
Version,
/// Update self [WIP]
Update,
+ /// Show cached main commands
+ ShowCache,
+ /// Cache a command to soon cache (for testing)
+ Cache {
+ #[arg()]
+ cmd: String,
+ },
}
fn detect_shell() -> String {
@@ -129,56 +138,131 @@ fn load_history(shell: &str) -> Vec<HistoryItem> {
result
}
-fn predict_next_command(history: &[HistoryItem], cwd: &str) -> Option<String> {
- let mut dir_cmds: HashMap<String, Vec<String>> = HashMap::new();
- let mut last_dir: Option<String> = None;
+// 提取主要指令
+fn main_cmd(cmd: &str) -> &str {
+ cmd.split_whitespace().next().unwrap_or("")
+}
+
+// 读取 soon 缓存的最近 n 条主要指令
+fn read_soon_cache(n: usize) -> Vec<String> {
+ let path = dirs::home_dir().unwrap().join(".soon_cache");
+ let mut cmds: Vec<String> = std::fs::read_to_string(path)
+ .unwrap_or_default()
+ .lines()
+ .map(|l| main_cmd(l).to_string())
+ .collect();
+ // 默认缓存条数为 10
+ let n = if n == 0 { 10 } else { n };
+ if cmds.len() > n {
+ cmds = cmds[cmds.len()-n..].to_vec();
+ }
+ cmds
+}
- for item in history {
- let cmd = item.cmd.trim();
- if let Some(rest) = cmd.strip_prefix("cd ") {
- let dir = rest.trim().to_string();
- last_dir = Some(dir);
- continue;
+// 展示缓存的指令(应显示 history 中倒数 n 条主要指令)
+fn soon_show_cache(ngram: usize) {
+ let shell = detect_shell();
+ let history = load_history(&shell);
+ let history_main: Vec<String> = history.iter().map(|h| main_cmd(&h.cmd).to_string()).collect();
+ let n = if ngram == 0 { 10 } else { ngram };
+ let len = history_main.len();
+ let start = if len > n { len - n } else { 0 };
+ let cmds = &history_main[start..];
+
+ println!("{}", "🗂️ Cached main commands (from history):".cyan().bold());
+ if cmds.is_empty() {
+ println!("{}", "No cached commands.".yellow());
+ } else {
+ for (i, cmd) in cmds.iter().enumerate() {
+ println!("{:>2}: {}", i + 1, cmd);
}
- if let Some(ref dir) = last_dir {
- dir_cmds
- .entry(dir.clone())
- .or_default()
- .push(cmd.to_string());
- last_dir = None;
+ }
+}
+
+// 写入 soon 缓存
+fn cache_main_cmd(cmd: &str) {
+ let path = dirs::home_dir().unwrap().join(".soon_cache");
+ let mut file = OpenOptions::new()
+ .append(true)
+ .create(true)
+ .open(path)
+ .unwrap();
+ writeln!(file, "{}", main_cmd(cmd)).unwrap();
+}
+
+// n-gram 匹配预测(带相关度判定)
+fn predict_next_command(history: &[HistoryItem], ngram: usize) -> Option<String> {
+ let cache_cmds = read_soon_cache(ngram);
+ if cache_cmds.is_empty() { return None; }
+
+ let history_main: Vec<&str> = history.iter().map(|h| main_cmd(&h.cmd)).collect();
+ let mut best_score = 0.0;
+ let mut best_idx = None;
+ let mut scores = Vec::new();
+
+ for i in 0..=history_main.len().saturating_sub(cache_cmds.len()) {
+ let window = &history_main[i..i+cache_cmds.len()];
+ let matches = window.iter().zip(&cache_cmds).filter(|(a, b)| a == &b).count();
+ let score = matches as f64 / cache_cmds.len() as f64;
+ scores.push((i, score));
+ if score > best_score {
+ best_score = score;
+ best_idx = Some(i + cache_cmds.len());
}
}
- let cwd_name = std::path::Path::new(cwd)
- .file_name()
- .and_then(|s| s.to_str())
- .unwrap_or("");
+ // 找到所有相关度大于60%的,选择最大相关度的预测
+ let mut filtered: Vec<_> = scores.iter()
+ .filter(|(_, score)| *score >= 0.6)
+ .collect();
+ filtered.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
- if let Some(cmds) = dir_cmds.get(cwd_name) {
- let mut counter = Counter::<&String, i32>::new();
- for cmd in cmds {
- counter.update([cmd]);
+ if let Some(&(idx, score)) = filtered.first() {
+ let next_idx = idx + cache_cmds.len();
+ if next_idx < history_main.len() {
+ let next = history_main[next_idx];
+ if next != "soon" && !cache_cmds.contains(&next.to_string()) {
+ return Some(format!("{} (match: {:.0}%)", next, score * 100.0));
+ }
}
- if let Some((cmd, _)) = counter.most_common().into_iter().next() {
- return Some(cmd.clone());
+ }
+
+ // 如果都小于60%,找最大相关度且>=40%
+ let mut filtered_40: Vec<_> = scores.iter()
+ .filter(|(_, score)| *score >= 0.4)
+ .collect();
+ filtered_40.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
+ if let Some(&(idx, score)) = filtered_40.first() {
+ let next_idx = idx + cache_cmds.len();
+ if next_idx < history_main.len() {
+ let next = history_main[next_idx];
+ if next != "soon" && !cache_cmds.contains(&next.to_string()) {
+ return Some(format!("{} (match: {:.0}%)", next, score * 100.0));
+ }
}
}
- let mut counter = Counter::<&String, i32>::new();
- for item in history {
- let cmd = item.cmd.trim();
- if !cmd.starts_with("cd ") {
- counter.update([&item.cmd]);
+ // 如果都小于10%,输出No suggestion
+ if best_score < 0.1 {
+ return None;
+ }
+
+ // 否则输出最接近40%的
+ let closest = scores.iter().min_by_key(|(_, score)| ((score - 0.4).abs() * 1000.0) as i32);
+ if let Some(&(idx, score)) = closest {
+ let next_idx = idx + cache_cmds.len();
+ if next_idx < history_main.len() {
+ let next = history_main[next_idx];
+ if next != "soon" && !cache_cmds.contains(&next.to_string()) {
+ return Some(format!("{} (match: {:.0}%)", next, score * 100.0));
+ }
}
}
- counter
- .most_common()
- .into_iter()
- .next()
- .map(|(cmd, _)| cmd.clone())
+
+ None
}
-fn soon_now(shell: &str) {
+fn soon_now(shell: &str, ngram: usize) {
let history = load_history(shell);
if history.is_empty() {
eprintln!(
@@ -187,9 +271,7 @@ fn soon_now(shell: &str) {
);
std::process::exit(1);
}
- let cwd = env::current_dir().unwrap_or_default();
- let cwd = cwd.to_string_lossy();
- let suggestion = predict_next_command(&history, &cwd);
+ let suggestion = predict_next_command(&history, ngram);
println!("\n{}", "🔮 You might run next:".magenta().bold());
if let Some(cmd) = suggestion {
println!("{} {}", "👉".green().bold(), cmd.green().bold());
@@ -260,6 +342,12 @@ fn soon_update() {
"🔄 [soon update] feature under development...".yellow()
);
}
+
+fn soon_cache(cmd: &str) {
+ cache_main_cmd(cmd);
+ println!("Cached main command: {}", main_cmd(cmd));
+}
+
fn main() {
let cli = Cli::parse();
let shell = cli.shell.clone().unwrap_or_else(detect_shell);
@@ -270,14 +358,16 @@ fn main() {
}
match cli.command {
- Some(Commands::Now) => soon_now(&shell),
+ Some(Commands::Now) => soon_now(&shell, cli.ngram),
Some(Commands::Stats) => soon_stats(&shell),
Some(Commands::Learn) => soon_learn(&shell),
Some(Commands::Which) => soon_which(&shell),
Some(Commands::Version) => soon_version(),
Some(Commands::Update) => soon_update(),
+ Some(Commands::ShowCache) => soon_show_cache(cli.ngram),
+ Some(Commands::Cache { cmd }) => soon_cache(&cmd),
None => {
- soon_now(&shell);
+ soon_now(&shell, cli.ngram);
}
}
-}
+} \ No newline at end of file