diff options
| author | 2025-12-30 20:39:34 +0800 | |
|---|---|---|
| committer | 2025-12-30 20:39:34 +0800 | |
| commit | 298035052b3e3d083b57f5dbac0e86de4f94efba (patch) | |
| tree | 944f38d734f752a5a0f71033ebece38fc5c35839 | |
| parent | 92a647ffbb3452a0ed49601177f290e20a88413e (diff) | |
| download | base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.tar.gz base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.zip | |
refactor: Update model download functionality and improve inference module to support automatic model retrieval from GitHub releases
| -rw-r--r-- | .github/workflows/publish.yml | 4 | ||||
| -rw-r--r-- | pyproject.toml | 17 | ||||
| -rw-r--r-- | src/base_model_trpgner/__init__.py | 2 | ||||
| -rw-r--r-- | src/base_model_trpgner/download_model.py | 70 | ||||
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 137 | ||||
| -rw-r--r-- | src/base_model_trpgner/training/__init__.py | 4 | ||||
| -rw-r--r-- | tests/test_onnx_only_infer.py | 198 |
7 files changed, 324 insertions, 108 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 01c6060..c6b5fcf 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -4,8 +4,6 @@ on: push: tags: - 'v*.*.*' - branches: - - main workflow_dispatch: inputs: create_test: @@ -201,7 +199,7 @@ jobs: ## 🚀 快速开始 ```python - from basemodeltrpgner import TRPGParser + from base_model_trpgner import TRPGParser parser = TRPGParser() result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") diff --git a/pyproject.toml b/pyproject.toml index b92fae3..f9e8216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,22 +56,11 @@ Issues = "https://github.com/HydroRoll-Team/base-model/issues" [tool.hatch.build.targets.wheel] packages = ["src/base_model_trpgner"] -artifacts = [ - "src/base_model_trpgner/**/*.py", - "models/trpg-final/model.onnx", - "models/trpg-final/model.onnx.data", - "models/trpg-final/config.json", - "models/trpg-final/tokenizer.json", - "models/trpg-final/tokenizer_config.json", - "models/trpg-final/special_tokens_map.json", - "models/trpg-final/vocab.txt", -] - -[tool.hatch.build.targets.wheel.shared-data] -"models/trpg-final" = "base_model_trpgner/models/trpg-final" +# 仅包含 Python 源码,模型文件首次运行时自动从 GitHub Release 下载 +artifacts = ["src/base_model_trpgner/**/*.py"] [tool.hatch.build.targets.sdist] -include = ["/src", "/models/trpg-final", "/README.md", "/COPYING"] +include = ["/src", "/README.md", "/COPYING"] [tool.ruff] exclude = [ diff --git a/src/base_model_trpgner/__init__.py b/src/base_model_trpgner/__init__.py index 9796c83..8549a69 100644 --- a/src/base_model_trpgner/__init__.py +++ b/src/base_model_trpgner/__init__.py @@ -22,7 +22,7 @@ try: from importlib.metadata import version __version__ = version("base_model_trpgner") except Exception: - __version__ = "0.1.1.dev" + __version__ = "0.1.2.dev" __all__ = [ "__version__", diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py deleted file mode 100644 index 2d65099..0000000 --- a/src/base_model_trpgner/download_model.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -""" -模型下载脚本 - -自动下载预训练的 ONNX 模型到用户缓存目录。 -""" - -import os -import sys -from pathlib import Path -import urllib.request - - -def download_model( - output_dir: str = None, - url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" -): - """ - 下载 ONNX 模型 - - Args: - output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/ - url: 模型下载 URL - """ - if output_dir is None: - output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" - else: - output_dir = Path(output_dir) - - output_dir.mkdir(parents=True, exist_ok=True) - output_path = output_dir / "model.onnx" - - if output_path.exists(): - print(f"✅ 模型已存在: {output_path}") - return str(output_path) - - print(f"📥 正在下载模型到 {output_path}...") - print(f" URL: {url}") - - try: - urllib.request.urlretrieve(url, output_path) - print(f"✅ 模型下载成功!") - return str(output_path) - except Exception as e: - print(f"❌ 下载失败: {e}") - print(f" 请手动从以下地址下载模型:") - print(f" {url}") - sys.exit(1) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型") - parser.add_argument( - "--output-dir", - type=str, - default=None, - help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)" - ) - parser.add_argument( - "--url", - type=str, - default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx", - help="模型下载 URL" - ) - - args = parser.parse_args() - - download_model(args.output_dir, args.url) diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index 93a185f..d70cb23 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -5,6 +5,8 @@ ONNX 推理模块 """ import os +import json +import shutil from typing import List, Dict, Any, Optional from pathlib import Path @@ -18,20 +20,113 @@ except ImportError as e: ) from e -# 默认模型路径(相对于包安装位置) -DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final" -# 远程模型 URL(用于自动下载) -MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" +# GitHub 仓库信息 +REPO_OWNER = "HydroRoll-Team" +REPO_NAME = "base-model" +# 用户数据目录 +USER_MODEL_DIR = Path.home() / ".cache" / "base_model_trpgner" / "models" / "trpg-final" + + +def get_latest_release_url() -> str: + """ + 获取 GitHub 最新 Release 的下载 URL + + Returns: + 最新 Release 的标签名(如 v0.1.0) + """ + import urllib.request + import urllib.error + + try: + # 使用 GitHub API 获取最新 release + api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases/latest" + with urllib.request.urlopen(api_url, timeout=10) as response: + data = json.load(response) + return data.get("tag_name", "v0.1.0") + except (urllib.error.URLError, json.JSONDecodeError, KeyError): + # 失败时返回默认版本 + return "v0.1.0" + + +def download_model_files(version: Optional[str] = None, force: bool = False) -> Path: + """ + 从 GitHub Release 下载模型文件 + + Args: + version: Release 版本(如 v0.1.0),None 表示最新版本 + force: 是否强制重新下载(即使文件已存在) + + Returns: + 模型文件保存目录 + """ + import urllib.request + import urllib.error + + if version is None: + version = get_latest_release_url() + + model_dir = USER_MODEL_DIR + model_dir.mkdir(parents=True, exist_ok=True) + + # 检查是否已下载 + marker_file = model_dir / ".version" + if not force and marker_file.exists(): + with open(marker_file, "r") as f: + current_version = f.read().strip() + if current_version == version: + print(f"模型已存在 (版本: {version})") + return model_dir + + print(f"正在下载模型 {version}...") + + # 需要下载的文件 + base_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}" + files_to_download = [ + "model.onnx", + "model.onnx.data", + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.txt", + ] + + for filename in files_to_download: + url = f"{base_url}/{filename}" + dest_path = model_dir / filename + + try: + print(f"下载 {filename}...") + with urllib.request.urlopen(url, timeout=60) as response: + with open(dest_path, "wb") as f: + shutil.copyfileobj(response, f) + except urllib.error.HTTPError as e: + print(f"下载 {filename} 失败: {e}") + # model.onnx.data 不是必需的(某些小模型可能没有) + if filename == "model.onnx.data": + continue + else: + raise RuntimeError(f"无法下载模型文件: {filename}") from e + + # 写入版本标记 + with open(marker_file, "w") as f: + f.write(version) + + print(f"✅ 模型下载完成: {model_dir}") + return model_dir class TRPGParser: """ TRPG 日志解析器(基于 ONNX) + 首次运行时会自动从 GitHub Release 下载最新模型。 + Args: - model_path: ONNX 模型路径,默认使用内置模型 + model_path: ONNX 模型路径,默认使用自动下载的模型 tokenizer_path: tokenizer 配置路径,默认与 model_path 相同 device: 推理设备,"cpu" 或 "cuda" + auto_download: 是否自动下载模型(默认 True) Examples: >>> parser = TRPGParser() @@ -45,10 +140,11 @@ class TRPGParser: model_path: Optional[str] = None, tokenizer_path: Optional[str] = None, device: str = "cpu", + auto_download: bool = True, ): # 确定模型路径 if model_path is None: - model_path = self._get_default_model_path() + model_path = self._get_default_model_path(auto_download) if tokenizer_path is None: tokenizer_path = Path(model_path).parent @@ -60,25 +156,31 @@ class TRPGParser: # 加载模型 self._load_model() - def _get_default_model_path(self) -> str: - """获取默认模型路径""" - # 1. 尝试相对于项目根目录 + def _get_default_model_path(self, auto_download: bool) -> str: + """获取默认模型路径,必要时自动下载""" + # 1. 检查本地开发环境 project_root = Path(__file__).parent.parent.parent.parent local_model = project_root / "models" / "trpg-final" / "model.onnx" if local_model.exists(): return str(local_model) - # 2. 尝试用户数据目录 - from pathlib import Path - user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" - user_model = user_model_dir / "model.onnx" + # 2. 检查用户缓存目录 + user_model = USER_MODEL_DIR / "model.onnx" if user_model.exists(): return str(user_model) - # 3. 抛出错误,提示下载 + # 3. 自动下载 + if auto_download: + print("模型未找到,正在从 GitHub Release 下载...") + download_model_files() + return str(user_model) + + # 4. 抛出错误 raise FileNotFoundError( - f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n" - f"或运行: python -m basemodel.download_model" + f"模型文件未找到。\n" + f"请开启自动下载: TRPGParser(auto_download=True)\n" + f"或手动下载到: {USER_MODEL_DIR}\n" + f"下载地址: https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest" ) def _load_model(self): @@ -100,7 +202,6 @@ class TRPGParser: ) # 加载标签映射 - import json config_path = self.tokenizer_path / "config.json" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: @@ -289,4 +390,4 @@ def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict return parser.parse_batch(texts) -__all__ = ["TRPGParser", "parse_line", "parse_lines"] +__all__ = ["TRPGParser", "parse_line", "parse_lines", "download_model_files"] diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py index ccf3c03..4f8e30d 100644 --- a/src/base_model_trpgner/training/__init__.py +++ b/src/base_model_trpgner/training/__init__.py @@ -36,11 +36,11 @@ def train_ner_model( resume_from_checkpoint: 恢复检查点路径 Examples: - >>> from basemodeltrpgner.training import train_ner_model + >>> from base_model_trpgner.training import train_ner_model >>> train_ner_model( ... conll_data="./data", ... output_dir="./my_model", - ... epochs=10 + ... num_train_epochs=10 ... ) """ try: diff --git a/tests/test_onnx_only_infer.py b/tests/test_onnx_only_infer.py new file mode 100644 index 0000000..69c72be --- /dev/null +++ b/tests/test_onnx_only_infer.py @@ -0,0 +1,198 @@ +""" +Minimal ONNX-only inference using only: + - models/trpg-final/model.onnx + - models/trpg-final/config.json + +NOTE: 使用自制字符级 tokenizer(非训练时 tokenizer),结果可能与原模型输出不一致, +但可在没有 tokenizer 文件时完成端到端推理演示。 +""" + +import os, sys, json, re +import numpy as np +import onnxruntime as ort + +MODEL_DIR = "models/trpg-final" +ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx") +CFG_PATH = os.path.join(MODEL_DIR, "config.json") +MAX_LEN = 128 + +# load id2label & vocab_size +with open(CFG_PATH, "r", encoding="utf-8") as f: + cfg = json.load(f) +id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()} +vocab_size = int(cfg.get("vocab_size", 30000)) +pad_id = int(cfg.get("pad_token_id", 0)) + +# simple char-level tokenizer (adds [CLS]=101, [SEP]=102, pads with pad_id) +CLS_ID = 101 +SEP_ID = 102 + + +def char_tokenize(text, max_length=MAX_LEN): + chars = list(text) + # reserve 2 for CLS and SEP + max_chars = max_length - 2 + chars = chars[:max_chars] + ids = [CLS_ID] + [100 + (ord(c) % (vocab_size - 200)) for c in chars] + [SEP_ID] + attn = [1] * len(ids) + # pad + pad_len = max_length - len(ids) + ids += [pad_id] * pad_len + attn += [0] * pad_len + # offsets: for CLS/SEP/pad use (0,0); for char tokens map to character positions + offsets = [(0, 0)] + pos = 0 + for c in chars: + offsets.append((pos, pos + 1)) + pos += 1 + offsets.append((0, 0)) # SEP + offsets += [(0, 0)] * pad_len + return { + "input_ids": np.array([ids], dtype=np.int64), + "attention_mask": np.array([attn], dtype=np.int64), + "offset_mapping": np.array([offsets], dtype=np.int64), + "text": text, + } + + +# onnx runtime session +providers = [ + p + for p in ("CUDAExecutionProvider", "CPUExecutionProvider") + if p in ort.get_available_providers() +] +so = ort.SessionOptions() +so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL +sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers) + + +def softmax(x): + x = x - x.max(axis=-1, keepdims=True) + e = np.exp(x) + return e / e.sum(axis=-1, keepdims=True) + + +text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..." +inp = char_tokenize(text, MAX_LEN) + +# build feed dict matching session inputs +feed = {} +for s_in in sess.get_inputs(): + name = s_in.name + if name in inp: + feed[name] = inp[name] + +outs = sess.run(None, feed) +logits = np.asarray(outs[0]) # (batch, seq_len, num_labels) +probs = softmax(logits) + +ids = inp["input_ids"][0] +offsets = inp["offset_mapping"][0] +attn = inp["attention_mask"][0] + +# reconstruct token strings (CLS, each char, SEP) +tokens = [] +for i, idv in enumerate(ids): + if i == 0: + tokens.append("[CLS]") + else: + if offsets[i][0] == 0 and offsets[i][1] == 0: + # SEP or pad + if attn[i] == 1: + tokens.append("[SEP]") + else: + tokens.append("[PAD]") + else: + s, e = offsets[i] + tokens.append(text[s:e]) + +# print raw logits shape and a small slice for inspection +print("Raw logits shape:", logits.shape) +print("\nPer-token logits (index token -> first 6 logits):") +for i, (t, l, a) in enumerate(zip(tokens, logits[0], attn)): + if not a: + continue + print(f"{i:03d} {t:>6} ->", np.around(l[:6], 3).tolist()) + +# predictions & probs +pred_ids = logits.argmax(-1)[0] +pred_probs = probs[0, np.arange(probs.shape[1]), pred_ids] + +print("\nPer-token predictions (token \\t label \\t prob):") +for i, (t, pid, pprob, a) in enumerate(zip(tokens, pred_ids, pred_probs, attn)): + if not a: + continue + lab = id2label.get(int(pid), "O") + print(f"{t}\t{lab}\t{pprob:.3f}") + +# merge BIO into entities using offsets +entities = [] +cur = None +for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)): + if not a or (off[0] == off[1] == 0): + if cur: + entities.append(cur) + cur = None + continue + label = id2label.get(int(pid), "O") + if label == "O": + if cur: + entities.append(cur) + cur = None + continue + if label.startswith("B-") or cur is None or label[2:] != cur["type"]: + if cur: + entities.append(cur) + cur = { + "type": label[2:], + "start": int(off[0]), + "end": int(off[1]), + "probs": [float(pprob)], + } + else: + cur["end"] = int(off[1]) + cur["probs"].append(float(pprob)) +if cur: + entities.append(cur) + + +# small fixes (timestamp/speaker) like main.py +def fix_timestamp(ts): + if not ts: + return ts + m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) + if m: + y, mo, d, rest = m.groups() + if len(y) == 1: + y = "202" + y + elif len(y) == 2: + y = "20" + y + return f"{y}-{mo}-{d}{rest}" + return ts + + +def fix_speaker(spk): + if not spk: + return spk + spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) + if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk + "某" + return spk + + +out = {"metadata": {}, "content": []} +for e in entities: + s, e_pos = e["start"], e["end"] + ent_text = text[s:e_pos] + conf = round(float(np.mean(e["probs"])), 3) + typ = e["type"] + if typ in ("timestamp", "speaker"): + ent_text = ( + fix_timestamp(ent_text) if typ == "timestamp" else fix_speaker(ent_text) + ) + out["metadata"][typ] = ent_text + else: + out["content"].append({"type": typ, "content": ent_text, "confidence": conf}) + +print("\nConstructed JSON:") +print(json.dumps(out, ensure_ascii=False, indent=2)) |
