telegrinder 0.1.dev171__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of telegrinder might be problematic. Click here for more details.

Files changed (92) hide show
  1. telegrinder/__init__.py +2 -2
  2. telegrinder/api/__init__.py +1 -2
  3. telegrinder/api/api.py +3 -3
  4. telegrinder/api/token.py +36 -0
  5. telegrinder/bot/__init__.py +12 -6
  6. telegrinder/bot/bot.py +12 -5
  7. telegrinder/bot/cute_types/__init__.py +7 -7
  8. telegrinder/bot/cute_types/base.py +7 -32
  9. telegrinder/bot/cute_types/callback_query.py +5 -6
  10. telegrinder/bot/cute_types/chat_join_request.py +4 -5
  11. telegrinder/bot/cute_types/chat_member_updated.py +3 -4
  12. telegrinder/bot/cute_types/inline_query.py +3 -4
  13. telegrinder/bot/cute_types/message.py +9 -10
  14. telegrinder/bot/cute_types/update.py +8 -9
  15. telegrinder/bot/cute_types/utils.py +1 -1
  16. telegrinder/bot/dispatch/__init__.py +9 -9
  17. telegrinder/bot/dispatch/abc.py +2 -2
  18. telegrinder/bot/dispatch/context.py +11 -2
  19. telegrinder/bot/dispatch/dispatch.py +18 -33
  20. telegrinder/bot/dispatch/handler/__init__.py +3 -3
  21. telegrinder/bot/dispatch/handler/abc.py +3 -3
  22. telegrinder/bot/dispatch/handler/func.py +17 -12
  23. telegrinder/bot/dispatch/handler/message_reply.py +6 -7
  24. telegrinder/bot/dispatch/middleware/__init__.py +1 -1
  25. telegrinder/bot/dispatch/process.py +30 -11
  26. telegrinder/bot/dispatch/return_manager/__init__.py +4 -4
  27. telegrinder/bot/dispatch/return_manager/callback_query.py +1 -2
  28. telegrinder/bot/dispatch/return_manager/inline_query.py +1 -2
  29. telegrinder/bot/dispatch/return_manager/message.py +1 -2
  30. telegrinder/bot/dispatch/view/__init__.py +8 -8
  31. telegrinder/bot/dispatch/view/abc.py +9 -4
  32. telegrinder/bot/dispatch/view/box.py +2 -2
  33. telegrinder/bot/dispatch/view/callback_query.py +1 -2
  34. telegrinder/bot/dispatch/view/chat_join_request.py +1 -2
  35. telegrinder/bot/dispatch/view/chat_member.py +16 -2
  36. telegrinder/bot/dispatch/view/inline_query.py +1 -2
  37. telegrinder/bot/dispatch/view/message.py +1 -2
  38. telegrinder/bot/dispatch/view/raw.py +8 -10
  39. telegrinder/bot/dispatch/waiter_machine/__init__.py +3 -3
  40. telegrinder/bot/dispatch/waiter_machine/machine.py +10 -6
  41. telegrinder/bot/dispatch/waiter_machine/short_state.py +2 -2
  42. telegrinder/bot/polling/abc.py +1 -1
  43. telegrinder/bot/polling/polling.py +3 -3
  44. telegrinder/bot/rules/__init__.py +20 -20
  45. telegrinder/bot/rules/abc.py +50 -40
  46. telegrinder/bot/rules/adapter/__init__.py +5 -5
  47. telegrinder/bot/rules/adapter/abc.py +6 -3
  48. telegrinder/bot/rules/adapter/errors.py +2 -1
  49. telegrinder/bot/rules/adapter/event.py +27 -15
  50. telegrinder/bot/rules/adapter/node.py +28 -22
  51. telegrinder/bot/rules/adapter/raw_update.py +13 -5
  52. telegrinder/bot/rules/callback_data.py +4 -4
  53. telegrinder/bot/rules/chat_join.py +4 -4
  54. telegrinder/bot/rules/func.py +1 -1
  55. telegrinder/bot/rules/inline.py +3 -3
  56. telegrinder/bot/rules/markup.py +3 -1
  57. telegrinder/bot/rules/message_entities.py +1 -1
  58. telegrinder/bot/rules/text.py +1 -2
  59. telegrinder/bot/rules/update.py +1 -2
  60. telegrinder/bot/scenario/abc.py +2 -2
  61. telegrinder/bot/scenario/checkbox.py +1 -2
  62. telegrinder/bot/scenario/choice.py +1 -2
  63. telegrinder/model.py +6 -1
  64. telegrinder/msgspec_utils.py +55 -55
  65. telegrinder/node/__init__.py +1 -3
  66. telegrinder/node/base.py +14 -86
  67. telegrinder/node/composer.py +71 -74
  68. telegrinder/node/container.py +3 -3
  69. telegrinder/node/event.py +40 -31
  70. telegrinder/node/polymorphic.py +12 -6
  71. telegrinder/node/rule.py +1 -9
  72. telegrinder/node/scope.py +9 -1
  73. telegrinder/node/source.py +11 -0
  74. telegrinder/node/update.py +6 -2
  75. telegrinder/rules.py +59 -0
  76. telegrinder/tools/error_handler/abc.py +2 -2
  77. telegrinder/tools/error_handler/error_handler.py +5 -5
  78. telegrinder/tools/global_context/global_context.py +1 -1
  79. telegrinder/tools/keyboard.py +1 -1
  80. telegrinder/tools/loop_wrapper/loop_wrapper.py +9 -9
  81. telegrinder/tools/magic.py +64 -19
  82. telegrinder/types/__init__.py +1 -0
  83. telegrinder/types/enums.py +1 -0
  84. telegrinder/types/methods.py +78 -11
  85. telegrinder/types/objects.py +46 -24
  86. telegrinder/verification_utils.py +1 -3
  87. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/METADATA +1 -1
  88. telegrinder-0.2.0.dist-info/RECORD +145 -0
  89. telegrinder/api/abc.py +0 -79
  90. telegrinder-0.1.dev171.dist-info/RECORD +0 -145
  91. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/LICENSE +0 -0
  92. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.dist-info}/WHEEL +0 -0
@@ -10,10 +10,16 @@ if typing.TYPE_CHECKING:
10
10
  from datetime import datetime
11
11
 
12
12
  from fntypes.option import Option
13
- from fntypes.result import Result
13
+
14
+ def get_class_annotations(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
15
+
16
+ def get_type_hints(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
17
+
14
18
  else:
15
19
  from datetime import datetime as dt
16
20
 
21
+ from msgspec._utils import get_class_annotations, get_type_hints
22
+
17
23
  Value = typing.TypeVar("Value")
18
24
  Err = typing.TypeVar("Err")
19
25
 
@@ -23,16 +29,10 @@ else:
23
29
  def __instancecheck__(cls, __instance: typing.Any) -> bool:
24
30
  return isinstance(__instance, fntypes.option.Some | fntypes.option.Nothing)
25
31
 
26
- class ResultMeta(type):
27
- def __instancecheck__(cls, __instance: typing.Any) -> bool:
28
- return isinstance(__instance, fntypes.result.Ok | fntypes.result.Error)
29
32
 
30
33
  class Option(typing.Generic[Value], metaclass=OptionMeta):
31
34
  pass
32
35
 
33
- class Result(typing.Generic[Value, Err], metaclass=ResultMeta):
34
- pass
35
-
36
36
 
37
37
  T = typing.TypeVar("T")
38
38
 
@@ -46,10 +46,29 @@ def get_origin(t: type[T]) -> type[T]:
46
46
  return typing.cast(T, typing.get_origin(t)) or t
47
47
 
48
48
 
49
- def repr_type(t: type) -> str:
49
+ def repr_type(t: typing.Any) -> str:
50
50
  return getattr(t, "__name__", repr(get_origin(t)))
51
51
 
52
52
 
53
+ def is_common_type(type_: typing.Any) -> typing.TypeGuard[type[typing.Any]]:
54
+ if not isinstance(type_, type):
55
+ return False
56
+ return (
57
+ type_ in (str, int, float, bool, None, Variative)
58
+ or issubclass(type_, msgspec.Struct)
59
+ or hasattr(type_, "__dataclass_fields__")
60
+ )
61
+
62
+
63
+ def type_check(obj: typing.Any, t: typing.Any) -> bool:
64
+ return (
65
+ isinstance(obj, t)
66
+ if isinstance(t, type)
67
+ and issubclass(t, msgspec.Struct)
68
+ else type(obj) in t if isinstance(t, tuple) else type(obj) is t
69
+ )
70
+
71
+
53
72
  def msgspec_convert(obj: typing.Any, t: type[T]) -> Result[T, str]:
54
73
  try:
55
74
  return Ok(decoder.convert(obj, type=t, strict=True))
@@ -68,72 +87,62 @@ def msgspec_to_builtins(
68
87
  str_keys: bool = False,
69
88
  builtin_types: typing.Iterable[type[typing.Any]] | None = None,
70
89
  order: typing.Literal["deterministic", "sorted"] | None = None,
71
- ) -> typing.Any:
72
- return encoder.to_builtins(**locals())
90
+ ) -> fntypes.result.Result[typing.Any, msgspec.ValidationError]:
91
+ try:
92
+ return Ok(encoder.to_builtins(**locals()))
93
+ except msgspec.ValidationError as exc:
94
+ return Error(exc)
73
95
 
74
96
 
75
97
  def option_dec_hook(tp: type[Option[typing.Any]], obj: typing.Any) -> Option[typing.Any]:
76
- orig_type = get_origin(tp)
77
- (value_type,) = typing.get_args(tp) or (typing.Any,)
98
+ if obj is None:
99
+ return Nothing
78
100
 
79
- if obj is None and orig_type in (fntypes.option.Nothing, Option):
80
- return fntypes.option.Nothing()
81
- return fntypes.option.Some(msgspec_convert(obj, value_type).unwrap())
82
-
83
-
84
- def result_dec_hook(
85
- tp: type[Result[typing.Any, typing.Any]], obj: typing.Any
86
- ) -> Result[typing.Any, typing.Any]:
87
- if not isinstance(obj, dict):
88
- raise TypeError(f"Cannot parse to Result object of type `{repr_type(type(obj))}`.")
89
-
90
- orig_type = get_origin(tp)
91
- (first_type, second_type) = (
92
- typing.get_args(tp) + (typing.Any,) if len(typing.get_args(tp)) == 1 else typing.get_args(tp)
93
- ) or (typing.Any, typing.Any)
101
+ (value_type,) = typing.get_args(tp) or (typing.Any,)
102
+ orig_value_type = typing.get_origin(value_type) or value_type
103
+ orig_obj = obj
94
104
 
95
- if orig_type is Ok and "ok" in obj:
96
- return Ok(msgspec_convert(obj["ok"], first_type).unwrap())
105
+ if not isinstance(orig_obj, dict | list) and is_common_type(orig_value_type):
106
+ if orig_value_type is Variative:
107
+ obj = value_type(orig_obj) # type: ignore
108
+ orig_value_type = typing.get_args(value_type)
97
109
 
98
- if orig_type is Error and "error" in obj:
99
- return Error(msgspec_convert(obj["error"], first_type).unwrap())
110
+ if not type_check(orig_obj, orig_value_type):
111
+ raise TypeError(f"Expected `{repr_type(orig_value_type)}`, got `{repr_type(type(orig_obj))}`.")
100
112
 
101
- if orig_type is Result:
102
- match obj:
103
- case {"ok": ok}:
104
- return Ok(msgspec_convert(ok, first_type).unwrap())
105
- case {"error": error}:
106
- return Error(msgspec_convert(error, second_type).unwrap())
113
+ return fntypes.option.Some(obj)
107
114
 
108
- raise msgspec.ValidationError(f"Cannot parse object `{obj!r}` to Result.")
115
+ return fntypes.option.Some(decoder.convert(orig_obj, type=value_type))
109
116
 
110
117
 
111
118
  def variative_dec_hook(tp: type[Variative], obj: typing.Any) -> Variative:
112
119
  union_types = typing.get_args(tp)
113
120
 
114
121
  if isinstance(obj, dict):
115
- struct_fields_match_sums: dict[type[msgspec.Struct], int] = {
122
+ models_struct_fields: dict[type[msgspec.Struct], int] = {
116
123
  m: sum(1 for k in obj if k in m.__struct_fields__)
117
124
  for m in union_types
118
125
  if issubclass(get_origin(m), msgspec.Struct)
119
126
  }
120
- union_types = tuple(t for t in union_types if t not in struct_fields_match_sums)
127
+ union_types = tuple(t for t in union_types if t not in models_struct_fields)
121
128
  reverse = False
122
129
 
123
- if len(set(struct_fields_match_sums.values())) != len(struct_fields_match_sums.values()):
124
- struct_fields_match_sums = {m: len(m.__struct_fields__) for m in struct_fields_match_sums}
130
+ if len(set(models_struct_fields.values())) != len(models_struct_fields.values()):
131
+ models_struct_fields = {m: len(m.__struct_fields__) for m in models_struct_fields}
125
132
  reverse = True
126
133
 
127
134
  union_types = (
128
135
  *sorted(
129
- struct_fields_match_sums,
130
- key=lambda k: struct_fields_match_sums[k],
136
+ models_struct_fields,
137
+ key=lambda k: models_struct_fields[k],
131
138
  reverse=reverse,
132
139
  ),
133
140
  *union_types,
134
141
  )
135
142
 
136
143
  for t in union_types:
144
+ if not isinstance(obj, dict | list) and is_common_type(t) and type_check(obj, t):
145
+ return tp(obj)
137
146
  match msgspec_convert(obj, t):
138
147
  case Ok(value):
139
148
  return tp(value)
@@ -179,12 +188,9 @@ class Decoder:
179
188
 
180
189
  def __init__(self) -> None:
181
190
  self.dec_hooks: dict[typing.Any, DecHook[typing.Any]] = {
182
- Result: result_dec_hook,
183
191
  Option: option_dec_hook,
184
192
  Variative: variative_dec_hook,
185
193
  datetime: lambda t, obj: t.fromtimestamp(obj),
186
- fntypes.result.Error: result_dec_hook,
187
- fntypes.result.Ok: result_dec_hook,
188
194
  fntypes.option.Some: option_dec_hook,
189
195
  fntypes.option.Nothing: option_dec_hook,
190
196
  }
@@ -271,10 +277,6 @@ class Encoder:
271
277
  self.enc_hooks: dict[typing.Any, EncHook[typing.Any]] = {
272
278
  fntypes.option.Some: lambda opt: opt.value,
273
279
  fntypes.option.Nothing: lambda _: None,
274
- fntypes.result.Ok: lambda ok: {"ok": ok.value},
275
- fntypes.result.Error: lambda err: {
276
- "error": (str(err.error) if isinstance(err.error, BaseException) else err.error)
277
- },
278
280
  Variative: lambda variative: variative.v,
279
281
  datetime: lambda date: int(date.timestamp()),
280
282
  }
@@ -342,10 +344,8 @@ __all__ = (
342
344
  "datetime",
343
345
  "decoder",
344
346
  "encoder",
345
- "get_origin",
346
347
  "msgspec_convert",
348
+ "get_class_annotations",
349
+ "get_type_hints",
347
350
  "msgspec_to_builtins",
348
- "option_dec_hook",
349
- "repr_type",
350
- "variative_dec_hook",
351
351
  )
@@ -1,5 +1,5 @@
1
1
  from .attachment import Attachment, Audio, Photo, Video
2
- from .base import BaseNode, ComposeError, DataNode, Node, ScalarNode, is_node, node_impl
2
+ from .base import ComposeError, DataNode, Node, ScalarNode, is_node
3
3
  from .callback_query import CallbackQueryNode
4
4
  from .command import CommandInfo
5
5
  from .composer import Composition, NodeCollection, NodeSession, compose_node, compose_nodes
@@ -18,7 +18,6 @@ from .update import UpdateNode
18
18
  __all__ = (
19
19
  "Attachment",
20
20
  "Audio",
21
- "BaseNode",
22
21
  "CallbackQueryNode",
23
22
  "ChatSource",
24
23
  "CommandInfo",
@@ -52,7 +51,6 @@ __all__ = (
52
51
  "global_node",
53
52
  "impl",
54
53
  "is_node",
55
- "node_impl",
56
54
  "per_call",
57
55
  "per_event",
58
56
  )
telegrinder/node/base.py CHANGED
@@ -3,17 +3,10 @@ import inspect
3
3
  import typing
4
4
  from types import AsyncGeneratorType
5
5
 
6
- from telegrinder.api.api import API
7
- from telegrinder.bot.cute_types.update import UpdateCute
8
- from telegrinder.bot.dispatch.context import Context
9
6
  from telegrinder.node.scope import NodeScope
10
7
  from telegrinder.tools.magic import (
11
- NODE_IMPL_MARK,
12
8
  cache_magic_value,
13
9
  get_annotations,
14
- get_impls_by_key,
15
- magic_bundle,
16
- node_impl,
17
10
  )
18
11
 
19
12
  ComposeResult: typing.TypeAlias = typing.Awaitable[typing.Any] | typing.AsyncGenerator[typing.Any, None]
@@ -29,11 +22,6 @@ def is_node(maybe_node: type[typing.Any]) -> typing.TypeGuard[type["Node"]]:
29
22
  )
30
23
 
31
24
 
32
- @cache_magic_value("__compose_annotations__")
33
- def get_compose_annotations(function: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
34
- return {k: v for k, v in get_annotations(function).items() if not is_node(v)}
35
-
36
-
37
25
  @cache_magic_value("__nodes__")
38
26
  def get_nodes(function: typing.Callable[..., typing.Any]) -> dict[str, type["Node"]]:
39
27
  return {k: v for k, v in get_annotations(function).items() if is_node(v)}
@@ -44,23 +32,17 @@ def is_generator(function: typing.Callable[..., typing.Any]) -> typing.TypeGuard
44
32
  return inspect.isasyncgenfunction(function)
45
33
 
46
34
 
47
- def get_node_impls(node_cls: type["Node"]) -> dict[str, typing.Any]:
48
- if not hasattr(node_cls, "__node_impls__"):
49
- impls = get_impls_by_key(node_cls, NODE_IMPL_MARK)
50
- if issubclass(node_cls, BaseNode):
51
- impls |= get_impls_by_key(BaseNode, NODE_IMPL_MARK)
52
- setattr(node_cls, "__node_impls__", impls)
53
- return getattr(node_cls, "__node_impls__")
54
-
55
-
56
- def get_node_impl(
57
- node: type[typing.Any],
58
- node_impls: dict[str, typing.Callable[..., typing.Any]],
59
- ) -> typing.Callable[..., typing.Any] | None:
60
- for n_impl in node_impls.values():
61
- if "return" in n_impl.__annotations__ and node is n_impl.__annotations__["return"]:
62
- return n_impl
63
- return None
35
+ def get_node_calc_lst(node: type["Node"]) -> list[type["Node"]]:
36
+ """ Returns flattened list of node types in ordering required to calculate given node. Provides caching for passed node type """
37
+ if calc_lst := getattr(node, "__nodes_calc_lst__", None):
38
+ return calc_lst
39
+ nodes_lst: list[type["Node"]] = []
40
+ annotations = list(node.as_node().get_subnodes().values())
41
+ for node_type in annotations:
42
+ nodes_lst.extend(get_node_calc_lst(node_type))
43
+ calc_lst = [*nodes_lst, node]
44
+ setattr(node, "__nodes_calc_lst__", calc_lst)
45
+ return calc_lst
64
46
 
65
47
 
66
48
  class ComposeError(BaseException):
@@ -76,47 +58,14 @@ class Node(abc.ABC):
76
58
  def compose(cls, *args, **kwargs) -> ComposeResult:
77
59
  pass
78
60
 
79
- @classmethod
80
- async def compose_annotation(
81
- cls,
82
- annotation: typing.Any,
83
- update: UpdateCute,
84
- ctx: Context,
85
- ) -> typing.Any:
86
- orig_annotation: type[typing.Any] = typing.get_origin(annotation) or annotation
87
- n_impl = get_node_impl(orig_annotation, cls.get_node_impls())
88
- if n_impl is None:
89
- raise ComposeError(f"Node implementation for {orig_annotation!r} not found.")
90
-
91
- result = n_impl(
92
- cls,
93
- **magic_bundle(
94
- n_impl,
95
- {"update": update, "context": ctx},
96
- start_idx=0,
97
- bundle_ctx=False,
98
- ),
99
- )
100
- if inspect.isawaitable(result):
101
- return await result
102
- return result
103
-
104
61
  @classmethod
105
62
  def compose_error(cls, error: str | None = None) -> typing.NoReturn:
106
63
  raise ComposeError(error)
107
64
 
108
65
  @classmethod
109
- def get_sub_nodes(cls) -> dict[str, type["Node"]]:
66
+ def get_subnodes(cls) -> dict[str, type["Node"]]:
110
67
  return get_nodes(cls.compose)
111
68
 
112
- @classmethod
113
- def get_compose_annotations(cls) -> dict[str, typing.Any]:
114
- return get_compose_annotations(cls.compose)
115
-
116
- @classmethod
117
- def get_node_impls(cls) -> dict[str, typing.Callable[..., typing.Any]]:
118
- return get_node_impls(cls)
119
-
120
69
  @classmethod
121
70
  def as_node(cls) -> type[typing.Self]:
122
71
  return cls
@@ -126,26 +75,7 @@ class Node(abc.ABC):
126
75
  return is_generator(cls.compose)
127
76
 
128
77
 
129
- class BaseNode(Node, abc.ABC):
130
- @classmethod
131
- @abc.abstractmethod
132
- def compose(cls, *args, **kwargs) -> ComposeResult:
133
- pass
134
-
135
- @node_impl
136
- def compose_api(cls, update: UpdateCute) -> API:
137
- return update.ctx_api
138
-
139
- @node_impl
140
- def compose_context(cls, context: Context) -> Context:
141
- return context
142
-
143
- @node_impl
144
- def compose_update(cls, update: UpdateCute) -> UpdateCute:
145
- return update
146
-
147
-
148
- class DataNode(BaseNode, abc.ABC):
78
+ class DataNode(Node, abc.ABC):
149
79
  node = "data"
150
80
 
151
81
  @typing.dataclass_transform()
@@ -155,7 +85,7 @@ class DataNode(BaseNode, abc.ABC):
155
85
  pass
156
86
 
157
87
 
158
- class ScalarNodeProto(BaseNode, abc.ABC):
88
+ class ScalarNodeProto(Node, abc.ABC):
159
89
  @classmethod
160
90
  @abc.abstractmethod
161
91
  async def compose(cls, *args, **kwargs) -> ComposeResult:
@@ -190,14 +120,12 @@ else:
190
120
 
191
121
 
192
122
  __all__ = (
193
- "BaseNode",
194
123
  "ComposeError",
195
124
  "DataNode",
196
125
  "Node",
197
126
  "SCALAR_NODE",
198
127
  "ScalarNode",
199
128
  "ScalarNodeProto",
200
- "get_compose_annotations",
201
129
  "get_nodes",
202
130
  "is_node",
203
131
  )
@@ -1,105 +1,97 @@
1
1
  import dataclasses
2
2
  import typing
3
3
 
4
+ from fntypes import Error, Ok, Result
4
5
  from fntypes.error import UnwrapError
5
6
 
6
- from telegrinder.bot.cute_types.update import UpdateCute
7
+ from telegrinder.api import API
8
+ from telegrinder.bot.cute_types.update import Update, UpdateCute
7
9
  from telegrinder.bot.dispatch.context import Context
8
10
  from telegrinder.modules import logger
9
11
  from telegrinder.node.base import (
10
- BaseNode,
11
12
  ComposeError,
12
13
  Node,
13
14
  NodeScope,
14
- get_compose_annotations,
15
+ get_node_calc_lst,
15
16
  get_nodes,
16
17
  )
17
18
  from telegrinder.tools.magic import magic_bundle
18
19
 
19
- CONTEXT_STORE_NODES_KEY = "node_ctx"
20
+ CONTEXT_STORE_NODES_KEY = "_node_ctx"
21
+ GLOBAL_VALUE_KEY = "_value"
20
22
 
21
23
 
22
24
  async def compose_node(
23
25
  _node: type[Node],
24
- update: UpdateCute,
25
- ctx: Context,
26
+ linked: dict[type, typing.Any],
26
27
  ) -> "NodeSession":
27
28
  node = _node.as_node()
28
- context = NodeCollection({})
29
- node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
30
-
31
- for name, subnode in node.get_sub_nodes().items():
32
- if subnode in node_ctx:
33
- context.sessions[name] = node_ctx[subnode]
34
- else:
35
- context.sessions[name] = await compose_node(subnode, update, ctx)
36
-
37
- if getattr(subnode, "scope", None) is NodeScope.PER_EVENT:
38
- node_ctx[subnode] = context.sessions[name]
39
-
40
- for name, annotation in node.get_compose_annotations().items():
41
- context.sessions[name] = NodeSession(
42
- None,
43
- await node.compose_annotation(annotation, update, ctx),
44
- {},
45
- )
29
+ kwargs = magic_bundle(node.compose, linked, typebundle=True)
46
30
 
47
31
  if node.is_generator():
48
- generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**context.values()))
32
+ generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**kwargs))
49
33
  value = await generator.asend(None)
50
34
  else:
51
35
  generator = None
52
- value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**context.values()))
36
+ value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**kwargs))
53
37
 
54
- return NodeSession(_node, value, context.sessions, generator)
38
+ return NodeSession(_node, value, {}, generator)
55
39
 
56
40
 
57
41
  async def compose_nodes(
58
- update: UpdateCute,
59
- ctx: Context,
60
42
  nodes: dict[str, type[Node]],
61
- node_class: type[Node] | None = None,
62
- context_annotations: dict[str, typing.Any] | None = None,
63
- ) -> "NodeCollection | None":
64
- node_sessions: dict[str, NodeSession] = {}
65
- node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
66
-
67
- try:
68
- for name, node_t in nodes.items():
43
+ ctx: Context,
44
+ data: dict[type, typing.Any] | None = None,
45
+ ) -> Result["NodeCollection", ComposeError]:
46
+ logger.debug("Composing nodes: {!r}...", nodes)
47
+
48
+ parent_nodes: dict[type[Node], NodeSession] = {}
49
+ event_nodes: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
50
+ data = {Context: ctx} | (data or {})
51
+
52
+ # Create flattened list of ordered nodes to be calculated
53
+ # TODO: optimize flattened list calculation via caching key = tuple of node types
54
+ calculation_nodes: list[list[type[Node]]] = []
55
+ for node_t in nodes.values():
56
+ calculation_nodes.append(get_node_calc_lst(node_t))
57
+
58
+ for linked_nodes in calculation_nodes:
59
+ local_nodes: dict[type[Node], "NodeSession"] = {}
60
+ for node_t in linked_nodes:
69
61
  scope = getattr(node_t, "scope", None)
70
62
 
71
- if scope is NodeScope.PER_EVENT and node_t in node_ctx:
72
- node_sessions[name] = node_ctx[node_t]
63
+ if scope is NodeScope.PER_EVENT and node_t in event_nodes:
64
+ local_nodes[node_t] = event_nodes[node_t]
73
65
  continue
74
- elif scope is NodeScope.GLOBAL and hasattr(node_t, "_value"):
75
- node_sessions[name] = getattr(node_t, "_value")
66
+ elif scope is NodeScope.GLOBAL and hasattr(node_t, GLOBAL_VALUE_KEY):
67
+ local_nodes[node_t] = getattr(node_t, GLOBAL_VALUE_KEY)
76
68
  continue
77
69
 
78
- node_sessions[name] = await compose_node(node_t, update, ctx)
70
+ subnodes = {
71
+ k: session.value
72
+ for k, session in
73
+ (local_nodes | event_nodes).items()
74
+ }
75
+
76
+ try:
77
+ local_nodes[node_t] = await compose_node(node_t, subnodes | data)
78
+ except (ComposeError, UnwrapError) as exc:
79
+ for t, local_node in local_nodes.items():
80
+ if t.scope is NodeScope.PER_CALL:
81
+ await local_node.close()
82
+ return Error(ComposeError(f"Cannot compose {node_t}. Error: {exc}"))
79
83
 
80
84
  if scope is NodeScope.PER_EVENT:
81
- node_ctx[node_t] = node_sessions[name]
85
+ event_nodes[node_t] = local_nodes[node_t]
82
86
  elif scope is NodeScope.GLOBAL:
83
- setattr(node_t, "_value", node_sessions[name])
84
- except (ComposeError, UnwrapError) as exc:
85
- logger.debug(f"Composing node (name={name!r}, node_class={node_t!r}) failed with error: {str(exc)!r}")
86
- await NodeCollection(node_sessions).close_all()
87
- return None
87
+ setattr(node_t, GLOBAL_VALUE_KEY, local_nodes[node_t])
88
88
 
89
- if context_annotations:
90
- node_class = node_class or BaseNode
89
+ # Last node is the parent node
90
+ parent_node_t = linked_nodes[-1]
91
+ parent_nodes[parent_node_t] = local_nodes[parent_node_t]
91
92
 
92
- try:
93
- for name, annotation in context_annotations.items():
94
- node_sessions[name] = await node_class.compose_annotation(annotation, update, ctx)
95
- except (ComposeError, UnwrapError) as exc:
96
- logger.debug(
97
- f"Composing context annotation (name={name!r}, annotation={annotation!r}) failed with error: {str(exc)!r}",
98
- )
99
- await NodeCollection(node_sessions).close_all()
100
- return None
101
-
102
- return NodeCollection(node_sessions)
93
+ node_sessions = {k: parent_nodes[t] for k, t in nodes.items()}
94
+ return Ok(NodeCollection(node_sessions))
103
95
 
104
96
 
105
97
  @dataclasses.dataclass(slots=True, repr=False)
@@ -107,7 +99,7 @@ class NodeSession:
107
99
  node_type: type[Node] | None
108
100
  value: typing.Any
109
101
  subnodes: dict[str, typing.Self]
110
- generator: typing.AsyncGenerator[typing.Any, None] | None = None
102
+ generator: typing.AsyncGenerator[typing.Any, typing.Any | None] | None = None
111
103
 
112
104
  def __repr__(self) -> str:
113
105
  return f"<{self.__class__.__name__}: {self.value!r}" + (" ACTIVE>" if self.generator else ">")
@@ -126,6 +118,7 @@ class NodeSession:
126
118
  if self.generator is None:
127
119
  return
128
120
  try:
121
+ logger.debug("Closing session for node {!r}...", self.node_type)
129
122
  await self.generator.asend(with_value)
130
123
  except StopAsyncIteration:
131
124
  self.generator = None
@@ -140,6 +133,7 @@ class NodeCollection:
140
133
  def __repr__(self) -> str:
141
134
  return "<{}: sessions={!r}>".format(self.__class__.__name__, self.sessions)
142
135
 
136
+ @property
143
137
  def values(self) -> dict[str, typing.Any]:
144
138
  return {name: session.value for name, session in self.sessions.items()}
145
139
 
@@ -156,30 +150,33 @@ class NodeCollection:
156
150
  class Composition:
157
151
  func: typing.Callable[..., typing.Any]
158
152
  is_blocking: bool
159
- node_class: type[Node] = dataclasses.field(default_factory=lambda: BaseNode)
160
153
  nodes: dict[str, type[Node]] = dataclasses.field(init=False)
161
- context_annotations: dict[str, typing.Any] = dataclasses.field(init=False)
162
154
 
163
155
  def __post_init__(self) -> None:
164
156
  self.nodes = get_nodes(self.func)
165
- self.context_annotations = get_compose_annotations(self.func)
166
157
 
167
158
  def __repr__(self) -> str:
168
- return "<{}: for function={!r} with nodes={!r}, context_annotations={!r}>".format(
159
+ return "<{}: for function={!r} with nodes={!r}>".format(
169
160
  ("blocking " if self.is_blocking else "") + self.__class__.__name__,
170
161
  self.func.__qualname__,
171
162
  self.nodes,
172
- self.context_annotations,
173
163
  )
174
164
 
175
- async def compose_nodes(self, update: UpdateCute, context: Context) -> NodeCollection | None:
176
- return await compose_nodes(
177
- update=update,
178
- ctx=context,
165
+ async def compose_nodes(
166
+ self,
167
+ update: UpdateCute,
168
+ context: Context,
169
+ ) -> NodeCollection | None:
170
+ match await compose_nodes(
179
171
  nodes=self.nodes,
180
- node_class=self.node_class,
181
- context_annotations=self.context_annotations,
182
- )
172
+ ctx=context,
173
+ data={Update: update, API: update.api},
174
+ ):
175
+ case Ok(col):
176
+ return col
177
+ case Error(err):
178
+ logger.debug(f"Composition failed with error: {err}")
179
+ return None
183
180
 
184
181
  async def __call__(self, **kwargs: typing.Any) -> typing.Any:
185
182
  return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore
@@ -1,9 +1,9 @@
1
1
  import typing
2
2
 
3
- from telegrinder.node.base import BaseNode, Node
3
+ from telegrinder.node.base import Node
4
4
 
5
5
 
6
- class ContainerNode(BaseNode):
6
+ class ContainerNode(Node):
7
7
  linked_nodes: typing.ClassVar[list[type[Node]]]
8
8
 
9
9
  @classmethod
@@ -11,7 +11,7 @@ class ContainerNode(BaseNode):
11
11
  return tuple(t[1] for t in sorted(kw.items(), key=lambda t: t[0]))
12
12
 
13
13
  @classmethod
14
- def get_sub_nodes(cls) -> dict[str, type[Node]]:
14
+ def get_subnodes(cls) -> dict[str, type[Node]]:
15
15
  return {f"node_{i}": node_t for i, node_t in enumerate(cls.linked_nodes)}
16
16
 
17
17
  @classmethod