diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 85 |
1 files changed, 57 insertions, 28 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index d70cb23..e90a49b 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -52,6 +52,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ 从 GitHub Release 下载模型文件 + 优先下载压缩包(如果存在),否则逐个下载文件。 + Args: version: Release 版本(如 v0.1.0),None 表示最新版本 force: 是否强制重新下载(即使文件已存在) @@ -61,6 +63,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ import urllib.request import urllib.error + import tempfile + import zipfile if version is None: version = get_latest_release_url() @@ -79,40 +83,65 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> print(f"正在下载模型 {version}...") - # 需要下载的文件 base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}" - files_to_download = [ - "model.onnx", - "model.onnx.data", - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.txt", - ] - - for filename in files_to_download: - url = f"{base_url}/{filename}" - dest_path = model_dir / filename - - try: - print(f"下载 {filename}...") - with urllib.request.urlopen(url, timeout=60) as response: - with open(dest_path, "wb") as f: - shutil.copyfileobj(response, f) - except urllib.error.HTTPError as e: - print(f"下载 {filename} 失败: {e}") - # model.onnx.data 不是必需的(某些小模型可能没有) - if filename == "model.onnx.data": - continue - else: - raise RuntimeError(f"无法下载模型文件: {filename}") from e + + # 方式1: 尝试下载压缩包(更快) + zip_url = f"{base_url}/model.zip" + try: + print(f"尝试下载压缩包...") + with urllib.request.urlopen(zip_url, timeout=120) as response: + # 下载到临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: + shutil.copyfileobj(response, tmp_file) + tmp_path = tmp_file.name + + # 解压到目标目录 + print(f"解压中...") + with zipfile.ZipFile(tmp_path, 'r') as zip_ref: + zip_ref.extractall(model_dir) + + # 清理临时文件 + os.unlink(tmp_path) + + print(f"✅ 模型下载完成(压缩包): {model_dir}") + + except urllib.error.HTTPError: + # 方式2: 压缩包不存在,逐个下载文件 + print(f"压缩包不存在,逐个下载文件...") + + files_to_download = [ + "model.onnx", + "model.onnx.data", + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.txt", + ] + + for filename in files_to_download: + url = f"{base_url}/{filename}" + dest_path = model_dir / filename + + try: + print(f"下载 {filename}...") + with urllib.request.urlopen(url, timeout=60) as response: + with open(dest_path, "wb") as f: + shutil.copyfileobj(response, f) + except urllib.error.HTTPError as e: + print(f"下载 {filename} 失败: {e}") + # model.onnx.data 不是必需的(某些小模型可能没有) + if filename == "model.onnx.data": + continue + else: + raise RuntimeError(f"无法下载模型文件: {filename}") from e + + print(f"✅ 模型下载完成: {model_dir}") # 写入版本标记 with open(marker_file, "w") as f: f.write(version) - print(f"✅ 模型下载完成: {model_dir}") return model_dir |
