summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
author苏向夜 <fu050409@163.com>2024-01-27 19:07:06 +0800
committer苏向夜 <fu050409@163.com>2024-01-27 19:07:06 +0800
commita02512aa689631426f3df640cf56a9e772f0e616 (patch)
tree6ec881fef485ffcf451598b646ee07d57eecbcc9
parent7f997f04fb2acc0a7f535f4522ccec22905de859 (diff)
downloadinfini-a02512aa689631426f3df640cf56a9e772f0e616.tar.gz
infini-a02512aa689631426f3df640cf56a9e772f0e616.zip
:sparkles: feat(core|interceptor): supports output for interceptors
-rw-r--r--src/infini/core.py40
-rw-r--r--src/infini/interceptor.py60
2 files changed, 72 insertions, 28 deletions
diff --git a/src/infini/core.py b/src/infini/core.py
index cbe90924..40ed6f4c 100644
--- a/src/infini/core.py
+++ b/src/infini/core.py
@@ -1,26 +1,40 @@
from infini.input import Input
from infini.interceptor import Interceptor
-from infini.generator import Generator
+from infini.generator import TextGenerator
from infini.handler import Handler
from infini.output import Output
+from infini.typing import Any, Generator
class Core:
pre_interceptor: Interceptor
handler: Handler
- generator: Generator
+ generator: TextGenerator
interceptor: Interceptor
- def input(self, input: Input):
- if isinstance(pre_intercepted_stream := self.pre_intercept(input), Output):
- yield self.generate(pre_intercepted_stream)
- return
- for handled_stream in self.handle(pre_intercepted_stream):
+ def input(self, input: Input) -> Generator[str, Any, None]:
+ for pre_intercepted_stream in self.pre_intercept(input):
+ if isinstance(pre_intercepted_stream, Output):
+ yield self.generate(pre_intercepted_stream)
+ if pre_intercepted_stream.block or pre_intercepted_stream.is_empty():
+ return
+ else:
+ input = pre_intercepted_stream
+
+ for handled_stream in self.handle(input):
if handled_stream.is_empty():
return
- yield self.intercept(self.generate(handled_stream))
+ outcome = self.generate(handled_stream)
+ for stream in self.intercept(outcome):
+ if isinstance(stream, Output):
+ yield self.generate(stream)
+ continue
+ outcome = stream
+ if handled_stream.block:
+ return
+ yield outcome
- def pre_intercept(self, input: Input) -> Input | Output:
+ def pre_intercept(self, input: Input) -> Generator[Output | Input, Any, None]:
return self.pre_interceptor.input(input)
def handle(self, input: Input):
@@ -31,9 +45,5 @@ class Core:
def generate(self, output: Output) -> str:
return self.generator.output(output)
- def intercept(self, output: str) -> str:
- return (
- self.generate(callback)
- if isinstance(callback := self.interceptor.output(output), Output)
- else callback
- )
+ def intercept(self, output_text: str) -> Generator[Output | str, Any, None]:
+ return self.interceptor.output(output_text)
diff --git a/src/infini/interceptor.py b/src/infini/interceptor.py
index 28a3f869..c8866ec0 100644
--- a/src/infini/interceptor.py
+++ b/src/infini/interceptor.py
@@ -1,32 +1,66 @@
from infini.input import Input
from infini.output import Output
-from infini.typing import List, RouterType, Callable
+from infini.typing import List, Any, RouterType, Callable, Generator
from infini.queue import EventQueue
class Interceptor:
interceptors: List[RouterType]
- def input(self, input: Input) -> Input | Output:
+ def input(self, input: Input) -> Generator[Output | Input, Any, None]:
queue = self.match(input.get_plain_text())
while not queue.is_empty():
- if isinstance(intercepted := queue.pop()(input), Output):
- return intercepted # TODO 允许拦截器产出文本
+ if isinstance(stream := queue.pop()(input), Generator):
+ for outcome in stream:
+ if isinstance(outcome, Input):
+ input = outcome
+ break
+ yield outcome
+ if outcome.block:
+ return
else:
- input = intercepted
- return input
+ if stream is None:
+ yield Output.empty()
+ return
+ if isinstance(stream, Output):
+ yield stream
+ if stream.block:
+ return
+ continue
+ input = stream
+ yield input
- def output(self, output_text: str) -> str | Output:
- queue = self.match(output_text) # TODO 需要测试输出拦截情况
+ def output(
+ self, output_text: str
+ ) -> Generator[Output | str, Any, None]:
input = Input(output_text)
+ queue = self.match(input.get_plain_text())
while not queue.is_empty():
- if isinstance(intercepted := queue.pop()(input), Output):
- return intercepted
+ if isinstance(stream := queue.pop()(input), Generator):
+ for outcome in stream:
+ if isinstance(outcome, Input):
+ input = outcome
+ break
+ yield outcome
+ if outcome.block:
+ return
else:
- input = intercepted
- return output_text
+ if stream is None:
+ yield Output.empty()
+ return
+ if isinstance(stream, Output):
+ yield stream
+ if stream.block:
+ return
+ continue
+ input = stream
+ yield input.get_plain_text()
- def match(self, text: str) -> EventQueue[Callable[[Input], Input | Output]]:
+ def match(
+ self, text: str
+ ) -> EventQueue[
+ Callable[[Input], Input | Output | Generator[Input | Output, Any, None]]
+ ]:
queue = EventQueue()
for interceptor in self.interceptors: