aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/hrc/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'hrc/core.py')
-rw-r--r--hrc/core.py689
1 files changed, 688 insertions, 1 deletions
diff --git a/hrc/core.py b/hrc/core.py
index 020ccb9..862b8d1 100644
--- a/hrc/core.py
+++ b/hrc/core.py
@@ -1 +1,688 @@
-class Core: ...
+import asyncio
+import json
+import pkgutil
+import signal
+import sys
+import threading
+import time
+from collections import defaultdict
+from contextlib import AsyncExitStack
+from itertools import chain
+from pathlib import Path
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ overload,
+)
+
+from pydantic import ValidationError, create_model
+
+from .config import ConfigModel, MainConfig, RuleConfig
+from .dependencies import solve_dependencies
+from .log import logger
+from .rules import Rule
+from .utils import (
+ ModulePathFinder,
+ get_classes_from_module_name,
+ is_config_class,
+ samefile,
+ wrap_get_func,
+)
+
+if sys.version_info >= (3, 11): # pragma: no cover
+ import tomllib
+else: # pragma: no cover
+ import tomli as tomllib
+
+HANDLED_SIGNALS = (
+ signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
+ signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
+)
+
+
+class Core:
+
+ should_exit: asyncio.Event
+ rules_priority_dict: Dict[int, List[Type[Rule[Any, Any, Any]]]]
+
+ _condition: (asyncio.Condition)
+ _current_event: Optional[Event[Any]]
+ _restart_flag: bool
+ _module_path_finder: ModulePathFinder
+ _raw_config_dict: Dict[str, Any]
+ _handle_event_tasks: Set[
+ "asyncio.Task[None]"
+ ] # Event handling task, used to keep a reference to the adapter task
+ # The following properties are not cleared on reboot
+ _config_file: Optional[str] # Configuration file
+ _config_dict: Optional[Dict[str, Any]] # Configuration dictionary
+ _hot_reload: bool # Hot-Reload
+ _extend_rules: List[
+ Union[Type[Rule[Any, Any, Any]], str, Path]
+ ] # A list of rules loaded programmatically using the ``load_rules()`` method
+ _extend_rule_dirs: List[
+ Path
+ ] # List of rule paths loaded programmatically using the ``load_rules_from_dirs()`` method
+ _core_run_hooks: List[CoreHook]
+ _core_exit_hooks: List[CoreHook]
+ _event_pre_processor_hooks: List[EventHook]
+ _event_post_processor_hooks: List[EventHook]
+
+ def __init__(
+ self,
+ *,
+ config_file: Optional[str] = "config.toml",
+ config_dict: Optional[Dict[str, Any]] = None,
+ hot_reload: bool = False,
+ ) -> None:
+ self.config = MainConfig()
+ self.rules_priority_dict = defaultdict(list)
+ self._current_event = None
+ self._restart_flag = False
+ self._module_path_finder = ModulePathFinder()
+ self._raw_config_dict = {}
+ self._handle_event_tasks = set()
+
+ self._config_file = config_file
+ self._config_dict = config_dict
+ self._hot_reload = hot_reload
+
+ self._extend_rules = []
+ self._extend_rule_dirs = []
+ self._core_run_hooks = []
+ self._core_exit_hooks = []
+ self._event_pre_processor_hooks = []
+ self._event_post_processor_hooks = []
+
+ sys.meta_path.insert(0, self._module_path_finder)
+
+ @property
+ def rules(self) -> List[Type[Rule[Any, Any, Any]]]:
+ """List of currently loaded rules."""
+ return list(chain(*self.rules_priority_dict.values()))
+
+ def run(self) -> None:
+ self._restart_flag = True
+ while self._restart_flag:
+ self._restart_flag = False
+ asyncio.run(self._run())
+ if self._restart_flag:
+ self._load_plugins_from_dirs(*self._extend_rule_dirs)
+ self._load_plugins(*self._extend_rules)
+
+ def restart(self) -> None:
+ logger.info("Restarting...")
+ self._restart_flag = True
+ self.should_exit.set()
+
+ async def _run(self) -> None:
+ self.should_exit = asyncio.Event()
+ self._condition = asyncio.Condition()
+
+ # Monitor and intercept system exit signals to complete some aftermath work before closing the program
+ if threading.current_thread() is threading.main_thread(): # pragma: no cover
+ # Signals can only be processed in the main thread
+ try:
+ loop = asyncio.get_running_loop()
+ for sig in HANDLED_SIGNALS:
+ loop.add_signal_handler(sig, self._handle_exit)
+ except NotImplementedError:
+ # add_signal_handler is only available under Unix, below for Windows
+ for sig in HANDLED_SIGNALS:
+ signal.signal(sig, self._handle_exit)
+
+ # Load configuration file
+ self._reload_config_dict()
+
+ self._load_rules_from_dirs(*self.config.bot.rule_dirs)
+ self._load_rules(*self.config.bot.rules)
+ self._update_config()
+
+ logger.info("Running...")
+
+ hot_reload_task = None
+ if self._hot_reload: # pragma: no cover
+ hot_reload_task = asyncio.create_task(self._run_hot_reload())
+
+ for core_run_hook_func in self._core_run_hooks:
+ await core_run_hook_func(self)
+
+ self.rules_priority_dict.clear()
+ self._module_path_finder.path.clear()
+
+ def _remove_rule_by_path(
+ self, file: Path
+ ) -> List[Type[Rule[Any, Any, Any]]]: # pragma: no cover
+ removed_rules: List[Type[Rule[Any, Any, Any]]] = []
+ for plugins in self.plugins_priority_dict.values():
+ _removed_rules = list(
+ filter(
+ lambda x: x.__rule_load_type__ != RuleLoadType.CLASS
+ and x.__rule_file_path__ is not None
+ and samefile(x.__rule_file_path__, file),
+ rules,
+ )
+ )
+ removed_rules.extend(_removed_rules)
+ for rule_ in _removed_rules:
+ rules.remove(rule_)
+ logger.info(
+ "Succeeded to remove rule "
+ f'"{rule_.__name__}" from file "{file}"'
+ )
+ return removed_rules
+
+ async def _run_hot_reload(self) -> None: # pragma: no cover
+ """Hot reload."""
+ try:
+ from watchfiles import Change, awatch
+ except ImportError:
+ logger.warning(
+ 'Hot reload needs to install "watchfiles", try "pip install watchfiles"'
+ )
+ return
+
+ logger.info("Hot reload is working!")
+ async for changes in awatch(
+ *(
+ x.resolve()
+ for x in set(self._extend_rule_dirs)
+ .union(self.config.core.rule_dirs)
+ .union(
+ {Path(self._config_file)}
+ if self._config_dict is None and self._config_file is not None
+ else set()
+ )
+ ),
+ stop_event=self.should_exit,
+ ):
+ # Processed in the order of Change.deleted, Change.modified, Change.added
+ # To ensure that when renaming occurs, deletions are processed first and then additions are processed
+ for change_type, file_ in sorted(changes, key=lambda x: x[0], reverse=True):
+ file = Path(file_)
+ # Change configuration file
+ if (
+ self._config_file is not None
+ and samefile(self._config_file, file)
+ and change_type == change_type.modified
+ ):
+ logger.info(f'Reload config file "{self._config_file}"')
+ old_config = self.config
+ self._reload_config_dict()
+ if (
+ self.config.bot != old_config.bot
+ or self.config.adapter != old_config.adapter
+ ):
+ self.restart()
+ continue
+
+ # Change rule folder
+ if change_type == Change.deleted:
+ # Special handling for deletion operations
+ if file.suffix != ".py":
+ file = file / "__init__.py"
+ else:
+ if file.is_dir() and (file / "__init__.py").is_file():
+ # When a new directory is added and this directory contains the ``__init__.py`` file
+ # It means that what happens at this time is that a Python package is added, and the ``__init__.py`` file of this package is deemed to be added
+ file = file / "__init__.py"
+ if not (file.is_file() and file.suffix == ".py"):
+ continue
+
+ if change_type == Change.added:
+ logger.info(f"Hot reload: Added file: {file}")
+ self._load_plugins(
+ Path(file), rule_load_type=RuleLoadType.DIR, reload=True
+ )
+ self._update_config()
+ continue
+ if change_type == Change.deleted:
+ logger.info(f"Hot reload: Deleted file: {file}")
+ self._remove_rule_by_path(file)
+ self._update_config()
+ elif change_type == Change.modified:
+ logger.info(f"Hot reload: Modified file: {file}")
+ self._remove_rule_by_path(file)
+ self._load_plugins(
+ Path(file), rule_load_type=RuleLoadType.DIR, reload=True
+ )
+ self._update_config()
+
+ def _update_config(self) -> None:
+ def update_config(
+ source: List[Type[Rule[Any, Any, Any]]],
+ name: str,
+ base: Type[ConfigModel],
+ ) -> Tuple[Type[ConfigModel], ConfigModel]:
+ config_update_dict: Dict[str, Any] = {}
+ for i in source:
+ config_class = getattr(i, "Config", None)
+ if is_config_class(config_class):
+ default_value: Any
+ try:
+ default_value = config_class()
+ except ValidationError:
+ default_value = ...
+ config_update_dict[config_class.__config_name__] = (
+ config_class,
+ default_value,
+ )
+ config_model = create_model(
+ name, **config_update_dict, __base__=base)
+ return config_model, config_model()
+
+ self.config = create_model(
+ "Config",
+ rule=update_config(self.rules, "RuleConfig", RuleConfig),
+ __base__=MainConfig,
+ )(**self._raw_config_dict)
+ # Update the level of logging
+ logger.remove()
+ logger.add(sys.stderr, level=self.config.bot.log.level)
+
+ def _reload_config_dict(self) -> None:
+ """Reload the configuration file."""
+ self._raw_config_dict = {}
+
+ if self._config_dict is not None:
+ self._raw_config_dict = self._config_dict
+ elif self._config_file is not None:
+ try:
+ with Path(self._config_file).open("rb") as f:
+ if self._config_file.endswith(".json"):
+ self._raw_config_dict = json.load(f)
+ elif self._config_file.endswith(".toml"):
+ self._raw_config_dict = tomllib.load(f)
+ else:
+ self.error_or_exception(
+ "Read config file failed:",
+ OSError("Unable to determine config file type"),
+ )
+ except OSError as e:
+ self.error_or_exception("Can not open config file:", e)
+ except (ValueError, json.JSONDecodeError, tomllib.TOMLDecodeError) as e:
+ self.error_or_exception("Read config file failed:", e)
+
+ try:
+ self.config = MainConfig(**self._raw_config_dict)
+ except ValidationError as e:
+ self.config = MainConfig()
+ self.error_or_exception("Config dict parse error:", e)
+ self._update_config()
+
+ def reload_rules(self) -> None:
+ self.rules_priority_dict.clear()
+ self._load_rules(*self.config.core.rules)
+ self._load_rules_from_dirs(*self.config.core.rule_dirs)
+ self._load_rules(*self._extend_rules)
+ self._load_rules_from_dirs(*self._extend_rule_dirs)
+ self._update_config()
+
+ def _handle_exit(self, *_args: Any) -> None: # pragma: no cover
+ """When the robot receives the exit signal, it will handle it according to the situation."""
+ logger.info("Stopping...")
+ if self.should_exit.is_set():
+ logger.warning("Force Exit...")
+ sys.exit()
+ else:
+ self.should_exit.set()
+
+ async def handle_event(
+ self,
+ current_event: Event[Any],
+ *,
+ handle_get: bool = True,
+ show_log: bool = True,
+ ) -> None:
+ if show_log:
+ logger.info(
+ f"Rule {current_event.rule.name} received: {current_event!r}"
+ )
+
+ if handle_get:
+ _handle_event_task = asyncio.create_task(self._handle_event())
+ self._handle_event_tasks.add(_handle_event_task)
+ _handle_event_task.add_done_callback(
+ self._handle_event_tasks.discard)
+ await asyncio.sleep(0)
+ async with self._condition:
+ self._current_event = current_event
+ self._condition.notify_all()
+ else:
+ _handle_event_task = asyncio.create_task(
+ self._handle_event(current_event))
+ self._handle_event_tasks.add(_handle_event_task)
+ _handle_event_task.add_done_callback(
+ self._handle_event_tasks.discard)
+
+ async def _handle_event(self, current_event: Optional[Event[Any]] = None) -> None:
+ if current_event is None:
+ async with self._condition:
+ await self._condition.wait()
+ assert self._current_event is not None
+ current_event = self._current_event
+ if current_event.__handled__:
+ return
+
+ for _hook_func in self._event_pre_processor_hooks:
+ await _hook_func(current_event)
+
+ for rule_priority in sorted(self.rules_priority_dict.keys()):
+ logger.debug(
+ f"Checking for matching rules with priority {rule_priority!r}"
+ )
+ stop = False
+ for rule in self.rules_priority_dict[rule_priority]:
+ try:
+ async with AsyncExitStack() as stack:
+ _rule = await solve_dependencies(
+ rule,
+ use_cache=True,
+ stack=stack,
+ dependency_cache={
+ Core: self,
+ Event: current_event,
+ },
+ )
+ if await _rule.rule():
+ logger.info(f"Event will be handled by {_rule!r}")
+ try:
+ await _rule.handle()
+ finally:
+ if _rule.block:
+ stop = True
+ except SkipException:
+ # The plug-in requires that it skips itself and continues the current event propagation
+ continue
+ except StopException:
+ # Plugin requires stopping current event propagation
+ stop = True
+ except Exception as e:
+ self.error_or_exception(f'Exception in rule "{rule}":', e)
+ if stop:
+ break
+
+ for _hook_func in self._event_post_processor_hooks:
+ await _hook_func(current_event)
+
+ logger.info("Event Finished")
+
+ @overload
+ async def get(
+ self,
+ func: Optional[Callable[[Event[Any]],
+ Union[bool, Awaitable[bool]]]] = None,
+ *,
+ event_type: None = None,
+ max_try_times: Optional[int] = None,
+ timeout: Optional[Union[int, float]] = None,
+ ) -> Event[Any]: ...
+
+ @overload
+ async def get(
+ self,
+ func: Optional[Callable[[EventT],
+ Union[bool, Awaitable[bool]]]] = None,
+ *,
+ event_type: None = None,
+ max_try_times: Optional[int] = None,
+ timeout: Optional[Union[int, float]] = None,
+ ) -> EventT: ...
+
+ @overload
+ async def get(
+ self,
+ func: Optional[Callable[[EventT],
+ Union[bool, Awaitable[bool]]]] = None,
+ *,
+ event_type: Type[EventT],
+ max_try_times: Optional[int] = None,
+ timeout: Optional[Union[int, float]] = None,
+ ) -> EventT: ...
+
+ async def get(
+ self,
+ func: Optional[Callable[[Any], Union[bool, Awaitable[bool]]]] = None,
+ *,
+ event_type: Optional[Type[Event[Any]]] = None,
+ max_try_times: Optional[int] = None,
+ timeout: Optional[Union[int, float]] = None,
+ ) -> Event[Any]:
+ """Get events that meet the specified conditions. The coroutine will wait until the adapter receives events that meet the conditions, exceeds the maximum number of events, or times out.
+
+ Args:
+ func: Coroutine or function, the function will be automatically packaged as a coroutine for execution.
+ Requires an event to be accepted as a parameter and returns a Boolean value. Returns the current event when the coroutine returns ``True``.
+ When ``None`` is equivalent to the input coroutine returning true for any event, that is, returning the next event received by the adapter.
+ event_type: When specified, only events of the specified type are accepted, taking effect before the func condition. Defaults to ``None``.
+ adapter_type: When specified, only events generated by the specified adapter will be accepted, taking effect before the func condition. Defaults to ``None``.
+ max_try_times: Maximum number of events.
+ timeout: timeout period.
+
+ Returns:
+ Returns events that satisfy the condition of ``func``.
+
+ Raises:
+ GetEventTimeout: Maximum number of events exceeded or timeout.
+ """
+ _func = wrap_get_func(func)
+
+ try_times = 0
+ start_time = time.time()
+ while not self.should_exit.is_set():
+ if max_try_times is not None and try_times > max_try_times:
+ break
+ if timeout is not None and time.time() - start_time > timeout:
+ break
+
+ async with self._condition:
+ if timeout is None:
+ await self._condition.wait()
+ else:
+ try:
+ await asyncio.wait_for(
+ self._condition.wait(),
+ timeout=start_time + timeout - time.time(),
+ )
+ except asyncio.TimeoutError:
+ break
+
+ if (
+ self._current_event is not None
+ and not self._current_event.__handled__
+ and (
+ event_type is None
+ or isinstance(self._current_event, event_type)
+ )
+ and await _func(self._current_event)
+ ):
+ self._current_event.__handled__ = True
+ return self._current_event
+
+ try_times += 1
+
+ raise GetEventTimeout
+
+ def _load_rule_class(
+ self,
+ rule_class: Type[Rule[Any, Any, Any]],
+ rule_load_type: RuleLoadType,
+ rule_file_path: Optional[str],
+ ) -> None:
+ """Load a rule class"""
+ priority = getattr(rule_class, "priority", None)
+ if isinstance(priority, int) and priority >= 0:
+ for _rule in self.rules:
+ if _rule.__name__ == rule_class.__name__:
+ logger.warning(
+ f'Already have a same name rule pack "{
+ _rule.__name__}"'
+ )
+ rule_class.__rule_load_type__ = rule_load_type
+ rule_class.__rule_file_path__ = rule_file_path
+ self.rules_priority_dict[priority].append(rule_class)
+ logger.info(
+ f'Succeeded to load rule "{rule_class.__name__}" '
+ f'from class "{rule_class!r}"'
+ )
+ else:
+ self.error_or_exception(
+ f'Load rule from class "{rule_class!r}" failed:',
+ LoadModuleError(
+ f'Rule priority incorrect in the class "{
+ rule_class!r}"'
+ ),
+ )
+
+ def _load_rules_from_module_name(
+ self,
+ module_name: str,
+ *,
+ rule_load_type: RuleLoadType,
+ reload: bool = False,
+ ) -> None:
+ """Load rules from the given module."""
+ try:
+ rule_classes = get_classes_from_module_name(
+ module_name, Rule, reload=reload
+ )
+ except ImportError as e:
+ self.error_or_exception(
+ f'Import module "{module_name}" failed:', e)
+ else:
+ for rule_class, module in rule_classes:
+ self._load_rule_class(
+ rule_class, # type: ignore
+ rule_load_type,
+ module.__file__,
+ )
+
+ def _load_rules(
+ self,
+ *rules: Union[Type[Rule[Any, Any, Any]], str, Path],
+ rule_load_type: Optional[RuleLoadType] = None,
+ reload: bool = False,
+ ) -> None:
+ for rule_ in rules:
+ try:
+ if isinstance(rule_, type) and issubclass(rule_, Rule):
+ self._load_plugin_class(
+ rule_, rule_load_type or RuleLoadType.CLASS, None
+ )
+ elif isinstance(rule_, str):
+ logger.info(f'Loading rules from module "{rule_}"')
+ self._load_rules_from_module_name(
+ rule_,
+ rule_load_type=rule_load_type or RuleLoadType.NAME,
+ reload=reload,
+ )
+ elif isinstance(rule_, Path):
+ logger.info(f'Loading rules from path "{rule_}"')
+ if not rule_.is_file():
+ raise LoadModuleError( # noqa: TRY301
+ f'The rule path "{rule_}" must be a file'
+ )
+
+ if rule_.suffix != ".py":
+ raise LoadModuleError( # noqa: TRY301
+ f'The path "{rule_}" must endswith ".py"'
+ )
+
+ rule_module_name = None
+ for path in self._module_path_finder.path:
+ try:
+ if rule_.stem == "__init__":
+ if rule_.resolve().parent.parent.samefile(Path(path)):
+ rule_module_name = rule_.resolve().parent.name
+ break
+ elif rule_.resolve().parent.samefile(Path(path)):
+ rule_module_name = rule_.stem
+ break
+ except OSError:
+ continue
+ if rule_module_name is None:
+ rel_path = rule_.resolve().relative_to(Path().resolve())
+ if rel_path.stem == "__init__":
+ rule_module_name = ".".join(rel_path.parts[:-1])
+ else:
+ rule_module_name = ".".join(
+ rel_path.parts[:-1] + (rel_path.stem,)
+ )
+
+ self._load_rules_from_module_name(
+ rule_module_name,
+ rule_load_type=rule_load_type or RuleLoadType.FILE,
+ reload=reload,
+ )
+ else:
+ raise TypeError( # noqa: TRY301
+ f"{rule_} can not be loaded as rule"
+ )
+ except Exception as e:
+ self.error_or_exception(f'Load rule "{rule_}" failed:', e)
+
+ def load_rules(
+ self, *rules: Union[Type[Rule[Any, Any, Any]], str, Path]
+ ) -> None:
+ self._extend_plugins.extend(rules)
+
+ return self._load_plugins(*rules)
+
+ def _load_rules_from_dirs(self, *dirs: Path) -> None:
+ dir_list = [str(x.resolve()) for x in dirs]
+ logger.info(f'Loading rules from dirs "{
+ ", ".join(map(str, dir_list))}"')
+ self._module_path_finder.path.extend(dir_list)
+ for module_info in pkgutil.iter_modules(dir_list):
+ if not module_info.name.startswith("_"):
+ self._load_rules_from_module_name(
+ module_info.name, rule_load_type=RuleLoadType.DIR
+ )
+
+ def load_rules_from_dirs(self, *dirs: Path) -> None:
+ self._extend_rule_dirs.extend(dirs)
+ self._load_rules_from_dirs(*dirs)
+
+ def get_plugin(self, name: str) -> Type[Rule[Any, Any, Any]]:
+ for _rule in self.rules:
+ if _rule.__name__ == name:
+ return _rule
+ raise LookupError(f'Can not find rule named "{name}"')
+
+ def error_or_exception(
+ self, message: str, exception: Exception
+ ) -> None: # pragma: no cover
+ """Output error or exception logs based on the current Bot configuration.
+
+ Args:
+ message: message.
+ exception: Exception.
+ """
+ if self.config.bot.log.verbose_exception:
+ logger.exception(message)
+ else:
+ logger.error(f"{message} {exception!r}")
+
+ def core_run_hook(self, func: CoreHook) -> CoreHook:
+ self._core_run_hooks.append(func)
+ return func
+
+ def core_exit_hook(self, func: CoreHook) -> CoreHook:
+ self._core_exit_hooks.append(func)
+ return func
+
+ def event_pre_processor_hook(self, func: EventHook) -> EventHook:
+ self._event_preprocessor_hooks.append(func)
+ return func
+
+ def event_post_processor_hook(self, func: EventHook) -> EventHook:
+ self._event_post_processor_hooks.append(func)
+ return func