aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/basemodel/download_model.py
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
commit575114661ef9afb95df2a211e1d8498686340e6b (patch)
tree91f1646cececb1597a9246865e89b52e059d3cfa /src/basemodel/download_model.py
parent7ac684f1f82023c6284cd7d7efde11b8dc98c149 (diff)
downloadbase-model-575114661ef9afb95df2a211e1d8498686340e6b.tar.gz
base-model-575114661ef9afb95df2a211e1d8498686340e6b.zip
feat: Refactor and enhance TRPG NER model SDK
- Removed deprecated `word_conll_to_char_conll.py` utility and integrated its functionality into the new `utils` module. - Introduced a comprehensive GitHub Actions workflow for automated publishing to PyPI and GitHub Releases. - Added `__init__.py` files to establish package structure for `basemodel`, `inference`, `training`, and `utils` modules. - Implemented model downloading functionality in `download_model.py` to fetch pre-trained ONNX models. - Developed `TRPGParser` class for ONNX-based inference, including methods for parsing TRPG logs. - Created training utilities in `training/__init__.py` for NER model training with Hugging Face Transformers. - Enhanced utility functions for CoNLL file parsing and dataset creation. - Added command-line interface for converting CoNLL files to datasets with validation options.
Diffstat (limited to 'src/basemodel/download_model.py')
-rw-r--r--src/basemodel/download_model.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/src/basemodel/download_model.py b/src/basemodel/download_model.py
new file mode 100644
index 0000000..2d65099
--- /dev/null
+++ b/src/basemodel/download_model.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+"""
+模型下载脚本
+
+自动下载预训练的 ONNX 模型到用户缓存目录。
+"""
+
+import os
+import sys
+from pathlib import Path
+import urllib.request
+
+
+def download_model(
+ output_dir: str = None,
+ url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+):
+ """
+ 下载 ONNX 模型
+
+ Args:
+ output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
+ url: 模型下载 URL
+ """
+ if output_dir is None:
+ output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ else:
+ output_dir = Path(output_dir)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "model.onnx"
+
+ if output_path.exists():
+ print(f"✅ 模型已存在: {output_path}")
+ return str(output_path)
+
+ print(f"📥 正在下载模型到 {output_path}...")
+ print(f" URL: {url}")
+
+ try:
+ urllib.request.urlretrieve(url, output_path)
+ print(f"✅ 模型下载成功!")
+ return str(output_path)
+ except Exception as e:
+ print(f"❌ 下载失败: {e}")
+ print(f" 请手动从以下地址下载模型:")
+ print(f" {url}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
+ help="模型下载 URL"
+ )
+
+ args = parser.parse_args()
+
+ download_model(args.output_dir, args.url)