summaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/download_model.py
blob: 2d6509938b2ca81ddeb9f9f33b07022f0ea0372c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
"""
模型下载脚本

自动下载预训练的 ONNX 模型到用户缓存目录。
"""

import os
import sys
from pathlib import Path
import urllib.request


def download_model(
    output_dir: str = None,
    url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
):
    """
    下载 ONNX 模型

    Args:
        output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
        url: 模型下载 URL
    """
    if output_dir is None:
        output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
    else:
        output_dir = Path(output_dir)

    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / "model.onnx"

    if output_path.exists():
        print(f"✅ 模型已存在: {output_path}")
        return str(output_path)

    print(f"📥 正在下载模型到 {output_path}...")
    print(f"   URL: {url}")

    try:
        urllib.request.urlretrieve(url, output_path)
        print(f"✅ 模型下载成功!")
        return str(output_path)
    except Exception as e:
        print(f"❌ 下载失败: {e}")
        print(f"   请手动从以下地址下载模型:")
        print(f"   {url}")
        sys.exit(1)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
    )
    parser.add_argument(
        "--url",
        type=str,
        default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
        help="模型下载 URL"
    )

    args = parser.parse_args()

    download_model(args.output_dir, args.url)