aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/download_model.py
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 20:16:05 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 20:16:05 +0800
commit5dd166366b8a2f4699c1841ebd7fceabcd9868a4 (patch)
tree85d78772054529579176547c00aee9559cffff37 /src/base_model_trpgner/download_model.py
parentdd55c70225367dec9e8d88821b4d65fcd24edd65 (diff)
downloadbase-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.tar.gz
base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.zip
refactor: Refactor TRPG NER model SDK: restructure codebase into base_model_trpgner package, implement training and inference modules, and add model download functionality. Remove legacy training and utils modules. Enhance documentation and examples for better usability.
Diffstat (limited to 'src/base_model_trpgner/download_model.py')
-rw-r--r--src/base_model_trpgner/download_model.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py
new file mode 100644
index 0000000..2d65099
--- /dev/null
+++ b/src/base_model_trpgner/download_model.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+"""
+模型下载脚本
+
+自动下载预训练的 ONNX 模型到用户缓存目录。
+"""
+
+import os
+import sys
+from pathlib import Path
+import urllib.request
+
+
+def download_model(
+ output_dir: str = None,
+ url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+):
+ """
+ 下载 ONNX 模型
+
+ Args:
+ output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
+ url: 模型下载 URL
+ """
+ if output_dir is None:
+ output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ else:
+ output_dir = Path(output_dir)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "model.onnx"
+
+ if output_path.exists():
+ print(f"✅ 模型已存在: {output_path}")
+ return str(output_path)
+
+ print(f"📥 正在下载模型到 {output_path}...")
+ print(f" URL: {url}")
+
+ try:
+ urllib.request.urlretrieve(url, output_path)
+ print(f"✅ 模型下载成功!")
+ return str(output_path)
+ except Exception as e:
+ print(f"❌ 下载失败: {e}")
+ print(f" 请手动从以下地址下载模型:")
+ print(f" {url}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
+ help="模型下载 URL"
+ )
+
+ args = parser.parse_args()
+
+ download_model(args.output_dir, args.url)