From 575114661ef9afb95df2a211e1d8498686340e6b Mon Sep 17 00:00:00 2001 From: HsiangNianian Date: Tue, 30 Dec 2025 19:54:08 +0800 Subject: feat: Refactor and enhance TRPG NER model SDK - Removed deprecated `word_conll_to_char_conll.py` utility and integrated its functionality into the new `utils` module. - Introduced a comprehensive GitHub Actions workflow for automated publishing to PyPI and GitHub Releases. - Added `__init__.py` files to establish package structure for `basemodel`, `inference`, `training`, and `utils` modules. - Implemented model downloading functionality in `download_model.py` to fetch pre-trained ONNX models. - Developed `TRPGParser` class for ONNX-based inference, including methods for parsing TRPG logs. - Created training utilities in `training/__init__.py` for NER model training with Hugging Face Transformers. - Enhanced utility functions for CoNLL file parsing and dataset creation. - Added command-line interface for converting CoNLL files to datasets with validation options. --- src/utils/conll_to_dataset.py | 263 ---------------------------------- src/utils/word_conll_to_char_conll.py | 55 ------- 2 files changed, 318 deletions(-) delete mode 100644 src/utils/conll_to_dataset.py delete mode 100644 src/utils/word_conll_to_char_conll.py (limited to 'src/utils') diff --git a/src/utils/conll_to_dataset.py b/src/utils/conll_to_dataset.py deleted file mode 100644 index 2ea5469..0000000 --- a/src/utils/conll_to_dataset.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -TRPG CoNLL 转 Dataset 工具 -- 自动检测 word-level / char-level -- 生成 {"text": str, "char_labels": List[str]} -- 支持多文档、跨行实体 -""" - -import os -import re -import json -import argparse -from pathlib import Path -from typing import List, Dict, Any, Tuple -from datasets import Dataset - -def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: - """ - 将 word-level 标注转为 char-level labels - Args: - text: 原始文本 (e.g., "风雨 2024-06-08") - word_labels: [("风雨", "B-speaker"), ("2024-06-08", "B-timestamp"), ...] - Returns: - char_labels: ["B-speaker", "I-speaker", "O", "B-timestamp", ...] - """ - char_labels = ["O"] * len(text) - pos = 0 - - for token, label in word_labels: - if pos >= len(text): - break - - # 在文本中定位 token(处理空格/换行) - while pos < len(text) and text[pos] != token[0]: - pos += 1 - if pos >= len(text): - break - - # 匹配 token - if text[pos:pos+len(token)] == token: - # 标注 B/I - for i, char in enumerate(token): - idx = pos + i - if idx < len(char_labels): - if i == 0 and label.startswith("B-"): - char_labels[idx] = label - elif label.startswith("B-"): - char_labels[idx] = "I" + label[1:] - else: - char_labels[idx] = label - pos += len(token) - else: - pos += 1 - - return char_labels - -def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]: - """ - 解析 .conll → [{"text": "...", "char_labels": [...]}, ...] - 自动处理: - - -DOCSTART- 文档边界 - - 空行句子边界 - - word-level → char-level 转换 - """ - samples = [] - current_lines = [] # 存储原始行用于检测粒度 - - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - current_lines.append(line.rstrip('\n')) - - # 检测是否 word-level - is_word_level = False - for line in current_lines: - if line.strip() and not line.startswith("-DOCSTART-"): - parts = line.split() - if len(parts) >= 4: - token = parts[0] - # 如果 token 长度 >1 且非标点 → 可能是 word-level - if len(token) > 1 and not re.match(r'^[^\w\s\u4e00-\u9fff]+$', token): - is_word_level = True - break - - if is_word_level: - print(f"Detected word-level CoNLL, converting to char-level...") - return _parse_word_conll(filepath) - else: - print(f"Detected char-level CoNLL, parsing directly...") - return _parse_char_conll(filepath) - -def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]: - """解析 word-level .conll(如您提供的原始格式)""" - samples = [] - current_text_parts = [] - current_word_labels = [] - - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line or line.startswith("-DOCSTART-"): - if current_text_parts: - # 合并文本 - text = "".join(current_text_parts) - # 生成 char-level labels - char_labels = word_to_char_labels(text, current_word_labels) - samples.append({ - "text": text, - "char_labels": char_labels - }) - current_text_parts = [] - current_word_labels = [] - continue - - parts = line.split() - if len(parts) < 4: - continue - - token, label = parts[0], parts[3] - current_text_parts.append(token) - current_word_labels.append((token, label)) - - # 处理末尾 - if current_text_parts: - text = "".join(current_text_parts) - char_labels = word_to_char_labels(text, current_word_labels) - samples.append({ - "text": text, - "char_labels": char_labels - }) - - return samples - -def _parse_char_conll(filepath: str) -> List[Dict[str, Any]]: - """解析 char-level .conll""" - samples = [] - current_text = [] - current_labels = [] - - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - line = line.rstrip('\n') - if line.startswith("-DOCSTART-"): - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy() - }) - current_text, current_labels = [], [] - continue - - if not line: - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy() - }) - current_text, current_labels = [], [] - continue - - parts = line.split() - if len(parts) < 4: - continue - - char = parts[0].replace("\\n", "\n") - label = parts[3] - current_text.append(char) - current_labels.append(label) - - # 末尾处理 - if current_text: - samples.append({ - "text": "".join(current_text), - "char_labels": current_labels.copy() - }) - - return samples - -def save_dataset(samples: List[Dict[str, Any]], output_path: str, format: str = "jsonl"): - """保存数据集""" - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - - if format == "jsonl": - with open(output_path, 'w', encoding='utf-8') as f: - for sample in samples: - f.write(json.dumps(sample, ensure_ascii=False) + '\n') - print(f"Saved {len(samples)} samples to {output_path} (JSONL)") - - elif format == "dataset": - dataset = Dataset.from_list(samples) - dataset.save_to_disk(output_path) - print(f"Saved {len(samples)} samples to {output_path} (Hugging Face Dataset)") - - elif format == "both": - jsonl_path = output_path + ".jsonl" - with open(jsonl_path, 'w', encoding='utf-8') as f: - for sample in samples: - f.write(json.dumps(sample, ensure_ascii=False) + '\n') - print(f"Saved JSONL to {jsonl_path}") - - dataset_path = output_path + "_dataset" - dataset = Dataset.from_list(samples) - dataset.save_to_disk(dataset_path) - print(f"Saved Dataset to {dataset_path}") - -def validate_samples(samples: List[Dict[str, Any]]) -> bool: - """验证样本一致性""" - for i, sample in enumerate(samples): - if len(sample["text"]) != len(sample["char_labels"]): - print(f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}") - return False - print(f"All {len(samples)} samples validated: text & labels length match") - return True - -def main(): - parser = argparse.ArgumentParser(description="Convert CoNLL to TRPG Dataset") - parser.add_argument("input", type=str, help="Input .conll file or directory") - parser.add_argument("--output", type=str, default="./dataset/trpg", - help="Output path (without extension)") - parser.add_argument("--format", choices=["jsonl", "dataset", "both"], - default="jsonl", help="Output format") - parser.add_argument("--validate", action="store_true", - help="Validate samples after conversion") - - args = parser.parse_args() - - filepaths = [] - if os.path.isdir(args.input): - filepaths = sorted(Path(args.input).glob("*.conll")) - elif args.input.endswith(".conll"): - filepaths = [Path(args.input)] - else: - raise ValueError("Input must be .conll file or directory") - - if not filepaths: - raise FileNotFoundError(f"No .conll files found in {args.input}") - - print(f"Processing {len(filepaths)} files: {[f.name for f in filepaths]}") - - all_samples = [] - for fp in filepaths: - print(f"\nProcessing {fp.name}...") - samples = parse_conll_to_samples(str(fp)) - print(f" → {len(samples)} samples") - all_samples.extend(samples) - - print(f"\nTotal: {len(all_samples)} samples") - - if args.validate: - if not validate_samples(all_samples): - exit(1) - - save_dataset(all_samples, args.output, args.format) - - label_counts = {} - for sample in all_samples: - for label in sample["char_labels"]: - label_counts[label] = label_counts.get(label, 0) + 1 - - print("\nLabel distribution:") - for label in sorted(label_counts.keys()): - print(f" {label}: {label_counts[label]}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/src/utils/word_conll_to_char_conll.py b/src/utils/word_conll_to_char_conll.py deleted file mode 100644 index e52405f..0000000 --- a/src/utils/word_conll_to_char_conll.py +++ /dev/null @@ -1,55 +0,0 @@ -def word_conll_to_char_conll(word_conll_lines: list[str]) -> list[str]: - char_lines = [] - in_new_sample = True # 下一行是否应视为新样本开始 - - for line in word_conll_lines: - stripped = line.strip() - if not stripped: - # 空行 → 标记下一句为新样本 - in_new_sample = True - char_lines.append("") - continue - - parts = stripped.split() - if len(parts) < 4: - char_lines.append(line.rstrip()) - continue - - token, label = parts[0], parts[3] - - # 检测新发言:B-speaker 出现 → 新样本 - if label == "B-speaker" and in_new_sample: - char_lines.append("-DOCSTART- -X- O") - in_new_sample = False - - # 转换 token → char labels(同前) - if label == "O": - for c in token: - char_lines.append(f"{c} -X- _ O") - else: - bio_prefix = label[:2] - tag = label[2:] - for i, c in enumerate(token): - char_label = f"B-{tag}" if (bio_prefix == "B-" and i == 0) else f"I-{tag}" - char_lines.append(f"{c} -X- _ {char_label}") - - return char_lines - -if __name__ == "__main__": - import sys - if len(sys.argv) < 3: - print("Usage: python word_conll_to_char_conll.py ") - sys.exit(1) - - input_fp = sys.argv[1] - output_fp = sys.argv[2] - - with open(input_fp, "r", encoding="utf-8") as f: - word_conll_lines = f.readlines() - - char_conll_lines = word_conll_to_char_conll(word_conll_lines) - - with open(output_fp, "w", encoding="utf-8") as f: - f.write("\n".join(char_conll_lines) + "\n") - - print(f"Converted {input_fp} to character-level CoNLL format at {output_fp}") \ No newline at end of file -- cgit v1.2.3-70-g09d2