diff options
Diffstat (limited to 'src/basemodel/utils')
| -rw-r--r-- | src/basemodel/utils/__init__.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/src/basemodel/utils/__init__.py b/src/basemodel/utils/__init__.py new file mode 100644 index 0000000..12a3ef4 --- /dev/null +++ b/src/basemodel/utils/__init__.py @@ -0,0 +1,192 @@ +""" +工具模块 + +提供数据加载、CoNLL 格式处理等工具函数。 +""" + +import os +import glob +from typing import List, Dict, Any, Tuple +from datasets import Dataset +from tqdm.auto import tqdm + + +def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]: + """Convert word-level labels to char-level""" + char_labels = ["O"] * len(text) + pos = 0 + + for token, label in word_labels: + if pos >= len(text): + break + + while pos < len(text) and text[pos] != token[0]: + pos += 1 + if pos >= len(text): + break + + if text[pos: pos + len(token)] == token: + for i in range(len(token)): + idx = pos + i + if idx < len(char_labels): + if i == 0 and label.startswith("B-"): + char_labels[idx] = label + elif label.startswith("B-"): + char_labels[idx] = "I" + label[1:] + else: + char_labels[idx] = label + pos += len(token) + else: + pos += 1 + + return char_labels + + +def parse_conll_file(filepath: str) -> List[Dict[str, Any]]: + """Parse .conll → [{"text": str, "char_labels": List[str]}]""" + with open(filepath, "r", encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # 检测 word-level + is_word_level = any( + len(line.split()[0]) > 1 + for line in lines + if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4 + ) + + samples = [] + if is_word_level: + current_text_parts = [] + current_word_labels = [] + + for line in lines: + if not line or line.startswith("-DOCSTART-"): + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + current_text_parts = [] + current_word_labels = [] + continue + + parts = line.split() + if len(parts) >= 4: + token, label = parts[0], parts[3] + current_text_parts.append(token) + current_word_labels.append((token, label)) + + if current_text_parts: + text = "".join(current_text_parts) + char_labels = word_to_char_labels(text, current_word_labels) + samples.append({"text": text, "char_labels": char_labels}) + else: + current_text = [] + current_labels = [] + + for line in lines: + if line.startswith("-DOCSTART-"): + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + if not line: + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + current_text, current_labels = [], [] + continue + + parts = line.split() + if len(parts) >= 4: + char = parts[0].replace("\\n", "\n") + label = parts[3] + current_text.append(char) + current_labels.append(label) + + if current_text: + samples.append({ + "text": "".join(current_text), + "char_labels": current_labels.copy(), + }) + + return samples + + +def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]: + """Load .conll files → Dataset""" + filepaths = [] + if os.path.isdir(conll_dir_or_files): + filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll"))) + elif conll_dir_or_files.endswith(".conll"): + filepaths = [conll_dir_or_files] + else: + raise ValueError("conll_dir_or_files must be .conll file or directory") + + if not filepaths: + raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}") + + print(f"Loading {len(filepaths)} conll files: {filepaths}") + + all_samples = [] + label_set = {"O"} + + for fp in tqdm(filepaths, desc="Parsing .conll"): + samples = parse_conll_file(fp) + for s in samples: + all_samples.append(s) + label_set.update(s["char_labels"]) + + # Build label list + label_list = ["O"] + for label in sorted(label_set - {"O"}): + if label.startswith("B-") or label.startswith("I-"): + label_list.append(label) + for label in list(label_list): + if label.startswith("B-"): + i_label = "I" + label[1:] + if i_label not in label_list: + label_list.append(i_label) + print(f"⚠️ Added missing {i_label} for {label}") + + print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}") + return Dataset.from_list(all_samples), label_list + + +def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128): + """Tokenize and align labels with tokenizer""" + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=True, + max_length=max_length, + return_offsets_mapping=True, + return_tensors=None, + ) + + labels = [] + for i, label_seq in enumerate(examples["char_labels"]): + offsets = tokenized["offset_mapping"][i] + label_ids = [] + for start, end in offsets: + if start == end: + label_ids.append(-100) + else: + label_ids.append(label2id[label_seq[start]]) + labels.append(label_ids) + + tokenized["labels"] = labels + return tokenized + + +__all__ = [ + "word_to_char_labels", + "parse_conll_file", + "load_conll_dataset", + "tokenize_and_align_labels", +] |
