diff options
| -rw-r--r-- | src/infini/loader.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/src/infini/loader.py b/src/infini/loader.py index 4c1d38fe..20d392c9 100644 --- a/src/infini/loader.py +++ b/src/infini/loader.py @@ -1,5 +1,6 @@ from importlib.util import spec_from_file_location from infini.core import Core +from infini.doc import Doc from infini.generator import BaseGenerator, Generator, TextGenerator from infini.handler import Handler from infini.injector import Injector @@ -98,6 +99,7 @@ class Loader: global_variables: Dict[str, Union[str, Callable]] interceptors: List[RouterType] generators: Dict[str, BaseGenerator] + doc: Doc _core: Core @@ -108,6 +110,7 @@ class Loader: self.global_variables = {} self.interceptors = [] self.generators = {} + self.doc = Doc() self.prepare() def __enter__(self) -> "Loader": @@ -119,7 +122,7 @@ class Loader: if exc_type is not None: raise exc_type(exc_value) - def _find_register_variables(self, module: ModuleType) -> List[Register]: + def find_register_variables(self, module: ModuleType) -> List[Register]: module_variables = inspect.getmembers(module) register_variables = [ var for _, var in module_variables if isinstance(var, Register) @@ -134,17 +137,18 @@ class Loader: self.prepare() module = importlib.import_module(name) - vars(module)["__infini__"] = {"core": self._core, "loader": self} + self.load_from_module(module) + return module - registers = self._find_register_variables(module) + def load_from_module(self, module: ModuleType) -> None: + vars(module)["__infini__"] = {"core": self._core, "loader": self} + registers = self.find_register_variables(module) self.load_from_registers(registers) if not registers: logger.warning( f"Infini 装载器未能在规则包 [bold green]{module.__name__}[/bold green] 中找到注册器." ) - return module - def load_from_registers(self, registers: Sequence[Register]): for register in registers: self.load_from_register(register) @@ -157,6 +161,7 @@ class Loader: self.events.update(register.events) self.global_variables.update(register.global_variables) self.interceptors = self._update_list(self.interceptors, register.interceptors) + self.doc.update(register.doc) def _update_list(self, old_list: List[RouterType], new_list: List[RouterType]): list = old_list.copy() |
