telegrinder 0.1.dev171__py3-none-any.whl → 0.2.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of telegrinder might be problematic. Click here for more details.
- telegrinder/__init__.py +2 -2
- telegrinder/api/__init__.py +1 -2
- telegrinder/api/api.py +3 -3
- telegrinder/api/token.py +36 -0
- telegrinder/bot/__init__.py +12 -6
- telegrinder/bot/bot.py +12 -5
- telegrinder/bot/cute_types/__init__.py +7 -7
- telegrinder/bot/cute_types/base.py +17 -43
- telegrinder/bot/cute_types/callback_query.py +5 -6
- telegrinder/bot/cute_types/chat_join_request.py +4 -5
- telegrinder/bot/cute_types/chat_member_updated.py +3 -4
- telegrinder/bot/cute_types/inline_query.py +3 -4
- telegrinder/bot/cute_types/message.py +9 -10
- telegrinder/bot/cute_types/update.py +13 -13
- telegrinder/bot/cute_types/utils.py +1 -1
- telegrinder/bot/dispatch/__init__.py +9 -9
- telegrinder/bot/dispatch/abc.py +2 -2
- telegrinder/bot/dispatch/context.py +11 -2
- telegrinder/bot/dispatch/dispatch.py +18 -33
- telegrinder/bot/dispatch/handler/__init__.py +3 -3
- telegrinder/bot/dispatch/handler/abc.py +3 -3
- telegrinder/bot/dispatch/handler/func.py +17 -12
- telegrinder/bot/dispatch/handler/message_reply.py +6 -7
- telegrinder/bot/dispatch/middleware/__init__.py +1 -1
- telegrinder/bot/dispatch/process.py +30 -11
- telegrinder/bot/dispatch/return_manager/__init__.py +4 -4
- telegrinder/bot/dispatch/return_manager/callback_query.py +1 -2
- telegrinder/bot/dispatch/return_manager/inline_query.py +1 -2
- telegrinder/bot/dispatch/return_manager/message.py +1 -2
- telegrinder/bot/dispatch/view/__init__.py +8 -8
- telegrinder/bot/dispatch/view/abc.py +9 -4
- telegrinder/bot/dispatch/view/box.py +2 -2
- telegrinder/bot/dispatch/view/callback_query.py +1 -2
- telegrinder/bot/dispatch/view/chat_join_request.py +1 -2
- telegrinder/bot/dispatch/view/chat_member.py +16 -2
- telegrinder/bot/dispatch/view/inline_query.py +1 -2
- telegrinder/bot/dispatch/view/message.py +1 -2
- telegrinder/bot/dispatch/view/raw.py +8 -10
- telegrinder/bot/dispatch/waiter_machine/__init__.py +3 -3
- telegrinder/bot/dispatch/waiter_machine/machine.py +10 -6
- telegrinder/bot/dispatch/waiter_machine/short_state.py +2 -2
- telegrinder/bot/polling/abc.py +1 -1
- telegrinder/bot/polling/polling.py +3 -3
- telegrinder/bot/rules/__init__.py +20 -20
- telegrinder/bot/rules/abc.py +50 -40
- telegrinder/bot/rules/adapter/__init__.py +5 -5
- telegrinder/bot/rules/adapter/abc.py +6 -3
- telegrinder/bot/rules/adapter/errors.py +2 -1
- telegrinder/bot/rules/adapter/event.py +27 -15
- telegrinder/bot/rules/adapter/node.py +28 -22
- telegrinder/bot/rules/adapter/raw_update.py +13 -5
- telegrinder/bot/rules/callback_data.py +12 -6
- telegrinder/bot/rules/chat_join.py +4 -4
- telegrinder/bot/rules/func.py +1 -1
- telegrinder/bot/rules/inline.py +3 -3
- telegrinder/bot/rules/markup.py +3 -1
- telegrinder/bot/rules/message_entities.py +1 -1
- telegrinder/bot/rules/text.py +1 -2
- telegrinder/bot/rules/update.py +1 -2
- telegrinder/bot/scenario/abc.py +2 -2
- telegrinder/bot/scenario/checkbox.py +1 -2
- telegrinder/bot/scenario/choice.py +1 -2
- telegrinder/model.py +9 -2
- telegrinder/msgspec_utils.py +55 -55
- telegrinder/node/__init__.py +1 -3
- telegrinder/node/base.py +20 -85
- telegrinder/node/command.py +3 -3
- telegrinder/node/composer.py +71 -74
- telegrinder/node/container.py +3 -3
- telegrinder/node/event.py +45 -33
- telegrinder/node/me.py +3 -4
- telegrinder/node/polymorphic.py +12 -6
- telegrinder/node/rule.py +1 -9
- telegrinder/node/scope.py +9 -1
- telegrinder/node/source.py +11 -0
- telegrinder/node/update.py +6 -2
- telegrinder/rules.py +59 -0
- telegrinder/tools/error_handler/abc.py +2 -2
- telegrinder/tools/error_handler/error_handler.py +5 -5
- telegrinder/tools/global_context/global_context.py +1 -1
- telegrinder/tools/keyboard.py +1 -1
- telegrinder/tools/loop_wrapper/loop_wrapper.py +9 -9
- telegrinder/tools/magic.py +64 -19
- telegrinder/types/__init__.py +1 -0
- telegrinder/types/enums.py +1 -0
- telegrinder/types/methods.py +78 -11
- telegrinder/types/objects.py +46 -24
- telegrinder/verification_utils.py +1 -3
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/METADATA +1 -1
- telegrinder-0.2.0.post1.dist-info/RECORD +145 -0
- telegrinder/api/abc.py +0 -79
- telegrinder-0.1.dev171.dist-info/RECORD +0 -145
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/LICENSE +0 -0
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/WHEEL +0 -0
telegrinder/model.py
CHANGED
|
@@ -16,10 +16,17 @@ if typing.TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
T = typing.TypeVar("T")
|
|
18
18
|
|
|
19
|
+
|
|
20
|
+
def rename_field(name: str) -> str:
|
|
21
|
+
if name.endswith("_") and name.removesuffix("_") in keyword.kwlist:
|
|
22
|
+
return name.removesuffix("_")
|
|
23
|
+
return name if not keyword.iskeyword(name) else name + "_"
|
|
24
|
+
|
|
25
|
+
|
|
19
26
|
MODEL_CONFIG: typing.Final[dict[str, typing.Any]] = {
|
|
20
27
|
"omit_defaults": True,
|
|
21
28
|
"dict": True,
|
|
22
|
-
"rename":
|
|
29
|
+
"rename": rename_field,
|
|
23
30
|
}
|
|
24
31
|
|
|
25
32
|
|
|
@@ -168,7 +175,7 @@ class DataConverter:
|
|
|
168
175
|
data: Model,
|
|
169
176
|
serialize: bool = True,
|
|
170
177
|
) -> str | dict[str, typing.Any]:
|
|
171
|
-
converted_dct = self(data.
|
|
178
|
+
converted_dct = self(data.to_dict(), serialize=False)
|
|
172
179
|
return encoder.encode(converted_dct) if serialize is True else converted_dct
|
|
173
180
|
|
|
174
181
|
def convert_dct(
|
telegrinder/msgspec_utils.py
CHANGED
|
@@ -10,10 +10,16 @@ if typing.TYPE_CHECKING:
|
|
|
10
10
|
from datetime import datetime
|
|
11
11
|
|
|
12
12
|
from fntypes.option import Option
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
def get_class_annotations(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
|
|
15
|
+
|
|
16
|
+
def get_type_hints(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
|
|
17
|
+
|
|
14
18
|
else:
|
|
15
19
|
from datetime import datetime as dt
|
|
16
20
|
|
|
21
|
+
from msgspec._utils import get_class_annotations, get_type_hints
|
|
22
|
+
|
|
17
23
|
Value = typing.TypeVar("Value")
|
|
18
24
|
Err = typing.TypeVar("Err")
|
|
19
25
|
|
|
@@ -23,16 +29,10 @@ else:
|
|
|
23
29
|
def __instancecheck__(cls, __instance: typing.Any) -> bool:
|
|
24
30
|
return isinstance(__instance, fntypes.option.Some | fntypes.option.Nothing)
|
|
25
31
|
|
|
26
|
-
class ResultMeta(type):
|
|
27
|
-
def __instancecheck__(cls, __instance: typing.Any) -> bool:
|
|
28
|
-
return isinstance(__instance, fntypes.result.Ok | fntypes.result.Error)
|
|
29
32
|
|
|
30
33
|
class Option(typing.Generic[Value], metaclass=OptionMeta):
|
|
31
34
|
pass
|
|
32
35
|
|
|
33
|
-
class Result(typing.Generic[Value, Err], metaclass=ResultMeta):
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
36
|
|
|
37
37
|
T = typing.TypeVar("T")
|
|
38
38
|
|
|
@@ -46,10 +46,29 @@ def get_origin(t: type[T]) -> type[T]:
|
|
|
46
46
|
return typing.cast(T, typing.get_origin(t)) or t
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def repr_type(t:
|
|
49
|
+
def repr_type(t: typing.Any) -> str:
|
|
50
50
|
return getattr(t, "__name__", repr(get_origin(t)))
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def is_common_type(type_: typing.Any) -> typing.TypeGuard[type[typing.Any]]:
|
|
54
|
+
if not isinstance(type_, type):
|
|
55
|
+
return False
|
|
56
|
+
return (
|
|
57
|
+
type_ in (str, int, float, bool, None, Variative)
|
|
58
|
+
or issubclass(type_, msgspec.Struct)
|
|
59
|
+
or hasattr(type_, "__dataclass_fields__")
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def type_check(obj: typing.Any, t: typing.Any) -> bool:
|
|
64
|
+
return (
|
|
65
|
+
isinstance(obj, t)
|
|
66
|
+
if isinstance(t, type)
|
|
67
|
+
and issubclass(t, msgspec.Struct)
|
|
68
|
+
else type(obj) in t if isinstance(t, tuple) else type(obj) is t
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
53
72
|
def msgspec_convert(obj: typing.Any, t: type[T]) -> Result[T, str]:
|
|
54
73
|
try:
|
|
55
74
|
return Ok(decoder.convert(obj, type=t, strict=True))
|
|
@@ -68,72 +87,62 @@ def msgspec_to_builtins(
|
|
|
68
87
|
str_keys: bool = False,
|
|
69
88
|
builtin_types: typing.Iterable[type[typing.Any]] | None = None,
|
|
70
89
|
order: typing.Literal["deterministic", "sorted"] | None = None,
|
|
71
|
-
) -> typing.Any:
|
|
72
|
-
|
|
90
|
+
) -> fntypes.result.Result[typing.Any, msgspec.ValidationError]:
|
|
91
|
+
try:
|
|
92
|
+
return Ok(encoder.to_builtins(**locals()))
|
|
93
|
+
except msgspec.ValidationError as exc:
|
|
94
|
+
return Error(exc)
|
|
73
95
|
|
|
74
96
|
|
|
75
97
|
def option_dec_hook(tp: type[Option[typing.Any]], obj: typing.Any) -> Option[typing.Any]:
|
|
76
|
-
|
|
77
|
-
|
|
98
|
+
if obj is None:
|
|
99
|
+
return Nothing
|
|
78
100
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def result_dec_hook(
|
|
85
|
-
tp: type[Result[typing.Any, typing.Any]], obj: typing.Any
|
|
86
|
-
) -> Result[typing.Any, typing.Any]:
|
|
87
|
-
if not isinstance(obj, dict):
|
|
88
|
-
raise TypeError(f"Cannot parse to Result object of type `{repr_type(type(obj))}`.")
|
|
89
|
-
|
|
90
|
-
orig_type = get_origin(tp)
|
|
91
|
-
(first_type, second_type) = (
|
|
92
|
-
typing.get_args(tp) + (typing.Any,) if len(typing.get_args(tp)) == 1 else typing.get_args(tp)
|
|
93
|
-
) or (typing.Any, typing.Any)
|
|
101
|
+
(value_type,) = typing.get_args(tp) or (typing.Any,)
|
|
102
|
+
orig_value_type = typing.get_origin(value_type) or value_type
|
|
103
|
+
orig_obj = obj
|
|
94
104
|
|
|
95
|
-
if
|
|
96
|
-
|
|
105
|
+
if not isinstance(orig_obj, dict | list) and is_common_type(orig_value_type):
|
|
106
|
+
if orig_value_type is Variative:
|
|
107
|
+
obj = value_type(orig_obj) # type: ignore
|
|
108
|
+
orig_value_type = typing.get_args(value_type)
|
|
97
109
|
|
|
98
|
-
|
|
99
|
-
|
|
110
|
+
if not type_check(orig_obj, orig_value_type):
|
|
111
|
+
raise TypeError(f"Expected `{repr_type(orig_value_type)}`, got `{repr_type(type(orig_obj))}`.")
|
|
100
112
|
|
|
101
|
-
|
|
102
|
-
match obj:
|
|
103
|
-
case {"ok": ok}:
|
|
104
|
-
return Ok(msgspec_convert(ok, first_type).unwrap())
|
|
105
|
-
case {"error": error}:
|
|
106
|
-
return Error(msgspec_convert(error, second_type).unwrap())
|
|
113
|
+
return fntypes.option.Some(obj)
|
|
107
114
|
|
|
108
|
-
|
|
115
|
+
return fntypes.option.Some(decoder.convert(orig_obj, type=value_type))
|
|
109
116
|
|
|
110
117
|
|
|
111
118
|
def variative_dec_hook(tp: type[Variative], obj: typing.Any) -> Variative:
|
|
112
119
|
union_types = typing.get_args(tp)
|
|
113
120
|
|
|
114
121
|
if isinstance(obj, dict):
|
|
115
|
-
|
|
122
|
+
models_struct_fields: dict[type[msgspec.Struct], int] = {
|
|
116
123
|
m: sum(1 for k in obj if k in m.__struct_fields__)
|
|
117
124
|
for m in union_types
|
|
118
125
|
if issubclass(get_origin(m), msgspec.Struct)
|
|
119
126
|
}
|
|
120
|
-
union_types = tuple(t for t in union_types if t not in
|
|
127
|
+
union_types = tuple(t for t in union_types if t not in models_struct_fields)
|
|
121
128
|
reverse = False
|
|
122
129
|
|
|
123
|
-
if len(set(
|
|
124
|
-
|
|
130
|
+
if len(set(models_struct_fields.values())) != len(models_struct_fields.values()):
|
|
131
|
+
models_struct_fields = {m: len(m.__struct_fields__) for m in models_struct_fields}
|
|
125
132
|
reverse = True
|
|
126
133
|
|
|
127
134
|
union_types = (
|
|
128
135
|
*sorted(
|
|
129
|
-
|
|
130
|
-
key=lambda k:
|
|
136
|
+
models_struct_fields,
|
|
137
|
+
key=lambda k: models_struct_fields[k],
|
|
131
138
|
reverse=reverse,
|
|
132
139
|
),
|
|
133
140
|
*union_types,
|
|
134
141
|
)
|
|
135
142
|
|
|
136
143
|
for t in union_types:
|
|
144
|
+
if not isinstance(obj, dict | list) and is_common_type(t) and type_check(obj, t):
|
|
145
|
+
return tp(obj)
|
|
137
146
|
match msgspec_convert(obj, t):
|
|
138
147
|
case Ok(value):
|
|
139
148
|
return tp(value)
|
|
@@ -179,12 +188,9 @@ class Decoder:
|
|
|
179
188
|
|
|
180
189
|
def __init__(self) -> None:
|
|
181
190
|
self.dec_hooks: dict[typing.Any, DecHook[typing.Any]] = {
|
|
182
|
-
Result: result_dec_hook,
|
|
183
191
|
Option: option_dec_hook,
|
|
184
192
|
Variative: variative_dec_hook,
|
|
185
193
|
datetime: lambda t, obj: t.fromtimestamp(obj),
|
|
186
|
-
fntypes.result.Error: result_dec_hook,
|
|
187
|
-
fntypes.result.Ok: result_dec_hook,
|
|
188
194
|
fntypes.option.Some: option_dec_hook,
|
|
189
195
|
fntypes.option.Nothing: option_dec_hook,
|
|
190
196
|
}
|
|
@@ -271,10 +277,6 @@ class Encoder:
|
|
|
271
277
|
self.enc_hooks: dict[typing.Any, EncHook[typing.Any]] = {
|
|
272
278
|
fntypes.option.Some: lambda opt: opt.value,
|
|
273
279
|
fntypes.option.Nothing: lambda _: None,
|
|
274
|
-
fntypes.result.Ok: lambda ok: {"ok": ok.value},
|
|
275
|
-
fntypes.result.Error: lambda err: {
|
|
276
|
-
"error": (str(err.error) if isinstance(err.error, BaseException) else err.error)
|
|
277
|
-
},
|
|
278
280
|
Variative: lambda variative: variative.v,
|
|
279
281
|
datetime: lambda date: int(date.timestamp()),
|
|
280
282
|
}
|
|
@@ -342,10 +344,8 @@ __all__ = (
|
|
|
342
344
|
"datetime",
|
|
343
345
|
"decoder",
|
|
344
346
|
"encoder",
|
|
345
|
-
"get_origin",
|
|
346
347
|
"msgspec_convert",
|
|
348
|
+
"get_class_annotations",
|
|
349
|
+
"get_type_hints",
|
|
347
350
|
"msgspec_to_builtins",
|
|
348
|
-
"option_dec_hook",
|
|
349
|
-
"repr_type",
|
|
350
|
-
"variative_dec_hook",
|
|
351
351
|
)
|
telegrinder/node/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .attachment import Attachment, Audio, Photo, Video
|
|
2
|
-
from .base import
|
|
2
|
+
from .base import ComposeError, DataNode, Node, ScalarNode, is_node
|
|
3
3
|
from .callback_query import CallbackQueryNode
|
|
4
4
|
from .command import CommandInfo
|
|
5
5
|
from .composer import Composition, NodeCollection, NodeSession, compose_node, compose_nodes
|
|
@@ -18,7 +18,6 @@ from .update import UpdateNode
|
|
|
18
18
|
__all__ = (
|
|
19
19
|
"Attachment",
|
|
20
20
|
"Audio",
|
|
21
|
-
"BaseNode",
|
|
22
21
|
"CallbackQueryNode",
|
|
23
22
|
"ChatSource",
|
|
24
23
|
"CommandInfo",
|
|
@@ -52,7 +51,6 @@ __all__ = (
|
|
|
52
51
|
"global_node",
|
|
53
52
|
"impl",
|
|
54
53
|
"is_node",
|
|
55
|
-
"node_impl",
|
|
56
54
|
"per_call",
|
|
57
55
|
"per_event",
|
|
58
56
|
)
|
telegrinder/node/base.py
CHANGED
|
@@ -3,17 +3,10 @@ import inspect
|
|
|
3
3
|
import typing
|
|
4
4
|
from types import AsyncGeneratorType
|
|
5
5
|
|
|
6
|
-
from telegrinder.api.api import API
|
|
7
|
-
from telegrinder.bot.cute_types.update import UpdateCute
|
|
8
|
-
from telegrinder.bot.dispatch.context import Context
|
|
9
6
|
from telegrinder.node.scope import NodeScope
|
|
10
7
|
from telegrinder.tools.magic import (
|
|
11
|
-
NODE_IMPL_MARK,
|
|
12
8
|
cache_magic_value,
|
|
13
9
|
get_annotations,
|
|
14
|
-
get_impls_by_key,
|
|
15
|
-
magic_bundle,
|
|
16
|
-
node_impl,
|
|
17
10
|
)
|
|
18
11
|
|
|
19
12
|
ComposeResult: typing.TypeAlias = typing.Awaitable[typing.Any] | typing.AsyncGenerator[typing.Any, None]
|
|
@@ -29,11 +22,6 @@ def is_node(maybe_node: type[typing.Any]) -> typing.TypeGuard[type["Node"]]:
|
|
|
29
22
|
)
|
|
30
23
|
|
|
31
24
|
|
|
32
|
-
@cache_magic_value("__compose_annotations__")
|
|
33
|
-
def get_compose_annotations(function: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
|
|
34
|
-
return {k: v for k, v in get_annotations(function).items() if not is_node(v)}
|
|
35
|
-
|
|
36
|
-
|
|
37
25
|
@cache_magic_value("__nodes__")
|
|
38
26
|
def get_nodes(function: typing.Callable[..., typing.Any]) -> dict[str, type["Node"]]:
|
|
39
27
|
return {k: v for k, v in get_annotations(function).items() if is_node(v)}
|
|
@@ -44,23 +32,19 @@ def is_generator(function: typing.Callable[..., typing.Any]) -> typing.TypeGuard
|
|
|
44
32
|
return inspect.isasyncgenfunction(function)
|
|
45
33
|
|
|
46
34
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if issubclass(node_cls, BaseNode):
|
|
51
|
-
impls |= get_impls_by_key(BaseNode, NODE_IMPL_MARK)
|
|
52
|
-
setattr(node_cls, "__node_impls__", impls)
|
|
53
|
-
return getattr(node_cls, "__node_impls__")
|
|
54
|
-
|
|
35
|
+
def get_node_calc_lst(node: type["Node"]) -> list[type["Node"]]:
|
|
36
|
+
"""Returns flattened list of node types in ordering required to calculate given node.
|
|
37
|
+
Provides caching for passed node type"""
|
|
55
38
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
for
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
39
|
+
if calc_lst := getattr(node, "__nodes_calc_lst__", None):
|
|
40
|
+
return calc_lst
|
|
41
|
+
nodes_lst: list[type["Node"]] = []
|
|
42
|
+
annotations = list(node.as_node().get_subnodes().values())
|
|
43
|
+
for node_type in annotations:
|
|
44
|
+
nodes_lst.extend(get_node_calc_lst(node_type))
|
|
45
|
+
calc_lst = [*nodes_lst, node]
|
|
46
|
+
setattr(node, "__nodes_calc_lst__", calc_lst)
|
|
47
|
+
return calc_lst
|
|
64
48
|
|
|
65
49
|
|
|
66
50
|
class ComposeError(BaseException):
|
|
@@ -76,47 +60,14 @@ class Node(abc.ABC):
|
|
|
76
60
|
def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
77
61
|
pass
|
|
78
62
|
|
|
79
|
-
@classmethod
|
|
80
|
-
async def compose_annotation(
|
|
81
|
-
cls,
|
|
82
|
-
annotation: typing.Any,
|
|
83
|
-
update: UpdateCute,
|
|
84
|
-
ctx: Context,
|
|
85
|
-
) -> typing.Any:
|
|
86
|
-
orig_annotation: type[typing.Any] = typing.get_origin(annotation) or annotation
|
|
87
|
-
n_impl = get_node_impl(orig_annotation, cls.get_node_impls())
|
|
88
|
-
if n_impl is None:
|
|
89
|
-
raise ComposeError(f"Node implementation for {orig_annotation!r} not found.")
|
|
90
|
-
|
|
91
|
-
result = n_impl(
|
|
92
|
-
cls,
|
|
93
|
-
**magic_bundle(
|
|
94
|
-
n_impl,
|
|
95
|
-
{"update": update, "context": ctx},
|
|
96
|
-
start_idx=0,
|
|
97
|
-
bundle_ctx=False,
|
|
98
|
-
),
|
|
99
|
-
)
|
|
100
|
-
if inspect.isawaitable(result):
|
|
101
|
-
return await result
|
|
102
|
-
return result
|
|
103
|
-
|
|
104
63
|
@classmethod
|
|
105
64
|
def compose_error(cls, error: str | None = None) -> typing.NoReturn:
|
|
106
65
|
raise ComposeError(error)
|
|
107
66
|
|
|
108
67
|
@classmethod
|
|
109
|
-
def
|
|
68
|
+
def get_subnodes(cls) -> dict[str, type["Node"]]:
|
|
110
69
|
return get_nodes(cls.compose)
|
|
111
70
|
|
|
112
|
-
@classmethod
|
|
113
|
-
def get_compose_annotations(cls) -> dict[str, typing.Any]:
|
|
114
|
-
return get_compose_annotations(cls.compose)
|
|
115
|
-
|
|
116
|
-
@classmethod
|
|
117
|
-
def get_node_impls(cls) -> dict[str, typing.Callable[..., typing.Any]]:
|
|
118
|
-
return get_node_impls(cls)
|
|
119
|
-
|
|
120
71
|
@classmethod
|
|
121
72
|
def as_node(cls) -> type[typing.Self]:
|
|
122
73
|
return cls
|
|
@@ -126,26 +77,7 @@ class Node(abc.ABC):
|
|
|
126
77
|
return is_generator(cls.compose)
|
|
127
78
|
|
|
128
79
|
|
|
129
|
-
class
|
|
130
|
-
@classmethod
|
|
131
|
-
@abc.abstractmethod
|
|
132
|
-
def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
133
|
-
pass
|
|
134
|
-
|
|
135
|
-
@node_impl
|
|
136
|
-
def compose_api(cls, update: UpdateCute) -> API:
|
|
137
|
-
return update.ctx_api
|
|
138
|
-
|
|
139
|
-
@node_impl
|
|
140
|
-
def compose_context(cls, context: Context) -> Context:
|
|
141
|
-
return context
|
|
142
|
-
|
|
143
|
-
@node_impl
|
|
144
|
-
def compose_update(cls, update: UpdateCute) -> UpdateCute:
|
|
145
|
-
return update
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class DataNode(BaseNode, abc.ABC):
|
|
80
|
+
class DataNode(Node, abc.ABC):
|
|
149
81
|
node = "data"
|
|
150
82
|
|
|
151
83
|
@typing.dataclass_transform()
|
|
@@ -155,7 +87,7 @@ class DataNode(BaseNode, abc.ABC):
|
|
|
155
87
|
pass
|
|
156
88
|
|
|
157
89
|
|
|
158
|
-
class ScalarNodeProto(
|
|
90
|
+
class ScalarNodeProto(Node, abc.ABC):
|
|
159
91
|
@classmethod
|
|
160
92
|
@abc.abstractmethod
|
|
161
93
|
async def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
@@ -171,6 +103,10 @@ if typing.TYPE_CHECKING:
|
|
|
171
103
|
pass
|
|
172
104
|
|
|
173
105
|
else:
|
|
106
|
+
def __init_subclass__(cls, *args, **kwargs): # noqa: N807
|
|
107
|
+
if any(issubclass(base, ScalarNode) for base in cls.__bases__ if base is not ScalarNode):
|
|
108
|
+
raise RuntimeError("Scalar nodes do not support inheritance.")
|
|
109
|
+
|
|
174
110
|
def create_node(cls, bases, dct):
|
|
175
111
|
dct.update(cls.__dict__)
|
|
176
112
|
return type(cls.__name__, bases, dct)
|
|
@@ -182,6 +118,7 @@ else:
|
|
|
182
118
|
{
|
|
183
119
|
"as_node": classmethod(lambda cls: create_node(cls, bases, dct)),
|
|
184
120
|
"scope": Node.scope,
|
|
121
|
+
"__init_subclass__": __init_subclass__,
|
|
185
122
|
},
|
|
186
123
|
)
|
|
187
124
|
|
|
@@ -190,14 +127,12 @@ else:
|
|
|
190
127
|
|
|
191
128
|
|
|
192
129
|
__all__ = (
|
|
193
|
-
"BaseNode",
|
|
194
130
|
"ComposeError",
|
|
195
131
|
"DataNode",
|
|
196
132
|
"Node",
|
|
197
133
|
"SCALAR_NODE",
|
|
198
134
|
"ScalarNode",
|
|
199
135
|
"ScalarNodeProto",
|
|
200
|
-
"get_compose_annotations",
|
|
201
136
|
"get_nodes",
|
|
202
137
|
"is_node",
|
|
203
138
|
)
|
telegrinder/node/command.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
|
-
from fntypes import Nothing, Option, Some
|
|
4
|
+
from fntypes.option import Nothing, Option, Some
|
|
5
5
|
|
|
6
|
-
from .base import DataNode
|
|
7
|
-
from .text import Text
|
|
6
|
+
from telegrinder.node.base import DataNode
|
|
7
|
+
from telegrinder.node.text import Text
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def single_split(s: str, separator: str) -> tuple[str, str]:
|
telegrinder/node/composer.py
CHANGED
|
@@ -1,105 +1,97 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
|
+
from fntypes import Error, Ok, Result
|
|
4
5
|
from fntypes.error import UnwrapError
|
|
5
6
|
|
|
6
|
-
from telegrinder.
|
|
7
|
+
from telegrinder.api import API
|
|
8
|
+
from telegrinder.bot.cute_types.update import Update, UpdateCute
|
|
7
9
|
from telegrinder.bot.dispatch.context import Context
|
|
8
10
|
from telegrinder.modules import logger
|
|
9
11
|
from telegrinder.node.base import (
|
|
10
|
-
BaseNode,
|
|
11
12
|
ComposeError,
|
|
12
13
|
Node,
|
|
13
14
|
NodeScope,
|
|
14
|
-
|
|
15
|
+
get_node_calc_lst,
|
|
15
16
|
get_nodes,
|
|
16
17
|
)
|
|
17
18
|
from telegrinder.tools.magic import magic_bundle
|
|
18
19
|
|
|
19
|
-
CONTEXT_STORE_NODES_KEY = "
|
|
20
|
+
CONTEXT_STORE_NODES_KEY = "_node_ctx"
|
|
21
|
+
GLOBAL_VALUE_KEY = "_value"
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
async def compose_node(
|
|
23
25
|
_node: type[Node],
|
|
24
|
-
|
|
25
|
-
ctx: Context,
|
|
26
|
+
linked: dict[type, typing.Any],
|
|
26
27
|
) -> "NodeSession":
|
|
27
28
|
node = _node.as_node()
|
|
28
|
-
|
|
29
|
-
node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
|
|
30
|
-
|
|
31
|
-
for name, subnode in node.get_sub_nodes().items():
|
|
32
|
-
if subnode in node_ctx:
|
|
33
|
-
context.sessions[name] = node_ctx[subnode]
|
|
34
|
-
else:
|
|
35
|
-
context.sessions[name] = await compose_node(subnode, update, ctx)
|
|
36
|
-
|
|
37
|
-
if getattr(subnode, "scope", None) is NodeScope.PER_EVENT:
|
|
38
|
-
node_ctx[subnode] = context.sessions[name]
|
|
39
|
-
|
|
40
|
-
for name, annotation in node.get_compose_annotations().items():
|
|
41
|
-
context.sessions[name] = NodeSession(
|
|
42
|
-
None,
|
|
43
|
-
await node.compose_annotation(annotation, update, ctx),
|
|
44
|
-
{},
|
|
45
|
-
)
|
|
29
|
+
kwargs = magic_bundle(node.compose, linked, typebundle=True)
|
|
46
30
|
|
|
47
31
|
if node.is_generator():
|
|
48
|
-
generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**
|
|
32
|
+
generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**kwargs))
|
|
49
33
|
value = await generator.asend(None)
|
|
50
34
|
else:
|
|
51
35
|
generator = None
|
|
52
|
-
value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**
|
|
36
|
+
value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**kwargs))
|
|
53
37
|
|
|
54
|
-
return NodeSession(_node, value,
|
|
38
|
+
return NodeSession(_node, value, {}, generator)
|
|
55
39
|
|
|
56
40
|
|
|
57
41
|
async def compose_nodes(
|
|
58
|
-
update: UpdateCute,
|
|
59
|
-
ctx: Context,
|
|
60
42
|
nodes: dict[str, type[Node]],
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
) -> "NodeCollection
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
43
|
+
ctx: Context,
|
|
44
|
+
data: dict[type, typing.Any] | None = None,
|
|
45
|
+
) -> Result["NodeCollection", ComposeError]:
|
|
46
|
+
logger.debug("Composing nodes: {!r}...", nodes)
|
|
47
|
+
|
|
48
|
+
parent_nodes: dict[type[Node], NodeSession] = {}
|
|
49
|
+
event_nodes: dict[type[Node], NodeSession] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
|
|
50
|
+
data = {Context: ctx} | (data or {})
|
|
51
|
+
|
|
52
|
+
# Create flattened list of ordered nodes to be calculated
|
|
53
|
+
# TODO: optimize flattened list calculation via caching key = tuple of node types
|
|
54
|
+
calculation_nodes: list[list[type[Node]]] = []
|
|
55
|
+
for node_t in nodes.values():
|
|
56
|
+
calculation_nodes.append(get_node_calc_lst(node_t))
|
|
57
|
+
|
|
58
|
+
for linked_nodes in calculation_nodes:
|
|
59
|
+
local_nodes: dict[type[Node], "NodeSession"] = {}
|
|
60
|
+
for node_t in linked_nodes:
|
|
69
61
|
scope = getattr(node_t, "scope", None)
|
|
70
62
|
|
|
71
|
-
if scope is NodeScope.PER_EVENT and node_t in
|
|
72
|
-
|
|
63
|
+
if scope is NodeScope.PER_EVENT and node_t in event_nodes:
|
|
64
|
+
local_nodes[node_t] = event_nodes[node_t]
|
|
73
65
|
continue
|
|
74
|
-
elif scope is NodeScope.GLOBAL and hasattr(node_t,
|
|
75
|
-
|
|
66
|
+
elif scope is NodeScope.GLOBAL and hasattr(node_t, GLOBAL_VALUE_KEY):
|
|
67
|
+
local_nodes[node_t] = getattr(node_t, GLOBAL_VALUE_KEY)
|
|
76
68
|
continue
|
|
77
69
|
|
|
78
|
-
|
|
70
|
+
subnodes = {
|
|
71
|
+
k: session.value
|
|
72
|
+
for k, session in
|
|
73
|
+
(local_nodes | event_nodes).items()
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
local_nodes[node_t] = await compose_node(node_t, subnodes | data)
|
|
78
|
+
except (ComposeError, UnwrapError) as exc:
|
|
79
|
+
for t, local_node in local_nodes.items():
|
|
80
|
+
if t.scope is NodeScope.PER_CALL:
|
|
81
|
+
await local_node.close()
|
|
82
|
+
return Error(ComposeError(f"Cannot compose {node_t}. Error: {exc}"))
|
|
79
83
|
|
|
80
84
|
if scope is NodeScope.PER_EVENT:
|
|
81
|
-
|
|
85
|
+
event_nodes[node_t] = local_nodes[node_t]
|
|
82
86
|
elif scope is NodeScope.GLOBAL:
|
|
83
|
-
setattr(node_t,
|
|
84
|
-
except (ComposeError, UnwrapError) as exc:
|
|
85
|
-
logger.debug(f"Composing node (name={name!r}, node_class={node_t!r}) failed with error: {str(exc)!r}")
|
|
86
|
-
await NodeCollection(node_sessions).close_all()
|
|
87
|
-
return None
|
|
87
|
+
setattr(node_t, GLOBAL_VALUE_KEY, local_nodes[node_t])
|
|
88
88
|
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
# Last node is the parent node
|
|
90
|
+
parent_node_t = linked_nodes[-1]
|
|
91
|
+
parent_nodes[parent_node_t] = local_nodes[parent_node_t]
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
node_sessions[name] = await node_class.compose_annotation(annotation, update, ctx)
|
|
95
|
-
except (ComposeError, UnwrapError) as exc:
|
|
96
|
-
logger.debug(
|
|
97
|
-
f"Composing context annotation (name={name!r}, annotation={annotation!r}) failed with error: {str(exc)!r}",
|
|
98
|
-
)
|
|
99
|
-
await NodeCollection(node_sessions).close_all()
|
|
100
|
-
return None
|
|
101
|
-
|
|
102
|
-
return NodeCollection(node_sessions)
|
|
93
|
+
node_sessions = {k: parent_nodes[t] for k, t in nodes.items()}
|
|
94
|
+
return Ok(NodeCollection(node_sessions))
|
|
103
95
|
|
|
104
96
|
|
|
105
97
|
@dataclasses.dataclass(slots=True, repr=False)
|
|
@@ -107,7 +99,7 @@ class NodeSession:
|
|
|
107
99
|
node_type: type[Node] | None
|
|
108
100
|
value: typing.Any
|
|
109
101
|
subnodes: dict[str, typing.Self]
|
|
110
|
-
generator: typing.AsyncGenerator[typing.Any, None] | None = None
|
|
102
|
+
generator: typing.AsyncGenerator[typing.Any, typing.Any | None] | None = None
|
|
111
103
|
|
|
112
104
|
def __repr__(self) -> str:
|
|
113
105
|
return f"<{self.__class__.__name__}: {self.value!r}" + (" ACTIVE>" if self.generator else ">")
|
|
@@ -126,6 +118,7 @@ class NodeSession:
|
|
|
126
118
|
if self.generator is None:
|
|
127
119
|
return
|
|
128
120
|
try:
|
|
121
|
+
logger.debug("Closing session for node {!r}...", self.node_type)
|
|
129
122
|
await self.generator.asend(with_value)
|
|
130
123
|
except StopAsyncIteration:
|
|
131
124
|
self.generator = None
|
|
@@ -140,6 +133,7 @@ class NodeCollection:
|
|
|
140
133
|
def __repr__(self) -> str:
|
|
141
134
|
return "<{}: sessions={!r}>".format(self.__class__.__name__, self.sessions)
|
|
142
135
|
|
|
136
|
+
@property
|
|
143
137
|
def values(self) -> dict[str, typing.Any]:
|
|
144
138
|
return {name: session.value for name, session in self.sessions.items()}
|
|
145
139
|
|
|
@@ -156,30 +150,33 @@ class NodeCollection:
|
|
|
156
150
|
class Composition:
|
|
157
151
|
func: typing.Callable[..., typing.Any]
|
|
158
152
|
is_blocking: bool
|
|
159
|
-
node_class: type[Node] = dataclasses.field(default_factory=lambda: BaseNode)
|
|
160
153
|
nodes: dict[str, type[Node]] = dataclasses.field(init=False)
|
|
161
|
-
context_annotations: dict[str, typing.Any] = dataclasses.field(init=False)
|
|
162
154
|
|
|
163
155
|
def __post_init__(self) -> None:
|
|
164
156
|
self.nodes = get_nodes(self.func)
|
|
165
|
-
self.context_annotations = get_compose_annotations(self.func)
|
|
166
157
|
|
|
167
158
|
def __repr__(self) -> str:
|
|
168
|
-
return "<{}: for function={!r} with nodes={!r}
|
|
159
|
+
return "<{}: for function={!r} with nodes={!r}>".format(
|
|
169
160
|
("blocking " if self.is_blocking else "") + self.__class__.__name__,
|
|
170
161
|
self.func.__qualname__,
|
|
171
162
|
self.nodes,
|
|
172
|
-
self.context_annotations,
|
|
173
163
|
)
|
|
174
164
|
|
|
175
|
-
async def compose_nodes(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
165
|
+
async def compose_nodes(
|
|
166
|
+
self,
|
|
167
|
+
update: UpdateCute,
|
|
168
|
+
context: Context,
|
|
169
|
+
) -> NodeCollection | None:
|
|
170
|
+
match await compose_nodes(
|
|
179
171
|
nodes=self.nodes,
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
)
|
|
172
|
+
ctx=context,
|
|
173
|
+
data={Update: update, API: update.api},
|
|
174
|
+
):
|
|
175
|
+
case Ok(col):
|
|
176
|
+
return col
|
|
177
|
+
case Error(err):
|
|
178
|
+
logger.debug(f"Composition failed with error: {err}")
|
|
179
|
+
return None
|
|
183
180
|
|
|
184
181
|
async def __call__(self, **kwargs: typing.Any) -> typing.Any:
|
|
185
182
|
return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore
|