diff options
| -rw-r--r-- | .gitignore | 6 | ||||
| -rw-r--r-- | README.md | 296 | ||||
| -rw-r--r-- | main.py | 632 | ||||
| -rw-r--r-- | pyproject.toml | 5 | ||||
| -rw-r--r-- | requirements.txt | 136 | ||||
| -rw-r--r-- | src/utils/conll_to_dataset.py | 263 | ||||
| -rw-r--r-- | src/utils/word_conll_to_char_conll.py | 55 | ||||
| -rw-r--r-- | tests/onnx_infer.py | 115 |
8 files changed, 1502 insertions, 6 deletions
@@ -161,4 +161,8 @@ cython_debug/ # uv .python-version -uv.lock
\ No newline at end of file +uv.lock + +# model +models/ +dataset/
\ No newline at end of file @@ -1,2 +1,294 @@ -# base-model -Base NLP Model for HydroRoll. +# TRPG NER 模型 - HydroRoll 基础 NLP 模型 + +基于 MiniRBT (hfl/minirbt-h256) 的中文 TRPG(桌上角色扮演游戏)日志命名实体识别系统,支持训练、推理、ONNX 导出和 WebUI 可视化标注平台。 + +## 功能特性 + +- **NER 实体识别**: 自动识别 TRPG 日志中的发言者、时间戳、对话、动作、注释等实体 +- **灵活训练**: 支持初次训练、增量训练、标签维度扩展/缩减 +- **ONNX 导出**: 支持 ONNX 格式导出,实现 CPU 高速推理 +- **WebUI 平台**: Label Studio 风格的可视化标注界面 (WIP) +- **数据转换**: 支持 Word-level/Char-level CoNLL 格式互转 +- **自动修复**: 智能修复截断的时间戳和发言者名称 + +## 实体类型 + +| 标签 | 说明 | 示例 | +| ------------- | --------- | ------------------------------------- | +| `speaker` | 发言者 | "风雨" | +| `timestamp` | 时间戳 | "2024-06-08 21:44:59" | +| `dialogue` | 对话内容 | ""呜哇..."" | +| `action` | 动作描述 | "剧烈的疼痛从头颅深处一波波地涌出..." | +| `comment` | 注释/旁白 | "(红木家具上刻着一行小字)" | + +## 安装 + +### 环境要求 + +- Python >= 3.12 +- CUDA (可选,用于 GPU 加速) + +### 安装依赖 + +```bash +# 使用 uv (推荐) +uv sync + +# 或使用 pip +pip install -r requirements.txt +``` + +``` + +``` + +### 1. 数据准备 + +将训练数据准备为 CoNLL 格式,支持两种格式: + +#### Char-level 格式 (推荐) + +``` +-DOCSTART- -X- O +风 -X- _ B-speaker +雨 -X- _ I-speaker + -X- _ O +2 -X- _ B-timestamp +0 -X- _ I-timestamp +... +``` + +#### Word-level 格式 + +``` +风雨 O O B-speaker +2024-06-08 O O B-timestamp +21:44:59 O O I-timestamp +``` + +### 2. 初次训练 + +```bash +# 基础训练(使用默认参数) +uv run main.py --train --conll ./data + +# 完整参数训练 +uv run main.py --train \ + --conll ./data \ + --model hfl/minirbt-h256 \ + --output ./models/trpg-final \ + --epochs 20 \ + --batch 4 +``` + +#### 参数说明 + +| 参数 | 说明 | 默认值 | +| ------------ | -------------------- | ------------------------ | +| `--train` | 启用训练模式 | - | +| `--conll` | CoNLL 文件或目录路径 | `./data` | +| `--model` | 基础模型名称 | `hfl/minirbt-h256` | +| `--output` | 模型输出目录 | `./models/trpg-ner-v1` | +| `--epochs` | 训练轮数 | `20` | +| `--batch` | 批处理大小 | `4` | +| `--resume` | 恢复检查点路径 | `None` | + +### 3. 推理测试 + +```bash +# 单文本测试 +uv run main.py --test "风雨 2024-06-08 21:44:59 剧烈的疼痛从头颅深处一波波地涌出..." + +# 多文本测试 +uv run main.py --test \ + "莎莎 2024-06-08 21:46:26 \"呜哇...\" 下意识去拿法杖" \ + "BOT 2024-06-08 21:50:03 莎莎 的出目是 D10+7=6+7=13" + +``` + +#### 输出格式 + +```json +{ + "metadata": { + "speaker": "风雨", + "timestamp": "2024-06-08 21:44:59" + }, + "content": [ + { + "type": "comment", + "content": "剧烈的疼痛从头颅深处一波波地涌出...", + "confidence": 0.952 + } + ] +} +``` + +## 增量训练 + +在已有模型基础上继续训练新数据: + +```bash +# 继续训练(自动加载最新检查点) +uv run main.py --train \ + --conll ./new_data \ + --output ./models/trpg-final \ + --epochs 5 + +# 从指定检查点恢复 +uv run main.py --train \ + --conll ./new_data \ + --output ./models/trpg-final \ + --resume ./models/trpg-final/checkpoint-200 \ + --epochs 5 +``` + +## 标签维度管理 + +### 添加新标签训练 + +在 CoNLL 数据中添加新的标签类型(如 `emotion`),模型会自动适配: + +```bash +# 添加新标签后重新训练(使用 ignore_mismatched_sizes) +uv run main.py --train \ + --conll ./data_with_emotion \ + --output ./models/trpg-final \ + --epochs 15 +``` + +### 减少标签维度训练 + +如果需要减少标签类型,建议从头训练: + +```bash +# 从基础模型重新训练(使用缩减后的标签集) +uv run main.py --train \ + --conll ./data_reduced_labels \ + --model hfl/minirbt-h256 \ + --output ./models/trpg-reduced \ + --epochs 20 +``` + +## ONNX 导出 + +将训练好的模型导出为 ONNX 格式,用于 CPU 推理加速。 + +> **注意**: ONNX 导出功能会自动使用推理时找到的模型目录。如果模型目录不在默认位置,请先���用 `--output` 指定正确的模型目录。 + +```bash +# 方式一:从默认模型目录导出(自动查找 models/trpg-final) +uv run main.py --export_onnx \ + --onnx_path ./models/trpg-final/model.onnx + +# 方式二:指定模型目录导出 +uv run main.py --export_onnx \ + --output ./models/trpg-final \ + --onnx_path ./models/trpg-final/model.onnx + +# 方式三:导出到其他路径 +uv run main.py --export_onnx \ + --onnx_path ./models/trpg-optimized.onnx +``` + +### ONNX 模型特点 + +- 支持 CPU 推理(无需 GPU) +- 模型大小约 50-100 MB +- 推理速度约 10-50 ms/句(取决于硬件) +- 兼容 Windows/Linux/macOS/Raspberry Pi + +## ONNX 模型测试 + +### 基础推理测试 + +```bash +# 使用 ONNX 模型进行推理 +uv run tests/onnx_infer.py "风雨 2024-06-08 21:44:59 剧烈的疼痛..." +``` + +#### 性能指标 + +``` +Performance Results (n=100): + Average latency: 25.32 ms + P95 latency: 31.45 ms + Max RAM usage: 120.5 MB + Throughput: 39.5 sentences/sec +``` + +## 数据转换工具 + +### CoNLL 转 Dataset + +```bash +# 转换单个文件 +uv run src/utils/conll_to_dataset.py data.conll \ + --output ./dataset/trpg \ + --format jsonl \ + --validate + +# 转换整个目录 +uv run src/utils/conll_to_dataset.py ./data \ + --output ./dataset/trpg \ + --format both \ + --validate +``` + +#### 输出格式 + +- `--format jsonl`: 输出 JSONL 格式 +- `--format dataset`: 输出 HuggingFace Dataset 格式 +- `--format both`: 同时输出两种格式 + +### Word-level 转 Char-level CoNLL + +```bash +uv run src/utils/word_conll_to_char_conll.py \ + input_word.conll \ + output_char.conll +``` + +## 高级功能 + +### 自定义标签 + +在 `src/webui/utils.py` 中修改 `DEFAULT_LABELS`: + +```python +DEFAULT_LABELS = [ + {"name": "timestamp", "color": "#87CEEB", "type": "text"}, + {"name": "speaker", "color": "#90EE90", "type": "text"}, + {"name": "dialogue", "color": "#FFB6C1", "type": "text"}, + {"name": "action", "color": "#DDA0DD", "type": "text"}, + {"name": "comment", "color": "#FFD700", "type": "text"}, + # 添加新标签 + {"name": "emotion", "color": "#FFA07A", "type": "text"}, +] +``` + +## 系统要求 + +### 训练环境 + +- CPU: Intel Core i5 或同等性能 +- RAM: 8 GB+ +- GPU: NVIDIA GPU(可选,用于加速) +- 存储: 5 GB+ + +### ONNX 推理环境 + +- CPU: Intel Core i3 或同等性能 +- RAM: 2 GB+ +- 存储: 100 MB +- 支持 Raspberry Pi 4 (4GB RAM) + +## 开源协议 + +本项目采用 [AFL-3.0](COPYING) 协议开源。 + +## 相关链接 + +- [MiniRBT 模型](https://huggingface.co/hfl/minirbt-h256) +- [Transformers 文档](https://huggingface.co/docs/transformers) +- [ONNX Runtime](https://onnxruntime.ai/) @@ -1,4 +1,630 @@ -from transformers import BertTokenizer, BertModel +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +TRPG NER 训练与推理脚本 (Robust Edition) +- 自动探测模型路径(支持 safetensors/pytorch) +- 统一 tokenizer 行为(offset_mapping) +- 智能修复 timestamp/speaker 截断 +- 兼容 transformers <4.5 +""" -tokenizer = BertTokenizer.from_pretrained("hfl/minirbt-h256") -model = BertModel.from_pretrained("hfl/minirbt-h256")
\ No newline at end of file +import os +import re +import glob +import sys +from typing import List, Dict, Any, Tuple, Optional +from pathlib import Path + +import torch +from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + TrainingArguments, + Trainer, + pipeline, + logging as hf_logging, +) +from datasets import Dataset +from tqdm.auto import tqdm + +# 抑制 transformers 警告 +hf_logging.set_verbosity_error() + + +# =========================== +# 1. CoNLL 解析器(自动 word→char 转换) +# =========================== + + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """Convert word-level labels to char-level""" + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + # 定位 token(跳过空格/换行) + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + if text[pos : pos + len(token)] == token: + for i, _ in enumerate(token): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + + +def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: + """Parse .conll → [{"text": str, "char_labels": List[str]}]""" + with open(filepath, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # 检测 word-level + is_word_level = any( + len(line.split()[0]) > 1 + for line in lines + if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 + ) + + samples = [] + if is_word_level: + # Word-level parsing + current_text_parts = [] + current_word_labels = [] + + for line in lines: + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) >= 4: + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + + else: + # Char-level parsing + current_text = [] + current_labels = [] + + for line in lines: + if line.startswith("-DOCSTART-"): + if current_text: + samples.append( + { + "text": "".join(current_text), + "char_labels": current_labels.copy(), + } + ) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append( + { + "text": "".join(current_text), + "char_labels": current_labels.copy(), + } + ) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) >= 4: + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + if current_text: + samples.append( + {"text": "".join(current_text), "char_labels": current_labels.copy()} + ) + + return samples + + +def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: + """Load .conll files → Dataset""" + filepaths = [] + if os.path.isdir(conll_dir_or_files): + filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) + elif conll_dir_or_files.endswith(".conll"): + filepaths = [conll_dir_or_files] + else: + raise ValueError("conll_dir_or_files must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") + + print(f"Loading {len(filepaths)} conll files: {filepaths}") + + all_samples = [] + label_set = {"O"} + + for fp in tqdm(filepaths, desc="Parsing .conll"): + samples = parse_conll_file(fp) + for s in samples: + all_samples.append(s) + label_set.update(s["char_labels"]) + + # Build label list + label_list = ["O"] + for label in sorted(label_set - {"O"}): + if label.startswith("B-") or label.startswith("I-"): + label_list.append(label) + for label in list(label_list): + if label.startswith("B-"): + i_label = "I" + label[1:] + if i_label not in label_list: + label_list.append(i_label) + print(f"⚠️ Added missing {i_label} for {label}") + + print( + f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}" + ) + return Dataset.from_list(all_samples), label_list + + +# =========================== +# 2. 数据预处理(offset_mapping 对齐) +# =========================== + + +def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=True, + max_length=max_length, + return_offsets_mapping=True, + return_tensors=None, + ) + + labels = [] + for i, label_seq in enumerate(examples["char_labels"]): + offsets = tokenized["offset_mapping"][i] + label_ids = [] + for start, end in offsets: + if start == end: # special tokens + label_ids.append(-100) + else: + label_ids.append(label2id[label_seq[start]]) + labels.append(label_ids) + + tokenized["labels"] = labels + return tokenized + + +# =========================== +# 3. 训练流程 +# =========================== + + +def train_ner_model( + conll_data: str, + model_name_or_path: str = "hfl/minirbt-h256", + output_dir: str = "./models/trpg-ner-v1", + num_train_epochs: int = 15, + per_device_train_batch_size: int = 4, + learning_rate: float = 5e-5, + max_length: int = 128, + resume_from_checkpoint: Optional[str] = None, +): + # Step 1: Load data + dataset, label_list = load_conll_dataset(conll_data) + label2id = {label: i for i, label in enumerate(label_list)} + id2label = {i: label for i, label in enumerate(label_list)} + + # Step 2: Init model + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + if tokenizer.model_max_length > 1000: + tokenizer.model_max_length = max_length + + model = AutoModelForTokenClassification.from_pretrained( + model_name_or_path, + num_labels=len(label_list), + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, + ) + + # Step 3: Tokenize + tokenized_dataset = dataset.map( + lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length), + batched=True, + remove_columns=["text", "char_labels"], + ) + + # Step 4: Training args (compatible with old transformers) + training_args = TrainingArguments( + output_dir=output_dir, + learning_rate=learning_rate, + per_device_train_batch_size=per_device_train_batch_size, + num_train_epochs=num_train_epochs, + logging_steps=5, + save_steps=200, + save_total_limit=2, + do_eval=False, + report_to="none", + no_cuda=not torch.cuda.is_available(), + load_best_model_at_end=False, + push_to_hub=False, + fp16=torch.cuda.is_available(), + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + ) + + print("🚀 Starting training...") + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + print("💾 Saving final model...") + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) + return model, tokenizer, label_list + + +# =========================== +# 4. 推理函数(增强版) +# =========================== + + +def fix_timestamp(ts: str) -> str: + """Fix truncated timestamp: '4-06-08' → '2024-06-08'""" + if not ts: + return ts + m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) + if m: + year_short, month, day, rest = m.groups() + if len(year_short) == 1: + year = "202" + year_short + elif len(year_short) == 2: + year = "20" + year_short + else: + year = year_short + return f"{year}-{month}-{day}{rest}" + return ts + + +def fix_speaker(spk: str) -> str: + """Fix truncated speaker name""" + if not spk: + return spk + spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) + # Ensure at least 2 chars for Chinese names + if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk + "某" + return spk + + +class TRPGParser: + def __init__(self, model_dir: str): + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForTokenClassification.from_pretrained(model_dir) + self.id2label = self.model.config.id2label + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + + self.nlp = pipeline( + "token-classification", + model=self.model, + tokenizer=self.tokenizer, + aggregation_strategy="simple", + device=0 if torch.cuda.is_available() else -1, + ) + + def parse_line(self, text: str) -> Dict[str, Any]: + ents = self.nlp(text) + print(f"[DEBUG] Raw entities: {ents}") + out = {"metadata": {}, "content": []} + + # Merge adjacent entities + merged = [] + for e in sorted(ents, key=lambda x: x["start"]): + if merged and merged[-1]["entity_group"] == e["entity_group"]: + if e["start"] <= merged[-1]["end"]: + merged[-1]["end"] = max(merged[-1]["end"], e["end"]) + merged[-1]["score"] = min(merged[-1]["score"], e["score"]) + continue + merged.append(e) + + for e in merged: + group = e["entity_group"] + raw_text = text[e["start"] : e["end"]] + clean_text = re.sub( + r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text + ).strip() + if not clean_text: + clean_text = raw_text + + # Special fixes + if group == "timestamp": + clean_text = fix_timestamp(clean_text) + elif group == "speaker": + clean_text = fix_speaker(clean_text) + + if group in ["timestamp", "speaker"] and clean_text: + out["metadata"][group] = clean_text + elif group in ["dialogue", "action", "comment"] and clean_text: + out["content"].append( + { + "type": group, + "content": clean_text, + "confidence": round(float(e["score"]), 3), + } + ) + return out + + def parse_lines(self, texts: List[str]) -> List[Dict[str, Any]]: + return [self.parse_line(text) for text in texts] + + +# =========================== +# 5. 模型路径探测器(关键修复!) +# =========================== + + +def find_model_dir(requested_path: str, default_paths: List[str]) -> str: + """Robustly find model directory""" + # Check requested path first + candidates = [requested_path] + default_paths + + for path in candidates: + if not os.path.isdir(path): + continue + + # Check required files + required = ["config.json", "tokenizer.json"] + has_required = all((Path(path) / f).exists() for f in required) + + # Check model files (safetensors or pytorch) + model_files = ["model.safetensors", "pytorch_model.bin"] + has_model = any((Path(path) / f).exists() for f in model_files) + + if has_required and has_model: + return path + + # If not found, try subdirectories + for path in candidates: + if not os.path.isdir(path): + continue + for root, dirs, files in os.walk(path): + for d in dirs: + full_path = os.path.join(root, d) + has_required = all( + (Path(full_path) / f).exists() + for f in ["config.json", "tokenizer.json"] + ) + has_model = any( + (Path(full_path) / f).exists() + for f in ["model.safetensors", "pytorch_model.bin"] + ) + if has_required and has_model: + return full_path + + raise FileNotFoundError( + f"Model not found in any of: {candidates}\n" + "Required files: config.json, tokenizer.json, and (model.safetensors or pytorch_model.bin)\n" + "👉 Run training first: --train --conll ./data" + ) + + +def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128): + """Export model to ONNX format (fixed for local paths)""" + try: + from transformers import AutoTokenizer, AutoModelForTokenClassification + import torch + from torch.onnx import export as onnx_export + import os + + print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + + # ✅ 修复1:确保路径是绝对路径 + model_dir = os.path.abspath(model_dir) + if not os.path.exists(model_dir): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + # ✅ 修复2:显式指定 local_files_only=True + tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) + model = AutoModelForTokenClassification.from_pretrained( + model_dir, local_files_only=True + ) + model.eval() + + # Create dummy input + dummy_text = "莎莎 2024-06-08 21:46:26" + inputs = tokenizer( + dummy_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + ) + + # Ensure directory exists + os.makedirs(os.path.dirname(onnx_path), exist_ok=True) + + # Export to ONNX (使用 opset 18 以兼容现代 PyTorch) + onnx_export( + model, + (inputs["input_ids"], inputs["attention_mask"]), + onnx_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "logits": {0: "batch_size", 1: "sequence_length"}, + }, + ) + + # Verify ONNX model + import onnx + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + print( + f"✅ ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB" + ) + return True + + except Exception as e: + print(f"❌ ONNX export failed: {e}") + import traceback + + traceback.print_exc() + return False + + +# =========================== +# 6. CLI 入口 +# =========================== + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="TRPG NER: Train & Infer") + parser.add_argument("--train", action="store_true", help="Run training") + parser.add_argument( + "--conll", type=str, default="./data", help="Path to .conll files or dir" + ) + parser.add_argument( + "--model", type=str, default="hfl/minirbt-h256", help="Base model" + ) + parser.add_argument( + "--output", type=str, default="./models/trpg-ner-v1", help="Model output dir" + ) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument( + "--resume", type=str, default=None, help="Resume from checkpoint" + ) + parser.add_argument("--test", type=str, nargs="*", help="Test texts") + parser.add_argument( + "--export_onnx", action="store_true", help="Export model to ONNX" + ) + parser.add_argument( + "--onnx_path", + type=str, + default="./models/trpg-final/model.onnx", + help="ONNX output path", + ) + + args = parser.parse_args() + + if args.train: + try: + train_ner_model( + conll_data=args.conll, + model_name_or_path=args.model, + output_dir=args.output, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch, + resume_from_checkpoint=args.resume, + ) + print(f"✅ Training finished. Model saved to {args.output}") + except Exception as e: + print(f"❌ Training failed: {e}") + sys.exit(1) + + # Inference setup + default_model_paths = [ + args.output, + "./models/trpg-ner-v1", + "./models/trpg-ner-v2", + "./models/trpg-ner-v3", + "./cvrp-ner-model", + "./models", + ] + + try: + model_dir = find_model_dir(args.output, default_model_paths) + print(f"Using model from: {model_dir}") + parser = TRPGParser(model_dir=model_dir) + except FileNotFoundError as e: + print(f"{e}") + sys.exit(1) + except Exception as e: + print(f"Failed to load model: {e}") + sys.exit(1) + + # Run inference + if args.test: + for t in args.test: + print(f"\nInput: {t}") + result = parser.parse_line(t) + print("Parse:", result) + else: + # Demo + demo_texts = [ + "风雨 2024-06-08 21:44:59\n剧烈的疼痛从头颅深处一波波地涌出...", + "莎莎 2024-06-08 21:46:26\n“呜哇...”#下意识去拿法杖", + "白麗 霊夢 2024-06-08 21:50:03\n莎莎 的出目是 D10+7=6+7=13", + ] + print("\n🧪 Demo inference (using model from", model_dir, "):") + for i, t in enumerate(demo_texts, 1): + print(f"\nDemo {i}: {t[:50]}...") + result = parser.parse_line(t) + meta = result["metadata"] + content = result["content"] + print(f" Speaker: {meta.get('speaker', 'N/A')}") + print(f" Timestamp: {meta.get('timestamp', 'N/A')}") + if content: + first = content[0] + print( + f" Content: {first['type']}='{first['content'][:40]}{'...' if len(first['content'])>40 else ''}' (conf={first['confidence']})" + ) + + # ONNX export + if args.export_onnx: + # 智能确定模型目录:优先使用推理时找到的目录 + onnx_model_dir = model_dir + success = export_to_onnx( + model_dir=onnx_model_dir, + onnx_path=args.onnx_path, + max_length=128, + ) + if success: + print(f"🎉 ONNX model saved to {args.onnx_path}") + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index f6fb85a..0e9a82b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,11 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "gradio>=6.2.0", + "onnx>=1.20.0", + "onnxruntime>=1.23.2", + "onnxscript>=0.5.7", + "pynvml>=13.0.1", "torch>=2.9.1", "transformers>=4.57.3", ] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bd6d100 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,136 @@ +accelerate==1.12.0 +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.2 +aiosignal==1.4.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.0 +asttokens==3.0.1 +attrs==25.4.0 +brotli==1.2.0 +certifi==2025.11.12 +charset-normalizer==3.4.4 +click==8.3.1 +coloredlogs==15.0.1 +comm==0.2.3 +datasets==4.4.2 +debugpy==1.8.19 +decorator==5.2.1 +dill==0.4.0 +executing==2.2.1 +fastapi==0.128.0 +ffmpy==1.0.0 +filelock==3.20.1 +flatbuffers==25.12.19 +frozenlist==1.8.0 +fsspec==2025.12.0 +gradio==6.2.0 +gradio-client==2.0.2 +groovy==0.1.2 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.36.0 +humanfriendly==10.0 +idna==3.11 +ipykernel==7.1.0 +ipython==9.8.0 +ipython-pygments-lexers==1.1.1 +jedi==0.19.2 +jinja2==3.1.6 +joblib==1.5.3 +jupyter-client==8.7.0 +jupyter-core==5.9.1 +markdown-it-py==4.0.0 +markupsafe==3.0.3 +matplotlib-inline==0.2.1 +maturin==1.10.2 +mdurl==0.1.2 +ml-dtypes==0.5.4 +mpmath==1.3.0 +multidict==6.7.0 +multiprocess==0.70.18 +nest-asyncio==1.6.0 +networkx==3.6.1 +numpy==2.4.0 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-ml-py==13.590.44 +nvidia-nccl-cu12==2.27.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.3.20 +nvidia-nvtx-cu12==12.8.90 +onnx==1.20.0 +onnx-ir==0.1.13 +onnxruntime==1.23.2 +onnxscript==0.5.7 +orjson==3.11.5 +packaging==25.0 +pandas==2.3.3 +parso==0.8.5 +pexpect==4.9.0 +pillow==12.0.0 +platformdirs==4.5.1 +prompt-toolkit==3.0.52 +propcache==0.4.1 +protobuf==6.33.2 +psutil==7.2.1 +ptyprocess==0.7.0 +pure-eval==0.2.3 +pyarrow==22.0.0 +pydantic==2.12.5 +pydantic-core==2.41.5 +pydub==0.25.1 +pygments==2.19.2 +pynvml==13.0.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.21 +pytz==2025.2 +pyyaml==6.0.3 +pyzmq==27.1.0 +regex==2025.11.3 +requests==2.32.5 +rich==14.2.0 +safehttpx==0.1.7 +safetensors==0.7.0 +scikit-learn==1.8.0 +scipy==1.16.3 +semantic-version==2.10.0 +seqeval==1.2.2 +setuptools==80.9.0 +shellingham==1.5.4 +six==1.17.0 +stack-data==0.6.3 +starlette==0.50.0 +sympy==1.14.0 +threadpoolctl==3.6.0 +tokenizers==0.22.1 +tomlkit==0.13.3 +torch==2.9.1 +torchaudio==2.9.1+cpu +torchvision==0.24.1+cpu +tornado==6.5.4 +tqdm==4.67.1 +traitlets==5.14.3 +transformers==4.57.3 +triton==3.5.1 +typer==0.21.0 +typing-extensions==4.15.0 +typing-inspection==0.4.2 +tzdata==2025.3 +urllib3==2.6.2 +uvicorn==0.40.0 +wcwidth==0.2.14 +xxhash==3.6.0 +yarl==1.22.0 diff --git a/src/utils/conll_to_dataset.py b/src/utils/conll_to_dataset.py new file mode 100644 index 0000000..2ea5469 --- /dev/null +++ b/src/utils/conll_to_dataset.py @@ -0,0 +1,263 @@ +""" +TRPG CoNLL 转 Dataset 工具 +- 自动检测 word-level / char-level +- 生成 {"text": str, "char_labels": List[str]} +- 支持多文档、跨行实体 +""" + +import os +import re +import json +import argparse +from pathlib import Path +from typing import List, Dict, Any, Tuple +from datasets import Dataset + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """ + 将 word-level 标注转为 char-level labels + Args: + text: 原始文本 (e.g., "风雨 2024-06-08") + word_labels: [("风雨", "B-speaker"), ("2024-06-08", "B-timestamp"), ...] + Returns: + char_labels: ["B-speaker", "I-speaker", "O", "B-timestamp", ...] + """ + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + # 在文本中定位 token(处理空格/换行) + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + # 匹配 token + if text[pos:pos+len(token)] == token: + # 标注 B/I + for i, char in enumerate(token): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + +def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]: + """ + 解析 .conll → [{"text": "...", "char_labels": [...]}, ...] + 自动处理: + - -DOCSTART- 文档边界 + - 空行句子边界 + - word-level → char-level 转换 + """ + samples = [] + current_lines = [] # 存储原始行用于检测粒度 + + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + current_lines.append(line.rstrip('\n')) + + # 检测是否 word-level + is_word_level = False + for line in current_lines: + if line.strip() and not line.startswith("-DOCSTART-"): + parts = line.split() + if len(parts) >= 4: + token = parts[0] + # 如果 token 长度 >1 且非标点 → 可能是 word-level + if len(token) > 1 and not re.match(r'^[^\w\s\u4e00-\u9fff]+$', token): + is_word_level = True + break + + if is_word_level: + print(f"Detected word-level CoNLL, converting to char-level...") + return _parse_word_conll(filepath) + else: + print(f"Detected char-level CoNLL, parsing directly...") + return _parse_char_conll(filepath) + +def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]: + """解析 word-level .conll(如您提供的原始格式)""" + samples = [] + current_text_parts = [] + current_word_labels = [] + + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + # 合并文本 + text = "".join(current_text_parts) + # 生成 char-level labels + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({ + "text": text, + "char_labels": char_labels + }) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) < 4: + continue + + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + # 处理末尾 + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({ + "text": text, + "char_labels": char_labels + }) + + return samples + +def _parse_char_conll(filepath: str) -> List[Dict[str, Any]]: + """解析 char-level .conll""" + samples = [] + current_text = [] + current_labels = [] + + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.rstrip('\n') + if line.startswith("-DOCSTART-"): + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy() + }) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy() + }) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) < 4: + continue + + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + # 末尾处理 + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy() + }) + + return samples + +def save_dataset(samples: List[Dict[str, Any]], output_path: str, format: str = "jsonl"): + """保存数据集""" + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + if format == "jsonl": + with open(output_path, 'w', encoding='utf-8') as f: + for sample in samples: + f.write(json.dumps(sample, ensure_ascii=False) + '\n') + print(f"Saved {len(samples)} samples to {output_path} (JSONL)") + + elif format == "dataset": + dataset = Dataset.from_list(samples) + dataset.save_to_disk(output_path) + print(f"Saved {len(samples)} samples to {output_path} (Hugging Face Dataset)") + + elif format == "both": + jsonl_path = output_path + ".jsonl" + with open(jsonl_path, 'w', encoding='utf-8') as f: + for sample in samples: + f.write(json.dumps(sample, ensure_ascii=False) + '\n') + print(f"Saved JSONL to {jsonl_path}") + + dataset_path = output_path + "_dataset" + dataset = Dataset.from_list(samples) + dataset.save_to_disk(dataset_path) + print(f"Saved Dataset to {dataset_path}") + +def validate_samples(samples: List[Dict[str, Any]]) -> bool: + """验证样本一致性""" + for i, sample in enumerate(samples): + if len(sample["text"]) != len(sample["char_labels"]): + print(f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}") + return False + print(f"All {len(samples)} samples validated: text & labels length match") + return True + +def main(): + parser = argparse.ArgumentParser(description="Convert CoNLL to TRPG Dataset") + parser.add_argument("input", type=str, help="Input .conll file or directory") + parser.add_argument("--output", type=str, default="./dataset/trpg", + help="Output path (without extension)") + parser.add_argument("--format", choices=["jsonl", "dataset", "both"], + default="jsonl", help="Output format") + parser.add_argument("--validate", action="store_true", + help="Validate samples after conversion") + + args = parser.parse_args() + + filepaths = [] + if os.path.isdir(args.input): + filepaths = sorted(Path(args.input).glob("*.conll")) + elif args.input.endswith(".conll"): + filepaths = [Path(args.input)] + else: + raise ValueError("Input must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {args.input}") + + print(f"Processing {len(filepaths)} files: {[f.name for f in filepaths]}") + + all_samples = [] + for fp in filepaths: + print(f"\nProcessing {fp.name}...") + samples = parse_conll_to_samples(str(fp)) + print(f" → {len(samples)} samples") + all_samples.extend(samples) + + print(f"\nTotal: {len(all_samples)} samples") + + if args.validate: + if not validate_samples(all_samples): + exit(1) + + save_dataset(all_samples, args.output, args.format) + + label_counts = {} + for sample in all_samples: + for label in sample["char_labels"]: + label_counts[label] = label_counts.get(label, 0) + 1 + + print("\nLabel distribution:") + for label in sorted(label_counts.keys()): + print(f" {label}: {label_counts[label]}") + +if __name__ == "__main__": + main()
\ No newline at end of file diff --git a/src/utils/word_conll_to_char_conll.py b/src/utils/word_conll_to_char_conll.py new file mode 100644 index 0000000..e52405f --- /dev/null +++ b/src/utils/word_conll_to_char_conll.py @@ -0,0 +1,55 @@ +def word_conll_to_char_conll(word_conll_lines: list[str]) -> list[str]: + char_lines = [] + in_new_sample = True # 下一行是否应视为新样本开始 + + for line in word_conll_lines: + stripped = line.strip() + if not stripped: + # 空行 → 标记下一句为新样本 + in_new_sample = True + char_lines.append("") + continue + + parts = stripped.split() + if len(parts) < 4: + char_lines.append(line.rstrip()) + continue + + token, label = parts[0], parts[3] + + # 检测新发言:B-speaker 出现 → 新样本 + if label == "B-speaker" and in_new_sample: + char_lines.append("-DOCSTART- -X- O") + in_new_sample = False + + # 转换 token → char labels(同前) + if label == "O": + for c in token: + char_lines.append(f"{c} -X- _ O") + else: + bio_prefix = label[:2] + tag = label[2:] + for i, c in enumerate(token): + char_label = f"B-{tag}" if (bio_prefix == "B-" and i == 0) else f"I-{tag}" + char_lines.append(f"{c} -X- _ {char_label}") + + return char_lines + +if __name__ == "__main__": + import sys + if len(sys.argv) < 3: + print("Usage: python word_conll_to_char_conll.py <input_word.conll> <output_char.conll>") + sys.exit(1) + + input_fp = sys.argv[1] + output_fp = sys.argv[2] + + with open(input_fp, "r", encoding="utf-8") as f: + word_conll_lines = f.readlines() + + char_conll_lines = word_conll_to_char_conll(word_conll_lines) + + with open(output_fp, "w", encoding="utf-8") as f: + f.write("\n".join(char_conll_lines) + "\n") + + print(f"Converted {input_fp} to character-level CoNLL format at {output_fp}")
\ No newline at end of file diff --git a/tests/onnx_infer.py b/tests/onnx_infer.py new file mode 100644 index 0000000..4ffca25 --- /dev/null +++ b/tests/onnx_infer.py @@ -0,0 +1,115 @@ +import os, sys, json, re +import numpy as np +import onnxruntime as ort +from transformers import AutoTokenizer + +MODEL_DIR = "models/trpg-final" +ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx") + +tok = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) + +providers = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in ort.get_available_providers()] +so = ort.SessionOptions() +so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL +sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers) + +def softmax(x): + x = x - x.max(axis=-1, keepdims=True) + e = np.exp(x) + return e / e.sum(axis=-1, keepdims=True) + +text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..." +inputs = tok(text, return_tensors="np", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=128) + +feed = {} +for inp in sess.get_inputs(): + if inp.name in inputs: + feed[inp.name] = inputs[inp.name] + +outs = sess.run(None, feed) +logits = np.asarray(outs[0]) # (batch, seq_len, num_labels) +probs = softmax(logits) + +ids = inputs["input_ids"][0] +offsets = inputs["offset_mapping"][0] +attn = inputs["attention_mask"][0] +tokens = tok.convert_ids_to_tokens(ids) + +print("Raw logits shape:", logits.shape) +# print("\nPer-token raw logits (token : [..first 8 logits..])") +# for i, (t, l, a) in enumerate(zip(tokens, logits[0], attn)): +# if not a: +# continue +# print(f"{i:03d}", t, "->", np.around(l[:8], 4).tolist()) + +pred_ids = logits.argmax(-1)[0] +pred_probs = probs[0, np.arange(probs.shape[1]), pred_ids] + +with open(os.path.join(MODEL_DIR, "config.json"), "r", encoding="utf-8") as f: + cfg = json.load(f) +id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()} + +print("\nPer-token predictions (token \\t label \\t prob):") +for i, (t, pid, pprob, a) in enumerate(zip(tokens, pred_ids, pred_probs, attn)): + if not a: + continue + lab = id2label.get(int(pid), "O") + print(f"{t}\t{lab}\t{pprob:.3f}") + +# 聚合实体 +entities = [] +cur = None +for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)): + if not a or (off[0] == off[1] == 0): + if cur: + entities.append(cur); cur = None + continue + label = id2label.get(int(pid), "O") + if label == "O": + if cur: + entities.append(cur); cur = None + continue + if label.startswith("B-") or cur is None or label[2:] != cur["type"]: + if cur: + entities.append(cur) + cur = {"type": label[2:], "tokens": [i], "start": int(off[0]), "end": int(off[1]), "probs":[float(pprob)]} + else: + cur["tokens"].append(i) + cur["end"] = int(off[1]) + cur["probs"].append(float(pprob)) +if cur: + entities.append(cur) + +def fix_timestamp(ts): + if not ts: return ts + m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) + if m: + y, mo, d, rest = m.groups() + if len(y)==1: y="202"+y + elif len(y)==2: y="20"+y + return f"{y}-{mo}-{d}{rest}" + return ts +def fix_speaker(spk): + if not spk: return spk + spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) + if len(spk)==1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk+"某" + return spk + +out = {"metadata": {}, "content": []} +for e in entities: + s, epos = e["start"], e["end"] + ent_text = text[s:epos] + conf = round(float(np.mean(e["probs"])), 3) + typ = e["type"] + if typ in ("timestamp", "speaker"): + if typ=="timestamp": + ent_text = fix_timestamp(ent_text) + else: + ent_text = fix_speaker(ent_text) + out["metadata"][typ] = ent_text + else: + out["content"].append({"type": typ, "content": ent_text, "confidence": conf}) + +print("\nConstructed JSON:") +print(json.dumps(out, ensure_ascii=False, indent=2))
\ No newline at end of file |
