aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
author苏向夜 <fu050409@163.com>2024-04-01 15:42:47 +0800
committer苏向夜 <fu050409@163.com>2024-04-01 15:42:47 +0800
commit1cbd9bd2dd16ddaa263f2eb69970dd8e45dfa370 (patch)
treef3bd007163341037caed344a1c79c970a31fb215
parent4634be1aef800bf4c435d04bb11a2e31e847e604 (diff)
downloadinfini-1cbd9bd2dd16ddaa263f2eb69970dd8e45dfa370.tar.gz
infini-1cbd9bd2dd16ddaa263f2eb69970dd8e45dfa370.zip
feat(injector): allow complex and instance annocation
-rw-r--r--src/infini/injector.py6
-rw-r--r--tests/test_injector.py39
2 files changed, 42 insertions, 3 deletions
diff --git a/src/infini/injector.py b/src/infini/injector.py
index 73d96566..62032bb5 100644
--- a/src/infini/injector.py
+++ b/src/infini/injector.py
@@ -28,7 +28,11 @@ class Injector:
else:
param_types = (origin,)
- if type(parameters[param_name]) not in param_types:
+ if not any(
+ isinstance(parameters[param_name], param_type)
+ for param_type in param_types
+ if not isinstance(param_type, typing._SpecialForm)
+ ):
raise ValueError(
f"Parameter with name '{param_name}' has a mismatch type, "
f"expected '{param.annotation!r}' but got '{type(parameters[param_name])!r}'."
diff --git a/tests/test_injector.py b/tests/test_injector.py
index 819810e1..e6ae2e56 100644
--- a/tests/test_injector.py
+++ b/tests/test_injector.py
@@ -1,4 +1,5 @@
-from typing import Optional
+from typing import Dict, List, Optional
+from unittest.mock import Base
from infini.handler import Handler
from infini.injector import Injector
from infini.input import Input
@@ -32,7 +33,7 @@ def test_handler_injector():
"text": plain_text,
},
)
-
+
def absolute_2(input: Input[str], plain_text: Optional[str]) -> Output:
return input.output(
"text",
@@ -65,3 +66,37 @@ def test_handler_injector():
for output in core.input(input):
assert output == "test_message"
+
+
+def test_instance_injector():
+ class BaseClass: ...
+
+ class Class(BaseClass):
+ value = 10
+
+ def test(
+ a: int,
+ base: BaseClass,
+ b: int = 0,
+ cls: Optional[BaseClass] = None,
+ ):
+ assert isinstance(base, Class)
+ assert isinstance(cls, Class)
+ assert cls.value == 10
+ return a + b
+
+ injector = Injector()
+ injector.parameters = {"a": 12, "b": 20, "c": 0, "cls": Class(), "base": Class()}
+
+ assert injector.output(test) == 32
+
+
+def test_complex_injector():
+ def test(data: Dict[str, Dict[str, List[str]]]):
+ assert isinstance(data, dict)
+ return data["value"]
+
+ injector = Injector()
+ injector.parameters = {"data": {"value": 32}}
+
+ assert injector.output(test) == 32