diff options
| author | 2025-12-30 20:16:05 +0800 | |
|---|---|---|
| committer | 2025-12-30 20:16:05 +0800 | |
| commit | 5dd166366b8a2f4699c1841ebd7fceabcd9868a4 (patch) | |
| tree | 85d78772054529579176547c00aee9559cffff37 /src/base_model_trpgner/download_model.py | |
| parent | dd55c70225367dec9e8d88821b4d65fcd24edd65 (diff) | |
| download | base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.tar.gz base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.zip | |
refactor: Refactor TRPG NER model SDK: restructure codebase into base_model_trpgner package, implement training and inference modules, and add model download functionality. Remove legacy training and utils modules. Enhance documentation and examples for better usability.
Diffstat (limited to 'src/base_model_trpgner/download_model.py')
| -rw-r--r-- | src/base_model_trpgner/download_model.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py new file mode 100644 index 0000000..2d65099 --- /dev/null +++ b/src/base_model_trpgner/download_model.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +模型下载脚本 + +自动下载预训练的 ONNX 模型到用户缓存目录。 +""" + +import os +import sys +from pathlib import Path +import urllib.request + + +def download_model( + output_dir: str = None, + url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" +): + """ + 下载 ONNX 模型 + + Args: + output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/ + url: 模型下载 URL + """ + if output_dir is None: + output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "model.onnx" + + if output_path.exists(): + print(f"✅ 模型已存在: {output_path}") + return str(output_path) + + print(f"📥 正在下载模型到 {output_path}...") + print(f" URL: {url}") + + try: + urllib.request.urlretrieve(url, output_path) + print(f"✅ 模型下载成功!") + return str(output_path) + except Exception as e: + print(f"❌ 下载失败: {e}") + print(f" 请手动从以下地址下载模型:") + print(f" {url}") + sys.exit(1) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型") + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)" + ) + parser.add_argument( + "--url", + type=str, + default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx", + help="模型下载 URL" + ) + + args = parser.parse_args() + + download_model(args.output_dir, args.url) |
