aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/infini/generator.py
blob: e2730ca3cb2ae447d55bc6dedc32eec677c68896 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from infini.output import Output
from infini.typing import Dict, Callable, Union, Optional
from infini.exceptions import UnknownEvent, UnknownEventType
from infini.injector import Injector
from jinja2 import Template

import abc


class BaseGenerator(metaclass=abc.ABCMeta):
    type: str
    events: Dict[str, str]
    global_variables: Dict[str, Union[str, Callable]]

    @abc.abstractmethod
    def output(self, output: Output, injector: Injector) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def match(self, output: Output) -> Template:
        raise NotImplementedError


class TextGenerator(BaseGenerator):
    type = "text"

    def __init__(self) -> None:
        self.events = {}
        self.global_variables = {}

    def output(self, output: Output, injector: Injector) -> str:
        assert (
            output.type == self.type
        ), f"文本生成器应当传入类型为 '{self.type}' 的 Output 实例"
        variables = self.global_variables.copy()
        variables.update(output.variables)
        for name, variable in variables.items():
            if callable(variable):
                variables[name] = injector.output(variable, variables)
        return self.match(output).render(variables)

    def match(self, output: Output) -> Template:
        if context := self.events.get(output.name):
            return Template(context)
        raise UnknownEvent(f"事件不存在: {output.name}")


class Generator:
    generators: Dict[str, BaseGenerator]
    events: Dict[str, str]
    global_variables: Dict[str, Union[str, Callable]]

    def __init__(self) -> None:
        self.generators = {"text": TextGenerator()}

    def output(self, output: Output, injector: Injector) -> str:
        assert (
            output.type != "workflow"
        ), "生成器应当传入类型为非 'workflow' 的 Output 实例"
        if not (generator := self.match(output)):
            raise UnknownEventType(f"没有为事件类型 '{output.type}' 注册生成器")

        generator.events = self.events
        generator.global_variables = self.global_variables
        return generator.output(output, injector)

    def match(self, output: Output) -> Optional[BaseGenerator]:
        return self.generators.get(output.type)