telegrinder 0.1.dev171__py3-none-any.whl → 0.2.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of telegrinder might be problematic. Click here for more details.

Files changed (94) hide show
  1. telegrinder/__init__.py +2 -2
  2. telegrinder/api/__init__.py +1 -2
  3. telegrinder/api/api.py +3 -3
  4. telegrinder/api/token.py +36 -0
  5. telegrinder/bot/__init__.py +12 -6
  6. telegrinder/bot/bot.py +12 -5
  7. telegrinder/bot/cute_types/__init__.py +7 -7
  8. telegrinder/bot/cute_types/base.py +17 -43
  9. telegrinder/bot/cute_types/callback_query.py +5 -6
  10. telegrinder/bot/cute_types/chat_join_request.py +4 -5
  11. telegrinder/bot/cute_types/chat_member_updated.py +3 -4
  12. telegrinder/bot/cute_types/inline_query.py +3 -4
  13. telegrinder/bot/cute_types/message.py +9 -10
  14. telegrinder/bot/cute_types/update.py +13 -13
  15. telegrinder/bot/cute_types/utils.py +1 -1
  16. telegrinder/bot/dispatch/__init__.py +9 -9
  17. telegrinder/bot/dispatch/abc.py +2 -2
  18. telegrinder/bot/dispatch/context.py +11 -2
  19. telegrinder/bot/dispatch/dispatch.py +18 -33
  20. telegrinder/bot/dispatch/handler/__init__.py +3 -3
  21. telegrinder/bot/dispatch/handler/abc.py +3 -3
  22. telegrinder/bot/dispatch/handler/func.py +17 -12
  23. telegrinder/bot/dispatch/handler/message_reply.py +6 -7
  24. telegrinder/bot/dispatch/middleware/__init__.py +1 -1
  25. telegrinder/bot/dispatch/process.py +30 -11
  26. telegrinder/bot/dispatch/return_manager/__init__.py +4 -4
  27. telegrinder/bot/dispatch/return_manager/callback_query.py +1 -2
  28. telegrinder/bot/dispatch/return_manager/inline_query.py +1 -2
  29. telegrinder/bot/dispatch/return_manager/message.py +1 -2
  30. telegrinder/bot/dispatch/view/__init__.py +8 -8
  31. telegrinder/bot/dispatch/view/abc.py +9 -4
  32. telegrinder/bot/dispatch/view/box.py +2 -2
  33. telegrinder/bot/dispatch/view/callback_query.py +1 -2
  34. telegrinder/bot/dispatch/view/chat_join_request.py +1 -2
  35. telegrinder/bot/dispatch/view/chat_member.py +16 -2
  36. telegrinder/bot/dispatch/view/inline_query.py +1 -2
  37. telegrinder/bot/dispatch/view/message.py +1 -2
  38. telegrinder/bot/dispatch/view/raw.py +8 -10
  39. telegrinder/bot/dispatch/waiter_machine/__init__.py +3 -3
  40. telegrinder/bot/dispatch/waiter_machine/machine.py +10 -6
  41. telegrinder/bot/dispatch/waiter_machine/short_state.py +2 -2
  42. telegrinder/bot/polling/abc.py +1 -1
  43. telegrinder/bot/polling/polling.py +3 -3
  44. telegrinder/bot/rules/__init__.py +20 -20
  45. telegrinder/bot/rules/abc.py +50 -40
  46. telegrinder/bot/rules/adapter/__init__.py +5 -5
  47. telegrinder/bot/rules/adapter/abc.py +6 -3
  48. telegrinder/bot/rules/adapter/errors.py +2 -1
  49. telegrinder/bot/rules/adapter/event.py +27 -15
  50. telegrinder/bot/rules/adapter/node.py +28 -22
  51. telegrinder/bot/rules/adapter/raw_update.py +13 -5
  52. telegrinder/bot/rules/callback_data.py +12 -6
  53. telegrinder/bot/rules/chat_join.py +4 -4
  54. telegrinder/bot/rules/func.py +1 -1
  55. telegrinder/bot/rules/inline.py +3 -3
  56. telegrinder/bot/rules/markup.py +3 -1
  57. telegrinder/bot/rules/message_entities.py +1 -1
  58. telegrinder/bot/rules/text.py +1 -2
  59. telegrinder/bot/rules/update.py +1 -2
  60. telegrinder/bot/scenario/abc.py +2 -2
  61. telegrinder/bot/scenario/checkbox.py +1 -2
  62. telegrinder/bot/scenario/choice.py +1 -2
  63. telegrinder/model.py +9 -2
  64. telegrinder/msgspec_utils.py +55 -55
  65. telegrinder/node/__init__.py +1 -3
  66. telegrinder/node/base.py +20 -85
  67. telegrinder/node/command.py +3 -3
  68. telegrinder/node/composer.py +71 -74
  69. telegrinder/node/container.py +3 -3
  70. telegrinder/node/event.py +45 -33
  71. telegrinder/node/me.py +3 -4
  72. telegrinder/node/polymorphic.py +12 -6
  73. telegrinder/node/rule.py +1 -9
  74. telegrinder/node/scope.py +9 -1
  75. telegrinder/node/source.py +11 -0
  76. telegrinder/node/update.py +6 -2
  77. telegrinder/rules.py +59 -0
  78. telegrinder/tools/error_handler/abc.py +2 -2
  79. telegrinder/tools/error_handler/error_handler.py +5 -5
  80. telegrinder/tools/global_context/global_context.py +1 -1
  81. telegrinder/tools/keyboard.py +1 -1
  82. telegrinder/tools/loop_wrapper/loop_wrapper.py +9 -9
  83. telegrinder/tools/magic.py +64 -19
  84. telegrinder/types/__init__.py +1 -0
  85. telegrinder/types/enums.py +1 -0
  86. telegrinder/types/methods.py +78 -11
  87. telegrinder/types/objects.py +46 -24
  88. telegrinder/verification_utils.py +1 -3
  89. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/METADATA +1 -1
  90. telegrinder-0.2.0.post1.dist-info/RECORD +145 -0
  91. telegrinder/api/abc.py +0 -79
  92. telegrinder-0.1.dev171.dist-info/RECORD +0 -145
  93. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/LICENSE +0 -0
  94. {telegrinder-0.1.dev171.dist-info → telegrinder-0.2.0.post1.dist-info}/WHEEL +0 -0
telegrinder/model.py CHANGED
@@ -16,10 +16,17 @@ if typing.TYPE_CHECKING:
16
16
 
17
17
  T = typing.TypeVar("T")
18
18
 
19
+
20
+ def rename_field(name: str) -> str:
21
+ if name.endswith("_") and name.removesuffix("_") in keyword.kwlist:
22
+ return name.removesuffix("_")
23
+ return name if not keyword.iskeyword(name) else name + "_"
24
+
25
+
19
26
  MODEL_CONFIG: typing.Final[dict[str, typing.Any]] = {
20
27
  "omit_defaults": True,
21
28
  "dict": True,
22
- "rename": {kw + "_": kw for kw in keyword.kwlist},
29
+ "rename": rename_field,
23
30
  }
24
31
 
25
32
 
@@ -168,7 +175,7 @@ class DataConverter:
168
175
  data: Model,
169
176
  serialize: bool = True,
170
177
  ) -> str | dict[str, typing.Any]:
171
- converted_dct = self(data.to_full_dict(), serialize=False)
178
+ converted_dct = self(data.to_dict(), serialize=False)
172
179
  return encoder.encode(converted_dct) if serialize is True else converted_dct
173
180
 
174
181
  def convert_dct(
@@ -10,10 +10,16 @@ if typing.TYPE_CHECKING:
10
10
  from datetime import datetime
11
11
 
12
12
  from fntypes.option import Option
13
- from fntypes.result import Result
13
+
14
+ def get_class_annotations(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
15
+
16
+ def get_type_hints(obj: typing.Any) -> dict[str, type[typing.Any]]: ...
17
+
14
18
  else:
15
19
  from datetime import datetime as dt
16
20
 
21
+ from msgspec._utils import get_class_annotations, get_type_hints
22
+
17
23
  Value = typing.TypeVar("Value")
18
24
  Err = typing.TypeVar("Err")
19
25
 
@@ -23,16 +29,10 @@ else:
23
29
  def __instancecheck__(cls, __instance: typing.Any) -> bool:
24
30
  return isinstance(__instance, fntypes.option.Some | fntypes.option.Nothing)
25
31
 
26
- class ResultMeta(type):
27
- def __instancecheck__(cls, __instance: typing.Any) -> bool:
28
- return isinstance(__instance, fntypes.result.Ok | fntypes.result.Error)
29
32
 
30
33
  class Option(typing.Generic[Value], metaclass=OptionMeta):
31
34
  pass
32
35
 
33
- class Result(typing.Generic[Value, Err], metaclass=ResultMeta):
34
- pass
35
-
36
36
 
37
37
  T = typing.TypeVar("T")
38
38
 
@@ -46,10 +46,29 @@ def get_origin(t: type[T]) -> type[T]:
46
46
  return typing.cast(T, typing.get_origin(t)) or t
47
47
 
48
48
 
49
- def repr_type(t: type) -> str:
49
+ def repr_type(t: typing.Any) -> str:
50
50
  return getattr(t, "__name__", repr(get_origin(t)))
51
51
 
52
52
 
53
+ def is_common_type(type_: typing.Any) -> typing.TypeGuard[type[typing.Any]]:
54
+ if not isinstance(type_, type):
55
+ return False
56
+ return (
57
+ type_ in (str, int, float, bool, None, Variative)
58
+ or issubclass(type_, msgspec.Struct)
59
+ or hasattr(type_, "__dataclass_fields__")
60
+ )
61
+
62
+
63
+ def type_check(obj: typing.Any, t: typing.Any) -> bool:
64
+ return (
65
+ isinstance(obj, t)
66
+ if isinstance(t, type)
67
+ and issubclass(t, msgspec.Struct)
68
+ else type(obj) in t if isinstance(t, tuple) else type(obj) is t
69
+ )
70
+
71
+
53
72
  def msgspec_convert(obj: typing.Any, t: type[T]) -> Result[T, str]:
54
73
  try:
55
74
  return Ok(decoder.convert(obj, type=t, strict=True))
@@ -68,72 +87,62 @@ def msgspec_to_builtins(
68
87
  str_keys: bool = False,
69
88
  builtin_types: typing.Iterable[type[typing.Any]] | None = None,
70
89
  order: typing.Literal["deterministic", "sorted"] | None = None,
71
- ) -> typing.Any:
72
- return encoder.to_builtins(**locals())
90
+ ) -> fntypes.result.Result[typing.Any, msgspec.ValidationError]:
91
+ try:
92
+ return Ok(encoder.to_builtins(**locals()))
93
+ except msgspec.ValidationError as exc:
94
+ return Error(exc)
73
95
 
74
96
 
75
97
  def option_dec_hook(tp: type[Option[typing.Any]], obj: typing.Any) -> Option[typing.Any]:
76
- orig_type = get_origin(tp)
77
- (value_type,) = typing.get_args(tp) or (typing.Any,)
98
+ if obj is None:
99
+ return Nothing
78
100
 
79
- if obj is None and orig_type in (fntypes.option.Nothing, Option):
80
- return fntypes.option.Nothing()
81
- return fntypes.option.Some(msgspec_convert(obj, value_type).unwrap())
82
-
83
-
84
- def result_dec_hook(
85
- tp: type[Result[typing.Any, typing.Any]], obj: typing.Any
86
- ) -> Result[typing.Any, typing.Any]:
87
- if not isinstance(obj, dict):
88
- raise TypeError(f"Cannot parse to Result object of type `{repr_type(type(obj))}`.")
89
-
90
- orig_type = get_origin(tp)
91
- (first_type, second_type) = (
92
- typing.get_args(tp) + (typing.Any,) if len(typing.get_args(tp)) == 1 else typing.get_args(tp)
93
- ) or (typing.Any, typing.Any)
101
+ (value_type,) = typing.get_args(tp) or (typing.Any,)
102
+ orig_value_type = typing.get_origin(value_type) or value_type
103
+ orig_obj = obj
94
104
 
95
- if orig_type is Ok and "ok" in obj:
96
- return Ok(msgspec_convert(obj["ok"], first_type).unwrap())
105
+ if not isinstance(orig_obj, dict | list) and is_common_type(orig_value_type):
106
+ if orig_value_type is Variative:
107
+ obj = value_type(orig_obj) # type: ignore
108
+ orig_value_type = typing.get_args(value_type)
97
109
 
98
- if orig_type is Error and "error" in obj:
99
- return Error(msgspec_convert(obj["error"], first_type).unwrap())
110
+ if not type_check(orig_obj, orig_value_type):
111
+ raise TypeError(f"Expected `{repr_type(orig_value_type)}`, got `{repr_type(type(orig_obj))}`.")
100
112
 
101
- if orig_type is Result:
102
- match obj:
103
- case {"ok": ok}:
104
- return Ok(msgspec_convert(ok, first_type).unwrap())
105
- case {"error": error}:
106
- return Error(msgspec_convert(error, second_type).unwrap())
113
+ return fntypes.option.Some(obj)
107
114
 
108
- raise msgspec.ValidationError(f"Cannot parse object `{obj!r}` to Result.")
115
+ return fntypes.option.Some(decoder.convert(orig_obj, type=value_type))
109
116
 
110
117
 
111
118
  def variative_dec_hook(tp: type[Variative], obj: typing.Any) -> Variative:
112
119
  union_types = typing.get_args(tp)
113
120
 
114
121
  if isinstance(obj, dict):
115
- struct_fields_match_sums: dict[type[msgspec.Struct], int] = {
122
+ models_struct_fields: dict[type[msgspec.Struct], int] = {
116
123
  m: sum(1 for k in obj if k in m.__struct_fields__)
117
124
  for m in union_types
118
125
  if issubclass(get_origin(m), msgspec.Struct)
119
126
  }
120
- union_types = tuple(t for t in union_types if t not in struct_fields_match_sums)
127
+ union_types = tuple(t for t in union_types if t not in models_struct_fields)
121
128
  reverse = False
122
129
 
123
- if len(set(struct_fields_match_sums.values())) != len(struct_fields_match_sums.values()):
124
- struct_fields_match_sums = {m: len(m.__struct_fields__) for m in struct_fields_match_sums}
130
+ if len(set(models_struct_fields.values())) != len(models_struct_fields.values()):
131
+ models_struct_fields = {m: len(m.__struct_fields__) for m in models_struct_fields}
125
132
  reverse = True
126
133
 
127
134
  union_types = (
128
135
  *sorted(
129
- struct_fields_match_sums,
130
- key=lambda k: struct_fields_match_sums[k],
136
+ models_struct_fields,
137
+ key=lambda k: models_struct_fields[k],
131
138
  reverse=reverse,
132
139
  ),
133
140
  *union_types,
134
141
  )
135
142
 
136
143
  for t in union_types:
144
+ if not isinstance(obj, dict | list) and is_common_type(t) and type_check(obj, t):
145
+ return tp(obj)
137
146
  match msgspec_convert(obj, t):
138
147
  case Ok(value):
139
148
  return tp(value)
@@ -179,12 +188,9 @@ class Decoder:
179
188
 
180
189
  def __init__(self) -> None:
181
190
  self.dec_hooks: dict[typing.Any, DecHook[typing.Any]] = {
182
- Result: result_dec_hook,
183
191
  Option: option_dec_hook,
184
192
  Variative: variative_dec_hook,
185
193
  datetime: lambda t, obj: t.fromtimestamp(obj),
186
- fntypes.result.Error: result_dec_hook,
187
- fntypes.result.Ok: result_dec_hook,
188
194
  fntypes.option.Some: option_dec_hook,
189
195
  fntypes.option.Nothing: option_dec_hook,
190
196
  }
@@ -271,10 +277,6 @@ class Encoder:
271
277
  self.enc_hooks: dict[typing.Any, EncHook[typing.Any]] = {
272
278
  fntypes.option.Some: lambda opt: opt.value,
273
279
  fntypes.option.Nothing: lambda _: None,
274
- fntypes.result.Ok: lambda ok: {"ok": ok.value},
275
- fntypes.result.Error: lambda err: {
276
- "error": (str(err.error) if isinstance(err.error, BaseException) else err.error)
277
- },
278
280
  Variative: lambda variative: variative.v,
279
281
  datetime: lambda date: int(date.timestamp()),
280
282
  }
@@ -342,10 +344,8 @@ __all__ = (
342
344
  "datetime",
343
345
  "decoder",
344
346
  "encoder",
345
- "get_origin",
346
347
  "msgspec_convert",
348
+ "get_class_annotations",
349
+ "get_type_hints",
347
350
  "msgspec_to_builtins",
348
- "option_dec_hook",
349
- "repr_type",
350
- "variative_dec_hook",
351
351
  )
@@ -1,5 +1,5 @@
1
1
  from .attachment import Attachment, Audio, Photo, Video
2
- from .base import BaseNode, ComposeError, DataNode, Node, ScalarNode, is_node, node_impl
2
+ from .base import ComposeError, DataNode, Node, ScalarNode, is_node
3
3
  from .callback_query import CallbackQueryNode
4
4
  from .command import CommandInfo
5
5
  from .composer import Composition, NodeCollection, NodeSession, compose_node, compose_nodes
@@ -18,7 +18,6 @@ from .update import UpdateNode
18
18
  __all__ = (
19
19
  "Attachment",
20
20
  "Audio",
21
- "BaseNode",
22
21
  "CallbackQueryNode",
23
22
  "ChatSource",
24
23
  "CommandInfo",
@@ -52,7 +51,6 @@ __all__ = (
52
51
  "global_node",
53
52
  "impl",
54
53
  "is_node",
55
- "node_impl",
56
54
  "per_call",
57
55
  "per_event",
58
56
  )
telegrinder/node/base.py CHANGED
@@ -3,17 +3,10 @@ import inspect
3
3
  import typing
4
4
  from types import AsyncGeneratorType
5
5
 
6
- from telegrinder.api.api import API
7
- from telegrinder.bot.cute_types.update import UpdateCute
8
- from telegrinder.bot.dispatch.context import Context
9
6
  from telegrinder.node.scope import NodeScope
10
7
  from telegrinder.tools.magic import (
11
- NODE_IMPL_MARK,
12
8
  cache_magic_value,
13
9
  get_annotations,
14
- get_impls_by_key,
15
- magic_bundle,
16
- node_impl,
17
10
  )
18
11
 
19
12
  ComposeResult: typing.TypeAlias = typing.Awaitable[typing.Any] | typing.AsyncGenerator[typing.Any, None]
@@ -29,11 +22,6 @@ def is_node(maybe_node: type[typing.Any]) -> typing.TypeGuard[type["Node"]]:
29
22
  )
30
23
 
31
24
 
32
- @cache_magic_value("__compose_annotations__")
33
- def get_compose_annotations(function: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
34
- return {k: v for k, v in get_annotations(function).items() if not is_node(v)}
35
-
36
-
37
25
  @cache_magic_value("__nodes__")
38
26
  def get_nodes(function: typing.Callable[..., typing.Any]) -> dict[str, type["Node"]]:
39
27
  return {k: v for k, v in get_annotations(function).items() if is_node(v)}
@@ -44,23 +32,19 @@ def is_generator(function: typing.Callable[..., typing.Any]) -> typing.TypeGuard
44
32
  return inspect.isasyncgenfunction(function)
45
33
 
46
34
 
47
- def get_node_impls(node_cls: type["Node"]) -> dict[str, typing.Any]:
48
- if not hasattr(node_cls, "__node_impls__"):
49
- impls = get_impls_by_key(node_cls, NODE_IMPL_MARK)
50
- if issubclass(node_cls, BaseNode):
51
- impls |= get_impls_by_key(BaseNode, NODE_IMPL_MARK)
52
- setattr(node_cls, "__node_impls__", impls)
53
- return getattr(node_cls, "__node_impls__")
54
-
35
+ def get_node_calc_lst(node: type["Node"]) -> list[type["Node"]]:
36
+ """Returns flattened list of node types in ordering required to calculate given node.
37
+ Provides caching for passed node type"""
55
38
 
56
- def get_node_impl(
57
- node: type[typing.Any],
58
- node_impls: dict[str, typing.Callable[..., typing.Any]],
59
- ) -> typing.Callable[..., typing.Any] | None:
60
- for n_impl in node_impls.values():
61
- if "return" in n_impl.__annotations__ and node is n_impl.__annotations__["return"]:
62
- return n_impl
63
- return None
39
+ if calc_lst := getattr(node, "__nodes_calc_lst__", None):
40
+ return calc_lst
41
+ nodes_lst: list[type["Node"]] = []
42
+ annotations = list(node.as_node().get_subnodes().values())
43
+ for node_type in annotations:
44
+ nodes_lst.extend(get_node_calc_lst(node_type))
45
+ calc_lst = [*nodes_lst, node]
46
+ setattr(node, "__nodes_calc_lst__", calc_lst)
47
+ return calc_lst
64
48
 
65
49
 
66
50
  class ComposeError(BaseException):
@@ -76,47 +60,14 @@ class Node(abc.ABC):
76
60
  def compose(cls, *args, **kwargs) -> ComposeResult:
77
61
  pass
78
62
 
79
- @classmethod
80
- async def compose_annotation(
81
- cls,
82
- annotation: typing.Any,
83
- update: UpdateCute,
84
- ctx: Context,
85
- ) -> typing.Any:
86
- orig_annotation: type[typing.Any] = typing.get_origin(annotation) or annotation
87
- n_impl = get_node_impl(orig_annotation, cls.get_node_impls())
88
- if n_impl is None:
89
- raise ComposeError(f"Node implementation for {orig_annotation!r} not found.")
90
-
91
- result = n_impl(
92
- cls,
93
- **magic_bundle(
94
- n_impl,
95
- {"update": update, "context": ctx},
96
- start_idx=0,
97
- bundle_ctx=False,
98
- ),
99
- )
100
- if inspect.isawaitable(result):
101
- return await result
102
- return result
103
-
104
63
  @classmethod
105
64
  def compose_error(cls, error: str | None = None) -> typing.NoReturn:
106
65
  raise ComposeError(error)
107
66
 
108
67
  @classmethod
109
- def get_sub_nodes(cls) -> dict[str, type["Node"]]:
68
+ def get_subnodes(cls) -> dict[str, type["Node"]]:
110
69
  return get_nodes(cls.compose)
111
70
 
112
- @classmethod
113
- def get_compose_annotations(cls) -> dict[str, typing.Any]:
114
- return get_compose_annotations(cls.compose)
115
-
116
- @classmethod
117
- def get_node_impls(cls) -> dict[str, typing.Callable[..., typing.Any]]:
118
- return get_node_impls(cls)
119
-
120
71
  @classmethod
121
72
  def as_node(cls) -> type[typing.Self]:
122
73
  return cls
@@ -126,26 +77,7 @@ class Node(abc.ABC):
126
77
  return is_generator(cls.compose)
127
78
 
128
79
 
129
- class BaseNode(Node, abc.ABC):
130
- @classmethod
131
- @abc.abstractmethod
132
- def compose(cls, *args, **kwargs) -> ComposeResult:
133
- pass
134
-
135
- @node_impl
136
- def compose_api(cls, update: UpdateCute) -> API:
137
- return update.ctx_api
138
-
139
- @node_impl
140
- def compose_context(cls, context: Context) -> Context:
141
- return context
142
-
143
- @node_impl
144
- def compose_update(cls, update: UpdateCute) -> UpdateCute:
145
- return update
146
-
147
-
148
- class DataNode(BaseNode, abc.ABC):
80
+ class DataNode(Node, abc.ABC):
149
81
  node = "data"
150
82
 
151
83
  @typing.dataclass_transform()
@@ -155,7 +87,7 @@ class DataNode(BaseNode, abc.ABC):
155
87
  pass
156
88
 
157
89
 
158
- class ScalarNodeProto(BaseNode, abc.ABC):
90
+ class ScalarNodeProto(Node, abc.ABC):
159
91
  @classmethod
160
92
  @abc.abstractmethod
161
93
  async def compose(cls, *args, **kwargs) -> ComposeResult:
@@ -171,6 +103,10 @@ if typing.TYPE_CHECKING:
171
103
  pass
172
104
 
173
105
  else:
106
+ def __init_subclass__(cls, *args, **kwargs): # noqa: N807
107
+ if any(issubclass(base, ScalarNode) for base in cls.__bases__ if base is not ScalarNode):
108
+ raise RuntimeError("Scalar nodes do not support inheritance.")
109
+
174
110
  def create_node(cls, bases, dct):
175
111
  dct.update(cls.__dict__)
176
112
  return type(cls.__name__, bases, dct)
@@ -182,6 +118,7 @@ else:
182
118
  {
183
119
  "as_node": classmethod(lambda cls: create_node(cls, bases, dct)),
184
120
  "scope": Node.scope,
121
+ "__init_subclass__": __init_subclass__,
185
122
  },
186
123
  )
187
124
 
@@ -190,14 +127,12 @@ else:
190
127
 
191
128
 
192
129
  __all__ = (
193
- "BaseNode",
194
130
  "ComposeError",
195
131
  "DataNode",
196
132
  "Node",
197
133
  "SCALAR_NODE",
198
134
  "ScalarNode",
199
135
  "ScalarNodeProto",
200
- "get_compose_annotations",
201
136
  "get_nodes",
202
137
  "is_node",
203
138
  )
@@ -1,10 +1,10 @@
1
1
  import typing
2
2
  from dataclasses import dataclass, field
3
3
 
4
- from fntypes import Nothing, Option, Some
4
+ from fntypes.option import Nothing, Option, Some
5
5
 
6
- from .base import DataNode
7
- from .text import Text
6
+ from telegrinder.node.base import DataNode
7
+ from telegrinder.node.text import Text
8
8
 
9
9
 
10
10
  def single_split(s: str, separator: str) -> tuple[str, str]:
@@ -1,105 +1,97 @@
1
1
  import dataclasses
2
2
  import typing
3
3
 
4
+ from fntypes import Error, Ok, Result
4
5
  from fntypes.error import UnwrapError
5
6
 
6
- from telegrinder.bot.cute_types.update import UpdateCute
7
+ from telegrinder.api import API
8
+ from telegrinder.bot.cute_types.update import Update, UpdateCute
7
9
  from telegrinder.bot.dispatch.context import Context
8
10
  from telegrinder.modules import logger
9
11
  from telegrinder.node.base import (
10
- BaseNode,
11
12
  ComposeError,
12
13
  Node,
13
14
  NodeScope,
14
- get_compose_annotations,
15
+ get_node_calc_lst,
15
16
  get_nodes,
16
17
  )
17
18
  from telegrinder.tools.magic import magic_bundle
18
19
 
19
- CONTEXT_STORE_NODES_KEY = "node_ctx"
20
+ CONTEXT_STORE_NODES_KEY = "_node_ctx"
21
+ GLOBAL_VALUE_KEY = "_value"
20
22
 
21
23
 
22
24
  async def compose_node(
23
25
  _node: type[Node],
24
- update: UpdateCute,
25
- ctx: Context,
26
+ linked: dict[type, typing.Any],
26
27
  ) -> "NodeSession":
27
28
  node = _node.as_node()
28
- context = NodeCollection({})
29
- node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
30
-
31
- for name, subnode in node.get_sub_nodes().items():
32
- if subnode in node_ctx:
33
- context.sessions[name] = node_ctx[subnode]
34
- else:
35
- context.sessions[name] = await compose_node(subnode, update, ctx)
36
-
37
- if getattr(subnode, "scope", None) is NodeScope.PER_EVENT:
38
- node_ctx[subnode] = context.sessions[name]
39
-
40
- for name, annotation in node.get_compose_annotations().items():
41
- context.sessions[name] = NodeSession(
42
- None,
43
- await node.compose_annotation(annotation, update, ctx),
44
- {},
45
- )
29
+ kwargs = magic_bundle(node.compose, linked, typebundle=True)
46
30
 
47
31
  if node.is_generator():
48
- generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**context.values()))
32
+ generator = typing.cast(typing.AsyncGenerator[typing.Any, None], node.compose(**kwargs))
49
33
  value = await generator.asend(None)
50
34
  else:
51
35
  generator = None
52
- value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**context.values()))
36
+ value = await typing.cast(typing.Awaitable[typing.Any], node.compose(**kwargs))
53
37
 
54
- return NodeSession(_node, value, context.sessions, generator)
38
+ return NodeSession(_node, value, {}, generator)
55
39
 
56
40
 
57
41
  async def compose_nodes(
58
- update: UpdateCute,
59
- ctx: Context,
60
42
  nodes: dict[str, type[Node]],
61
- node_class: type[Node] | None = None,
62
- context_annotations: dict[str, typing.Any] | None = None,
63
- ) -> "NodeCollection | None":
64
- node_sessions: dict[str, NodeSession] = {}
65
- node_ctx: dict[type[Node], "NodeSession"] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
66
-
67
- try:
68
- for name, node_t in nodes.items():
43
+ ctx: Context,
44
+ data: dict[type, typing.Any] | None = None,
45
+ ) -> Result["NodeCollection", ComposeError]:
46
+ logger.debug("Composing nodes: {!r}...", nodes)
47
+
48
+ parent_nodes: dict[type[Node], NodeSession] = {}
49
+ event_nodes: dict[type[Node], NodeSession] = ctx.get_or_set(CONTEXT_STORE_NODES_KEY, {})
50
+ data = {Context: ctx} | (data or {})
51
+
52
+ # Create flattened list of ordered nodes to be calculated
53
+ # TODO: optimize flattened list calculation via caching key = tuple of node types
54
+ calculation_nodes: list[list[type[Node]]] = []
55
+ for node_t in nodes.values():
56
+ calculation_nodes.append(get_node_calc_lst(node_t))
57
+
58
+ for linked_nodes in calculation_nodes:
59
+ local_nodes: dict[type[Node], "NodeSession"] = {}
60
+ for node_t in linked_nodes:
69
61
  scope = getattr(node_t, "scope", None)
70
62
 
71
- if scope is NodeScope.PER_EVENT and node_t in node_ctx:
72
- node_sessions[name] = node_ctx[node_t]
63
+ if scope is NodeScope.PER_EVENT and node_t in event_nodes:
64
+ local_nodes[node_t] = event_nodes[node_t]
73
65
  continue
74
- elif scope is NodeScope.GLOBAL and hasattr(node_t, "_value"):
75
- node_sessions[name] = getattr(node_t, "_value")
66
+ elif scope is NodeScope.GLOBAL and hasattr(node_t, GLOBAL_VALUE_KEY):
67
+ local_nodes[node_t] = getattr(node_t, GLOBAL_VALUE_KEY)
76
68
  continue
77
69
 
78
- node_sessions[name] = await compose_node(node_t, update, ctx)
70
+ subnodes = {
71
+ k: session.value
72
+ for k, session in
73
+ (local_nodes | event_nodes).items()
74
+ }
75
+
76
+ try:
77
+ local_nodes[node_t] = await compose_node(node_t, subnodes | data)
78
+ except (ComposeError, UnwrapError) as exc:
79
+ for t, local_node in local_nodes.items():
80
+ if t.scope is NodeScope.PER_CALL:
81
+ await local_node.close()
82
+ return Error(ComposeError(f"Cannot compose {node_t}. Error: {exc}"))
79
83
 
80
84
  if scope is NodeScope.PER_EVENT:
81
- node_ctx[node_t] = node_sessions[name]
85
+ event_nodes[node_t] = local_nodes[node_t]
82
86
  elif scope is NodeScope.GLOBAL:
83
- setattr(node_t, "_value", node_sessions[name])
84
- except (ComposeError, UnwrapError) as exc:
85
- logger.debug(f"Composing node (name={name!r}, node_class={node_t!r}) failed with error: {str(exc)!r}")
86
- await NodeCollection(node_sessions).close_all()
87
- return None
87
+ setattr(node_t, GLOBAL_VALUE_KEY, local_nodes[node_t])
88
88
 
89
- if context_annotations:
90
- node_class = node_class or BaseNode
89
+ # Last node is the parent node
90
+ parent_node_t = linked_nodes[-1]
91
+ parent_nodes[parent_node_t] = local_nodes[parent_node_t]
91
92
 
92
- try:
93
- for name, annotation in context_annotations.items():
94
- node_sessions[name] = await node_class.compose_annotation(annotation, update, ctx)
95
- except (ComposeError, UnwrapError) as exc:
96
- logger.debug(
97
- f"Composing context annotation (name={name!r}, annotation={annotation!r}) failed with error: {str(exc)!r}",
98
- )
99
- await NodeCollection(node_sessions).close_all()
100
- return None
101
-
102
- return NodeCollection(node_sessions)
93
+ node_sessions = {k: parent_nodes[t] for k, t in nodes.items()}
94
+ return Ok(NodeCollection(node_sessions))
103
95
 
104
96
 
105
97
  @dataclasses.dataclass(slots=True, repr=False)
@@ -107,7 +99,7 @@ class NodeSession:
107
99
  node_type: type[Node] | None
108
100
  value: typing.Any
109
101
  subnodes: dict[str, typing.Self]
110
- generator: typing.AsyncGenerator[typing.Any, None] | None = None
102
+ generator: typing.AsyncGenerator[typing.Any, typing.Any | None] | None = None
111
103
 
112
104
  def __repr__(self) -> str:
113
105
  return f"<{self.__class__.__name__}: {self.value!r}" + (" ACTIVE>" if self.generator else ">")
@@ -126,6 +118,7 @@ class NodeSession:
126
118
  if self.generator is None:
127
119
  return
128
120
  try:
121
+ logger.debug("Closing session for node {!r}...", self.node_type)
129
122
  await self.generator.asend(with_value)
130
123
  except StopAsyncIteration:
131
124
  self.generator = None
@@ -140,6 +133,7 @@ class NodeCollection:
140
133
  def __repr__(self) -> str:
141
134
  return "<{}: sessions={!r}>".format(self.__class__.__name__, self.sessions)
142
135
 
136
+ @property
143
137
  def values(self) -> dict[str, typing.Any]:
144
138
  return {name: session.value for name, session in self.sessions.items()}
145
139
 
@@ -156,30 +150,33 @@ class NodeCollection:
156
150
  class Composition:
157
151
  func: typing.Callable[..., typing.Any]
158
152
  is_blocking: bool
159
- node_class: type[Node] = dataclasses.field(default_factory=lambda: BaseNode)
160
153
  nodes: dict[str, type[Node]] = dataclasses.field(init=False)
161
- context_annotations: dict[str, typing.Any] = dataclasses.field(init=False)
162
154
 
163
155
  def __post_init__(self) -> None:
164
156
  self.nodes = get_nodes(self.func)
165
- self.context_annotations = get_compose_annotations(self.func)
166
157
 
167
158
  def __repr__(self) -> str:
168
- return "<{}: for function={!r} with nodes={!r}, context_annotations={!r}>".format(
159
+ return "<{}: for function={!r} with nodes={!r}>".format(
169
160
  ("blocking " if self.is_blocking else "") + self.__class__.__name__,
170
161
  self.func.__qualname__,
171
162
  self.nodes,
172
- self.context_annotations,
173
163
  )
174
164
 
175
- async def compose_nodes(self, update: UpdateCute, context: Context) -> NodeCollection | None:
176
- return await compose_nodes(
177
- update=update,
178
- ctx=context,
165
+ async def compose_nodes(
166
+ self,
167
+ update: UpdateCute,
168
+ context: Context,
169
+ ) -> NodeCollection | None:
170
+ match await compose_nodes(
179
171
  nodes=self.nodes,
180
- node_class=self.node_class,
181
- context_annotations=self.context_annotations,
182
- )
172
+ ctx=context,
173
+ data={Update: update, API: update.api},
174
+ ):
175
+ case Ok(col):
176
+ return col
177
+ case Error(err):
178
+ logger.debug(f"Composition failed with error: {err}")
179
+ return None
183
180
 
184
181
  async def __call__(self, **kwargs: typing.Any) -> typing.Any:
185
182
  return await self.func(**magic_bundle(self.func, kwargs, start_idx=0, bundle_ctx=False)) # type: ignore