telegrinder 0.1.dev171__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of telegrinder might be problematic. Click here for more details.
- telegrinder/__init__.py +2 -2
- telegrinder/api/__init__.py +1 -2
- telegrinder/api/api.py +3 -3
- telegrinder/api/token.py +36 -0
- telegrinder/bot/__init__.py +12 -6
- telegrinder/bot/bot.py +12 -5
- telegrinder/bot/cute_types/__init__.py +7 -7
- telegrinder/bot/cute_types/base.py +7 -32
- telegrinder/bot/cute_types/callback_query.py +5 -6
- telegrinder/bot/cute_types/chat_join_request.py +4 -5
- telegrinder/bot/cute_types/chat_member_updated.py +3 -4
- telegrinder/bot/cute_types/inline_query.py +3 -4
- telegrinder/bot/cute_types/message.py +9 -10
- telegrinder/bot/cute_types/update.py +8 -9
- telegrinder/bot/cute_types/utils.py +1 -1
- telegrinder/bot/dispatch/__init__.py +9 -9
- telegrinder/bot/dispatch/abc.py +2 -2
- telegrinder/bot/dispatch/context.py +11 -2
- telegrinder/bot/dispatch/dispatch.py +18 -33
- telegrinder/bot/dispatch/handler/__init__.py +3 -3
- telegrinder/bot/dispatch/handler/abc.py +3 -3
- telegrinder/bot/dispatch/handler/func.py +17 -12
- telegrinder/bot/dispatch/handler/message_reply.py +6 -7
- telegrinder/bot/dispatch/middleware/__init__.py +1 -1
- telegrinder/bot/dispatch/process.py +30 -11
- telegrinder/bot/dispatch/return_manager/__init__.py +4 -4
- telegrinder/bot/dispatch/return_manager/callback_query.py +1 -2
- telegrinder/bot/dispatch/return_manager/inline_query.py +1 -2
- telegrinder/bot/dispatch/return_manager/message.py +1 -2
- telegrinder/bot/dispatch/view/__init__.py +8 -8
- telegrinder/bot/dispatch/view/abc.py +9 -4
- telegrinder/bot/dispatch/view/box.py +2 -2
- telegrinder/bot/dispatch/view/callback_query.py +1 -2
- telegrinder/bot/dispatch/view/chat_join_request.py +1 -2
- telegrinder/bot/dispatch/view/chat_member.py +16 -2
- telegrinder/bot/dispatch/view/inline_query.py +1 -2
- telegrinder/bot/dispatch/view/message.py +1 -2
- telegrinder/bot/dispatch/view/raw.py +8 -10
- telegrinder/bot/dispatch/waiter_machine/__init__.py +3 -3
- telegrinder/bot/dispatch/waiter_machine/machine.py +10 -6
- telegrinder/bot/dispatch/waiter_machine/short_state.py +2 -2
- telegrinder/bot/polling/abc.py +1 -1
- telegrinder/bot/polling/polling.py +3 -3
- telegrinder/bot/rules/__init__.py +20 -20
- telegrinder/bot/rules/abc.py +50 -40
- telegrinder/bot/rules/adapter/__init__.py +5 -5
- telegrinder/bot/rules/adapter/abc.py +6 -3
- telegrinder/bot/rules/adapter/errors.py +2 -1
- telegrinder/bot/rules/adapter/event.py +27 -15
- telegrinder/bot/rules/adapter/node.py +28 -22
- telegrinder/bot/rules/adapter/raw_update.py +13 -5
- telegrinder/bot/rules/callback_data.py +4 -4
- telegrinder/bot/rules/chat_join.py +4 -4
- telegrinder/bot/rules/func.py +1 -1
- telegrinder/bot/rules/inline.py +3 -3
- telegrinder/bot/rules/markup.py +3 -1
- telegrinder/bot/rules/message_entities.py +1 -1
- telegrinder/bot/rules/text.py +1 -2
- telegrinder/bot/rules/update.py +1 -2
- telegrinder/bot/scenario/abc.py +2 -2
- telegrinder/bot/scenario/checkbox.py +1 -2
- telegrinder/bot/scenario/choice.py +1 -2
- telegrinder/model.py +6 -1
- telegrinder/msgspec_utils.py +55 -55
- telegrinder/node/__init__.py +1 -3
- telegrinder/node/base.py +14 -86
- telegrinder/node/composer.py +71 -74
- telegrinder/node/container.py +3 -3
- telegrinder/node/event.py +40 -31
- telegrinder/node/polymorphic.py +12 -6
- telegrinder/node/rule.py +1 -9
- telegrinder/node/scope.py +9 -1
- telegrinder/node/source.py +11 -0
- telegrinder/node/update.py +6 -2
- telegrinder/rules.py +59 -0
- telegrinder/tools/error_handler/abc.py +2 -2
- telegrinder/tools/error_handler/error_handler.py +5 -5
- telegrinder/tools/global_context/global_context.py +1 -1
- telegrinder/tools/keyboard.py +1 -1
- telegrinder/tools/loop_wrapper/loop_wrapper.py +9 -9
- telegrinder/tools/magic.py +64 -19
- telegrinder/types/__init__.py +1 -0
- telegrinder/types/enums.py +1 -0
- telegrinder/types/methods.py +78 -11
- telegrinder/types/objects.py +46 -24
- telegrinder/verification_utils.py +1 -3
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/METADATA +1 -1
- telegrinder-0.2.0.dist-info/RECORD +145 -0
- telegrinder/api/abc.py +0 -79
- telegrinder-0.1.dev171.dist-info/RECORD +0 -145
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/LICENSE +0 -0
- {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/WHEEL +0 -0
telegrinder/msgspec_utils.py
CHANGED
|
@@ -10,10 +10,16 @@ if typing.TYPE_CHECKING:
|
|
|
10
10
|
from datetime import datetime
|
|
11
11
|
|
|
12
12
|
from fntypes.option import Option
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
def get_class_annotations(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
|
|
15
|
+
|
|
16
|
+
def get_type_hints(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
|
|
17
|
+
|
|
14
18
|
else:
|
|
15
19
|
from datetime import datetime as dt
|
|
16
20
|
|
|
21
|
+
from msgspec._utils import get_class_annotations, get_type_hints
|
|
22
|
+
|
|
17
23
|
Value = typing.TypeVar("Value")
|
|
18
24
|
Err = typing.TypeVar("Err")
|
|
19
25
|
|
|
@@ -23,16 +29,10 @@ else:
|
|
|
23
29
|
def __instancecheck__(cls, __instance: typing.Any) -> bool:
|
|
24
30
|
return isinstance(__instance, fntypes.option.Some | fntypes.option.Nothing)
|
|
25
31
|
|
|
26
|
-
class ResultMeta(type):
|
|
27
|
-
def __instancecheck__(cls, __instance: typing.Any) -> bool:
|
|
28
|
-
return isinstance(__instance, fntypes.result.Ok | fntypes.result.Error)
|
|
29
32
|
|
|
30
33
|
class Option(typing.Generic[Value], metaclass=OptionMeta):
|
|
31
34
|
pass
|
|
32
35
|
|
|
33
|
-
class Result(typing.Generic[Value, Err], metaclass=ResultMeta):
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
36
|
|
|
37
37
|
T = typing.TypeVar("T")
|
|
38
38
|
|
|
@@ -46,10 +46,29 @@ def get_origin(t: type[T]) -> type[T]:
|
|
|
46
46
|
return typing.cast(T, typing.get_origin(t)) or t
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def repr_type(t:
|
|
49
|
+
def repr_type(t: typing.Any) -> str:
|
|
50
50
|
return getattr(t, "__name__", repr(get_origin(t)))
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def is_common_type(type_: typing.Any) -> typing.TypeGuard[type[typing.Any]]:
|
|
54
|
+
if not isinstance(type_, type):
|
|
55
|
+
return False
|
|
56
|
+
return (
|
|
57
|
+
type_ in (str, int, float, bool, None, Variative)
|
|
58
|
+
or issubclass(type_, msgspec.Struct)
|
|
59
|
+
or hasattr(type_, "__dataclass_fields__")
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def type_check(obj: typing.Any, t: typing.Any) -> bool:
|
|
64
|
+
return (
|
|
65
|
+
isinstance(obj, t)
|
|
66
|
+
if isinstance(t, type)
|
|
67
|
+
and issubclass(t, msgspec.Struct)
|
|
68
|
+
else type(obj) in t if isinstance(t, tuple) else type(obj) is t
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
53
72
|
def msgspec_convert(obj: typing.Any, t: type[T]) -> Result[T, str]:
|
|
54
73
|
try:
|
|
55
74
|
return Ok(decoder.convert(obj, type=t, strict=True))
|
|
@@ -68,72 +87,62 @@ def msgspec_to_builtins(
|
|
|
68
87
|
str_keys: bool = False,
|
|
69
88
|
builtin_types: typing.Iterable[type[typing.Any]] | None = None,
|
|
70
89
|
order: typing.Literal["deterministic", "sorted"] | None = None,
|
|
71
|
-
) -> typing.Any:
|
|
72
|
-
|
|
90
|
+
) -> fntypes.result.Result[typing.Any, msgspec.ValidationError]:
|
|
91
|
+
try:
|
|
92
|
+
return Ok(encoder.to_builtins(**locals()))
|
|
93
|
+
except msgspec.ValidationError as exc:
|
|
94
|
+
return Error(exc)
|
|
73
95
|
|
|
74
96
|
|
|
75
97
|
def option_dec_hook(tp: type[Option[typing.Any]], obj: typing.Any) -> Option[typing.Any]:
|
|
76
|
-
|
|
77
|
-
|
|
98
|
+
if obj is None:
|
|
99
|
+
return Nothing
|
|
78
100
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def result_dec_hook(
|
|
85
|
-
tp: type[Result[typing.Any, typing.Any]], obj: typing.Any
|
|
86
|
-
) -> Result[typing.Any, typing.Any]:
|
|
87
|
-
if not isinstance(obj, dict):
|
|
88
|
-
raise TypeError(f"Cannot parse to Result object of type `{repr_type(type(obj))}`.")
|
|
89
|
-
|
|
90
|
-
orig_type = get_origin(tp)
|
|
91
|
-
(first_type, second_type) = (
|
|
92
|
-
typing.get_args(tp) + (typing.Any,) if len(typing.get_args(tp)) == 1 else typing.get_args(tp)
|
|
93
|
-
) or (typing.Any, typing.Any)
|
|
101
|
+
(value_type,) = typing.get_args(tp) or (typing.Any,)
|
|
102
|
+
orig_value_type = typing.get_origin(value_type) or value_type
|
|
103
|
+
orig_obj = obj
|
|
94
104
|
|
|
95
|
-
if
|
|
96
|
-
|
|
105
|
+
if not isinstance(orig_obj, dict | list) and is_common_type(orig_value_type):
|
|
106
|
+
if orig_value_type is Variative:
|
|
107
|
+
obj = value_type(orig_obj) # type: ignore
|
|
108
|
+
orig_value_type = typing.get_args(value_type)
|
|
97
109
|
|
|
98
|
-
|
|
99
|
-
|
|
110
|
+
if not type_check(orig_obj, orig_value_type):
|
|
111
|
+
raise TypeError(f"Expected `{repr_type(orig_value_type)}`, got `{repr_type(type(orig_obj))}`.")
|
|
100
112
|
|
|
101
|
-
|
|
102
|
-
match obj:
|
|
103
|
-
case {"ok": ok}:
|
|
104
|
-
return Ok(msgspec_convert(ok, first_type).unwrap())
|
|
105
|
-
case {"error": error}:
|
|
106
|
-
return Error(msgspec_convert(error, second_type).unwrap())
|
|
113
|
+
return fntypes.option.Some(obj)
|
|
107
114
|
|
|
108
|
-
|
|
115
|
+
return fntypes.option.Some(decoder.convert(orig_obj, type=value_type))
|
|
109
116
|
|
|
110
117
|
|
|
111
118
|
def variative_dec_hook(tp: type[Variative], obj: typing.Any) -> Variative:
|
|
112
119
|
union_types = typing.get_args(tp)
|
|
113
120
|
|
|
114
121
|
if isinstance(obj, dict):
|
|
115
|
-
|
|
122
|
+
models_struct_fields: dict[type[msgspec.Struct], int] = {
|
|
116
123
|
m: sum(1 for k in obj if k in m.__struct_fields__)
|
|
117
124
|
for m in union_types
|
|
118
125
|
if issubclass(get_origin(m), msgspec.Struct)
|
|
119
126
|
}
|
|
120
|
-
union_types = tuple(t for t in union_types if t not in
|
|
127
|
+
union_types = tuple(t for t in union_types if t not in models_struct_fields)
|
|
121
128
|
reverse = False
|
|
122
129
|
|
|
123
|
-
if len(set(
|
|
124
|
-
|
|
130
|
+
if len(set(models_struct_fields.values())) != len(models_struct_fields.values()):
|
|
131
|
+
models_struct_fields = {m: len(m.__struct_fields__) for m in models_struct_fields}
|
|
125
132
|
reverse = True
|
|
126
133
|
|
|
127
134
|
union_types = (
|
|
128
135
|
*sorted(
|
|
129
|
-
|
|
130
|
-
key=lambda k:
|
|
136
|
+
models_struct_fields,
|
|
137
|
+
key=lambda k: models_struct_fields[k],
|
|
131
138
|
reverse=reverse,
|
|
132
139
|
),
|
|
133
140
|
*union_types,
|
|
134
141
|
)
|
|
135
142
|
|
|
136
143
|
for t in union_types:
|
|
144
|
+
if not isinstance(obj, dict | list) and is_common_type(t) and type_check(obj, t):
|
|
145
|
+
return tp(obj)
|
|
137
146
|
match msgspec_convert(obj, t):
|
|
138
147
|
case Ok(value):
|
|
139
148
|
return tp(value)
|
|
@@ -179,12 +188,9 @@ class Decoder:
|
|
|
179
188
|
|
|
180
189
|
def __init__(self) -> None:
|
|
181
190
|
self.dec_hooks: dict[typing.Any, DecHook[typing.Any]] = {
|
|
182
|
-
Result: result_dec_hook,
|
|
183
191
|
Option: option_dec_hook,
|
|
184
192
|
Variative: variative_dec_hook,
|
|
185
193
|
datetime: lambda t, obj: t.fromtimestamp(obj),
|
|
186
|
-
fntypes.result.Error: result_dec_hook,
|
|
187
|
-
fntypes.result.Ok: result_dec_hook,
|
|
188
194
|
fntypes.option.Some: option_dec_hook,
|
|
189
195
|
fntypes.option.Nothing: option_dec_hook,
|
|
190
196
|
}
|
|
@@ -271,10 +277,6 @@ class Encoder:
|
|
|
271
277
|
self.enc_hooks: dict[typing.Any, EncHook[typing.Any]] = {
|
|
272
278
|
fntypes.option.Some: lambda opt: opt.value,
|
|
273
279
|
fntypes.option.Nothing: lambda _: None,
|
|
274
|
-
fntypes.result.Ok: lambda ok: {"ok": ok.value},
|
|
275
|
-
fntypes.result.Error: lambda err: {
|
|
276
|
-
"error": (str(err.error) if isinstance(err.error, BaseException) else err.error)
|
|
277
|
-
},
|
|
278
280
|
Variative: lambda variative: variative.v,
|
|
279
281
|
datetime: lambda date: int(date.timestamp()),
|
|
280
282
|
}
|
|
@@ -342,10 +344,8 @@ __all__ = (
|
|
|
342
344
|
"datetime",
|
|
343
345
|
"decoder",
|
|
344
346
|
"encoder",
|
|
345
|
-
"get_origin",
|
|
346
347
|
"msgspec_convert",
|
|
348
|
+
"get_class_annotations",
|
|
349
|
+
"get_type_hints",
|
|
347
350
|
"msgspec_to_builtins",
|
|
348
|
-
"option_dec_hook",
|
|
349
|
-
"repr_type",
|
|
350
|
-
"variative_dec_hook",
|
|
351
351
|
)
|
telegrinder/node/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .attachment import Attachment, Audio, Photo, Video
|
|
2
|
-
from .base import
|
|
2
|
+
from .base import ComposeError, DataNode, Node, ScalarNode, is_node
|
|
3
3
|
from .callback_query import CallbackQueryNode
|
|
4
4
|
from .command import CommandInfo
|
|
5
5
|
from .composer import Composition, NodeCollection, NodeSession, compose_node, compose_nodes
|
|
@@ -18,7 +18,6 @@ from .update import UpdateNode
|
|
|
18
18
|
__all__ = (
|
|
19
19
|
"Attachment",
|
|
20
20
|
"Audio",
|
|
21
|
-
"BaseNode",
|
|
22
21
|
"CallbackQueryNode",
|
|
23
22
|
"ChatSource",
|
|
24
23
|
"CommandInfo",
|
|
@@ -52,7 +51,6 @@ __all__ = (
|
|
|
52
51
|
"global_node",
|
|
53
52
|
"impl",
|
|
54
53
|
"is_node",
|
|
55
|
-
"node_impl",
|
|
56
54
|
"per_call",
|
|
57
55
|
"per_event",
|
|
58
56
|
)
|
telegrinder/node/base.py
CHANGED
|
@@ -3,17 +3,10 @@ import inspect
|
|
|
3
3
|
import typing
|
|
4
4
|
from types import AsyncGeneratorType
|
|
5
5
|
|
|
6
|
-
from telegrinder.api.api import API
|
|
7
|
-
from telegrinder.bot.cute_types.update import UpdateCute
|
|
8
|
-
from telegrinder.bot.dispatch.context import Context
|
|
9
6
|
from telegrinder.node.scope import NodeScope
|
|
10
7
|
from telegrinder.tools.magic import (
|
|
11
|
-
NODE_IMPL_MARK,
|
|
12
8
|
cache_magic_value,
|
|
13
9
|
get_annotations,
|
|
14
|
-
get_impls_by_key,
|
|
15
|
-
magic_bundle,
|
|
16
|
-
node_impl,
|
|
17
10
|
)
|
|
18
11
|
|
|
19
12
|
ComposeResult: typing.TypeAlias = typing.Awaitable[typing.Any] | typing.AsyncGenerator[typing.Any, None]
|
|
@@ -29,11 +22,6 @@ def is_node(maybe_node: type[typing.Any]) -> typing.TypeGuard[type["Node"]]:
|
|
|
29
22
|
)
|
|
30
23
|
|
|
31
24
|
|
|
32
|
-
@cache_magic_value("__compose_annotations__")
|
|
33
|
-
def get_compose_annotations(function: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
|
|
34
|
-
return {k: v for k, v in get_annotations(function).items() if not is_node(v)}
|
|
35
|
-
|
|
36
|
-
|
|
37
25
|
@cache_magic_value("__nodes__")
|
|
38
26
|
def get_nodes(function: typing.Callable[..., typing.Any]) -> dict[str, type["Node"]]:
|
|
39
27
|
return {k: v for k, v in get_annotations(function).items() if is_node(v)}
|
|
@@ -44,23 +32,17 @@ def is_generator(function: typing.Callable[..., typing.Any]) -> typing.TypeGuard
|
|
|
44
32
|
return inspect.isasyncgenfunction(function)
|
|
45
33
|
|
|
46
34
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
node_impls: dict[str, typing.Callable[..., typing.Any]],
|
|
59
|
-
) -> typing.Callable[..., typing.Any] | None:
|
|
60
|
-
for n_impl in node_impls.values():
|
|
61
|
-
if "return" in n_impl.__annotations__ and node is n_impl.__annotations__["return"]:
|
|
62
|
-
return n_impl
|
|
63
|
-
return None
|
|
35
|
+
def get_node_calc_lst(node: type["Node"]) -> list[type["Node"]]:
|
|
36
|
+
""" Returns flattened list of node types in ordering required to calculate given node. Provides caching for passed node type """
|
|
37
|
+
if calc_lst := getattr(node, "__nodes_calc_lst__", None):
|
|
38
|
+
return calc_lst
|
|
39
|
+
nodes_lst: list[type["Node"]] = []
|
|
40
|
+
annotations = list(node.as_node().get_subnodes().values())
|
|
41
|
+
for node_type in annotations:
|
|
42
|
+
nodes_lst.extend(get_node_calc_lst(node_type))
|
|
43
|
+
calc_lst = [*nodes_lst, node]
|
|
44
|
+
setattr(node, "__nodes_calc_lst__", calc_lst)
|
|
45
|
+
return calc_lst
|
|
64
46
|
|
|
65
47
|
|
|
66
48
|
class ComposeError(BaseException):
|
|
@@ -76,47 +58,14 @@ class Node(abc.ABC):
|
|
|
76
58
|
def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
77
59
|
pass
|
|
78
60
|
|
|
79
|
-
@classmethod
|
|
80
|
-
async def compose_annotation(
|
|
81
|
-
cls,
|
|
82
|
-
annotation: typing.Any,
|
|
83
|
-
update: UpdateCute,
|
|
84
|
-
ctx: Context,
|
|
85
|
-
) -> typing.Any:
|
|
86
|
-
orig_annotation: type[typing.Any] = typing.get_origin(annotation) or annotation
|
|
87
|
-
n_impl = get_node_impl(orig_annotation, cls.get_node_impls())
|
|
88
|
-
if n_impl is None:
|
|
89
|
-
raise ComposeError(f"Node implementation for {orig_annotation!r} not found.")
|
|
90
|
-
|
|
91
|
-
result = n_impl(
|
|
92
|
-
cls,
|
|
93
|
-
**magic_bundle(
|
|
94
|
-
n_impl,
|
|
95
|
-
{"update": update, "context": ctx},
|
|
96
|
-
start_idx=0,
|
|
97
|
-
bundle_ctx=False,
|
|
98
|
-
),
|
|
99
|
-
)
|
|
100
|
-
if inspect.isawaitable(result):
|
|
101
|
-
return await result
|
|
102
|
-
return result
|
|
103
|
-
|
|
104
61
|
@classmethod
|
|
105
62
|
def compose_error(cls, error: str | None = None) -> typing.NoReturn:
|
|
106
63
|
raise ComposeError(error)
|
|
107
64
|
|
|
108
65
|
@classmethod
|
|
109
|
-
def
|
|
66
|
+
def get_subnodes(cls) -> dict[str, type["Node"]]:
|
|
110
67
|
return get_nodes(cls.compose)
|
|
111
68
|
|
|
112
|
-
@classmethod
|
|
113
|
-
def get_compose_annotations(cls) -> dict[str, typing.Any]:
|
|
114
|
-
return get_compose_annotations(cls.compose)
|
|
115
|
-
|
|
116
|
-
@classmethod
|
|
117
|
-
def get_node_impls(cls) -> dict[str, typing.Callable[..., typing.Any]]:
|
|
118
|
-
return get_node_impls(cls)
|
|
119
|
-
|
|
120
69
|
@classmethod
|
|
121
70
|
def as_node(cls) -> type[typing.Self]:
|
|
122
71
|
return cls
|
|
@@ -126,26 +75,7 @@ class Node(abc.ABC):
|
|
|
126
75
|
return is_generator(cls.compose)
|
|
127
76
|
|
|
128
77
|
|
|
129
|
-
class
|
|
130
|
-
@classmethod
|
|
131
|
-
@abc.abstractmethod
|
|
132
|
-
def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
133
|
-
pass
|
|
134
|
-
|
|
135
|
-
@node_impl
|
|
136
|
-
def compose_api(cls, update: UpdateCute) -> API:
|
|
137
|
-
return update.ctx_api
|
|
138
|
-
|
|
139
|
-
@node_impl
|
|
140
|
-
def compose_context(cls, context: Context) -> Context:
|
|
141
|
-
return context
|
|
142
|
-
|
|
143
|
-
@node_impl
|
|
144
|
-
def compose_update(cls, update: UpdateCute) -> UpdateCute:
|
|
145
|
-
return update
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class DataNode(BaseNode, abc.ABC):
|
|
78
|
+
class DataNode(Node, abc.ABC):
|
|
149
79
|
node = "data"
|
|
150
80
|
|
|
151
81
|
@typing.dataclass_transform()
|
|
@@ -155,7 +85,7 @@ class DataNode(BaseNode, abc.ABC):
|
|
|
155
85
|
pass
|
|
156
86
|
|
|
157
87
|
|
|
158
|
-
class ScalarNodeProto(
|
|
88
|
+
class ScalarNodeProto(Node, abc.ABC):
|
|
159
89
|
@classmethod
|
|
160
90
|
@abc.abstractmethod
|
|
161
91
|
async def compose(cls, *args, **kwargs) -> ComposeResult:
|
|
@@ -190,14 +120,12 @@ else:
|
|
|
190
120
|
|
|
191
121
|
|
|
192
122
|
__all__ = (
|
|
193
|
-
"BaseNode",
|
|
194
123
|
"ComposeError",
|
|
195
124
|
"DataNode",
|
|
196
125
|
"Node",
|
|
197
126
|
"SCALAR_NODE",
|
|
198
127
|
"ScalarNode",
|
|
199
128
|
"ScalarNodeProto",
|
|
200
|
-
"get_compose_annotations",
|
|
201
129
|
"get_nodes",
|
|
202
130
|
"is_node",
|
|
203
131
|
)
|
telegrinder/node/composer.py
CHANGED
|
@@ -1,105 +1,97 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
|
+
from fntypes import Error, Ok, Result
|
|
4
5
|
from fntypes.error import UnwrapError
|
|
5
6
|
|
|
6
|
-
from telegrinder.
|
|
7
|
+
from telegrinder.api import API
|
|
8
|
+
from telegrinder.bot.cute_types.update import Update, UpdateCute
|
|
7
9
|
from telegrinder.bot.dispatch.context import Context
|
|
8
10
|
from telegrinder.modules import logger
|
|
9
11
|
from telegrinder.node.base import (
|
|
10
|
-
BaseNode,
|
|
11
12
|
ComposeError,
|
|
12
13
|
Node,
|
|
13
14
|
NodeScope,
|
|
14
|
-
|
|
15
|
+
get_node_calc_lst,
|
|
15
16
|
get_nodes,
|
|
16
17
|
)
|
|
17
18
|
from telegrinder.tools.magic import magic_bundle
|
|
18
19
|
|
|
19
|
-
CONTEXT_STORE_NODES_KEY = "
|
|
20
|
+
CONTEXT_STORE_NODES_KEY = "_node_ctx"
|
|
21
|
+
GLOBAL_VALUE_KEY = "_value"
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
async def compose_node(
|
|
23
25
|
_node: type[Node],
|
|
24
|
-
|
|
25
|
-
ctx: Context,
|
|
26
|
+
linked: dict[type, typing.Any],
|
|
26
27
|
) -> "NodeSession":
|
|
27
28
|
node = _node.as_node()
|
|
28
|
-
|
|
29
|
-
node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
|
|
30
|
-
|
|
31
|
-
for name, subnode in node.get_sub_nodes().items():
|
|
32
|
-
if subnode in node_ctx:
|
|
33
|
-
context.sessions[name] = node_ctx[subnode]
|
|
34
|
-
else:
|
|
35
|
-
context.sessions[name] = await compose_node(subnode, update, ctx)
|
|
36
|
-
|
|
37
|
-
if getattr(subnode, "scope", None) is NodeScope.PER_EVENT:
|
|
38
|
-
node_ctx[subnode] = context.sessions[name]
|
|
39
|
-
|
|
40
|
-
for name, annotation in node.get_compose_annotations().items():
|
|
41
|
-
context.sessions[name] = NodeSession(
|
|
42
|
-
None,
|
|
43
|
-
await node.compose_annotation(annotation, update, ctx),
|
|
44
|
-
{},
|
|
45
|
-
)
|
|
29
|
+
kwargs = magic_bundle(node.compose, linked, typebundle=True)
|
|
46
30
|
|
|
47
31
|
if node.is_generator():
|
|
48
|
-
generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**
|
|
32
|
+
generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**kwargs))
|
|
49
33
|
value = await generator.asend(None)
|
|
50
34
|
else:
|
|
51
35
|
generator = None
|
|
52
|
-
value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**
|
|
36
|
+
value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**kwargs))
|
|
53
37
|
|
|
54
|
-
return NodeSession(_node, value,
|
|
38
|
+
return NodeSession(_node, value, {}, generator)
|
|
55
39
|
|
|
56
40
|
|
|
57
41
|
async def compose_nodes(
|
|
58
|
-
update: UpdateCute,
|
|
59
|
-
ctx: Context,
|
|
60
42
|
nodes: dict[str, type[Node]],
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
) -> "NodeCollection
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
43
|
+
ctx: Context,
|
|
44
|
+
data: dict[type, typing.Any] | None = None,
|
|
45
|
+
) -> Result["NodeCollection", ComposeError]:
|
|
46
|
+
logger.debug("Composing nodes: {!r}...", nodes)
|
|
47
|
+
|
|
48
|
+
parent_nodes: dict[type[Node], NodeSession] = {}
|
|
49
|
+
event_nodes: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
|
|
50
|
+
data = {Context: ctx} | (data or {})
|
|
51
|
+
|
|
52
|
+
# Create flattened list of ordered nodes to be calculated
|
|
53
|
+
# TODO: optimize flattened list calculation via caching key = tuple of node types
|
|
54
|
+
calculation_nodes: list[list[type[Node]]] = []
|
|
55
|
+
for node_t in nodes.values():
|
|
56
|
+
calculation_nodes.append(get_node_calc_lst(node_t))
|
|
57
|
+
|
|
58
|
+
for linked_nodes in calculation_nodes:
|
|
59
|
+
local_nodes: dict[type[Node], "NodeSession"] = {}
|
|
60
|
+
for node_t in linked_nodes:
|
|
69
61
|
scope = getattr(node_t, "scope", None)
|
|
70
62
|
|
|
71
|
-
if scope is NodeScope.PER_EVENT and node_t in
|
|
72
|
-
|
|
63
|
+
if scope is NodeScope.PER_EVENT and node_t in event_nodes:
|
|
64
|
+
local_nodes[node_t] = event_nodes[node_t]
|
|
73
65
|
continue
|
|
74
|
-
elif scope is NodeScope.GLOBAL and hasattr(node_t,
|
|
75
|
-
|
|
66
|
+
elif scope is NodeScope.GLOBAL and hasattr(node_t, GLOBAL_VALUE_KEY):
|
|
67
|
+
local_nodes[node_t] = getattr(node_t, GLOBAL_VALUE_KEY)
|
|
76
68
|
continue
|
|
77
69
|
|
|
78
|
-
|
|
70
|
+
subnodes = {
|
|
71
|
+
k: session.value
|
|
72
|
+
for k, session in
|
|
73
|
+
(local_nodes | event_nodes).items()
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
local_nodes[node_t] = await compose_node(node_t, subnodes | data)
|
|
78
|
+
except (ComposeError, UnwrapError) as exc:
|
|
79
|
+
for t, local_node in local_nodes.items():
|
|
80
|
+
if t.scope is NodeScope.PER_CALL:
|
|
81
|
+
await local_node.close()
|
|
82
|
+
return Error(ComposeError(f"Cannot compose {node_t}. Error: {exc}"))
|
|
79
83
|
|
|
80
84
|
if scope is NodeScope.PER_EVENT:
|
|
81
|
-
|
|
85
|
+
event_nodes[node_t] = local_nodes[node_t]
|
|
82
86
|
elif scope is NodeScope.GLOBAL:
|
|
83
|
-
setattr(node_t,
|
|
84
|
-
except (ComposeError, UnwrapError) as exc:
|
|
85
|
-
logger.debug(f"Composing node (name={name!r}, node_class={node_t!r}) failed with error: {str(exc)!r}")
|
|
86
|
-
await NodeCollection(node_sessions).close_all()
|
|
87
|
-
return None
|
|
87
|
+
setattr(node_t, GLOBAL_VALUE_KEY, local_nodes[node_t])
|
|
88
88
|
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
# Last node is the parent node
|
|
90
|
+
parent_node_t = linked_nodes[-1]
|
|
91
|
+
parent_nodes[parent_node_t] = local_nodes[parent_node_t]
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
node_sessions[name] = await node_class.compose_annotation(annotation, update, ctx)
|
|
95
|
-
except (ComposeError, UnwrapError) as exc:
|
|
96
|
-
logger.debug(
|
|
97
|
-
f"Composing context annotation (name={name!r}, annotation={annotation!r}) failed with error: {str(exc)!r}",
|
|
98
|
-
)
|
|
99
|
-
await NodeCollection(node_sessions).close_all()
|
|
100
|
-
return None
|
|
101
|
-
|
|
102
|
-
return NodeCollection(node_sessions)
|
|
93
|
+
node_sessions = {k: parent_nodes[t] for k, t in nodes.items()}
|
|
94
|
+
return Ok(NodeCollection(node_sessions))
|
|
103
95
|
|
|
104
96
|
|
|
105
97
|
@dataclasses.dataclass(slots=True, repr=False)
|
|
@@ -107,7 +99,7 @@ class NodeSession:
|
|
|
107
99
|
node_type: type[Node] | None
|
|
108
100
|
value: typing.Any
|
|
109
101
|
subnodes: dict[str, typing.Self]
|
|
110
|
-
generator: typing.AsyncGenerator[typing.Any, None] | None = None
|
|
102
|
+
generator: typing.AsyncGenerator[typing.Any, typing.Any | None] | None = None
|
|
111
103
|
|
|
112
104
|
def __repr__(self) -> str:
|
|
113
105
|
return f"<{self.__class__.__name__}: {self.value!r}" + (" ACTIVE>" if self.generator else ">")
|
|
@@ -126,6 +118,7 @@ class NodeSession:
|
|
|
126
118
|
if self.generator is None:
|
|
127
119
|
return
|
|
128
120
|
try:
|
|
121
|
+
logger.debug("Closing session for node {!r}...", self.node_type)
|
|
129
122
|
await self.generator.asend(with_value)
|
|
130
123
|
except StopAsyncIteration:
|
|
131
124
|
self.generator = None
|
|
@@ -140,6 +133,7 @@ class NodeCollection:
|
|
|
140
133
|
def __repr__(self) -> str:
|
|
141
134
|
return "<{}: sessions={!r}>".format(self.__class__.__name__, self.sessions)
|
|
142
135
|
|
|
136
|
+
@property
|
|
143
137
|
def values(self) -> dict[str, typing.Any]:
|
|
144
138
|
return {name: session.value for name, session in self.sessions.items()}
|
|
145
139
|
|
|
@@ -156,30 +150,33 @@ class NodeCollection:
|
|
|
156
150
|
class Composition:
|
|
157
151
|
func: typing.Callable[..., typing.Any]
|
|
158
152
|
is_blocking: bool
|
|
159
|
-
node_class: type[Node] = dataclasses.field(default_factory=lambda: BaseNode)
|
|
160
153
|
nodes: dict[str, type[Node]] = dataclasses.field(init=False)
|
|
161
|
-
context_annotations: dict[str, typing.Any] = dataclasses.field(init=False)
|
|
162
154
|
|
|
163
155
|
def __post_init__(self) -> None:
|
|
164
156
|
self.nodes = get_nodes(self.func)
|
|
165
|
-
self.context_annotations = get_compose_annotations(self.func)
|
|
166
157
|
|
|
167
158
|
def __repr__(self) -> str:
|
|
168
|
-
return "<{}: for function={!r} with nodes={!r}
|
|
159
|
+
return "<{}: for function={!r} with nodes={!r}>".format(
|
|
169
160
|
("blocking " if self.is_blocking else "") + self.__class__.__name__,
|
|
170
161
|
self.func.__qualname__,
|
|
171
162
|
self.nodes,
|
|
172
|
-
self.context_annotations,
|
|
173
163
|
)
|
|
174
164
|
|
|
175
|
-
async def compose_nodes(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
165
|
+
async def compose_nodes(
|
|
166
|
+
self,
|
|
167
|
+
update: UpdateCute,
|
|
168
|
+
context: Context,
|
|
169
|
+
) -> NodeCollection | None:
|
|
170
|
+
match await compose_nodes(
|
|
179
171
|
nodes=self.nodes,
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
)
|
|
172
|
+
ctx=context,
|
|
173
|
+
data={Update: update, API: update.api},
|
|
174
|
+
):
|
|
175
|
+
case Ok(col):
|
|
176
|
+
return col
|
|
177
|
+
case Error(err):
|
|
178
|
+
logger.debug(f"Composition failed with error: {err}")
|
|
179
|
+
return None
|
|
183
180
|
|
|
184
181
|
async def __call__(self, **kwargs: typing.Any) -> typing.Any:
|
|
185
182
|
return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore
|
telegrinder/node/container.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
-
from telegrinder.node.base import
|
|
3
|
+
from telegrinder.node.base import Node
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class ContainerNode(
|
|
6
|
+
class ContainerNode(Node):
|
|
7
7
|
linked_nodes: typing.ClassVar[list[type[Node]]]
|
|
8
8
|
|
|
9
9
|
@classmethod
|
|
@@ -11,7 +11,7 @@ class ContainerNode(BaseNode):
|
|
|
11
11
|
return tuple(t[1] for t in sorted(kw.items(), key=lambda t: t[0]))
|
|
12
12
|
|
|
13
13
|
@classmethod
|
|
14
|
-
def
|
|
14
|
+
def get_subnodes(cls) -> dict[str, type[Node]]:
|
|
15
15
|
return {f"node_{i}": node_t for i, node_t in enumerate(cls.linked_nodes)}
|
|
16
16
|
|
|
17
17
|
@classmethod
|