aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/base_model_trpgner/utils')
-rw-r--r--src/base_model_trpgner/utils/__init__.py192
1 files changed, 192 insertions, 0 deletions
diff --git a/src/base_model_trpgner/utils/__init__.py b/src/base_model_trpgner/utils/__init__.py
new file mode 100644
index 0000000..12a3ef4
--- /dev/null
+++ b/src/base_model_trpgner/utils/__init__.py
@@ -0,0 +1,192 @@
+"""
+工具模块
+
+提供数据加载、CoNLL 格式处理等工具函数。
+"""
+
+import os
+import glob
+from typing import List, Dict, Any, Tuple
+from datasets import Dataset
+from tqdm.auto import tqdm
+
+
+def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
+ """Convert word-level labels to char-level"""
+ char_labels = ["O"] * len(text)
+ pos = 0
+
+ for token, label in word_labels:
+ if pos >= len(text):
+ break
+
+ while pos < len(text) and text[pos] != token[0]:
+ pos += 1
+ if pos >= len(text):
+ break
+
+ if text[pos: pos + len(token)] == token:
+ for i in range(len(token)):
+ idx = pos + i
+ if idx < len(char_labels):
+ if i == 0 and label.startswith("B-"):
+ char_labels[idx] = label
+ elif label.startswith("B-"):
+ char_labels[idx] = "I" + label[1:]
+ else:
+ char_labels[idx] = label
+ pos += len(token)
+ else:
+ pos += 1
+
+ return char_labels
+
+
+def parse_conll_file(filepath: str) -> List[Dict[str, Any]]:
+ """Parse .conll → [{"text": str, "char_labels": List[str]}]"""
+ with open(filepath, "r", encoding="utf-8") as f:
+ lines = [line.rstrip("\n") for line in f.readlines()]
+
+ # 检测 word-level
+ is_word_level = any(
+ len(line.split()[0]) > 1
+ for line in lines
+ if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4
+ )
+
+ samples = []
+ if is_word_level:
+ current_text_parts = []
+ current_word_labels = []
+
+ for line in lines:
+ if not line or line.startswith("-DOCSTART-"):
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+ current_text_parts = []
+ current_word_labels = []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ token, label = parts[0], parts[3]
+ current_text_parts.append(token)
+ current_word_labels.append((token, label))
+
+ if current_text_parts:
+ text = "".join(current_text_parts)
+ char_labels = word_to_char_labels(text, current_word_labels)
+ samples.append({"text": text, "char_labels": char_labels})
+ else:
+ current_text = []
+ current_labels = []
+
+ for line in lines:
+ if line.startswith("-DOCSTART-"):
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+ current_text, current_labels = [], []
+ continue
+
+ if not line:
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+ current_text, current_labels = [], []
+ continue
+
+ parts = line.split()
+ if len(parts) >= 4:
+ char = parts[0].replace("\\n", "\n")
+ label = parts[3]
+ current_text.append(char)
+ current_labels.append(label)
+
+ if current_text:
+ samples.append({
+ "text": "".join(current_text),
+ "char_labels": current_labels.copy(),
+ })
+
+ return samples
+
+
+def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]:
+ """Load .conll files → Dataset"""
+ filepaths = []
+ if os.path.isdir(conll_dir_or_files):
+ filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll")))
+ elif conll_dir_or_files.endswith(".conll"):
+ filepaths = [conll_dir_or_files]
+ else:
+ raise ValueError("conll_dir_or_files must be .conll file or directory")
+
+ if not filepaths:
+ raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}")
+
+ print(f"Loading {len(filepaths)} conll files: {filepaths}")
+
+ all_samples = []
+ label_set = {"O"}
+
+ for fp in tqdm(filepaths, desc="Parsing .conll"):
+ samples = parse_conll_file(fp)
+ for s in samples:
+ all_samples.append(s)
+ label_set.update(s["char_labels"])
+
+ # Build label list
+ label_list = ["O"]
+ for label in sorted(label_set - {"O"}):
+ if label.startswith("B-") or label.startswith("I-"):
+ label_list.append(label)
+ for label in list(label_list):
+ if label.startswith("B-"):
+ i_label = "I" + label[1:]
+ if i_label not in label_list:
+ label_list.append(i_label)
+ print(f"⚠️ Added missing {i_label} for {label}")
+
+ print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}")
+ return Dataset.from_list(all_samples), label_list
+
+
+def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128):
+ """Tokenize and align labels with tokenizer"""
+ tokenized = tokenizer(
+ examples["text"],
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+ return_offsets_mapping=True,
+ return_tensors=None,
+ )
+
+ labels = []
+ for i, label_seq in enumerate(examples["char_labels"]):
+ offsets = tokenized["offset_mapping"][i]
+ label_ids = []
+ for start, end in offsets:
+ if start == end:
+ label_ids.append(-100)
+ else:
+ label_ids.append(label2id[label_seq[start]])
+ labels.append(label_ids)
+
+ tokenized["labels"] = labels
+ return tokenized
+
+
+__all__ = [
+ "word_to_char_labels",
+ "parse_conll_file",
+ "load_conll_dataset",
+ "tokenize_and_align_labels",
+]