aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 20:54:06 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 20:54:06 +0800
commit18c946aac2b0e16ec4e66bb4c40c62403af6f205 (patch)
tree6b4320d0e9efc4566f733eda003cd241b2f94d96 /src
parenta4dd04f6e3af86ce3f96c7f7ebc88e195db366f4 (diff)
downloadbase-model-18c946aac2b0e16ec4e66bb4c40c62403af6f205.tar.gz
base-model-18c946aac2b0e16ec4e66bb4c40c62403af6f205.zip
refactor: Clean up publish workflow by removing Test PyPI steps and improving artifact packaging
Diffstat (limited to 'src')
-rw-r--r--src/base_model_trpgner/inference/__init__.py48
1 files changed, 9 insertions, 39 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
index e90a49b..3d1d720 100644
--- a/src/base_model_trpgner/inference/__init__.py
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -52,7 +52,7 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
"""
从 GitHub Release 下载模型文件
- 优先下载压缩包(如果存在),否则逐个下载文件。
+ 下载 model.zip 压缩包并解压。
Args:
version: Release 版本(如 v0.1.0),None 表示最新版本
@@ -83,12 +83,9 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
print(f"正在下载模型 {version}...")
- base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}"
-
- # 方式1: 尝试下载压缩包(更快)
- zip_url = f"{base_url}/model.zip"
+ # 下载 model.zip
+ zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip"
try:
- print(f"尝试下载压缩包...")
with urllib.request.urlopen(zip_url, timeout=120) as response:
# 下载到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
@@ -103,41 +100,14 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
# 清理临时文件
os.unlink(tmp_path)
- print(f"✅ 模型下载完成(压缩包): {model_dir}")
-
- except urllib.error.HTTPError:
- # 方式2: 压缩包不存在,逐个下载文件
- print(f"压缩包不存在,逐个下载文件...")
-
- files_to_download = [
- "model.onnx",
- "model.onnx.data",
- "config.json",
- "tokenizer.json",
- "tokenizer_config.json",
- "special_tokens_map.json",
- "vocab.txt",
- ]
-
- for filename in files_to_download:
- url = f"{base_url}/{filename}"
- dest_path = model_dir / filename
-
- try:
- print(f"下载 {filename}...")
- with urllib.request.urlopen(url, timeout=60) as response:
- with open(dest_path, "wb") as f:
- shutil.copyfileobj(response, f)
- except urllib.error.HTTPError as e:
- print(f"下载 {filename} 失败: {e}")
- # model.onnx.data 不是必需的(某些小模型可能没有)
- if filename == "model.onnx.data":
- continue
- else:
- raise RuntimeError(f"无法下载模型文件: {filename}") from e
-
print(f"✅ 模型下载完成: {model_dir}")
+ except urllib.error.HTTPError as e:
+ raise RuntimeError(
+ f"无法下载模型文件。请检查版本 {version} 是否存在。\n"
+ f"下载地址: {zip_url}"
+ ) from e
+
# 写入版本标记
with open(marker_file, "w") as f:
f.write(version)