1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
|
"""
ONNX 推理模块
提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。
"""
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
try:
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"依赖未安装。请运行: pip install onnxruntime transformers numpy"
) from e
# 默认模型路径(相对于包安装位置)
DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final"
# 远程模型 URL(用于自动下载)
MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
class TRPGParser:
"""
TRPG 日志解析器(基于 ONNX)
Args:
model_path: ONNX 模型路径,默认使用内置模型
tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
device: 推理设备,"cpu" 或 "cuda"
Examples:
>>> parser = TRPGParser()
>>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
>>> print(result['metadata']['speaker'])
'风雨'
"""
def __init__(
self,
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
device: str = "cpu",
):
# 确定模型路径
if model_path is None:
model_path = self._get_default_model_path()
if tokenizer_path is None:
tokenizer_path = Path(model_path).parent
self.model_path = Path(model_path)
self.tokenizer_path = Path(tokenizer_path)
self.device = device
# 加载模型
self._load_model()
def _get_default_model_path(self) -> str:
"""获取默认模型路径"""
# 1. 尝试相对于项目根目录
project_root = Path(__file__).parent.parent.parent.parent
local_model = project_root / "models" / "trpg-final" / "model.onnx"
if local_model.exists():
return str(local_model)
# 2. 尝试用户数据目录
from pathlib import Path
user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
user_model = user_model_dir / "model.onnx"
if user_model.exists():
return str(user_model)
# 3. 抛出错误,提示下载
raise FileNotFoundError(
f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n"
f"或运行: python -m basemodel.download_model"
)
def _load_model(self):
"""加载 ONNX 模型和 Tokenizer"""
# 加载 tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
str(self.tokenizer_path),
local_files_only=True,
)
# 加载 ONNX 模型
providers = ["CPUExecutionProvider"]
if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
self.session = ort.InferenceSession(
str(self.model_path),
providers=providers,
)
# 加载标签映射
import json
config_path = self.tokenizer_path / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()}
else:
# 默认标签
self.id2label = {
0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment",
5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker",
9: "B-timestamp", 10: "I-timestamp",
}
def parse(self, text: str) -> Dict[str, Any]:
"""
解析单条 TRPG 日志
Args:
text: 待解析的日志文本
Returns:
包含 metadata 和 content 的字典
- metadata: speaker, timestamp
- content: dialogue, action, comment 列表
Examples:
>>> parser = TRPGParser()
>>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
>>> result['metadata']['speaker']
'风雨'
"""
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="np",
return_offsets_mapping=True,
padding="max_length",
truncation=True,
max_length=128,
)
# 推理
outputs = self.session.run(
["logits"],
{
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64),
},
)
# 后处理
logits = outputs[0][0]
predictions = np.argmax(logits, axis=-1)
offsets = inputs["offset_mapping"][0]
# 聚合实体
entities = self._group_entities(predictions, offsets, logits)
# 构建结果
result = {"metadata": {}, "content": []}
for ent in entities:
if ent["start"] >= len(text) or ent["end"] > len(text):
continue
raw_text = text[ent["start"]: ent["end"]]
clean_text = self._clean_text(raw_text, ent["type"])
if not clean_text.strip():
continue
if ent["type"] in ["timestamp", "speaker"]:
result["metadata"][ent["type"]] = clean_text
elif ent["type"] in ["dialogue", "action", "comment"]:
result["content"].append({
"type": ent["type"],
"content": clean_text,
"confidence": round(ent["score"], 3),
})
return result
def _group_entities(self, predictions, offsets, logits):
"""将 token 级别的预测聚合为实体"""
entities = []
current = None
for i in range(len(predictions)):
start, end = offsets[i]
if start == end: # special tokens
continue
pred_id = int(predictions[i])
label = self.id2label.get(pred_id, "O")
if label == "O":
if current:
entities.append(current)
current = None
continue
tag_type = label[2:] if len(label) > 2 else "O"
if label.startswith("B-"):
if current:
entities.append(current)
current = {
"type": tag_type,
"start": int(start),
"end": int(end),
"score": float(np.max(logits[i])),
}
elif label.startswith("I-") and current and current["type"] == tag_type:
current["end"] = int(end)
else:
if current:
entities.append(current)
current = None
if current:
entities.append(current)
return entities
def _clean_text(self, text: str, group: str) -> str:
"""清理提取的文本"""
import re
text = text.strip()
# 移除周围符号
if group == "comment":
text = re.sub(r"^[((]+|[))]+$", "", text)
elif group == "dialogue":
text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text)
elif group == "action":
text = re.sub(r"^[*#]+|[*#]+$", "", text)
# 修复时间戳
if group == "timestamp" and text and text[0].isdigit():
if len(text) > 2 and text[2] == "-":
text = "20" + text
return text
def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
批量解析多条日志
Args:
texts: 日志文本列表
Returns:
解析结果列表
"""
return [self.parse(text) for text in texts]
# 便捷函数
def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]:
"""
解析单条日志的便捷函数
Args:
text: 日志文本
model_path: 可选的模型路径
Returns:
解析结果字典
"""
parser = TRPGParser(model_path=model_path)
return parser.parse(text)
def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]:
"""
批量解析日志的便捷函数
Args:
texts: 日志文本列表
model_path: 可选的模型路径
Returns:
解析结果列表
"""
parser = TRPGParser(model_path=model_path)
return parser.parse_batch(texts)
__all__ = ["TRPGParser", "parse_line", "parse_lines"]
|