From fd1f123d531e25ac066dae6a1ea8dc19fd1c0964 Mon Sep 17 00:00:00 2001 From: 简律纯 Date: Sat, 7 Oct 2023 02:50:20 +0800 Subject: feat: BREAKING CHANGES MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 白咕咕 Co-authored-by: kenichiLyon --- src/hydrorollcore/utils.py | 192 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 src/hydrorollcore/utils.py (limited to 'src/hydrorollcore/utils.py') diff --git a/src/hydrorollcore/utils.py b/src/hydrorollcore/utils.py new file mode 100644 index 00000000..c69fdd48 --- /dev/null +++ b/src/hydrorollcore/utils.py @@ -0,0 +1,192 @@ +import os +import json +import asyncio +import inspect +import os.path +import pkgutil +import importlib +import traceback +import dataclasses +from abc import ABC +from types import ModuleType +from functools import partial +from typing_extensions import ParamSpec +from importlib.abc import MetaPathFinder +from importlib.machinery import PathFinder +from typing import Any, List, Type, Tuple, TypeVar, Callable, Iterable, Coroutine + +from HydroRollCore.config import ConfigModel + +__all__ = [ + "ModulePathFinder", + "is_config_class", + "get_classes_from_module", + "get_classes_from_module_name", + "get_classes_from_dir", + "DataclassEncoder", + "samefile", + "sync_func_wrapper", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class ModulePathFinder(MetaPathFinder): + """用于查找 HydroRollCore 组件的元路径查找器。""" + + path: List[str] = [] + + def find_spec(self, fullname, path=None, target=None): + if path is None: + path = [] + return PathFinder.find_spec(fullname, self.path + list(path), target) + + +def is_config_class(config_class: Any) -> bool: + """判断一个对象是否是配置类。 + + Args: + config_class: 待判断的对象。 + + Returns: + 返回是否是配置类。 + """ + return ( + inspect.isclass(config_class) + and issubclass(config_class, ConfigModel) + and isinstance(getattr(config_class, "__config_name__", None), str) + and ABC not in config_class.__bases__ + and not inspect.isabstract(config_class) + ) + + +def get_classes_from_module( + module: ModuleType, super_class: Type[_T] +) -> List[Type[_T]]: + """从模块中查找指定类型的类。 + + Args: + module: Python 模块。 + super_class: 要查找的类的超类。 + + Returns: + 返回符合条件的类的列表。 + """ + classes: List[Type[_T]] = [] + for _, module_attr in inspect.getmembers(module, inspect.isclass): + module_attr: type + if ( + (inspect.getmodule(module_attr) or module) is module + and issubclass(module_attr, super_class) + and module_attr != super_class + and ABC not in module_attr.__bases__ + and not inspect.isabstract(module_attr) + ): + classes.append(module_attr) + return classes + + +def get_classes_from_module_name( + name: str, super_class: Type[_T] +) -> List[Tuple[Type[_T], ModuleType]]: + """从指定名称的模块中查找指定类型的类。 + + Args: + name: 模块名称,格式和 Python `import` 语句相同。 + super_class: 要查找的类的超类。 + + Returns: + 返回由符合条件的类和模块组成的元组的列表。 + + Raises: + ImportError: 当导入模块过程中出现错误。 + """ + try: + importlib.invalidate_caches() + module = importlib.import_module(name) + importlib.reload(module) + return list( + map(lambda x: (x, module), get_classes_from_module(module, super_class)) + ) + except BaseException as e: + # 不捕获 KeyboardInterrupt + # 捕获 KeyboardInterrupt 会阻止用户关闭 Python 当正在导入的模块陷入死循环时 + if isinstance(e, KeyboardInterrupt): + raise e + raise ImportError(e, traceback.format_exc()) from e + + +def get_classes_from_dir( + dirs: Iterable[str], super_class: Type[_T] +) -> List[Tuple[Type[_T], ModuleType]]: + """从指定路径列表中的所有模块中查找指定类型的类,以 `_` 开头的插件不会被导入。路径可以是相对路径或绝对路径。 + + Args: + dirs: 储存模块的路径的列表。 + super_class: 要查找的类的超类。 + + Returns: + 返回由符合条件的类和模块组成的元组的列表。 + """ + classes: List[Tuple[Type[_T], ModuleType]] = [] + for module_info in pkgutil.iter_modules(dirs): + if not module_info.name.startswith("_"): + try: + classes.extend( + get_classes_from_module_name(module_info.name, super_class) + ) + except ImportError: + continue + return classes + + +class DataclassEncoder(json.JSONEncoder): + """用于解析 MessageSegment 的 JSONEncoder 类。""" + + def default(self, o): + return o.as_dict() if dataclasses.is_dataclass(o) else super().default(o) + + +def samefile(path1: str, path2: str) -> bool: + """一个 `os.path.samefile` 的简单包装。 + + Args: + path1: 路径1。 + path2: 路径2。 + + Returns: + 如果两个路径是否指向相同的文件或目录。 + """ + try: + return path1 == path2 or os.path.samefile(path1, path2) + except OSError: + return False + + +def sync_func_wrapper( + func: Callable[_P, _R], to_thread: bool = False +) -> Callable[_P, Coroutine[None, None, _R]]: + """包装一个同步函数为异步函数 + + Args: + func: 待包装的同步函数。 + to_thread: 在独立的线程中运行同步函数。 + + Returns: + 异步函数。 + """ + if to_thread: + + async def _wrapper(*args: _P.args, **kwargs: _P.kwargs): + loop = asyncio.get_running_loop() + func_call = partial(func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + else: + + async def _wrapper(*args: _P.args, **kwargs: _P.kwargs): + return func(*args, **kwargs) + + return _wrapper \ No newline at end of file -- cgit v1.2.3-70-g09d2