aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py83
1 files changed, 20 insertions, 63 deletions
diff --git a/main.py b/main.py
index 2f96207..62c9027 100644
--- a/main.py
+++ b/main.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
"""
TRPG NER 训练与推理脚本 (Robust Edition)
- 自动探测模型路径(支持 safetensors/pytorch)
@@ -27,13 +25,8 @@ from transformers import (
from datasets import Dataset
from tqdm.auto import tqdm
-# 抑制 transformers 警告
-hf_logging.set_verbosity_error()
-
-# ===========================
-# 1. CoNLL 解析器(自动 word→char 转换)
-# ===========================
+hf_logging.set_verbosity_error()
def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
@@ -143,9 +136,7 @@ def parse_conll_file(filepath: str) -> List[Dict[str, Any]]:
current_labels.append(label)
if current_text:
- samples.append(
- {"text": "".join(current_text), "char_labels": current_labels.copy()}
- )
+ samples.append({"text": "".join(current_text), "char_labels": current_labels.copy()})
return samples
@@ -186,9 +177,7 @@ def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]:
label_list.append(i_label)
print(f"⚠️ Added missing {i_label} for {label}")
- print(
- f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}"
- )
+ print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}")
return Dataset.from_list(all_samples), label_list
@@ -286,20 +275,15 @@ def train_ner_model(
tokenizer=tokenizer,
)
- print("🚀 Starting training...")
+ print("Starting training...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
- print("💾 Saving final model...")
+ print("Saving final model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
return model, tokenizer, label_list
-# ===========================
-# 4. 推理函数(增强版)
-# ===========================
-
-
def fix_timestamp(ts: str) -> str:
"""Fix truncated timestamp: '4-06-08' → '2024-06-08'"""
if not ts:
@@ -362,9 +346,7 @@ class TRPGParser:
for e in merged:
group = e["entity_group"]
raw_text = text[e["start"] : e["end"]]
- clean_text = re.sub(
- r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text
- ).strip()
+ clean_text = re.sub(r"^[<\[\"“「*(#\s]+|[>\]\"”」*)\s]+$", "", raw_text).strip()
if not clean_text:
clean_text = raw_text
@@ -390,11 +372,6 @@ class TRPGParser:
return [self.parse_line(text) for text in texts]
-# ===========================
-# 5. 模型路径探测器(关键修复!)
-# ===========================
-
-
def find_model_dir(requested_path: str, default_paths: List[str]) -> str:
"""Robustly find model directory"""
# Check requested path first
@@ -423,8 +400,7 @@ def find_model_dir(requested_path: str, default_paths: List[str]) -> str:
for d in dirs:
full_path = os.path.join(root, d)
has_required = all(
- (Path(full_path) / f).exists()
- for f in ["config.json", "tokenizer.json"]
+ (Path(full_path) / f).exists() for f in ["config.json", "tokenizer.json"]
)
has_model = any(
(Path(full_path) / f).exists()
@@ -436,7 +412,7 @@ def find_model_dir(requested_path: str, default_paths: List[str]) -> str:
raise FileNotFoundError(
f"Model not found in any of: {candidates}\n"
"Required files: config.json, tokenizer.json, and (model.safetensors or pytorch_model.bin)\n"
- "👉 Run training first: --train --conll ./data"
+ "Run training first: --train --conll ./data"
)
@@ -448,18 +424,14 @@ def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128):
from torch.onnx import export as onnx_export
import os
- print(f"📤 Exporting model from {model_dir} to {onnx_path}...")
+ print(f"Exporting model from {model_dir} to {onnx_path}...")
- # ✅ 修复1:确保路径是绝对路径
model_dir = os.path.abspath(model_dir)
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory not found: {model_dir}")
- # ✅ 修复2:显式指定 local_files_only=True
tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
- model = AutoModelForTokenClassification.from_pretrained(
- model_dir, local_files_only=True
- )
+ model = AutoModelForTokenClassification.from_pretrained(model_dir, local_files_only=True)
model.eval()
# Create dummy input
@@ -498,47 +470,32 @@ def export_to_onnx(model_dir: str, onnx_path: str, max_length: int = 128):
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
- print(
- f"✅ ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB"
- )
+ print(f"ONNX export successful! Size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB")
return True
except Exception as e:
- print(f"❌ ONNX export failed: {e}")
+ print(f"ONNX export failed: {e}")
import traceback
traceback.print_exc()
return False
-# ===========================
-# 6. CLI 入口
-# ===========================
-
-
def main():
import argparse
parser = argparse.ArgumentParser(description="TRPG NER: Train & Infer")
parser.add_argument("--train", action="store_true", help="Run training")
- parser.add_argument(
- "--conll", type=str, default="./data", help="Path to .conll files or dir"
- )
- parser.add_argument(
- "--model", type=str, default="hfl/minirbt-h256", help="Base model"
- )
+ parser.add_argument("--conll", type=str, default="./data", help="Path to .conll files or dir")
+ parser.add_argument("--model", type=str, default="hfl/minirbt-h256", help="Base model")
parser.add_argument(
"--output", type=str, default="./models/trpg-ner-v1", help="Model output dir"
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch", type=int, default=4)
- parser.add_argument(
- "--resume", type=str, default=None, help="Resume from checkpoint"
- )
+ parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--test", type=str, nargs="*", help="Test texts")
- parser.add_argument(
- "--export_onnx", action="store_true", help="Export model to ONNX"
- )
+ parser.add_argument("--export_onnx", action="store_true", help="Export model to ONNX")
parser.add_argument(
"--onnx_path",
type=str,
@@ -558,9 +515,9 @@ def main():
per_device_train_batch_size=args.batch,
resume_from_checkpoint=args.resume,
)
- print(f"✅ Training finished. Model saved to {args.output}")
+ print(f"Training finished. Model saved to {args.output}")
except Exception as e:
- print(f"❌ Training failed: {e}")
+ print(f"Training failed: {e}")
sys.exit(1)
# Inference setup
@@ -597,7 +554,7 @@ def main():
"莎莎 2024-06-08 21:46:26\n“呜哇...”#下意识去拿法杖",
"白麗 霊夢 2024-06-08 21:50:03\n莎莎 的出目是 D10+7=6+7=13",
]
- print("\n🧪 Demo inference (using model from", model_dir, "):")
+ print("\nDemo inference (using model from", model_dir, "):")
for i, t in enumerate(demo_texts, 1):
print(f"\nDemo {i}: {t[:50]}...")
result = parser.parse_line(t)
@@ -621,7 +578,7 @@ def main():
max_length=128,
)
if success:
- print(f"🎉 ONNX model saved to {args.onnx_path}")
+ print(f"ONNX model saved to {args.onnx_path}")
else:
sys.exit(1)