summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/base_model_trpgner/training/__init__.py32
1 files changed, 10 insertions, 22 deletions
diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py
index 4f8e30d..72ea3b3 100644
--- a/src/base_model_trpgner/training/__init__.py
+++ b/src/base_model_trpgner/training/__init__.py
@@ -51,17 +51,13 @@ def train_ner_model(
TrainingArguments,
Trainer,
)
- from datasets import Dataset
- from tqdm.auto import tqdm
except ImportError as e:
- raise ImportError(
- "训练依赖未安装。请运行: pip install base-model-trpgner[train]"
- ) from e
+ raise ImportError("训练依赖未安装。请运行: pip install base-model-trpgner[train]") from e
# 导入数据处理函数
- from base_model_trpgner.utils.conll import load_conll_dataset, tokenize_and_align_labels
+ from base_model_trpgner.utils import load_conll_dataset, tokenize_and_align_labels
- print(f"🚀 Starting training...")
+ print("Starting training...")
# 加载数据
dataset, label_list = load_conll_dataset(conll_data)
@@ -113,15 +109,15 @@ def train_ner_model(
)
# 开始训练
- print("🚀 Starting training...")
+ print("Starting training...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# 保存模型
- print("💾 Saving final model...")
+ print("Saving final model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
- print(f"✅ Training finished. Model saved to {output_dir}")
+ print(f"Training finished. Model saved to {output_dir}")
def export_to_onnx(
@@ -141,29 +137,22 @@ def export_to_onnx(
是否成功
"""
try:
- import torch
from torch.onnx import export as onnx_export
from transformers import AutoTokenizer, AutoModelForTokenClassification
import onnx
except ImportError as e:
- raise ImportError(
- "ONNX 导出依赖未安装。请运行: pip install onnx"
- ) from e
+ raise ImportError("ONNX 导出依赖未安装。请运行: pip install onnx") from e
- print(f"📤 Exporting model from {model_dir} to {onnx_path}...")
+ print(f"Exporting model from {model_dir} to {onnx_path}...")
model_dir = os.path.abspath(model_dir)
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory not found: {model_dir}")
- # 加载模型
tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
- model = AutoModelForTokenClassification.from_pretrained(
- model_dir, local_files_only=True
- )
+ model = AutoModelForTokenClassification.from_pretrained(model_dir, local_files_only=True)
model.eval()
- # 创建虚拟输入
dummy_text = "莎莎 2024-06-08 21:46:26"
inputs = tokenizer(
dummy_text,
@@ -173,7 +162,6 @@ def export_to_onnx(
max_length=max_length,
)
- # 确保目录存在
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
# 导出 ONNX
@@ -198,7 +186,7 @@ def export_to_onnx(
onnx.checker.check_model(onnx_model)
size_mb = os.path.getsize(onnx_path) / 1024 / 1024
- print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB")
+ print(f"ONNX export successful! Size: {size_mb:.2f} MB")
return True