blob: 2d6509938b2ca81ddeb9f9f33b07022f0ea0372c (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
|
#!/usr/bin/env python3
"""
模型下载脚本
自动下载预训练的 ONNX 模型到用户缓存目录。
"""
import os
import sys
from pathlib import Path
import urllib.request
def download_model(
output_dir: str = None,
url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
):
"""
下载 ONNX 模型
Args:
output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
url: 模型下载 URL
"""
if output_dir is None:
output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "model.onnx"
if output_path.exists():
print(f"✅ 模型已存在: {output_path}")
return str(output_path)
print(f"📥 正在下载模型到 {output_path}...")
print(f" URL: {url}")
try:
urllib.request.urlretrieve(url, output_path)
print(f"✅ 模型下载成功!")
return str(output_path)
except Exception as e:
print(f"❌ 下载失败: {e}")
print(f" 请手动从以下地址下载模型:")
print(f" {url}")
sys.exit(1)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
)
parser.add_argument(
"--url",
type=str,
default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
help="模型下载 URL"
)
args = parser.parse_args()
download_model(args.output_dir, args.url)
|