From 5dd166366b8a2f4699c1841ebd7fceabcd9868a4 Mon Sep 17 00:00:00 2001 From: HsiangNianian Date: Tue, 30 Dec 2025 20:16:05 +0800 Subject: refactor: Refactor TRPG NER model SDK: restructure codebase into base_model_trpgner package, implement training and inference modules, and add model download functionality. Remove legacy training and utils modules. Enhance documentation and examples for better usability. --- src/base_model_trpgner/__init__.py | 36 ++++ src/base_model_trpgner/download_model.py | 70 +++++++ src/base_model_trpgner/inference/__init__.py | 292 +++++++++++++++++++++++++++ src/base_model_trpgner/training/__init__.py | 205 +++++++++++++++++++ src/base_model_trpgner/utils/__init__.py | 192 ++++++++++++++++++ src/basemodel/__init__.py | 36 ---- src/basemodel/download_model.py | 70 ------- src/basemodel/inference/__init__.py | 292 --------------------------- src/basemodel/training/__init__.py | 205 ------------------- src/basemodel/utils/__init__.py | 192 ------------------ 10 files changed, 795 insertions(+), 795 deletions(-) create mode 100644 src/base_model_trpgner/__init__.py create mode 100644 src/base_model_trpgner/download_model.py create mode 100644 src/base_model_trpgner/inference/__init__.py create mode 100644 src/base_model_trpgner/training/__init__.py create mode 100644 src/base_model_trpgner/utils/__init__.py delete mode 100644 src/basemodel/__init__.py delete mode 100644 src/basemodel/download_model.py delete mode 100644 src/basemodel/inference/__init__.py delete mode 100644 src/basemodel/training/__init__.py delete mode 100644 src/basemodel/utils/__init__.py (limited to 'src') diff --git a/src/base_model_trpgner/__init__.py b/src/base_model_trpgner/__init__.py new file mode 100644 index 0000000..9796c83 --- /dev/null +++ b/src/base_model_trpgner/__init__.py @@ -0,0 +1,36 @@ +""" +base-model-trpgner - HydroRoll TRPG NER 模型 SDK + +这是一个用于 TRPG(桌上角色扮演游戏)日志命名实体识别的 Python SDK。 + +基本用法: + >>> from base_model_trpgner import TRPGParser + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> print(result) + {'metadata': {'speaker': '风雨', 'timestamp': '2024-06-08 21:44:59'}, 'content': [...]} + +训练功能(需要额外安装): + >>> pip install base-model-trpgner[train] + >>> from base_model_trpgner.training import train_ner_model + >>> train_ner_model(conll_data="./data", output_dir="./model") +""" + +from base_model_trpgner.inference import TRPGParser, parse_line, parse_lines + +try: + from importlib.metadata import version + __version__ = version("base_model_trpgner") +except Exception: + __version__ = "0.1.1.dev" + +__all__ = [ + "__version__", + "TRPGParser", + "parse_line", + "parse_lines", +] + + +def get_version(): + return __version__ diff --git a/src/base_model_trpgner/download_model.py b/src/base_model_trpgner/download_model.py new file mode 100644 index 0000000..2d65099 --- /dev/null +++ b/src/base_model_trpgner/download_model.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +模型下载脚本 + +自动下载预训练的 ONNX 模型到用户缓存目录。 +""" + +import os +import sys +from pathlib import Path +import urllib.request + + +def download_model( + output_dir: str = None, + url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" +): + """ + 下载 ONNX 模型 + + Args: + output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/ + url: 模型下载 URL + """ + if output_dir is None: + output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "model.onnx" + + if output_path.exists(): + print(f"✅ 模型已存在: {output_path}") + return str(output_path) + + print(f"📥 正在下载模型到 {output_path}...") + print(f" URL: {url}") + + try: + urllib.request.urlretrieve(url, output_path) + print(f"✅ 模型下载成功!") + return str(output_path) + except Exception as e: + print(f"❌ 下载失败: {e}") + print(f" 请手动从以下地址下载模型:") + print(f" {url}") + sys.exit(1) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型") + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)" + ) + parser.add_argument( + "--url", + type=str, + default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx", + help="模型下载 URL" + ) + + args = parser.parse_args() + + download_model(args.output_dir, args.url) diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py new file mode 100644 index 0000000..93a185f --- /dev/null +++ b/src/base_model_trpgner/inference/__init__.py @@ -0,0 +1,292 @@ +""" +ONNX 推理模块 + +提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。 +""" + +import os +from typing import List, Dict, Any, Optional +from pathlib import Path + +try: + import numpy as np + import onnxruntime as ort + from transformers import AutoTokenizer +except ImportError as e: + raise ImportError( + "依赖未安装。请运行: pip install onnxruntime transformers numpy" + ) from e + + +# 默认模型路径(相对于包安装位置) +DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final" +# 远程模型 URL(用于自动下载) +MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" + + +class TRPGParser: + """ + TRPG 日志解析器(基于 ONNX) + + Args: + model_path: ONNX 模型路径,默认使用内置模型 + tokenizer_path: tokenizer 配置路径,默认与 model_path 相同 + device: 推理设备,"cpu" 或 "cuda" + + Examples: + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> print(result['metadata']['speaker']) + '风雨' + """ + + def __init__( + self, + model_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + device: str = "cpu", + ): + # 确定模型路径 + if model_path is None: + model_path = self._get_default_model_path() + + if tokenizer_path is None: + tokenizer_path = Path(model_path).parent + + self.model_path = Path(model_path) + self.tokenizer_path = Path(tokenizer_path) + self.device = device + + # 加载模型 + self._load_model() + + def _get_default_model_path(self) -> str: + """获取默认模型路径""" + # 1. 尝试相对于项目根目录 + project_root = Path(__file__).parent.parent.parent.parent + local_model = project_root / "models" / "trpg-final" / "model.onnx" + if local_model.exists(): + return str(local_model) + + # 2. 尝试用户数据目录 + from pathlib import Path + user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" + user_model = user_model_dir / "model.onnx" + if user_model.exists(): + return str(user_model) + + # 3. 抛出错误,提示下载 + raise FileNotFoundError( + f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n" + f"或运行: python -m basemodel.download_model" + ) + + def _load_model(self): + """加载 ONNX 模型和 Tokenizer""" + # 加载 tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + str(self.tokenizer_path), + local_files_only=True, + ) + + # 加载 ONNX 模型 + providers = ["CPUExecutionProvider"] + if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + + self.session = ort.InferenceSession( + str(self.model_path), + providers=providers, + ) + + # 加载标签映射 + import json + config_path = self.tokenizer_path / "config.json" + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()} + else: + # 默认标签 + self.id2label = { + 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment", + 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker", + 9: "B-timestamp", 10: "I-timestamp", + } + + def parse(self, text: str) -> Dict[str, Any]: + """ + 解析单条 TRPG 日志 + + Args: + text: 待解析的日志文本 + + Returns: + 包含 metadata 和 content 的字典 + - metadata: speaker, timestamp + - content: dialogue, action, comment 列表 + + Examples: + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> result['metadata']['speaker'] + '风雨' + """ + # Tokenize + inputs = self.tokenizer( + text, + return_tensors="np", + return_offsets_mapping=True, + padding="max_length", + truncation=True, + max_length=128, + ) + + # 推理 + outputs = self.session.run( + ["logits"], + { + "input_ids": inputs["input_ids"].astype(np.int64), + "attention_mask": inputs["attention_mask"].astype(np.int64), + }, + ) + + # 后处理 + logits = outputs[0][0] + predictions = np.argmax(logits, axis=-1) + offsets = inputs["offset_mapping"][0] + + # 聚合实体 + entities = self._group_entities(predictions, offsets, logits) + + # 构建结果 + result = {"metadata": {}, "content": []} + for ent in entities: + if ent["start"] >= len(text) or ent["end"] > len(text): + continue + + raw_text = text[ent["start"]: ent["end"]] + clean_text = self._clean_text(raw_text, ent["type"]) + + if not clean_text.strip(): + continue + + if ent["type"] in ["timestamp", "speaker"]: + result["metadata"][ent["type"]] = clean_text + elif ent["type"] in ["dialogue", "action", "comment"]: + result["content"].append({ + "type": ent["type"], + "content": clean_text, + "confidence": round(ent["score"], 3), + }) + + return result + + def _group_entities(self, predictions, offsets, logits): + """将 token 级别的预测聚合为实体""" + entities = [] + current = None + + for i in range(len(predictions)): + start, end = offsets[i] + if start == end: # special tokens + continue + + pred_id = int(predictions[i]) + label = self.id2label.get(pred_id, "O") + + if label == "O": + if current: + entities.append(current) + current = None + continue + + tag_type = label[2:] if len(label) > 2 else "O" + + if label.startswith("B-"): + if current: + entities.append(current) + current = { + "type": tag_type, + "start": int(start), + "end": int(end), + "score": float(np.max(logits[i])), + } + elif label.startswith("I-") and current and current["type"] == tag_type: + current["end"] = int(end) + else: + if current: + entities.append(current) + current = None + + if current: + entities.append(current) + + return entities + + def _clean_text(self, text: str, group: str) -> str: + """清理提取的文本""" + import re + + text = text.strip() + + # 移除周围符号 + if group == "comment": + text = re.sub(r"^[((]+|[))]+$", "", text) + elif group == "dialogue": + text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text) + elif group == "action": + text = re.sub(r"^[*#]+|[*#]+$", "", text) + + # 修复时间戳 + if group == "timestamp" and text and text[0].isdigit(): + if len(text) > 2 and text[2] == "-": + text = "20" + text + + return text + + def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: + """ + 批量解析多条日志 + + Args: + texts: 日志文本列表 + + Returns: + 解析结果列表 + """ + return [self.parse(text) for text in texts] + + +# 便捷函数 +def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]: + """ + 解析单条日志的便捷函数 + + Args: + text: 日志文本 + model_path: 可选的模型路径 + + Returns: + 解析结果字典 + """ + parser = TRPGParser(model_path=model_path) + return parser.parse(text) + + +def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]: + """ + 批量解析日志的便捷函数 + + Args: + texts: 日志文本列表 + model_path: 可选的模型路径 + + Returns: + 解析结果列表 + """ + parser = TRPGParser(model_path=model_path) + return parser.parse_batch(texts) + + +__all__ = ["TRPGParser", "parse_line", "parse_lines"] diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py new file mode 100644 index 0000000..ccf3c03 --- /dev/null +++ b/src/base_model_trpgner/training/__init__.py @@ -0,0 +1,205 @@ +""" +训练模块 + +提供 TRPG NER 模型训练功能。 + +注意: 使用此模块需要安装训练依赖: + pip install base-model-trpgner[train] +""" + +import os +from typing import Optional, List +from pathlib import Path + + +def train_ner_model( + conll_data: str, + model_name_or_path: str = "hfl/minirbt-h256", + output_dir: str = "./models/trpg-ner-v1", + num_train_epochs: int = 20, + per_device_train_batch_size: int = 4, + learning_rate: float = 5e-5, + max_length: int = 128, + resume_from_checkpoint: Optional[str] = None, +) -> None: + """ + 训练 NER 模型 + + Args: + conll_data: CoNLL 格式数据文件或目录 + model_name_or_path: 基础模型名称或路径 + output_dir: 模型输出目录 + num_train_epochs: 训练轮数 + per_device_train_batch_size: 批处理大小 + learning_rate: 学习率 + max_length: 最大序列长度 + resume_from_checkpoint: 恢复检查点路径 + + Examples: + >>> from basemodeltrpgner.training import train_ner_model + >>> train_ner_model( + ... conll_data="./data", + ... output_dir="./my_model", + ... epochs=10 + ... ) + """ + try: + import torch + from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + TrainingArguments, + Trainer, + ) + from datasets import Dataset + from tqdm.auto import tqdm + except ImportError as e: + raise ImportError( + "训练依赖未安装。请运行: pip install base-model-trpgner[train]" + ) from e + + # 导入数据处理函数 + from base_model_trpgner.utils.conll import load_conll_dataset, tokenize_and_align_labels + + print(f"🚀 Starting training...") + + # 加载数据 + dataset, label_list = load_conll_dataset(conll_data) + label2id = {label: i for i, label in enumerate(label_list)} + id2label = {i: label for i, label in enumerate(label_list)} + + # 初始化模型 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + if tokenizer.model_max_length > 1000: + tokenizer.model_max_length = max_length + + model = AutoModelForTokenClassification.from_pretrained( + model_name_or_path, + num_labels=len(label_list), + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, + ) + + # Tokenize + tokenized_dataset = dataset.map( + lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length), + batched=True, + remove_columns=["text", "char_labels"], + ) + + # 训练参数 + training_args = TrainingArguments( + output_dir=output_dir, + learning_rate=learning_rate, + per_device_train_batch_size=per_device_train_batch_size, + num_train_epochs=num_train_epochs, + logging_steps=5, + save_steps=200, + save_total_limit=2, + do_eval=False, + report_to="none", + no_cuda=not torch.cuda.is_available(), + load_best_model_at_end=False, + push_to_hub=False, + fp16=torch.cuda.is_available(), + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + ) + + # 开始训练 + print("🚀 Starting training...") + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + # 保存模型 + print("💾 Saving final model...") + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) + + print(f"✅ Training finished. Model saved to {output_dir}") + + +def export_to_onnx( + model_dir: str, + onnx_path: str, + max_length: int = 128, +) -> bool: + """ + 将训练好的模型导出为 ONNX 格式 + + Args: + model_dir: 模型目录 + onnx_path: ONNX 输出路径 + max_length: 最大序列长度 + + Returns: + 是否成功 + """ + try: + import torch + from torch.onnx import export as onnx_export + from transformers import AutoTokenizer, AutoModelForTokenClassification + import onnx + except ImportError as e: + raise ImportError( + "ONNX 导出依赖未安装。请运行: pip install onnx" + ) from e + + print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + + model_dir = os.path.abspath(model_dir) + if not os.path.exists(model_dir): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + # 加载模型 + tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) + model = AutoModelForTokenClassification.from_pretrained( + model_dir, local_files_only=True + ) + model.eval() + + # 创建虚拟输入 + dummy_text = "莎莎 2024-06-08 21:46:26" + inputs = tokenizer( + dummy_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + ) + + # 确保目录存在 + os.makedirs(os.path.dirname(onnx_path), exist_ok=True) + + # 导出 ONNX + onnx_export( + model, + (inputs["input_ids"], inputs["attention_mask"]), + onnx_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "logits": {0: "batch_size", 1: "sequence_length"}, + }, + ) + + # 验证 ONNX 模型 + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + size_mb = os.path.getsize(onnx_path) / 1024 / 1024 + print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB") + return True + + +__all__ = ["train_ner_model", "export_to_onnx"] diff --git a/src/base_model_trpgner/utils/__init__.py b/src/base_model_trpgner/utils/__init__.py new file mode 100644 index 0000000..12a3ef4 --- /dev/null +++ b/src/base_model_trpgner/utils/__init__.py @@ -0,0 +1,192 @@ +""" +工具模块 + +提供数据加载、CoNLL 格式处理等工具函数。 +""" + +import os +import glob +from typing import List, Dict, Any, Tuple +from datasets import Dataset +from tqdm.auto import tqdm + + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """Convert word-level labels to char-level""" + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + if text[pos: pos + len(token)] == token: + for i in range(len(token)): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + + +def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: + """Parse .conll → [{"text": str, "char_labels": List[str]}]""" + with open(filepath, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # 检测 word-level + is_word_level = any( + len(line.split()[0]) > 1 + for line in lines + if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 + ) + + samples = [] + if is_word_level: + current_text_parts = [] + current_word_labels = [] + + for line in lines: + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) >= 4: + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + else: + current_text = [] + current_labels = [] + + for line in lines: + if line.startswith("-DOCSTART-"): + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) >= 4: + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + + return samples + + +def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: + """Load .conll files → Dataset""" + filepaths = [] + if os.path.isdir(conll_dir_or_files): + filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) + elif conll_dir_or_files.endswith(".conll"): + filepaths = [conll_dir_or_files] + else: + raise ValueError("conll_dir_or_files must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") + + print(f"Loading {len(filepaths)} conll files: {filepaths}") + + all_samples = [] + label_set = {"O"} + + for fp in tqdm(filepaths, desc="Parsing .conll"): + samples = parse_conll_file(fp) + for s in samples: + all_samples.append(s) + label_set.update(s["char_labels"]) + + # Build label list + label_list = ["O"] + for label in sorted(label_set - {"O"}): + if label.startswith("B-") or label.startswith("I-"): + label_list.append(label) + for label in list(label_list): + if label.startswith("B-"): + i_label = "I" + label[1:] + if i_label not in label_list: + label_list.append(i_label) + print(f"⚠️ Added missing {i_label} for {label}") + + print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}") + return Dataset.from_list(all_samples), label_list + + +def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): + """Tokenize and align labels with tokenizer""" + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=True, + max_length=max_length, + return_offsets_mapping=True, + return_tensors=None, + ) + + labels = [] + for i, label_seq in enumerate(examples["char_labels"]): + offsets = tokenized["offset_mapping"][i] + label_ids = [] + for start, end in offsets: + if start == end: + label_ids.append(-100) + else: + label_ids.append(label2id[label_seq[start]]) + labels.append(label_ids) + + tokenized["labels"] = labels + return tokenized + + +__all__ = [ + "word_to_char_labels", + "parse_conll_file", + "load_conll_dataset", + "tokenize_and_align_labels", +] diff --git a/src/basemodel/__init__.py b/src/basemodel/__init__.py deleted file mode 100644 index 7287df4..0000000 --- a/src/basemodel/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -base-model - HydroRoll TRPG NER 模型 SDK - -这是一个用于 TRPG(桌上角色扮演游戏)日志命名实体识别的 Python SDK。 - -基本用法: - >>> from basemodel import TRPGParser - >>> parser = TRPGParser() - >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") - >>> print(result) - {'metadata': {'speaker': '风雨', 'timestamp': '2024-06-08 21:44:59'}, 'content': [...]} - -训练功能(需要额外安装): - >>> pip install base-model[train] - >>> from basemodel.training import train_ner_model - >>> train_ner_model(conll_data="./data", output_dir="./model") -""" - -from basemodel.inference import TRPGParser, parse_line, parse_lines - -try: - from importlib.metadata import version - __version__ = version("base-model") -except Exception: - __version__ = "0.1.0.dev" - -__all__ = [ - "__version__", - "TRPGParser", - "parse_line", - "parse_lines", -] - - -def get_version(): - return __version__ diff --git a/src/basemodel/download_model.py b/src/basemodel/download_model.py deleted file mode 100644 index 2d65099..0000000 --- a/src/basemodel/download_model.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -""" -模型下载脚本 - -自动下载预训练的 ONNX 模型到用户缓存目录。 -""" - -import os -import sys -from pathlib import Path -import urllib.request - - -def download_model( - output_dir: str = None, - url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" -): - """ - 下载 ONNX 模型 - - Args: - output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/ - url: 模型下载 URL - """ - if output_dir is None: - output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" - else: - output_dir = Path(output_dir) - - output_dir.mkdir(parents=True, exist_ok=True) - output_path = output_dir / "model.onnx" - - if output_path.exists(): - print(f"✅ 模型已存在: {output_path}") - return str(output_path) - - print(f"📥 正在下载模型到 {output_path}...") - print(f" URL: {url}") - - try: - urllib.request.urlretrieve(url, output_path) - print(f"✅ 模型下载成功!") - return str(output_path) - except Exception as e: - print(f"❌ 下载失败: {e}") - print(f" 请手动从以下地址下载模型:") - print(f" {url}") - sys.exit(1) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型") - parser.add_argument( - "--output-dir", - type=str, - default=None, - help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)" - ) - parser.add_argument( - "--url", - type=str, - default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx", - help="模型下载 URL" - ) - - args = parser.parse_args() - - download_model(args.output_dir, args.url) diff --git a/src/basemodel/inference/__init__.py b/src/basemodel/inference/__init__.py deleted file mode 100644 index 93a185f..0000000 --- a/src/basemodel/inference/__init__.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -ONNX 推理模块 - -提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。 -""" - -import os -from typing import List, Dict, Any, Optional -from pathlib import Path - -try: - import numpy as np - import onnxruntime as ort - from transformers import AutoTokenizer -except ImportError as e: - raise ImportError( - "依赖未安装。请运行: pip install onnxruntime transformers numpy" - ) from e - - -# 默认模型路径(相对于包安装位置) -DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final" -# 远程模型 URL(用于自动下载) -MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" - - -class TRPGParser: - """ - TRPG 日志解析器(基于 ONNX) - - Args: - model_path: ONNX 模型路径,默认使用内置模型 - tokenizer_path: tokenizer 配置路径,默认与 model_path 相同 - device: 推理设备,"cpu" 或 "cuda" - - Examples: - >>> parser = TRPGParser() - >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") - >>> print(result['metadata']['speaker']) - '风雨' - """ - - def __init__( - self, - model_path: Optional[str] = None, - tokenizer_path: Optional[str] = None, - device: str = "cpu", - ): - # 确定模型路径 - if model_path is None: - model_path = self._get_default_model_path() - - if tokenizer_path is None: - tokenizer_path = Path(model_path).parent - - self.model_path = Path(model_path) - self.tokenizer_path = Path(tokenizer_path) - self.device = device - - # 加载模型 - self._load_model() - - def _get_default_model_path(self) -> str: - """获取默认模型路径""" - # 1. 尝试相对于项目根目录 - project_root = Path(__file__).parent.parent.parent.parent - local_model = project_root / "models" / "trpg-final" / "model.onnx" - if local_model.exists(): - return str(local_model) - - # 2. 尝试用户数据目录 - from pathlib import Path - user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" - user_model = user_model_dir / "model.onnx" - if user_model.exists(): - return str(user_model) - - # 3. 抛出错误,提示下载 - raise FileNotFoundError( - f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n" - f"或运行: python -m basemodel.download_model" - ) - - def _load_model(self): - """加载 ONNX 模型和 Tokenizer""" - # 加载 tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - str(self.tokenizer_path), - local_files_only=True, - ) - - # 加载 ONNX 模型 - providers = ["CPUExecutionProvider"] - if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers(): - providers.insert(0, "CUDAExecutionProvider") - - self.session = ort.InferenceSession( - str(self.model_path), - providers=providers, - ) - - # 加载标签映射 - import json - config_path = self.tokenizer_path / "config.json" - if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: - config = json.load(f) - self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()} - else: - # 默认标签 - self.id2label = { - 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment", - 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker", - 9: "B-timestamp", 10: "I-timestamp", - } - - def parse(self, text: str) -> Dict[str, Any]: - """ - 解析单条 TRPG 日志 - - Args: - text: 待解析的日志文本 - - Returns: - 包含 metadata 和 content 的字典 - - metadata: speaker, timestamp - - content: dialogue, action, comment 列表 - - Examples: - >>> parser = TRPGParser() - >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") - >>> result['metadata']['speaker'] - '风雨' - """ - # Tokenize - inputs = self.tokenizer( - text, - return_tensors="np", - return_offsets_mapping=True, - padding="max_length", - truncation=True, - max_length=128, - ) - - # 推理 - outputs = self.session.run( - ["logits"], - { - "input_ids": inputs["input_ids"].astype(np.int64), - "attention_mask": inputs["attention_mask"].astype(np.int64), - }, - ) - - # 后处理 - logits = outputs[0][0] - predictions = np.argmax(logits, axis=-1) - offsets = inputs["offset_mapping"][0] - - # 聚合实体 - entities = self._group_entities(predictions, offsets, logits) - - # 构建结果 - result = {"metadata": {}, "content": []} - for ent in entities: - if ent["start"] >= len(text) or ent["end"] > len(text): - continue - - raw_text = text[ent["start"]: ent["end"]] - clean_text = self._clean_text(raw_text, ent["type"]) - - if not clean_text.strip(): - continue - - if ent["type"] in ["timestamp", "speaker"]: - result["metadata"][ent["type"]] = clean_text - elif ent["type"] in ["dialogue", "action", "comment"]: - result["content"].append({ - "type": ent["type"], - "content": clean_text, - "confidence": round(ent["score"], 3), - }) - - return result - - def _group_entities(self, predictions, offsets, logits): - """将 token 级别的预测聚合为实体""" - entities = [] - current = None - - for i in range(len(predictions)): - start, end = offsets[i] - if start == end: # special tokens - continue - - pred_id = int(predictions[i]) - label = self.id2label.get(pred_id, "O") - - if label == "O": - if current: - entities.append(current) - current = None - continue - - tag_type = label[2:] if len(label) > 2 else "O" - - if label.startswith("B-"): - if current: - entities.append(current) - current = { - "type": tag_type, - "start": int(start), - "end": int(end), - "score": float(np.max(logits[i])), - } - elif label.startswith("I-") and current and current["type"] == tag_type: - current["end"] = int(end) - else: - if current: - entities.append(current) - current = None - - if current: - entities.append(current) - - return entities - - def _clean_text(self, text: str, group: str) -> str: - """清理提取的文本""" - import re - - text = text.strip() - - # 移除周围符号 - if group == "comment": - text = re.sub(r"^[((]+|[))]+$", "", text) - elif group == "dialogue": - text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text) - elif group == "action": - text = re.sub(r"^[*#]+|[*#]+$", "", text) - - # 修复时间戳 - if group == "timestamp" and text and text[0].isdigit(): - if len(text) > 2 and text[2] == "-": - text = "20" + text - - return text - - def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: - """ - 批量解析多条日志 - - Args: - texts: 日志文本列表 - - Returns: - 解析结果列表 - """ - return [self.parse(text) for text in texts] - - -# 便捷函数 -def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]: - """ - 解析单条日志的便捷函数 - - Args: - text: 日志文本 - model_path: 可选的模型路径 - - Returns: - 解析结果字典 - """ - parser = TRPGParser(model_path=model_path) - return parser.parse(text) - - -def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]: - """ - 批量解析日志的便捷函数 - - Args: - texts: 日志文本列表 - model_path: 可选的模型路径 - - Returns: - 解析结果列表 - """ - parser = TRPGParser(model_path=model_path) - return parser.parse_batch(texts) - - -__all__ = ["TRPGParser", "parse_line", "parse_lines"] diff --git a/src/basemodel/training/__init__.py b/src/basemodel/training/__init__.py deleted file mode 100644 index 5671c42..0000000 --- a/src/basemodel/training/__init__.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -训练模块 - -提供 TRPG NER 模型训练功能。 - -注意: 使用此模块需要安装训练依赖: - pip install base-model-trpgner[train] -""" - -import os -from typing import Optional, List -from pathlib import Path - - -def train_ner_model( - conll_data: str, - model_name_or_path: str = "hfl/minirbt-h256", - output_dir: str = "./models/trpg-ner-v1", - num_train_epochs: int = 20, - per_device_train_batch_size: int = 4, - learning_rate: float = 5e-5, - max_length: int = 128, - resume_from_checkpoint: Optional[str] = None, -) -> None: - """ - 训练 NER 模型 - - Args: - conll_data: CoNLL 格式数据文件或目录 - model_name_or_path: 基础模型名称或路径 - output_dir: 模型输出目录 - num_train_epochs: 训练轮数 - per_device_train_batch_size: 批处理大小 - learning_rate: 学习率 - max_length: 最大序列长度 - resume_from_checkpoint: 恢复检查点路径 - - Examples: - >>> from basemodel.training import train_ner_model - >>> train_ner_model( - ... conll_data="./data", - ... output_dir="./my_model", - ... epochs=10 - ... ) - """ - try: - import torch - from transformers import ( - AutoTokenizer, - AutoModelForTokenClassification, - TrainingArguments, - Trainer, - ) - from datasets import Dataset - from tqdm.auto import tqdm - except ImportError as e: - raise ImportError( - "训练依赖未安装。请运行: pip install base-model-trpgner[train]" - ) from e - - # 导入数据处理函数 - from basemodel.utils.conll import load_conll_dataset, tokenize_and_align_labels - - print(f"🚀 Starting training...") - - # 加载数据 - dataset, label_list = load_conll_dataset(conll_data) - label2id = {label: i for i, label in enumerate(label_list)} - id2label = {i: label for i, label in enumerate(label_list)} - - # 初始化模型 - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - if tokenizer.model_max_length > 1000: - tokenizer.model_max_length = max_length - - model = AutoModelForTokenClassification.from_pretrained( - model_name_or_path, - num_labels=len(label_list), - id2label=id2label, - label2id=label2id, - ignore_mismatched_sizes=True, - ) - - # Tokenize - tokenized_dataset = dataset.map( - lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length), - batched=True, - remove_columns=["text", "char_labels"], - ) - - # 训练参数 - training_args = TrainingArguments( - output_dir=output_dir, - learning_rate=learning_rate, - per_device_train_batch_size=per_device_train_batch_size, - num_train_epochs=num_train_epochs, - logging_steps=5, - save_steps=200, - save_total_limit=2, - do_eval=False, - report_to="none", - no_cuda=not torch.cuda.is_available(), - load_best_model_at_end=False, - push_to_hub=False, - fp16=torch.cuda.is_available(), - ) - - trainer = Trainer( - model=model, - args=training_args, - train_dataset=tokenized_dataset, - tokenizer=tokenizer, - ) - - # 开始训练 - print("🚀 Starting training...") - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - - # 保存模型 - print("💾 Saving final model...") - trainer.save_model(output_dir) - tokenizer.save_pretrained(output_dir) - - print(f"✅ Training finished. Model saved to {output_dir}") - - -def export_to_onnx( - model_dir: str, - onnx_path: str, - max_length: int = 128, -) -> bool: - """ - 将训练好的模型导出为 ONNX 格式 - - Args: - model_dir: 模型目录 - onnx_path: ONNX 输出路径 - max_length: 最大序列长度 - - Returns: - 是否成功 - """ - try: - import torch - from torch.onnx import export as onnx_export - from transformers import AutoTokenizer, AutoModelForTokenClassification - import onnx - except ImportError as e: - raise ImportError( - "ONNX 导出依赖未安装。请运行: pip install onnx" - ) from e - - print(f"📤 Exporting model from {model_dir} to {onnx_path}...") - - model_dir = os.path.abspath(model_dir) - if not os.path.exists(model_dir): - raise FileNotFoundError(f"Model directory not found: {model_dir}") - - # 加载模型 - tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) - model = AutoModelForTokenClassification.from_pretrained( - model_dir, local_files_only=True - ) - model.eval() - - # 创建虚拟输入 - dummy_text = "莎莎 2024-06-08 21:46:26" - inputs = tokenizer( - dummy_text, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=max_length, - ) - - # 确保目录存在 - os.makedirs(os.path.dirname(onnx_path), exist_ok=True) - - # 导出 ONNX - onnx_export( - model, - (inputs["input_ids"], inputs["attention_mask"]), - onnx_path, - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=["input_ids", "attention_mask"], - output_names=["logits"], - dynamic_axes={ - "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "sequence_length"}, - "logits": {0: "batch_size", 1: "sequence_length"}, - }, - ) - - # 验证 ONNX 模型 - onnx_model = onnx.load(onnx_path) - onnx.checker.check_model(onnx_model) - - size_mb = os.path.getsize(onnx_path) / 1024 / 1024 - print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB") - return True - - -__all__ = ["train_ner_model", "export_to_onnx"] diff --git a/src/basemodel/utils/__init__.py b/src/basemodel/utils/__init__.py deleted file mode 100644 index 12a3ef4..0000000 --- a/src/basemodel/utils/__init__.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -工具模块 - -提供数据加载、CoNLL 格式处理等工具函数。 -""" - -import os -import glob -from typing import List, Dict, Any, Tuple -from datasets import Dataset -from tqdm.auto import tqdm - - -def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: - """Convert word-level labels to char-level""" - char_labels = ["O"] * len(text) - pos = 0 - - for token, label in word_labels: - if pos >= len(text): - break - - while pos < len(text) and text[pos] != token[0]: - pos += 1 - if pos >= len(text): - break - - if text[pos: pos + len(token)] == token: - for i in range(len(token)): - idx = pos + i - if idx < len(char_labels): - if i == 0 and label.startswith("B-"): - char_labels[idx] = label - elif label.startswith("B-"): - char_labels[idx] = "I" + label[1:] - else: - char_labels[idx] = label - pos += len(token) - else: - pos += 1 - - return char_labels - - -def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: - """Parse .conll → [{"text": str, "char_labels": List[str]}]""" - with open(filepath, "r", encoding="utf-8") as f: - lines = [line.rstrip("\n") for line in f.readlines()] - - # 检测 word-level - is_word_level = any( - len(line.split()[0]) > 1 - for line in lines - if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 - ) - - samples = [] - if is_word_level: - current_text_parts = [] - current_word_labels = [] - - for line in lines: - if not line or line.startswith("-DOCSTART-"): - if current_text_parts: - text = "".join(current_text_parts) - char_labels = word_to_char_labels(text, current_word_labels) - samples.append({"text": text, "char_labels": char_labels}) - current_text_parts = [] - current_word_labels = [] - continue - - parts = line.split() - if len(parts) >= 4: - token, label = parts[0], parts[3] - current_text_parts.append(token) - current_word_labels.append((token, label)) - - if current_text_parts: - text = "".join(current_text_parts) - char_labels = word_to_char_labels(text, current_word_labels) - samples.append({"text": text, "char_labels": char_labels}) - else: - current_text = [] - current_labels = [] - - for line in lines: - if line.startswith("-DOCSTART-"): - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy(), - }) - current_text, current_labels = [], [] - continue - - if not line: - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy(), - }) - current_text, current_labels = [], [] - continue - - parts = line.split() - if len(parts) >= 4: - char = parts[0].replace("\\n", "\n") - label = parts[3] - current_text.append(char) - current_labels.append(label) - - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy(), - }) - - return samples - - -def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: - """Load .conll files → Dataset""" - filepaths = [] - if os.path.isdir(conll_dir_or_files): - filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) - elif conll_dir_or_files.endswith(".conll"): - filepaths = [conll_dir_or_files] - else: - raise ValueError("conll_dir_or_files must be .conll file or directory") - - if not filepaths: - raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") - - print(f"Loading {len(filepaths)} conll files: {filepaths}") - - all_samples = [] - label_set = {"O"} - - for fp in tqdm(filepaths, desc="Parsing .conll"): - samples = parse_conll_file(fp) - for s in samples: - all_samples.append(s) - label_set.update(s["char_labels"]) - - # Build label list - label_list = ["O"] - for label in sorted(label_set - {"O"}): - if label.startswith("B-") or label.startswith("I-"): - label_list.append(label) - for label in list(label_list): - if label.startswith("B-"): - i_label = "I" + label[1:] - if i_label not in label_list: - label_list.append(i_label) - print(f"⚠️ Added missing {i_label} for {label}") - - print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}") - return Dataset.from_list(all_samples), label_list - - -def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): - """Tokenize and align labels with tokenizer""" - tokenized = tokenizer( - examples["text"], - truncation=True, - padding=True, - max_length=max_length, - return_offsets_mapping=True, - return_tensors=None, - ) - - labels = [] - for i, label_seq in enumerate(examples["char_labels"]): - offsets = tokenized["offset_mapping"][i] - label_ids = [] - for start, end in offsets: - if start == end: - label_ids.append(-100) - else: - label_ids.append(label2id[label_seq[start]]) - labels.append(label_ids) - - tokenized["labels"] = labels - return tokenized - - -__all__ = [ - "word_to_char_labels", - "parse_conll_file", - "load_conll_dataset", - "tokenize_and_align_labels", -] -- cgit v1.2.3-70-g09d2