relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
  from __future__ import annotations
3
3
 
4
4
  # Standard library imports
5
+ import contextvars
5
6
  import io
6
7
  import itertools
7
8
  import os
@@ -12,7 +13,7 @@ import time
12
13
  import importlib
13
14
  from dataclasses import dataclass
14
15
  from pathlib import Path
15
- from typing import Any, Callable, Dict, List, Sequence, TextIO, cast
16
+ from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, cast
16
17
 
17
18
  # Third-party imports
18
19
  import rich
@@ -36,6 +37,27 @@ from ..environments import (
36
37
  runtime_env,
37
38
  )
38
39
 
40
+ # ---------------------------------------------
41
+ # Global controls for nesting TaskProgress
42
+ # ---------------------------------------------
43
+
44
+ # Type alias for any progress type that supports nesting
45
+ _ProgressType = Any # Actually TaskProgress | NotebookTaskProgress, but defined before those classes
46
+
47
+ _current_progress: contextvars.ContextVar[Optional[_ProgressType]] = contextvars.ContextVar(
48
+ 'current_progress', default=None
49
+ )
50
+
51
+
52
+ def get_current_progress() -> Optional[_ProgressType]:
53
+ """Get the currently active TaskProgress, if any."""
54
+ return _current_progress.get()
55
+
56
+
57
+ def _set_current_progress(progress: Optional[_ProgressType]) -> contextvars.Token:
58
+ """Set the current TaskProgress and return a token for restoration."""
59
+ return _current_progress.set(progress)
60
+
39
61
  #--------------------------------------------------
40
62
  # Constants
41
63
  #--------------------------------------------------
@@ -895,8 +917,14 @@ class _TaskStateMixin:
895
917
  self.add_sub_task(sub_status, "default")
896
918
  self.update_sub_task("default", sub_status)
897
919
 
920
+ def update_main_status_fn(self, fn: Callable[[], str]) -> None:
921
+ """Update the main status line using a callable function."""
922
+ self._description_fn = fn
923
+
898
924
  def update_main_status(self, message: str) -> None:
899
925
  """Update the main status line with custom information."""
926
+ if self._description_fn is not None:
927
+ self._description_fn = None
900
928
  if getattr(self, "description", "") != message:
901
929
  self.description = message
902
930
  self._after_task_update()
@@ -904,6 +932,8 @@ class _TaskStateMixin:
904
932
  def update_messages(self, updater: dict[str, str]) -> None:
905
933
  """Update both main message and sub-status if provided."""
906
934
  if "message" in updater:
935
+ if self._description_fn is not None:
936
+ self._description_fn = None
907
937
  self.description = updater["message"]
908
938
  self._after_task_update()
909
939
  if "sub_status" in updater:
@@ -998,7 +1028,7 @@ class TaskProgress(_TaskStateMixin):
998
1028
 
999
1029
  def __init__(
1000
1030
  self,
1001
- description: str = "",
1031
+ description: str | Callable[[], str] = "",
1002
1032
  success_message: str = "",
1003
1033
  failure_message: str = "",
1004
1034
  leading_newline: bool = False,
@@ -1007,8 +1037,13 @@ class TaskProgress(_TaskStateMixin):
1007
1037
  hide_on_completion: bool = False,
1008
1038
  show_duration_summary: bool = True,
1009
1039
  ):
1010
- # Public configuration
1011
- self.description = description
1040
+ # Public configuration - description can be a string or callable
1041
+ if callable(description):
1042
+ self._description_fn: Callable[[], str] | None = description
1043
+ self.description = description() # Initial value
1044
+ else:
1045
+ self._description_fn = None
1046
+ self.description = description
1012
1047
  self.success_message = success_message
1013
1048
  self.failure_message = failure_message
1014
1049
  self.leading_newline = leading_newline
@@ -1046,19 +1081,29 @@ class TaskProgress(_TaskStateMixin):
1046
1081
 
1047
1082
  self._spinner_thread = None
1048
1083
 
1084
+ # Nesting support
1085
+ self._context_token: Optional[contextvars.Token] = None
1086
+ self._parent_progress: Optional[TaskProgress] = None
1087
+
1049
1088
  def _generate_task_id(self) -> str:
1050
1089
  """Generate a unique task ID."""
1051
1090
  task_id = f"task_{self._next_task_id}"
1052
1091
  self._next_task_id += 1
1053
1092
  return task_id
1054
1093
 
1055
- def _compute_state_hash(self) -> int:
1094
+ def _get_description(self) -> str:
1095
+ """Get the current description, calling the function if one was provided."""
1096
+ if self._description_fn is not None:
1097
+ return self._description_fn()
1098
+ return self.description
1099
+
1100
+ def _compute_state_hash(self, description: str) -> int:
1056
1101
  """Compute a simple hash of the current state for caching."""
1057
1102
  # Use a simple hash based on key state variables
1058
1103
  state_parts = [
1059
1104
  str(self.main_completed),
1060
1105
  str(self.main_failed),
1061
- self.description,
1106
+ description,
1062
1107
  str(self.spinner_index),
1063
1108
  str(len(self._tasks)),
1064
1109
  str(len(self._highlighted_tasks)),
@@ -1074,8 +1119,11 @@ class TaskProgress(_TaskStateMixin):
1074
1119
 
1075
1120
  def _render_display(self):
1076
1121
  """Render the current display state with caching optimization."""
1122
+ # Get current description (may be dynamic)
1123
+ description = self._get_description()
1124
+
1077
1125
  # Check if we need to re-render
1078
- current_hash = self._compute_state_hash()
1126
+ current_hash = self._compute_state_hash(description)
1079
1127
  if current_hash == self._last_state_hash and self._render_cache is not None:
1080
1128
  return self._render_cache
1081
1129
 
@@ -1084,20 +1132,20 @@ class TaskProgress(_TaskStateMixin):
1084
1132
  # Build main task line
1085
1133
  if self.main_failed:
1086
1134
  # Split the description to style only the "Failed:" part in red
1087
- if self.description.startswith("Failed:"):
1135
+ if description.startswith("Failed:"):
1088
1136
  failed_part = "Failed:"
1089
- rest_part = self.description[len("Failed:"):].lstrip()
1137
+ rest_part = description[len("Failed:"):].lstrip()
1090
1138
  main_line = (Text(f"{FAIL_ICON} ", style="red") +
1091
1139
  Text(failed_part, style="red") +
1092
1140
  Text(f" {rest_part}", style="default"))
1093
1141
  else:
1094
1142
  # Fallback if description doesn't start with "Failed:"
1095
- main_line = Text(f"{FAIL_ICON} ", style="red") + Text(self.description, style="red")
1143
+ main_line = Text(f"{FAIL_ICON} ", style="red") + Text(description, style="red")
1096
1144
  elif self.main_completed:
1097
- main_line = Text(f"{SUCCESS_ICON} ", style="green") + Text(self.description, style="green")
1145
+ main_line = Text(f"{SUCCESS_ICON} ", style="green") + Text(description, style="green")
1098
1146
  else:
1099
1147
  spinner_text = SPINNER_FRAMES[self.spinner_index]
1100
- main_line = Text(f"{spinner_text} ", style="magenta") + Text(self.description, style="magenta")
1148
+ main_line = Text(f"{spinner_text} ", style="magenta") + Text(description, style="magenta")
1101
1149
 
1102
1150
  # Build subtask lines
1103
1151
  subtask_lines = self._render_subtask_lines()
@@ -1231,6 +1279,14 @@ class TaskProgress(_TaskStateMixin):
1231
1279
  return "\n".join(lines)
1232
1280
 
1233
1281
  def __enter__(self):
1282
+ # Handle nesting: pause any parent progress
1283
+ self._parent_progress = get_current_progress()
1284
+ if self._parent_progress is not None:
1285
+ self._parent_progress._pause()
1286
+
1287
+ # Set ourselves as the current progress
1288
+ self._context_token = _set_current_progress(self)
1289
+
1234
1290
  if self.leading_newline:
1235
1291
  print()
1236
1292
 
@@ -1246,16 +1302,40 @@ class TaskProgress(_TaskStateMixin):
1246
1302
 
1247
1303
  def _start_spinner(self):
1248
1304
  """Start the spinner animation thread."""
1305
+ self._spinner_paused = False
1306
+
1249
1307
  def spinner_animation():
1250
1308
  while self.live and not self.main_completed and not self.main_failed:
1251
1309
  time.sleep(SPINNER_UPDATE_INTERVAL)
1252
- if self.live:
1310
+ if self.live and not self._spinner_paused:
1253
1311
  self._advance_spinner()
1254
1312
  self.live.update(self._render_display())
1255
1313
 
1256
1314
  self._spinner_thread = threading.Thread(target=spinner_animation, daemon=True)
1257
1315
  self._spinner_thread.start()
1258
1316
 
1317
+ def _pause(self):
1318
+ """Pause the live display to allow a child progress to render."""
1319
+ self._spinner_paused = True
1320
+ if self.live:
1321
+ # Clear the live display content before stopping so child has a clean slate
1322
+ from rich.text import Text
1323
+ self.live.update(Text(""))
1324
+ self.live.stop()
1325
+ # live.stop() leaves the cursor after the (empty) rendered content.
1326
+ # Move cursor up one line so the child renders in the same place.
1327
+ if not self.is_ci and sys.stdout.isatty():
1328
+ sys.stdout.write("\033[A\r\033[K")
1329
+ sys.stdout.flush()
1330
+
1331
+ def _resume(self):
1332
+ """Resume the live display after a child progress finishes."""
1333
+ self._spinner_paused = False
1334
+ if self.live and not self.main_completed and not self.main_failed:
1335
+ # Re-render current state and restart the live display
1336
+ self.live.update(self._render_display())
1337
+ self.live.start()
1338
+
1259
1339
  def __exit__(self, exc_type, exc_val, exc_tb):
1260
1340
  # Stop timer manager
1261
1341
  self._timer_manager.stop()
@@ -1263,12 +1343,23 @@ class TaskProgress(_TaskStateMixin):
1263
1343
  if exc_type is not None:
1264
1344
  # Exception occurred - show failure message
1265
1345
  self._handle_failure(exc_val)
1266
- return False # Don't suppress the exception
1346
+ result = False # Don't suppress the exception
1267
1347
  else:
1268
1348
  # Success - show completion
1269
1349
  self._handle_success()
1350
+ result = True
1351
+
1352
+ # Restore the parent progress as current
1353
+ if self._context_token is not None:
1354
+ _current_progress.reset(self._context_token)
1355
+ self._context_token = None
1356
+
1357
+ # Resume the parent progress if there was one
1358
+ if self._parent_progress is not None:
1359
+ self._parent_progress._resume()
1360
+ self._parent_progress = None
1270
1361
 
1271
- return True
1362
+ return result
1272
1363
 
1273
1364
  def _handle_failure(self, exc_val):
1274
1365
  """Handle failure case in context manager exit."""
@@ -1352,7 +1443,7 @@ def _calculate_category_duration(category_name: str, tasks: Dict[str, TaskInfo])
1352
1443
  return sum(task_info.get_duration() for task_info in tasks.values())
1353
1444
 
1354
1445
 
1355
- def create_progress(description: str = "", success_message: str = "", failure_message: str = "",
1446
+ def create_progress(description: str | Callable[[], str] = "", success_message: str = "", failure_message: str = "",
1356
1447
  leading_newline: bool = False, trailing_newline: bool = False, show_duration_summary: bool = True):
1357
1448
  """Factory function to create the appropriate progress component based on environment.
1358
1449
 
@@ -1413,14 +1504,20 @@ class NotebookTaskProgress(_TaskStateMixin):
1413
1504
 
1414
1505
  def __init__(
1415
1506
  self,
1416
- description: str = "",
1507
+ description: str | Callable[[], str] = "",
1417
1508
  success_message: str = "",
1418
1509
  failure_message: str = "",
1419
1510
  leading_newline: bool = False,
1420
1511
  trailing_newline: bool = False,
1421
1512
  show_duration_summary: bool = True,
1422
1513
  ):
1423
- self.description = description
1514
+ # Description can be a string or callable
1515
+ if callable(description):
1516
+ self._description_fn: Callable[[], str] | None = description
1517
+ self.description = description() # Initial value
1518
+ else:
1519
+ self._description_fn = None
1520
+ self.description = description
1424
1521
  self.success_message = success_message
1425
1522
  self.failure_message = failure_message
1426
1523
  self.leading_newline = leading_newline
@@ -1447,6 +1544,16 @@ class NotebookTaskProgress(_TaskStateMixin):
1447
1544
  self._current_subtask = ""
1448
1545
  self.busy = False # Initialize busy state
1449
1546
 
1547
+ # Nesting support
1548
+ self._context_token: Optional[contextvars.Token] = None
1549
+ self._parent_progress: Optional[TaskProgress] = None
1550
+ self._spinner_paused = False
1551
+
1552
+ def _get_description(self) -> str:
1553
+ """Get the current description, calling the function if one was provided."""
1554
+ if self._description_fn is not None:
1555
+ return self._description_fn()
1556
+ return self.description
1450
1557
 
1451
1558
  def _generate_task_id(self) -> str:
1452
1559
  """Generate a unique task ID."""
@@ -1481,15 +1588,18 @@ class NotebookTaskProgress(_TaskStateMixin):
1481
1588
 
1482
1589
  spinner = "⏳⏳⏳⏳" if starting else next(self.spinner_generator)
1483
1590
 
1591
+ # Get current description (may be dynamic)
1592
+ description = self._get_description()
1593
+
1484
1594
  # If there's an active subtask, show ONLY the subtask
1485
1595
  if hasattr(self, '_current_subtask') and self._current_subtask:
1486
1596
  full_message = f"{spinner} {self._current_subtask}"
1487
1597
  else:
1488
1598
  # Otherwise show the main task with subtask count if any
1489
1599
  if len(self._tasks) > 0:
1490
- full_message = f"{spinner} {self.description} ({len(self._tasks)} active)"
1600
+ full_message = f"{spinner} {description} ({len(self._tasks)} active)"
1491
1601
  else:
1492
- full_message = f"{spinner} {self.description}"
1602
+ full_message = f"{spinner} {description}"
1493
1603
 
1494
1604
  if len(full_message) > max_width:
1495
1605
  return full_message[:max_width - 3] + "..."
@@ -1539,11 +1649,12 @@ class NotebookTaskProgress(_TaskStateMixin):
1539
1649
 
1540
1650
  def _build_jupyter_lines(self, starting: bool) -> list[str]:
1541
1651
  """Compose the main status and subtasks for Jupyter display."""
1652
+ description = self._get_description()
1542
1653
  if self.busy or starting:
1543
1654
  spinner = SPINNER_FRAMES[0] if starting else next(self.spinner_generator)
1544
- main_line = f"{spinner} {self.description}"
1655
+ main_line = f"{spinner} {description}"
1545
1656
  else:
1546
- main_text = self.success_message or self.description
1657
+ main_text = self.success_message or description
1547
1658
  main_line = f"{SUCCESS_ICON} {main_text}"
1548
1659
 
1549
1660
  visible_tasks = self._collect_visible_tasks()
@@ -1582,10 +1693,26 @@ class NotebookTaskProgress(_TaskStateMixin):
1582
1693
  def spinner_task(self):
1583
1694
  """Spinner animation task."""
1584
1695
  while self.busy and self.delay:
1585
- self.update()
1696
+ if not self._spinner_paused:
1697
+ self.update()
1586
1698
  time.sleep(self.delay) #type: ignore[union-attr] | we only call spinner_task if delay is not None anyway
1587
1699
  self.reset_cursor()
1588
1700
 
1701
+ def _pause(self):
1702
+ """Pause the display to allow a child progress to render."""
1703
+ self._spinner_paused = True
1704
+ # Clear the current line so child can render cleanly
1705
+ if not self.is_jupyter:
1706
+ sys.stdout.write("\r" + " " * DEFAULT_TERMINAL_WIDTH + "\r")
1707
+ sys.stdout.flush()
1708
+
1709
+ def _resume(self):
1710
+ """Resume the display after a child progress finishes."""
1711
+ self._spinner_paused = False
1712
+ # Force an immediate update to restore display
1713
+ if self.busy:
1714
+ self.update()
1715
+
1589
1716
  def _update_subtask_display(self, subtask_text: str):
1590
1717
  """Update sub-task display - shows ONLY the subtask text."""
1591
1718
  # Store the current display state
@@ -1652,6 +1779,14 @@ class NotebookTaskProgress(_TaskStateMixin):
1652
1779
  self._current_display = ""
1653
1780
 
1654
1781
  def __enter__(self):
1782
+ # Handle nesting: pause any parent progress
1783
+ self._parent_progress = get_current_progress()
1784
+ if self._parent_progress is not None:
1785
+ self._parent_progress._pause()
1786
+
1787
+ # Set ourselves as the current progress
1788
+ self._context_token = _set_current_progress(self)
1789
+
1655
1790
  # Skip leading newline for Jupyter - it interferes with IPython display
1656
1791
  if self.leading_newline and not self.is_jupyter:
1657
1792
  rich.print()
@@ -1675,54 +1810,68 @@ class NotebookTaskProgress(_TaskStateMixin):
1675
1810
  # For Jupyter, IPython display handles formatting
1676
1811
  if not self.is_jupyter:
1677
1812
  rich.print(file=sys.stderr)
1678
- return True
1679
- return False
1680
- if self.delay: # will be None for non-interactive environments
1681
- time.sleep(self.delay)
1682
-
1683
- # Generate summary BEFORE clearing the spinner line (so we have timing data)
1684
- # Only generate if show_duration_summary flag is True
1685
- summary = self.generate_summary() if self.show_duration_summary else ""
1686
-
1687
- # Clear the spinner line completely
1688
- self._clear_spinner_line()
1689
-
1690
- final_message: str | None = None
1691
- if self.success_message:
1692
- final_message = f"{SUCCESS_ICON} {self.success_message}"
1693
- elif summary:
1694
- final_message = f"{SUCCESS_ICON} Done"
1695
-
1696
- if final_message:
1697
- if self.is_jupyter:
1698
- if self.display is not None:
1699
- self.display.update({"text/plain": final_message}, raw=True)
1813
+ result = True
1814
+ else:
1815
+ result = False
1816
+ else:
1817
+ if self.delay: # will be None for non-interactive environments
1818
+ time.sleep(self.delay)
1819
+
1820
+ # Generate summary BEFORE clearing the spinner line (so we have timing data)
1821
+ # Only generate if show_duration_summary flag is True
1822
+ summary = self.generate_summary() if self.show_duration_summary else ""
1823
+
1824
+ # Clear the spinner line completely
1825
+ self._clear_spinner_line()
1826
+
1827
+ final_message: str | None = None
1828
+ if self.success_message:
1829
+ final_message = f"{SUCCESS_ICON} {self.success_message}"
1830
+ elif summary:
1831
+ final_message = f"{SUCCESS_ICON} Done"
1832
+
1833
+ if final_message:
1834
+ if self.is_jupyter:
1835
+ if self.display is not None:
1836
+ self.display.update({"text/plain": final_message}, raw=True)
1837
+ else:
1838
+ _, display_fn = _load_ipython_display()
1839
+ self.display = display_fn({"text/plain": final_message}, display_id=True, raw=True)
1700
1840
  else:
1841
+ print(final_message)
1842
+ elif self.success_message == "":
1843
+ # When there's no success message, clear the display for notebooks
1844
+ # The summary will be printed below if available
1845
+ if self.is_jupyter:
1846
+ self.update("")
1847
+ # For non-Jupyter notebooks, _clear_spinner_line() already handled it
1848
+
1849
+ # Print summary if there are completed tasks
1850
+ if summary:
1851
+ if self.is_jupyter:
1852
+ # Use IPython display to avoid blank stdout lines in notebooks
1701
1853
  _, display_fn = _load_ipython_display()
1702
- self.display = display_fn({"text/plain": final_message}, display_id=True, raw=True)
1703
- else:
1704
- print(final_message)
1705
- elif self.success_message == "":
1706
- # When there's no success message, clear the display for notebooks
1707
- # The summary will be printed below if available
1708
- if self.is_jupyter:
1709
- self.update("")
1710
- # For non-Jupyter notebooks, _clear_spinner_line() already handled it
1854
+ display_fn({"text/plain": summary.strip()}, raw=True)
1855
+ else:
1856
+ print()
1857
+ print(summary.strip()) # Summary includes visual separator line
1711
1858
 
1712
- # Print summary if there are completed tasks
1713
- if summary:
1714
- if self.is_jupyter:
1715
- # Use IPython display to avoid blank stdout lines in notebooks
1716
- _, display_fn = _load_ipython_display()
1717
- display_fn({"text/plain": summary.strip()}, raw=True)
1718
- else:
1719
- print()
1720
- print(summary.strip()) # Summary includes visual separator line
1859
+ # Skip trailing newline for Jupyter - it interferes with IPython display
1860
+ if self.trailing_newline and not self.is_jupyter:
1861
+ rich.print()
1862
+ result = True
1721
1863
 
1722
- # Skip trailing newline for Jupyter - it interferes with IPython display
1723
- if self.trailing_newline and not self.is_jupyter:
1724
- rich.print()
1725
- return True
1864
+ # Restore the parent progress as current
1865
+ if self._context_token is not None:
1866
+ _current_progress.reset(self._context_token)
1867
+ self._context_token = None
1868
+
1869
+ # Resume the parent progress if there was one
1870
+ if self._parent_progress is not None:
1871
+ self._parent_progress._resume()
1872
+ self._parent_progress = None
1873
+
1874
+ return result
1726
1875
 
1727
1876
  def _clear_spinner_line(self):
1728
1877
  """Clear the current spinner line completely."""