aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/download_model.py
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 20:39:34 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 20:39:34 +0800
commit298035052b3e3d083b57f5dbac0e86de4f94efba (patch)
tree944f38d734f752a5a0f71033ebece38fc5c35839 /src/base_model_trpgner/download_model.py
parent92a647ffbb3452a0ed49601177f290e20a88413e (diff)
downloadbase-model-298035052b3e3d083b57f5dbac0e86de4f94efba.tar.gz
base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.zip
refactor: Update model download functionality and improve inference module to support automatic model retrieval from GitHub releases
Diffstat (limited to 'src/base_model_trpgner/download_model.py')
-rw-r--r--src/base_model_trpgner/download_model.py70
1 files changed, 0 insertions, 70 deletions
diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py
deleted file mode 100644
index 2d65099..0000000
--- a/src/base_model_trpgner/download_model.py
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env python3
-"""
-模型下载脚本
-
-自动下载预训练的 ONNX 模型到用户缓存目录。
-"""
-
-import os
-import sys
-from pathlib import Path
-import urllib.request
-
-
-def download_model(
- output_dir: str = None,
- url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
-):
- """
- 下载 ONNX 模型
-
- Args:
- output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
- url: 模型下载 URL
- """
- if output_dir is None:
- output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
- else:
- output_dir = Path(output_dir)
-
- output_dir.mkdir(parents=True, exist_ok=True)
- output_path = output_dir / "model.onnx"
-
- if output_path.exists():
- print(f"✅ 模型已存在: {output_path}")
- return str(output_path)
-
- print(f"📥 正在下载模型到 {output_path}...")
- print(f" URL: {url}")
-
- try:
- urllib.request.urlretrieve(url, output_path)
- print(f"✅ 模型下载成功!")
- return str(output_path)
- except Exception as e:
- print(f"❌ 下载失败: {e}")
- print(f" 请手动从以下地址下载模型:")
- print(f" {url}")
- sys.exit(1)
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
- parser.add_argument(
- "--output-dir",
- type=str,
- default=None,
- help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
- )
- parser.add_argument(
- "--url",
- type=str,
- default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
- help="模型下载 URL"
- )
-
- args = parser.parse_args()
-
- download_model(args.output_dir, args.url)