diff options
Diffstat (limited to 'src/base_model_trpgner/inference/__init__.py')
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 48 |
1 files changed, 25 insertions, 23 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index 5824666..9c8d99c 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -15,9 +15,7 @@ try: import onnxruntime as ort from transformers import AutoTokenizer except ImportError as e: - raise ImportError( - "依赖未安装。请运行: pip install onnxruntime transformers numpy" - ) from e + raise ImportError("依赖未安装。请运行: pip install onnxruntime transformers numpy") from e # GitHub 仓库信息 @@ -83,29 +81,23 @@ def download_model_files(version: Optional[str] = None, force: bool = False) -> print(f"正在下载模型 {version}...") - # 下载 model.zip zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip" try: with urllib.request.urlopen(zip_url, timeout=120) as response: - # 下载到临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: shutil.copyfileobj(response, tmp_file) tmp_path = tmp_file.name - # 解压到目标目录 - print(f"解压中...") - with zipfile.ZipFile(tmp_path, 'r') as zip_ref: + print("解压中...") + with zipfile.ZipFile(tmp_path, "r") as zip_ref: zip_ref.extractall(model_dir) - - # 清理临时文件 os.unlink(tmp_path) - print(f"✅ 模型下载完成: {model_dir}") + print(f"模型下载完成: {model_dir}") except urllib.error.HTTPError as e: raise RuntimeError( - f"无法下载模型文件。请检查版本 {version} 是否存在。\n" - f"下载地址: {zip_url}" + f"无法下载模型文件。请检查版本 {version} 是否存在。\n" f"下载地址: {zip_url}" ) from e # 写入版本标记 @@ -209,9 +201,17 @@ class TRPGParser: else: # 默认标签 self.id2label = { - 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment", - 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker", - 9: "B-timestamp", 10: "I-timestamp", + 0: "O", + 1: "B-action", + 2: "I-action", + 3: "B-comment", + 4: "I-comment", + 5: "B-dialogue", + 6: "I-dialogue", + 7: "B-speaker", + 8: "I-speaker", + 9: "B-timestamp", + 10: "I-timestamp", } def parse(self, text: str, max_length: int = 512) -> Dict[str, Any]: @@ -268,7 +268,7 @@ class TRPGParser: if ent["start"] >= len(text) or ent["end"] > len(text): continue - raw_text = text[ent["start"]: ent["end"]] + raw_text = text[ent["start"] : ent["end"]] clean_text = self._clean_text(raw_text, ent["type"]) if not clean_text.strip(): @@ -277,11 +277,13 @@ class TRPGParser: if ent["type"] in ["timestamp", "speaker"]: result["metadata"][ent["type"]] = clean_text elif ent["type"] in ["dialogue", "action", "comment"]: - result["content"].append({ - "type": ent["type"], - "content": clean_text, - "confidence": round(ent["score"], 3), - }) + result["content"].append( + { + "type": ent["type"], + "content": clean_text, + "confidence": round(ent["score"], 3), + } + ) return result @@ -337,7 +339,7 @@ class TRPGParser: if group == "comment": text = re.sub(r"^[((]+|[))]+$", "", text) elif group == "dialogue": - text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text) + text = re.sub(r'^[""' '「」『』]+|[""""」』『』]+$', "", text) elif group == "action": text = re.sub(r"^[*#]+|[*#]+$", "", text) |
