summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
author苏向夜 <fu050409@163.com>2024-01-27 13:26:36 +0800
committer苏向夜 <fu050409@163.com>2024-01-27 13:26:36 +0800
commit28e83e03185992c857c9fbba5bafe50c5d615ed7 (patch)
treea595973ac2612066cfab6fdcb22eae14d5a046de
parent26c56c456764baffbefa804fd5e80be2fb5a936d (diff)
downloadinfini-28e83e03185992c857c9fbba5bafe50c5d615ed7.tar.gz
infini-28e83e03185992c857c9fbba5bafe50c5d615ed7.zip
:sparkles: feat(loader): add loader method
-rw-r--r--src/infini/loader.py152
-rw-r--r--src/infini/typing.py3
-rw-r--r--tests/test_loader.py46
3 files changed, 199 insertions, 2 deletions
diff --git a/src/infini/loader.py b/src/infini/loader.py
index d5a95699..0686aaeb 100644
--- a/src/infini/loader.py
+++ b/src/infini/loader.py
@@ -1,2 +1,152 @@
+from importlib.util import spec_from_file_location
+from infini.core import Core
+from infini.generator import Generator
+from infini.handler import Handler
+from infini.interceptor import Interceptor
+from infini.register import Register
+from infini.typing import List, Dict, Sequence, ModuleType, RouterType, Callable
+from pathlib import Path
+
+import inspect
+import sys
+import importlib
+import importlib.abc
+
+
+class InfiniMetaFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname: str, path: Sequence[str] | None, target=None):
+ default_entries = [
+ Path.cwd() / "src",
+ Path.home() / ".ipm" / "src",
+ ]
+
+ entries: List[Path] = (
+ [Path(catch_path).resolve() for catch_path in path] + default_entries
+ if path
+ else default_entries
+ )
+
+ if "." in fullname:
+ *_, name = fullname.split(".")
+ else:
+ name = fullname
+
+ for entry in entries:
+ if (entry / name).is_dir():
+ filename = entry / name / "src" / "__init__.py"
+ submodule_locations = [entry / name / "src"]
+ if not filename.exists():
+ filename = entry / name / "src" / (name + ".py")
+ else:
+ continue
+ if not filename.exists():
+ continue
+
+ return spec_from_file_location(
+ fullname,
+ filename,
+ loader=InfiniLoader(str(filename)),
+ submodule_search_locations=[
+ str(submodule_location)
+ for submodule_location in submodule_locations
+ ]
+ if submodule_locations
+ else None,
+ )
+
+ return None
+
+
+class InfiniLoader(importlib.abc.Loader):
+ def __init__(self, filename):
+ self.filename = filename
+
+ def create_module(self, _):
+ return None
+
+ def exec_module(self, module):
+ with open(self.filename) as f:
+ data = f.read()
+
+ exec(data, vars(module))
+
+
+def install():
+ sys.meta_path.insert(0, InfiniMetaFinder())
+
+
+def uninstall():
+ for meta_path in sys.meta_path:
+ if isinstance(meta_path, InfiniMetaFinder):
+ sys.meta_path.remove(meta_path)
+ break
+
+
class Loader:
- ...
+ pre_interceptors: List[RouterType]
+ handlers: List[RouterType]
+ events: Dict[str, str]
+ global_variables: Dict[str, str | Callable]
+ interceptors: List[RouterType]
+
+ def __init__(self) -> None:
+ self.pre_interceptors = []
+ self.handlers = []
+ self.events = {}
+ self.global_variables = {}
+ self.interceptors = []
+ self.prepare()
+
+ def _find_register_variables(self, module: ModuleType) -> List[Register]:
+ module_variables = inspect.getmembers(module)
+ register_variables = [
+ var for _, var in module_variables if isinstance(var, Register)
+ ]
+ return register_variables
+
+ def prepare(self) -> None:
+ install()
+
+ def load(self, name: str) -> ModuleType:
+ self.prepare()
+
+ module = importlib.import_module(name)
+ registers = self._find_register_variables(module)
+ for register in registers:
+ self.load_from_register(register)
+ if not registers:
+ # TODO 警告内容
+ ...
+
+ return module
+
+ def load_from_register(self, register: Register):
+ self.pre_interceptors.extend(register.pre_interceptors)
+ self.handlers.extend(register.handlers)
+ self.events.update(register.events)
+ self.global_variables.update(register.global_variables)
+ self.interceptors.extend(register.interceptors)
+
+ def close(self):
+ uninstall()
+
+ def inject(self, core: Core):
+ pre_interceptor = Interceptor()
+ pre_interceptor.interceptors = self.pre_interceptors
+ handler = Handler()
+ handler.handlers = self.handlers
+ generator = Generator()
+ generator.events = self.events
+ generator.global_variables = self.global_variables
+ interceptor = Interceptor()
+ interceptor.interceptors = self.interceptors
+
+ core.pre_interceptor = pre_interceptor
+ core.handler = handler
+ core.generator = generator
+ core.interceptor = interceptor
+
+ def output(self) -> Core:
+ core = Core()
+ self.inject(core)
+ return core
diff --git a/src/infini/typing.py b/src/infini/typing.py
index 0a0d3292..a1273f02 100644
--- a/src/infini/typing.py
+++ b/src/infini/typing.py
@@ -5,12 +5,13 @@ from typing import (
Generic as Generic,
Callable as Callable,
Literal as Literal,
+ Sequence as Sequence,
overload as overload,
TypeVar,
TypedDict,
Union,
)
-from types import ModuleType as ModuleType
+from types import ModuleType as ModuleType, GeneratorType as GeneratorType
from . import router, input, output
T = TypeVar("T")
diff --git a/tests/test_loader.py b/tests/test_loader.py
new file mode 100644
index 00000000..552223d3
--- /dev/null
+++ b/tests/test_loader.py
@@ -0,0 +1,46 @@
+from infini.input import Input
+from infini.loader import Loader
+from infini.output import Output
+from infini.register import Register
+
+
+def test_loader():
+ blocked_god_input = Input("这是苏向夜的杰作.")
+ snh_input = Input("撅少年狐!")
+
+ register = Register()
+
+ @register.pre_interceptor("苏向夜", priority=0)
+ def test_pre_interceptor(_: Input):
+ return Output("text", "block.sxy", block=True)
+
+ @register.handler("撅少年狐")
+ def test_handler(_: Input):
+ return Output("text", "block.snh", block=True)
+
+ register.regist_textevent("block.sxy", "不可直呼{{ sxy_id }}的ID")
+ register.regist_textevent("block.snh", "不许撅{{ get_snh_id() }}")
+
+ register.regist_variable("sxy_id", "苏向夜")
+
+ @register.dynamic_variable()
+ def get_snh_id():
+ return "少年狐"
+
+ @register.interceptor("苏向夜", priority=0)
+ def test_interceptor(_: Input):
+ return Output("text", "block.sxy", block=True)
+
+ loader = Loader()
+ loader.load_from_register(register)
+ core = loader.output()
+
+ for output in core.input(blocked_god_input):
+ assert output == "不可直呼苏向夜的ID"
+
+ for output in core.input(snh_input):
+ assert output == "不许撅少年狐"
+
+
+if __name__ == "__main__":
+ test_loader()