aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/download_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/base_model_trpgner/download_model.py')
-rw-r--r--src/base_model_trpgner/download_model.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py
new file mode 100644
index 0000000..2d65099
--- /dev/null
+++ b/src/base_model_trpgner/download_model.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+"""
+模型下载脚本
+
+自动下载预训练的 ONNX 模型到用户缓存目录。
+"""
+
+import os
+import sys
+from pathlib import Path
+import urllib.request
+
+
+def download_model(
+ output_dir: str = None,
+ url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+):
+ """
+ 下载 ONNX 模型
+
+ Args:
+ output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
+ url: 模型下载 URL
+ """
+ if output_dir is None:
+ output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ else:
+ output_dir = Path(output_dir)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "model.onnx"
+
+ if output_path.exists():
+ print(f"✅ 模型已存在: {output_path}")
+ return str(output_path)
+
+ print(f"📥 正在下载模型到 {output_path}...")
+ print(f" URL: {url}")
+
+ try:
+ urllib.request.urlretrieve(url, output_path)
+ print(f"✅ 模型下载成功!")
+ return str(output_path)
+ except Exception as e:
+ print(f"❌ 下载失败: {e}")
+ print(f" 请手动从以下地址下载模型:")
+ print(f" {url}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
+ help="模型下载 URL"
+ )
+
+ args = parser.parse_args()
+
+ download_model(args.output_dir, args.url)