aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--pyproject.toml2
-rw-r--r--src/infini/injector.py15
-rw-r--r--tests/test_injector.py16
3 files changed, 28 insertions, 5 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 36e8e168..0f1a5b92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "infini"
-version = "2.1.10"
+version = "2.1.11"
description = "Infini 内容输入输出流框架"
authors = [
{ name = "苏向夜", email = "fu050409@163.com" },
diff --git a/src/infini/injector.py b/src/infini/injector.py
index 462664a8..73d96566 100644
--- a/src/infini/injector.py
+++ b/src/infini/injector.py
@@ -1,5 +1,6 @@
+import typing
from infini.typing import Callable, T, Optional, Dict, Any
-from typing import get_origin
+from typing import get_args, get_origin
import inspect
@@ -19,9 +20,15 @@ class Injector:
for param_name, param in signature.parameters.items():
default = None if param.default == inspect._empty else param.default
if param_name in parameters:
- if type(parameters[param_name]) != (
- get_origin(param.annotation) or param.annotation
- ):
+ origin = get_origin(param.annotation)
+ if isinstance(origin, typing._SpecialForm):
+ param_types = get_args(param.annotation)
+ elif not origin:
+ param_types = (param.annotation,)
+ else:
+ param_types = (origin,)
+
+ if type(parameters[param_name]) not in param_types:
raise ValueError(
f"Parameter with name '{param_name}' has a mismatch type, "
f"expected '{param.annotation!r}' but got '{type(parameters[param_name])!r}'."
diff --git a/tests/test_injector.py b/tests/test_injector.py
index c41dd732..819810e1 100644
--- a/tests/test_injector.py
+++ b/tests/test_injector.py
@@ -1,3 +1,4 @@
+from typing import Optional
from infini.handler import Handler
from infini.injector import Injector
from infini.input import Input
@@ -31,6 +32,16 @@ def test_handler_injector():
"text": plain_text,
},
)
+
+ def absolute_2(input: Input[str], plain_text: Optional[str]) -> Output:
+ return input.output(
+ "text",
+ "absolute",
+ block=False,
+ variables={
+ "text": plain_text,
+ },
+ )
handler = Handler()
handler.handlers = [
@@ -39,6 +50,11 @@ def test_handler_injector():
"router": Startswith(""),
"handler": absolute,
},
+ {
+ "priority": 2,
+ "router": Startswith(""),
+ "handler": absolute_2,
+ },
]
core = Loader().into_core()