diff options
Diffstat (limited to 'src/base_model_trpgner/inference/__init__.py')
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 48 |
1 files changed, 9 insertions, 39 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index e90a49b..3d1d720 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -52,7 +52,7 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ 从 GitHub Release 下载模型文件 - 优先下载压缩包(如果存在),否则逐个下载文件。 + 下载 model.zip 压缩包并解压。 Args: version: Release 版本(如 v0.1.0),None 表示最新版本 @@ -83,12 +83,9 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> print(f"正在下载模型 {version}...") - base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}" - - # 方式1: 尝试下载压缩包(更快) - zip_url = f"{base_url}/model.zip" + # 下载 model.zip + zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip" try: - print(f"尝试下载压缩包...") with urllib.request.urlopen(zip_url, timeout=120) as response: # 下载到临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: @@ -103,41 +100,14 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> # 清理临时文件 os.unlink(tmp_path) - print(f"✅ 模型下载完成(压缩包): {model_dir}") - - except urllib.error.HTTPError: - # 方式2: 压缩包不存在,逐个下载文件 - print(f"压缩包不存在,逐个下载文件...") - - files_to_download = [ - "model.onnx", - "model.onnx.data", - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.txt", - ] - - for filename in files_to_download: - url = f"{base_url}/{filename}" - dest_path = model_dir / filename - - try: - print(f"下载 {filename}...") - with urllib.request.urlopen(url, timeout=60) as response: - with open(dest_path, "wb") as f: - shutil.copyfileobj(response, f) - except urllib.error.HTTPError as e: - print(f"下载 {filename} 失败: {e}") - # model.onnx.data 不是必需的(某些小模型可能没有) - if filename == "model.onnx.data": - continue - else: - raise RuntimeError(f"无法下载模型文件: {filename}") from e - print(f"✅ 模型下载完成: {model_dir}") + except urllib.error.HTTPError as e: + raise RuntimeError( + f"无法下载模型文件。请检查版本 {version} 是否存在。\n" + f"下载地址: {zip_url}" + ) from e + # 写入版本标记 with open(marker_file, "w") as f: f.write(version) |
