aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
author苏向夜 <fu050409@163.com>2024-02-28 22:45:16 +0800
committer苏向夜 <fu050409@163.com>2024-02-28 22:45:16 +0800
commitb30be75b29d8320da812ecf7accd96af03424ad4 (patch)
treef2b0872e62c75fd3dcae7581f9c6451ac6164579
parent4a18d2fb508caf84f93712341fad4005cbfbc269 (diff)
downloadinfini-b30be75b29d8320da812ecf7accd96af03424ad4.tar.gz
infini-b30be75b29d8320da812ecf7accd96af03424ad4.zip
feat(injector): add injector feature
-rw-r--r--src/infini/core.py4
-rw-r--r--src/infini/generator.py8
-rw-r--r--src/infini/injector.py40
-rw-r--r--src/infini/loader.py12
-rw-r--r--src/infini/register.py12
5 files changed, 67 insertions, 9 deletions
diff --git a/src/infini/core.py b/src/infini/core.py
index 8fb46752..fe99f4b3 100644
--- a/src/infini/core.py
+++ b/src/infini/core.py
@@ -2,6 +2,7 @@ from infini.input import Input
from infini.interceptor import Interceptor
from infini.generator import TextGenerator
from infini.handler import Handler
+from infini.injector import Injector
from infini.output import Output
from infini.typing import Any, Generator, Union
from infini.exceptions import ValueError
@@ -12,6 +13,7 @@ class Core:
handler: Handler
generator: TextGenerator
interceptor: Interceptor
+ injector: Injector
def input(
self, input: Input
@@ -60,7 +62,7 @@ class Core:
yield output
def generate(self, output: Output) -> str:
- return self.generator.output(output)
+ return self.generator.output(output, self.injector)
def intercept(self, output_text: str) -> Generator[Union[str, Output], Any, None]:
return self.interceptor.output(output_text)
diff --git a/src/infini/generator.py b/src/infini/generator.py
index 53ca1a69..a8946949 100644
--- a/src/infini/generator.py
+++ b/src/infini/generator.py
@@ -1,6 +1,7 @@
from infini.output import Output
from infini.typing import Dict, Callable, Union
from infini.exceptions import UnknownEvent
+from infini.injector import Injector
from jinja2 import Template
@@ -12,10 +13,13 @@ class TextGenerator: # TODO 兼容多类型事件
self.events = {}
self.global_variables = {}
- def output(self, output: Output) -> str:
- assert output.type != "workflow", "Workflow 事件无法产出文本"
+ def output(self, output: Output, injector: Injector) -> str:
+ assert output.type == "text", "文本生成器应当传入类型为 'text' 的 Output 实例"
variables = self.global_variables.copy()
variables.update(output.variables)
+ for name, variable in variables.items():
+ if callable(variable):
+ variables[name] = injector.output(variable, output.variables)
return self.match(output).render(variables)
def match(self, output: Output) -> Template:
diff --git a/src/infini/injector.py b/src/infini/injector.py
new file mode 100644
index 00000000..b4c2bea4
--- /dev/null
+++ b/src/infini/injector.py
@@ -0,0 +1,40 @@
+from infini.typing import Callable, T, Optional, Dict, Any
+
+import inspect
+
+
+class Injector:
+ def __init__(self) -> None:
+ self.parameters: Dict[str, Any] = {}
+
+ def inject(
+ self, func: Callable[..., T], parameters: Optional[Dict[str, Any]] = None
+ ) -> Callable[[], T]:
+ signature = inspect.signature(func)
+ _parameters = {} if parameters is None else parameters
+ parameters = self.parameters.copy()
+ parameters.update(_parameters)
+ inject_params = {}
+ for param_name, param in signature.parameters.items():
+ default = None if param.default == inspect._empty else param.default
+ if param_name in parameters:
+ if not isinstance(parameters[param_name], param.annotation):
+ raise ValueError(
+ f"Parameter with name '{param_name}' has a mismatch type."
+ )
+ inject_params[param_name] = parameters[param_name]
+ else:
+ for parameter in parameters:
+ if isinstance(parameter, param.annotation):
+ inject_params[param_name] = parameter
+ break
+ else:
+ inject_params[param_name] = default
+ bound_args = signature.bind(**inject_params)
+ bound_args.apply_defaults()
+ return lambda: func(*bound_args.args, **bound_args.kwargs)
+
+ def output(
+ self, func: Callable[..., T], parameters: Optional[Dict[str, Any]] = None
+ ) -> T:
+ return self.inject(func, parameters)()
diff --git a/src/infini/loader.py b/src/infini/loader.py
index 4dbc392e..061ed8b9 100644
--- a/src/infini/loader.py
+++ b/src/infini/loader.py
@@ -2,6 +2,7 @@ from importlib.util import spec_from_file_location
from infini.core import Core
from infini.generator import TextGenerator
from infini.handler import Handler
+from infini.injector import Injector
from infini.interceptor import Interceptor
from infini.register import Register
from infini.typing import (
@@ -174,15 +175,18 @@ class Loader:
handler = Handler()
generator = TextGenerator()
interceptor = Interceptor()
+ injector = Injector()
self.inject_pre_interceptor(pre_interceptor)
self.inject_handler(handler)
self.inject_generator(generator)
self.inject_interceptor(interceptor)
+ self.inject_injector(injector)
core.pre_interceptor = pre_interceptor
core.handler = handler
core.generator = generator
core.interceptor = interceptor
+ core.injector = injector
def into_core(self) -> Core:
core = Core()
@@ -233,3 +237,11 @@ class Loader:
pre_interceptor = Interceptor()
self.inject_pre_interceptor(pre_interceptor)
return pre_interceptor
+
+ def inject_injector(self, injector: Injector):
+ injector.parameters = self.global_variables
+
+ def into_injector(self) -> Injector:
+ injector = Injector()
+ self.inject_injector(injector)
+ return injector
diff --git a/src/infini/register.py b/src/infini/register.py
index a2addc2b..0586ae7e 100644
--- a/src/infini/register.py
+++ b/src/infini/register.py
@@ -22,8 +22,8 @@ class Register:
def pre_interceptor(self, router: Union[Router, str], priority: int = 0):
def decorator(func):
@wraps(func)
- def wrapper(input: Input) -> Union[Input, Output]:
- return func(input)
+ def wrapper(*args, **kwargs) -> Union[Input, Output]:
+ return func(*args, **kwargs)
self.pre_interceptors.append(
{
@@ -39,8 +39,8 @@ class Register:
def handler(self, router: Union[Router, str], priority: int = 0):
def decorator(func):
@wraps(func)
- def wrapper(input: Input) -> Output:
- return func(input)
+ def wrapper(*args, **kwargs) -> Output:
+ return func(*args, **kwargs)
self.handlers.append(
{
@@ -73,8 +73,8 @@ class Register:
def interceptor(self, router: Union[Router, str], priority: int = 0):
def decorator(func):
@wraps(func)
- def wrapper(input: Input) -> Union[Input, Output]:
- return func(input)
+ def wrapper(*args, **kwargs) -> Union[Input, Output]:
+ return func(*args, **kwargs)
self.interceptors.append(
{