diff options
| author | 2025-12-30 20:16:05 +0800 | |
|---|---|---|
| committer | 2025-12-30 20:16:05 +0800 | |
| commit | 5dd166366b8a2f4699c1841ebd7fceabcd9868a4 (patch) | |
| tree | 85d78772054529579176547c00aee9559cffff37 /src/basemodel/inference | |
| parent | dd55c70225367dec9e8d88821b4d65fcd24edd65 (diff) | |
| download | base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.tar.gz base-model-5dd166366b8a2f4699c1841ebd7fceabcd9868a4.zip | |
refactor: Refactor TRPG NER model SDK: restructure codebase into base_model_trpgner package, implement training and inference modules, and add model download functionality. Remove legacy training and utils modules. Enhance documentation and examples for better usability.
Diffstat (limited to 'src/basemodel/inference')
| -rw-r--r-- | src/basemodel/inference/__init__.py | 292 |
1 files changed, 0 insertions, 292 deletions
diff --git a/src/basemodel/inference/__init__.py b/src/basemodel/inference/__init__.py deleted file mode 100644 index 93a185f..0000000 --- a/src/basemodel/inference/__init__.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -ONNX 推理模块 - -提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。 -""" - -import os -from typing import List, Dict, Any, Optional -from pathlib import Path - -try: - import numpy as np - import onnxruntime as ort - from transformers import AutoTokenizer -except ImportError as e: - raise ImportError( - "依赖未安装。请运行: pip install onnxruntime transformers numpy" - ) from e - - -# 默认模型路径(相对于包安装位置) -DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final" -# 远程模型 URL(用于自动下载) -MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" - - -class TRPGParser: - """ - TRPG 日志解析器(基于 ONNX) - - Args: - model_path: ONNX 模型路径,默认使用内置模型 - tokenizer_path: tokenizer 配置路径,默认与 model_path 相同 - device: 推理设备,"cpu" 或 "cuda" - - Examples: - >>> parser = TRPGParser() - >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") - >>> print(result['metadata']['speaker']) - '风雨' - """ - - def __init__( - self, - model_path: Optional[str] = None, - tokenizer_path: Optional[str] = None, - device: str = "cpu", - ): - # 确定模型路径 - if model_path is None: - model_path = self._get_default_model_path() - - if tokenizer_path is None: - tokenizer_path = Path(model_path).parent - - self.model_path = Path(model_path) - self.tokenizer_path = Path(tokenizer_path) - self.device = device - - # 加载模型 - self._load_model() - - def _get_default_model_path(self) -> str: - """获取默认模型路径""" - # 1. 尝试相对于项目根目录 - project_root = Path(__file__).parent.parent.parent.parent - local_model = project_root / "models" / "trpg-final" / "model.onnx" - if local_model.exists(): - return str(local_model) - - # 2. 尝试用户数据目录 - from pathlib import Path - user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" - user_model = user_model_dir / "model.onnx" - if user_model.exists(): - return str(user_model) - - # 3. 抛出错误,提示下载 - raise FileNotFoundError( - f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n" - f"或运行: python -m basemodel.download_model" - ) - - def _load_model(self): - """加载 ONNX 模型和 Tokenizer""" - # 加载 tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - str(self.tokenizer_path), - local_files_only=True, - ) - - # 加载 ONNX 模型 - providers = ["CPUExecutionProvider"] - if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers(): - providers.insert(0, "CUDAExecutionProvider") - - self.session = ort.InferenceSession( - str(self.model_path), - providers=providers, - ) - - # 加载标签映射 - import json - config_path = self.tokenizer_path / "config.json" - if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: - config = json.load(f) - self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()} - else: - # 默认标签 - self.id2label = { - 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment", - 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker", - 9: "B-timestamp", 10: "I-timestamp", - } - - def parse(self, text: str) -> Dict[str, Any]: - """ - 解析单条 TRPG 日志 - - Args: - text: 待解析的日志文本 - - Returns: - 包含 metadata 和 content 的字典 - - metadata: speaker, timestamp - - content: dialogue, action, comment 列表 - - Examples: - >>> parser = TRPGParser() - >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") - >>> result['metadata']['speaker'] - '风雨' - """ - # Tokenize - inputs = self.tokenizer( - text, - return_tensors="np", - return_offsets_mapping=True, - padding="max_length", - truncation=True, - max_length=128, - ) - - # 推理 - outputs = self.session.run( - ["logits"], - { - "input_ids": inputs["input_ids"].astype(np.int64), - "attention_mask": inputs["attention_mask"].astype(np.int64), - }, - ) - - # 后处理 - logits = outputs[0][0] - predictions = np.argmax(logits, axis=-1) - offsets = inputs["offset_mapping"][0] - - # 聚合实体 - entities = self._group_entities(predictions, offsets, logits) - - # 构建结果 - result = {"metadata": {}, "content": []} - for ent in entities: - if ent["start"] >= len(text) or ent["end"] > len(text): - continue - - raw_text = text[ent["start"]: ent["end"]] - clean_text = self._clean_text(raw_text, ent["type"]) - - if not clean_text.strip(): - continue - - if ent["type"] in ["timestamp", "speaker"]: - result["metadata"][ent["type"]] = clean_text - elif ent["type"] in ["dialogue", "action", "comment"]: - result["content"].append({ - "type": ent["type"], - "content": clean_text, - "confidence": round(ent["score"], 3), - }) - - return result - - def _group_entities(self, predictions, offsets, logits): - """将 token 级别的预测聚合为实体""" - entities = [] - current = None - - for i in range(len(predictions)): - start, end = offsets[i] - if start == end: # special tokens - continue - - pred_id = int(predictions[i]) - label = self.id2label.get(pred_id, "O") - - if label == "O": - if current: - entities.append(current) - current = None - continue - - tag_type = label[2:] if len(label) > 2 else "O" - - if label.startswith("B-"): - if current: - entities.append(current) - current = { - "type": tag_type, - "start": int(start), - "end": int(end), - "score": float(np.max(logits[i])), - } - elif label.startswith("I-") and current and current["type"] == tag_type: - current["end"] = int(end) - else: - if current: - entities.append(current) - current = None - - if current: - entities.append(current) - - return entities - - def _clean_text(self, text: str, group: str) -> str: - """清理提取的文本""" - import re - - text = text.strip() - - # 移除周围符号 - if group == "comment": - text = re.sub(r"^[((]+|[))]+$", "", text) - elif group == "dialogue": - text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text) - elif group == "action": - text = re.sub(r"^[*#]+|[*#]+$", "", text) - - # 修复时间戳 - if group == "timestamp" and text and text[0].isdigit(): - if len(text) > 2 and text[2] == "-": - text = "20" + text - - return text - - def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: - """ - 批量解析多条日志 - - Args: - texts: 日志文本列表 - - Returns: - 解析结果列表 - """ - return [self.parse(text) for text in texts] - - -# 便捷函数 -def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]: - """ - 解析单条日志的便捷函数 - - Args: - text: 日志文本 - model_path: 可选的模型路径 - - Returns: - 解析结果字典 - """ - parser = TRPGParser(model_path=model_path) - return parser.parse(text) - - -def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]: - """ - 批量解析日志的便捷函数 - - Args: - texts: 日志文本列表 - model_path: 可选的模型路径 - - Returns: - 解析结果列表 - """ - parser = TRPGParser(model_path=model_path) - return parser.parse_batch(texts) - - -__all__ = ["TRPGParser", "parse_line", "parse_lines"] |
