summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2026-01-04 15:59:54 +0800
committerHsiangNianian <i@jyunko.cn>2026-01-04 15:59:54 +0800
commitf4f9c541e9917fa614e6e1b8e737167f44c89c43 (patch)
tree92b362d525035f26979fa16396b9d52464201415
parent910333fda7fc32ed426e96f11f01c76d6e95544b (diff)
downloadbase-model-f4f9c541e9917fa614e6e1b8e737167f44c89c43.tar.gz
base-model-f4f9c541e9917fa614e6e1b8e737167f44c89c43.zip
feat: add log processing and LLM annotation functionality
-rw-r--r--pyproject.toml3
-rw-r--r--utils/llm_seri.py396
-rw-r--r--utils/process_log.py53
3 files changed, 452 insertions, 0 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 279136b..c95592d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,7 +34,9 @@ classifiers = [
]
dependencies = [
+ "aiohttp>=3.13.2",
"numpy>=1.24.0",
+ "ollama>=0.6.1",
"onnxruntime>=1.23.2",
"transformers>=4.57.3",
]
@@ -52,6 +54,7 @@ dev = [
"pytest>=8.0.0",
"black>=24.0.0",
"ruff>=0.1.0",
+ "dotenv>=0.9.9",
]
webui = ["base-model-trpgner[train]", "gradio>=6.2.0", "scikit-learn>=1.4.0"]
all = ["base-model-trpgner[train,webui,dev]"]
diff --git a/utils/llm_seri.py b/utils/llm_seri.py
new file mode 100644
index 0000000..ac8174b
--- /dev/null
+++ b/utils/llm_seri.py
@@ -0,0 +1,396 @@
+#!/usr/bin/env python3
+"""
+使用 LLM 对游戏日志进行自动标注(支持高并发)
+标注格式:speaker、timestamp、dialogue、action、comment
+"""
+
+import json
+import os
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime, timezone
+from typing import Any, Dict, List
+
+from dotenv import load_dotenv
+from ollama import chat
+
+load_dotenv()
+
+
+def get_annotation_prompt(text: str) -> str:
+ """
+ 构造 LLM 标注 prompt(只返回类型和文本,我们自己计算位置)
+ """
+ return f"""你是一个专业的文本标注助手。请对以下 TRPG 游戏日志进行标注,标注格式为 JSON。
+
+## 标签类型及规则
+
+1. **speaker**: 说话人/玩家名字(位于文本开头,后跟空格和时间戳)
+2. **timestamp**: 时间戳(格式如:2024-06-08 21:44:59)
+3. **dialogue**: 角色对话(用引号包裹的说话内容,如 ""..."")
+4. **action**: 动作/指令(以 . 开头的骰子指令,如 .rd10+7、.ww12a9+1)
+5. **comment**: 其他描述性内容(角色扮演描述、系统消息、动作描述等)
+
+## 标注示例
+
+### 示例 1:
+文本:`风雨 2024-06-08 21:44:59\\n剧烈的疼痛从头颅深处一波波地涌出,仿佛每一次脉搏的跳动都在击打你的头骨。`
+标注:
+```json
+{{
+ "annotations": [
+ {{"type": "speaker", "text": "风雨"}},
+ {{"type": "timestamp", "text": "2024-06-08 21:44:59"}},
+ {{"type": "comment", "text": "剧烈的疼痛从头颅深处一波波地涌出,仿佛每一次脉搏的跳动都在击打你的头骨。"}}
+ ]
+}}
+```
+
+### 示例 2:
+文本:`莎莎 2024-06-08 21:46:26\\n"呜哇..."#下意识去拿法杖,但启动施法起手后大脑里一片空白...`
+标注:
+```json
+{{
+ "annotations": [
+ {{"type": "speaker", "text": "莎莎"}},
+ {{"type": "timestamp", "text": "2024-06-08 21:46:26"}},
+ {{"type": "dialogue", "text": ""呜哇...""}},
+ {{"type": "comment", "text": "#下意识去拿法杖,但启动施法起手后大脑里一片空白..."}}
+ ]
+}}
+```
+
+### 示例 3:
+文本:`莎莎 2024-06-08 21:49:51\\n.rd10+7`
+标注:
+```json
+{{
+ "annotations": [
+ {{"type": "speaker", "text": "莎莎"}},
+ {{"type": "timestamp", "text": "2024-06-08 21:49:51"}},
+ {{"type": "action", "text": ".rd10+7"}}
+ ]
+}}
+```
+
+### 示例 4:
+文本:`白麗 霊夢 2024-06-08 21:49:51\\n莎莎 的出目是\\nD10+7=6+7=13`
+标注:
+```json
+{{
+ "annotations": [
+ {{"type": "speaker", "text": "白麗 霊夢"}},
+ {{"type": "timestamp", "text": "2024-06-08 21:49:51"}},
+ {{"type": "comment", "text": "莎莎 的出目是\\nD10+7=6+7=13"}}
+ ]
+}}
+```
+
+## 注意事项
+
+- 只返回标注的类型(type)和文本内容(text),不需要返回位置信息
+- 确保标注的文本内容与原文本完全一致
+- 只返回 JSON,不要添加任何其他解释性文字
+- 如果文本中不包含某种标签类型,就不要包含该标签
+
+## 待标注文本
+
+{text}
+
+## 请返回标注结果(只返回 JSON,不要其他内容):"""
+
+
+def call_llm_api(prompt: str, index: int, total: int) -> Dict[str, Any]:
+ """
+ 调用 Ollama 本地 LLM(带重试机制)
+ """
+ ollama_model = os.getenv("OLLAMA_MODEL", "qwen3:8b")
+
+ messages = [
+ {
+ "role": "system",
+ "content": "你是一个专业的文本标注助手,严格按照 JSON 格式返回标注结果,不要添加任何其他内容。",
+ },
+ {"role": "user", "content": prompt},
+ ]
+
+ max_retries = 3
+ base_delay = 1 # 秒
+
+ for attempt in range(max_retries):
+ try:
+ response = chat(
+ model=ollama_model,
+ messages=messages,
+ think=False,
+ stream=False,
+ )
+
+ content = response.message.content
+
+ if not content:
+ print(f"[{index}/{total}] API 返回空内容")
+ if attempt < max_retries - 1:
+ time.sleep(base_delay)
+ continue
+ return {"annotations": []}
+
+ # 尝试解析 JSON
+ content = content.strip()
+ if content.startswith("```json"):
+ content = content[7:]
+ if content.startswith("```"):
+ content = content[3:]
+ if content.endswith("```"):
+ content = content[:-3]
+ content = content.strip()
+
+ print(f"[{index}/{total}] API 调用成功")
+ return json.loads(content)
+ except json.JSONDecodeError as e:
+ if attempt < max_retries - 1:
+ print(f"[{index}/{total}] JSON 解析失败: {e},重试中...")
+ time.sleep(base_delay)
+ else:
+ print(f"[{index}/{total}] JSON 解析失败,达到最大重试次数")
+ return {"annotations": []}
+ except Exception as e:
+ if attempt < max_retries - 1:
+ print(f"[{index}/{total}] API 调用失败: {e},重试中...")
+ time.sleep(base_delay * (2**attempt))
+ else:
+ print(f"[{index}/{total}] API 调用失败: {e},达到最大重试次数")
+ return {"annotations": []}
+
+ return {"annotations": []}
+
+
+def calculate_annotation_positions(original_text: str, llm_annotations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ 根据 LLM 返回的标注文本计算在原文本中的位置
+
+ Args:
+ original_text: 原始文本
+ llm_annotations: LLM 返回的标注列表,每个包含 type 和 text
+
+ Returns:
+ 包含 start, end, type, text 的完整标注列表
+ """
+ result_annotations = []
+ current_pos = 0
+
+ for ann in llm_annotations:
+ ann_type = ann.get("type")
+ ann_text = ann.get("text", "")
+
+ if not ann_type or not ann_text:
+ continue
+
+ # 在原文本中查找标注文本的位置
+ # 从当前位置开始查找,避免重复匹配
+ pos = original_text.find(ann_text, current_pos)
+
+ if pos == -1:
+ # 如果没找到,尝试从头查找(处理非顺序标注)
+ pos = original_text.find(ann_text)
+
+ if pos != -1:
+ result_annotations.append({
+ "type": ann_type,
+ "start": pos,
+ "end": pos + len(ann_text),
+ "text": ann_text
+ })
+ # 更新当前位置为标注结束位置
+ current_pos = pos + len(ann_text)
+
+ return result_annotations
+
+
+def convert_to_label_studio_format(
+ task_id: int, text: str, llm_annotations: List[Dict[str, Any]]
+) -> Dict[str, Any]:
+ """
+ 将 LLM 标注结果转换为 Label Studio 格式
+ """
+ import uuid
+
+ annotation_id = str(uuid.uuid4())
+
+ # 计算标注���置
+ annotations = calculate_annotation_positions(text, llm_annotations)
+
+ # 构建 result 数组
+ results = []
+ for ann in annotations:
+ if ann.get("type") is None or ann.get("text") is None:
+ continue
+
+ result_id = str(uuid.uuid4())
+ results.append(
+ {
+ "value": {
+ "start": ann["start"],
+ "end": ann["end"],
+ "text": ann["text"],
+ "labels": [ann["type"]],
+ },
+ "id": result_id,
+ "from_name": "label",
+ "to_name": "text",
+ "type": "labels",
+ "origin": "manual",
+ }
+ )
+
+ # 构建完整的 Label Studio 任务格式
+ now = datetime.now(timezone.utc).isoformat()
+
+ return {
+ "id": task_id,
+ "annotations": [
+ {
+ "id": task_id,
+ "completed_by": 1,
+ "result": results,
+ "was_cancelled": False,
+ "ground_truth": False,
+ "created_at": now,
+ "updated_at": now,
+ "draft_created_at": now,
+ "lead_time": 0.0,
+ "prediction": {},
+ "result_count": len(results),
+ "unique_id": annotation_id,
+ "import_id": None,
+ "last_action": None,
+ "bulk_created": False,
+ "task": task_id,
+ "project": 2,
+ "updated_by": 1,
+ "parent_prediction": None,
+ "parent_annotation": None,
+ "last_created_by": None,
+ }
+ ],
+ "file_upload": "llm-auto-annotated.json",
+ "drafts": [],
+ "predictions": [],
+ "data": {"text": text},
+ "meta": {},
+ "created_at": now,
+ "updated_at": now,
+ "allow_skip": True,
+ "inner_id": task_id,
+ "total_annotations": 1,
+ "cancelled_annotations": 0,
+ "total_predictions": 0,
+ "comment_count": 0,
+ "unresolved_comment_count": 0,
+ "last_comment_updated_at": None,
+ "project": 2,
+ "updated_by": 1,
+ "comment_authors": [],
+ }
+
+
+def process_logs(input_path: str, output_path: str, concurrency: int = 5, batch_size: int = 50):
+ """
+ 处理日志文件并进行自动标注(支持高并发)
+
+ Args:
+ input_path: 输入的 processed_logs.json 文件路径
+ output_path: 输出的标注结果文件路径
+ concurrency: 并发线程数
+ batch_size: 批处理保存大小
+ """
+ # 读取输入文件
+ print(f"读取输入文件: {input_path}")
+ with open(input_path, "r", encoding="utf-8") as f:
+ logs = json.load(f)
+
+ total = len(logs)
+ print(f"总共 {total} 条日志需要标注")
+ print(f"并发数: {concurrency}")
+
+ results = []
+ # 用于保持顺序的字典
+ results_dict = {}
+
+ def process_single_log(index: int, log_entry: Dict[str, Any]):
+ text = log_entry.get("text", "")
+ if not text:
+ print(f"[{index}/{total}] 跳过空文本")
+ return None
+
+ print(f"\n[{index}/{total}] 处理文本: {text[:50]}...")
+
+ # 构造 prompt
+ prompt = get_annotation_prompt(text)
+
+ # 调用 LLM API
+ llm_result = call_llm_api(prompt, index, total)
+
+ # 转换为 Label Studio 格式
+ return convert_to_label_studio_format(
+ task_id=index, text=text, llm_annotations=llm_result.get("annotations", [])
+ )
+
+ # 使用线程池并发处理
+ with ThreadPoolExecutor(max_workers=concurrency) as executor:
+ # 提交所有任务
+ future_to_index = {
+ executor.submit(process_single_log, i, log_entry): i
+ for i, log_entry in enumerate(logs, 1)
+ }
+
+ # 收集完成的任务
+ for future in as_completed(future_to_index):
+ index = future_to_index[future]
+ try:
+ result = future.result()
+ if result is not None:
+ results_dict[index] = result
+ except Exception as e:
+ print(f"[{index}/{total}] 处理失败: {e}")
+
+ # 按顺序整理结果
+ for index in sorted(results_dict.keys()):
+ results.append(results_dict[index])
+
+ # 每处理 batch_size 条保存一次
+ if index % batch_size == 0:
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+ print(f"\n已保存 {index} 条结果到 {output_path}")
+
+ # 保存最终结果
+ print(f"\n保存最终结果到: {output_path}")
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+
+ print(f"完成!共标注 {len(results)} 条日志")
+
+
+def main():
+ """
+ 主函数
+ """
+ # 使用默认路径
+ input_path = "dataset/processed_logs/processed_logs.json"
+ output_path = f"dataset/llm_annotated_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+
+ print(f"输入文件: {input_path}")
+ print(f"输出文件: {output_path}")
+
+ # 检查输入文件是否存在
+ if not os.path.exists(input_path):
+ print(f"错误:输入文件不存在: {input_path}")
+ return
+
+ # 开始处理
+ process_logs(input_path, output_path, concurrency=5, batch_size=50)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/utils/process_log.py b/utils/process_log.py
new file mode 100644
index 0000000..fffc359
--- /dev/null
+++ b/utils/process_log.py
@@ -0,0 +1,53 @@
+import glob
+import json
+import re
+
+
+def process_g_files():
+ files = glob.glob("g*.txt")
+
+ if not files:
+ print("未找到以'g'开头的txt文件")
+ return
+
+ print(f"找到 {len(files)} 个文件: {', '.join(files)}")
+
+ all_entries = []
+
+ for file_path in files:
+ try:
+ with open(file_path, "r", encoding="utf-8") as file:
+ current_paragraph = []
+
+ for line in file:
+ stripped_line = line.rstrip("\n")
+
+ if stripped_line.strip():
+ current_paragraph.append(stripped_line)
+ else:
+ if current_paragraph:
+ paragraph_text = "\n".join(current_paragraph)
+ cleaned_text = re.sub(r"\(\d+\)", "", paragraph_text)
+ all_entries.append({"text": cleaned_text})
+ current_paragraph = []
+
+ if current_paragraph:
+ paragraph_text = "\n".join(current_paragraph)
+ cleaned_text = re.sub(r"\(\d+\)", "", paragraph_text)
+ all_entries.append({"text": cleaned_text})
+
+ print(f"处理文件 {file_path} 完成")
+
+ except Exception as e:
+ print(f"处理文件 {file_path} 时出错: {e}")
+
+ output_file = "processed_logs.json"
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(all_entries, f, ensure_ascii=False, indent=2)
+
+ print(f"\n处理完成! 共处理 {len(all_entries)} 个段落")
+ print(f"结果已保存到 {output_file}")
+
+
+if __name__ == "__main__":
+ process_g_files()