summaryrefslogtreecommitdiffstatshomepage
path: root/src/utils
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
committerHsiangNianian <i@jyunko.cn>2025-12-30 19:54:08 +0800
commit575114661ef9afb95df2a211e1d8498686340e6b (patch)
tree91f1646cececb1597a9246865e89b52e059d3cfa /src/utils
parent7ac684f1f82023c6284cd7d7efde11b8dc98c149 (diff)
downloadbase-model-575114661ef9afb95df2a211e1d8498686340e6b.tar.gz
base-model-575114661ef9afb95df2a211e1d8498686340e6b.zip
feat: Refactor and enhance TRPG NER model SDK
- Removed deprecated `word_conll_to_char_conll.py` utility and integrated its functionality into the new `utils` module. - Introduced a comprehensive GitHub Actions workflow for automated publishing to PyPI and GitHub Releases. - Added `__init__.py` files to establish package structure for `basemodel`, `inference`, `training`, and `utils` modules. - Implemented model downloading functionality in `download_model.py` to fetch pre-trained ONNX models. - Developed `TRPGParser` class for ONNX-based inference, including methods for parsing TRPG logs. - Created training utilities in `training/__init__.py` for NER model training with Hugging Face Transformers. - Enhanced utility functions for CoNLL file parsing and dataset creation. - Added command-line interface for converting CoNLL files to datasets with validation options.
Diffstat (limited to 'src/utils')
-rw-r--r--src/utils/conll_to_dataset.py263
-rw-r--r--src/utils/word_conll_to_char_conll.py55
2 files changed, 0 insertions, 318 deletions
diff --git a/src/utils/conll_to_dataset.py b/src/utils/conll_to_dataset.py
deleted file mode 100644
index 2ea5469..0000000
--- a/src/utils/conll_to_dataset.py
+++ /dev/null
@@ -1,263 +0,0 @@
-"""
-TRPG CoNLL 转 Dataset 工具
-- 自动检测 word-level / char-level
-- 生成 {"text": str, "char_labels": List[str]}
-- 支持多文档、跨行实体
-"""
-
-import os
-import re
-import json
-import argparse
-from pathlib import Path
-from typing import List, Dict, Any, Tuple
-from datasets import Dataset
-
-def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
- """
- 将 word-level 标注转为 char-level labels
- Args:
- text: 原始文本 (e.g., "风雨 2024-06-08")
- word_labels: [("风雨", "B-speaker"), ("2024-06-08", "B-timestamp"), ...]
- Returns:
- char_labels: ["B-speaker", "I-speaker", "O", "B-timestamp", ...]
- """
- char_labels = ["O"] * len(text)
- pos = 0
-
- for token, label in word_labels:
- if pos >= len(text):
- break
-
- # 在文本中定位 token(处理空格/换行)
- while pos < len(text) and text[pos] != token[0]:
- pos += 1
- if pos >= len(text):
- break
-
- # 匹配 token
- if text[pos:pos+len(token)] == token:
- # 标注 B/I
- for i, char in enumerate(token):
- idx = pos + i
- if idx < len(char_labels):
- if i == 0 and label.startswith("B-"):
- char_labels[idx] = label
- elif label.startswith("B-"):
- char_labels[idx] = "I" + label[1:]
- else:
- char_labels[idx] = label
- pos += len(token)
- else:
- pos += 1
-
- return char_labels
-
-def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
- """
- 解析 .conll → [{"text": "...", "char_labels": [...]}, ...]
- 自动处理:
- - -DOCSTART- 文档边界
- - 空行句子边界
- - word-level → char-level 转换
- """
- samples = []
- current_lines = [] # 存储原始行用于检测粒度
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- current_lines.append(line.rstrip('\n'))
-
- # 检测是否 word-level
- is_word_level = False
- for line in current_lines:
- if line.strip() and not line.startswith("-DOCSTART-"):
- parts = line.split()
- if len(parts) >= 4:
- token = parts[0]
- # 如果 token 长度 >1 且非标点 → 可能是 word-level
- if len(token) > 1 and not re.match(r'^[^\w\s\u4e00-\u9fff]+$', token):
- is_word_level = True
- break
-
- if is_word_level:
- print(f"Detected word-level CoNLL, converting to char-level...")
- return _parse_word_conll(filepath)
- else:
- print(f"Detected char-level CoNLL, parsing directly...")
- return _parse_char_conll(filepath)
-
-def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]:
- """解析 word-level .conll(如您提供的原始格式)"""
- samples = []
- current_text_parts = []
- current_word_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if not line or line.startswith("-DOCSTART-"):
- if current_text_parts:
- # 合并文本
- text = "".join(current_text_parts)
- # 生成 char-level labels
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
- current_text_parts = []
- current_word_labels = []
- continue
-
- parts = line.split()
- if len(parts) < 4:
- continue
-
- token, label = parts[0], parts[3]
- current_text_parts.append(token)
- current_word_labels.append((token, label))
-
- # 处理末尾
- if current_text_parts:
- text = "".join(current_text_parts)
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
-
- return samples
-
-def _parse_char_conll(filepath: str) -> List[Dict[str, Any]]:
- """解析 char-level .conll"""
- samples = []
- current_text = []
- current_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.rstrip('\n')
- if line.startswith("-DOCSTART-"):
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
- current_text, current_labels = [], []
- continue
-
- if not line:
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
- current_text, current_labels = [], []
- continue
-
- parts = line.split()
- if len(parts) < 4:
- continue
-
- char = parts[0].replace("\\n", "\n")
- label = parts[3]
- current_text.append(char)
- current_labels.append(label)
-
- # 末尾处理
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
-
- return samples
-
-def save_dataset(samples: List[Dict[str, Any]], output_path: str, format: str = "jsonl"):
- """保存数据集"""
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
-
- if format == "jsonl":
- with open(output_path, 'w', encoding='utf-8') as f:
- for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
- print(f"Saved {len(samples)} samples to {output_path} (JSONL)")
-
- elif format == "dataset":
- dataset = Dataset.from_list(samples)
- dataset.save_to_disk(output_path)
- print(f"Saved {len(samples)} samples to {output_path} (Hugging Face Dataset)")
-
- elif format == "both":
- jsonl_path = output_path + ".jsonl"
- with open(jsonl_path, 'w', encoding='utf-8') as f:
- for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
- print(f"Saved JSONL to {jsonl_path}")
-
- dataset_path = output_path + "_dataset"
- dataset = Dataset.from_list(samples)
- dataset.save_to_disk(dataset_path)
- print(f"Saved Dataset to {dataset_path}")
-
-def validate_samples(samples: List[Dict[str, Any]]) -> bool:
- """验证样本一致性"""
- for i, sample in enumerate(samples):
- if len(sample["text"]) != len(sample["char_labels"]):
- print(f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}")
- return False
- print(f"All {len(samples)} samples validated: text & labels length match")
- return True
-
-def main():
- parser = argparse.ArgumentParser(description="Convert CoNLL to TRPG Dataset")
- parser.add_argument("input", type=str, help="Input .conll file or directory")
- parser.add_argument("--output", type=str, default="./dataset/trpg",
- help="Output path (without extension)")
- parser.add_argument("--format", choices=["jsonl", "dataset", "both"],
- default="jsonl", help="Output format")
- parser.add_argument("--validate", action="store_true",
- help="Validate samples after conversion")
-
- args = parser.parse_args()
-
- filepaths = []
- if os.path.isdir(args.input):
- filepaths = sorted(Path(args.input).glob("*.conll"))
- elif args.input.endswith(".conll"):
- filepaths = [Path(args.input)]
- else:
- raise ValueError("Input must be .conll file or directory")
-
- if not filepaths:
- raise FileNotFoundError(f"No .conll files found in {args.input}")
-
- print(f"Processing {len(filepaths)} files: {[f.name for f in filepaths]}")
-
- all_samples = []
- for fp in filepaths:
- print(f"\nProcessing {fp.name}...")
- samples = parse_conll_to_samples(str(fp))
- print(f" → {len(samples)} samples")
- all_samples.extend(samples)
-
- print(f"\nTotal: {len(all_samples)} samples")
-
- if args.validate:
- if not validate_samples(all_samples):
- exit(1)
-
- save_dataset(all_samples, args.output, args.format)
-
- label_counts = {}
- for sample in all_samples:
- for label in sample["char_labels"]:
- label_counts[label] = label_counts.get(label, 0) + 1
-
- print("\nLabel distribution:")
- for label in sorted(label_counts.keys()):
- print(f" {label}: {label_counts[label]}")
-
-if __name__ == "__main__":
- main() \ No newline at end of file
diff --git a/src/utils/word_conll_to_char_conll.py b/src/utils/word_conll_to_char_conll.py
deleted file mode 100644
index e52405f..0000000
--- a/src/utils/word_conll_to_char_conll.py
+++ /dev/null
@@ -1,55 +0,0 @@
-def word_conll_to_char_conll(word_conll_lines: list[str]) -> list[str]:
- char_lines = []
- in_new_sample = True # 下一行是否应视为新样本开始
-
- for line in word_conll_lines:
- stripped = line.strip()
- if not stripped:
- # 空行 → 标记下一句为新样本
- in_new_sample = True
- char_lines.append("")
- continue
-
- parts = stripped.split()
- if len(parts) < 4:
- char_lines.append(line.rstrip())
- continue
-
- token, label = parts[0], parts[3]
-
- # 检测新发言:B-speaker 出现 → 新样本
- if label == "B-speaker" and in_new_sample:
- char_lines.append("-DOCSTART- -X- O")
- in_new_sample = False
-
- # 转换 token → char labels(同前)
- if label == "O":
- for c in token:
- char_lines.append(f"{c} -X- _ O")
- else:
- bio_prefix = label[:2]
- tag = label[2:]
- for i, c in enumerate(token):
- char_label = f"B-{tag}" if (bio_prefix == "B-" and i == 0) else f"I-{tag}"
- char_lines.append(f"{c} -X- _ {char_label}")
-
- return char_lines
-
-if __name__ == "__main__":
- import sys
- if len(sys.argv) < 3:
- print("Usage: python word_conll_to_char_conll.py <input_word.conll> <output_char.conll>")
- sys.exit(1)
-
- input_fp = sys.argv[1]
- output_fp = sys.argv[2]
-
- with open(input_fp, "r", encoding="utf-8") as f:
- word_conll_lines = f.readlines()
-
- char_conll_lines = word_conll_to_char_conll(word_conll_lines)
-
- with open(output_fp, "w", encoding="utf-8") as f:
- f.write("\n".join(char_conll_lines) + "\n")
-
- print(f"Converted {input_fp} to character-level CoNLL format at {output_fp}") \ No newline at end of file