diff options
| author | 2026-01-05 14:39:47 +0800 | |
|---|---|---|
| committer | 2026-01-05 14:40:15 +0800 | |
| commit | a9e98ae197a49b8a6629601e3be7b9d0507eb6da (patch) | |
| tree | de83a8802de5d3c177779d56963b0f960c80feab /src | |
| parent | 65f48da74e446df81b17d0cc9bf203b75947fff1 (diff) | |
| download | base-model-a9e98ae197a49b8a6629601e3be7b9d0507eb6da.tar.gz base-model-a9e98ae197a49b8a6629601e3be7b9d0507eb6da.zip | |
refactor: streamline import statements and improve print messages in training module
Diffstat (limited to 'src')
| -rw-r--r-- | src/base_model_trpgner/training/__init__.py | 32 |
1 files changed, 10 insertions, 22 deletions
diff --git a/src/base_model_trpgner/training/__init__.py b/src/base_model_trpgner/training/__init__.py index 4f8e30d..72ea3b3 100644 --- a/src/base_model_trpgner/training/__init__.py +++ b/src/base_model_trpgner/training/__init__.py @@ -51,17 +51,13 @@ def train_ner_model( TrainingArguments, Trainer, ) - from datasets import Dataset - from tqdm.auto import tqdm except ImportError as e: - raise ImportError( - "训练依赖未安装。请运行: pip install base-model-trpgner[train]" - ) from e + raise ImportError("训练依赖未安装。请运行: pip install base-model-trpgner[train]") from e # 导入数据处理函数 - from base_model_trpgner.utils.conll import load_conll_dataset, tokenize_and_align_labels + from base_model_trpgner.utils import load_conll_dataset, tokenize_and_align_labels - print(f"🚀 Starting training...") + print("Starting training...") # 加载数据 dataset, label_list = load_conll_dataset(conll_data) @@ -113,15 +109,15 @@ def train_ner_model( ) # 开始训练 - print("🚀 Starting training...") + print("Starting training...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) # 保存模型 - print("💾 Saving final model...") + print("Saving final model...") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) - print(f"✅ Training finished. Model saved to {output_dir}") + print(f"Training finished. Model saved to {output_dir}") def export_to_onnx( @@ -141,29 +137,22 @@ def export_to_onnx( 是否成功 """ try: - import torch from torch.onnx import export as onnx_export from transformers import AutoTokenizer, AutoModelForTokenClassification import onnx except ImportError as e: - raise ImportError( - "ONNX 导出依赖未安装。请运行: pip install onnx" - ) from e + raise ImportError("ONNX 导出依赖未安装。请运行: pip install onnx") from e - print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + print(f"Exporting model from {model_dir} to {onnx_path}...") model_dir = os.path.abspath(model_dir) if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory not found: {model_dir}") - # 加载模型 tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) - model = AutoModelForTokenClassification.from_pretrained( - model_dir, local_files_only=True - ) + model = AutoModelForTokenClassification.from_pretrained(model_dir, local_files_only=True) model.eval() - # 创建虚拟输入 dummy_text = "莎莎 2024-06-08 21:46:26" inputs = tokenizer( dummy_text, @@ -173,7 +162,6 @@ def export_to_onnx( max_length=max_length, ) - # 确保目录存在 os.makedirs(os.path.dirname(onnx_path), exist_ok=True) # 导出 ONNX @@ -198,7 +186,7 @@ def export_to_onnx( onnx.checker.check_model(onnx_model) size_mb = os.path.getsize(onnx_path) / 1024 / 1024 - print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB") + print(f"ONNX export successful! Size: {size_mb:.2f} MB") return True |
