diff options
| author | 2025-12-30 19:54:08 +0800 | |
|---|---|---|
| committer | 2025-12-30 19:54:08 +0800 | |
| commit | 575114661ef9afb95df2a211e1d8498686340e6b (patch) | |
| tree | 91f1646cececb1597a9246865e89b52e059d3cfa | |
| parent | 7ac684f1f82023c6284cd7d7efde11b8dc98c149 (diff) | |
| download | base-model-575114661ef9afb95df2a211e1d8498686340e6b.tar.gz base-model-575114661ef9afb95df2a211e1d8498686340e6b.zip | |
feat: Refactor and enhance TRPG NER model SDK
- Removed deprecated `word_conll_to_char_conll.py` utility and integrated its functionality into the new `utils` module.
- Introduced a comprehensive GitHub Actions workflow for automated publishing to PyPI and GitHub Releases.
- Added `__init__.py` files to establish package structure for `basemodel`, `inference`, `training`, and `utils` modules.
- Implemented model downloading functionality in `download_model.py` to fetch pre-trained ONNX models.
- Developed `TRPGParser` class for ONNX-based inference, including methods for parsing TRPG logs.
- Created training utilities in `training/__init__.py` for NER model training with Hugging Face Transformers.
- Enhanced utility functions for CoNLL file parsing and dataset creation.
- Added command-line interface for converting CoNLL files to datasets with validation options.
| -rw-r--r-- | .github/workflows/publish.yml | 201 | ||||
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | pyproject.toml | 116 | ||||
| -rw-r--r-- | src/basemodel/__init__.py | 36 | ||||
| -rw-r--r-- | src/basemodel/download_model.py | 70 | ||||
| -rw-r--r-- | src/basemodel/inference/__init__.py | 292 | ||||
| -rw-r--r-- | src/basemodel/training/__init__.py | 205 | ||||
| -rw-r--r-- | src/basemodel/utils/__init__.py | 192 | ||||
| -rw-r--r-- | utils/conll_to_dataset.py (renamed from src/utils/conll_to_dataset.py) | 0 | ||||
| -rw-r--r-- | utils/word_conll_to_char_conll.py (renamed from src/utils/word_conll_to_char_conll.py) | 0 |
10 files changed, 1104 insertions, 10 deletions
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..2004132 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,201 @@ +name: Publish to PyPI & GitHub Release + +on: + push: + tags: + - 'v*.*.*' + workflow_dispatch: + inputs: + create_test: + description: 'Publish to Test PyPI' + required: false + default: false + type: boolean + +permissions: + contents: write + id-token: write + +jobs: + build: + name: Build distribution + runs-on: ubuntu-latest + + outputs: + version: ${{ steps.version.outputs.version }} + + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Extract version from tag + id: version + run: | + VERSION=${{ github.ref_name }} + echo "version=${VERSION#v}" >> $GITHUB_OUTPUT + echo "Tag: $VERSION" + echo "Version: ${VERSION#v}" + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Build with uv + run: | + uv build + + - name: Check distribution + run: | + uv pip install twine + twine check dist/* + + - uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + - name: Prepare ONNX artifact + run: | + mkdir -p onnx-artifact + cp models/trpg-final/model.onnx onnx-artifact/ + cp models/trpg-final/model.onnx.data onnx-artifact/ || true + ls -lh onnx-artifact/ + + - uses: actions/upload-artifact@v4 + with: + name: onnx-model + path: onnx-artifact/ + + publish-test-pypi: + name: Publish to Test PyPI + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' && inputs.create_test == true + + environment: + name: test-pypi + url: https://test.pypi.org/p/base-model + + permissions: + id-token: write + + steps: + - name: Download dist + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to Test PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + publish-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + + environment: + name: pypi + url: https://pypi.org/p/base-model + + permissions: + id-token: write + + steps: + - name: Download dist + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + create-release: + name: Create GitHub Release with ONNX + needs: [build, publish-pypi] + runs-on: ubuntu-latest + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + + permissions: + contents: write + + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Generate CHANGELOG + id: changelog + uses: requarks/changelog-action@v1 + with: + token: ${{ github.token }} + tag: ${{ github.ref_name }} + includeInvalidCommits: true + changelogFilePath: CHANGELOG.md + writeToFile: true + useGitmojis: false + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: '*' + path: artifacts/ + merge-multiple: true + + - name: Create Release with ONNX + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + VERSION="${{ github.ref_name }}" + + cat > release_notes.md << 'EOF' + ## 📦 安装 + + ### pip 安装 + ```bash + pip install base-model + ``` + + ### 使用 uv(推荐) + ```bash + uv pip install base-model + ``` + + ### 训练模式 + ```bash + pip install base-model[train] + ``` + + ## 🚀 快速开始 + ```python + from basemodel import TRPGParser + + parser = TRPGParser() + result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + print(result) + ``` + + --- + + ${{ steps.changelog.outputs.changes }} + EOF + + # 上传 Python 包和 ONNX 模型 + gh release create "${VERSION}" \ + --repo "${{ github.repository }}" \ + --notes-file release_notes.md \ + --title "🚀 ${VERSION}" \ + artifacts/dist/* \ + artifacts/onnx-artifact/* || true + + - name: Commit CHANGELOG.md + if: hashFiles('CHANGELOG.md') != '' + uses: stefanzweifel/git-auto-commit-action@v7 + with: + branch: main + commit_message: "docs: update CHANGELOG.md for ${{ github.ref_name }} [skip ci]" + file_pattern: CHANGELOG.md @@ -165,4 +165,6 @@ uv.lock # model models/ + +# dataset dataset/
\ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0e9a82b..db858a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,115 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [project] -name = "base-model" +name = "base-model-trpgner" version = "0.1.0" -description = "Add your description here" +description = "HydroRoll TRPG NER 模型 - 桌上角色扮演游戏日志命名实体识别" +authors = [ + { name = "HsiangNianian", email = "leader@hydroroll.team" } +] readme = "README.md" requires-python = ">=3.12" +license = { text = "AFL-3.0" } +keywords = ["hydroroll", "trpg", "nlp", "ner", "chinese", "onnx", "robot framework"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Academic Free License (AFL)", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Text Processing :: Linguistic", +] + dependencies = [ - "gradio>=6.2.0", - "onnx>=1.20.0", + "numpy>=1.24.0", "onnxruntime>=1.23.2", - "onnxscript>=0.5.7", - "pynvml>=13.0.1", - "torch>=2.9.1", "transformers>=4.57.3", ] -[[tool.uv.index]] -url = "https://mirrors.aliyun.com/pypi/simple" -default = true +[project.optional-dependencies] +train = [ + "torch>=2.9.1", + "datasets>=2.18.0", + "accelerate>=0.27.0", + "tqdm>=4.66.0", +] +dev = [ + "base-model-trpgner[train]", + "pytest>=8.0.0", + "black>=24.0.0", + "ruff>=0.1.0", +] +webui = [ + "base-model-trpgner[train]", + "gradio>=6.2.0", + "scikit-learn>=1.4.0", +] +all = [ + "base-model-trpgner[train,webui,dev]", +] + +[project.urls] +Homepage = "https://ailab.hydroroll.team/" +Repository = "https://github.com/HydroRoll-Team/base-model" +Documentation = "https://ailab.hydroroll.team/" +Issues = "https://github.com/HydroRoll-Team/base-model/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/basemodel"] +# 只包含 ONNX 推理需要的文件(约 41MB) +artifacts = [ + "src/basemodel/**/*.py", + "models/trpg-final/model.onnx", + "models/trpg-final/model.onnx.data", + "models/trpg-final/config.json", + "models/trpg-final/tokenizer.json", + "models/trpg-final/tokenizer_config.json", + "models/trpg-final/special_tokens_map.json", + "models/trpg-final/vocab.txt", +] +# 共享数据:模型文件安装位置 +[tool.hatch.build.targets.wheel.shared-data] +"models/trpg-final" = "basemodel/models/trpg-final" + +[tool.hatch.build.targets.sdist] +include = [ + "/src", + "/models/trpg-final", + "/README.md", + "/COPYING", +] + +[tool.ruff] +exclude = [ + ".bzr", ".direnv", ".eggs", ".git", ".git-rewrite", ".hg", + ".ipynb_checkpoints", ".mypy_cache", ".nox", ".pants.d", + ".pyenv", ".pytest_cache", ".pytype", ".ruff_cache", ".svn", + ".tox", ".venv", ".vscode", "__pypackages__", "_build", + "buck-out", "build", "dist", "node_modules", "site-packages", "venv", +] +line-length = 100 +indent-width = 4 +target-version = "py312" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I", "N", "W"] +ignore = ["E501"] +fixable = ["ALL"] +unfixable = [] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.black] +line-length = 100 +target-version = ["py312"] + +[tool.uv] +dev-dependencies = []
\ No newline at end of file diff --git a/src/basemodel/__init__.py b/src/basemodel/__init__.py new file mode 100644 index 0000000..7287df4 --- /dev/null +++ b/src/basemodel/__init__.py @@ -0,0 +1,36 @@ +""" +base-model - HydroRoll TRPG NER 模型 SDK + +这是一个用于 TRPG(桌上角色扮演游戏)日志命名实体识别的 Python SDK。 + +基本用法: + >>> from basemodel import TRPGParser + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> print(result) + {'metadata': {'speaker': '风雨', 'timestamp': '2024-06-08 21:44:59'}, 'content': [...]} + +训练功能(需要额外安装): + >>> pip install base-model[train] + >>> from basemodel.training import train_ner_model + >>> train_ner_model(conll_data="./data", output_dir="./model") +""" + +from basemodel.inference import TRPGParser, parse_line, parse_lines + +try: + from importlib.metadata import version + __version__ = version("base-model") +except Exception: + __version__ = "0.1.0.dev" + +__all__ = [ + "__version__", + "TRPGParser", + "parse_line", + "parse_lines", +] + + +def get_version(): + return __version__ diff --git a/src/basemodel/download_model.py b/src/basemodel/download_model.py new file mode 100644 index 0000000..2d65099 --- /dev/null +++ b/src/basemodel/download_model.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +模型下载脚本 + +自动下载预训练的 ONNX 模型到用户缓存目录。 +""" + +import os +import sys +from pathlib import Path +import urllib.request + + +def download_model( + output_dir: str = None, + url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" +): + """ + 下载 ONNX 模型 + + Args: + output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/ + url: 模型下载 URL + """ + if output_dir is None: + output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "model.onnx" + + if output_path.exists(): + print(f"✅ 模型已存在: {output_path}") + return str(output_path) + + print(f"📥 正在下载模型到 {output_path}...") + print(f" URL: {url}") + + try: + urllib.request.urlretrieve(url, output_path) + print(f"✅ 模型下载成功!") + return str(output_path) + except Exception as e: + print(f"❌ 下载失败: {e}") + print(f" 请手动从以下地址下载模型:") + print(f" {url}") + sys.exit(1) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型") + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)" + ) + parser.add_argument( + "--url", + type=str, + default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx", + help="模型下载 URL" + ) + + args = parser.parse_args() + + download_model(args.output_dir, args.url) diff --git a/src/basemodel/inference/__init__.py b/src/basemodel/inference/__init__.py new file mode 100644 index 0000000..93a185f --- /dev/null +++ b/src/basemodel/inference/__init__.py @@ -0,0 +1,292 @@ +""" +ONNX 推理模块 + +提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。 +""" + +import os +from typing import List, Dict, Any, Optional +from pathlib import Path + +try: + import numpy as np + import onnxruntime as ort + from transformers import AutoTokenizer +except ImportError as e: + raise ImportError( + "依赖未安装。请运行: pip install onnxruntime transformers numpy" + ) from e + + +# 默认模型路径(相对于包安装位置) +DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final" +# 远程模型 URL(用于自动下载) +MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx" + + +class TRPGParser: + """ + TRPG 日志解析器(基于 ONNX) + + Args: + model_path: ONNX 模型路径,默认使用内置模型 + tokenizer_path: tokenizer 配置路径,默认与 model_path 相同 + device: 推理设备,"cpu" 或 "cuda" + + Examples: + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> print(result['metadata']['speaker']) + '风雨' + """ + + def __init__( + self, + model_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + device: str = "cpu", + ): + # 确定模型路径 + if model_path is None: + model_path = self._get_default_model_path() + + if tokenizer_path is None: + tokenizer_path = Path(model_path).parent + + self.model_path = Path(model_path) + self.tokenizer_path = Path(tokenizer_path) + self.device = device + + # 加载模型 + self._load_model() + + def _get_default_model_path(self) -> str: + """获取默认模型路径""" + # 1. 尝试相对于项目根目录 + project_root = Path(__file__).parent.parent.parent.parent + local_model = project_root / "models" / "trpg-final" / "model.onnx" + if local_model.exists(): + return str(local_model) + + # 2. 尝试用户数据目录 + from pathlib import Path + user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final" + user_model = user_model_dir / "model.onnx" + if user_model.exists(): + return str(user_model) + + # 3. 抛出错误,提示下载 + raise FileNotFoundError( + f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n" + f"或运行: python -m basemodel.download_model" + ) + + def _load_model(self): + """加载 ONNX 模型和 Tokenizer""" + # 加载 tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + str(self.tokenizer_path), + local_files_only=True, + ) + + # 加载 ONNX 模型 + providers = ["CPUExecutionProvider"] + if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + + self.session = ort.InferenceSession( + str(self.model_path), + providers=providers, + ) + + # 加载标签映射 + import json + config_path = self.tokenizer_path / "config.json" + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()} + else: + # 默认标签 + self.id2label = { + 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment", + 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker", + 9: "B-timestamp", 10: "I-timestamp", + } + + def parse(self, text: str) -> Dict[str, Any]: + """ + 解析单条 TRPG 日志 + + Args: + text: 待解析的日志文本 + + Returns: + 包含 metadata 和 content 的字典 + - metadata: speaker, timestamp + - content: dialogue, action, comment 列表 + + Examples: + >>> parser = TRPGParser() + >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...") + >>> result['metadata']['speaker'] + '风雨' + """ + # Tokenize + inputs = self.tokenizer( + text, + return_tensors="np", + return_offsets_mapping=True, + padding="max_length", + truncation=True, + max_length=128, + ) + + # 推理 + outputs = self.session.run( + ["logits"], + { + "input_ids": inputs["input_ids"].astype(np.int64), + "attention_mask": inputs["attention_mask"].astype(np.int64), + }, + ) + + # 后处理 + logits = outputs[0][0] + predictions = np.argmax(logits, axis=-1) + offsets = inputs["offset_mapping"][0] + + # 聚合实体 + entities = self._group_entities(predictions, offsets, logits) + + # 构建结果 + result = {"metadata": {}, "content": []} + for ent in entities: + if ent["start"] >= len(text) or ent["end"] > len(text): + continue + + raw_text = text[ent["start"]: ent["end"]] + clean_text = self._clean_text(raw_text, ent["type"]) + + if not clean_text.strip(): + continue + + if ent["type"] in ["timestamp", "speaker"]: + result["metadata"][ent["type"]] = clean_text + elif ent["type"] in ["dialogue", "action", "comment"]: + result["content"].append({ + "type": ent["type"], + "content": clean_text, + "confidence": round(ent["score"], 3), + }) + + return result + + def _group_entities(self, predictions, offsets, logits): + """将 token 级别的预测聚合为实体""" + entities = [] + current = None + + for i in range(len(predictions)): + start, end = offsets[i] + if start == end: # special tokens + continue + + pred_id = int(predictions[i]) + label = self.id2label.get(pred_id, "O") + + if label == "O": + if current: + entities.append(current) + current = None + continue + + tag_type = label[2:] if len(label) > 2 else "O" + + if label.startswith("B-"): + if current: + entities.append(current) + current = { + "type": tag_type, + "start": int(start), + "end": int(end), + "score": float(np.max(logits[i])), + } + elif label.startswith("I-") and current and current["type"] == tag_type: + current["end"] = int(end) + else: + if current: + entities.append(current) + current = None + + if current: + entities.append(current) + + return entities + + def _clean_text(self, text: str, group: str) -> str: + """清理提取的文本""" + import re + + text = text.strip() + + # 移除周围符号 + if group == "comment": + text = re.sub(r"^[((]+|[))]+$", "", text) + elif group == "dialogue": + text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text) + elif group == "action": + text = re.sub(r"^[*#]+|[*#]+$", "", text) + + # 修复时间戳 + if group == "timestamp" and text and text[0].isdigit(): + if len(text) > 2 and text[2] == "-": + text = "20" + text + + return text + + def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: + """ + 批量解析多条日志 + + Args: + texts: 日志文本列表 + + Returns: + 解析结果列表 + """ + return [self.parse(text) for text in texts] + + +# 便捷函数 +def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]: + """ + 解析单条日志的便捷函数 + + Args: + text: 日志文本 + model_path: 可选的模型路径 + + Returns: + 解析结果字典 + """ + parser = TRPGParser(model_path=model_path) + return parser.parse(text) + + +def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]: + """ + 批量解析日志的便捷函数 + + Args: + texts: 日志文本列表 + model_path: 可选的模型路径 + + Returns: + 解析结果列表 + """ + parser = TRPGParser(model_path=model_path) + return parser.parse_batch(texts) + + +__all__ = ["TRPGParser", "parse_line", "parse_lines"] diff --git a/src/basemodel/training/__init__.py b/src/basemodel/training/__init__.py new file mode 100644 index 0000000..5671c42 --- /dev/null +++ b/src/basemodel/training/__init__.py @@ -0,0 +1,205 @@ +""" +训练模块 + +提供 TRPG NER 模型训练功能。 + +注意: 使用此模块需要安装训练依赖: + pip install base-model-trpgner[train] +""" + +import os +from typing import Optional, List +from pathlib import Path + + +def train_ner_model( + conll_data: str, + model_name_or_path: str = "hfl/minirbt-h256", + output_dir: str = "./models/trpg-ner-v1", + num_train_epochs: int = 20, + per_device_train_batch_size: int = 4, + learning_rate: float = 5e-5, + max_length: int = 128, + resume_from_checkpoint: Optional[str] = None, +) -> None: + """ + 训练 NER 模型 + + Args: + conll_data: CoNLL 格式数据文件或目录 + model_name_or_path: 基础模型名称或路径 + output_dir: 模型输出目录 + num_train_epochs: 训练轮数 + per_device_train_batch_size: 批处理大小 + learning_rate: 学习率 + max_length: 最大序列长度 + resume_from_checkpoint: 恢复检查点路径 + + Examples: + >>> from basemodel.training import train_ner_model + >>> train_ner_model( + ... conll_data="./data", + ... output_dir="./my_model", + ... epochs=10 + ... ) + """ + try: + import torch + from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + TrainingArguments, + Trainer, + ) + from datasets import Dataset + from tqdm.auto import tqdm + except ImportError as e: + raise ImportError( + "训练依赖未安装。请运行: pip install base-model-trpgner[train]" + ) from e + + # 导入数据处理函数 + from basemodel.utils.conll import load_conll_dataset, tokenize_and_align_labels + + print(f"🚀 Starting training...") + + # 加载数据 + dataset, label_list = load_conll_dataset(conll_data) + label2id = {label: i for i, label in enumerate(label_list)} + id2label = {i: label for i, label in enumerate(label_list)} + + # 初始化模型 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + if tokenizer.model_max_length > 1000: + tokenizer.model_max_length = max_length + + model = AutoModelForTokenClassification.from_pretrained( + model_name_or_path, + num_labels=len(label_list), + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, + ) + + # Tokenize + tokenized_dataset = dataset.map( + lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length), + batched=True, + remove_columns=["text", "char_labels"], + ) + + # 训练参数 + training_args = TrainingArguments( + output_dir=output_dir, + learning_rate=learning_rate, + per_device_train_batch_size=per_device_train_batch_size, + num_train_epochs=num_train_epochs, + logging_steps=5, + save_steps=200, + save_total_limit=2, + do_eval=False, + report_to="none", + no_cuda=not torch.cuda.is_available(), + load_best_model_at_end=False, + push_to_hub=False, + fp16=torch.cuda.is_available(), + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + ) + + # 开始训练 + print("🚀 Starting training...") + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + # 保存模型 + print("💾 Saving final model...") + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) + + print(f"✅ Training finished. Model saved to {output_dir}") + + +def export_to_onnx( + model_dir: str, + onnx_path: str, + max_length: int = 128, +) -> bool: + """ + 将训练好的模型导出为 ONNX 格式 + + Args: + model_dir: 模型目录 + onnx_path: ONNX 输出路径 + max_length: 最大序列长度 + + Returns: + 是否成功 + """ + try: + import torch + from torch.onnx import export as onnx_export + from transformers import AutoTokenizer, AutoModelForTokenClassification + import onnx + except ImportError as e: + raise ImportError( + "ONNX 导出依赖未安装。请运行: pip install onnx" + ) from e + + print(f"📤 Exporting model from {model_dir} to {onnx_path}...") + + model_dir = os.path.abspath(model_dir) + if not os.path.exists(model_dir): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + # 加载模型 + tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) + model = AutoModelForTokenClassification.from_pretrained( + model_dir, local_files_only=True + ) + model.eval() + + # 创建虚拟输入 + dummy_text = "莎莎 2024-06-08 21:46:26" + inputs = tokenizer( + dummy_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + ) + + # 确保目录存在 + os.makedirs(os.path.dirname(onnx_path), exist_ok=True) + + # 导出 ONNX + onnx_export( + model, + (inputs["input_ids"], inputs["attention_mask"]), + onnx_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "logits": {0: "batch_size", 1: "sequence_length"}, + }, + ) + + # 验证 ONNX 模型 + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + size_mb = os.path.getsize(onnx_path) / 1024 / 1024 + print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB") + return True + + +__all__ = ["train_ner_model", "export_to_onnx"] diff --git a/src/basemodel/utils/__init__.py b/src/basemodel/utils/__init__.py new file mode 100644 index 0000000..12a3ef4 --- /dev/null +++ b/src/basemodel/utils/__init__.py @@ -0,0 +1,192 @@ +""" +工具模块 + +提供数据加载、CoNLL 格式处理等工具函数。 +""" + +import os +import glob +from typing import List, Dict, Any, Tuple +from datasets import Dataset +from tqdm.auto import tqdm + + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """Convert word-level labels to char-level""" + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + if text[pos: pos + len(token)] == token: + for i in range(len(token)): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + + +def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: + """Parse .conll → [{"text": str, "char_labels": List[str]}]""" + with open(filepath, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # 检测 word-level + is_word_level = any( + len(line.split()[0]) > 1 + for line in lines + if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 + ) + + samples = [] + if is_word_level: + current_text_parts = [] + current_word_labels = [] + + for line in lines: + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) >= 4: + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + else: + current_text = [] + current_labels = [] + + for line in lines: + if line.startswith("-DOCSTART-"): + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) >= 4: + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + + return samples + + +def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: + """Load .conll files → Dataset""" + filepaths = [] + if os.path.isdir(conll_dir_or_files): + filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) + elif conll_dir_or_files.endswith(".conll"): + filepaths = [conll_dir_or_files] + else: + raise ValueError("conll_dir_or_files must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") + + print(f"Loading {len(filepaths)} conll files: {filepaths}") + + all_samples = [] + label_set = {"O"} + + for fp in tqdm(filepaths, desc="Parsing .conll"): + samples = parse_conll_file(fp) + for s in samples: + all_samples.append(s) + label_set.update(s["char_labels"]) + + # Build label list + label_list = ["O"] + for label in sorted(label_set - {"O"}): + if label.startswith("B-") or label.startswith("I-"): + label_list.append(label) + for label in list(label_list): + if label.startswith("B-"): + i_label = "I" + label[1:] + if i_label not in label_list: + label_list.append(i_label) + print(f"⚠️ Added missing {i_label} for {label}") + + print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}") + return Dataset.from_list(all_samples), label_list + + +def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): + """Tokenize and align labels with tokenizer""" + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=True, + max_length=max_length, + return_offsets_mapping=True, + return_tensors=None, + ) + + labels = [] + for i, label_seq in enumerate(examples["char_labels"]): + offsets = tokenized["offset_mapping"][i] + label_ids = [] + for start, end in offsets: + if start == end: + label_ids.append(-100) + else: + label_ids.append(label2id[label_seq[start]]) + labels.append(label_ids) + + tokenized["labels"] = labels + return tokenized + + +__all__ = [ + "word_to_char_labels", + "parse_conll_file", + "load_conll_dataset", + "tokenize_and_align_labels", +] diff --git a/src/utils/conll_to_dataset.py b/utils/conll_to_dataset.py index 2ea5469..2ea5469 100644 --- a/src/utils/conll_to_dataset.py +++ b/utils/conll_to_dataset.py diff --git a/src/utils/word_conll_to_char_conll.py b/utils/word_conll_to_char_conll.py index e52405f..e52405f 100644 --- a/src/utils/word_conll_to_char_conll.py +++ b/utils/word_conll_to_char_conll.py |
