aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
commit575114661ef9afb95df2a211e1d8498686340e6b (patch)
tree91f1646cececb1597a9246865e89b52e059d3cfa /src
parent7ac684f1f82023c6284cd7d7efde11b8dc98c149 (diff)
downloadbase-model-575114661ef9afb95df2a211e1d8498686340e6b.tar.gz
base-model-575114661ef9afb95df2a211e1d8498686340e6b.zip
feat: Refactor and enhance TRPG NER model SDK
- Removed deprecated `word_conll_to_char_conll.py` utility and integrated its functionality into the new `utils` module. - Introduced a comprehensive GitHub Actions workflow for automated publishing to PyPI and GitHub Releases. - Added `__init__.py` files to establish package structure for `basemodel`, `inference`, `training`, and `utils` modules. - Implemented model downloading functionality in `download_model.py` to fetch pre-trained ONNX models. - Developed `TRPGParser` class for ONNX-based inference, including methods for parsing TRPG logs. - Created training utilities in `training/__init__.py` for NER model training with Hugging Face Transformers. - Enhanced utility functions for CoNLL file parsing and dataset creation. - Added command-line interface for converting CoNLL files to datasets with validation options.
Diffstat (limited to 'src')
-rw-r--r--src/basemodel/__init__.py36
-rw-r--r--src/basemodel/download_model.py70
-rw-r--r--src/basemodel/inference/__init__.py292
-rw-r--r--src/basemodel/training/__init__.py205
-rw-r--r--src/basemodel/utils/__init__.py192
-rw-r--r--src/utils/conll_to_dataset.py263
-rw-r--r--src/utils/word_conll_to_char_conll.py55
7 files changed, 795 insertions, 318 deletions
diff --git a/src/basemodel/__init__.py b/src/basemodel/__init__.py
new file mode 100644
index 0000000..7287df4
--- /dev/null
+++ b/src/basemodel/__init__.py
@@ -0,0 +1,36 @@
+"""
+base-model - HydroRoll TRPG NER 模型 SDK
+
+这是一个用于 TRPG(桌上角色扮演游戏)日志命名实体识别的 Python SDK。
+
+基本用法:
+ >>> from basemodel import TRPGParser
+ >>> parser = TRPGParser()
+ >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
+ >>> print(result)
+ {'metadata': {'speaker': '风雨', 'timestamp': '2024-06-08 21:44:59'}, 'content': [...]}
+
+训练功能(需要额外安装):
+ >>> pip install base-model[train]
+ >>> from basemodel.training import train_ner_model
+ >>> train_ner_model(conll_data="./data", output_dir="./model")
+"""
+
+from basemodel.inference import TRPGParser, parse_line, parse_lines
+
+try:
+ from importlib.metadata import version
+ __version__ = version("base-model")
+except Exception:
+ __version__ = "0.1.0.dev"
+
+__all__ = [
+ "__version__",
+ "TRPGParser",
+ "parse_line",
+ "parse_lines",
+]
+
+
+def get_version():
+ return __version__
diff --git a/src/basemodel/download_model.py b/src/basemodel/download_model.py
new file mode 100644
index 0000000..2d65099
--- /dev/null
+++ b/src/basemodel/download_model.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+"""
+模型下载脚本
+
+自动下载预训练的 ONNX 模型到用户缓存目录。
+"""
+
+import os
+import sys
+from pathlib import Path
+import urllib.request
+
+
+def download_model(
+ output_dir: str = None,
+ url: str = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+):
+ """
+ 下载 ONNX 模型
+
+ Args:
+ output_dir: 输出目录,默认为 ~/.cache/basemodel/models/trpg-final/
+ url: 模型下载 URL
+ """
+ if output_dir is None:
+ output_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ else:
+ output_dir = Path(output_dir)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "model.onnx"
+
+ if output_path.exists():
+ print(f"✅ 模型已存在: {output_path}")
+ return str(output_path)
+
+ print(f"📥 正在下载模型到 {output_path}...")
+ print(f" URL: {url}")
+
+ try:
+ urllib.request.urlretrieve(url, output_path)
+ print(f"✅ 模型下载成功!")
+ return str(output_path)
+ except Exception as e:
+ print(f"❌ 下载失败: {e}")
+ print(f" 请手动从以下地址下载模型:")
+ print(f" {url}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="下载 base-model ONNX 模型")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="模型输出目录(默认: ~/.cache/basemodel/models/trpg-final/)"
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx",
+ help="模型下载 URL"
+ )
+
+ args = parser.parse_args()
+
+ download_model(args.output_dir, args.url)
diff --git a/src/basemodel/inference/__init__.py b/src/basemodel/inference/__init__.py
new file mode 100644
index 0000000..93a185f
--- /dev/null
+++ b/src/basemodel/inference/__init__.py
@@ -0,0 +1,292 @@
+"""
+ONNX 推理模块
+
+提供基于 ONNX 的 TRPG 日志命名实体识别推理功能。
+"""
+
+import os
+from typing import List, Dict, Any, Optional
+from pathlib import Path
+
+try:
+ import numpy as np
+ import onnxruntime as ort
+ from transformers import AutoTokenizer
+except ImportError as e:
+ raise ImportError(
+ "依赖未安装。请运行: pip install onnxruntime transformers numpy"
+ ) from e
+
+
+# 默认模型路径(相对于包安装位置)
+DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent.parent / "models" / "trpg-final"
+# 远程模型 URL(用于自动下载)
+MODEL_URL = "https://github.com/HydroRoll-Team/base-model/releases/download/v0.1.0/model.onnx"
+
+
+class TRPGParser:
+ """
+ TRPG 日志解析器(基于 ONNX)
+
+ Args:
+ model_path: ONNX 模型路径,默认使用内置模型
+ tokenizer_path: tokenizer 配置路径,默认与 model_path 相同
+ device: 推理设备,"cpu" 或 "cuda"
+
+ Examples:
+ >>> parser = TRPGParser()
+ >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
+ >>> print(result['metadata']['speaker'])
+ '风雨'
+ """
+
+ def __init__(
+ self,
+ model_path: Optional[str] = None,
+ tokenizer_path: Optional[str] = None,
+ device: str = "cpu",
+ ):
+ # 确定模型路径
+ if model_path is None:
+ model_path = self._get_default_model_path()
+
+ if tokenizer_path is None:
+ tokenizer_path = Path(model_path).parent
+
+ self.model_path = Path(model_path)
+ self.tokenizer_path = Path(tokenizer_path)
+ self.device = device
+
+ # 加载模型
+ self._load_model()
+
+ def _get_default_model_path(self) -> str:
+ """获取默认模型路径"""
+ # 1. 尝试相对于项目根目录
+ project_root = Path(__file__).parent.parent.parent.parent
+ local_model = project_root / "models" / "trpg-final" / "model.onnx"
+ if local_model.exists():
+ return str(local_model)
+
+ # 2. 尝试用户数据目录
+ from pathlib import Path
+ user_model_dir = Path.home() / ".cache" / "basemodel" / "models" / "trpg-final"
+ user_model = user_model_dir / "model.onnx"
+ if user_model.exists():
+ return str(user_model)
+
+ # 3. 抛出错误,提示下载
+ raise FileNotFoundError(
+ f"模型文件未找到。请从 {MODEL_URL} 下载模型到 {user_model_dir}\n"
+ f"或运行: python -m basemodel.download_model"
+ )
+
+ def _load_model(self):
+ """加载 ONNX 模型和 Tokenizer"""
+ # 加载 tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ str(self.tokenizer_path),
+ local_files_only=True,
+ )
+
+ # 加载 ONNX 模型
+ providers = ["CPUExecutionProvider"]
+ if self.device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
+ providers.insert(0, "CUDAExecutionProvider")
+
+ self.session = ort.InferenceSession(
+ str(self.model_path),
+ providers=providers,
+ )
+
+ # 加载标签映射
+ import json
+ config_path = self.tokenizer_path / "config.json"
+ if config_path.exists():
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ self.id2label = {int(k): v for k, v in config.get("id2label", {}).items()}
+ else:
+ # 默认标签
+ self.id2label = {
+ 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment",
+ 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker",
+ 9: "B-timestamp", 10: "I-timestamp",
+ }
+
+ def parse(self, text: str) -> Dict[str, Any]:
+ """
+ 解析单条 TRPG 日志
+
+ Args:
+ text: 待解析的日志文本
+
+ Returns:
+ 包含 metadata 和 content 的字典
+ - metadata: speaker, timestamp
+ - content: dialogue, action, comment 列表
+
+ Examples:
+ >>> parser = TRPGParser()
+ >>> result = parser.parse("风雨 2024-06-08 21:44:59 剧烈的疼痛...")
+ >>> result['metadata']['speaker']
+ '风雨'
+ """
+ # Tokenize
+ inputs = self.tokenizer(
+ text,
+ return_tensors="np",
+ return_offsets_mapping=True,
+ padding="max_length",
+ truncation=True,
+ max_length=128,
+ )
+
+ # 推理
+ outputs = self.session.run(
+ ["logits"],
+ {
+ "input_ids": inputs["input_ids"].astype(np.int64),
+ "attention_mask": inputs["attention_mask"].astype(np.int64),
+ },
+ )
+
+ # 后处理
+ logits = outputs[0][0]
+ predictions = np.argmax(logits, axis=-1)
+ offsets = inputs["offset_mapping"][0]
+
+ # 聚合实体
+ entities = self._group_entities(predictions, offsets, logits)
+
+ # 构建结果
+ result = {"metadata": {}, "content": []}
+ for ent in entities:
+ if ent["start"] >= len(text) or ent["end"] > len(text):
+ continue
+
+ raw_text = text[ent["start"]: ent["end"]]
+ clean_text = self._clean_text(raw_text, ent["type"])
+
+ if not clean_text.strip():
+ continue
+
+ if ent["type"] in ["timestamp", "speaker"]:
+ result["metadata"][ent["type"]] = clean_text
+ elif ent["type"] in ["dialogue", "action", "comment"]:
+ result["content"].append({
+ "type": ent["type"],
+ "content": clean_text,
+ "confidence": round(ent["score"], 3),
+ })
+
+ return result
+
+ def _group_entities(self, predictions, offsets, logits):
+ """将 token 级别的预测聚合为实体"""
+ entities = []
+ current = None
+
+ for i in range(len(predictions)):
+ start, end = offsets[i]
+ if start == end: # special tokens
+ continue
+
+ pred_id = int(predictions[i])
+ label = self.id2label.get(pred_id, "O")
+
+ if label == "O":
+ if current:
+ entities.append(current)
+ current = None
+ continue
+
+ tag_type = label[2:] if len(label) > 2 else "O"
+
+ if label.startswith("B-"):
+ if current:
+ entities.append(current)
+ current = {
+ "type": tag_type,
+ "start": int(start),
+ "end": int(end),
+ "score": float(np.max(logits[i])),
+ }
+ elif label.startswith("I-") and current and current["type"] == tag_type:
+ current["end"] = int(end)
+ else:
+ if current:
+ entities.append(current)
+ current = None
+
+ if current:
+ entities.append(current)
+
+ return entities
+
+ def _clean_text(self, text: str, group: str) -> str:
+ """清理提取的文本"""
+ import re
+
+ text = text.strip()
+
+ # 移除周围符号
+ if group == "comment":
+ text = re.sub(r"^[((]+|[))]+$", "", text)
+ elif group == "dialogue":
+ text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text)
+ elif group == "action":
+ text = re.sub(r"^[*#]+|[*#]+$", "", text)
+
+ # 修复时间戳
+ if group == "timestamp" and text and text[0].isdigit():
+ if len(text) > 2 and text[2] == "-":
+ text = "20" + text
+
+ return text
+
+ def parse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
+ """
+ 批量解析多条日志
+
+ Args:
+ texts: 日志文本列表
+
+ Returns:
+ 解析结果列表
+ """
+ return [self.parse(text) for text in texts]
+
+
+# 便捷函数
+def parse_line(text: str, model_path: Optional[str] = None) -> Dict[str, Any]:
+ """
+ 解析单条日志的便捷函数
+
+ Args:
+ text: 日志文本
+ model_path: 可选的模型路径
+
+ Returns:
+ 解析结果字典
+ """
+ parser = TRPGParser(model_path=model_path)
+ return parser.parse(text)
+
+
+def parse_lines(texts: List[str], model_path: Optional[str] = None) -> List[Dict[str, Any]]:
+ """
+ 批量解析日志的便捷函数
+
+ Args:
+ texts: 日志文本列表
+ model_path: 可选的模型路径
+
+ Returns:
+ 解析结果列表
+ """
+ parser = TRPGParser(model_path=model_path)
+ return parser.parse_batch(texts)
+
+
+__all__ = ["TRPGParser", "parse_line", "parse_lines"]
diff --git a/src/basemodel/training/__init__.py b/src/basemodel/training/__init__.py
new file mode 100644
index 0000000..5671c42
--- /dev/null
+++ b/src/basemodel/training/__init__.py
@@ -0,0 +1,205 @@
+"""
+训练模块
+
+提供 TRPG NER 模型训练功能。
+
+注意: 使用此模块需要安装训练依赖:
+ pip install base-model-trpgner[train]
+"""
+
+import os
+from typing import Optional, List
+from pathlib import Path
+
+
+def train_ner_model(
+ conll_data: str,
+ model_name_or_path: str = "hfl/minirbt-h256",
+ output_dir: str = "./models/trpg-ner-v1",
+ num_train_epochs: int = 20,
+ per_device_train_batch_size: int = 4,
+ learning_rate: float = 5e-5,
+ max_length: int = 128,
+ resume_from_checkpoint: Optional[str] = None,
+) -> None:
+ """
+ 训练 NER 模型
+
+ Args:
+ conll_data: CoNLL 格式数据文件或目录
+ model_name_or_path: 基础模型名称或路径
+ output_dir: 模型输出目录
+ num_train_epochs: 训练轮数
+ per_device_train_batch_size: 批处理大小
+ learning_rate: 学习率
+ max_length: 最大序列长度
+ resume_from_checkpoint: 恢复检查点路径
+
+ Examples:
+ >>> from basemodel.training import train_ner_model
+ >>> train_ner_model(
+ ... conll_data="./data",
+ ... output_dir="./my_model",
+ ... epochs=10
+ ... )
+ """
+ try:
+ import torch
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForTokenClassification,
+ TrainingArguments,
+ Trainer,
+ )
+ from datasets import Dataset
+ from tqdm.auto import tqdm
+ except ImportError as e:
+ raise ImportError(
+ "训练依赖未安装。请运行: pip install base-model-trpgner[train]"
+ ) from e
+
+ # 导入数据处理函数
+ from basemodel.utils.conll import load_conll_dataset, tokenize_and_align_labels
+
+ print(f"🚀 Starting training...")
+
+ # 加载数据
+ dataset, label_list = load_conll_dataset(conll_data)
+ label2id = {label: i for i, label in enumerate(label_list)}
+ id2label = {i: label for i, label in enumerate(label_list)}
+
+ # 初始化模型
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+ if tokenizer.model_max_length > 1000:
+ tokenizer.model_max_length = max_length
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_name_or_path,
+ num_labels=len(label_list),
+ id2label=id2label,
+ label2id=label2id,
+ ignore_mismatched_sizes=True,
+ )
+
+ # Tokenize
+ tokenized_dataset = dataset.map(
+ lambda ex: tokenize_and_align_labels(ex, tokenizer, label2id, max_length),
+ batched=True,
+ remove_columns=["text", "char_labels"],
+ )
+
+ # 训练参数
+ training_args = TrainingArguments(
+ output_dir=output_dir,
+ learning_rate=learning_rate,
+ per_device_train_batch_size=per_device_train_batch_size,
+ num_train_epochs=num_train_epochs,
+ logging_steps=5,
+ save_steps=200,
+ save_total_limit=2,
+ do_eval=False,
+ report_to="none",
+ no_cuda=not torch.cuda.is_available(),
+ load_best_model_at_end=False,
+ push_to_hub=False,
+ fp16=torch.cuda.is_available(),
+ )
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=tokenized_dataset,
+ tokenizer=tokenizer,
+ )
+
+ # 开始训练
+ print("🚀 Starting training...")
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+
+ # 保存模型
+ print("💾 Saving final model...")
+ trainer.save_model(output_dir)
+ tokenizer.save_pretrained(output_dir)
+
+ print(f"✅ Training finished. Model saved to {output_dir}")
+
+
+def export_to_onnx(
+ model_dir: str,
+ onnx_path: str,
+ max_length: int = 128,
+) -> bool:
+ """
+ 将训练好的模型导出为 ONNX 格式
+
+ Args:
+ model_dir: 模型目录
+ onnx_path: ONNX 输出路径
+ max_length: 最大序列长度
+
+ Returns:
+ 是否成功
+ """
+ try:
+ import torch
+ from torch.onnx import export as onnx_export
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
+ import onnx
+ except ImportError as e:
+ raise ImportError(
+ "ONNX 导出依赖未安装。请运行: pip install onnx"
+ ) from e
+
+ print(f"📤 Exporting model from {model_dir} to {onnx_path}...")
+
+ model_dir = os.path.abspath(model_dir)
+ if not os.path.exists(model_dir):
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
+
+ # 加载模型
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_dir, local_files_only=True
+ )
+ model.eval()
+
+ # 创建虚拟输入
+ dummy_text = "莎莎 2024-06-08 21:46:26"
+ inputs = tokenizer(
+ dummy_text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ )
+
+ # 确保目录存在
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
+
+ # 导出 ONNX
+ onnx_export(
+ model,
+ (inputs["input_ids"], inputs["attention_mask"]),
+ onnx_path,
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=["input_ids", "attention_mask"],
+ output_names=["logits"],
+ dynamic_axes={
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
+ "logits": {0: "batch_size", 1: "sequence_length"},
+ },
+ )
+
+ # 验证 ONNX 模型
+ onnx_model = onnx.load(onnx_path)
+ onnx.checker.check_model(onnx_model)
+
+ size_mb = os.path.getsize(onnx_path) / 1024 / 1024
+ print(f"✅ ONNX export successful! Size: {size_mb:.2f} MB")
+ return True
+
+
+__all__ = ["train_ner_model", "export_to_onnx"]
diff --git a/src/basemodel/utils/__init__.py b/src/basemodel/utils/__init__.py
new file mode 100644
index 0000000..12a3ef4
--- /dev/null
+++ b/src/basemodel/utils/__init__.py
@@ -0,0 +1,192 @@
+"""
+工具模块
+
+提供数据加载、CoNLL 格式处理等工具函数。
+"""
+
+import os
+import glob
+from typing import List, Dict, Any, Tuple
+from datasets import Dataset
+from tqdm.auto import tqdm
+
+
+def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
+ """Convert word-level labels to char-level"""
+ char_labels = ["O"] * len(text)
+ pos = 0
+
+ for token, label in word_labels:
+ if pos >= len(text):
+ break
+
+ while pos < len(text) and text[pos] != token[0]:
+ pos += 1
+ if pos >= len(text):
+ break
+
+ if text[pos: pos + len(token)] == token:
+ for i in range(len(token)):
+ idx = pos + i
+ if idx < len(char_labels):
+ if i == 0 and label.startswith("B-"):
+ char_labels[idx] = label
+ elif label.startswith("B-"):
+ char_labels[idx] = "I" + label[1:]
+ else:
+ char_labels[idx] = label
+ pos += len(token)
+ else:
+ pos += 1
+
+ return char_labels
+
+
+def parse_conll_file(filepath: str) -> List[Dict[str, Any]]:
+ """Parse .conll → [{"text": str, "char_labels": List[str]}]"""
+ with open(filepath, "r", encoding="utf-8") as f:
+ lines = [line.rstrip("\n") for line in f.readlines()]
+
+ # 检测 word-level
+ is_word_level = any(
+ len(line.split()[0]) > 1
+ for line in lines
+ if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4
+ )
+
+ samples = []
+ if is_word_level:
+ current_text_parts = []
+ current_word_labels = []
+
+ for line in lines:
+ if not line or line.startswith("-DOCSTART-"):
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+ current_text_parts = []
+ current_word_labels = []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ token, label = parts[0], parts[3]
+ current_text_parts.append(token)
+ current_word_labels.append((token, label))
+
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+ else:
+ current_text = []
+ current_labels = []
+
+ for line in lines:
+ if line.startswith("-DOCSTART-"):
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+ current_text, current_labels = [], []
+ continue
+
+ if not line:
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+ current_text, current_labels = [], []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ char = parts[0].replace("\\n", "\n")
+ label = parts[3]
+ current_text.append(char)
+ current_labels.append(label)
+
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+
+ return samples
+
+
+def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]:
+ """Load .conll files → Dataset"""
+ filepaths = []
+ if os.path.isdir(conll_dir_or_files):
+ filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll")))
+ elif conll_dir_or_files.endswith(".conll"):
+ filepaths = [conll_dir_or_files]
+ else:
+ raise ValueError("conll_dir_or_files must be .conll file or directory")
+
+ if not filepaths:
+ raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}")
+
+ print(f"Loading {len(filepaths)} conll files: {filepaths}")
+
+ all_samples = []
+ label_set = {"O"}
+
+ for fp in tqdm(filepaths, desc="Parsing .conll"):
+ samples = parse_conll_file(fp)
+ for s in samples:
+ all_samples.append(s)
+ label_set.update(s["char_labels"])
+
+ # Build label list
+ label_list = ["O"]
+ for label in sorted(label_set - {"O"}):
+ if label.startswith("B-") or label.startswith("I-"):
+ label_list.append(label)
+ for label in list(label_list):
+ if label.startswith("B-"):
+ i_label = "I" + label[1:]
+ if i_label not in label_list:
+ label_list.append(i_label)
+ print(f"⚠️ Added missing {i_label} for {label}")
+
+ print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}")
+ return Dataset.from_list(all_samples), label_list
+
+
+def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128):
+ """Tokenize and align labels with tokenizer"""
+ tokenized = tokenizer(
+ examples["text"],
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+ return_offsets_mapping=True,
+ return_tensors=None,
+ )
+
+ labels = []
+ for i, label_seq in enumerate(examples["char_labels"]):
+ offsets = tokenized["offset_mapping"][i]
+ label_ids = []
+ for start, end in offsets:
+ if start == end:
+ label_ids.append(-100)
+ else:
+ label_ids.append(label2id[label_seq[start]])
+ labels.append(label_ids)
+
+ tokenized["labels"] = labels
+ return tokenized
+
+
+__all__ = [
+ "word_to_char_labels",
+ "parse_conll_file",
+ "load_conll_dataset",
+ "tokenize_and_align_labels",
+]
diff --git a/src/utils/conll_to_dataset.py b/src/utils/conll_to_dataset.py
deleted file mode 100644
index 2ea5469..0000000
--- a/src/utils/conll_to_dataset.py
+++ /dev/null
@@ -1,263 +0,0 @@
-"""
-TRPG CoNLL 转 Dataset 工具
-- 自动检测 word-level / char-level
-- 生成 {"text": str, "char_labels": List[str]}
-- 支持多文档、跨行实体
-"""
-
-import os
-import re
-import json
-import argparse
-from pathlib import Path
-from typing import List, Dict, Any, Tuple
-from datasets import Dataset
-
-def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
- """
- 将 word-level 标注转为 char-level labels
- Args:
- text: 原始文本 (e.g., "风雨 2024-06-08")
- word_labels: [("风雨", "B-speaker"), ("2024-06-08", "B-timestamp"), ...]
- Returns:
- char_labels: ["B-speaker", "I-speaker", "O", "B-timestamp", ...]
- """
- char_labels = ["O"] * len(text)
- pos = 0
-
- for token, label in word_labels:
- if pos >= len(text):
- break
-
- # 在文本中定位 token(处理空格/换行)
- while pos < len(text) and text[pos] != token[0]:
- pos += 1
- if pos >= len(text):
- break
-
- # 匹配 token
- if text[pos:pos+len(token)] == token:
- # 标注 B/I
- for i, char in enumerate(token):
- idx = pos + i
- if idx < len(char_labels):
- if i == 0 and label.startswith("B-"):
- char_labels[idx] = label
- elif label.startswith("B-"):
- char_labels[idx] = "I" + label[1:]
- else:
- char_labels[idx] = label
- pos += len(token)
- else:
- pos += 1
-
- return char_labels
-
-def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
- """
- 解析 .conll → [{"text": "...", "char_labels": [...]}, ...]
- 自动处理:
- - -DOCSTART- 文档边界
- - 空行句子边界
- - word-level → char-level 转换
- """
- samples = []
- current_lines = [] # 存储原始行用于检测粒度
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- current_lines.append(line.rstrip('\n'))
-
- # 检测是否 word-level
- is_word_level = False
- for line in current_lines:
- if line.strip() and not line.startswith("-DOCSTART-"):
- parts = line.split()
- if len(parts) >= 4:
- token = parts[0]
- # 如果 token 长度 >1 且非标点 → 可能是 word-level
- if len(token) > 1 and not re.match(r'^[^\w\s\u4e00-\u9fff]+$', token):
- is_word_level = True
- break
-
- if is_word_level:
- print(f"Detected word-level CoNLL, converting to char-level...")
- return _parse_word_conll(filepath)
- else:
- print(f"Detected char-level CoNLL, parsing directly...")
- return _parse_char_conll(filepath)
-
-def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]:
- """解析 word-level .conll(如您提供的原始格式)"""
- samples = []
- current_text_parts = []
- current_word_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if not line or line.startswith("-DOCSTART-"):
- if current_text_parts:
- # 合并文本
- text = "".join(current_text_parts)
- # 生成 char-level labels
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
- current_text_parts = []
- current_word_labels = []
- continue
-
- parts = line.split()
- if len(parts) < 4:
- continue
-
- token, label = parts[0], parts[3]
- current_text_parts.append(token)
- current_word_labels.append((token, label))
-
- # 处理末尾
- if current_text_parts:
- text = "".join(current_text_parts)
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
-
- return samples
-
-def _parse_char_conll(filepath: str) -> List[Dict[str, Any]]:
- """解析 char-level .conll"""
- samples = []
- current_text = []
- current_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.rstrip('\n')
- if line.startswith("-DOCSTART-"):
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
- current_text, current_labels = [], []
- continue
-
- if not line:
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
- current_text, current_labels = [], []
- continue
-
- parts = line.split()
- if len(parts) < 4:
- continue
-
- char = parts[0].replace("\\n", "\n")
- label = parts[3]
- current_text.append(char)
- current_labels.append(label)
-
- # 末尾处理
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
-
- return samples
-
-def save_dataset(samples: List[Dict[str, Any]], output_path: str, format: str = "jsonl"):
- """保存数据集"""
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
-
- if format == "jsonl":
- with open(output_path, 'w', encoding='utf-8') as f:
- for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
- print(f"Saved {len(samples)} samples to {output_path} (JSONL)")
-
- elif format == "dataset":
- dataset = Dataset.from_list(samples)
- dataset.save_to_disk(output_path)
- print(f"Saved {len(samples)} samples to {output_path} (Hugging Face Dataset)")
-
- elif format == "both":
- jsonl_path = output_path + ".jsonl"
- with open(jsonl_path, 'w', encoding='utf-8') as f:
- for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
- print(f"Saved JSONL to {jsonl_path}")
-
- dataset_path = output_path + "_dataset"
- dataset = Dataset.from_list(samples)
- dataset.save_to_disk(dataset_path)
- print(f"Saved Dataset to {dataset_path}")
-
-def validate_samples(samples: List[Dict[str, Any]]) -> bool:
- """验证样本一致性"""
- for i, sample in enumerate(samples):
- if len(sample["text"]) != len(sample["char_labels"]):
- print(f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}")
- return False
- print(f"All {len(samples)} samples validated: text & labels length match")
- return True
-
-def main():
- parser = argparse.ArgumentParser(description="Convert CoNLL to TRPG Dataset")
- parser.add_argument("input", type=str, help="Input .conll file or directory")
- parser.add_argument("--output", type=str, default="./dataset/trpg",
- help="Output path (without extension)")
- parser.add_argument("--format", choices=["jsonl", "dataset", "both"],
- default="jsonl", help="Output format")
- parser.add_argument("--validate", action="store_true",
- help="Validate samples after conversion")
-
- args = parser.parse_args()
-
- filepaths = []
- if os.path.isdir(args.input):
- filepaths = sorted(Path(args.input).glob("*.conll"))
- elif args.input.endswith(".conll"):
- filepaths = [Path(args.input)]
- else:
- raise ValueError("Input must be .conll file or directory")
-
- if not filepaths:
- raise FileNotFoundError(f"No .conll files found in {args.input}")
-
- print(f"Processing {len(filepaths)} files: {[f.name for f in filepaths]}")
-
- all_samples = []
- for fp in filepaths:
- print(f"\nProcessing {fp.name}...")
- samples = parse_conll_to_samples(str(fp))
- print(f" → {len(samples)} samples")
- all_samples.extend(samples)
-
- print(f"\nTotal: {len(all_samples)} samples")
-
- if args.validate:
- if not validate_samples(all_samples):
- exit(1)
-
- save_dataset(all_samples, args.output, args.format)
-
- label_counts = {}
- for sample in all_samples:
- for label in sample["char_labels"]:
- label_counts[label] = label_counts.get(label, 0) + 1
-
- print("\nLabel distribution:")
- for label in sorted(label_counts.keys()):
- print(f" {label}: {label_counts[label]}")
-
-if __name__ == "__main__":
- main() \ No newline at end of file
diff --git a/src/utils/word_conll_to_char_conll.py b/src/utils/word_conll_to_char_conll.py
deleted file mode 100644
index e52405f..0000000
--- a/src/utils/word_conll_to_char_conll.py
+++ /dev/null
@@ -1,55 +0,0 @@
-def word_conll_to_char_conll(word_conll_lines: list[str]) -> list[str]:
- char_lines = []
- in_new_sample = True # 下一行是否应视为新样本开始
-
- for line in word_conll_lines:
- stripped = line.strip()
- if not stripped:
- # 空行 → 标记下一句为新样本
- in_new_sample = True
- char_lines.append("")
- continue
-
- parts = stripped.split()
- if len(parts) < 4:
- char_lines.append(line.rstrip())
- continue
-
- token, label = parts[0], parts[3]
-
- # 检测新发言:B-speaker 出现 → 新样本
- if label == "B-speaker" and in_new_sample:
- char_lines.append("-DOCSTART- -X- O")
- in_new_sample = False
-
- # 转换 token → char labels(同前)
- if label == "O":
- for c in token:
- char_lines.append(f"{c} -X- _ O")
- else:
- bio_prefix = label[:2]
- tag = label[2:]
- for i, c in enumerate(token):
- char_label = f"B-{tag}" if (bio_prefix == "B-" and i == 0) else f"I-{tag}"
- char_lines.append(f"{c} -X- _ {char_label}")
-
- return char_lines
-
-if __name__ == "__main__":
- import sys
- if len(sys.argv) < 3:
- print("Usage: python word_conll_to_char_conll.py <input_word.conll> <output_char.conll>")
- sys.exit(1)
-
- input_fp = sys.argv[1]
- output_fp = sys.argv[2]
-
- with open(input_fp, "r", encoding="utf-8") as f:
- word_conll_lines = f.readlines()
-
- char_conll_lines = word_conll_to_char_conll(word_conll_lines)
-
- with open(output_fp, "w", encoding="utf-8") as f:
- f.write("\n".join(char_conll_lines) + "\n")
-
- print(f"Converted {input_fp} to character-level CoNLL format at {output_fp}") \ No newline at end of file