aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/hydroroll/plugins/HydroRoll_plugin_base/__init__.py
blob: a050ae1b494c93f5c92e7d72d176f110f0db81c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import re
from abc import ABC, abstractmethod
from typing import Type, Union, Generic, TypeVar

from iamai import Plugin
from iamai.typing import T_State
from iamai.adapter.cqhttp.event import GroupMessageEvent, PrivateMessageEvent

from .config import BasePluginConfig, RegexPluginConfig, CommandPluginConfig

T_Config = TypeVar("T_Config", bound=BasePluginConfig)
T_RegexPluginConfig = TypeVar("T_RegexPluginConfig", bound=RegexPluginConfig)
T_CommandPluginConfig = TypeVar("T_CommandPluginConfig", bound=CommandPluginConfig)


class BasePlugin(
    Plugin[Union[PrivateMessageEvent, GroupMessageEvent], T_State, T_Config],
    ABC,
    Generic[T_State, T_Config],
):
    Config: Type[T_Config] = BasePluginConfig

    def format_str(self, format_str: str, message_str: str = "") -> str:
        return format_str.format(
            message=message_str,
            user_name=self.event.sender.nickname,
            user_id=self.event.sender.user_id,
        )

    async def rule(self) -> bool:
        is_bot_off = True
        
        if self.event.adapter.name != "cqhttp":
            return False
        if self.event.type != "message":
            return False
        match_str = self.event.message.get_plain_text()
        if is_bot_off:
            if self.event.message.startswith(f'[CQ:at,qq={self.event.self_id}]'):
                match_str = re.sub(fr'^\[CQ:at,qq={self.event.self_id}\]', '', match_str)
            else:
                return False
        if self.config.handle_all_message:
            return self.str_match(match_str)
        elif self.config.handle_friend_message:
            if self.event.message_type == "private":
                return self.str_match(match_str)
        elif self.config.handle_group_message:
            if self.event.message_type == "group":
                if (
                    self.config.accept_group is None
                    or self.event.group_id in self.config.accept_group
                ):
                    return self.str_match(match_str)
        return False

    @abstractmethod
    def str_match(self, msg_str: str) -> bool:
        raise NotImplemented


class RegexPluginBase(BasePlugin[T_State, T_RegexPluginConfig], ABC):
    msg_match: re.Match
    re_pattern: re.Pattern
    Config: Type[T_RegexPluginConfig] = RegexPluginConfig

    def str_match(self, msg_str: str) -> bool:
        msg_str = msg_str.strip()
        self.msg_match = self.re_pattern.fullmatch(msg_str)
        return bool(self.msg_match)


class CommandPluginBase(RegexPluginBase[T_State, T_CommandPluginConfig], ABC):
    command_match: re.Match
    command_re_pattern: re.Pattern
    Config: Type[T_CommandPluginConfig] = CommandPluginConfig

    def str_match(self, msg_str: str) -> bool:
        if not hasattr(self, "command_re_pattern"):
            self.command_re_pattern = re.compile(
                f'[{"".join(self.config.command_prefix)}]'
                f'({"|".join(self.config.command)})'
                r"\s*(?P<command_args>.*)",
                flags=re.I if self.config.ignore_case else 0,
            )
        msg_str = msg_str.strip()
        self.command_match = self.command_re_pattern.fullmatch(msg_str)
        if not self.command_match:
            return False
        self.msg_match = self.re_pattern.fullmatch(
            self.command_match.group("command_args")
        )
        return bool(self.msg_match)