reflex 0.8.7a1__py3-none-any.whl → 0.8.8a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflex might be problematic. Click here for more details.

Files changed (67) hide show
  1. reflex/app.py +13 -5
  2. reflex/app_mixins/lifespan.py +8 -2
  3. reflex/compiler/compiler.py +12 -12
  4. reflex/compiler/templates.py +629 -102
  5. reflex/compiler/utils.py +29 -20
  6. reflex/components/base/bare.py +17 -0
  7. reflex/components/component.py +37 -33
  8. reflex/components/core/cond.py +6 -12
  9. reflex/components/core/foreach.py +1 -1
  10. reflex/components/core/match.py +83 -60
  11. reflex/components/dynamic.py +3 -3
  12. reflex/components/el/elements/forms.py +31 -14
  13. reflex/components/el/elements/forms.pyi +0 -5
  14. reflex/components/lucide/icon.py +2 -1
  15. reflex/components/lucide/icon.pyi +2 -1
  16. reflex/components/markdown/markdown.py +2 -2
  17. reflex/components/radix/primitives/accordion.py +1 -1
  18. reflex/components/radix/primitives/drawer.py +1 -1
  19. reflex/components/radix/primitives/form.py +1 -1
  20. reflex/components/radix/primitives/slider.py +1 -1
  21. reflex/components/tags/cond_tag.py +14 -5
  22. reflex/components/tags/iter_tag.py +0 -26
  23. reflex/components/tags/match_tag.py +15 -6
  24. reflex/components/tags/tag.py +3 -6
  25. reflex/components/tags/tagless.py +14 -0
  26. reflex/constants/base.py +0 -2
  27. reflex/constants/installer.py +4 -4
  28. reflex/custom_components/custom_components.py +202 -15
  29. reflex/experimental/client_state.py +1 -1
  30. reflex/istate/manager.py +2 -1
  31. reflex/plugins/shared_tailwind.py +87 -62
  32. reflex/plugins/tailwind_v3.py +2 -2
  33. reflex/plugins/tailwind_v4.py +4 -4
  34. reflex/state.py +5 -1
  35. reflex/utils/format.py +2 -3
  36. reflex/utils/frontend_skeleton.py +2 -2
  37. reflex/utils/imports.py +18 -0
  38. reflex/utils/pyi_generator.py +10 -2
  39. reflex/utils/telemetry.py +4 -1
  40. reflex/utils/templates.py +1 -6
  41. {reflex-0.8.7a1.dist-info → reflex-0.8.8a1.dist-info}/METADATA +3 -4
  42. {reflex-0.8.7a1.dist-info → reflex-0.8.8a1.dist-info}/RECORD +45 -67
  43. reflex/.templates/jinja/app/rxconfig.py.jinja2 +0 -9
  44. reflex/.templates/jinja/custom_components/README.md.jinja2 +0 -9
  45. reflex/.templates/jinja/custom_components/__init__.py.jinja2 +0 -1
  46. reflex/.templates/jinja/custom_components/demo_app.py.jinja2 +0 -39
  47. reflex/.templates/jinja/custom_components/pyproject.toml.jinja2 +0 -25
  48. reflex/.templates/jinja/custom_components/src.py.jinja2 +0 -57
  49. reflex/.templates/jinja/web/package.json.jinja2 +0 -27
  50. reflex/.templates/jinja/web/pages/_app.js.jinja2 +0 -62
  51. reflex/.templates/jinja/web/pages/_document.js.jinja2 +0 -9
  52. reflex/.templates/jinja/web/pages/base_page.js.jinja2 +0 -21
  53. reflex/.templates/jinja/web/pages/component.js.jinja2 +0 -2
  54. reflex/.templates/jinja/web/pages/custom_component.js.jinja2 +0 -22
  55. reflex/.templates/jinja/web/pages/index.js.jinja2 +0 -18
  56. reflex/.templates/jinja/web/pages/macros.js.jinja2 +0 -38
  57. reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +0 -15
  58. reflex/.templates/jinja/web/pages/stateful_components.js.jinja2 +0 -5
  59. reflex/.templates/jinja/web/pages/utils.js.jinja2 +0 -93
  60. reflex/.templates/jinja/web/styles/styles.css.jinja2 +0 -6
  61. reflex/.templates/jinja/web/utils/context.js.jinja2 +0 -129
  62. reflex/.templates/jinja/web/utils/theme.js.jinja2 +0 -1
  63. reflex/.templates/jinja/web/vite.config.js.jinja2 +0 -74
  64. reflex/components/core/client_side_routing.pyi +0 -68
  65. {reflex-0.8.7a1.dist-info → reflex-0.8.8a1.dist-info}/WHEEL +0 -0
  66. {reflex-0.8.7a1.dist-info → reflex-0.8.8a1.dist-info}/entry_points.txt +0 -0
  67. {reflex-0.8.7a1.dist-info → reflex-0.8.8a1.dist-info}/licenses/LICENSE +0 -0
reflex/compiler/utils.py CHANGED
@@ -8,7 +8,7 @@ import traceback
8
8
  from collections.abc import Mapping, Sequence
9
9
  from datetime import datetime
10
10
  from pathlib import Path
11
- from typing import Any
11
+ from typing import Any, TypedDict
12
12
  from urllib.parse import urlparse
13
13
 
14
14
  from reflex import constants
@@ -90,7 +90,13 @@ def validate_imports(import_dict: ParsedImportDict):
90
90
  used_tags[import_name] = lib
91
91
 
92
92
 
93
- def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
93
+ class _ImportDict(TypedDict):
94
+ lib: str
95
+ default: str
96
+ rest: list[str]
97
+
98
+
99
+ def compile_imports(import_dict: ParsedImportDict) -> list[_ImportDict]:
94
100
  """Compile an import dict.
95
101
 
96
102
  Args:
@@ -104,7 +110,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
104
110
  """
105
111
  collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
106
112
  validate_imports(collapsed_import_dict)
107
- import_dicts = []
113
+ import_dicts: list[_ImportDict] = []
108
114
  for lib, fields in collapsed_import_dict.items():
109
115
  # prevent lib from being rendered on the page if all imports are non rendered kind
110
116
  if not any(f.render for f in fields):
@@ -139,7 +145,9 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
139
145
  return import_dicts
140
146
 
141
147
 
142
- def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
148
+ def get_import_dict(
149
+ lib: str, default: str = "", rest: list[str] | None = None
150
+ ) -> _ImportDict:
143
151
  """Get dictionary for import template.
144
152
 
145
153
  Args:
@@ -150,11 +158,11 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
150
158
  Returns:
151
159
  A dictionary for import template.
152
160
  """
153
- return {
154
- "lib": lib,
155
- "default": default,
156
- "rest": rest if rest else [],
157
- }
161
+ return _ImportDict(
162
+ lib=lib,
163
+ default=default,
164
+ rest=rest if rest else [],
165
+ )
158
166
 
159
167
 
160
168
  def save_error(error: Exception) -> str:
@@ -237,23 +245,20 @@ def _compile_client_storage_field(
237
245
 
238
246
  def _compile_client_storage_recursive(
239
247
  state: type[BaseState],
240
- ) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]:
248
+ ) -> tuple[
249
+ dict[str, dict[str, Any]], dict[str, dict[str, Any]], dict[str, dict[str, Any]]
250
+ ]:
241
251
  """Compile the client-side storage for the given state recursively.
242
252
 
243
253
  Args:
244
254
  state: The app state object.
245
255
 
246
256
  Returns:
247
- A tuple of the compiled client-side storage info:
248
- (
249
- cookies: dict[str, dict],
250
- local_storage: dict[str, dict[str, str]]
251
- session_storage: dict[str, dict[str, str]]
252
- ).
257
+ A tuple of the compiled client-side storage info: (cookies, local_storage, session_storage).
253
258
  """
254
- cookies = {}
255
- local_storage = {}
256
- session_storage = {}
259
+ cookies: dict[str, dict[str, Any]] = {}
260
+ local_storage: dict[str, dict[str, Any]] = {}
261
+ session_storage: dict[str, dict[str, Any]] = {}
257
262
  state_name = state.get_full_name()
258
263
  for name, field in state.__fields__.items():
259
264
  if name in state.inherited_vars:
@@ -261,6 +266,8 @@ def _compile_client_storage_recursive(
261
266
  continue
262
267
  state_key = f"{state_name}.{name}" + FIELD_MARKER
263
268
  field_type, options = _compile_client_storage_field(field)
269
+ if field_type is None or options is None:
270
+ continue
264
271
  if field_type is Cookie:
265
272
  cookies[state_key] = options
266
273
  elif field_type is LocalStorage:
@@ -281,7 +288,9 @@ def _compile_client_storage_recursive(
281
288
  return cookies, local_storage, session_storage
282
289
 
283
290
 
284
- def compile_client_storage(state: type[BaseState]) -> dict[str, dict]:
291
+ def compile_client_storage(
292
+ state: type[BaseState],
293
+ ) -> dict[str, dict[str, dict[str, Any]]]:
285
294
  """Compile the client-side storage for the given state.
286
295
 
287
296
  Args:
@@ -188,6 +188,23 @@ class Bare(Component):
188
188
  return Tagless(contents=f"{contents.to_string()!s}")
189
189
  return Tagless(contents=f"{contents!s}")
190
190
 
191
+ def render(self) -> dict:
192
+ """Render the component as a dictionary.
193
+
194
+ This is overridden to provide a short performant path for rendering.
195
+
196
+ Returns:
197
+ The rendered component.
198
+ """
199
+ contents = (
200
+ Var.create(self.contents)
201
+ if not isinstance(self.contents, Var)
202
+ else self.contents
203
+ )
204
+ if isinstance(contents, (BooleanVar, ObjectVar)):
205
+ return {"contents": f"{contents.to_string()!s}"}
206
+ return {"contents": f"{contents!s}"}
207
+
191
208
  def _add_style_recursive(
192
209
  self, style: ComponentStyle, theme: Component | None = None
193
210
  ) -> Component:
@@ -22,7 +22,7 @@ from typing_extensions import dataclass_transform
22
22
 
23
23
  import reflex.state
24
24
  from reflex import constants
25
- from reflex.compiler.templates import STATEFUL_COMPONENT
25
+ from reflex.compiler.templates import stateful_component_template
26
26
  from reflex.components.core.breakpoints import Breakpoints
27
27
  from reflex.components.dynamic import load_dynamic_serializer
28
28
  from reflex.components.field import BaseField, FieldBasedMeta
@@ -51,7 +51,7 @@ from reflex.event import (
51
51
  )
52
52
  from reflex.style import Style, format_as_emotion
53
53
  from reflex.utils import console, format, imports, types
54
- from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
54
+ from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict
55
55
  from reflex.vars import VarData
56
56
  from reflex.vars.base import (
57
57
  CachedVarOperation,
@@ -1270,7 +1270,6 @@ class Component(BaseComponent, ABC):
1270
1270
  rendered_dict = dict(
1271
1271
  tag.set(
1272
1272
  children=[child.render() for child in self.children],
1273
- contents=str(tag.contents),
1274
1273
  )
1275
1274
  )
1276
1275
  self._replace_prop_names(rendered_dict)
@@ -1498,7 +1497,7 @@ class Component(BaseComponent, ABC):
1498
1497
  yield clz.__name__
1499
1498
 
1500
1499
  @classmethod
1501
- def _iter_parent_classes_with_method(cls, method: str) -> Iterator[type[Component]]:
1500
+ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Component]]:
1502
1501
  """Iterate through parent classes that define a given method.
1503
1502
 
1504
1503
  Used for handling the `add_*` API functions that internally simulate a super() call chain.
@@ -1506,12 +1505,13 @@ class Component(BaseComponent, ABC):
1506
1505
  Args:
1507
1506
  method: The method to look for.
1508
1507
 
1509
- Yields:
1510
- The parent classes that define the method (differently than the base).
1508
+ Returns:
1509
+ A sequence of parent classes that define the method (differently than the base).
1511
1510
  """
1512
1511
  seen_methods = (
1513
1512
  {getattr(Component, method)} if hasattr(Component, method) else set()
1514
1513
  )
1514
+ clzs: list[type[Component]] = []
1515
1515
  for clz in cls.mro():
1516
1516
  if clz is Component:
1517
1517
  break
@@ -1521,7 +1521,8 @@ class Component(BaseComponent, ABC):
1521
1521
  if not callable(method_func) or method_func in seen_methods:
1522
1522
  continue
1523
1523
  seen_methods.add(method_func)
1524
- yield clz
1524
+ clzs.append(clz)
1525
+ return clzs
1525
1526
 
1526
1527
  def _get_custom_code(self) -> str | None:
1527
1528
  """Get custom code for the component.
@@ -1644,18 +1645,18 @@ class Component(BaseComponent, ABC):
1644
1645
  Returns:
1645
1646
  The imports needed by the component.
1646
1647
  """
1647
- _imports = {}
1648
-
1649
- # Import this component's tag from the main library.
1650
- if self.library is not None and self.tag is not None:
1651
- _imports[self.library] = self.import_var
1648
+ _imports = (
1649
+ {self.library: [self.import_var]}
1650
+ if self.library is not None and self.tag is not None
1651
+ else {}
1652
+ )
1652
1653
 
1653
1654
  # Get static imports required for event processing.
1654
1655
  event_imports = Imports.EVENTS if self.event_triggers else {}
1655
1656
 
1656
1657
  # Collect imports from Vars used directly by this component.
1657
1658
  var_imports = [
1658
- var_data.imports
1659
+ dict(var_data.imports)
1659
1660
  for var in self._get_vars()
1660
1661
  if (var_data := var._get_all_var_data()) is not None
1661
1662
  ]
@@ -1665,16 +1666,16 @@ class Component(BaseComponent, ABC):
1665
1666
  list_of_import_dict = clz.add_imports(self)
1666
1667
 
1667
1668
  if not isinstance(list_of_import_dict, list):
1668
- list_of_import_dict = [list_of_import_dict]
1669
-
1670
- added_import_dicts.extend(
1671
- [parse_imports(import_dict) for import_dict in list_of_import_dict]
1672
- )
1669
+ added_import_dicts.append(imports.parse_imports(list_of_import_dict))
1670
+ else:
1671
+ added_import_dicts.extend(
1672
+ [imports.parse_imports(item) for item in list_of_import_dict]
1673
+ )
1673
1674
 
1674
- return imports.merge_imports(
1675
+ return imports.merge_parsed_imports(
1675
1676
  self._get_dependencies_imports(),
1676
1677
  self._get_hooks_imports(),
1677
- {**_imports},
1678
+ _imports,
1678
1679
  event_imports,
1679
1680
  *var_imports,
1680
1681
  *added_import_dicts,
@@ -1689,7 +1690,7 @@ class Component(BaseComponent, ABC):
1689
1690
  Returns:
1690
1691
  The import dict with the required imports.
1691
1692
  """
1692
- _imports = imports.merge_imports(
1693
+ _imports = imports.merge_parsed_imports(
1693
1694
  self._get_imports(), *[child._get_all_imports() for child in self.children]
1694
1695
  )
1695
1696
  return imports.collapse_imports(_imports) if collapse else _imports
@@ -2470,7 +2471,7 @@ class StatefulComponent(BaseComponent):
2470
2471
  if not self.tag:
2471
2472
  return ""
2472
2473
  # Render the code for this component and hooks.
2473
- return STATEFUL_COMPONENT.render(
2474
+ return stateful_component_template(
2474
2475
  tag_name=self.tag,
2475
2476
  memo_trigger_hooks=self.memo_trigger_hooks,
2476
2477
  component=self.component,
@@ -2796,6 +2797,9 @@ def render_dict_to_var(tag: dict | Component | str) -> Var:
2796
2797
  return render_dict_to_var(tag.render())
2797
2798
  return Var.create(tag)
2798
2799
 
2800
+ if "contents" in tag:
2801
+ return Var(tag["contents"])
2802
+
2799
2803
  if "iterable" in tag:
2800
2804
  function_return = LiteralArrayVar.create(
2801
2805
  [render_dict_to_var(child.render()) for child in tag["children"]]
@@ -2813,27 +2817,30 @@ def render_dict_to_var(tag: dict | Component | str) -> Var:
2813
2817
  func,
2814
2818
  )
2815
2819
 
2816
- if tag["name"] == "match":
2817
- element = tag["cond"]
2820
+ if "match_cases" in tag:
2821
+ element = Var(tag["cond"])
2818
2822
 
2819
2823
  conditionals = render_dict_to_var(tag["default"])
2820
2824
 
2821
2825
  for case in tag["match_cases"][::-1]:
2822
- condition = case[0].to_string() == element.to_string()
2823
- for pattern in case[1:-1]:
2824
- condition = condition | (pattern.to_string() == element.to_string())
2826
+ conditions, return_value = case
2827
+ condition = Var.create(False)
2828
+ for pattern in conditions:
2829
+ condition = condition | (
2830
+ Var(pattern).to_string() == element.to_string()
2831
+ )
2825
2832
 
2826
2833
  conditionals = ternary_operation(
2827
2834
  condition,
2828
- render_dict_to_var(case[-1]),
2835
+ render_dict_to_var(return_value),
2829
2836
  conditionals,
2830
2837
  )
2831
2838
 
2832
2839
  return conditionals
2833
2840
 
2834
- if "cond" in tag:
2841
+ if "cond_state" in tag:
2835
2842
  return ternary_operation(
2836
- tag["cond"],
2843
+ Var(tag["cond_state"]),
2837
2844
  render_dict_to_var(tag["true_value"]),
2838
2845
  render_dict_to_var(tag["false_value"])
2839
2846
  if tag["false_value"] is not None
@@ -2842,8 +2849,6 @@ def render_dict_to_var(tag: dict | Component | str) -> Var:
2842
2849
 
2843
2850
  props = Var("({" + ",".join(tag["props"]) + "})")
2844
2851
 
2845
- contents = tag["contents"] if tag["contents"] else None
2846
-
2847
2852
  raw_tag_name = tag.get("name")
2848
2853
  tag_name = Var(raw_tag_name or "Fragment")
2849
2854
 
@@ -2852,7 +2857,6 @@ def render_dict_to_var(tag: dict | Component | str) -> Var:
2852
2857
  ).call(
2853
2858
  tag_name,
2854
2859
  props,
2855
- *([Var(contents)] if contents is not None else []),
2856
2860
  *[render_dict_to_var(child) for child in tag["children"]],
2857
2861
  )
2858
2862
 
@@ -61,7 +61,7 @@ class Cond(Component):
61
61
 
62
62
  def _render(self) -> Tag:
63
63
  return CondTag(
64
- cond=self.cond,
64
+ cond_state=str(self.cond),
65
65
  true_value=self.children[0].render(),
66
66
  false_value=self.children[1].render(),
67
67
  )
@@ -72,17 +72,11 @@ class Cond(Component):
72
72
  Returns:
73
73
  The dictionary for template of component.
74
74
  """
75
- tag = self._render()
76
- return dict(
77
- tag.add_props(
78
- **self.event_triggers,
79
- key=self.key,
80
- sx=self.style,
81
- id=self.id,
82
- class_name=self.class_name,
83
- ),
84
- cond_state=str(self.cond),
85
- )
75
+ return {
76
+ "cond_state": str(self.cond),
77
+ "true_value": self.children[0].render(),
78
+ "false_value": self.children[1].render(),
79
+ }
86
80
 
87
81
  def add_imports(self) -> ImportDict:
88
82
  """Add imports for the Cond component.
@@ -181,7 +181,7 @@ class Foreach(Component):
181
181
  tag,
182
182
  iterable_state=str(tag.iterable),
183
183
  arg_name=tag.arg_var_name,
184
- arg_index=tag.get_index_var_arg(),
184
+ arg_index=tag.index_var_name,
185
185
  )
186
186
 
187
187
 
@@ -1,11 +1,12 @@
1
1
  """rx.match."""
2
2
 
3
3
  import textwrap
4
- from typing import Any
4
+ from typing import Any, cast
5
5
 
6
6
  from reflex.components.base import Fragment
7
7
  from reflex.components.component import BaseComponent, Component, MemoizationLeaf, field
8
- from reflex.components.tags import MatchTag, Tag
8
+ from reflex.components.tags import Tag
9
+ from reflex.components.tags.match_tag import MatchTag
9
10
  from reflex.style import Style
10
11
  from reflex.utils import format
11
12
  from reflex.utils.exceptions import MatchTypeError
@@ -21,10 +22,14 @@ class Match(MemoizationLeaf):
21
22
  cond: Var[Any]
22
23
 
23
24
  # The list of match cases to be matched.
24
- match_cases: list[Any] = field(default_factory=list, is_javascript_property=False)
25
+ match_cases: list[tuple[list[Var], BaseComponent]] = field(
26
+ default_factory=list, is_javascript_property=False
27
+ )
25
28
 
26
29
  # The catchall case to match.
27
- default: Any = field(default=None, is_javascript_property=False)
30
+ default: BaseComponent = field(
31
+ default_factory=Fragment.create, is_javascript_property=False
32
+ )
28
33
 
29
34
  @classmethod
30
35
  def create(cls, cond: Any, *cases) -> Component | Var:
@@ -44,9 +49,9 @@ class Match(MemoizationLeaf):
44
49
  cases, default = cls._process_cases(list(cases))
45
50
  match_cases = cls._process_match_cases(cases)
46
51
 
47
- cls._validate_return_types(match_cases)
52
+ match_cases = cls._validate_return_types(match_cases)
48
53
 
49
- if default is None and isinstance(match_cases[0][-1], Var):
54
+ if default is None and isinstance(match_cases[0][1], Var):
50
55
  msg = "For cases with return types as Vars, a default case must be provided"
51
56
  raise ValueError(msg)
52
57
 
@@ -75,7 +80,9 @@ class Match(MemoizationLeaf):
75
80
  return match_cond_var
76
81
 
77
82
  @classmethod
78
- def _process_cases(cls, cases: list) -> tuple[list, Var | BaseComponent | None]:
83
+ def _process_cases(
84
+ cls, cases: list
85
+ ) -> tuple[list[tuple], Var | BaseComponent | None]:
79
86
  """Process the list of match cases and the catchall default case.
80
87
 
81
88
  Args:
@@ -87,26 +94,29 @@ class Match(MemoizationLeaf):
87
94
  Raises:
88
95
  ValueError: If there are multiple default cases.
89
96
  """
90
- default = None
91
-
92
- if len([case for case in cases if not isinstance(case, tuple)]) > 1:
93
- msg = "rx.match can only have one default case."
94
- raise ValueError(msg)
95
-
96
97
  if not cases:
97
98
  msg = "rx.match should have at least one case."
98
99
  raise ValueError(msg)
99
100
 
100
- # Get the default case which should be the last non-tuple arg
101
101
  if not isinstance(cases[-1], tuple):
102
- default = cases.pop()
103
- default = (
104
- cls._create_case_var_with_var_data(default)
105
- if not isinstance(default, BaseComponent)
106
- else default
102
+ *cases, default_return = cases
103
+ default_return = (
104
+ cls._create_case_var_with_var_data(default_return)
105
+ if not isinstance(default_return, BaseComponent)
106
+ else default_return
107
107
  )
108
+ else:
109
+ default_return = None
108
110
 
109
- return cases, default
111
+ if any(case for case in cases if not isinstance(case, tuple)):
112
+ msg = "rx.match should have tuples of cases and one default case as the last argument."
113
+ raise ValueError(msg)
114
+
115
+ if not cases:
116
+ msg = "rx.match should have at least one case."
117
+ raise ValueError(msg)
118
+
119
+ return cases, default_return
110
120
 
111
121
  @classmethod
112
122
  def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
@@ -124,7 +134,9 @@ class Match(MemoizationLeaf):
124
134
  return LiteralVar.create(case_element, _var_data=_var_data)
125
135
 
126
136
  @classmethod
127
- def _process_match_cases(cls, cases: list) -> list[list[Var]]:
137
+ def _process_match_cases(
138
+ cls, cases: list[tuple]
139
+ ) -> list[tuple[list[Var], BaseComponent | Var]]:
128
140
  """Process the individual match cases.
129
141
 
130
142
  Args:
@@ -136,40 +148,48 @@ class Match(MemoizationLeaf):
136
148
  Raises:
137
149
  ValueError: If the default case is not the last case or the tuple elements are less than 2.
138
150
  """
139
- match_cases = []
140
- for case in cases:
141
- if not isinstance(case, tuple):
142
- msg = "rx.match should have tuples of cases and a default case as the last argument."
143
- raise ValueError(msg)
151
+ match_cases: list[tuple[list[Var], BaseComponent | Var]] = []
152
+ for case_index, case in enumerate(cases):
144
153
  # There should be at least two elements in a case tuple(a condition and return value)
145
154
  if len(case) < 2:
146
155
  msg = "A case tuple should have at least a match case element and a return value."
147
156
  raise ValueError(msg)
148
157
 
149
- case_list = []
150
- for element in case:
151
- # convert all non component element to vars.
152
- el = (
153
- cls._create_case_var_with_var_data(element)
154
- if not isinstance(element, BaseComponent)
155
- else element
156
- )
157
- if not isinstance(el, (Var, BaseComponent)):
158
- msg = "Case element must be a var or component"
158
+ *conditions, return_value = case
159
+
160
+ conditions_vars: list[Var] = []
161
+ for condition_index, condition in enumerate(conditions):
162
+ if isinstance(condition, BaseComponent):
163
+ msg = f"Match condition {condition_index} of case {case_index} cannot be a component."
159
164
  raise ValueError(msg)
160
- case_list.append(el)
165
+ conditions_vars.append(cls._create_case_var_with_var_data(condition))
166
+
167
+ return_value = (
168
+ cls._create_case_var_with_var_data(return_value)
169
+ if not isinstance(return_value, BaseComponent)
170
+ else return_value
171
+ )
172
+
173
+ if not isinstance(return_value, (Var, BaseComponent)):
174
+ msg = "Return value must be a var or component"
175
+ raise ValueError(msg)
161
176
 
162
- match_cases.append(case_list)
177
+ match_cases.append((conditions_vars, return_value))
163
178
 
164
179
  return match_cases
165
180
 
166
181
  @classmethod
167
- def _validate_return_types(cls, match_cases: list[list[Var]]) -> None:
182
+ def _validate_return_types(
183
+ cls, match_cases: list[tuple[list[Var], BaseComponent | Var]]
184
+ ) -> list[tuple[list[Var], Var]] | list[tuple[list[Var], BaseComponent]]:
168
185
  """Validate that match cases have the same return types.
169
186
 
170
187
  Args:
171
188
  match_cases: The match cases.
172
189
 
190
+ Returns:
191
+ The validated match cases.
192
+
173
193
  Raises:
174
194
  MatchTypeError: If the return types of cases are different.
175
195
  """
@@ -181,20 +201,25 @@ class Match(MemoizationLeaf):
181
201
  elif isinstance(first_case_return, Var):
182
202
  return_type = Var
183
203
 
204
+ cases = []
184
205
  for index, case in enumerate(match_cases):
185
- if not isinstance(case[-1], return_type):
206
+ conditions, return_value = case
207
+ if not isinstance(return_value, return_type):
186
208
  msg = (
187
209
  f"Match cases should have the same return types. Case {index} with return "
188
- f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
189
- f" of type {type(case[-1])!r} is not {return_type}"
210
+ f"value `{return_value._js_expr if isinstance(return_value, Var) else textwrap.shorten(str(return_value), width=250)}`"
211
+ f" of type {type(return_value)!r} is not {return_type}"
190
212
  )
191
213
  raise MatchTypeError(msg)
214
+ cases.append((conditions, return_value))
215
+ return cases
192
216
 
193
217
  @classmethod
194
218
  def _create_match_cond_var_or_component(
195
219
  cls,
196
220
  match_cond_var: Var,
197
- match_cases: list[list[Var]],
221
+ match_cases: list[tuple[list[Var], BaseComponent]]
222
+ | list[tuple[list[Var], Var]],
198
223
  default: Var | BaseComponent | None,
199
224
  ) -> Component | Var:
200
225
  """Create and return the match condition var or component.
@@ -206,29 +231,22 @@ class Match(MemoizationLeaf):
206
231
 
207
232
  Returns:
208
233
  The match component wrapped in a fragment or the match var.
209
-
210
- Raises:
211
- ValueError: If the return types are not vars when creating a match var for Var types.
212
234
  """
213
- if default is None and isinstance(match_cases[0][-1], BaseComponent):
214
- default = Fragment.create()
235
+ if isinstance(match_cases[0][1], BaseComponent):
236
+ if default is None:
237
+ default = Fragment.create()
215
238
 
216
- if isinstance(match_cases[0][-1], BaseComponent):
217
239
  return Fragment.create(
218
240
  cls._create(
219
241
  cond=match_cond_var,
220
242
  match_cases=match_cases,
221
243
  default=default,
222
- children=[case[-1] for case in match_cases] + [default], # pyright: ignore [reportArgumentType]
244
+ children=[case[1] for case in match_cases] + [default], # pyright: ignore [reportArgumentType]
223
245
  )
224
246
  )
225
247
 
226
- # Validate the match cases (as well as the default case) to have Var return types.
227
- if any(
228
- case for case in match_cases if not isinstance(case[-1], Var)
229
- ) or not isinstance(default, Var):
230
- msg = "Return types of match cases should be Vars."
231
- raise ValueError(msg)
248
+ match_cases = cast("list[tuple[list[Var], Var]]", match_cases)
249
+ default = cast("Var", default)
232
250
 
233
251
  return Var(
234
252
  _js_expr=format.format_match(
@@ -239,14 +257,20 @@ class Match(MemoizationLeaf):
239
257
  _var_type=default._var_type,
240
258
  _var_data=VarData.merge(
241
259
  match_cond_var._get_all_var_data(),
242
- *[el._get_all_var_data() for case in match_cases for el in case],
260
+ *[el._get_all_var_data() for case in match_cases for el in case[0]],
261
+ *[case[1]._get_all_var_data() for case in match_cases],
243
262
  default._get_all_var_data(),
244
263
  ),
245
264
  )
246
265
 
247
266
  def _render(self) -> Tag:
248
267
  return MatchTag(
249
- cond=self.cond, match_cases=self.match_cases, default=self.default
268
+ cond=str(self.cond),
269
+ match_cases=[
270
+ ([str(cond) for cond in conditions], return_value.render())
271
+ for conditions, return_value in self.match_cases
272
+ ],
273
+ default=self.default.render(),
250
274
  )
251
275
 
252
276
  def render(self) -> dict:
@@ -255,8 +279,7 @@ class Match(MemoizationLeaf):
255
279
  Returns:
256
280
  The dictionary for template of component.
257
281
  """
258
- tag = self._render()
259
- return dict(tag.set(name="match"))
282
+ return dict(self._render())
260
283
 
261
284
  def add_imports(self) -> ImportDict:
262
285
  """Add imports for the Match component.
@@ -86,7 +86,7 @@ def load_dynamic_serializer():
86
86
  )
87
87
 
88
88
  rendered_components[
89
- templates.STATEFUL_COMPONENT.render(
89
+ templates.stateful_component_template(
90
90
  tag_name="MySSRComponent",
91
91
  memo_trigger_hooks=[],
92
92
  component=component,
@@ -111,10 +111,10 @@ def load_dynamic_serializer():
111
111
  else:
112
112
  imports[lib] = names
113
113
 
114
- module_code_lines = templates.STATEFUL_COMPONENTS.render(
114
+ module_code_lines = templates.stateful_components_template(
115
115
  imports=utils.compile_imports(imports),
116
116
  memoized_code="\n".join(rendered_components),
117
- ).splitlines()[1:]
117
+ ).splitlines()
118
118
 
119
119
  # Rewrite imports from `/` to destructure from window
120
120
  for ix, line in enumerate(module_code_lines[:]):