aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/publish.yml33
-rw-r--r--src/base_model_trpgner/inference/__init__.py85
2 files changed, 88 insertions, 30 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index c6b5fcf..520fdb2 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -77,13 +77,29 @@ jobs:
mkdir -p onnx-artifact
cp models/trpg-final/model.onnx onnx-artifact/
cp models/trpg-final/model.onnx.data onnx-artifact/ || true
+ cp models/trpg-final/config.json onnx-artifact/
+ cp models/trpg-final/tokenizer.json onnx-artifact/
+ cp models/trpg-final/tokenizer_config.json onnx-artifact/
+ cp models/trpg-final/special_tokens_map.json onnx-artifact/
+ cp models/trpg-final/vocab.txt onnx-artifact/
+
+ # 创建压缩包(方便用户下载)
+ cd onnx-artifact
+ zip -r ../model.zip .
+ cd ..
ls -lh onnx-artifact/
+ ls -lh model.zip
- uses: actions/upload-artifact@v4
with:
name: onnx-model
path: onnx-artifact/
+ - uses: actions/upload-artifact@v4
+ with:
+ name: model-zip
+ path: model.zip
+
publish-test-pypi:
name: Publish to Test PyPI
needs: build
@@ -168,6 +184,17 @@ jobs:
path: artifacts/
merge-multiple: true
+ - name: Verify artifact structure
+ run: |
+ echo "📁 Artifact directory structure:"
+ ls -la artifacts/
+ echo ""
+ echo "📦 dist contents:"
+ ls -la artifacts/dist/ || echo "dist not found"
+ echo ""
+ echo "🧠 onnx-model contents:"
+ ls -la artifacts/onnx-model/ || echo "onnx-model not found"
+
- name: Create Release with ONNX
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -231,7 +258,8 @@ jobs:
# 上传新的资源
gh release upload "${VERSION}" \
artifacts/dist/* \
- artifacts/onnx-artifact/* \
+ artifacts/onnx-model/* \
+ artifacts/model.zip \
--repo "${{ github.repository }}" --clobber
else
echo "✨ 创建新 release ${VERSION}..."
@@ -240,7 +268,8 @@ jobs:
--notes-file release_notes.md \
--title "🚀 ${VERSION}" \
artifacts/dist/* \
- artifacts/onnx-artifact/*
+ artifacts/onnx-model/* \
+ artifacts/model.zip
fi
- name: Commit CHANGELOG.md
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
index d70cb23..e90a49b 100644
--- a/src/base_model_trpgner/inference/__init__.py
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -52,6 +52,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
"""
从 GitHub Release 下载模型文件
+ 优先下载压缩包(如果存在),否则逐个下载文件。
+
Args:
version: Release 版本(如 v0.1.0),None 表示最新版本
force: 是否强制重新下载(即使文件已存在)
@@ -61,6 +63,8 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
"""
import urllib.request
import urllib.error
+ import tempfile
+ import zipfile
if version is None:
version = get_latest_release_url()
@@ -79,40 +83,65 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
print(f"正在下载模型 {version}...")
- # 需要下载的文件
base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}"
- files_to_download = [
- "model.onnx",
- "model.onnx.data",
- "config.json",
- "tokenizer.json",
- "tokenizer_config.json",
- "special_tokens_map.json",
- "vocab.txt",
- ]
-
- for filename in files_to_download:
- url = f"{base_url}/{filename}"
- dest_path = model_dir / filename
-
- try:
- print(f"下载 {filename}...")
- with urllib.request.urlopen(url, timeout=60) as response:
- with open(dest_path, "wb") as f:
- shutil.copyfileobj(response, f)
- except urllib.error.HTTPError as e:
- print(f"下载 {filename} 失败: {e}")
- # model.onnx.data 不是必需的(某些小模型可能没有)
- if filename == "model.onnx.data":
- continue
- else:
- raise RuntimeError(f"无法下载模型文件: {filename}") from e
+
+ # 方式1: 尝试下载压缩包(更快)
+ zip_url = f"{base_url}/model.zip"
+ try:
+ print(f"尝试下载压缩包...")
+ with urllib.request.urlopen(zip_url, timeout=120) as response:
+ # 下载到临时文件
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
+ shutil.copyfileobj(response, tmp_file)
+ tmp_path = tmp_file.name
+
+ # 解压到目标目录
+ print(f"解压中...")
+ with zipfile.ZipFile(tmp_path, 'r') as zip_ref:
+ zip_ref.extractall(model_dir)
+
+ # 清理临时文件
+ os.unlink(tmp_path)
+
+ print(f"✅ 模型下载完成(压缩包): {model_dir}")
+
+ except urllib.error.HTTPError:
+ # 方式2: 压缩包不存在,逐个下载文件
+ print(f"压缩包不存在,逐个下载文件...")
+
+ files_to_download = [
+ "model.onnx",
+ "model.onnx.data",
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "vocab.txt",
+ ]
+
+ for filename in files_to_download:
+ url = f"{base_url}/{filename}"
+ dest_path = model_dir / filename
+
+ try:
+ print(f"下载 {filename}...")
+ with urllib.request.urlopen(url, timeout=60) as response:
+ with open(dest_path, "wb") as f:
+ shutil.copyfileobj(response, f)
+ except urllib.error.HTTPError as e:
+ print(f"下载 {filename} 失败: {e}")
+ # model.onnx.data 不是必需的(某些小模型可能没有)
+ if filename == "model.onnx.data":
+ continue
+ else:
+ raise RuntimeError(f"无法下载模型文件: {filename}") from e
+
+ print(f"✅ 模型下载完成: {model_dir}")
# 写入版本标记
with open(marker_file, "w") as f:
f.write(version)
- print(f"✅ 模型下载完成: {model_dir}")
return model_dir