zrb 1.5.7__py3-none-any.whl → 1.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zrb/task/base_task.py CHANGED
@@ -1,19 +1,25 @@
1
1
  import asyncio
2
- import os
3
- from collections.abc import Callable, Coroutine
2
+ from collections.abc import Callable
4
3
  from typing import Any
5
4
 
6
5
  from zrb.attr.type import BoolAttr, fstring
7
6
  from zrb.context.any_context import AnyContext
8
- from zrb.context.shared_context import AnySharedContext, SharedContext
9
7
  from zrb.env.any_env import AnyEnv
10
8
  from zrb.input.any_input import AnyInput
11
9
  from zrb.session.any_session import AnySession
12
- from zrb.session.session import Session
13
10
  from zrb.task.any_task import AnyTask
14
- from zrb.util.attr import get_bool_attr
15
- from zrb.util.run import run_async
16
- from zrb.xcom.xcom import Xcom
11
+ from zrb.task.base.context import (
12
+ build_task_context,
13
+ get_combined_envs,
14
+ get_combined_inputs,
15
+ )
16
+ from zrb.task.base.execution import (
17
+ execute_task_action,
18
+ execute_task_chain,
19
+ run_default_action,
20
+ )
21
+ from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
22
+ from zrb.task.base.operators import handle_lshift, handle_rshift
17
23
 
18
24
 
19
25
  class BaseTask(AnyTask):
@@ -62,25 +68,13 @@ class BaseTask(AnyTask):
62
68
  self._action = action
63
69
 
64
70
  def __repr__(self):
65
- return f"<{self.__class__.__name__} name={self._name}>"
71
+ return f"<{self.__class__.__name__} name={self.name}>"
66
72
 
67
73
  def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask | list[AnyTask]:
68
- try:
69
- if isinstance(other, AnyTask):
70
- other.append_upstream(self)
71
- elif isinstance(other, list):
72
- for task in other:
73
- task.append_upstream(self)
74
- return other
75
- except Exception as e:
76
- raise ValueError(f"Invalid operation {self} >> {other}: {e}")
74
+ return handle_rshift(self, other)
77
75
 
78
76
  def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
79
- try:
80
- self.append_upstream(other)
81
- return self
82
- except Exception as e:
83
- raise ValueError(f"Invalid operation {self} << {other}: {e}")
77
+ return handle_lshift(self, other)
84
78
 
85
79
  @property
86
80
  def name(self) -> str:
@@ -104,430 +98,130 @@ class BaseTask(AnyTask):
104
98
 
105
99
  @property
106
100
  def envs(self) -> list[AnyEnv]:
107
- envs = []
108
- for upstream in self.upstreams:
109
- envs += upstream.envs
110
- if isinstance(self._envs, AnyEnv):
111
- envs.append(self._envs)
112
- elif self._envs is not None:
113
- envs += self._envs
114
- return [env for env in envs if env is not None]
101
+ return get_combined_envs(self)
115
102
 
116
103
  @property
117
104
  def inputs(self) -> list[AnyInput]:
118
- inputs = []
119
- for upstream in self.upstreams:
120
- self.__combine_inputs(inputs, upstream.inputs)
121
- if self._inputs is not None:
122
- self.__combine_inputs(inputs, self._inputs)
123
- return [task_input for task_input in inputs if task_input is not None]
124
-
125
- def __combine_inputs(
126
- self,
127
- inputs: list[AnyInput],
128
- other_inputs: list[AnyInput | None] | AnyInput | None,
129
- ):
130
- input_names = [task_input.name for task_input in inputs]
131
- if isinstance(other_inputs, AnyInput):
132
- other_inputs = [other_inputs]
133
- elif other_inputs is None:
134
- other_inputs = []
135
- for task_input in other_inputs:
136
- if task_input is None:
137
- continue
138
- if task_input.name not in input_names:
139
- inputs.append(task_input)
105
+ return get_combined_inputs(self)
140
106
 
141
107
  @property
142
108
  def fallbacks(self) -> list[AnyTask]:
109
+ """Returns the list of fallback tasks."""
143
110
  if self._fallbacks is None:
144
111
  return []
145
- elif isinstance(self._fallbacks, AnyTask):
146
- return [self._fallbacks]
147
- return self._fallbacks
112
+ elif isinstance(self._fallbacks, list):
113
+ return self._fallbacks
114
+ return [self._fallbacks] # Assume single task
148
115
 
149
116
  def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
150
- fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks
151
- for fallback in fallback_list:
152
- self.__append_fallback(fallback)
153
-
154
- def __append_fallback(self, fallback: AnyTask):
155
- # Make sure self._fallbacks is a list
117
+ """Appends fallback tasks, ensuring no duplicates."""
156
118
  if self._fallbacks is None:
157
119
  self._fallbacks = []
158
- elif isinstance(self._fallbacks, AnyTask):
120
+ elif not isinstance(self._fallbacks, list):
159
121
  self._fallbacks = [self._fallbacks]
160
- # Add fallback if it was not on self._fallbacks
161
- if fallback not in self._fallbacks:
162
- self._fallbacks.append(fallback)
122
+ to_add = fallbacks if isinstance(fallbacks, list) else [fallbacks]
123
+ for fb in to_add:
124
+ if fb not in self._fallbacks:
125
+ self._fallbacks.append(fb)
163
126
 
164
127
  @property
165
128
  def successors(self) -> list[AnyTask]:
129
+ """Returns the list of successor tasks."""
166
130
  if self._successors is None:
167
131
  return []
168
- elif isinstance(self._successors, AnyTask):
169
- return [self._successors]
170
- return self._successors
132
+ elif isinstance(self._successors, list):
133
+ return self._successors
134
+ return [self._successors] # Assume single task
171
135
 
172
136
  def append_successor(self, successors: AnyTask | list[AnyTask]):
173
- successor_list = [successors] if isinstance(successors, AnyTask) else successors
174
- for successor in successor_list:
175
- self.__append_successor(successor)
176
-
177
- def __append_successor(self, successor: AnyTask):
178
- # Make sure self._successors is a list
137
+ """Appends successor tasks, ensuring no duplicates."""
179
138
  if self._successors is None:
180
139
  self._successors = []
181
- elif isinstance(self._successors, AnyTask):
140
+ elif not isinstance(self._successors, list):
182
141
  self._successors = [self._successors]
183
- # Add successor if it was not on self._successors
184
- if successor not in self._successors:
185
- self._successors.append(successor)
142
+ to_add = successors if isinstance(successors, list) else [successors]
143
+ for succ in to_add:
144
+ if succ not in self._successors:
145
+ self._successors.append(succ)
186
146
 
187
147
  @property
188
148
  def readiness_checks(self) -> list[AnyTask]:
149
+ """Returns the list of readiness check tasks."""
189
150
  if self._readiness_checks is None:
190
151
  return []
191
- elif isinstance(self._readiness_checks, AnyTask):
192
- return [self._readiness_checks]
193
- return self._readiness_checks
152
+ elif isinstance(self._readiness_checks, list):
153
+ return self._readiness_checks
154
+ return [self._readiness_checks] # Assume single task
194
155
 
195
156
  def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
196
- readiness_check_list = (
197
- [readiness_checks]
198
- if isinstance(readiness_checks, AnyTask)
199
- else readiness_checks
200
- )
201
- for readiness_check in readiness_check_list:
202
- self.__append_readiness_check(readiness_check)
203
-
204
- def __append_readiness_check(self, readiness_check: AnyTask):
205
- # Make sure self._readiness_checks is a list
157
+ """Appends readiness check tasks, ensuring no duplicates."""
206
158
  if self._readiness_checks is None:
207
159
  self._readiness_checks = []
208
- elif isinstance(self._readiness_checks, AnyTask):
160
+ elif not isinstance(self._readiness_checks, list):
209
161
  self._readiness_checks = [self._readiness_checks]
210
- # Add readiness_check if it was not on self._readiness_checks
211
- if readiness_check not in self._readiness_checks:
212
- self._readiness_checks.append(readiness_check)
162
+ to_add = (
163
+ readiness_checks
164
+ if isinstance(readiness_checks, list)
165
+ else [readiness_checks]
166
+ )
167
+ for rc in to_add:
168
+ if rc not in self._readiness_checks:
169
+ self._readiness_checks.append(rc)
213
170
 
214
171
  @property
215
172
  def upstreams(self) -> list[AnyTask]:
173
+ """Returns the list of upstream tasks."""
216
174
  if self._upstreams is None:
217
175
  return []
218
- elif isinstance(self._upstreams, AnyTask):
219
- return [self._upstreams]
220
- return self._upstreams
176
+ elif isinstance(self._upstreams, list):
177
+ return self._upstreams
178
+ return [self._upstreams] # Assume single task
221
179
 
222
180
  def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
223
- upstream_list = [upstreams] if isinstance(upstreams, AnyTask) else upstreams
224
- for upstream in upstream_list:
225
- self.__append_upstream(upstream)
226
-
227
- def __append_upstream(self, upstream: AnyTask):
228
- # Make sure self._upstreams is a list
181
+ """Appends upstream tasks, ensuring no duplicates."""
229
182
  if self._upstreams is None:
230
183
  self._upstreams = []
231
- elif isinstance(self._upstreams, AnyTask):
184
+ elif not isinstance(self._upstreams, list):
232
185
  self._upstreams = [self._upstreams]
233
- # Add upstream if it was not on self._upstreams
234
- if upstream not in self._upstreams:
235
- self._upstreams.append(upstream)
186
+ to_add = upstreams if isinstance(upstreams, list) else [upstreams]
187
+ for up in to_add:
188
+ if up not in self._upstreams:
189
+ self._upstreams.append(up)
236
190
 
237
191
  def get_ctx(self, session: AnySession) -> AnyContext:
238
- ctx = session.get_ctx(self)
239
- # Enhance session ctx with current task env
240
- for env in self.envs:
241
- env.update_context(ctx)
242
- return ctx
192
+ return build_task_context(self, session)
243
193
 
244
194
  def run(
245
195
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
246
196
  ) -> Any:
247
- return asyncio.run(self._run_and_cleanup(session, str_kwargs))
248
-
249
- async def _run_and_cleanup(
250
- self,
251
- session: AnySession | None = None,
252
- str_kwargs: dict[str, str] = {},
253
- ) -> Any:
254
- current_task = asyncio.create_task(self.async_run(session, str_kwargs))
255
- try:
256
- result = await current_task
257
- finally:
258
- if session and not session.is_terminated:
259
- session.terminate()
260
- # Cancel all running tasks except the current one
261
- pending = [task for task in asyncio.all_tasks() if task is not current_task]
262
- for task in pending:
263
- task.cancel()
264
- if pending:
265
- try:
266
- await asyncio.wait(pending, timeout=5)
267
- except asyncio.CancelledError:
268
- pass
269
- return result
197
+ # Use asyncio.run() to execute the async cleanup wrapper
198
+ return asyncio.run(run_and_cleanup(self, session, str_kwargs))
270
199
 
271
200
  async def async_run(
272
201
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
273
202
  ) -> Any:
274
- if session is None:
275
- session = Session(shared_ctx=SharedContext())
276
- # Update session
277
- self.__fill_shared_context_inputs(session.shared_ctx, str_kwargs)
278
- self.__fill_shared_context_envs(session.shared_ctx)
279
- result = await run_async(self.exec_root_tasks(session))
280
- return result
281
-
282
- def __fill_shared_context_inputs(
283
- self, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
284
- ):
285
- for task_input in self.inputs:
286
- if task_input.name not in shared_context.input:
287
- str_value = str_kwargs.get(task_input.name, None)
288
- task_input.update_shared_context(shared_context, str_value)
289
-
290
- def __fill_shared_context_envs(self, shared_context: AnySharedContext):
291
- # Inject os environ
292
- os_env_map = {
293
- key: val for key, val in os.environ.items() if key not in shared_context.env
294
- }
295
- shared_context.env.update(os_env_map)
203
+ return await run_task_async(self, session, str_kwargs)
296
204
 
297
205
  async def exec_root_tasks(self, session: AnySession):
298
- session.set_main_task(self)
299
- session.state_logger.write(session.as_state_log())
300
- try:
301
- log_state = asyncio.create_task(self._log_session_state(session))
302
- root_tasks = [
303
- task
304
- for task in session.get_root_tasks(self)
305
- if session.is_allowed_to_run(task)
306
- ]
307
- root_task_coros = [
308
- run_async(root_task.exec_chain(session)) for root_task in root_tasks
309
- ]
310
- await asyncio.gather(*root_task_coros)
311
- await session.wait_deferred()
312
- session.terminate()
313
- await log_state
314
- return session.final_result
315
- except IndexError:
316
- return None
317
- except (asyncio.CancelledError, KeyboardInterrupt):
318
- ctx = self.get_ctx(session)
319
- ctx.log_info("Session terminated")
320
- finally:
321
- session.terminate()
322
- session.state_logger.write(session.as_state_log())
323
- ctx = self.get_ctx(session)
324
- ctx.log_debug(session)
325
-
326
- async def _log_session_state(self, session: AnySession):
327
- session.state_logger.write(session.as_state_log())
328
- while not session.is_terminated:
329
- await asyncio.sleep(0.1)
330
- session.state_logger.write(session.as_state_log())
331
- session.state_logger.write(session.as_state_log())
206
+ return await execute_root_tasks(self, session)
332
207
 
333
208
  async def exec_chain(self, session: AnySession):
334
- if session.is_terminated or not session.is_allowed_to_run(self):
335
- return
336
- result = await self.exec(session)
337
- # Get next tasks
338
- nexts = session.get_next_tasks(self)
339
- if session.is_terminated or len(nexts) == 0:
340
- return result
341
- # Run next tasks asynchronously
342
- next_coros = [run_async(next.exec_chain(session)) for next in nexts]
343
- return await asyncio.gather(*next_coros)
209
+ return await execute_task_chain(self, session)
344
210
 
345
211
  async def exec(self, session: AnySession):
346
- ctx = self.get_ctx(session)
347
- if not session.is_allowed_to_run(self):
348
- # Task is not allowed to run, skip it for now.
349
- # This will be triggered later
350
- ctx.log_info("Not allowed to run")
351
- return
352
- if not self.__get_execute_condition(session):
353
- # Skip the task
354
- ctx.log_info("Marked as skipped")
355
- session.get_task_status(self).mark_as_skipped()
356
- return
357
- # Wait for task to be ready
358
- return await run_async(self.__exec_action_until_ready(session))
359
-
360
- def __get_execute_condition(self, session: Session) -> bool:
361
- ctx = self.get_ctx(session)
362
- return get_bool_attr(ctx, self._execute_condition, True, auto_render=True)
363
-
364
- async def __exec_action_until_ready(self, session: AnySession):
365
- ctx = self.get_ctx(session)
366
- readiness_checks = self.readiness_checks
367
- if len(readiness_checks) == 0:
368
- ctx.log_info("No readiness checks")
369
- # Task has no readiness check
370
- result = await run_async(self.__exec_action_and_retry(session))
371
- ctx.log_info("Marked as ready")
372
- session.get_task_status(self).mark_as_ready()
373
- return result
374
- # Start the task along with the readiness checks
375
- action_coro = asyncio.create_task(
376
- run_async(self.__exec_action_and_retry(session))
377
- )
378
- await asyncio.sleep(self._readiness_check_delay)
379
- readiness_check_coros = [
380
- run_async(check.exec_chain(session)) for check in readiness_checks
381
- ]
382
- # Only wait for readiness checks and mark the task as ready
383
- ctx.log_info("Start readiness checks")
384
- result = await asyncio.gather(*readiness_check_coros)
385
- ctx.log_info("Readiness checks completed")
386
- ctx.log_info("Marked as ready")
387
- session.get_task_status(self).mark_as_ready()
388
- # Defer task's coroutines, will be waited later
389
- session.defer_action(self, action_coro)
390
- if self._monitor_readiness:
391
- monitor_and_rerun_coro = asyncio.create_task(
392
- run_async(self.__exec_monitoring(session, action_coro))
393
- )
394
- session.defer_monitoring(self, monitor_and_rerun_coro)
395
- return result
396
-
397
- async def __exec_monitoring(self, session: AnySession, action_coro: asyncio.Task):
398
- readiness_checks = self.readiness_checks
399
- failure_count = 0
400
- ctx = self.get_ctx(session)
401
- while not session.is_terminated:
402
- await asyncio.sleep(self._readiness_check_period)
403
- if failure_count < self._readiness_failure_threshold:
404
- for readiness_check in readiness_checks:
405
- session.get_task_status(readiness_check).reset_history()
406
- session.get_task_status(readiness_check).reset()
407
- readiness_xcom: Xcom = ctx.xcom[self.name]
408
- readiness_xcom.clear()
409
- readiness_check_coros = [
410
- check.exec_chain(session) for check in readiness_checks
411
- ]
412
- try:
413
- ctx.log_info("Checking")
414
- await asyncio.wait_for(
415
- asyncio.gather(*readiness_check_coros),
416
- timeout=self._readiness_timeout,
417
- )
418
- ctx.log_info("OK")
419
- continue
420
- except (asyncio.CancelledError, KeyboardInterrupt):
421
- for readiness_check in readiness_checks:
422
- ctx.log_info("Marked as failed")
423
- session.get_task_status(readiness_check).mark_as_failed()
424
- except asyncio.TimeoutError:
425
- failure_count += 1
426
- ctx.log_info("Detecting failure")
427
- ctx.log_debug(f"Failure count: {failure_count}")
428
- # Readiness check failed, reset
429
- ctx.log_info("Resetting")
430
- action_coro.cancel()
431
- session.get_task_status(self).reset()
432
- # defer this action
433
- ctx.log_info("Running")
434
- action_coro: Coroutine = asyncio.create_task(
435
- run_async(self.__exec_action_and_retry(session))
436
- )
437
- session.defer_action(self, action_coro)
438
- failure_count = 0
439
- ctx.log_info("Continue monitoring")
440
-
441
- async def __exec_action_and_retry(self, session: AnySession) -> Any:
442
- """
443
- Executes an action with retry logic.
444
-
445
- This method attempts to execute the action defined in `_exec_action` with a specified number of retries.
446
- If the action fails, it will retry after a specified period until the maximum number of attempts is reached.
447
- If the action succeeds, it marks the task as completed and executes any successors.
448
- If the action fails permanently, it marks the task as permanently failed and executes any fallbacks.
449
-
450
- Args:
451
- session (AnySession): The session object containing the task status and context.
452
-
453
- Returns:
454
- Any: The result of the executed action if successful.
455
-
456
- Raises:
457
- Exception: If the action fails permanently after all retry attempts.
458
- """
459
- ctx = self.get_ctx(session)
460
- max_attempt = self._retries + 1
461
- ctx.set_max_attempt(max_attempt)
462
- for attempt in range(max_attempt):
463
- ctx.set_attempt(attempt + 1)
464
- if attempt > 0:
465
- # apply retry period only if this is not the first attempt
466
- await asyncio.sleep(self._retry_period)
467
- try:
468
- ctx.log_info("Marked as started")
469
- session.get_task_status(self).mark_as_started()
470
- result = await run_async(self._exec_action(ctx))
471
- ctx.log_info("Marked as completed")
472
- session.get_task_status(self).mark_as_completed()
473
- # Put result on xcom
474
- task_xcom: Xcom = ctx.xcom.get(self.name)
475
- task_xcom.push(result)
476
- self.__skip_fallbacks(session)
477
- await run_async(self.__exec_successors(session))
478
- return result
479
- except (asyncio.CancelledError, KeyboardInterrupt):
480
- ctx.log_info("Marked as failed")
481
- session.get_task_status(self).mark_as_failed()
482
- return
483
- except BaseException as e:
484
- ctx.log_error(e)
485
- if attempt < max_attempt - 1:
486
- ctx.log_info("Marked as failed")
487
- session.get_task_status(self).mark_as_failed()
488
- continue
489
- ctx.log_info("Marked as permanently failed")
490
- session.get_task_status(self).mark_as_permanently_failed()
491
- self.__skip_successors(session)
492
- await run_async(self.__exec_fallbacks(session))
493
- raise e
494
-
495
- async def __exec_successors(self, session: AnySession) -> Any:
496
- successors: list[AnyTask] = self.successors
497
- successor_coros = [
498
- run_async(successor.exec_chain(session)) for successor in successors
499
- ]
500
- await asyncio.gather(*successor_coros)
501
-
502
- def __skip_successors(self, session: AnySession) -> Any:
503
- for successor in self.successors:
504
- session.get_task_status(successor).mark_as_skipped()
505
-
506
- async def __exec_fallbacks(self, session: AnySession) -> Any:
507
- fallbacks: list[AnyTask] = self.fallbacks
508
- fallback_coros = [
509
- run_async(fallback.exec_chain(session)) for fallback in fallbacks
510
- ]
511
- await asyncio.gather(*fallback_coros)
512
-
513
- def __skip_fallbacks(self, session: AnySession) -> Any:
514
- for fallback in self.fallbacks:
515
- session.get_task_status(fallback).mark_as_skipped()
212
+ return await execute_task_action(self, session)
516
213
 
517
214
  async def _exec_action(self, ctx: AnyContext) -> Any:
518
- """Execute the main action of the task.
519
- By default will render and run the _action attribute.
520
-
521
- Override this method to define custom action.
215
+ """
216
+ Execute the main action of the task.
217
+ This is the primary method to override in subclasses for custom action logic.
218
+ The default implementation handles the '_action' attribute (string or callable).
522
219
 
523
220
  Args:
524
- session (AnySession): The shared session.
221
+ ctx (AnyContext): The execution context for this task.
525
222
 
526
223
  Returns:
527
224
  Any: The result of the action execution.
528
225
  """
529
- if self._action is None:
530
- return
531
- if isinstance(self._action, str):
532
- return ctx.render(self._action)
533
- return await run_async(self._action(ctx))
226
+ # Delegate to the helper function for the default behavior
227
+ return await run_default_action(self, ctx)
zrb/task/cmd_task.py CHANGED
@@ -130,7 +130,8 @@ class CmdTask(BaseTask):
130
130
  partial(ctx.print, plain=True) if self._should_plain_print else ctx.print
131
131
  )
132
132
  xcom_pid_key = f"{self.name}-pid"
133
- ctx.xcom[xcom_pid_key] = Xcom([])
133
+ if xcom_pid_key not in ctx.xcom:
134
+ ctx.xcom[xcom_pid_key] = Xcom([])
134
135
  cmd_result, return_code = await run_command(
135
136
  cmd=[shell, shell_flag, cmd_script],
136
137
  cwd=cwd,