diff options
Diffstat (limited to 'main.py')
| -rw-r--r-- | main.py | 632 |
1 files changed, 629 insertions, 3 deletions
@@ -1,4 +1,630 @@ -from transformers import BertTokenizer, BertModel +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +TRPG NER 训练与推理脚本 (Robust Edition) +- 自动探测模型路径(支持 safetensors/pytorch) +- 统一 tokenizer 行为(offset_mapping) +- 智能修复 timestamp/speaker 截断 +- 兼容 transformers <4.5 +""" -tokenizer = BertTokenizer.from_pretrained("hfl/minirbt-h256") -model = BertModel.from_pretrained("hfl/minirbt-h256")
\ No newline at end of file +import os +import re +import glob +import sys +from typing import List, Dict, Any, Tuple, Optional +from pathlib import Path + +import torch +from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + TrainingArguments, + Trainer, + pipeline, + logging as hf_logging, +) +from datasets import Dataset +from tqdm.auto import tqdm + +# 抑制 transformers 警告 +hf_logging.set_verbosity_error() + + +# =========================== +# 1. CoNLL 解析器(自动 word→char 转换) +# =========================== + + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """Convert word-level labels to char-level""" + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + # 定位 token(跳过空格/换行) + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + if text[pos : pos + len(token)] == token: + for i, _ in enumerate(token): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + + +def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: + """Parse .conll → [{"text": str, "char_labels": List[str]}]""" + with open(filepath, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # 检测 word-level + is_word_level = any( + len(line.split()[0]) > 1 + for line in lines + if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 + ) + + samples = [] + if is_word_level: + # Word-level parsing + current_text_parts = [] + current_word_labels = [] + + for line in lines: + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) >= 4: + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + + else: + # Char-level parsing + current_text = [] + current_labels = [] + + for line in lines: + if line.startswith("-DOCSTART-"): + if current_text: + samples.append( + { + "text": "".join(current_text), + "char_labels": current_labels.copy(), + } + ) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append( + { + "text": "".join(current_text), + "char_labels": current_labels.copy(), + } + ) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) >= 4: + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + if current_text: + samples.append( + {"text": "".join(current_text), "char_labels": current_labels.copy()} + ) + + return samples + + +def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: + """Load .conll files → Dataset""" + filepaths = [] + if os.path.isdir(conll_dir_or_files): + filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) + elif conll_dir_or_files.endswith(".conll"): + filepaths = [conll_dir_or_files] + else: + raise ValueError("conll_dir_or_files must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") + + print(f"Loading {len(filepaths)} conll files: {filepaths}") + + all_samples = [] + label_set = {"O"} + + for fp in tqdm(filepaths, desc="Parsing .conll"): + samples = parse_conll_file(fp) + for s in samples: + all_samples.append(s) + label_set.update(s["char_labels"]) + + # Build label list + label_list = ["O"] + for label in sorted(label_set - {"O"}): + if label.startswith("B-") or label.startswith("I-"): + label_list.append(label) + for label in list(label_list): + if label.startswith("B-"): + i_label = "I" + label[1:] + if i_label not in label_list: + label_list.append(i_label) + print(f"⚠️ Added missing {i_label} for {label}") + + print( + f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}" + ) + return Dataset.from_list(all_samples), label_list + + +# =========================== +# 2. 数据预处理(offset_mapping 对齐) +# =========================== + + +def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=True, + max_length=max_length, + return_offsets_mapping=True, + return_tensors=None, + ) + + labels = [] + for i, label_seq in enumerate(examples["char_labels"]): + offsets = tokenized["offset_mapping"][i] + label_ids = [] + for start, end in offsets: + if start == end: # special tokens + label_ids.append(-100) + else: + label_ids.append(label2id[label_seq[start]]) + labels.append(label_ids) + + tokenized["labels"] = labels + return tokenized + + +# =========================== +# 3. 训练流程 +# =========================== + + +def train_ner_model( + conll_data: str, + model_name_or_path: str = "hfl/minirbt-h256", + output_dir: str = "./models/trpg-ner-v1", + num_train_epochs: int = 15, + per_device_train_batch_size: int = 4, + learning_rate: float = 5e-5, + max_length: int = 128, + resume_from_checkpoint: Optional[str] = None, +): + # Step 1: Load data + dataset, label_list = load_conll_dataset(conll_data) + label2id = {label: i for i, label in enumerate(label_list)} + id2label = {i: label for i, label in enumerate(label_list)} + + # Step 2: Init model + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + if tokenizer.model_max_length > 1000: + tokenizer.model_max_length = max_length + + model = AutoModelForTokenClassification.from_pretrained( + model_name_or_path, + num_labels=len(label_list), + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, + ) + + # Step 3: Tokenize + tokenized_dataset = dataset.map( + lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length), + batched=True, + remove_columns=["text", "char_labels"], + ) + + # Step 4: Training args (compatible with old transformers) + training_args = TrainingArguments( + output_dir=output_dir, + learning_rate=learning_rate, + per_device_train_batch_size=per_device_train_batch_size, + num_train_epochs=num_train_epochs, + logging_steps=5, + save_steps=200, + save_total_limit=2, + do_eval=False, + report_to="none", + no_cuda=not torch.cuda.is_available(), + load_best_model_at_end=False, + push_to_hub=False, + fp16=torch.cuda.is_available(), + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + ) + + print("🚀 Starting training...") + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + print("💾 Saving final model...") + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) + return model, tokenizer, label_list + + +# =========================== +# 4. 推理函数(增强版) +# =========================== + + +def fix_timestamp(ts: str) -> str: + """Fix truncated timestamp: '4-06-08' → '2024-06-08'""" + if not ts: + return ts + m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) + if m: + year_short, month, day, rest = m.groups() + if len(year_short) == 1: + year = "202" + year_short + elif len(year_short) == 2: + year = "20" + year_short + else: + year = year_short + return f"{year}-{month}-{day}{rest}" + return ts + + +def fix_speaker(spk: str) -> str: + """Fix truncated speaker name""" + if not spk: + return spk + spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) + # Ensure at least 2 chars for Chinese names + if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk + "某" + return spk + + +class TRPGParser: + def __init__(self, model_dir: str): + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForTokenClassification.from_pretrained(model_dir) + self.id2label = self.model.config.id2label + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + + self.nlp = pipeline( + "token-classification", + model=self.model, + tokenizer=self.tokenizer, + aggregation_strategy="simple", + device=0 if torch.cuda.is_available() else -1, + ) + + def parse_line(self, text: str) -> Dict[str, Any]: + ents = self.nlp(text) + print(f"[DEBUG] Raw entities: {ents}") + out = {"metadata": {}, "content": []} + + # Merge adjacent entities + merged = [] + for e in sorted(ents, key=lambda x: x["start"]): + if merged and merged[-1]["entity_group"] == e["entity_group"]: + if e["start"] <= merged[-1]["end"]: + merged[-1]["end"] = max(merged[-1]["end"], e["end"]) + merged[-1]["score"] = min(merged[-1]["score"], e["score"]) + continue + merged.append(e) + + for e in merged: + group = e["entity_group"] + raw_text = text[e["start"] : e["end"]] + clean_text = re.sub( + r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text + ).strip() + if not clean_text: + clean_text = raw_text + + # Special fixes + if group == "timestamp": + clean_text = fix_timestamp(clean_text) + elif group == "speaker": + clean_text = fix_speaker(clean_text) + + if group in ["timestamp", "speaker"] and clean_text: + out["metadata"][group] = clean_text + elif group in ["dialogue", "action", "comment"] and clean_text: + out["content"].append( + { + "type": group, + "content": clean_text, + "confidence": round(float(e["score"]), 3), + } + ) + return out + + def parse_lines(self, texts: List[str]) -> List[Dict[str, Any]]: + return [self.parse_line(text) for text in texts] + + +# =========================== +# 5. 模型路径探测器(关键修复!) +# =========================== + + +def find_model_dir(requested_path: str, default_paths: List[str]) -> str: + """Robustly find model directory""" + # Check requested path first + candidates = [requested_path] + default_paths + + for path in candidates: + if not os.path.isdir(path): + continue + + # Check required files + required = ["config.json", "tokenizer.json"] + has_required = all((Path(path) / f).exists() for f in required) + + # Check model files (safetensors or pytorch) + model_files = ["model.safetensors", "pytorch_model.bin"] + has_model = any((Path(path) / f).exists() for f in model_files) + + if has_required and has_model: + return path + + # If not found, try subdirectories + for path in candidates: + if not os.path.isdir(path): + continue + for root, dirs, files in os.walk(path): + for d in dirs: + full_path = os.path.join(root, d) + has_required = all( + (Path(full_path) / f).exists() + for f in ["config.json", "tokenizer.json"] + ) + has_model = any( + (Path(full_path) / f).exists() + for f in ["model.safetensors", "pytorch_model.bin"] + ) + if has_required and has_model: + return full_path + + raise FileNotFoundError( + f"Model not found in any of: {candidates}\n" + "Required files: config.json, tokenizer.json, and (model.safetensors or pytorch_model.bin)\n" + "👉 Run training first: --train --conll ./data" + ) + + +def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128): + """Export model to ONNX format (fixed for local paths)""" + try: + from transformers import AutoTokenizer, AutoModelForTokenClassification + import torch + from torch.onnx import export as onnx_export + import os + + print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + + # ✅ 修复1:确保路径是绝对路径 + model_dir = os.path.abspath(model_dir) + if not os.path.exists(model_dir): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + # ✅ 修复2:显式指定 local_files_only=True + tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) + model = AutoModelForTokenClassification.from_pretrained( + model_dir, local_files_only=True + ) + model.eval() + + # Create dummy input + dummy_text = "莎莎 2024-06-08 21:46:26" + inputs = tokenizer( + dummy_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + ) + + # Ensure directory exists + os.makedirs(os.path.dirname(onnx_path), exist_ok=True) + + # Export to ONNX (使用 opset 18 以兼容现代 PyTorch) + onnx_export( + model, + (inputs["input_ids"], inputs["attention_mask"]), + onnx_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "logits": {0: "batch_size", 1: "sequence_length"}, + }, + ) + + # Verify ONNX model + import onnx + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + print( + f"✅ ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB" + ) + return True + + except Exception as e: + print(f"❌ ONNX export failed: {e}") + import traceback + + traceback.print_exc() + return False + + +# =========================== +# 6. CLI 入口 +# =========================== + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="TRPG NER: Train & Infer") + parser.add_argument("--train", action="store_true", help="Run training") + parser.add_argument( + "--conll", type=str, default="./data", help="Path to .conll files or dir" + ) + parser.add_argument( + "--model", type=str, default="hfl/minirbt-h256", help="Base model" + ) + parser.add_argument( + "--output", type=str, default="./models/trpg-ner-v1", help="Model output dir" + ) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument( + "--resume", type=str, default=None, help="Resume from checkpoint" + ) + parser.add_argument("--test", type=str, nargs="*", help="Test texts") + parser.add_argument( + "--export_onnx", action="store_true", help="Export model to ONNX" + ) + parser.add_argument( + "--onnx_path", + type=str, + default="./models/trpg-final/model.onnx", + help="ONNX output path", + ) + + args = parser.parse_args() + + if args.train: + try: + train_ner_model( + conll_data=args.conll, + model_name_or_path=args.model, + output_dir=args.output, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch, + resume_from_checkpoint=args.resume, + ) + print(f"✅ Training finished. Model saved to {args.output}") + except Exception as e: + print(f"❌ Training failed: {e}") + sys.exit(1) + + # Inference setup + default_model_paths = [ + args.output, + "./models/trpg-ner-v1", + "./models/trpg-ner-v2", + "./models/trpg-ner-v3", + "./cvrp-ner-model", + "./models", + ] + + try: + model_dir = find_model_dir(args.output, default_model_paths) + print(f"Using model from: {model_dir}") + parser = TRPGParser(model_dir=model_dir) + except FileNotFoundError as e: + print(f"{e}") + sys.exit(1) + except Exception as e: + print(f"Failed to load model: {e}") + sys.exit(1) + + # Run inference + if args.test: + for t in args.test: + print(f"\nInput: {t}") + result = parser.parse_line(t) + print("Parse:", result) + else: + # Demo + demo_texts = [ + "风雨 2024-06-08 21:44:59\n剧烈的疼痛从头颅深处一波波地涌出...", + "莎莎 2024-06-08 21:46:26\n“呜哇...”#下意识去拿法杖", + "白麗 霊夢 2024-06-08 21:50:03\n莎莎 的出目是 D10+7=6+7=13", + ] + print("\n🧪 Demo inference (using model from", model_dir, "):") + for i, t in enumerate(demo_texts, 1): + print(f"\nDemo {i}: {t[:50]}...") + result = parser.parse_line(t) + meta = result["metadata"] + content = result["content"] + print(f" Speaker: {meta.get('speaker', 'N/A')}") + print(f" Timestamp: {meta.get('timestamp', 'N/A')}") + if content: + first = content[0] + print( + f" Content: {first['type']}='{first['content'][:40]}{'...' if len(first['content'])>40 else ''}' (conf={first['confidence']})" + ) + + # ONNX export + if args.export_onnx: + # 智能确定模型目录:优先使用推理时找到的目录 + onnx_model_dir = model_dir + success = export_to_onnx( + model_dir=onnx_model_dir, + onnx_path=args.onnx_path, + max_length=128, + ) + if success: + print(f"🎉 ONNX model saved to {args.onnx_path}") + else: + sys.exit(1) + + +if __name__ == "__main__": + main() |
