diff options
| -rw-r--r-- | .github/workflows/publish.yml | 33 | ||||
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 85 |
2 files changed, 88 insertions, 30 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c6b5fcf..520fdb2 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -77,13 +77,29 @@ jobs: mkdir -p onnx-artifact cp models/trpg-final/model.onnx onnx-artifact/ cp models/trpg-final/model.onnx.data onnx-artifact/ || true + cp models/trpg-final/config.json onnx-artifact/ + cp models/trpg-final/tokenizer.json onnx-artifact/ + cp models/trpg-final/tokenizer_config.json onnx-artifact/ + cp models/trpg-final/special_tokens_map.json onnx-artifact/ + cp models/trpg-final/vocab.txt onnx-artifact/ + + # 创建压缩包(方便用户下载) + cd onnx-artifact + zip -r ../model.zip . + cd .. ls -lh onnx-artifact/ + ls -lh model.zip - uses: actions/upload-artifact@v4 with: name: onnx-model path: onnx-artifact/ + - uses: actions/upload-artifact@v4 + with: + name: model-zip + path: model.zip + publish-test-pypi: name: Publish to Test PyPI needs: build @@ -168,6 +184,17 @@ jobs: path: artifacts/ merge-multiple: true + - name: Verify artifact structure + run: | + echo "📁 Artifact directory structure:" + ls -la artifacts/ + echo "" + echo "📦 dist contents:" + ls -la artifacts/dist/ || echo "dist not found" + echo "" + echo "🧠 onnx-model contents:" + ls -la artifacts/onnx-model/ || echo "onnx-model not found" + - name: Create Release with ONNX env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -231,7 +258,8 @@ jobs: # 上传新的资源 gh release upload "${VERSION}" \ artifacts/dist/* \ - artifacts/onnx-artifact/* \ + artifacts/onnx-model/* \ + artifacts/model.zip \ --repo "${{ github.repository }}" --clobber else echo "✨ 创建新 release ${VERSION}..." @@ -240,7 +268,8 @@ jobs: --notes-file release_notes.md \ --title "🚀 ${VERSION}" \ artifacts/dist/* \ - artifacts/onnx-artifact/* + artifacts/onnx-model/* \ + artifacts/model.zip fi - name: Commit CHANGELOG.md diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index d70cb23..e90a49b 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -52,6 +52,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ 从 GitHub Release 下载模型文件 + 优先下载压缩包(如果存在),否则逐个下载文件。 + Args: version: Release 版本(如 v0.1.0),None 表示最新版本 force: 是否强制重新下载(即使文件已存在) @@ -61,6 +63,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ import urllib.request import urllib.error + import tempfile + import zipfile if version is None: version = get_latest_release_url() @@ -79,40 +83,65 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> print(f"正在下载模型 {version}...") - # 需要下载的文件 base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}" - files_to_download = [ - "model.onnx", - "model.onnx.data", - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.txt", - ] - - for filename in files_to_download: - url = f"{base_url}/{filename}" - dest_path = model_dir / filename - - try: - print(f"下载 {filename}...") - with urllib.request.urlopen(url, timeout=60) as response: - with open(dest_path, "wb") as f: - shutil.copyfileobj(response, f) - except urllib.error.HTTPError as e: - print(f"下载 {filename} 失败: {e}") - # model.onnx.data 不是必需的(某些小模型可能没有) - if filename == "model.onnx.data": - continue - else: - raise RuntimeError(f"无法下载模型文件: {filename}") from e + + # 方式1: 尝试下载压缩包(更快) + zip_url = f"{base_url}/model.zip" + try: + print(f"尝试下载压缩包...") + with urllib.request.urlopen(zip_url, timeout=120) as response: + # 下载到临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: + shutil.copyfileobj(response, tmp_file) + tmp_path = tmp_file.name + + # 解压到目标目录 + print(f"解压中...") + with zipfile.ZipFile(tmp_path, 'r') as zip_ref: + zip_ref.extractall(model_dir) + + # 清理临时文件 + os.unlink(tmp_path) + + print(f"✅ 模型下载完成(压缩包): {model_dir}") + + except urllib.error.HTTPError: + # 方式2: 压缩包不存在,逐个下载文件 + print(f"压缩包不存在,逐个下载文件...") + + files_to_download = [ + "model.onnx", + "model.onnx.data", + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.txt", + ] + + for filename in files_to_download: + url = f"{base_url}/{filename}" + dest_path = model_dir / filename + + try: + print(f"下载 {filename}...") + with urllib.request.urlopen(url, timeout=60) as response: + with open(dest_path, "wb") as f: + shutil.copyfileobj(response, f) + except urllib.error.HTTPError as e: + print(f"下载 {filename} 失败: {e}") + # model.onnx.data 不是必需的(某些小模型可能没有) + if filename == "model.onnx.data": + continue + else: + raise RuntimeError(f"无法下载模型文件: {filename}") from e + + print(f"✅ 模型下载完成: {model_dir}") # 写入版本标记 with open(marker_file, "w") as f: f.write(version) - print(f"✅ 模型下载完成: {model_dir}") return model_dir |
