aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
author苏向夜 <fu050409@163.com>2024-04-01 11:06:17 +0800
committer苏向夜 <fu050409@163.com>2024-04-01 11:06:17 +0800
commit8821abdfd1e69aa065dd38eba81c8e29618ea8a2 (patch)
tree57b8475ffd153410e2a336b9511437ea3adb8398 /src
parenta81364221c55528910cbe8932041c4b0ea3b10d2 (diff)
downloadinfini-2.1.11.tar.gz
infini-2.1.11.zip
fix(injector): fix injector error when complex subscrib annocationv2.1.11
Diffstat (limited to 'src')
-rw-r--r--src/infini/injector.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/src/infini/injector.py b/src/infini/injector.py
index 462664a8..73d96566 100644
--- a/src/infini/injector.py
+++ b/src/infini/injector.py
@@ -1,5 +1,6 @@
+import typing
from infini.typing import Callable, T, Optional, Dict, Any
-from typing import get_origin
+from typing import get_args, get_origin
import inspect
@@ -19,9 +20,15 @@ class Injector:
for param_name, param in signature.parameters.items():
default = None if param.default == inspect._empty else param.default
if param_name in parameters:
- if type(parameters[param_name]) != (
- get_origin(param.annotation) or param.annotation
- ):
+ origin = get_origin(param.annotation)
+ if isinstance(origin, typing._SpecialForm):
+ param_types = get_args(param.annotation)
+ elif not origin:
+ param_types = (param.annotation,)
+ else:
+ param_types = (origin,)
+
+ if type(parameters[param_name]) not in param_types:
raise ValueError(
f"Parameter with name '{param_name}' has a mismatch type, "
f"expected '{param.annotation!r}' but got '{type(parameters[param_name])!r}'."