aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/hrc
diff options
context:
space:
mode:
author简律纯 <i@jyunko.cn>2024-07-06 09:07:19 +0800
committer简律纯 <i@jyunko.cn>2024-07-06 09:07:19 +0800
commit0ed10486f719c23ab7e0e84d2e119a7fa5f70475 (patch)
tree095b77f6220fdf929de0a1de29332b1912aa30e1 /hrc
parentc0518c138914b321d0fa2d7b0d1377f78ff85b3c (diff)
downloadHydroRollCore-0ed10486f719c23ab7e0e84d2e119a7fa5f70475.tar.gz
HydroRollCore-0ed10486f719c23ab7e0e84d2e119a7fa5f70475.zip
refactor!: rewrite core business logic
Diffstat (limited to 'hrc')
-rw-r--r--hrc/core.py322
-rw-r--r--hrc/typing.py1
2 files changed, 154 insertions, 169 deletions
diff --git a/hrc/core.py b/hrc/core.py
index 729f069..20d9f74 100644
--- a/hrc/core.py
+++ b/hrc/core.py
@@ -30,7 +30,7 @@ from .dependencies import solve_dependencies
from .log import logger
from .rule import Rule, RuleLoadType
from .event import Event
-from .typing import CoreHook, EventHook, EventT
+from .typing import CoreHook, EventHook, EventT, RuleHook
from .utils import (
ModulePathFinder,
get_classes_from_module_name,
@@ -53,31 +53,22 @@ HANDLED_SIGNALS = (
class Core:
- should_exit: asyncio.Event
- rules_priority_dict: Dict[int, List[Type[Rule[Any, Any, Any]]]]
-
- _condition: asyncio.Condition
+ config: MainConfig
_current_event: Optional[Event[Any]]
- _restart_flag: bool
_module_path_finder: ModulePathFinder
- _raw_config_dict: Dict[str, Any]
- _handle_event_tasks: Set[
- "asyncio.Task[None]"
- ] # Event handling task, used to keep a reference to the adapter task
- # The following properties are not cleared on reboot
+ _hot_reload: bool
+ # pyright: ignore[reportUninitializedInstanceVariable]
+ should_exit: asyncio.Event
+ _restart_flag: bool # Restart flag
+ _extend_rules: List[Union[Type[Rule[Any, Any, Any]], str, Path]]
+ _extend_rule_dirs: List[Path]
+ rules_priority_dict: Dict[int, List[Type[Rule[Any, Any, Any]]]]
_config_file: Optional[str] # Configuration file
_config_dict: Optional[Dict[str, Any]] # Configuration dictionary
- _hot_reload: bool # Hot-Reload
- _extend_rules: List[
- Union[Type[Rule[Any, Any, Any]], str, Path]
- ] # A list of rules loaded programmatically using the ``load_rules()`` method
- _extend_rule_dirs: List[
- Path
- ] # List of rule paths loaded programmatically using the ``load_rules_from_dirs()`` method
- _core_run_hooks: List[CoreHook]
- _core_exit_hooks: List[CoreHook]
- _event_pre_processor_hooks: List[EventHook]
- _event_post_processor_hooks: List[EventHook]
+
+ _condition: (
+ asyncio.Condition
+ ) # Condition used to handle get # pyright: ignore[reportUninitializedInstanceVariable]
def __init__(
self,
@@ -87,29 +78,26 @@ class Core:
hot_reload: bool = False,
) -> None:
self.config = MainConfig()
- self.rules_priority_dict = defaultdict(list)
self._current_event = None
- self._restart_flag = False
- self._module_path_finder = ModulePathFinder()
- self._raw_config_dict = {}
- self._handle_event_tasks = set()
-
self._config_file = config_file
self._config_dict = config_dict
self._hot_reload = hot_reload
-
- self._extend_rules = []
- self._extend_rule_dirs = []
+ self._restart_flag = False
+ self._module_path_finder = ModulePathFinder()
+ self.rules_priority_dict = defaultdict(list)
+
self._core_run_hooks = []
self._core_exit_hooks = []
- self._event_pre_processor_hooks = []
- self._event_post_processor_hooks = []
+ self._rule_enable_hooks = []
+ self._rule_run_hooks = []
+ self._rule_disable_hooks = []
+ self._event_preprocessor_hooks = []
+ self._event_postprocessor_hooks = []
sys.meta_path.insert(0, self._module_path_finder)
@property
def rules(self) -> List[Type[Rule[Any, Any, Any]]]:
- """List of currently loaded rules."""
return list(chain(*self.rules_priority_dict.values()))
def run(self) -> None:
@@ -118,8 +106,8 @@ class Core:
self._restart_flag = False
asyncio.run(self._run())
if self._restart_flag:
- self._load_plugins_from_dirs(*self._extend_rule_dirs)
- self._load_plugins(*self._extend_rules)
+ self._load_rules_from_dirs(*self._extend_rule_dirs)
+ self._load_rules(*self._extend_rules)
def restart(self) -> None:
logger.info("Restarting...")
@@ -145,27 +133,61 @@ class Core:
# Load configuration file
self._reload_config_dict()
- self._load_rules_from_dirs(*self.config.bot.rule_dirs)
- self._load_rules(*self.config.bot.rules)
+ self._load_rules_from_dirs(*self.config.core.rule_dirs)
+ self._load_rules(*self.config.core.rules)
self._update_config()
logger.info("Running...")
hot_reload_task = None
if self._hot_reload: # pragma: no cover
- hot_reload_task = asyncio.create_task(self._run_hot_reload()) # noqa: F841
+ hot_reload_task = asyncio.create_task(self._run_core_reload())
for core_run_hook_func in self._core_run_hooks:
await core_run_hook_func(self)
- self.rules_priority_dict.clear()
- self._module_path_finder.path.clear()
+ try:
+ for _rule in self.rules:
+ for rule_enable_hook_func in self._rule_enable_hooks:
+ await rule_enable_hook_func(_rule)
+ try:
+ await _rule.enable()
+ except Exception as e:
+ self.error_or_exception(
+ f"Enable rule {_rule!r} failed:", e)
+
+ for _rule in self.rules:
+ for rule_run_hook_func in self._rule_run_hooks:
+ await rule_run_hook_func(_rule)
+ _rule_task = asyncio.create_task(_rule.safe_run())
+ self._rule_tasks.add(_rule_task)
+ _rule_task.add_done_callback(self._rule_tasks.discard)
+
+ await self.should_exit.wait()
+
+ if hot_reload_task is not None: # pragma: no cover
+ await hot_reload_task
+ finally:
+ for _rule in self.rules:
+ for rule_shutdown_hook_func in self._rule_shutdown_hooks:
+ await rule_shutdown_hook_func(_rule)
+ await _rule.disable()
+
+ while self._rule_tasks:
+ await asyncio.sleep(0)
+
+ for core_exit_hook_func in self._core_exit_hooks:
+ await core_exit_hook_func(self)
+
+ self.rules.clear()
+ self.rules_priority_dict.clear()
+ self._module_path_finder.path.clear()
def _remove_rule_by_path(
self, file: Path
) -> List[Type[Rule[Any, Any, Any]]]: # pragma: no cover
removed_rules: List[Type[Rule[Any, Any, Any]]] = []
- for rules in self.plugins_priority_dict.values():
+ for rules in self.rules_priority_dict.values():
_removed_rules = list(
filter(
lambda x: x.__rule_load_type__ != RuleLoadType.CLASS
@@ -220,8 +242,8 @@ class Core:
old_config = self.config
self._reload_config_dict()
if (
- self.config.bot != old_config.bot
- or self.config.adapter != old_config.adapter
+ self.config.core != old_config.core
+ or self.config.rule != old_config.rule
):
self.restart()
continue
@@ -241,7 +263,7 @@ class Core:
if change_type == Change.added:
logger.info(f"Hot reload: Added file: {file}")
- self._load_plugins(
+ self._load_rules(
Path(file), rule_load_type=RuleLoadType.DIR, reload=True
)
self._update_config()
@@ -253,7 +275,7 @@ class Core:
elif change_type == Change.modified:
logger.info(f"Hot reload: Modified file: {file}")
self._remove_rule_by_path(file)
- self._load_plugins(
+ self._load_rules(
Path(file), rule_load_type=RuleLoadType.DIR, reload=True
)
self._update_config()
@@ -277,7 +299,8 @@ class Core:
config_class,
default_value,
)
- config_model = create_model(name, **config_update_dict, __base__=base)
+ config_model = create_model(
+ name, **config_update_dict, __base__=base)
return config_model, config_model()
self.config = create_model(
@@ -287,7 +310,7 @@ class Core:
)(**self._raw_config_dict)
# Update the level of logging
logger.remove()
- logger.add(sys.stderr, level=self.config.bot.log.level)
+ logger.add(sys.stderr, level=self.config.core.log.level)
def _reload_config_dict(self) -> None:
"""Reload the configuration file."""
@@ -320,6 +343,7 @@ class Core:
self._update_config()
def reload_rules(self) -> None:
+ """Manually reload all rules."""
self.rules_priority_dict.clear()
self._load_rules(*self.config.core.rules)
self._load_rules_from_dirs(*self.config.core.rule_dirs)
@@ -344,20 +368,24 @@ class Core:
show_log: bool = True,
) -> None:
if show_log:
- logger.info(f"Rule {current_event.rule.name} received: {current_event!r}")
+ logger.info(
+ f"Rule {current_event.rule.name} received: {current_event!r}")
if handle_get:
_handle_event_task = asyncio.create_task(self._handle_event())
self._handle_event_tasks.add(_handle_event_task)
- _handle_event_task.add_done_callback(self._handle_event_tasks.discard)
+ _handle_event_task.add_done_callback(
+ self._handle_event_tasks.discard)
await asyncio.sleep(0)
async with self._condition:
self._current_event = current_event
self._condition.notify_all()
else:
- _handle_event_task = asyncio.create_task(self._handle_event(current_event))
+ _handle_event_task = asyncio.create_task(
+ self._handle_event(current_event))
self._handle_event_tasks.add(_handle_event_task)
- _handle_event_task.add_done_callback(self._handle_event_tasks.discard)
+ _handle_event_task.add_done_callback(
+ self._handle_event_tasks.discard)
async def _handle_event(self, current_event: Optional[Event[Any]] = None) -> None:
if current_event is None:
@@ -368,11 +396,13 @@ class Core:
if current_event.__handled__:
return
- for _hook_func in self._event_pre_processor_hooks:
+ for _hook_func in self._event_preprocessor_hooks:
await _hook_func(current_event)
for rule_priority in sorted(self.rules_priority_dict.keys()):
- logger.debug(f"Checking for matching rules with priority {rule_priority!r}")
+ logger.debug(
+ f"Checking for matching rules with priority {rule_priority!r}"
+ )
stop = False
for rule in self.rules_priority_dict[rule_priority]:
try:
@@ -386,6 +416,10 @@ class Core:
Event: current_event,
},
)
+ if _rule.name not in self.rule_state:
+ rule_state = _rule.__init_state__()
+ if rule_state is not None:
+ self.rule_state[_rule.name] = rule_state
if await _rule.rule():
logger.info(f"Event will be handled by {_rule!r}")
try:
@@ -397,111 +431,18 @@ class Core:
# The plug-in requires that it skips itself and continues the current event propagation
continue
except StopException:
- # Plugin requires stopping current event propagation
+ # rule requires stopping current event propagation
stop = True
except Exception as e:
self.error_or_exception(f'Exception in rule "{rule}":', e)
if stop:
break
- for _hook_func in self._event_post_processor_hooks:
+ for _hook_func in self._event_postprocessor_hooks:
await _hook_func(current_event)
logger.info("Event Finished")
- @overload
- async def get(
- self,
- func: Optional[Callable[[Event[Any]], Union[bool, Awaitable[bool]]]] = None,
- *,
- event_type: None = None,
- max_try_times: Optional[int] = None,
- timeout: Optional[Union[int, float]] = None,
- ) -> Event[Any]: ...
-
- @overload
- async def get(
- self,
- func: Optional[Callable[[EventT], Union[bool, Awaitable[bool]]]] = None,
- *,
- event_type: None = None,
- max_try_times: Optional[int] = None,
- timeout: Optional[Union[int, float]] = None,
- ) -> EventT: ...
-
- @overload
- async def get(
- self,
- func: Optional[Callable[[EventT], Union[bool, Awaitable[bool]]]] = None,
- *,
- event_type: Type[EventT],
- max_try_times: Optional[int] = None,
- timeout: Optional[Union[int, float]] = None,
- ) -> EventT: ...
-
- async def get(
- self,
- func: Optional[Callable[[Any], Union[bool, Awaitable[bool]]]] = None,
- *,
- event_type: Optional[Type[Event[Any]]] = None,
- max_try_times: Optional[int] = None,
- timeout: Optional[Union[int, float]] = None,
- ) -> Event[Any]:
- """Get events that meet the specified conditions. The coroutine will wait until the adapter receives events that meet the conditions, exceeds the maximum number of events, or times out.
-
- Args:
- func: Coroutine or function, the function will be automatically packaged as a coroutine for execution.
- Requires an event to be accepted as a parameter and returns a Boolean value. Returns the current event when the coroutine returns ``True``.
- When ``None`` is equivalent to the input coroutine returning true for any event, that is, returning the next event received by the adapter.
- event_type: When specified, only events of the specified type are accepted, taking effect before the func condition. Defaults to ``None``.
- adapter_type: When specified, only events generated by the specified adapter will be accepted, taking effect before the func condition. Defaults to ``None``.
- max_try_times: Maximum number of events.
- timeout: timeout period.
-
- Returns:
- Returns events that satisfy the condition of ``func``.
-
- Raises:
- GetEventTimeout: Maximum number of events exceeded or timeout.
- """
- _func = wrap_get_func(func)
-
- try_times = 0
- start_time = time.time()
- while not self.should_exit.is_set():
- if max_try_times is not None and try_times > max_try_times:
- break
- if timeout is not None and time.time() - start_time > timeout:
- break
-
- async with self._condition:
- if timeout is None:
- await self._condition.wait()
- else:
- try:
- await asyncio.wait_for(
- self._condition.wait(),
- timeout=start_time + timeout - time.time(),
- )
- except asyncio.TimeoutError:
- break
-
- if (
- self._current_event is not None
- and not self._current_event.__handled__
- and (
- event_type is None
- or isinstance(self._current_event, event_type)
- )
- and await _func(self._current_event)
- ):
- self._current_event.__handled__ = True
- return self._current_event
-
- try_times += 1
-
- raise GetEventTimeout
-
def _load_rule_class(
self,
rule_class: Type[Rule[Any, Any, Any]],
@@ -514,7 +455,7 @@ class Core:
for _rule in self.rules:
if _rule.__name__ == rule_class.__name__:
logger.warning(
- f'Already have a same name rule pack "{_rule.__name__}"'
+ f'Already have a same name rule "{_rule.__name__}"'
)
rule_class.__rule_load_type__ = rule_load_type
rule_class.__rule_file_path__ = rule_file_path
@@ -527,8 +468,7 @@ class Core:
self.error_or_exception(
f'Load rule from class "{rule_class!r}" failed:',
LoadModuleError(
- f'Rule priority incorrect in the class "{
- rule_class!r}"'
+ f'rule priority incorrect in the class "{rule_class!r}"'
),
)
@@ -545,7 +485,8 @@ class Core:
module_name, Rule, reload=reload
)
except ImportError as e:
- self.error_or_exception(f'Import module "{module_name}" failed:', e)
+ self.error_or_exception(
+ f'Import module "{module_name}" failed:', e)
else:
for rule_class, module in rule_classes:
self._load_rule_class(
@@ -560,10 +501,22 @@ class Core:
rule_load_type: Optional[RuleLoadType] = None,
reload: bool = False,
) -> None:
+ """Load rules.
+
+ Args:
+ *rules: plug-in class, plug-in module name or plug-in module file path. Type can be ``Type[rule]``, ``str`` or ``pathlib.Path``.
+ If it is ``Type[rule]``, it will be loaded as a plug-in class.
+ If it is of type ``str``, it will be loaded as the plug-in module name, and the format is the same as the Python ``import`` statement.
+ For example: ``path.of.rule``.
+ If it is of type ``pathlib.Path``, it will be loaded as the plug-in module file path.
+ For example: ``pathlib.Path("path/of/rule")``.
+ rule_load_type: Plug-in loading type, if it is ``None``, it will be automatically determined, otherwise the specified type will be used.
+ reload: Whether to reload the module.
+ """
for rule_ in rules:
try:
if isinstance(rule_, type) and issubclass(rule_, Rule):
- self._load_plugin_class(
+ self._load_rule_class(
rule_, rule_load_type or RuleLoadType.CLASS, None
)
elif isinstance(rule_, str):
@@ -618,15 +571,34 @@ class Core:
except Exception as e:
self.error_or_exception(f'Load rule "{rule_}" failed:', e)
- def load_rules(self, *rules: Union[Type[Rule[Any, Any, Any]], str, Path]) -> None:
- self._extend_plugins.extend(rules)
+ def load_rules(
+ self, *rules: Union[Type[Rule[Any, Any, Any]], str, Path]
+ ) -> None:
+ """Load the rule.
+
+ Args:
+ *rules: ``rule`` class, rule module name or plug-in module file path.
+ Type can be ``Type[rule]``, ``str`` or ``pathlib.Path``.
+ If it is ``Type[rule]``, it will be loaded as a plug-in class.
+ If it is of type ``str``, it will be loaded as the plug-in module name, and the format is the same as the Python ``import`` statement.
+ For example: ``path.of.rule``.
+ If it is of type ``pathlib.Path``, it will be loaded as the plug-in module file path.
+ For example: ``pathlib.Path("path/of/rule")``.
+ """
+ self._extend_rules.extend(rules)
- return self._load_plugins(*rules)
+ return self._load_rules(*rules)
def _load_rules_from_dirs(self, *dirs: Path) -> None:
+ """Load plug-ins from the directory. Plug-ins in modules starting with ``_`` will not be imported. The path can be a relative path or an absolute path.
+
+ Args:
+ *dirs: Module paths that store modules containing rules.
+ For example: ``pathlib.Path("path/of/rules/")`` .
+ """
dir_list = [str(x.resolve()) for x in dirs]
- logger.info(f'Loading rules from dirs "{
- ", ".join(map(str, dir_list))}"')
+ logger.info(
+ f'Loading rules from dirs "{", ".join(map(str, dir_list))}"')
self._module_path_finder.path.extend(dir_list)
for module_info in pkgutil.iter_modules(dir_list):
if not module_info.name.startswith("_"):
@@ -635,6 +607,12 @@ class Core:
)
def load_rules_from_dirs(self, *dirs: Path) -> None:
+ """Load plug-ins from the directory. Plug-ins in modules starting with ``_`` will not be imported. The path can be a relative path or an absolute path.
+
+ Args:
+ *dirs: Module paths that store modules containing rules.
+ For example: ``pathlib.Path("path/of/rules/")`` .
+ """
self._extend_rule_dirs.extend(dirs)
self._load_rules_from_dirs(*dirs)
@@ -647,13 +625,7 @@ class Core:
def error_or_exception(
self, message: str, exception: Exception
) -> None: # pragma: no cover
- """Output error or exception logs based on the current Bot configuration.
-
- Args:
- message: message.
- exception: Exception.
- """
- if self.config.bot.log.verbose_exception:
+ if self.config.core.log.verbose_exception:
logger.exception(message)
else:
logger.error(f"{message} {exception!r}")
@@ -666,10 +638,22 @@ class Core:
self._core_exit_hooks.append(func)
return func
- def event_pre_processor_hook(self, func: EventHook) -> EventHook:
+ def rule_enable_hook(self, func: RuleHook) -> RuleHook:
+ self._rule_enable_hooks.append(func)
+ return func
+
+ def rule_run_hook(self, func: RuleHook) -> RuleHook:
+ self._rule_run_hooks.append(func)
+ return func
+
+ def rule_disable_hook(self, func: RuleHook) -> RuleHook:
+ self._rule_disable_hooks.append(func)
+ return func
+
+ def event_preprocessor_hook(self, func: EventHook) -> EventHook:
self._event_preprocessor_hooks.append(func)
return func
- def event_post_processor_hook(self, func: EventHook) -> EventHook:
- self._event_post_processor_hooks.append(func)
+ def event_postprocessor_hook(self, func: EventHook) -> EventHook:
+ self._event_postprocessor_hooks.append(func)
return func
diff --git a/hrc/typing.py b/hrc/typing.py
index a873194..d74fd26 100644
--- a/hrc/typing.py
+++ b/hrc/typing.py
@@ -16,4 +16,5 @@ RuleT = TypeVar("RuleT", bound="Rule[Any, Any, Any]")
ConfigT = TypeVar("ConfigT", bound=Optional["ConfigModel"])
CoreHook = Callable[["Core"], Awaitable[None]]
+RuleHook = Callable[["Rule"], Awaitable[None]]
EventHook = Callable[["Event[Any]"], Awaitable[None]]