summaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/inference/__init__.py
blob: 3d1d7207c3e3883a3764ef07bede2bb5be177a09 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""
ONNX 推理模块

提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。
"""

import os
import json
import shutil
from typing import List, Dict, Any, Optional
from pathlib import Path

try:
    import numpy as np
    import onnxruntime as ort
    from transformers import AutoTokenizer
except ImportError as e:
    raise ImportError(
        "依赖未安装。请运行: pip install onnxruntime transformers numpy"
    ) from e


# GitHub 仓库信息
REPO_OWNER = "HydroRoll-Team"
REPO_NAME = "base-model"
# 用户数据目录
USER_MODEL_DIR = Path.home() / ".cache" / "base_model_trpgner" / "models" / "trpg-final"


def get_latest_release_url() -> str:
    """
    获取 GitHub 最新 Release 的下载 URL

    Returns:
        最新 Release 的标签名(如 v0.1.0)
    """
    import urllib.request
    import urllib.error

    try:
        # 使用 GitHub API 获取最新 release
        api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases/latest"
        with urllib.request.urlopen(api_url, timeout=10) as response:
            data = json.load(response)
            return data.get("tag_name", "v0.1.0")
    except (urllib.error.URLError, json.JSONDecodeError, KeyError):
        # 失败时返回默认版本
        return "v0.1.0"


def download_model_files(version: Optional[str] = None, force: bool = False) -> Path:
    """
    从 GitHub Release 下载模型文件

    下载 model.zip 压缩包并解压。

    Args:
        version: Release 版本(如 v0.1.0),None 表示最新版本
        force: 是否强制重新下载(即使文件已存在)

    Returns:
        模型文件保存目录
    """
    import urllib.request
    import urllib.error
    import tempfile
    import zipfile

    if version is None:
        version = get_latest_release_url()

    model_dir = USER_MODEL_DIR
    model_dir.mkdir(parents=True, exist_ok=True)

    # 检查是否已下载
    marker_file = model_dir / ".version"
    if not force and marker_file.exists():
        with open(marker_file, "r") as f:
            current_version = f.read().strip()
        if current_version == version:
            print(f"模型已存在 (版本: {version})")
            return model_dir

    print(f"正在下载模型 {version}...")

    # 下载 model.zip
    zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip"
    try:
        with urllib.request.urlopen(zip_url, timeout=120) as response:
            # 下载到临时文件
            with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
                shutil.copyfileobj(response, tmp_file)
                tmp_path = tmp_file.name

            # 解压到目标目录
            print(f"解压中...")
            with zipfile.ZipFile(tmp_path, 'r') as zip_ref:
                zip_ref.extractall(model_dir)

            # 清理临时文件
            os.unlink(tmp_path)

        print(f"✅ 模型下载完成: {model_dir}")

    except urllib.error.HTTPError as e:
        raise RuntimeError(
            f"无法下载模型文件。请检查版本 {version} 是否存在。\n"
            f"下载地址: {zip_url}"
        ) from e

    # 写入版本标记
    with open(marker_file, "w") as f:
        f.write(version)

    return model_dir


class TRPGParser:
    """
    TRPG 日志解析器(基于 ONNX)

    首次运行时会自动从 GitHub Release 下载最新模型。

    Args:
        model_path: ONNX 模型路径,默认使用自动下载的模型
        tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
        device: 推理设备,"cpu" 或 "cuda"
        auto_download: 是否自动下载模型(默认 True)

    Examples:
        >>> parser = TRPGParser()
        >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
        >>> print(result['metadata']['speaker'])
        '风雨'
    """

    def __init__(
        self,
        model_path: Optional[str] = None,
        tokenizer_path: Optional[str] = None,
        device: str = "cpu",
        auto_download: bool = True,
    ):
        # 确定模型路径
        if model_path is None:
            model_path = self._get_default_model_path(auto_download)

        if tokenizer_path is None:
            tokenizer_path = Path(model_path).parent

        self.model_path = Path(model_path)
        self.tokenizer_path = Path(tokenizer_path)
        self.device = device

        # 加载模型
        self._load_model()

    def _get_default_model_path(self, auto_download: bool) -> str:
        """获取默认模型路径,必要时自动下载"""
        # 1. 检查本地开发环境
        project_root = Path(__file__).parent.parent.parent.parent
        local_model = project_root / "models" / "trpg-final" / "model.onnx"
        if local_model.exists():
            return str(local_model)

        # 2. 检查用户缓存目录
        user_model = USER_MODEL_DIR / "model.onnx"
        if user_model.exists():
            return str(user_model)

        # 3. 自动下载
        if auto_download:
            print("模型未找到,正在从 GitHub Release 下载...")
            download_model_files()
            return str(user_model)

        # 4. 抛出错误
        raise FileNotFoundError(
            f"模型文件未找到。\n"
            f"请开启自动下载: TRPGParser(auto_download=True)\n"
            f"或手动下载到: {USER_MODEL_DIR}\n"
            f"下载地址: https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest"
        )

    def _load_model(self):
        """加载 ONNX 模型和 Tokenizer"""
        # 加载 tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            str(self.tokenizer_path),
            local_files_only=True,
        )

        # 加载 ONNX 模型
        providers = ["CPUExecutionProvider"]
        if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
            providers.insert(0, "CUDAExecutionProvider")

        self.session = ort.InferenceSession(
            str(self.model_path),
            providers=providers,
        )

        # 加载标签映射
        config_path = self.tokenizer_path / "config.json"
        if config_path.exists():
            with open(config_path, "r", encoding="utf-8") as f:
                config = json.load(f)
                self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()}
        else:
            # 默认标签
            self.id2label = {
                0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment",
                5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker",
                9: "B-timestamp", 10: "I-timestamp",
            }

    def parse(self, text: str) -> Dict[str, Any]:
        """
        解析单条 TRPG 日志

        Args:
            text: 待解析的日志文本

        Returns:
            包含 metadata 和 content 的字典
            - metadata: speaker, timestamp
            - content: dialogue, action, comment 列表

        Examples:
            >>> parser = TRPGParser()
            >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
            >>> result['metadata']['speaker']
            '风雨'
        """
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors="np",
            return_offsets_mapping=True,
            padding="max_length",
            truncation=True,
            max_length=128,
        )

        # 推理
        outputs = self.session.run(
            ["logits"],
            {
                "input_ids": inputs["input_ids"].astype(np.int64),
                "attention_mask": inputs["attention_mask"].astype(np.int64),
            },
        )

        # 后处理
        logits = outputs[0][0]
        predictions = np.argmax(logits, axis=-1)
        offsets = inputs["offset_mapping"][0]

        # 聚合实体
        entities = self._group_entities(predictions, offsets, logits)

        # 构建结果
        result = {"metadata": {}, "content": []}
        for ent in entities:
            if ent["start"] >= len(text) or ent["end"] > len(text):
                continue

            raw_text = text[ent["start"]: ent["end"]]
            clean_text = self._clean_text(raw_text, ent["type"])

            if not clean_text.strip():
                continue

            if ent["type"] in ["timestamp", "speaker"]:
                result["metadata"][ent["type"]] = clean_text
            elif ent["type"] in ["dialogue", "action", "comment"]:
                result["content"].append({
                    "type": ent["type"],
                    "content": clean_text,
                    "confidence": round(ent["score"], 3),
                })

        return result

    def _group_entities(self, predictions, offsets, logits):
        """将 token 级别的预测聚合为实体"""
        entities = []
        current = None

        for i in range(len(predictions)):
            start, end = offsets[i]
            if start == end:  # special tokens
                continue

            pred_id = int(predictions[i])
            label = self.id2label.get(pred_id, "O")

            if label == "O":
                if current:
                    entities.append(current)
                    current = None
                continue

            tag_type = label[2:] if len(label) > 2 else "O"

            if label.startswith("B-"):
                if current:
                    entities.append(current)
                current = {
                    "type": tag_type,
                    "start": int(start),
                    "end": int(end),
                    "score": float(np.max(logits[i])),
                }
            elif label.startswith("I-") and current and current["type"] == tag_type:
                current["end"] = int(end)
            else:
                if current:
                    entities.append(current)
                current = None

        if current:
            entities.append(current)

        return entities

    def _clean_text(self, text: str, group: str) -> str:
        """清理提取的文本"""
        import re

        text = text.strip()

        # 移除周围符号
        if group == "comment":
            text = re.sub(r"^[((]+|[))]+$", "", text)
        elif group == "dialogue":
            text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text)
        elif group == "action":
            text = re.sub(r"^[*#]+|[*#]+$", "", text)

        # 修复时间戳
        if group == "timestamp" and text and text[0].isdigit():
            if len(text) > 2 and text[2] == "-":
                text = "20" + text

        return text

    def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
        """
        批量解析多条日志

        Args:
            texts: 日志文本列表

        Returns:
            解析结果列表
        """
        return [self.parse(text) for text in texts]


# 便捷函数
def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]:
    """
    解析单条日志的便捷函数

    Args:
        text: 日志文本
        model_path: 可选的模型路径

    Returns:
        解析结果字典
    """
    parser = TRPGParser(model_path=model_path)
    return parser.parse(text)


def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]:
    """
    批量解析日志的便捷函数

    Args:
        texts: 日志文本列表
        model_path: 可选的模型路径

    Returns:
        解析结果列表
    """
    parser = TRPGParser(model_path=model_path)
    return parser.parse_batch(texts)


__all__ = ["TRPGParser", "parse_line", "parse_lines", "download_model_files"]