diff options
Diffstat (limited to 'tests/onnx_infer.py')
| -rw-r--r-- | tests/onnx_infer.py | 55 |
1 files changed, 42 insertions, 13 deletions
diff --git a/tests/onnx_infer.py b/tests/onnx_infer.py index f4c7f9d..d6c00d2 100644 --- a/tests/onnx_infer.py +++ b/tests/onnx_infer.py @@ -8,18 +8,31 @@ ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx") tok = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) -providers = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in ort.get_available_providers()] +providers = [ + p + for p in ("CUDAExecutionProvider", "CPUExecutionProvider") + if p in ort.get_available_providers() +] so = ort.SessionOptions() so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers) + def softmax(x): x = x - x.max(axis=-1, keepdims=True) e = np.exp(x) return e / e.sum(axis=-1, keepdims=True) + text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..." -inputs = tok(text, return_tensors="np", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=512) +inputs = tok( + text, + return_tensors="np", + return_offsets_mapping=True, + padding="max_length", + truncation=True, + max_length=512, +) feed = {} for inp in sess.get_inputs(): @@ -62,17 +75,25 @@ cur = None for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)): if not a or (off[0] == off[1] == 0): if cur: - entities.append(cur); cur = None + entities.append(cur) + cur = None continue label = id2label.get(int(pid), "O") if label == "O": if cur: - entities.append(cur); cur = None + entities.append(cur) + cur = None continue if label.startswith("B-") or cur is None or label[2:] != cur["type"]: if cur: entities.append(cur) - cur = {"type": label[2:], "tokens": [i], "start": int(off[0]), "end": int(off[1]), "probs":[float(pprob)]} + cur = { + "type": label[2:], + "tokens": [i], + "start": int(off[0]), + "end": int(off[1]), + "probs": [float(pprob)], + } else: cur["tokens"].append(i) cur["end"] = int(off[1]) @@ -80,22 +101,30 @@ for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn if cur: entities.append(cur) + def fix_timestamp(ts): - if not ts: return ts + if not ts: + return ts m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) if m: y, mo, d, rest = m.groups() - if len(y)==1: y="202"+y - elif len(y)==2: y="20"+y + if len(y) == 1: + y = "202" + y + elif len(y) == 2: + y = "20" + y return f"{y}-{mo}-{d}{rest}" return ts + + def fix_speaker(spk): - if not spk: return spk + if not spk: + return spk spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) - if len(spk)==1 and re.match(r"^[风雷电雨雪火水木金]", spk): - return spk+"某" + if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk + "某" return spk + out = {"metadata": {}, "content": []} for e in entities: s, epos = e["start"], e["end"] @@ -103,7 +132,7 @@ for e in entities: conf = round(float(np.mean(e["probs"])), 3) typ = e["type"] if typ in ("timestamp", "speaker"): - if typ=="timestamp": + if typ == "timestamp": ent_text = fix_timestamp(ent_text) else: ent_text = fix_speaker(ent_text) @@ -112,4 +141,4 @@ for e in entities: out["content"].append({"type": typ, "content": ent_text, "confidence": conf}) print("\nConstructed JSON:") -print(json.dumps(out, ensure_ascii=False, indent=2))
\ No newline at end of file +print(json.dumps(out, ensure_ascii=False, indent=2)) |
