summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2026-01-05 14:39:21 +0800
committerHsiangNianian <i@jyunko.cn>2026-01-05 14:40:15 +0800
commitf51ebaf36593dffb066ad3c4f7f98a0827d8f8e9 (patch)
tree9b78e55197bd89f09d8faf21f7d839b9aacfc2d5
parentc8072765f7abb4bb3fce88a074b0c8f5a25e7ece (diff)
downloadbase-model-f51ebaf36593dffb066ad3c4f7f98a0827d8f8e9.tar.gz
base-model-f51ebaf36593dffb066ad3c4f7f98a0827d8f8e9.zip
feat: improve code formatting and readability in conll_to_dataset.py
-rw-r--r--utils/conll_to_dataset.py142
1 files changed, 71 insertions, 71 deletions
diff --git a/utils/conll_to_dataset.py b/utils/conll_to_dataset.py
index 2ea5469..fc868c5 100644
--- a/utils/conll_to_dataset.py
+++ b/utils/conll_to_dataset.py
@@ -13,6 +13,7 @@ from pathlib import Path
from typing import List, Dict, Any, Tuple
from datasets import Dataset
+
def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
"""
将 word-level 标注转为 char-level labels
@@ -24,19 +25,19 @@ def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[s
"""
char_labels = ["O"] * len(text)
pos = 0
-
+
for token, label in word_labels:
if pos >= len(text):
break
-
+
# 在文本中定位 token(处理空格/换行)
while pos < len(text) and text[pos] != token[0]:
pos += 1
if pos >= len(text):
break
-
+
# 匹配 token
- if text[pos:pos+len(token)] == token:
+ if text[pos : pos + len(token)] == token:
# 标注 B/I
for i, char in enumerate(token):
idx = pos + i
@@ -50,9 +51,10 @@ def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[s
pos += len(token)
else:
pos += 1
-
+
return char_labels
+
def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
"""
解析 .conll → [{"text": "...", "char_labels": [...]}, ...]
@@ -63,11 +65,11 @@ def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
"""
samples = []
current_lines = [] # 存储原始行用于检测粒度
-
- with open(filepath, 'r', encoding='utf-8') as f:
+
+ with open(filepath, "r", encoding="utf-8") as f:
for line in f:
- current_lines.append(line.rstrip('\n'))
-
+ current_lines.append(line.rstrip("\n"))
+
# 检测是否 word-level
is_word_level = False
for line in current_lines:
@@ -76,10 +78,10 @@ def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
if len(parts) >= 4:
token = parts[0]
# 如果 token 长度 >1 且非标点 → 可能是 word-level
- if len(token) > 1 and not re.match(r'^[^\w\s\u4e00-\u9fff]+$', token):
+ if len(token) > 1 and not re.match(r"^[^\w\s\u4e00-\u9fff]+$", token):
is_word_level = True
break
-
+
if is_word_level:
print(f"Detected word-level CoNLL, converting to char-level...")
return _parse_word_conll(filepath)
@@ -87,13 +89,14 @@ def parse_conll_to_samples(filepath: str) -> List[Dict[str, Any]]:
print(f"Detected char-level CoNLL, parsing directly...")
return _parse_char_conll(filepath)
+
def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]:
"""解析 word-level .conll(如您提供的原始格式)"""
samples = []
current_text_parts = []
current_word_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
+
+ with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("-DOCSTART-"):
@@ -102,126 +105,122 @@ def _parse_word_conll(filepath: str) -> List[Dict[str, Any]]:
text = "".join(current_text_parts)
# 生成 char-level labels
char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
+ samples.append({"text": text, "char_labels": char_labels})
current_text_parts = []
current_word_labels = []
continue
-
+
parts = line.split()
if len(parts) < 4:
continue
-
+
token, label = parts[0], parts[3]
current_text_parts.append(token)
current_word_labels.append((token, label))
-
+
# 处理末尾
if current_text_parts:
text = "".join(current_text_parts)
char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({
- "text": text,
- "char_labels": char_labels
- })
-
+ samples.append({"text": text, "char_labels": char_labels})
+
return samples
+
def _parse_char_conll(filepath: str) -> List[Dict[str, Any]]:
"""解析 char-level .conll"""
samples = []
current_text = []
current_labels = []
-
- with open(filepath, 'r', encoding='utf-8') as f:
+
+ with open(filepath, "r", encoding="utf-8") as f:
for line in f:
- line = line.rstrip('\n')
+ line = line.rstrip("\n")
if line.startswith("-DOCSTART-"):
if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
+ samples.append(
+ {"text": "".join(current_text), "char_labels": current_labels.copy()}
+ )
current_text, current_labels = [], []
continue
-
+
if not line:
if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
+ samples.append(
+ {"text": "".join(current_text), "char_labels": current_labels.copy()}
+ )
current_text, current_labels = [], []
continue
-
+
parts = line.split()
if len(parts) < 4:
continue
-
+
char = parts[0].replace("\\n", "\n")
label = parts[3]
current_text.append(char)
current_labels.append(label)
-
+
# 末尾处理
if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy()
- })
-
+ samples.append({"text": "".join(current_text), "char_labels": current_labels.copy()})
+
return samples
+
def save_dataset(samples: List[Dict[str, Any]], output_path: str, format: str = "jsonl"):
"""保存数据集"""
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
-
+
if format == "jsonl":
- with open(output_path, 'w', encoding='utf-8') as f:
+ with open(output_path, "w", encoding="utf-8") as f:
for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
+ f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples to {output_path} (JSONL)")
-
+
elif format == "dataset":
dataset = Dataset.from_list(samples)
dataset.save_to_disk(output_path)
print(f"Saved {len(samples)} samples to {output_path} (Hugging Face Dataset)")
-
+
elif format == "both":
jsonl_path = output_path + ".jsonl"
- with open(jsonl_path, 'w', encoding='utf-8') as f:
+ with open(jsonl_path, "w", encoding="utf-8") as f:
for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + '\n')
+ f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f"Saved JSONL to {jsonl_path}")
-
+
dataset_path = output_path + "_dataset"
dataset = Dataset.from_list(samples)
dataset.save_to_disk(dataset_path)
print(f"Saved Dataset to {dataset_path}")
+
def validate_samples(samples: List[Dict[str, Any]]) -> bool:
"""验证样本一致性"""
for i, sample in enumerate(samples):
if len(sample["text"]) != len(sample["char_labels"]):
- print(f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}")
+ print(
+ f"Sample {i}: text len={len(sample['text'])}, labels len={len(sample['char_labels'])}"
+ )
return False
print(f"All {len(samples)} samples validated: text & labels length match")
return True
+
def main():
parser = argparse.ArgumentParser(description="Convert CoNLL to TRPG Dataset")
parser.add_argument("input", type=str, help="Input .conll file or directory")
- parser.add_argument("--output", type=str, default="./dataset/trpg",
- help="Output path (without extension)")
- parser.add_argument("--format", choices=["jsonl", "dataset", "both"],
- default="jsonl", help="Output format")
- parser.add_argument("--validate", action="store_true",
- help="Validate samples after conversion")
-
+ parser.add_argument(
+ "--output", type=str, default="./dataset/trpg", help="Output path (without extension)"
+ )
+ parser.add_argument(
+ "--format", choices=["jsonl", "dataset", "both"], default="jsonl", help="Output format"
+ )
+ parser.add_argument("--validate", action="store_true", help="Validate samples after conversion")
+
args = parser.parse_args()
-
+
filepaths = []
if os.path.isdir(args.input):
filepaths = sorted(Path(args.input).glob("*.conll"))
@@ -229,35 +228,36 @@ def main():
filepaths = [Path(args.input)]
else:
raise ValueError("Input must be .conll file or directory")
-
+
if not filepaths:
raise FileNotFoundError(f"No .conll files found in {args.input}")
-
+
print(f"Processing {len(filepaths)} files: {[f.name for f in filepaths]}")
-
+
all_samples = []
for fp in filepaths:
print(f"\nProcessing {fp.name}...")
samples = parse_conll_to_samples(str(fp))
print(f" → {len(samples)} samples")
all_samples.extend(samples)
-
+
print(f"\nTotal: {len(all_samples)} samples")
-
+
if args.validate:
if not validate_samples(all_samples):
exit(1)
-
+
save_dataset(all_samples, args.output, args.format)
-
+
label_counts = {}
for sample in all_samples:
for label in sample["char_labels"]:
label_counts[label] = label_counts.get(label, 0) + 1
-
+
print("\nLabel distribution:")
for label in sorted(label_counts.keys()):
print(f" {label}: {label_counts[label]}")
+
if __name__ == "__main__":
- main() \ No newline at end of file
+ main()