zrb 1.5.8__py3-none-any.whl → 1.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,274 @@
1
+ import asyncio
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ from zrb.context.any_context import AnyContext
5
+ from zrb.session.any_session import AnySession
6
+ from zrb.util.attr import get_bool_attr
7
+ from zrb.util.run import run_async
8
+ from zrb.xcom.xcom import Xcom
9
+
10
+ if TYPE_CHECKING:
11
+ from zrb.task.base_task import BaseTask
12
+
13
+
14
+ async def execute_task_chain(task: "BaseTask", session: AnySession):
15
+ """
16
+ Executes the task and its downstream successors if conditions are met.
17
+ """
18
+ if session.is_terminated or not session.is_allowed_to_run(task):
19
+ return
20
+ result = await execute_task_action(task, session)
21
+ # Get next tasks
22
+ nexts = session.get_next_tasks(task)
23
+ if session.is_terminated or len(nexts) == 0:
24
+ return result
25
+ # Run next tasks asynchronously
26
+ next_coros = [run_async(next_task.exec_chain(session)) for next_task in nexts]
27
+ # Wait for the next tasks to complete. The result of the current task is returned.
28
+ await asyncio.gather(*next_coros)
29
+ return result
30
+
31
+
32
+ async def execute_task_action(task: "BaseTask", session: AnySession):
33
+ """
34
+ Executes a single task's action, handling conditions and readiness checks.
35
+ """
36
+ ctx = task.get_ctx(session)
37
+ if not session.is_allowed_to_run(task):
38
+ # Task is not allowed to run, skip it for now.
39
+ # This will be triggered later if dependencies are met.
40
+ ctx.log_info("Not allowed to run")
41
+ return
42
+ if not check_execute_condition(task, session):
43
+ # Skip the task based on its execute_condition
44
+ ctx.log_info("Marked as skipped (condition false)")
45
+ session.get_task_status(task).mark_as_skipped()
46
+ return
47
+ # Wait for task to be ready (handles action execution, readiness checks)
48
+ return await run_async(execute_action_until_ready(task, session))
49
+
50
+
51
+ def check_execute_condition(task: "BaseTask", session: AnySession) -> bool:
52
+ """
53
+ Evaluates the task's execute_condition attribute.
54
+ """
55
+ ctx = task.get_ctx(session)
56
+ execute_condition_attr = getattr(task, "_execute_condition", True)
57
+ return get_bool_attr(ctx, execute_condition_attr, True, auto_render=True)
58
+
59
+
60
+ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
61
+ """
62
+ Manages the execution of the task's action, coordinating with readiness checks.
63
+ """
64
+ ctx = task.get_ctx(session)
65
+ readiness_checks = task.readiness_checks
66
+ readiness_check_delay = getattr(task, "_readiness_check_delay", 0.5)
67
+ monitor_readiness = getattr(task, "_monitor_readiness", False)
68
+
69
+ if not readiness_checks: # Simplified check for empty list
70
+ ctx.log_info("No readiness checks")
71
+ # Task has no readiness check, execute action directly
72
+ result = await run_async(execute_action_with_retry(task, session))
73
+ # Mark ready only if the action completed successfully (not failed/cancelled)
74
+ if session.get_task_status(task).is_completed:
75
+ ctx.log_info("Marked as ready")
76
+ session.get_task_status(task).mark_as_ready()
77
+ return result
78
+
79
+ # Start the task action and readiness checks concurrently
80
+ ctx.log_info("Starting action and readiness checks")
81
+ action_coro = asyncio.create_task(
82
+ run_async(execute_action_with_retry(task, session))
83
+ )
84
+
85
+ await asyncio.sleep(readiness_check_delay)
86
+
87
+ readiness_check_coros = [
88
+ run_async(check.exec_chain(session)) for check in readiness_checks
89
+ ]
90
+
91
+ # Wait primarily for readiness checks to complete
92
+ ctx.log_info("Waiting for readiness checks")
93
+ readiness_passed = False
94
+ try:
95
+ # Gather results, but primarily interested in completion/errors
96
+ await asyncio.gather(*readiness_check_coros)
97
+ # Check if all readiness tasks actually completed successfully
98
+ all_readiness_completed = all(
99
+ session.get_task_status(check).is_completed for check in readiness_checks
100
+ )
101
+ if all_readiness_completed:
102
+ ctx.log_info("Readiness checks completed successfully")
103
+ readiness_passed = True
104
+ # Mark task as ready only if checks passed and action didn't fail during checks
105
+ if not session.get_task_status(task).is_failed:
106
+ ctx.log_info("Marked as ready")
107
+ session.get_task_status(task).mark_as_ready()
108
+ else:
109
+ ctx.log_warning(
110
+ "One or more readiness checks did not complete successfully."
111
+ )
112
+
113
+ except Exception as e:
114
+ ctx.log_error(f"Readiness check failed with exception: {e}")
115
+ # If readiness checks fail with an exception, the task is not ready.
116
+ # The action_coro might still be running or have failed.
117
+ # execute_action_with_retry handles marking the main task status.
118
+
119
+ # Defer the main action coroutine; it will be awaited later if needed
120
+ session.defer_action(task, action_coro)
121
+
122
+ # Start monitoring only if readiness passed and monitoring is enabled
123
+ if readiness_passed and monitor_readiness:
124
+ # Import dynamically to avoid circular dependency if monitoring imports execution
125
+ from zrb.task.base.monitoring import monitor_task_readiness
126
+
127
+ monitor_coro = asyncio.create_task(
128
+ run_async(monitor_task_readiness(task, session, action_coro))
129
+ )
130
+ session.defer_monitoring(task, monitor_coro)
131
+
132
+ # The result here is primarily about readiness check completion.
133
+ # The actual task result is handled by the deferred action_coro.
134
+ return None
135
+
136
+
137
+ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> Any:
138
+ """
139
+ Executes the task's core action (`_exec_action`) with retry logic,
140
+ handling success (triggering successors) and failure (triggering fallbacks).
141
+ """
142
+ ctx = task.get_ctx(session)
143
+ retries = getattr(task, "_retries", 2)
144
+ retry_period = getattr(task, "_retry_period", 0)
145
+ max_attempt = retries + 1
146
+ ctx.set_max_attempt(max_attempt)
147
+
148
+ for attempt in range(max_attempt):
149
+ ctx.set_attempt(attempt + 1)
150
+ if attempt > 0:
151
+ ctx.log_info(f"Retrying in {retry_period}s...")
152
+ await asyncio.sleep(retry_period)
153
+
154
+ try:
155
+ ctx.log_info("Marked as started")
156
+ session.get_task_status(task).mark_as_started()
157
+
158
+ # Execute the underlying action (which might be overridden in subclasses)
159
+ # We call the task's _exec_action method directly here.
160
+ result = await run_async(task._exec_action(ctx))
161
+
162
+ ctx.log_info("Marked as completed")
163
+ session.get_task_status(task).mark_as_completed()
164
+
165
+ # Store result in XCom
166
+ task_xcom: Xcom = ctx.xcom.get(task.name)
167
+ task_xcom.push(result)
168
+
169
+ # Skip fallbacks and execute successors on success
170
+ skip_fallbacks(task, session)
171
+ await run_async(execute_successors(task, session))
172
+ return result
173
+
174
+ except (asyncio.CancelledError, KeyboardInterrupt):
175
+ ctx.log_warning("Task cancelled or interrupted")
176
+ session.get_task_status(task).mark_as_failed() # Mark as failed on cancel
177
+ # Do not trigger fallbacks/successors on cancellation
178
+ return # Or re-raise? Depends on desired cancellation behavior
179
+
180
+ except BaseException as e:
181
+ ctx.log_error(f"Attempt {attempt + 1}/{max_attempt} failed: {e}")
182
+ session.get_task_status(
183
+ task
184
+ ).mark_as_failed() # Mark failed for this attempt
185
+
186
+ if attempt < max_attempt - 1:
187
+ # More retries available
188
+ continue
189
+ else:
190
+ # Final attempt failed
191
+ ctx.log_error("Marked as permanently failed")
192
+ session.get_task_status(task).mark_as_permanently_failed()
193
+ # Skip successors and execute fallbacks on permanent failure
194
+ skip_successors(task, session)
195
+ await run_async(execute_fallbacks(task, session))
196
+ raise e # Re-raise the exception after handling fallbacks
197
+
198
+
199
+ async def run_default_action(task: "BaseTask", ctx: AnyContext) -> Any:
200
+ """
201
+ Executes the specific action defined by the '_action' attribute for BaseTask.
202
+ This is the default implementation called by BaseTask._exec_action.
203
+ Subclasses like LLMTask override _exec_action with their own logic.
204
+ """
205
+ action = getattr(task, "_action", None)
206
+ if action is None:
207
+ ctx.log_debug("No action defined for this task.")
208
+ return None
209
+ if isinstance(action, str):
210
+ # Render f-string action
211
+ rendered_action = ctx.render(action)
212
+ ctx.log_debug(f"Rendered action string: {rendered_action}")
213
+ # Assuming string actions are meant to be returned as is.
214
+ # If they need execution (e.g., shell commands), that logic would go here.
215
+ return rendered_action
216
+ elif callable(action):
217
+ # Execute callable action
218
+ ctx.log_debug(f"Executing callable action: {action.__name__}")
219
+ return await run_async(action(ctx))
220
+ else:
221
+ ctx.log_warning(f"Unsupported action type: {type(action)}")
222
+ return None
223
+
224
+
225
+ async def execute_successors(task: "BaseTask", session: "AnySession"):
226
+ """Executes all successor tasks."""
227
+ ctx = task.get_ctx(session)
228
+ successors = task.successors
229
+ if successors:
230
+ ctx.log_info(f"Executing {len(successors)} successor(s)")
231
+ successor_coros = [
232
+ run_async(successor.exec_chain(session)) for successor in successors
233
+ ]
234
+ await asyncio.gather(*successor_coros)
235
+ else:
236
+ ctx.log_debug("No successors to execute.")
237
+
238
+
239
+ def skip_successors(task: "BaseTask", session: AnySession):
240
+ """Marks all successor tasks as skipped."""
241
+ ctx = task.get_ctx(session)
242
+ successors = task.successors
243
+ if successors:
244
+ ctx.log_info(f"Skipping {len(successors)} successor(s)")
245
+ for successor in successors:
246
+ # Check if already skipped to avoid redundant logging/state changes
247
+ if not session.get_task_status(successor).is_skipped:
248
+ session.get_task_status(successor).mark_as_skipped()
249
+
250
+
251
+ async def execute_fallbacks(task: "BaseTask", session: AnySession):
252
+ """Executes all fallback tasks."""
253
+ ctx = task.get_ctx(session)
254
+ fallbacks = task.fallbacks
255
+ if fallbacks:
256
+ ctx.log_info(f"Executing {len(fallbacks)} fallback(s)")
257
+ fallback_coros = [
258
+ run_async(fallback.exec_chain(session)) for fallback in fallbacks
259
+ ]
260
+ await asyncio.gather(*fallback_coros)
261
+ else:
262
+ ctx.log_debug("No fallbacks to execute.")
263
+
264
+
265
+ def skip_fallbacks(task: "BaseTask", session: AnySession):
266
+ """Marks all fallback tasks as skipped."""
267
+ ctx = task.get_ctx(session)
268
+ fallbacks = task.fallbacks
269
+ if fallbacks:
270
+ ctx.log_info(f"Skipping {len(fallbacks)} fallback(s)")
271
+ for fallback in fallbacks:
272
+ # Check if already skipped
273
+ if not session.get_task_status(fallback).is_skipped:
274
+ session.get_task_status(fallback).mark_as_skipped()
@@ -0,0 +1,182 @@
1
+ import asyncio
2
+ from typing import Any
3
+
4
+ from zrb.context.shared_context import SharedContext
5
+ from zrb.session.any_session import AnySession
6
+ from zrb.session.session import Session
7
+ from zrb.task.any_task import AnyTask
8
+ from zrb.task.base.context import fill_shared_context_envs, fill_shared_context_inputs
9
+ from zrb.util.run import run_async
10
+
11
+
12
+ async def run_and_cleanup(
13
+ task: AnyTask,
14
+ session: AnySession | None = None,
15
+ str_kwargs: dict[str, str] = {},
16
+ ) -> Any:
17
+ """
18
+ Wrapper for async_run that ensures session termination and cleanup of
19
+ other concurrent asyncio tasks. This is the main entry point for `task.run()`.
20
+ """
21
+ # Ensure a session exists
22
+ if session is None:
23
+ session = Session(shared_ctx=SharedContext())
24
+
25
+ # Create the main task execution coroutine
26
+ main_task_coro = asyncio.create_task(run_task_async(task, session, str_kwargs))
27
+
28
+ try:
29
+ result = await main_task_coro
30
+ return result
31
+ except (asyncio.CancelledError, KeyboardInterrupt) as e:
32
+ ctx = task.get_ctx(session) # Get context for logging
33
+ ctx.log_warning(f"Run cancelled/interrupted: {e}")
34
+ raise # Re-raise to propagate
35
+ finally:
36
+ # Ensure session termination if it exists and wasn't terminated by the run
37
+ if session and not session.is_terminated:
38
+ ctx = task.get_ctx(session) # Get context for logging
39
+ ctx.log_info("Terminating session after run completion/error.")
40
+ session.terminate()
41
+
42
+ # Clean up other potentially running asyncio tasks (excluding the main one)
43
+ # Be cautious with blanket cancellation if other background tasks are expected
44
+ pending = [
45
+ t for t in asyncio.all_tasks() if t is not main_task_coro and not t.done()
46
+ ]
47
+ if pending:
48
+ ctx = task.get_ctx(session) # Get context for logging
49
+ ctx.log_debug(f"Cleaning up {len(pending)} pending asyncio tasks...")
50
+ for t in pending:
51
+ t.cancel()
52
+ try:
53
+ # Give cancelled tasks a moment to process cancellation
54
+ await asyncio.wait(pending, timeout=1.0)
55
+ except asyncio.CancelledError:
56
+ # Expected if tasks handle cancellation promptly
57
+ pass
58
+ except Exception as cleanup_exc:
59
+ # Log errors during cleanup if necessary
60
+ ctx.log_warning(f"Error during task cleanup: {cleanup_exc}")
61
+
62
+
63
+ async def run_task_async(
64
+ task: AnyTask,
65
+ session: AnySession | None = None,
66
+ str_kwargs: dict[str, str] = {},
67
+ ) -> Any:
68
+ """
69
+ Asynchronous entry point for running a task (`task.async_run()`).
70
+ Sets up the session and initiates the root task execution chain.
71
+ """
72
+ if session is None:
73
+ session = Session(shared_ctx=SharedContext())
74
+
75
+ # Populate shared context with inputs and environment variables
76
+ fill_shared_context_inputs(task, session.shared_ctx, str_kwargs)
77
+ fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
78
+
79
+ # Start the execution chain from the root tasks
80
+ result = await task.exec_root_tasks(session)
81
+ return result
82
+
83
+
84
+ async def execute_root_tasks(task: AnyTask, session: AnySession):
85
+ """
86
+ Identifies and executes the root tasks required for the main task,
87
+ manages session state logging, and handles overall execution flow.
88
+ """
89
+ session.set_main_task(task)
90
+ session.state_logger.write(session.as_state_log()) # Initial state log
91
+ ctx = task.get_ctx(session) # Get context early for logging
92
+
93
+ log_state_task = None
94
+ try:
95
+ # Start background state logging
96
+ log_state_task = asyncio.create_task(log_session_state(task, session))
97
+
98
+ # Identify root tasks allowed to run
99
+ root_tasks = [
100
+ t for t in session.get_root_tasks(task) if session.is_allowed_to_run(t)
101
+ ]
102
+
103
+ if not root_tasks:
104
+ ctx.log_info("No root tasks to execute for this task.")
105
+ # If the main task itself should run even with no explicit roots?
106
+ # Current logic seems to imply if no roots, nothing runs.
107
+ session.terminate() # Terminate if nothing to run
108
+ return None
109
+
110
+ ctx.log_info(f"Executing {len(root_tasks)} root task(s)")
111
+ root_task_coros = [
112
+ # Assuming exec_chain exists on AnyTask (it's abstract)
113
+ run_async(root_task.exec_chain(session))
114
+ for root_task in root_tasks
115
+ ]
116
+
117
+ # Wait for all root chains to complete
118
+ await asyncio.gather(*root_task_coros)
119
+
120
+ # Wait for any deferred actions (like long-running task bodies)
121
+ ctx.log_info("Waiting for deferred actions...")
122
+ await session.wait_deferred()
123
+ ctx.log_info("Deferred actions complete.")
124
+
125
+ # Final termination and logging
126
+ session.terminate()
127
+ if log_state_task and not log_state_task.done():
128
+ await log_state_task # Ensure final state is logged
129
+ ctx.log_info("Session finished.")
130
+ return session.final_result
131
+
132
+ except IndexError:
133
+ # This might occur if get_root_tasks fails unexpectedly
134
+ ctx.log_error(
135
+ "IndexError during root task execution, potentially session issue."
136
+ )
137
+ session.terminate() # Ensure termination on error
138
+ return None
139
+ except (asyncio.CancelledError, KeyboardInterrupt):
140
+ ctx.log_warning("Session execution cancelled or interrupted.")
141
+ # Session termination happens in finally block
142
+ return None # Indicate abnormal termination
143
+ finally:
144
+ # Ensure termination and final state logging regardless of outcome
145
+ if not session.is_terminated:
146
+ session.terminate()
147
+ # Ensure the state logger task is awaited/cancelled properly
148
+ if log_state_task:
149
+ if not log_state_task.done():
150
+ log_state_task.cancel()
151
+ try:
152
+ await log_state_task
153
+ except asyncio.CancelledError:
154
+ pass # Expected cancellation
155
+ # Log final state after ensuring logger task is finished
156
+ session.state_logger.write(session.as_state_log())
157
+ else:
158
+ # Log final state even if logger task didn't start
159
+ session.state_logger.write(session.as_state_log())
160
+
161
+ ctx.log_debug(f"Final session state: {session}") # Log final session details
162
+
163
+
164
+ async def log_session_state(task: AnyTask, session: AnySession):
165
+ """
166
+ Periodically logs the session state until the session is terminated.
167
+ """
168
+ try:
169
+ while not session.is_terminated:
170
+ session.state_logger.write(session.as_state_log())
171
+ await asyncio.sleep(0.1) # Log interval
172
+ # Log one final time after termination signal
173
+ session.state_logger.write(session.as_state_log())
174
+ except asyncio.CancelledError:
175
+ # Log final state on cancellation too
176
+ session.state_logger.write(session.as_state_log())
177
+ ctx = task.get_ctx(session)
178
+ ctx.log_debug("Session state logger cancelled.")
179
+ except Exception as e:
180
+ # Log any unexpected errors in the logger itself
181
+ ctx = task.get_ctx(session)
182
+ ctx.log_error(f"Error in session state logger: {e}")
@@ -0,0 +1,134 @@
1
+ import asyncio
2
+
3
+ from zrb.session.any_session import AnySession
4
+ from zrb.task.base.execution import execute_action_with_retry
5
+ from zrb.task.base_task import BaseTask
6
+ from zrb.util.run import run_async
7
+ from zrb.xcom.xcom import Xcom
8
+
9
+
10
+ async def monitor_task_readiness(
11
+ task: BaseTask, session: AnySession, action_coro: asyncio.Task
12
+ ):
13
+ """
14
+ Monitors the readiness of a task after its initial execution.
15
+ If readiness checks fail beyond a threshold, it cancels the original action,
16
+ resets the task status, and re-executes the action.
17
+ """
18
+ ctx = task.get_ctx(session)
19
+ readiness_checks = task.readiness_checks
20
+ readiness_check_period = getattr(task, "_readiness_check_period", 5.0)
21
+ readiness_failure_threshold = getattr(task, "_readiness_failure_threshold", 1)
22
+ readiness_timeout = getattr(task, "_readiness_timeout", 60)
23
+
24
+ if not readiness_checks:
25
+ ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
26
+ return
27
+
28
+ failure_count = 0
29
+ ctx.log_info("Starting readiness monitoring...")
30
+
31
+ while not session.is_terminated:
32
+ await asyncio.sleep(readiness_check_period)
33
+
34
+ if session.is_terminated:
35
+ break # Exit loop if session terminated during sleep
36
+
37
+ if failure_count < readiness_failure_threshold:
38
+ ctx.log_info("Performing periodic readiness check...")
39
+ # Reset status and XCom for readiness check tasks before re-running
40
+ for check in readiness_checks:
41
+ session.get_task_status(check).reset_history()
42
+ session.get_task_status(check).reset()
43
+ # Clear previous XCom data for the check task if needed
44
+ check_xcom: Xcom = ctx.xcom.get(check.name)
45
+ check_xcom.clear()
46
+
47
+ readiness_check_coros = [
48
+ run_async(check.exec_chain(session)) for check in readiness_checks
49
+ ]
50
+
51
+ try:
52
+ # Wait for checks with a timeout
53
+ await asyncio.wait_for(
54
+ asyncio.gather(*readiness_check_coros),
55
+ timeout=readiness_timeout,
56
+ )
57
+ # Check if all checks actually completed successfully
58
+ all_checks_completed = all(
59
+ session.get_task_status(check).is_completed
60
+ for check in readiness_checks
61
+ )
62
+ if all_checks_completed:
63
+ ctx.log_info("Readiness check OK.")
64
+ failure_count = 0 # Reset failure count on success
65
+ continue # Continue monitoring
66
+ else:
67
+ ctx.log_warning(
68
+ "Periodic readiness check failed (tasks did not complete)."
69
+ )
70
+ failure_count += 1
71
+
72
+ except asyncio.TimeoutError:
73
+ failure_count += 1
74
+ ctx.log_warning(
75
+ f"Readiness check timed out ({readiness_timeout}s). "
76
+ f"Failure count: {failure_count}/{readiness_failure_threshold}"
77
+ )
78
+ # Ensure check tasks are marked as failed on timeout
79
+ for check in readiness_checks:
80
+ if not session.get_task_status(check).is_finished:
81
+ session.get_task_status(check).mark_as_failed()
82
+
83
+ except (asyncio.CancelledError, KeyboardInterrupt):
84
+ ctx.log_info("Monitoring cancelled or interrupted.")
85
+ break # Exit monitoring loop
86
+
87
+ except Exception as e:
88
+ failure_count += 1
89
+ ctx.log_error(
90
+ f"Readiness check failed with exception: {e}. "
91
+ f"Failure count: {failure_count}"
92
+ )
93
+ # Mark checks as failed
94
+ for check in readiness_checks:
95
+ if not session.get_task_status(check).is_finished:
96
+ session.get_task_status(check).mark_as_failed()
97
+
98
+ # If failure threshold is reached
99
+ if failure_count >= readiness_failure_threshold:
100
+ ctx.log_warning(
101
+ f"Readiness failure threshold ({readiness_failure_threshold}) reached."
102
+ )
103
+
104
+ # Cancel the original running action if it's still running
105
+ if action_coro and not action_coro.done():
106
+ ctx.log_info("Cancelling original task action...")
107
+ action_coro.cancel()
108
+ try:
109
+ await action_coro # Allow cancellation to process
110
+ except asyncio.CancelledError:
111
+ ctx.log_info("Original task action cancelled.")
112
+ except Exception as e:
113
+ ctx.log_warning(f"Error during original action cancellation: {e}")
114
+
115
+ # Reset the main task status
116
+ ctx.log_info("Resetting task status.")
117
+ session.get_task_status(task).reset()
118
+
119
+ # Re-execute the action (with retries)
120
+ ctx.log_info("Re-executing task action...")
121
+ # Import dynamically to avoid circular dependency
122
+ new_action_coro = asyncio.create_task(
123
+ run_async(execute_action_with_retry(task, session))
124
+ )
125
+ # Defer the new action coroutine
126
+ session.defer_action(task, new_action_coro)
127
+ # Update the reference for the next monitoring cycle
128
+ action_coro = new_action_coro
129
+
130
+ # Reset failure count after attempting restart
131
+ failure_count = 0
132
+ ctx.log_info("Continuing monitoring...")
133
+
134
+ ctx.log_info("Stopping readiness monitoring.")
@@ -0,0 +1,41 @@
1
+ # No specific imports needed from typing for these changes
2
+ from zrb.task.any_task import AnyTask
3
+
4
+
5
+ def handle_rshift(
6
+ left_task: AnyTask, right_operand: AnyTask | list[AnyTask]
7
+ ) -> AnyTask | list[AnyTask]:
8
+ """
9
+ Implements the >> operator logic: left_task becomes an upstream for right_operand.
10
+ Modifies the right_operand(s) by calling append_upstream.
11
+ Returns the right_operand.
12
+ """
13
+ try:
14
+ if isinstance(right_operand, list):
15
+ for task in right_operand:
16
+ # Assuming append_upstream exists and handles duplicates
17
+ task.append_upstream(left_task)
18
+ else:
19
+ # Assuming right_operand is a single AnyTask
20
+ right_operand.append_upstream(left_task)
21
+ return right_operand
22
+ except Exception as e:
23
+ # Catch potential errors during append_upstream or type issues
24
+ raise ValueError(f"Invalid operation {left_task} >> {right_operand}: {e}")
25
+
26
+
27
+ def handle_lshift(
28
+ left_task: AnyTask, right_operand: AnyTask | list[AnyTask]
29
+ ) -> AnyTask:
30
+ """
31
+ Implements the << operator logic: right_operand becomes an upstream for left_task.
32
+ Modifies the left_task by calling append_upstream.
33
+ Returns the left_task.
34
+ """
35
+ try:
36
+ # Assuming append_upstream exists and handles single or list input
37
+ left_task.append_upstream(right_operand)
38
+ return left_task
39
+ except Exception as e:
40
+ # Catch potential errors during append_upstream or type issues
41
+ raise ValueError(f"Invalid operation {left_task} << {right_operand}: {e}")