diff options
| author | 2026-01-05 14:33:10 +0800 | |
|---|---|---|
| committer | 2026-01-05 14:33:10 +0800 | |
| commit | df94eb6c125279a9c32bc85de8633371d50afbed (patch) | |
| tree | fa7e99e5078bbcb62e7b9dcc181f4fbc282129c4 | |
| parent | 25380fb4de77966a0f3d00681be25857c27b0869 (diff) | |
| download | base-model-df94eb6c125279a9c32bc85de8633371d50afbed.tar.gz base-model-df94eb6c125279a9c32bc85de8633371d50afbed.zip | |
feat: update max_length parameter for TRPGParser and onnx_infer to improve text parsing capabilities
| -rw-r--r-- | src/base_model_trpgner/inference/__init__.py | 6 | ||||
| -rw-r--r-- | tests/onnx_infer.py | 2 |
2 files changed, 4 insertions, 4 deletions
diff --git a/src/base_model_trpgner/inference/__init__.py b/src/base_model_trpgner/inference/__init__.py index 3d1d720..41f3504 100644 --- a/src/base_model_trpgner/inference/__init__.py +++ b/src/base_model_trpgner/inference/__init__.py @@ -214,13 +214,13 @@ class TRPGParser: 9: "B-timestamp", 10: "I-timestamp", } - def parse(self, text: str) -> Dict[str, Any]: + def parse(self, text: str, max_length: int = 512) -> Dict[str, Any]: """ 解析单条 TRPG 日志 Args: text: 待解析的日志文本 - + max_length: 最大序列长度,大小512以内 Returns: 包含 metadata 和 content 的字典 - metadata: speaker, timestamp @@ -239,7 +239,7 @@ class TRPGParser: return_offsets_mapping=True, padding="max_length", truncation=True, - max_length=128, + max_length=max_length or 512, ) # 推理 diff --git a/tests/onnx_infer.py b/tests/onnx_infer.py index 4ffca25..f4c7f9d 100644 --- a/tests/onnx_infer.py +++ b/tests/onnx_infer.py @@ -19,7 +19,7 @@ def softmax(x): return e / e.sum(axis=-1, keepdims=True) text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..." -inputs = tok(text, return_tensors="np", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=128) +inputs = tok(text, return_tensors="np", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=512) feed = {} for inp in sess.get_inputs(): |
