aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tests/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'tests/plugins')
-rw-r--r--tests/plugins/HydroRoll/__init__.py35
1 files changed, 21 insertions, 14 deletions
diff --git a/tests/plugins/HydroRoll/__init__.py b/tests/plugins/HydroRoll/__init__.py
index e1252f0..2f883a8 100644
--- a/tests/plugins/HydroRoll/__init__.py
+++ b/tests/plugins/HydroRoll/__init__.py
@@ -17,6 +17,7 @@ HYDRO_DIR = dirname(abspath(__file__))
class HydroRoll(Plugin):
"""中间件"""
+
class Config(ConfigModel):
__config_name__ = "HydroRoll"
@@ -28,7 +29,7 @@ class HydroRoll(Plugin):
def __post_init__(self):
self.state = {}
self.model_path_list = []
- self.bot.global_state['HydroRoll'] = {}
+ self.bot.global_state["HydroRoll"] = {}
self.model_dict = Models().get_models_dict()
self.model_path_list.append(join(BASE_DIR, "models"))
@@ -58,34 +59,41 @@ class HydroRoll(Plugin):
@BODY: lexer module will return a list of tokens, parser module will parse the tokens into a tree, and executor module will execute the tokens with a stack with a bool return value.
"""
logger.info("loading psi...")
- if not self.bot.global_state['HydroRoll'].get('hola') and self.event.type == "message" and self.event.message_type == "private" and not os.path.exists(join(BASE_DIR, "HydroRoll")):
- hola = self.models['hola']
- _, max_similarity = find_max_similarity(
- self.event.message.get_plain_text(), hola)
+ if (
+ not self.bot.global_state["HydroRoll"].get("hola")
+ and self.event.type == "message"
+ and self.event.message_type == "private"
+ and not os.path.exists(join(BASE_DIR, "HydroRoll"))
+ ):
+ # hola = self.models["hola"]
+ # _, max_similarity = find_max_similarity(
+ # self.event.message.get_plain_text(), hola
+ # )
+ max_similarity = 1
if max_similarity > 0.51:
self.init_directory()
- self.bot.global_state['HydroRoll']['hola'] = True
+ self.bot.global_state["HydroRoll"]["hola"] = True
await self.event.reply(f"验证成功√ 正在初始化水系目录...")
logger.info(GlobalConfig._copyright)
- return self.event.adapter.name in ['cqhttp', 'kook', 'console', 'mirai']
+ return self.event.adapter.name in ["cqhttp", "kook", "console", "mirai"]
- def _init_directory(self, _prefix: str = ''):
+ def _init_directory(self, _prefix: str = ""):
"""初始化水系目录"""
for _ in Directory(BASE_DIR).get_dice_dir_list(_prefix):
if not os.path.exists(_):
os.makedirs(_)
- def _init_file(self, _prefix: str = ''):
+ def _init_file(self, _prefix: str = ""):
"""初始化文件"""
- def init_directory(self, _prefix: str = 'HydroRoll'):
+ def init_directory(self, _prefix: str = "HydroRoll"):
"""在指定目录生成水系文件结构"""
self._init_directory(_prefix=_prefix)
def _load_model(self, path: str, model_file: str):
if model_file is None:
- model_file = ''
- return joblib.load(join(path, f'{model_file}'))
+ model_file = ""
+ return joblib.load(join(path, f"{model_file}"))
def _load_models(self, model_path_list: list, model_dict: dict) -> dict:
"""加载指定模型, 当然也可能是数据集"""
@@ -94,8 +102,7 @@ class HydroRoll(Plugin):
for model_name, model_file in model_dict.items():
if os.path.exists(join(path, model_file)):
models[model_name] = self._load_model(path, model_file)
- logger.success(
- f'Succeeded to load model "{model_name}"')
+ logger.success(f'Succeeded to load model "{model_name}"')
return models
def load_models(self):