diff options
| -rw-r--r-- | .github/workflows/publish.yml | 103 | ||||
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 48 |
2 files changed, 24 insertions, 127 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 520fdb2..bdc6cfc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -3,18 +3,13 @@ name: Publish to PyPI & GitHub Release on: push: tags: - - 'v*.*.*' + - "v*.*.*" workflow_dispatch: inputs: - create_test: - description: 'Publish to Test PyPI' - required: false - default: false - type: boolean tag_name: - description: 'Tag name to release (e.g., v1.0.0). Use this to re-run release with fixed code.' + description: "Tag name to release (e.g., v1.0.0)" required: true - default: 'v1.0.0' + default: "v1.0.0" permissions: contents: write @@ -33,24 +28,18 @@ jobs: uses: actions/checkout@v6 with: fetch-depth: 0 - # 最佳实践: Tag 触发时使用 main 分支代码,而不是 Tag 快照 - # 这样修复代码后可以手动重新触发 workflow 而无需重新打 tag ref: main - name: Extract version from tag id: version run: | - # Tag 触发时使用 github.ref_name, 手动触发时使用 inputs.tag_name if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then VERSION="${{ github.event.inputs.tag_name }}" - echo "🔧 Manual trigger mode: Using tag $VERSION from input" else VERSION="${{ github.ref_name }}" - echo "🏷️ Tag trigger mode: Using tag $VERSION from GitHub" fi echo "version=${VERSION#v}" >> $GITHUB_OUTPUT echo "Tag: $VERSION" - echo "Version: ${VERSION#v}" - name: Install uv run: | @@ -72,63 +61,25 @@ jobs: name: dist path: dist/ - - name: Prepare ONNX artifact + - name: Package model files run: | - mkdir -p onnx-artifact - cp models/trpg-final/model.onnx onnx-artifact/ - cp models/trpg-final/model.onnx.data onnx-artifact/ || true - cp models/trpg-final/config.json onnx-artifact/ - cp models/trpg-final/tokenizer.json onnx-artifact/ - cp models/trpg-final/tokenizer_config.json onnx-artifact/ - cp models/trpg-final/special_tokens_map.json onnx-artifact/ - cp models/trpg-final/vocab.txt onnx-artifact/ - - # 创建压缩包(方便用户下载) - cd onnx-artifact - zip -r ../model.zip . - cd .. - ls -lh onnx-artifact/ + # 直接打包 models/trpg-final 整个目录 + cd models/trpg-final + zip -r ../../model.zip . + cd ../.. ls -lh model.zip - uses: actions/upload-artifact@v4 with: - name: onnx-model - path: onnx-artifact/ - - - uses: actions/upload-artifact@v4 - with: name: model-zip path: model.zip - publish-test-pypi: - name: Publish to Test PyPI - needs: build - runs-on: ubuntu-latest - if: github.event_name == 'workflow_dispatch' && inputs.create_test == true - - environment: - name: test-pypi - url: https://test.pypi.org/p/base-model-trpgner - - permissions: - id-token: write - - steps: - - name: Download dist - uses: actions/download-artifact@v4 - with: - name: dist - path: dist/ - - - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - publish-pypi: name: Publish to PyPI needs: build runs-on: ubuntu-latest - # Tag 推送时自动发布, 或手动触发且未指定测试模式时发布 - if: (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) || (github.event_name == 'workflow_dispatch' && inputs.create_test != true) + # Tag 推送或手动触发时发布 + if: (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) || github.event_name == 'workflow_dispatch' environment: name: pypi @@ -151,8 +102,8 @@ jobs: name: Create GitHub Release with ONNX needs: [build, publish-pypi] runs-on: ubuntu-latest - # Tag 推送时自动创建, 或手动触发且未指定测试模式时创建 - if: (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) || (github.event_name == 'workflow_dispatch' && inputs.create_test != true) + # Tag 推送或手动触发时创建 + if: (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) || github.event_name == 'workflow_dispatch' permissions: contents: write @@ -170,7 +121,6 @@ jobs: uses: requarks/changelog-action@v1 with: token: ${{ github.token }} - # 根据触发类型选择正确的 tag tag: ${{ github.event_name == 'workflow_dispatch' && inputs.tag_name || github.ref_name }} includeInvalidCommits: true changelogFilePath: CHANGELOG.md @@ -180,26 +130,14 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4 with: - pattern: '*' + pattern: "*" path: artifacts/ merge-multiple: true - - name: Verify artifact structure - run: | - echo "📁 Artifact directory structure:" - ls -la artifacts/ - echo "" - echo "📦 dist contents:" - ls -la artifacts/dist/ || echo "dist not found" - echo "" - echo "🧠 onnx-model contents:" - ls -la artifacts/onnx-model/ || echo "onnx-model not found" - - - name: Create Release with ONNX + - name: Create Release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - # 根据触发类型选择正确的 tag if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then VERSION="${{ github.event.inputs.tag_name }}" else @@ -238,27 +176,18 @@ jobs: ${{ steps.changelog.outputs.changes }} EOF - # 检查 release 是否已存在 if gh release view "${VERSION}" --repo "${{ github.repository }}" >/dev/null 2>&1; then - echo "📝 Release ${VERSION} 已存在,更新 release 资源..." - - # 删除旧的 assets 以便上传新的 + echo "📝 Release ${VERSION} 已存在,更新..." ASSETS=$(gh release view "${VERSION}" --repo "${{ github.repository }}" --json assets -q '.assets[].name') for asset in $ASSETS; do - echo " 删除旧资源: $asset" gh release delete-asset "${VERSION}" "$asset" --repo "${{ github.repository }}" || true done - - # 更新 release notes 和 assets gh release edit "${VERSION}" \ --repo "${{ github.repository }}" \ --notes-file release_notes.md \ --title "🚀 ${VERSION}" - - # 上传新的资源 gh release upload "${VERSION}" \ artifacts/dist/* \ - artifacts/onnx-model/* \ artifacts/model.zip \ --repo "${{ github.repository }}" --clobber else @@ -268,7 +197,6 @@ jobs: --notes-file release_notes.md \ --title "🚀 ${VERSION}" \ artifacts/dist/* \ - artifacts/onnx-model/* \ artifacts/model.zip fi @@ -277,6 +205,5 @@ jobs: uses: stefanzweifel/git-auto-commit-action@v7 with: branch: main - # 根据触发类型选择正确的 tag commit_message: "docs: update CHANGELOG.md for ${{ github.event_name == 'workflow_dispatch' && inputs.tag_name || github.ref_name }} [skip ci]" file_pattern: CHANGELOG.md diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index e90a49b..3d1d720 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -52,7 +52,7 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> """ 从 GitHub Release 下载模型文件 - 优先下载压缩包(如果存在),否则逐个下载文件。 + 下载 model.zip 压缩包并解压。 Args: version: Release 版本(如 v0.1.0),None 表示最新版本 @@ -83,12 +83,9 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> print(f"正在下载模型 {version}...") - base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}" - - # 方式1: 尝试下载压缩包(更快) - zip_url = f"{base_url}/model.zip" + # 下载 model.zip + zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip" try: - print(f"尝试下载压缩包...") with urllib.request.urlopen(zip_url, timeout=120) as response: # 下载到临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: @@ -103,41 +100,14 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> # 清理临时文件 os.unlink(tmp_path) - print(f"✅ 模型下载完成(压缩包): {model_dir}") - - except urllib.error.HTTPError: - # 方式2: 压缩包不存在,逐个下载文件 - print(f"压缩包不存在,逐个下载文件...") - - files_to_download = [ - "model.onnx", - "model.onnx.data", - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.txt", - ] - - for filename in files_to_download: - url = f"{base_url}/{filename}" - dest_path = model_dir / filename - - try: - print(f"下载 {filename}...") - with urllib.request.urlopen(url, timeout=60) as response: - with open(dest_path, "wb") as f: - shutil.copyfileobj(response, f) - except urllib.error.HTTPError as e: - print(f"下载 {filename} 失败: {e}") - # model.onnx.data 不是必需的(某些小模型可能没有) - if filename == "model.onnx.data": - continue - else: - raise RuntimeError(f"无法下载模型文件: {filename}") from e - print(f"✅ 模型下载完成: {model_dir}") + except urllib.error.HTTPError as e: + raise RuntimeError( + f"无法下载模型文件。请检查版本 {version} 是否存在。\n" + f"下载地址: {zip_url}" + ) from e + # 写入版本标记 with open(marker_file, "w") as f: f.write(version) |
