aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/base_model_trpgner
diff options
context:
space:
mode:
Diffstat (limited to 'src/base_model_trpgner')
-rw-r--r--src/base_model_trpgner/inference/__init__.py48
1 files changed, 25 insertions, 23 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py
index 5824666..9c8d99c 100644
--- a/src/base_model_trpgner/inference/__init__.py
+++ b/src/base_model_trpgner/inference/__init__.py
@@ -15,9 +15,7 @@ try:
import onnxruntime as ort
from transformers import AutoTokenizer
except ImportError as e:
- raise ImportError(
- "依赖未安装。请运行: pip install onnxruntime transformers numpy"
- ) from e
+ raise ImportError("依赖未安装。请运行: pip install onnxruntime transformers numpy") from e
# GitHub 仓库信息
@@ -83,29 +81,23 @@ def download_model_files(version: Optional[str] = None, force: bool = False) ->
print(f"正在下载模型 {version}...")
- # 下载 model.zip
zip_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/download/{version}/model.zip"
try:
with urllib.request.urlopen(zip_url, timeout=120) as response:
- # 下载到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
shutil.copyfileobj(response, tmp_file)
tmp_path = tmp_file.name
- # 解压到目标目录
- print(f"解压中...")
- with zipfile.ZipFile(tmp_path, 'r') as zip_ref:
+ print("解压中...")
+ with zipfile.ZipFile(tmp_path, "r") as zip_ref:
zip_ref.extractall(model_dir)
-
- # 清理临时文件
os.unlink(tmp_path)
- print(f"✅ 模型下载完成: {model_dir}")
+ print(f"模型下载完成: {model_dir}")
except urllib.error.HTTPError as e:
raise RuntimeError(
- f"无法下载模型文件。请检查版本 {version} 是否存在。\n"
- f"下载地址: {zip_url}"
+ f"无法下载模型文件。请检查版本 {version} 是否存在。\n" f"下载地址: {zip_url}"
) from e
# 写入版本标记
@@ -209,9 +201,17 @@ class TRPGParser:
else:
# 默认标签
self.id2label = {
- 0: "O", 1: "B-action", 2: "I-action", 3: "B-comment", 4: "I-comment",
- 5: "B-dialogue", 6: "I-dialogue", 7: "B-speaker", 8: "I-speaker",
- 9: "B-timestamp", 10: "I-timestamp",
+ 0: "O",
+ 1: "B-action",
+ 2: "I-action",
+ 3: "B-comment",
+ 4: "I-comment",
+ 5: "B-dialogue",
+ 6: "I-dialogue",
+ 7: "B-speaker",
+ 8: "I-speaker",
+ 9: "B-timestamp",
+ 10: "I-timestamp",
}
def parse(self, text: str, max_length: int = 512) -> Dict[str, Any]:
@@ -268,7 +268,7 @@ class TRPGParser:
if ent["start"] >= len(text) or ent["end"] > len(text):
continue
- raw_text = text[ent["start"]: ent["end"]]
+ raw_text = text[ent["start"] : ent["end"]]
clean_text = self._clean_text(raw_text, ent["type"])
if not clean_text.strip():
@@ -277,11 +277,13 @@ class TRPGParser:
if ent["type"] in ["timestamp", "speaker"]:
result["metadata"][ent["type"]] = clean_text
elif ent["type"] in ["dialogue", "action", "comment"]:
- result["content"].append({
- "type": ent["type"],
- "content": clean_text,
- "confidence": round(ent["score"], 3),
- })
+ result["content"].append(
+ {
+ "type": ent["type"],
+ "content": clean_text,
+ "confidence": round(ent["score"], 3),
+ }
+ )
return result
@@ -337,7 +339,7 @@ class TRPGParser:
if group == "comment":
text = re.sub(r"^[((]+|[))]+$", "", text)
elif group == "dialogue":
- text = re.sub(r'^[""''「」『』]+|[""""」』『』]+$', "", text)
+ text = re.sub(r'^[""' '「」『』]+|[""""」』『』]+$', "", text)
elif group == "action":
text = re.sub(r"^[*#]+|[*#]+$", "", text)