From 0250c8373c8ea12c6d12cf399a15c57f8690b032 Mon Sep 17 00:00:00 2001 From: 苏向夜 Date: Fri, 26 Jan 2024 15:04:21 +0800 Subject: :recycle: refactor(output): add __init__ method in Output class in order to generate output mehtod in initialization --- src/infini/interceptor.py | 2 +- src/infini/output.py | 19 ++++++++++++++----- tests/test_handlers.py | 40 ++++++++++++++-------------------------- tests/test_interceptor.py | 7 +------ 4 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/infini/interceptor.py b/src/infini/interceptor.py index 54387281..c227fef7 100644 --- a/src/infini/interceptor.py +++ b/src/infini/interceptor.py @@ -1,6 +1,6 @@ from infini.input import Input from infini.output import Output -from infini.typing import List, RouterType, Callable, Generic, T, overload +from infini.typing import List, RouterType, Callable from infini.queue import EventQueue diff --git a/src/infini/output.py b/src/infini/output.py index ddb7f0a5..92913b1e 100644 --- a/src/infini/output.py +++ b/src/infini/output.py @@ -7,13 +7,22 @@ class Output: status: int block: bool + def __init__( + self, + type: Literal["null", "text", "workflow"], + name: str, + *, + status: int = 0, + block: bool = False, + ) -> None: + self.type = type + self.name = name + self.status = status + self.block = block + @classmethod def empty(cls) -> "Output": - output = cls() - output.type = "null" - output.status = 0 - output.block = True - return output + return cls("null", "null", status=0, block=True) def is_empty(self) -> bool: return self.type == "null" diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 8d2d2894..bf9dd9a0 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -8,21 +8,15 @@ def test_handler(): input = Input(".add 1 2") def add(input: Input) -> Output: - a, b = map(int, input.get_plain_text().lstrip(".add").split()) - output = Output() - output.block = False - output.status = 0 - output.type = "text" - output.name = str(a + b) - return output + return Output( + "text", + str(sum(list(map(int, input.get_plain_text().lstrip(".add").split())))), + status=0, + block=False, + ) def cmd(_: Input) -> Output: - output = Output() - output.block = False - output.status = 0 - output.type = "text" - output.name = "cmd" - return output + return Output("text", "cmd", status=0, block=False) handler = Handler() handler.handlers = [ @@ -48,21 +42,15 @@ def test_handler_block(): input = Input(".add 1 2") def add(input: Input) -> Output: - a, b = map(int, input.get_plain_text().lstrip(".add").split()) - output = Output() - output.block = False - output.status = 0 - output.type = "text" - output.name = str(a + b) - return output + return Output( + "text", + str(sum(list(map(int, input.get_plain_text().lstrip(".add").split())))), + status=0, + block=False, + ) def cmd(_: Input) -> Output: - output = Output() - output.block = True - output.status = 0 - output.type = "text" - output.name = "cmd" - return output + return Output("text", "cmd", status=0, block=True) handler = Handler() handler.handlers = [ diff --git a/tests/test_interceptor.py b/tests/test_interceptor.py index df5b6750..fb12beac 100644 --- a/tests/test_interceptor.py +++ b/tests/test_interceptor.py @@ -9,12 +9,7 @@ def test_interceptor(): valid_input = Input("这个叫苏向夜.") def intercept(_: Input) -> Input | Output: - output = Output() - output.block = True # TODO 拦截器阻塞标识 - output.name = "block.jianlvchun" - output.status = 0 - output.type = "text" - return output + return Output("text", "block.jianlvchun", block=True) # TODO 拦截器阻塞标识 interceptor = Interceptor() interceptor.interceptors = [ -- cgit v1.2.3-70-g09d2