diff options
Diffstat (limited to 'main.py')
| -rw-r--r-- | main.py | 83 |
1 files changed, 20 insertions, 63 deletions
@@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ TRPG NER 训练与推理脚本 (Robust Edition) - 自动探测模型路径(支持 safetensors/pytorch) @@ -27,13 +25,8 @@ from transformers import ( from datasets import Dataset from tqdm.auto import tqdm -# 抑制 transformers 警告 -hf_logging.set_verbosity_error() - -# =========================== -# 1. CoNLL 解析器(自动 word→char 转换) -# =========================== +hf_logging.set_verbosity_error() def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: @@ -143,9 +136,7 @@ def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: current_labels.append(label) if current_text: - samples.append( - {"text": "".join(current_text), "char_labels": current_labels.copy()} - ) + samples.append({"text": "".join(current_text), "char_labels": current_labels.copy()}) return samples @@ -186,9 +177,7 @@ def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: label_list.append(i_label) print(f"⚠️ Added missing {i_label} for {label}") - print( - f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}" - ) + print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}") return Dataset.from_list(all_samples), label_list @@ -286,20 +275,15 @@ def train_ner_model( tokenizer=tokenizer, ) - print("🚀 Starting training...") + print("Starting training...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) - print("💾 Saving final model...") + print("Saving final model...") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) return model, tokenizer, label_list -# =========================== -# 4. 推理函数(增强版) -# =========================== - - def fix_timestamp(ts: str) -> str: """Fix truncated timestamp: '4-06-08' → '2024-06-08'""" if not ts: @@ -362,9 +346,7 @@ class TRPGParser: for e in merged: group = e["entity_group"] raw_text = text[e["start"] : e["end"]] - clean_text = re.sub( - r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text - ).strip() + clean_text = re.sub(r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text).strip() if not clean_text: clean_text = raw_text @@ -390,11 +372,6 @@ class TRPGParser: return [self.parse_line(text) for text in texts] -# =========================== -# 5. 模型路径探测器(关键修复!) -# =========================== - - def find_model_dir(requested_path: str, default_paths: List[str]) -> str: """Robustly find model directory""" # Check requested path first @@ -423,8 +400,7 @@ def find_model_dir(requested_path: str, default_paths: List[str]) -> str: for d in dirs: full_path = os.path.join(root, d) has_required = all( - (Path(full_path) / f).exists() - for f in ["config.json", "tokenizer.json"] + (Path(full_path) / f).exists() for f in ["config.json", "tokenizer.json"] ) has_model = any( (Path(full_path) / f).exists() @@ -436,7 +412,7 @@ def find_model_dir(requested_path: str, default_paths: List[str]) -> str: raise FileNotFoundError( f"Model not found in any of: {candidates}\n" "Required files: config.json, tokenizer.json, and (model.safetensors or pytorch_model.bin)\n" - "👉 Run training first: --train --conll ./data" + "Run training first: --train --conll ./data" ) @@ -448,18 +424,14 @@ def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128): from torch.onnx import export as onnx_export import os - print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + print(f"Exporting model from {model_dir} to {onnx_path}...") - # ✅ 修复1:确保路径是绝对路径 model_dir = os.path.abspath(model_dir) if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory not found: {model_dir}") - # ✅ 修复2:显式指定 local_files_only=True tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) - model = AutoModelForTokenClassification.from_pretrained( - model_dir, local_files_only=True - ) + model = AutoModelForTokenClassification.from_pretrained(model_dir, local_files_only=True) model.eval() # Create dummy input @@ -498,47 +470,32 @@ def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128): onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) - print( - f"✅ ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB" - ) + print(f"ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB") return True except Exception as e: - print(f"❌ ONNX export failed: {e}") + print(f"ONNX export failed: {e}") import traceback traceback.print_exc() return False -# =========================== -# 6. CLI 入口 -# =========================== - - def main(): import argparse parser = argparse.ArgumentParser(description="TRPG NER: Train & Infer") parser.add_argument("--train", action="store_true", help="Run training") - parser.add_argument( - "--conll", type=str, default="./data", help="Path to .conll files or dir" - ) - parser.add_argument( - "--model", type=str, default="hfl/minirbt-h256", help="Base model" - ) + parser.add_argument("--conll", type=str, default="./data", help="Path to .conll files or dir") + parser.add_argument("--model", type=str, default="hfl/minirbt-h256", help="Base model") parser.add_argument( "--output", type=str, default="./models/trpg-ner-v1", help="Model output dir" ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch", type=int, default=4) - parser.add_argument( - "--resume", type=str, default=None, help="Resume from checkpoint" - ) + parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") parser.add_argument("--test", type=str, nargs="*", help="Test texts") - parser.add_argument( - "--export_onnx", action="store_true", help="Export model to ONNX" - ) + parser.add_argument("--export_onnx", action="store_true", help="Export model to ONNX") parser.add_argument( "--onnx_path", type=str, @@ -558,9 +515,9 @@ def main(): per_device_train_batch_size=args.batch, resume_from_checkpoint=args.resume, ) - print(f"✅ Training finished. Model saved to {args.output}") + print(f"Training finished. Model saved to {args.output}") except Exception as e: - print(f"❌ Training failed: {e}") + print(f"Training failed: {e}") sys.exit(1) # Inference setup @@ -597,7 +554,7 @@ def main(): "莎莎 2024-06-08 21:46:26\n“呜哇...”#下意识去拿法杖", "白麗 霊夢 2024-06-08 21:50:03\n莎莎 的出目是 D10+7=6+7=13", ] - print("\n🧪 Demo inference (using model from", model_dir, "):") + print("\nDemo inference (using model from", model_dir, "):") for i, t in enumerate(demo_texts, 1): print(f"\nDemo {i}: {t[:50]}...") result = parser.parse_line(t) @@ -621,7 +578,7 @@ def main(): max_length=128, ) if success: - print(f"🎉 ONNX model saved to {args.onnx_path}") + print(f"ONNX model saved to {args.onnx_path}") else: sys.exit(1) |
