From 1cbd9bd2dd16ddaa263f2eb69970dd8e45dfa370 Mon Sep 17 00:00:00 2001 From: 苏向夜 Date: Mon, 1 Apr 2024 15:42:47 +0800 Subject: feat(injector): allow complex and instance annocation --- src/infini/injector.py | 6 +++++- tests/test_injector.py | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/infini/injector.py b/src/infini/injector.py index 73d96566..62032bb5 100644 --- a/src/infini/injector.py +++ b/src/infini/injector.py @@ -28,7 +28,11 @@ class Injector: else: param_types = (origin,) - if type(parameters[param_name]) not in param_types: + if not any( + isinstance(parameters[param_name], param_type) + for param_type in param_types + if not isinstance(param_type, typing._SpecialForm) + ): raise ValueError( f"Parameter with name '{param_name}' has a mismatch type, " f"expected '{param.annotation!r}' but got '{type(parameters[param_name])!r}'." diff --git a/tests/test_injector.py b/tests/test_injector.py index 819810e1..e6ae2e56 100644 --- a/tests/test_injector.py +++ b/tests/test_injector.py @@ -1,4 +1,5 @@ -from typing import Optional +from typing import Dict, List, Optional +from unittest.mock import Base from infini.handler import Handler from infini.injector import Injector from infini.input import Input @@ -32,7 +33,7 @@ def test_handler_injector(): "text": plain_text, }, ) - + def absolute_2(input: Input[str], plain_text: Optional[str]) -> Output: return input.output( "text", @@ -65,3 +66,37 @@ def test_handler_injector(): for output in core.input(input): assert output == "test_message" + + +def test_instance_injector(): + class BaseClass: ... + + class Class(BaseClass): + value = 10 + + def test( + a: int, + base: BaseClass, + b: int = 0, + cls: Optional[BaseClass] = None, + ): + assert isinstance(base, Class) + assert isinstance(cls, Class) + assert cls.value == 10 + return a + b + + injector = Injector() + injector.parameters = {"a": 12, "b": 20, "c": 0, "cls": Class(), "base": Class()} + + assert injector.output(test) == 32 + + +def test_complex_injector(): + def test(data: Dict[str, Dict[str, List[str]]]): + assert isinstance(data, dict) + return data["value"] + + injector = Injector() + injector.parameters = {"data": {"value": 32}} + + assert injector.output(test) == 32 -- cgit v1.2.3-70-g09d2