aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/inference/__init__.py
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 20:16:05 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 20:16:05 +0800
commit5dd166366b8a2f4699c1841ebd7fceabcd9868a4 (patch)
tree85d78772054529579176547c00aee9559cffff37 /src/base_model_trpgner/inference/__init__.py
parentdd55c70225367dec9e8d88821b4d65fcd24edd65 (diff)
downloadbase-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.tar.gz
base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.zip
refactor: Refactor TRPG NER model SDK: restructure codebase into base_model_trpgner package, implement training and inference modules, and add model download functionality. Remove legacy training and utils modules. Enhance documentation and examples for better usability.
Diffstat (limited to 'src/base_model_trpgner/inference/__init__.py')
-rw-r--r--src/base_model_trpgner/inference/__init__.py292
1 files changed, 292 insertions, 0 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
new file mode 100644
index 0000000..93a185f
--- /dev/null
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -0,0 +1,292 @@
+"""
+ONNX 推理模块
+
+提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。
+"""
+
+import os
+from typing import List, Dict, Any, Optional
+from pathlib import Path
+
+try:
+ import numpy as np
+ import onnxruntime as ort
+ from transformers import AutoTokenizer
+except ImportError as e:
+ raise ImportError(
+ "依赖未安装。请运行: pip install onnxruntime transformers numpy"
+ ) from e
+
+
+# 默认模型路径(相对于包安装位置)
+DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final"
+# 远程模型 URL(用于自动下载)
+MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+
+
+class TRPGParser:
+ """
+ TRPG 日志解析器(基于 ONNX)
+
+ Args:
+ model_path: ONNX 模型路径,默认使用内置模型
+ tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
+ device: 推理设备,"cpu" 或 "cuda"
+
+ Examples:
+ >>> parser = TRPGParser()
+ >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
+ >>> print(result['metadata']['speaker'])
+ '风雨'
+ """
+
+ def __init__(
+ self,
+ model_path: Optional[str] = None,
+ tokenizer_path: Optional[str] = None,
+ device: str = "cpu",
+ ):
+ # 确定模型路径
+ if model_path is None:
+ model_path = self._get_default_model_path()
+
+ if tokenizer_path is None:
+ tokenizer_path = Path(model_path).parent
+
+ self.model_path = Path(model_path)
+ self.tokenizer_path = Path(tokenizer_path)
+ self.device = device
+
+ # 加载模型
+ self._load_model()
+
+ def _get_default_model_path(self) -> str:
+ """获取默认模型路径"""
+ # 1. 尝试相对于项目根目录
+ project_root = Path(__file__).parent.parent.parent.parent
+ local_model = project_root / "models" / "trpg-final" / "model.onnx"
+ if local_model.exists():
+ return str(local_model)
+
+ # 2. 尝试用户数据目录
+ from pathlib import Path
+ user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ user_model = user_model_dir / "model.onnx"
+ if user_model.exists():
+ return str(user_model)
+
+ # 3. 抛出错误,提示下载
+ raise FileNotFoundError(
+ f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n"
+ f"或运行: python -m basemodel.download_model"
+ )
+
+ def _load_model(self):
+ """加载 ONNX 模型和 Tokenizer"""
+ # 加载 tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ str(self.tokenizer_path),
+ local_files_only=True,
+ )
+
+ # 加载 ONNX 模型
+ providers = ["CPUExecutionProvider"]
+ if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
+ providers.insert(0, "CUDAExecutionProvider")
+
+ self.session = ort.InferenceSession(
+ str(self.model_path),
+ providers=providers,
+ )
+
+ # 加载标签映射
+ import json
+ config_path = self.tokenizer_path / "config.json"
+ if config_path.exists():
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()}
+ else:
+ # 默认标签
+ self.id2label = {
+ 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment",
+ 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker",
+ 9: "B-timestamp", 10: "I-timestamp",
+ }
+
+ def parse(self, text: str) -> Dict[str, Any]:
+ """
+ 解析单条 TRPG 日志
+
+ Args:
+ text: 待解析的日志文本
+
+ Returns:
+ 包含 metadata 和 content 的字典
+ - metadata: speaker, timestamp
+ - content: dialogue, action, comment 列表
+
+ Examples:
+ >>> parser = TRPGParser()
+ >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
+ >>> result['metadata']['speaker']
+ '风雨'
+ """
+ # Tokenize
+ inputs = self.tokenizer(
+ text,
+ return_tensors="np",
+ return_offsets_mapping=True,
+ padding="max_length",
+ truncation=True,
+ max_length=128,
+ )
+
+ # 推理
+ outputs = self.session.run(
+ ["logits"],
+ {
+ "input_ids": inputs["input_ids"].astype(np.int64),
+ "attention_mask": inputs["attention_mask"].astype(np.int64),
+ },
+ )
+
+ # 后处理
+ logits = outputs[0][0]
+ predictions = np.argmax(logits, axis=-1)
+ offsets = inputs["offset_mapping"][0]
+
+ # 聚合实体
+ entities = self._group_entities(predictions, offsets, logits)
+
+ # 构建结果
+ result = {"metadata": {}, "content": []}
+ for ent in entities:
+ if ent["start"] >= len(text) or ent["end"] > len(text):
+ continue
+
+ raw_text = text[ent["start"]: ent["end"]]
+ clean_text = self._clean_text(raw_text, ent["type"])
+
+ if not clean_text.strip():
+ continue
+
+ if ent["type"] in ["timestamp", "speaker"]:
+ result["metadata"][ent["type"]] = clean_text
+ elif ent["type"] in ["dialogue", "action", "comment"]:
+ result["content"].append({
+ "type": ent["type"],
+ "content": clean_text,
+ "confidence": round(ent["score"], 3),
+ })
+
+ return result
+
+ def _group_entities(self, predictions, offsets, logits):
+ """将 token 级别的预测聚合为实体"""
+ entities = []
+ current = None
+
+ for i in range(len(predictions)):
+ start, end = offsets[i]
+ if start == end: # special tokens
+ continue
+
+ pred_id = int(predictions[i])
+ label = self.id2label.get(pred_id, "O")
+
+ if label == "O":
+ if current:
+ entities.append(current)
+ current = None
+ continue
+
+ tag_type = label[2:] if len(label) > 2 else "O"
+
+ if label.startswith("B-"):
+ if current:
+ entities.append(current)
+ current = {
+ "type": tag_type,
+ "start": int(start),
+ "end": int(end),
+ "score": float(np.max(logits[i])),
+ }
+ elif label.startswith("I-") and current and current["type"] == tag_type:
+ current["end"] = int(end)
+ else:
+ if current:
+ entities.append(current)
+ current = None
+
+ if current:
+ entities.append(current)
+
+ return entities
+
+ def _clean_text(self, text: str, group: str) -> str:
+ """清理提取的文本"""
+ import re
+
+ text = text.strip()
+
+ # 移除周围符号
+ if group == "comment":
+ text = re.sub(r"^[((]+|[))]+$", "", text)
+ elif group == "dialogue":
+ text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text)
+ elif group == "action":
+ text = re.sub(r"^[*#]+|[*#]+$", "", text)
+
+ # 修复时间戳
+ if group == "timestamp" and text and text[0].isdigit():
+ if len(text) > 2 and text[2] == "-":
+ text = "20" + text
+
+ return text
+
+ def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
+ """
+ 批量解析多条日志
+
+ Args:
+ texts: 日志文本列表
+
+ Returns:
+ 解析结果列表
+ """
+ return [self.parse(text) for text in texts]
+
+
+# 便捷函数
+def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]:
+ """
+ 解析单条日志的便捷函数
+
+ Args:
+ text: 日志文本
+ model_path: 可选的模型路径
+
+ Returns:
+ 解析结果字典
+ """
+ parser = TRPGParser(model_path=model_path)
+ return parser.parse(text)
+
+
+def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]:
+ """
+ 批量解析日志的便捷函数
+
+ Args:
+ texts: 日志文本列表
+ model_path: 可选的模型路径
+
+ Returns:
+ 解析结果列表
+ """
+ parser = TRPGParser(model_path=model_path)
+ return parser.parse_batch(texts)
+
+
+__all__ = ["TRPGParser", "parse_line", "parse_lines"]