1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
import inspect
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
Callable,
ContextManager,
Dict,
Generator,
Optional,
Type,
TypeVar,
Union,
cast,
)
from .utils import get_annotations, sync_ctx_manager_wrapper
_T = TypeVar("_T")
Dependency = Union[
# Class
Type[Union[_T, AsyncContextManager[_T], ContextManager[_T]]],
# GeneratorContextManager
Callable[[], AsyncGenerator[_T, None]],
Callable[[], Generator[_T, None, None]],
]
__all__ = ["Depends"]
class InnerDepends:
dependency: Optional[Dependency[Any]]
use_cache: bool
def __init__(
self, dependency: Optional[Dependency[Any]] = None, *, use_cache: bool = True
) -> None:
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"InnerDepends({attr}{cache})"
def Depends( # noqa: N802 # pylint: disable=invalid-name
dependency: Optional[Dependency[_T]] = None, *, use_cache: bool = True
) -> _T:
return InnerDepends(dependency=dependency, use_cache=use_cache) # type: ignore
async def solve_dependencies(
dependent: Dependency[_T],
*,
use_cache: bool,
stack: AsyncExitStack,
dependency_cache: Dict[Dependency[Any], Any],
) -> _T:
if use_cache and dependent in dependency_cache:
return dependency_cache[dependent]
if isinstance(dependent, type):
# type of dependent is Type[T]
values: Dict[str, Any] = {}
ann = get_annotations(dependent)
for name, sub_dependent in inspect.getmembers(
dependent, lambda x: isinstance(x, InnerDepends)
):
assert isinstance(sub_dependent, InnerDepends)
if sub_dependent.dependency is None:
dependent_ann = ann.get(name, None)
if dependent_ann is None:
raise TypeError("can not solve dependent")
sub_dependent.dependency = dependent_ann
values[name] = await solve_dependencies(
sub_dependent.dependency,
use_cache=sub_dependent.use_cache,
stack=stack,
dependency_cache=dependency_cache,
)
depend_obj = cast(
Union[_T, AsyncContextManager[_T], ContextManager[_T]],
dependent.__new__(dependent), # pyright: ignore[reportGeneralTypeIssues]
)
for key, value in values.items():
setattr(depend_obj, key, value)
depend_obj.__init__() # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
if isinstance(depend_obj, AsyncContextManager):
depend = await stack.enter_async_context(
depend_obj # pyright: ignore[reportUnknownArgumentType]
)
elif isinstance(depend_obj, ContextManager):
depend = await stack.enter_async_context(
sync_ctx_manager_wrapper(
depend_obj # pyright: ignore[reportUnknownArgumentType]
)
)
else:
depend = depend_obj
elif inspect.isasyncgenfunction(dependent):
# type of dependent is Callable[[], AsyncGenerator[T, None]]
cm = asynccontextmanager(dependent)()
depend = cast(_T, await stack.enter_async_context(cm))
elif inspect.isgeneratorfunction(dependent):
# type of dependent is Callable[[], Generator[T, None, None]]
cm = sync_ctx_manager_wrapper(contextmanager(dependent)())
depend = cast(_T, await stack.enter_async_context(cm))
else:
raise TypeError("dependent is not a class or generator function")
dependency_cache[dependent] = depend
return depend
|