aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorHsiangNianian <i@jyunko.cn>2026-01-05 14:39:30 +0800
committerHsiangNianian <i@jyunko.cn>2026-01-05 14:40:15 +0800
commit6713e38407bdbc1495692c7e297c027a1dc3f612 (patch)
treec9e93c58e46583962807114de81c75be04a6dbef
parentf51ebaf36593dffb066ad3c4f7f98a0827d8f8e9 (diff)
downloadbase-model-6713e38407bdbc1495692c7e297c027a1dc3f612.tar.gz
base-model-6713e38407bdbc1495692c7e297c027a1dc3f612.zip
feat: enhance code readability and formatting in onnx_infer.py
-rw-r--r--tests/onnx_infer.py55
1 files changed, 42 insertions, 13 deletions
diff --git a/tests/onnx_infer.py b/tests/onnx_infer.py
index f4c7f9d..d6c00d2 100644
--- a/tests/onnx_infer.py
+++ b/tests/onnx_infer.py
@@ -8,18 +8,31 @@ ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx")
tok = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
-providers = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in ort.get_available_providers()]
+providers = [
+ p
+ for p in ("CUDAExecutionProvider", "CPUExecutionProvider")
+ if p in ort.get_available_providers()
+]
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers)
+
def softmax(x):
x = x - x.max(axis=-1, keepdims=True)
e = np.exp(x)
return e / e.sum(axis=-1, keepdims=True)
+
text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..."
-inputs = tok(text, return_tensors="np", return_offsets_mapping=True, padding="max_length", truncation=True, max_length=512)
+inputs = tok(
+ text,
+ return_tensors="np",
+ return_offsets_mapping=True,
+ padding="max_length",
+ truncation=True,
+ max_length=512,
+)
feed = {}
for inp in sess.get_inputs():
@@ -62,17 +75,25 @@ cur = None
for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)):
if not a or (off[0] == off[1] == 0):
if cur:
- entities.append(cur); cur = None
+ entities.append(cur)
+ cur = None
continue
label = id2label.get(int(pid), "O")
if label == "O":
if cur:
- entities.append(cur); cur = None
+ entities.append(cur)
+ cur = None
continue
if label.startswith("B-") or cur is None or label[2:] != cur["type"]:
if cur:
entities.append(cur)
- cur = {"type": label[2:], "tokens": [i], "start": int(off[0]), "end": int(off[1]), "probs":[float(pprob)]}
+ cur = {
+ "type": label[2:],
+ "tokens": [i],
+ "start": int(off[0]),
+ "end": int(off[1]),
+ "probs": [float(pprob)],
+ }
else:
cur["tokens"].append(i)
cur["end"] = int(off[1])
@@ -80,22 +101,30 @@ for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn
if cur:
entities.append(cur)
+
def fix_timestamp(ts):
- if not ts: return ts
+ if not ts:
+ return ts
m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts)
if m:
y, mo, d, rest = m.groups()
- if len(y)==1: y="202"+y
- elif len(y)==2: y="20"+y
+ if len(y) == 1:
+ y = "202" + y
+ elif len(y) == 2:
+ y = "20" + y
return f"{y}-{mo}-{d}{rest}"
return ts
+
+
def fix_speaker(spk):
- if not spk: return spk
+ if not spk:
+ return spk
spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk)
- if len(spk)==1 and re.match(r"^[风雷电雨雪火水木金]", spk):
- return spk+"某"
+ if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk):
+ return spk + "某"
return spk
+
out = {"metadata": {}, "content": []}
for e in entities:
s, epos = e["start"], e["end"]
@@ -103,7 +132,7 @@ for e in entities:
conf = round(float(np.mean(e["probs"])), 3)
typ = e["type"]
if typ in ("timestamp", "speaker"):
- if typ=="timestamp":
+ if typ == "timestamp":
ent_text = fix_timestamp(ent_text)
else:
ent_text = fix_speaker(ent_text)
@@ -112,4 +141,4 @@ for e in entities:
out["content"].append({"type": typ, "content": ent_text, "confidence": conf})
print("\nConstructed JSON:")
-print(json.dumps(out, ensure_ascii=False, indent=2)) \ No newline at end of file
+print(json.dumps(out, ensure_ascii=False, indent=2))