diff options
| author | 2025-12-30 20:39:34 +0800 | |
|---|---|---|
| committer | 2025-12-30 20:39:34 +0800 | |
| commit | 298035052b3e3d083b57f5dbac0e86de4f94efba (patch) | |
| tree | 944f38d734f752a5a0f71033ebece38fc5c35839 /tests | |
| parent | 92a647ffbb3452a0ed49601177f290e20a88413e (diff) | |
| download | base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.tar.gz base-model-298035052b3e3d083b57f5dbac0e86de4f94efba.zip | |
refactor: Update model download functionality and improve inference module to support automatic model retrieval from GitHub releases
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_onnx_only_infer.py | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/tests/test_onnx_only_infer.py b/tests/test_onnx_only_infer.py new file mode 100644 index 0000000..69c72be --- /dev/null +++ b/tests/test_onnx_only_infer.py @@ -0,0 +1,198 @@ +""" +Minimal ONNX-only inference using only: + - models/trpg-final/model.onnx + - models/trpg-final/config.json + +NOTE: 使用自制字符级 tokenizer(非训练时 tokenizer),结果可能与原模型输出不一致, +但可在没有 tokenizer 文件时完成端到端推理演示。 +""" + +import os, sys, json, re +import numpy as np +import onnxruntime as ort + +MODEL_DIR = "models/trpg-final" +ONNX_PATH = os.path.join(MODEL_DIR, "model.onnx") +CFG_PATH = os.path.join(MODEL_DIR, "config.json") +MAX_LEN = 128 + +# load id2label & vocab_size +with open(CFG_PATH, "r", encoding="utf-8") as f: + cfg = json.load(f) +id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()} +vocab_size = int(cfg.get("vocab_size", 30000)) +pad_id = int(cfg.get("pad_token_id", 0)) + +# simple char-level tokenizer (adds [CLS]=101, [SEP]=102, pads with pad_id) +CLS_ID = 101 +SEP_ID = 102 + + +def char_tokenize(text, max_length=MAX_LEN): + chars = list(text) + # reserve 2 for CLS and SEP + max_chars = max_length - 2 + chars = chars[:max_chars] + ids = [CLS_ID] + [100 + (ord(c) % (vocab_size - 200)) for c in chars] + [SEP_ID] + attn = [1] * len(ids) + # pad + pad_len = max_length - len(ids) + ids += [pad_id] * pad_len + attn += [0] * pad_len + # offsets: for CLS/SEP/pad use (0,0); for char tokens map to character positions + offsets = [(0, 0)] + pos = 0 + for c in chars: + offsets.append((pos, pos + 1)) + pos += 1 + offsets.append((0, 0)) # SEP + offsets += [(0, 0)] * pad_len + return { + "input_ids": np.array([ids], dtype=np.int64), + "attention_mask": np.array([attn], dtype=np.int64), + "offset_mapping": np.array([offsets], dtype=np.int64), + "text": text, + } + + +# onnx runtime session +providers = [ + p + for p in ("CUDAExecutionProvider", "CPUExecutionProvider") + if p in ort.get_available_providers() +] +so = ort.SessionOptions() +so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL +sess = ort.InferenceSession(ONNX_PATH, sess_options=so, providers=providers) + + +def softmax(x): + x = x - x.max(axis=-1, keepdims=True) + e = np.exp(x) + return e / e.sum(axis=-1, keepdims=True) + + +text = sys.argv[1] if len(sys.argv) > 1 else "风雨 2024-06-08 21:44:59 剧烈的疼痛..." +inp = char_tokenize(text, MAX_LEN) + +# build feed dict matching session inputs +feed = {} +for s_in in sess.get_inputs(): + name = s_in.name + if name in inp: + feed[name] = inp[name] + +outs = sess.run(None, feed) +logits = np.asarray(outs[0]) # (batch, seq_len, num_labels) +probs = softmax(logits) + +ids = inp["input_ids"][0] +offsets = inp["offset_mapping"][0] +attn = inp["attention_mask"][0] + +# reconstruct token strings (CLS, each char, SEP) +tokens = [] +for i, idv in enumerate(ids): + if i == 0: + tokens.append("[CLS]") + else: + if offsets[i][0] == 0 and offsets[i][1] == 0: + # SEP or pad + if attn[i] == 1: + tokens.append("[SEP]") + else: + tokens.append("[PAD]") + else: + s, e = offsets[i] + tokens.append(text[s:e]) + +# print raw logits shape and a small slice for inspection +print("Raw logits shape:", logits.shape) +print("\nPer-token logits (index token -> first 6 logits):") +for i, (t, l, a) in enumerate(zip(tokens, logits[0], attn)): + if not a: + continue + print(f"{i:03d} {t:>6} ->", np.around(l[:6], 3).tolist()) + +# predictions & probs +pred_ids = logits.argmax(-1)[0] +pred_probs = probs[0, np.arange(probs.shape[1]), pred_ids] + +print("\nPer-token predictions (token \\t label \\t prob):") +for i, (t, pid, pprob, a) in enumerate(zip(tokens, pred_ids, pred_probs, attn)): + if not a: + continue + lab = id2label.get(int(pid), "O") + print(f"{t}\t{lab}\t{pprob:.3f}") + +# merge BIO into entities using offsets +entities = [] +cur = None +for i, (pid, pprob, off, a) in enumerate(zip(pred_ids, pred_probs, offsets, attn)): + if not a or (off[0] == off[1] == 0): + if cur: + entities.append(cur) + cur = None + continue + label = id2label.get(int(pid), "O") + if label == "O": + if cur: + entities.append(cur) + cur = None + continue + if label.startswith("B-") or cur is None or label[2:] != cur["type"]: + if cur: + entities.append(cur) + cur = { + "type": label[2:], + "start": int(off[0]), + "end": int(off[1]), + "probs": [float(pprob)], + } + else: + cur["end"] = int(off[1]) + cur["probs"].append(float(pprob)) +if cur: + entities.append(cur) + + +# small fixes (timestamp/speaker) like main.py +def fix_timestamp(ts): + if not ts: + return ts + m = re.match(r"^(\d{1,2})-(\d{2})-(\d{2})(.*)", ts) + if m: + y, mo, d, rest = m.groups() + if len(y) == 1: + y = "202" + y + elif len(y) == 2: + y = "20" + y + return f"{y}-{mo}-{d}{rest}" + return ts + + +def fix_speaker(spk): + if not spk: + return spk + spk = re.sub(r"[^\w\s\u4e00-\u9fff]+$", "", spk) + if len(spk) == 1 and re.match(r"^[风雷电雨雪火水木金]", spk): + return spk + "某" + return spk + + +out = {"metadata": {}, "content": []} +for e in entities: + s, e_pos = e["start"], e["end"] + ent_text = text[s:e_pos] + conf = round(float(np.mean(e["probs"])), 3) + typ = e["type"] + if typ in ("timestamp", "speaker"): + ent_text = ( + fix_timestamp(ent_text) if typ == "timestamp" else fix_speaker(ent_text) + ) + out["metadata"][typ] = ent_text + else: + out["content"].append({"type": typ, "content": ent_text, "confidence": conf}) + +print("\nConstructed JSON:") +print(json.dumps(out, ensure_ascii=False, indent=2)) |
