summaryrefslogtreecommitdiffstatshomepage
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py632
1 files changed, 629 insertions, 3 deletions
diff --git a/main.py b/main.py
index de058e7..2f96207 100644
--- a/main.py
+++ b/main.py
@@ -1,4 +1,630 @@
-from transformers import BertTokenizer, BertModel
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+TRPG NER 训练与推理脚本 (Robust Edition)
+- 自动探测模型路径(支持 safetensors/pytorch)
+- 统一 tokenizer 行为(offset_mapping)
+- 智能修复 timestamp/speaker 截断
+- 兼容 transformers <4.5
+"""
-tokenizer = BertTokenizer.from_pretrained("hfl/minirbt-h256")
-model = BertModel.from_pretrained("hfl/minirbt-h256") \ No newline at end of file
+import os
+import re
+import glob
+import sys
+from typing import List, Dict, Any, Tuple, Optional
+from pathlib import Path
+
+import torch
+from transformers import (
+ AutoTokenizer,
+ AutoModelForTokenClassification,
+ TrainingArguments,
+ Trainer,
+ pipeline,
+ logging as hf_logging,
+)
+from datasets import Dataset
+from tqdm.auto import tqdm
+
+# 抑制 transformers 警告
+hf_logging.set_verbosity_error()
+
+
+# ===========================
+# 1. CoNLL 解析器(自动 word→char 转换)
+# ===========================
+
+
+def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
+ """Convert word-level labels to char-level"""
+ char_labels = ["O"] * len(text)
+ pos = 0
+
+ for token, label in word_labels:
+ if pos >= len(text):
+ break
+
+ # 定位 token(跳过空格/换行)
+ while pos < len(text) and text[pos] != token[0]:
+ pos += 1
+ if pos >= len(text):
+ break
+
+ if text[pos : pos + len(token)] == token:
+ for i, _ in enumerate(token):
+ idx = pos + i
+ if idx < len(char_labels):
+ if i == 0 and label.startswith("B-"):
+ char_labels[idx] = label
+ elif label.startswith("B-"):
+ char_labels[idx] = "I" + label[1:]
+ else:
+ char_labels[idx] = label
+ pos += len(token)
+ else:
+ pos += 1
+
+ return char_labels
+
+
+def parse_conll_file(filepath: str) -> List[Dict[str, Any]]:
+ """Parse .conll → [{"text": str, "char_labels": List[str]}]"""
+ with open(filepath, "r", encoding="utf-8") as f:
+ lines = [line.rstrip("\n") for line in f.readlines()]
+
+ # 检测 word-level
+ is_word_level = any(
+ len(line.split()[0]) > 1
+ for line in lines
+ if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4
+ )
+
+ samples = []
+ if is_word_level:
+ # Word-level parsing
+ current_text_parts = []
+ current_word_labels = []
+
+ for line in lines:
+ if not line or line.startswith("-DOCSTART-"):
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+ current_text_parts = []
+ current_word_labels = []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ token, label = parts[0], parts[3]
+ current_text_parts.append(token)
+ current_word_labels.append((token, label))
+
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+
+ else:
+ # Char-level parsing
+ current_text = []
+ current_labels = []
+
+ for line in lines:
+ if line.startswith("-DOCSTART-"):
+ if current_text:
+ samples.append(
+ {
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ }
+ )
+ current_text, current_labels = [], []
+ continue
+
+ if not line:
+ if current_text:
+ samples.append(
+ {
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ }
+ )
+ current_text, current_labels = [], []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ char = parts[0].replace("\\n", "\n")
+ label = parts[3]
+ current_text.append(char)
+ current_labels.append(label)
+
+ if current_text:
+ samples.append(
+ {"text": "".join(current_text), "char_labels": current_labels.copy()}
+ )
+
+ return samples
+
+
+def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]:
+ """Load .conll files → Dataset"""
+ filepaths = []
+ if os.path.isdir(conll_dir_or_files):
+ filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll")))
+ elif conll_dir_or_files.endswith(".conll"):
+ filepaths = [conll_dir_or_files]
+ else:
+ raise ValueError("conll_dir_or_files must be .conll file or directory")
+
+ if not filepaths:
+ raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}")
+
+ print(f"Loading {len(filepaths)} conll files: {filepaths}")
+
+ all_samples = []
+ label_set = {"O"}
+
+ for fp in tqdm(filepaths, desc="Parsing .conll"):
+ samples = parse_conll_file(fp)
+ for s in samples:
+ all_samples.append(s)
+ label_set.update(s["char_labels"])
+
+ # Build label list
+ label_list = ["O"]
+ for label in sorted(label_set - {"O"}):
+ if label.startswith("B-") or label.startswith("I-"):
+ label_list.append(label)
+ for label in list(label_list):
+ if label.startswith("B-"):
+ i_label = "I" + label[1:]
+ if i_label not in label_list:
+ label_list.append(i_label)
+ print(f"⚠️ Added missing {i_label} for {label}")
+
+ print(
+ f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}"
+ )
+ return Dataset.from_list(all_samples), label_list
+
+
+# ===========================
+# 2. 数据预处理(offset_mapping 对齐)
+# ===========================
+
+
+def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128):
+ tokenized = tokenizer(
+ examples["text"],
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+ return_offsets_mapping=True,
+ return_tensors=None,
+ )
+
+ labels = []
+ for i, label_seq in enumerate(examples["char_labels"]):
+ offsets = tokenized["offset_mapping"][i]
+ label_ids = []
+ for start, end in offsets:
+ if start == end: # special tokens
+ label_ids.append(-100)
+ else:
+ label_ids.append(label2id[label_seq[start]])
+ labels.append(label_ids)
+
+ tokenized["labels"] = labels
+ return tokenized
+
+
+# ===========================
+# 3. 训练流程
+# ===========================
+
+
+def train_ner_model(
+ conll_data: str,
+ model_name_or_path: str = "hfl/minirbt-h256",
+ output_dir: str = "./models/trpg-ner-v1",
+ num_train_epochs: int = 15,
+ per_device_train_batch_size: int = 4,
+ learning_rate: float = 5e-5,
+ max_length: int = 128,
+ resume_from_checkpoint: Optional[str] = None,
+):
+ # Step 1: Load data
+ dataset, label_list = load_conll_dataset(conll_data)
+ label2id = {label: i for i, label in enumerate(label_list)}
+ id2label = {i: label for i, label in enumerate(label_list)}
+
+ # Step 2: Init model
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+ if tokenizer.model_max_length > 1000:
+ tokenizer.model_max_length = max_length
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_name_or_path,
+ num_labels=len(label_list),
+ id2label=id2label,
+ label2id=label2id,
+ ignore_mismatched_sizes=True,
+ )
+
+ # Step 3: Tokenize
+ tokenized_dataset = dataset.map(
+ lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length),
+ batched=True,
+ remove_columns=["text", "char_labels"],
+ )
+
+ # Step 4: Training args (compatible with old transformers)
+ training_args = TrainingArguments(
+ output_dir=output_dir,
+ learning_rate=learning_rate,
+ per_device_train_batch_size=per_device_train_batch_size,
+ num_train_epochs=num_train_epochs,
+ logging_steps=5,
+ save_steps=200,
+ save_total_limit=2,
+ do_eval=False,
+ report_to="none",
+ no_cuda=not torch.cuda.is_available(),
+ load_best_model_at_end=False,
+ push_to_hub=False,
+ fp16=torch.cuda.is_available(),
+ )
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=tokenized_dataset,
+ tokenizer=tokenizer,
+ )
+
+ print("🚀 Starting training...")
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+
+ print("💾 Saving final model...")
+ trainer.save_model(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ return model, tokenizer, label_list
+
+
+# ===========================
+# 4. 推理函数(增强版)
+# ===========================
+
+
+def fix_timestamp(ts: str) -> str:
+ """Fix truncated timestamp: '4-06-08' → '2024-06-08'"""
+ if not ts:
+ return ts
+ m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts)
+ if m:
+ year_short, month, day, rest = m.groups()
+ if len(year_short) == 1:
+ year = "202" + year_short
+ elif len(year_short) == 2:
+ year = "20" + year_short
+ else:
+ year = year_short
+ return f"{year}-{month}-{day}{rest}"
+ return ts
+
+
+def fix_speaker(spk: str) -> str:
+ """Fix truncated speaker name"""
+ if not spk:
+ return spk
+ spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk)
+ # Ensure at least 2 chars for Chinese names
+ if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk):
+ return spk + "某"
+ return spk
+
+
+class TRPGParser:
+ def __init__(self, model_dir: str):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
+ self.model = AutoModelForTokenClassification.from_pretrained(model_dir)
+ self.id2label = self.model.config.id2label
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.model.to(self.device)
+
+ self.nlp = pipeline(
+ "token-classification",
+ model=self.model,
+ tokenizer=self.tokenizer,
+ aggregation_strategy="simple",
+ device=0 if torch.cuda.is_available() else -1,
+ )
+
+ def parse_line(self, text: str) -> Dict[str, Any]:
+ ents = self.nlp(text)
+ print(f"[DEBUG] Raw entities: {ents}")
+ out = {"metadata": {}, "content": []}
+
+ # Merge adjacent entities
+ merged = []
+ for e in sorted(ents, key=lambda x: x["start"]):
+ if merged and merged[-1]["entity_group"] == e["entity_group"]:
+ if e["start"] <= merged[-1]["end"]:
+ merged[-1]["end"] = max(merged[-1]["end"], e["end"])
+ merged[-1]["score"] = min(merged[-1]["score"], e["score"])
+ continue
+ merged.append(e)
+
+ for e in merged:
+ group = e["entity_group"]
+ raw_text = text[e["start"] : e["end"]]
+ clean_text = re.sub(
+ r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text
+ ).strip()
+ if not clean_text:
+ clean_text = raw_text
+
+ # Special fixes
+ if group == "timestamp":
+ clean_text = fix_timestamp(clean_text)
+ elif group == "speaker":
+ clean_text = fix_speaker(clean_text)
+
+ if group in ["timestamp", "speaker"] and clean_text:
+ out["metadata"][group] = clean_text
+ elif group in ["dialogue", "action", "comment"] and clean_text:
+ out["content"].append(
+ {
+ "type": group,
+ "content": clean_text,
+ "confidence": round(float(e["score"]), 3),
+ }
+ )
+ return out
+
+ def parse_lines(self, texts: List[str]) -> List[Dict[str, Any]]:
+ return [self.parse_line(text) for text in texts]
+
+
+# ===========================
+# 5. 模型路径探测器(关键修复!)
+# ===========================
+
+
+def find_model_dir(requested_path: str, default_paths: List[str]) -> str:
+ """Robustly find model directory"""
+ # Check requested path first
+ candidates = [requested_path] + default_paths
+
+ for path in candidates:
+ if not os.path.isdir(path):
+ continue
+
+ # Check required files
+ required = ["config.json", "tokenizer.json"]
+ has_required = all((Path(path) / f).exists() for f in required)
+
+ # Check model files (safetensors or pytorch)
+ model_files = ["model.safetensors", "pytorch_model.bin"]
+ has_model = any((Path(path) / f).exists() for f in model_files)
+
+ if has_required and has_model:
+ return path
+
+ # If not found, try subdirectories
+ for path in candidates:
+ if not os.path.isdir(path):
+ continue
+ for root, dirs, files in os.walk(path):
+ for d in dirs:
+ full_path = os.path.join(root, d)
+ has_required = all(
+ (Path(full_path) / f).exists()
+ for f in ["config.json", "tokenizer.json"]
+ )
+ has_model = any(
+ (Path(full_path) / f).exists()
+ for f in ["model.safetensors", "pytorch_model.bin"]
+ )
+ if has_required and has_model:
+ return full_path
+
+ raise FileNotFoundError(
+ f"Model not found in any of: {candidates}\n"
+ "Required files: config.json, tokenizer.json, and (model.safetensors or pytorch_model.bin)\n"
+ "👉 Run training first: --train --conll ./data"
+ )
+
+
+def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128):
+ """Export model to ONNX format (fixed for local paths)"""
+ try:
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
+ import torch
+ from torch.onnx import export as onnx_export
+ import os
+
+ print(f"📤 Exporting model from {model_dir} to {onnx_path}...")
+
+ # ✅ 修复1:确保路径是绝对路径
+ model_dir = os.path.abspath(model_dir)
+ if not os.path.exists(model_dir):
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
+
+ # ✅ 修复2:显式指定 local_files_only=True
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_dir, local_files_only=True
+ )
+ model.eval()
+
+ # Create dummy input
+ dummy_text = "莎莎 2024-06-08 21:46:26"
+ inputs = tokenizer(
+ dummy_text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ )
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
+
+ # Export to ONNX (使用 opset 18 以兼容现代 PyTorch)
+ onnx_export(
+ model,
+ (inputs["input_ids"], inputs["attention_mask"]),
+ onnx_path,
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=["input_ids", "attention_mask"],
+ output_names=["logits"],
+ dynamic_axes={
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
+ "logits": {0: "batch_size", 1: "sequence_length"},
+ },
+ )
+
+ # Verify ONNX model
+ import onnx
+
+ onnx_model = onnx.load(onnx_path)
+ onnx.checker.check_model(onnx_model)
+
+ print(
+ f"✅ ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB"
+ )
+ return True
+
+ except Exception as e:
+ print(f"❌ ONNX export failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return False
+
+
+# ===========================
+# 6. CLI 入口
+# ===========================
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="TRPG NER: Train & Infer")
+ parser.add_argument("--train", action="store_true", help="Run training")
+ parser.add_argument(
+ "--conll", type=str, default="./data", help="Path to .conll files or dir"
+ )
+ parser.add_argument(
+ "--model", type=str, default="hfl/minirbt-h256", help="Base model"
+ )
+ parser.add_argument(
+ "--output", type=str, default="./models/trpg-ner-v1", help="Model output dir"
+ )
+ parser.add_argument("--epochs", type=int, default=20)
+ parser.add_argument("--batch", type=int, default=4)
+ parser.add_argument(
+ "--resume", type=str, default=None, help="Resume from checkpoint"
+ )
+ parser.add_argument("--test", type=str, nargs="*", help="Test texts")
+ parser.add_argument(
+ "--export_onnx", action="store_true", help="Export model to ONNX"
+ )
+ parser.add_argument(
+ "--onnx_path",
+ type=str,
+ default="./models/trpg-final/model.onnx",
+ help="ONNX output path",
+ )
+
+ args = parser.parse_args()
+
+ if args.train:
+ try:
+ train_ner_model(
+ conll_data=args.conll,
+ model_name_or_path=args.model,
+ output_dir=args.output,
+ num_train_epochs=args.epochs,
+ per_device_train_batch_size=args.batch,
+ resume_from_checkpoint=args.resume,
+ )
+ print(f"✅ Training finished. Model saved to {args.output}")
+ except Exception as e:
+ print(f"❌ Training failed: {e}")
+ sys.exit(1)
+
+ # Inference setup
+ default_model_paths = [
+ args.output,
+ "./models/trpg-ner-v1",
+ "./models/trpg-ner-v2",
+ "./models/trpg-ner-v3",
+ "./cvrp-ner-model",
+ "./models",
+ ]
+
+ try:
+ model_dir = find_model_dir(args.output, default_model_paths)
+ print(f"Using model from: {model_dir}")
+ parser = TRPGParser(model_dir=model_dir)
+ except FileNotFoundError as e:
+ print(f"{e}")
+ sys.exit(1)
+ except Exception as e:
+ print(f"Failed to load model: {e}")
+ sys.exit(1)
+
+ # Run inference
+ if args.test:
+ for t in args.test:
+ print(f"\nInput: {t}")
+ result = parser.parse_line(t)
+ print("Parse:", result)
+ else:
+ # Demo
+ demo_texts = [
+ "风雨 2024-06-08 21:44:59\n剧烈的疼痛从头颅深处一波波地涌出...",
+ "莎莎 2024-06-08 21:46:26\n“呜哇...”#下意识去拿法杖",
+ "白麗 霊夢 2024-06-08 21:50:03\n莎莎 的出目是 D10+7=6+7=13",
+ ]
+ print("\n🧪 Demo inference (using model from", model_dir, "):")
+ for i, t in enumerate(demo_texts, 1):
+ print(f"\nDemo {i}: {t[:50]}...")
+ result = parser.parse_line(t)
+ meta = result["metadata"]
+ content = result["content"]
+ print(f" Speaker: {meta.get('speaker', 'N/A')}")
+ print(f" Timestamp: {meta.get('timestamp', 'N/A')}")
+ if content:
+ first = content[0]
+ print(
+ f" Content: {first['type']}='{first['content'][:40]}{'...' if len(first['content'])>40 else ''}' (conf={first['confidence']})"
+ )
+
+ # ONNX export
+ if args.export_onnx:
+ # 智能确定模型目录:优先使用推理时找到的目录
+ onnx_model_dir = model_dir
+ success = export_to_onnx(
+ model_dir=onnx_model_dir,
+ onnx_path=args.onnx_path,
+ max_length=128,
+ )
+ if success:
+ print(f"🎉 ONNX model saved to {args.onnx_path}")
+ else:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()