aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/publish.yml4
-rw-r--r--pyproject.toml17
-rw-r--r--src/base_model_trpgner/__init__.py2
-rw-r--r--src/base_model_trpgner/download_model.py70
-rw-r--r--src/base_model_trpgner/inference/__init__.py137
-rw-r--r--src/base_model_trpgner/training/__init__.py4
-rw-r--r--tests/test_onnx_only_infer.py198
7 files changed, 324 insertions, 108 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 01c6060..c6b5fcf 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -4,8 +4,6 @@ on:
push:
tags:
- 'v*.*.*'
- branches:
- - main
workflow_dispatch:
inputs:
create_test:
@@ -201,7 +199,7 @@ jobs:
## 🚀 快速开始
```python
- from basemodeltrpgner import TRPGParser
+ from base_model_trpgner import TRPGParser
parser = TRPGParser()
result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
diff --git a/pyproject.toml b/pyproject.toml
index b92fae3..f9e8216 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,22 +56,11 @@ Issues = "https://github.com/HydroRoll-Team/base-model/issues"
[tool.hatch.build.targets.wheel]
packages = ["src/base_model_trpgner"]
-artifacts = [
- "src/base_model_trpgner/**/*.py",
- "models/trpg-final/model.onnx",
- "models/trpg-final/model.onnx.data",
- "models/trpg-final/config.json",
- "models/trpg-final/tokenizer.json",
- "models/trpg-final/tokenizer_config.json",
- "models/trpg-final/special_tokens_map.json",
- "models/trpg-final/vocab.txt",
-]
-
-[tool.hatch.build.targets.wheel.shared-data]
-"models/trpg-final" = "base_model_trpgner/models/trpg-final"
+# 仅包含 Python 源码,模型文件首次运行时自动从 GitHub Release 下载
+artifacts = ["src/base_model_trpgner/**/*.py"]
[tool.hatch.build.targets.sdist]
-include = ["/src", "/models/trpg-final", "/README.md", "/COPYING"]
+include = ["/src", "/README.md", "/COPYING"]
[tool.ruff]
exclude = [
diff --git a/src/base_model_trpgner/__init__.py b/src/base_model_trpgner/__init__.py
index 9796c83..8549a69 100644
--- a/src/base_model_trpgner/__init__.py
+++ b/src/base_model_trpgner/__init__.py
@@ -22,7 +22,7 @@ try:
from importlib.metadata import version
__version__ = version("base_model_trpgner")
except Exception:
- __version__ = "0.1.1.dev"
+ __version__ = "0.1.2.dev"
__all__ = [
"__version__",
diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py
deleted file mode 100644
index 2d65099..0000000
--- a/src/base_model_trpgner/download_model.py
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env python3
-"""
-模型下载脚本
-
-自动下载预训练的 ONNX 模型到用户缓存目录。
-"""
-
-import os
-import sys
-from pathlib import Path
-import urllib.request
-
-
-def download_model(
- output_dir: str = None,
- url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
-):
- """
- 下载 ONNX 模型
-
- Args:
- output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
- url: 模型下载 URL
- """
- if output_dir is None:
- output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
- else:
- output_dir = Path(output_dir)
-
- output_dir.mkdir(parents=True, exist_ok=True)
- output_path = output_dir / "model.onnx"
-
- if output_path.exists():
- print(f"✅ 模型已存在: {output_path}")
- return str(output_path)
-
- print(f"📥 正在下载模型到 {output_path}...")
- print(f" URL: {url}")
-
- try:
- urllib.request.urlretrieve(url, output_path)
- print(f"✅ 模型下载成功!")
- return str(output_path)
- except Exception as e:
- print(f"❌ 下载失败: {e}")
- print(f" 请手动从以下地址下载模型:")
- print(f" {url}")
- sys.exit(1)
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
- parser.add_argument(
- "--output-dir",
- type=str,
- default=None,
- help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
- )
- parser.add_argument(
- "--url",
- type=str,
- default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
- help="模型下载 URL"
- )
-
- args = parser.parse_args()
-
- download_model(args.output_dir, args.url)
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
index 93a185f..d70cb23 100644
--- a/src/base_model_trpgner/inference/__init__.py
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -5,6 +5,8 @@ ONNX 推理模块
"""
import os
+import json
+import shutil
from typing import List, Dict, Any, Optional
from pathlib import Path
@@ -18,20 +20,113 @@ except ImportError as e:
) from e
-# 默认模型路径(相对于包安装位置)
-DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final"
-# 远程模型 URL(用于自动下载)
-MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+# GitHub 仓库信息
+REPO_OWNER = "HydroRoll-Team"
+REPO_NAME = "base-model"
+# 用户数据目录
+USER_MODEL_DIR = Path.home() / ".cache" / "base_model_trpgner" / "models" / "trpg-final"
+
+
+def get_latest_release_url() -> str:
+ """
+ 获取 GitHub 最新 Release 的下载 URL
+
+ Returns:
+ 最新 Release 的标签名(如 v0.1.0)
+ """
+ import urllib.request
+ import urllib.error
+
+ try:
+ # 使用 GitHub API 获取最新 release
+ api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases/latest"
+ with urllib.request.urlopen(api_url, timeout=10) as response:
+ data = json.load(response)
+ return data.get("tag_name", "v0.1.0")
+ except (urllib.error.URLError, json.JSONDecodeError, KeyError):
+ # 失败时返回默认版本
+ return "v0.1.0"
+
+
+def download_model_files(version: Optional[str] = None, force: bool = False) -> Path:
+ """
+ 从 GitHub Release 下载模型文件
+
+ Args:
+ version: Release 版本(如 v0.1.0),None 表示最新版本
+ force: 是否强制重新下载(即使文件已存在)
+
+ Returns:
+ 模型文件保存目录
+ """
+ import urllib.request
+ import urllib.error
+
+ if version is None:
+ version = get_latest_release_url()
+
+ model_dir = USER_MODEL_DIR
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ # 检查是否已下载
+ marker_file = model_dir / ".version"
+ if not force and marker_file.exists():
+ with open(marker_file, "r") as f:
+ current_version = f.read().strip()
+ if current_version == version:
+ print(f"模型已存在 (版本: {version})")
+ return model_dir
+
+ print(f"正在下载模型 {version}...")
+
+ # 需要下载的文件
+ base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}"
+ files_to_download = [
+ "model.onnx",
+ "model.onnx.data",
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "vocab.txt",
+ ]
+
+ for filename in files_to_download:
+ url = f"{base_url}/{filename}"
+ dest_path = model_dir / filename
+
+ try:
+ print(f"下载 {filename}...")
+ with urllib.request.urlopen(url, timeout=60) as response:
+ with open(dest_path, "wb") as f:
+ shutil.copyfileobj(response, f)
+ except urllib.error.HTTPError as e:
+ print(f"下载 {filename} 失败: {e}")
+ # model.onnx.data 不是必需的(某些小模型可能没有)
+ if filename == "model.onnx.data":
+ continue
+ else:
+ raise RuntimeError(f"无法下载模型文件: {filename}") from e
+
+ # 写入版本标记
+ with open(marker_file, "w") as f:
+ f.write(version)
+
+ print(f"✅ 模型下载完成: {model_dir}")
+ return model_dir
class TRPGParser:
"""
TRPG 日志解析器(基于 ONNX)
+ 首次运行时会自动从 GitHub Release 下载最新模型。
+
Args:
- model_path: ONNX 模型路径,默认使用内置模型
+ model_path: ONNX 模型路径,默认使用自动下载的模型
tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
device: 推理设备,"cpu" 或 "cuda"
+ auto_download: 是否自动下载模型(默认 True)
Examples:
>>> parser = TRPGParser()
@@ -45,10 +140,11 @@ class TRPGParser:
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
device: str = "cpu",
+ auto_download: bool = True,
):
# 确定模型路径
if model_path is None:
- model_path = self._get_default_model_path()
+ model_path = self._get_default_model_path(auto_download)
if tokenizer_path is None:
tokenizer_path = Path(model_path).parent
@@ -60,25 +156,31 @@ class TRPGParser:
# 加载模型
self._load_model()
- def _get_default_model_path(self) -> str:
- """获取默认模型路径"""
- # 1. 尝试相对于项目根目录
+ def _get_default_model_path(self, auto_download: bool) -> str:
+ """获取默认模型路径,必要时自动下载"""
+ # 1. 检查本地开发环境
project_root = Path(__file__).parent.parent.parent.parent
local_model = project_root / "models" / "trpg-final" / "model.onnx"
if local_model.exists():
return str(local_model)
- # 2. 尝试用户数据目录
- from pathlib import Path
- user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
- user_model = user_model_dir / "model.onnx"
+ # 2. 检查用户缓存目录
+ user_model = USER_MODEL_DIR / "model.onnx"
if user_model.exists():
return str(user_model)
- # 3. 抛出错误,提示下载
+ # 3. 自动下载
+ if auto_download:
+ print("模型未找到,正在从 GitHub Release 下载...")
+ download_model_files()
+ return str(user_model)
+
+ # 4. 抛出错误
raise FileNotFoundError(
- f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n"
- f"或运行: python -m basemodel.download_model"
+ f"模型文件未找到。\n"
+ f"请开启自动下载: TRPGParser(auto_download=True)\n"
+ f"或手动下载到: {USER_MODEL_DIR}\n"
+ f"下载地址: https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest"
)
def _load_model(self):
@@ -100,7 +202,6 @@ class TRPGParser:
)
# 加载标签映射
- import json
config_path = self.tokenizer_path / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
@@ -289,4 +390,4 @@ def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict
return parser.parse_batch(texts)
-__all__ = ["TRPGParser", "parse_line", "parse_lines"]
+__all__ = ["TRPGParser", "parse_line", "parse_lines", "download_model_files"]
diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py
index ccf3c03..4f8e30d 100644
--- a/src/base_model_trpgner/training/__init__.py
+++ b/src/base_model_trpgner/training/__init__.py
@@ -36,11 +36,11 @@ def train_ner_model(
resume_from_checkpoint: 恢复检查点路径
Examples:
- >>> from basemodeltrpgner.training import train_ner_model
+ >>> from base_model_trpgner.training import train_ner_model
>>> train_ner_model(
... conll_data="./data",
... output_dir="./my_model",
- ... epochs=10
+ ... num_train_epochs=10
... )
"""
try:
diff --git a/tests/test_onnx_only_infer.py b/tests/test_onnx_only_infer.py
new file mode 100644
index 0000000..69c72be
--- /dev/null
+++ b/tests/test_onnx_only_infer.py
@@ -0,0 +1,198 @@
+"""
+Minimal ONNX-only inference using only:
+ - models/trpg-final/model.onnx
+ - models/trpg-final/config.json
+
+NOTE: 使用自制字符级 tokenizer(非训练时 tokenizer),结果可能与原模型输出不一致,
+但可在没有 tokenizer 文件时完成端到端推理演示。
+"""
+
+import os, sys, json, re
+import numpy as np
+import onnxruntime as ort
+
+MODEL_DIR = "models/trpg-final"
+ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx")
+CFG_PATH = os.path.join(MODEL_DIR, "config.json")
+MAX_LEN = 128
+
+# load id2label & vocab_size
+with open(CFG_PATH, "r", encoding="utf-8") as f:
+ cfg = json.load(f)
+id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()}
+vocab_size = int(cfg.get("vocab_size", 30000))
+pad_id = int(cfg.get("pad_token_id", 0))
+
+# simple char-level tokenizer (adds [CLS]=101, [SEP]=102, pads with pad_id)
+CLS_ID = 101
+SEP_ID = 102
+
+
+def char_tokenize(text, max_length=MAX_LEN):
+ chars = list(text)
+ # reserve 2 for CLS and SEP
+ max_chars = max_length - 2
+ chars = chars[:max_chars]
+ ids = [CLS_ID] + [100 + (ord(c) % (vocab_size - 200)) for c in chars] + [SEP_ID]
+ attn = [1] * len(ids)
+ # pad
+ pad_len = max_length - len(ids)
+ ids += [pad_id] * pad_len
+ attn += [0] * pad_len
+ # offsets: for CLS/SEP/pad use (0,0); for char tokens map to character positions
+ offsets = [(0, 0)]
+ pos = 0
+ for c in chars:
+ offsets.append((pos, pos + 1))
+ pos += 1
+ offsets.append((0, 0)) # SEP
+ offsets += [(0, 0)] * pad_len
+ return {
+ "input_ids": np.array([ids], dtype=np.int64),
+ "attention_mask": np.array([attn], dtype=np.int64),
+ "offset_mapping": np.array([offsets], dtype=np.int64),
+ "text": text,
+ }
+
+
+# onnx runtime session
+providers = [
+ p
+ for p in ("CUDAExecutionProvider", "CPUExecutionProvider")
+ if p in ort.get_available_providers()
+]
+so = ort.SessionOptions()
+so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers)
+
+
+def softmax(x):
+ x = x - x.max(axis=-1, keepdims=True)
+ e = np.exp(x)
+ return e / e.sum(axis=-1, keepdims=True)
+
+
+text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..."
+inp = char_tokenize(text, MAX_LEN)
+
+# build feed dict matching session inputs
+feed = {}
+for s_in in sess.get_inputs():
+ name = s_in.name
+ if name in inp:
+ feed[name] = inp[name]
+
+outs = sess.run(None, feed)
+logits = np.asarray(outs[0]) # (batch, seq_len, num_labels)
+probs = softmax(logits)
+
+ids = inp["input_ids"][0]
+offsets = inp["offset_mapping"][0]
+attn = inp["attention_mask"][0]
+
+# reconstruct token strings (CLS, each char, SEP)
+tokens = []
+for i, idv in enumerate(ids):
+ if i == 0:
+ tokens.append("[CLS]")
+ else:
+ if offsets[i][0] == 0 and offsets[i][1] == 0:
+ # SEP or pad
+ if attn[i] == 1:
+ tokens.append("[SEP]")
+ else:
+ tokens.append("[PAD]")
+ else:
+ s, e = offsets[i]
+ tokens.append(text[s:e])
+
+# print raw logits shape and a small slice for inspection
+print("Raw logits shape:", logits.shape)
+print("\nPer-token logits (index token -> first 6 logits):")
+for i, (t, l, a) in enumerate(zip(tokens, logits[0], attn)):
+ if not a:
+ continue
+ print(f"{i:03d} {t:>6} ->", np.around(l[:6], 3).tolist())
+
+# predictions & probs
+pred_ids = logits.argmax(-1)[0]
+pred_probs = probs[0, np.arange(probs.shape[1]), pred_ids]
+
+print("\nPer-token predictions (token \\t label \\t prob):")
+for i, (t, pid, pprob, a) in enumerate(zip(tokens, pred_ids, pred_probs, attn)):
+ if not a:
+ continue
+ lab = id2label.get(int(pid), "O")
+ print(f"{t}\t{lab}\t{pprob:.3f}")
+
+# merge BIO into entities using offsets
+entities = []
+cur = None
+for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)):
+ if not a or (off[0] == off[1] == 0):
+ if cur:
+ entities.append(cur)
+ cur = None
+ continue
+ label = id2label.get(int(pid), "O")
+ if label == "O":
+ if cur:
+ entities.append(cur)
+ cur = None
+ continue
+ if label.startswith("B-") or cur is None or label[2:] != cur["type"]:
+ if cur:
+ entities.append(cur)
+ cur = {
+ "type": label[2:],
+ "start": int(off[0]),
+ "end": int(off[1]),
+ "probs": [float(pprob)],
+ }
+ else:
+ cur["end"] = int(off[1])
+ cur["probs"].append(float(pprob))
+if cur:
+ entities.append(cur)
+
+
+# small fixes (timestamp/speaker) like main.py
+def fix_timestamp(ts):
+ if not ts:
+ return ts
+ m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts)
+ if m:
+ y, mo, d, rest = m.groups()
+ if len(y) == 1:
+ y = "202" + y
+ elif len(y) == 2:
+ y = "20" + y
+ return f"{y}-{mo}-{d}{rest}"
+ return ts
+
+
+def fix_speaker(spk):
+ if not spk:
+ return spk
+ spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk)
+ if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk):
+ return spk + "某"
+ return spk
+
+
+out = {"metadata": {}, "content": []}
+for e in entities:
+ s, e_pos = e["start"], e["end"]
+ ent_text = text[s:e_pos]
+ conf = round(float(np.mean(e["probs"])), 3)
+ typ = e["type"]
+ if typ in ("timestamp", "speaker"):
+ ent_text = (
+ fix_timestamp(ent_text) if typ == "timestamp" else fix_speaker(ent_text)
+ )
+ out["metadata"][typ] = ent_text
+ else:
+ out["content"].append({"type": typ, "content": ent_text, "confidence": conf})
+
+print("\nConstructed JSON:")
+print(json.dumps(out, ensure_ascii=False, indent=2))