aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--pyproject.toml2
-rw-r--r--src/infini/injector.py6
-rw-r--r--tests/test_injector.py39
3 files changed, 43 insertions, 4 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 58b2abbb..94e2eefd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "infini"
-version = "2.1.12"
+version = "2.1.13"
description = "Infini 内容输入输出流框架"
authors = [
{ name = "苏向夜", email = "fu050409@163.com" },
diff --git a/src/infini/injector.py b/src/infini/injector.py
index 73d96566..62032bb5 100644
--- a/src/infini/injector.py
+++ b/src/infini/injector.py
@@ -28,7 +28,11 @@ class Injector:
else:
param_types = (origin,)
- if type(parameters[param_name]) not in param_types:
+ if not any(
+ isinstance(parameters[param_name], param_type)
+ for param_type in param_types
+ if not isinstance(param_type, typing._SpecialForm)
+ ):
raise ValueError(
f"Parameter with name '{param_name}' has a mismatch type, "
f"expected '{param.annotation!r}' but got '{type(parameters[param_name])!r}'."
diff --git a/tests/test_injector.py b/tests/test_injector.py
index 819810e1..e6ae2e56 100644
--- a/tests/test_injector.py
+++ b/tests/test_injector.py
@@ -1,4 +1,5 @@
-from typing import Optional
+from typing import Dict, List, Optional
+from unittest.mock import Base
from infini.handler import Handler
from infini.injector import Injector
from infini.input import Input
@@ -32,7 +33,7 @@ def test_handler_injector():
"text": plain_text,
},
)
-
+
def absolute_2(input: Input[str], plain_text: Optional[str]) -> Output:
return input.output(
"text",
@@ -65,3 +66,37 @@ def test_handler_injector():
for output in core.input(input):
assert output == "test_message"
+
+
+def test_instance_injector():
+ class BaseClass: ...
+
+ class Class(BaseClass):
+ value = 10
+
+ def test(
+ a: int,
+ base: BaseClass,
+ b: int = 0,
+ cls: Optional[BaseClass] = None,
+ ):
+ assert isinstance(base, Class)
+ assert isinstance(cls, Class)
+ assert cls.value == 10
+ return a + b
+
+ injector = Injector()
+ injector.parameters = {"a": 12, "b": 20, "c": 0, "cls": Class(), "base": Class()}
+
+ assert injector.output(test) == 32
+
+
+def test_complex_injector():
+ def test(data: Dict[str, Dict[str, List[str]]]):
+ assert isinstance(data, dict)
+ return data["value"]
+
+ injector = Injector()
+ injector.parameters = {"data": {"value": 32}}
+
+ assert injector.output(test) == 32