diff options
| author | 2026-01-06 14:48:40 +0800 | |
|---|---|---|
| committer | 2026-01-06 14:48:47 +0800 | |
| commit | e209a4ad05fd0d87f6a953663a93eec6709c3e63 (patch) | |
| tree | 50a509b5ad9068227f721a513cda3faa8cf7d1e7 | |
| parent | 6ed8fb2a7dc1fed79030a3c4c549e4eede332fe7 (diff) | |
| download | base-model-e209a4ad05fd0d87f6a953663a93eec6709c3e63.tar.gz base-model-e209a4ad05fd0d87f6a953663a93eec6709c3e63.zip | |
feat: enhance log processing with multiple output versions and test set generation
| -rw-r--r-- | utils/llm_seri.py | 167 | ||||
| -rw-r--r-- | utils/process_log.py | 111 |
2 files changed, 238 insertions, 40 deletions
diff --git a/utils/llm_seri.py b/utils/llm_seri.py index 17fee40..ba2cf1e 100644 --- a/utils/llm_seri.py +++ b/utils/llm_seri.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 """ 使用 LLM 对游戏日志进行自动标注(支持高并发) 标注格式:speaker、timestamp、dialogue、action、comment @@ -6,6 +5,7 @@ import json import os +import re import time import threading from concurrent.futures import ThreadPoolExecutor, as_completed @@ -29,15 +29,18 @@ def get_annotation_prompt(text: str) -> str: 1. **speaker**:说话人/玩家名字 - 通常位于文本开头 - 后面紧跟空格和时间戳 - - 格式特征:`名字 时间戳` + - 格式特征:`名字(QQ号) 时间戳` 2. **timestamp**:时间戳 - 时间格式:`YYYY-MM-DD HH:MM:SS` + - 时间格式也有可能是其他的变体,如 `YYYY/MM/DD HH:MM` 等 - 紧跟在 speaker 之后 -3. **action**:骰子/游戏指令 - - 以点号 `.` 开头的指令 - - 例如:`.rd10+7`、`.ww12a9+1`、`.rst` +3. **action**:骰子/游戏指令或者人物动作词描述 + - 以点号 `.` 或者 `/` 或者 `!` 或者 `。` 开头的指令 + - 例如:`.rd10+7`、`。ww12a9+1`、`!roll` 等 + - 也可以是描述角色动作的简短词语 + - 例如:`站起身来`、`掏出法杖` 等 4. **dialogue**:角色对话/说话内容 - **判断依据:是否为角色口中说出的话** @@ -47,7 +50,6 @@ def get_annotation_prompt(text: str) -> str: - 关键:这段文字是角色"说"出来的,而不是"做"的动作描述 5. **comment**:其他所有内容 - - 动作描述(角色做了什么) - 场景描写 - 系统消息(如骰子结果) - 心理活动(非说话形式) @@ -58,13 +60,13 @@ def get_annotation_prompt(text: str) -> str: - **按文本出现顺序标注** - **根据语义判断类型**,不要仅依赖格式特征 -- **不要遗漏任何内容**,文本的所有部分都必须被标注 +- **不要遗漏任何有效内容**,文本的有效部分都必须被标注,即使是系统消息或动作描述,不过你需要自己判断哪些是 action 动作哪些是 comment 描述或者旁白,说话人后面的QQ号和括号不需要标注 - **保持文本原样**,标注的 text 必须与原文完全一致 ## 标注示例 ### 示例 1:纯动作描述 -文本:`风雨 2024-06-08 21:44:59\\n剧烈的疼痛从头颅深处一波波地涌出,仿佛每一次脉搏的跳动都在击打你的头骨。` +文本:`风雨(1231287491) 2024-06-08 21:44:59\\n剧烈的疼痛从头颅深处一波波地涌出,仿佛每一次脉搏的跳动都在击打你的头骨。` 标注: ```json {{ @@ -77,7 +79,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 2:有引号的对话 + 动作 -文本:`莎莎 2024-06-08 21:46:26\\n"呜哇..."#下意识去拿法杖,但启动施法起手后大脑里一片空白...` +文本:`莎莎(123125124) 2024-06-08 21:46:26\\n"呜哇..."#下意识去拿法杖,但启动施法起手后大脑里一片空白...` 标注: ```json {{ @@ -91,7 +93,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 3:无引号的对话(语义判断) -文本:`风雨 2024-06-08 21:50:15\\n我不行了,��带我离开这里` +文本:`风雨(1231287491) 2024-06-08 21:50:15\\n我不行了,��带我离开这里` 标注: ```json {{ @@ -104,7 +106,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 4:对话 + 动作混合 -文本:`白麗 霊夢 2024-06-08 21:51:00\\n好的,我明白了。他点点头,转身离开了房间。` +文本:`白麗 霊夢(12345678921) 2024-06-08 21:51:00\\n好的,我明白了。他点点头,转身离开了房间。` 标注: ```json {{ @@ -118,7 +120,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 5:纯动作指令 -文本:`莎莎 2024-06-08 21:49:51\\n.rd10+7` +文本:`莎莎(1251124512) 2024-06-08 21:49:51\\n.rd10+7` 标注: ```json {{ @@ -131,7 +133,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 6:系统消息 -文本:`白麗 霊夢 2024-06-08 21:49:51\\n莎莎 的出目是\\nD10+7=6+7=13` +文本:`白麗 霊夢(12345678921) 2024-06-08 21:49:51\\n莎莎 的出目是\\nD10+7=6+7=13` 标注: ```json {{ @@ -144,7 +146,7 @@ def get_annotation_prompt(text: str) -> str: ``` ### 示例 7:多段对话混合描述 -文本:`白麗 霊夢 2024-06-08 21:52:00\\n等等,这是什么?他指着地上的物品,疑惑地问道。这是...魔法道具吗?` +文本:`白麗 霊夢(12345678921) 2024-06-08 21:52:00\\n等等,这是什么?他指着地上的物品,疑惑地问道。这是...魔法道具吗?` 标注: ```json {{ @@ -162,7 +164,9 @@ def get_annotation_prompt(text: str) -> str: - **dialogue 的判断核心是"这是角色说的话吗"**,而不是"有没有引号" - 如果文本是角色直接说出的内容,即使没有引号也应标注为 dialogue -- 如果文本是动作、场景、心理描写等非说话内容,应标注为 comment +- 如果文本是场景、心理描写等非说话内容,应标注为 comment +- 如果文本是指令或者任务的动作词,应标注为 action +- 请严格按照上述 JSON 格式返回标注结果 - 只返回 JSON,不要添加任何其他解释性文字 ## 待标注文本 @@ -444,24 +448,131 @@ def process_logs(input_path: str, output_path: str, concurrency: int = 10, batch print(f"结果已保存到: {output_path}") -def main(): +def post_process_labels(task_data: Dict[str, Any]) -> Dict[str, Any]: + """ + 对标注结果进行后处理,去除非 speaker 和 timestamp 标签中的特殊符号并重新计算位置 + + Args: + task_data: 单个任务的数据 + + Returns: + 后处理后的任务数据 """ - 主函数 + original_text = task_data["data"]["text"] + result_list = task_data["annotations"][0]["result"] + + # 定义需要去除的正则表达式模式 + pattern = r"[#“”「」『』【】]" + + for result_item in result_list: + labels = result_item["value"]["labels"] + text = result_item["value"]["text"] + print(f"Processing label: {labels} with text: {text}") + + # 只处理非 speaker 和 timestamp 的标签 + if "speaker" not in labels and "timestamp" not in labels: + # 记录原始信息 + original_start = result_item["value"]["start"] + original_end = result_item["value"]["end"] + print(f"Original positions: start={original_start}, end={original_end}") + + processed_text = re.sub(pattern, "", text) + print(f"Processed text: {processed_text}") + + # 如果文本发生变化 + if processed_text != text: + # 在原始文本中查找处理后的文本位置 + # 首先从原始位置附近开始查找 + search_start = max(0, original_start - 10) + search_end = min(len(original_text), original_end + 10) + + # 在搜索范围内查找处理后的文本 + pos = original_text.find(processed_text, search_start, search_end) + + # 如果在附近没找到,在整个文本中查找 + if pos == -1: + pos = original_text.find(processed_text) + + # 更新结果 + if pos != -1: + result_item["value"]["text"] = processed_text + result_item["value"]["start"] = pos + result_item["value"]["end"] = pos + len(processed_text) + else: + print( + f"警告:无法找到处理后的文本位置,原始文本: '{text}', 处理后: '{processed_text}'" + ) + + return task_data + + +def process_input_file(input_path: str): """ - # 使用默认路径 - input_path = "dataset/processed_logs/processed_logs.json" - output_path = f"dataset/llm_annotated_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + Process input file and perform post-processing on all annotation results + + Args: + input_path: Path to the input annotation result file + """ + print(f"Reading input file: {input_path}") + + try: + with open(input_path, "r", encoding="utf-8") as f: + data = json.load(f) - print(f"输入文件: {input_path}") - print(f"输出文件: {output_path}") + print(f"Total {len(data)} tasks to process") - # 检查输入文件是否存在 - if not os.path.exists(input_path): - print(f"错误:输入文件不存在: {input_path}") - return + # Process each task + for i, task_data in enumerate(data): + print(f"\n[{i+1}/{len(data)}] Processing task ID: {task_data.get('id', 'N/A')}") - # 开始处理 - process_logs(input_path, output_path, concurrency=10, batch_size=50) + # Perform post-processing + processed_task = post_process_labels(task_data) + + # Update original data + data[i] = processed_task + + print(f"Completed, processed text length: {len(processed_task['data']['text'])}") + + # Save processed file + output_path = input_path.replace(".json", "_post_processed.json") + output_path = ( + f"dataset/llm_annotated_post_processed_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + ) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + print(f"\nPost-processing completed! Results saved to: {output_path}") + + except Exception as e: + print(f"Error processing file: {e}") + + +def main(): + """Main function""" + import sys + + # Check command line arguments + if len(sys.argv) > 1: + # If file path is provided, execute post-processing + input_file = sys.argv[1] + print(f"File path detected, executing post-processing mode...") + process_input_file(input_file) + else: + # Use default path, execute normal annotation process + input_path = input("请输入待处理的日志文件路径(processed_logs.json):").strip() + output_path = f"dataset/llm_annotated_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + + print(f"Input file: {input_path}") + print(f"Output file: {output_path}") + + # Check if input file exists + if not os.path.exists(input_path): + print(f"Error: Input file does not exist: {input_path}") + return + + # Start processing + process_logs(input_path, output_path, concurrency=10, batch_size=50) if __name__ == "__main__": diff --git a/utils/process_log.py b/utils/process_log.py index fffc359..3a577f2 100644 --- a/utils/process_log.py +++ b/utils/process_log.py @@ -1,16 +1,22 @@ import glob import json +import os +import random import re -def process_g_files(): - files = glob.glob("g*.txt") +def process_g_files(directory="."): + # 获取绝对路径 + abs_directory = os.path.abspath(directory) + # 在指定目录中查找以g开头的txt文件 + pattern = os.path.join(abs_directory, "g*.txt") + files = glob.glob(pattern) if not files: print("未找到以'g'开头的txt文件") return - print(f"找到 {len(files)} 个文件: {', '.join(files)}") + print(f"在目录 {abs_directory} 中找到 {len(files)} 个文件: {', '.join(os.path.basename(f) for f in files)}") all_entries = [] @@ -27,27 +33,108 @@ def process_g_files(): else: if current_paragraph: paragraph_text = "\n".join(current_paragraph) - cleaned_text = re.sub(r"\(\d+\)", "", paragraph_text) - all_entries.append({"text": cleaned_text}) + all_entries.append(paragraph_text) current_paragraph = [] if current_paragraph: paragraph_text = "\n".join(current_paragraph) - cleaned_text = re.sub(r"\(\d+\)", "", paragraph_text) - all_entries.append({"text": cleaned_text}) + all_entries.append(paragraph_text) print(f"处理文件 {file_path} 完成") except Exception as e: print(f"处理文件 {file_path} 时出错: {e}") - output_file = "processed_logs.json" - with open(output_file, "w", encoding="utf-8") as f: - json.dump(all_entries, f, ensure_ascii=False, indent=2) + # 生成三个版本的日志文件 + + # 版本1:仅去除数字标记 (原有版本) + version1_entries = [] + for text in all_entries: + cleaned_text = re.sub(r"\(\d+\)", "", text) + version1_entries.append({"text": cleaned_text}) + + with open("processed_logs.json", "w", encoding="utf-8") as f: + json.dump(version1_entries, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs.json (仅去除数字标记)") + + # 版本2:保留所有敏感信息,仅去除数字标记 + version2_entries = [] + for text in all_entries: + version2_entries.append({"text": text}) + + with open("processed_logs_sensitive.json", "w", encoding="utf-8") as f: + json.dump(version2_entries, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs_sensitive.json (保留敏感信息)") + + # 版本3:去除数字标记和所有特定标点符号 + punctuation_pattern = r'[#""''""“”‘’「」『』【】]' + version3_entries = [] + for text in all_entries: + # 先去除数字标记 + cleaned_text = re.sub(r"\(\d+\)", "", text) + # 再去除特定标点符号 + cleaned_text = re.sub(punctuation_pattern, "", cleaned_text) + version3_entries.append({"text": cleaned_text}) + + with open("processed_logs_clean.json", "w", encoding="utf-8") as f: + json.dump(version3_entries, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs_clean.json (去除数字标记和标点符号)") + + # 为每个版本生成20%的测试集 + random.seed(42) # 设置随机种子确保可重复性 + + # 版本1的测试集 (processed_logs.json -> processed_logs_test.json) + test_size1 = max(1, int(len(version1_entries) * 0.2)) + test_set1 = random.sample(version1_entries, test_size1) + train_set1 = [entry for entry in version1_entries if entry not in test_set1] + + with open("processed_logs_train.json", "w", encoding="utf-8") as f: + json.dump(train_set1, f, ensure_ascii=False, indent=2) + with open("processed_logs_test.json", "w", encoding="utf-8") as f: + json.dump(test_set1, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs_train.json ({len(train_set1)} 条)") + print(f"已生成 processed_logs_test.json ({len(test_set1)} 条, 20%)") + + # 版本2的测试集 (processed_logs_sensitive.json -> processed_logs_sensitive_test.json) + test_size2 = max(1, int(len(version2_entries) * 0.2)) + test_set2 = random.sample(version2_entries, test_size2) + train_set2 = [entry for entry in version2_entries if entry not in test_set2] + + with open("processed_logs_sensitive_train.json", "w", encoding="utf-8") as f: + json.dump(train_set2, f, ensure_ascii=False, indent=2) + with open("processed_logs_sensitive_test.json", "w", encoding="utf-8") as f: + json.dump(test_set2, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs_sensitive_train.json ({len(train_set2)} 条)") + print(f"已生成 processed_logs_sensitive_test.json ({len(test_set2)} 条, 20%)") + + # 版本3的测试集 (processed_logs_clean.json -> processed_logs_clean_test.json) + test_size3 = max(1, int(len(version3_entries) * 0.2)) + test_set3 = random.sample(version3_entries, test_size3) + train_set3 = [entry for entry in version3_entries if entry not in test_set3] + + with open("processed_logs_clean_train.json", "w", encoding="utf-8") as f: + json.dump(train_set3, f, ensure_ascii=False, indent=2) + with open("processed_logs_clean_test.json", "w", encoding="utf-8") as f: + json.dump(test_set3, f, ensure_ascii=False, indent=2) + + print(f"已生成 processed_logs_clean_train.json ({len(train_set3)} 条)") + print(f"已生成 processed_logs_clean_test.json ({len(test_set3)} 条, 20%)") print(f"\n处理完成! 共处理 {len(all_entries)} 个段落") - print(f"结果已保存到 {output_file}") + print(f"生成了3个版本的数据集,每个版本都有对应的训练集和测试集") if __name__ == "__main__": - process_g_files() + import sys + if len(sys.argv) > 1: + directory = sys.argv[1] + print(f"处理目录: {directory}") + process_g_files(directory) + else: + print("处理当前目录") + process_g_files(".") |
