aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/inference/__init__.py
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 20:39:34 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 20:39:34 +0800
commit298035052b3e3d083b57f5dbac0e86de4f94efba (patch)
tree944f38d734f752a5a0f71033ebece38fc5c35839 /src/base_model_trpgner/inference/__init__.py
parent92a647ffbb3452a0ed49601177f290e20a88413e (diff)
downloadbase-model-298035052b3e3d083b57f5dbac0e86de4f94efba.tar.gz
base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.zip
refactor: Update model download functionality and improve inference module to support automatic model retrieval from GitHub releases
Diffstat (limited to 'src/base_model_trpgner/inference/__init__.py')
-rw-r--r--src/base_model_trpgner/inference/__init__.py137
1 files changed, 119 insertions, 18 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
index 93a185f..d70cb23 100644
--- a/src/base_model_trpgner/inference/__init__.py
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -5,6 +5,8 @@ ONNX 推理模块
"""
import os
+import json
+import shutil
from typing import List, Dict, Any, Optional
from pathlib import Path
@@ -18,20 +20,113 @@ except ImportError as e:
) from e
-# 默认模型路径(相对于包安装位置)
-DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final"
-# 远程模型 URL(用于自动下载)
-MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+# GitHub 仓库信息
+REPO_OWNER = "HydroRoll-Team"
+REPO_NAME = "base-model"
+# 用户数据目录
+USER_MODEL_DIR = Path.home() / ".cache" / "base_model_trpgner" / "models" / "trpg-final"
+
+
+def get_latest_release_url() -> str:
+ """
+ 获取 GitHub 最新 Release 的下载 URL
+
+ Returns:
+ 最新 Release 的标签名(如 v0.1.0)
+ """
+ import urllib.request
+ import urllib.error
+
+ try:
+ # 使用 GitHub API 获取最新 release
+ api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases/latest"
+ with urllib.request.urlopen(api_url, timeout=10) as response:
+ data = json.load(response)
+ return data.get("tag_name", "v0.1.0")
+ except (urllib.error.URLError, json.JSONDecodeError, KeyError):
+ # 失败时返回默认版本
+ return "v0.1.0"
+
+
+def download_model_files(version: Optional[str] = None, force: bool = False) -> Path:
+ """
+ 从 GitHub Release 下载模型文件
+
+ Args:
+ version: Release 版本(如 v0.1.0),None 表示最新版本
+ force: 是否强制重新下载(即使文件已存在)
+
+ Returns:
+ 模型文件保存目录
+ """
+ import urllib.request
+ import urllib.error
+
+ if version is None:
+ version = get_latest_release_url()
+
+ model_dir = USER_MODEL_DIR
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ # 检查是否已下载
+ marker_file = model_dir / ".version"
+ if not force and marker_file.exists():
+ with open(marker_file, "r") as f:
+ current_version = f.read().strip()
+ if current_version == version:
+ print(f"模型已存在 (版本: {version})")
+ return model_dir
+
+ print(f"正在下载模型 {version}...")
+
+ # 需要下载的文件
+ base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}"
+ files_to_download = [
+ "model.onnx",
+ "model.onnx.data",
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "vocab.txt",
+ ]
+
+ for filename in files_to_download:
+ url = f"{base_url}/{filename}"
+ dest_path = model_dir / filename
+
+ try:
+ print(f"下载 {filename}...")
+ with urllib.request.urlopen(url, timeout=60) as response:
+ with open(dest_path, "wb") as f:
+ shutil.copyfileobj(response, f)
+ except urllib.error.HTTPError as e:
+ print(f"下载 {filename} 失败: {e}")
+ # model.onnx.data 不是必需的(某些小模型可能没有)
+ if filename == "model.onnx.data":
+ continue
+ else:
+ raise RuntimeError(f"无法下载模型文件: {filename}") from e
+
+ # 写入版本标记
+ with open(marker_file, "w") as f:
+ f.write(version)
+
+ print(f"✅ 模型下载完成: {model_dir}")
+ return model_dir
class TRPGParser:
"""
TRPG 日志解析器(基于 ONNX)
+ 首次运行时会自动从 GitHub Release 下载最新模型。
+
Args:
- model_path: ONNX 模型路径,默认使用内置模型
+ model_path: ONNX 模型路径,默认使用自动下载的模型
tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
device: 推理设备,"cpu" 或 "cuda"
+ auto_download: 是否自动下载模型(默认 True)
Examples:
>>> parser = TRPGParser()
@@ -45,10 +140,11 @@ class TRPGParser:
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
device: str = "cpu",
+ auto_download: bool = True,
):
# 确定模型路径
if model_path is None:
- model_path = self._get_default_model_path()
+ model_path = self._get_default_model_path(auto_download)
if tokenizer_path is None:
tokenizer_path = Path(model_path).parent
@@ -60,25 +156,31 @@ class TRPGParser:
# 加载模型
self._load_model()
- def _get_default_model_path(self) -> str:
- """获取默认模型路径"""
- # 1. 尝试相对于项目根目录
+ def _get_default_model_path(self, auto_download: bool) -> str:
+ """获取默认模型路径,必要时自动下载"""
+ # 1. 检查本地开发环境
project_root = Path(__file__).parent.parent.parent.parent
local_model = project_root / "models" / "trpg-final" / "model.onnx"
if local_model.exists():
return str(local_model)
- # 2. 尝试用户数据目录
- from pathlib import Path
- user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
- user_model = user_model_dir / "model.onnx"
+ # 2. 检查用户缓存目录
+ user_model = USER_MODEL_DIR / "model.onnx"
if user_model.exists():
return str(user_model)
- # 3. 抛出错误,提示下载
+ # 3. 自动下载
+ if auto_download:
+ print("模型未找到,正在从 GitHub Release 下载...")
+ download_model_files()
+ return str(user_model)
+
+ # 4. 抛出错误
raise FileNotFoundError(
- f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n"
- f"或运行: python -m basemodel.download_model"
+ f"模型文件未找到。\n"
+ f"请开启自动下载: TRPGParser(auto_download=True)\n"
+ f"或手动下载到: {USER_MODEL_DIR}\n"
+ f"下载地址: https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest"
)
def _load_model(self):
@@ -100,7 +202,6 @@ class TRPGParser:
)
# 加载标签映射
- import json
config_path = self.tokenizer_path / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
@@ -289,4 +390,4 @@ def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict
return parser.parse_batch(texts)
-__all__ = ["TRPGParser", "parse_line", "parse_lines"]
+__all__ = ["TRPGParser", "parse_line", "parse_lines", "download_model_files"]