aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/basemodel/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/basemodel/utils')
-rw-r--r--src/basemodel/utils/__init__.py192
1 files changed, 0 insertions, 192 deletions
diff --git a/src/basemodel/utils/__init__.py b/src/basemodel/utils/__init__.py
deleted file mode 100644
index 12a3ef4..0000000
--- a/src/basemodel/utils/__init__.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""
-工具模块
-
-提供数据加载、CoNLL 格式处理等工具函数。
-"""
-
-import os
-import glob
-from typing import List, Dict, Any, Tuple
-from datasets import Dataset
-from tqdm.auto import tqdm
-
-
-def word_to_char_labels(text: str, word_labels: List[Tuple[str, str]]) -> List[str]:
- """Convert word-level labels to char-level"""
- char_labels = ["O"] * len(text)
- pos = 0
-
- for token, label in word_labels:
- if pos >= len(text):
- break
-
- while pos < len(text) and text[pos] != token[0]:
- pos += 1
- if pos >= len(text):
- break
-
- if text[pos: pos + len(token)] == token:
- for i in range(len(token)):
- idx = pos + i
- if idx < len(char_labels):
- if i == 0 and label.startswith("B-"):
- char_labels[idx] = label
- elif label.startswith("B-"):
- char_labels[idx] = "I" + label[1:]
- else:
- char_labels[idx] = label
- pos += len(token)
- else:
- pos += 1
-
- return char_labels
-
-
-def parse_conll_file(filepath: str) -> List[Dict[str, Any]]:
- """Parse .conll → [{"text": str, "char_labels": List[str]}]"""
- with open(filepath, "r", encoding="utf-8") as f:
- lines = [line.rstrip("\n") for line in f.readlines()]
-
- # 检测 word-level
- is_word_level = any(
- len(line.split()[0]) > 1
- for line in lines
- if line.strip() and not line.startswith("-DOCSTART-") and len(line.split()) >= 4
- )
-
- samples = []
- if is_word_level:
- current_text_parts = []
- current_word_labels = []
-
- for line in lines:
- if not line or line.startswith("-DOCSTART-"):
- if current_text_parts:
- text = "".join(current_text_parts)
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({"text": text, "char_labels": char_labels})
- current_text_parts = []
- current_word_labels = []
- continue
-
- parts = line.split()
- if len(parts) >= 4:
- token, label = parts[0], parts[3]
- current_text_parts.append(token)
- current_word_labels.append((token, label))
-
- if current_text_parts:
- text = "".join(current_text_parts)
- char_labels = word_to_char_labels(text, current_word_labels)
- samples.append({"text": text, "char_labels": char_labels})
- else:
- current_text = []
- current_labels = []
-
- for line in lines:
- if line.startswith("-DOCSTART-"):
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy(),
- })
- current_text, current_labels = [], []
- continue
-
- if not line:
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy(),
- })
- current_text, current_labels = [], []
- continue
-
- parts = line.split()
- if len(parts) >= 4:
- char = parts[0].replace("\\n", "\n")
- label = parts[3]
- current_text.append(char)
- current_labels.append(label)
-
- if current_text:
- samples.append({
- "text": "".join(current_text),
- "char_labels": current_labels.copy(),
- })
-
- return samples
-
-
-def load_conll_dataset(conll_dir_or_files: str) -> Tuple[Dataset, List[str]]:
- """Load .conll files → Dataset"""
- filepaths = []
- if os.path.isdir(conll_dir_or_files):
- filepaths = sorted(glob.glob(os.path.join(conll_dir_or_files, "*.conll")))
- elif conll_dir_or_files.endswith(".conll"):
- filepaths = [conll_dir_or_files]
- else:
- raise ValueError("conll_dir_or_files must be .conll file or directory")
-
- if not filepaths:
- raise FileNotFoundError(f"No .conll files found in {conll_dir_or_files}")
-
- print(f"Loading {len(filepaths)} conll files: {filepaths}")
-
- all_samples = []
- label_set = {"O"}
-
- for fp in tqdm(filepaths, desc="Parsing .conll"):
- samples = parse_conll_file(fp)
- for s in samples:
- all_samples.append(s)
- label_set.update(s["char_labels"])
-
- # Build label list
- label_list = ["O"]
- for label in sorted(label_set - {"O"}):
- if label.startswith("B-") or label.startswith("I-"):
- label_list.append(label)
- for label in list(label_list):
- if label.startswith("B-"):
- i_label = "I" + label[1:]
- if i_label not in label_list:
- label_list.append(i_label)
- print(f"⚠️ Added missing {i_label} for {label}")
-
- print(f"✅ Loaded {len(all_samples)} samples, {len(label_list)} labels: {label_list}")
- return Dataset.from_list(all_samples), label_list
-
-
-def tokenize_and_align_labels(examples, tokenizer, label2id, max_length=128):
- """Tokenize and align labels with tokenizer"""
- tokenized = tokenizer(
- examples["text"],
- truncation=True,
- padding=True,
- max_length=max_length,
- return_offsets_mapping=True,
- return_tensors=None,
- )
-
- labels = []
- for i, label_seq in enumerate(examples["char_labels"]):
- offsets = tokenized["offset_mapping"][i]
- label_ids = []
- for start, end in offsets:
- if start == end:
- label_ids.append(-100)
- else:
- label_ids.append(label2id[label_seq[start]])
- labels.append(label_ids)
-
- tokenized["labels"] = labels
- return tokenized
-
-
-__all__ = [
- "word_to_char_labels",
- "parse_conll_file",
- "load_conll_dataset",
- "tokenize_and_align_labels",
-]