aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/training/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/base_model_trpgner/training/__init__.py')
-rw-r--r--src/base_model_trpgner/training/__init__.py205
1 files changed, 205 insertions, 0 deletions
diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py
new file mode 100644
index 0000000..ccf3c03
--- /dev/null
+++ b/src/base_model_trpgner/training/__init__.py
@@ -0,0 +1,205 @@
+"""
+训练模块
+
+提供 TRPG NER 模型训练功能。
+
+注意: 使用此模块需要安装训练依赖:
+ pip install base-model-trpgner[train]
+"""
+
+import os
+from typing import Optional, List
+from pathlib import Path
+
+
+def train_ner_model(
+ conll_data: str,
+ model_name_or_path: str = "hfl/minirbt-h256",
+ output_dir: str = "./models/trpg-ner-v1",
+ num_train_epochs: int = 20,
+ per_device_train_batch_size: int = 4,
+ learning_rate: float = 5e-5,
+ max_length: int = 128,
+ resume_from_checkpoint: Optional[str] = None,
+) -> None:
+ """
+ 训练 NER 模型
+
+ Args:
+ conll_data: CoNLL 格式数据文件或目录
+ model_name_or_path: 基础模型名称或路径
+ output_dir: 模型输出目录
+ num_train_epochs: 训练轮数
+ per_device_train_batch_size: 批处理大小
+ learning_rate: 学习率
+ max_length: 最大序列长度
+ resume_from_checkpoint: 恢复检查点路径
+
+ Examples:
+ >>> from basemodeltrpgner.training import train_ner_model
+ >>> train_ner_model(
+ ... conll_data="./data",
+ ... output_dir="./my_model",
+ ... epochs=10
+ ... )
+ """
+ try:
+ import torch
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForTokenClassification,
+ TrainingArguments,
+ Trainer,
+ )
+ from datasets import Dataset
+ from tqdm.auto import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "训练依赖未安装。请运行: pip install base-model-trpgner[train]"
+ ) from e
+
+ # 导入数据处理函数
+ from base_model_trpgner.utils.conll import load_conll_dataset, tokenize_and_align_labels
+
+ print(f"🚀 Starting training...")
+
+ # 加载数据
+ dataset, label_list = load_conll_dataset(conll_data)
+ label2id = {label: i for i, label in enumerate(label_list)}
+ id2label = {i: label for i, label in enumerate(label_list)}
+
+ # 初始化模型
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+ if tokenizer.model_max_length > 1000:
+ tokenizer.model_max_length = max_length
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_name_or_path,
+ num_labels=len(label_list),
+ id2label=id2label,
+ label2id=label2id,
+ ignore_mismatched_sizes=True,
+ )
+
+ # Tokenize
+ tokenized_dataset = dataset.map(
+ lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length),
+ batched=True,
+ remove_columns=["text", "char_labels"],
+ )
+
+ # 训练参数
+ training_args = TrainingArguments(
+ output_dir=output_dir,
+ learning_rate=learning_rate,
+ per_device_train_batch_size=per_device_train_batch_size,
+ num_train_epochs=num_train_epochs,
+ logging_steps=5,
+ save_steps=200,
+ save_total_limit=2,
+ do_eval=False,
+ report_to="none",
+ no_cuda=not torch.cuda.is_available(),
+ load_best_model_at_end=False,
+ push_to_hub=False,
+ fp16=torch.cuda.is_available(),
+ )
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=tokenized_dataset,
+ tokenizer=tokenizer,
+ )
+
+ # 开始训练
+ print("🚀 Starting training...")
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+
+ # 保存模型
+ print("💾 Saving final model...")
+ trainer.save_model(output_dir)
+ tokenizer.save_pretrained(output_dir)
+
+ print(f"✅ Training finished. Model saved to {output_dir}")
+
+
+def export_to_onnx(
+ model_dir: str,
+ onnx_path: str,
+ max_length: int = 128,
+) -> bool:
+ """
+ 将训练好的模型导出为 ONNX 格式
+
+ Args:
+ model_dir: 模型目录
+ onnx_path: ONNX 输出路径
+ max_length: 最大序列长度
+
+ Returns:
+ 是否成功
+ """
+ try:
+ import torch
+ from torch.onnx import export as onnx_export
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
+ import onnx
+ except ImportError as e:
+ raise ImportError(
+ "ONNX 导出依赖未安装。请运行: pip install onnx"
+ ) from e
+
+ print(f"📤 Exporting model from {model_dir} to {onnx_path}...")
+
+ model_dir = os.path.abspath(model_dir)
+ if not os.path.exists(model_dir):
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
+
+ # 加载模型
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_dir, local_files_only=True
+ )
+ model.eval()
+
+ # 创建虚拟输入
+ dummy_text = "莎莎 2024-06-08 21:46:26"
+ inputs = tokenizer(
+ dummy_text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ )
+
+ # 确保目录存在
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
+
+ # 导出 ONNX
+ onnx_export(
+ model,
+ (inputs["input_ids"], inputs["attention_mask"]),
+ onnx_path,
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=["input_ids", "attention_mask"],
+ output_names=["logits"],
+ dynamic_axes={
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
+ "logits": {0: "batch_size", 1: "sequence_length"},
+ },
+ )
+
+ # 验证 ONNX 模型
+ onnx_model = onnx.load(onnx_path)
+ onnx.checker.check_model(onnx_model)
+
+ size_mb = os.path.getsize(onnx_path) / 1024 / 1024
+ print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB")
+ return True
+
+
+__all__ = ["train_ner_model", "export_to_onnx"]