runnable 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. extensions/README.md +0 -0
  2. extensions/__init__.py +0 -0
  3. extensions/catalog/README.md +0 -0
  4. extensions/catalog/any_path.py +214 -0
  5. extensions/catalog/file_system.py +52 -0
  6. extensions/catalog/minio.py +72 -0
  7. extensions/catalog/pyproject.toml +14 -0
  8. extensions/catalog/s3.py +11 -0
  9. extensions/job_executor/README.md +0 -0
  10. extensions/job_executor/__init__.py +236 -0
  11. extensions/job_executor/emulate.py +70 -0
  12. extensions/job_executor/k8s.py +553 -0
  13. extensions/job_executor/k8s_job_spec.yaml +37 -0
  14. extensions/job_executor/local.py +35 -0
  15. extensions/job_executor/local_container.py +161 -0
  16. extensions/job_executor/pyproject.toml +16 -0
  17. extensions/nodes/README.md +0 -0
  18. extensions/nodes/__init__.py +0 -0
  19. extensions/nodes/conditional.py +301 -0
  20. extensions/nodes/fail.py +78 -0
  21. extensions/nodes/loop.py +394 -0
  22. extensions/nodes/map.py +477 -0
  23. extensions/nodes/parallel.py +281 -0
  24. extensions/nodes/pyproject.toml +15 -0
  25. extensions/nodes/stub.py +93 -0
  26. extensions/nodes/success.py +78 -0
  27. extensions/nodes/task.py +156 -0
  28. extensions/pipeline_executor/README.md +0 -0
  29. extensions/pipeline_executor/__init__.py +871 -0
  30. extensions/pipeline_executor/argo.py +1266 -0
  31. extensions/pipeline_executor/emulate.py +119 -0
  32. extensions/pipeline_executor/local.py +226 -0
  33. extensions/pipeline_executor/local_container.py +369 -0
  34. extensions/pipeline_executor/mocked.py +159 -0
  35. extensions/pipeline_executor/pyproject.toml +16 -0
  36. extensions/run_log_store/README.md +0 -0
  37. extensions/run_log_store/__init__.py +0 -0
  38. extensions/run_log_store/any_path.py +100 -0
  39. extensions/run_log_store/chunked_fs.py +122 -0
  40. extensions/run_log_store/chunked_minio.py +141 -0
  41. extensions/run_log_store/file_system.py +91 -0
  42. extensions/run_log_store/generic_chunked.py +549 -0
  43. extensions/run_log_store/minio.py +114 -0
  44. extensions/run_log_store/pyproject.toml +15 -0
  45. extensions/secrets/README.md +0 -0
  46. extensions/secrets/dotenv.py +62 -0
  47. extensions/secrets/pyproject.toml +15 -0
  48. runnable/__init__.py +108 -0
  49. runnable/catalog.py +141 -0
  50. runnable/cli.py +484 -0
  51. runnable/context.py +730 -0
  52. runnable/datastore.py +1058 -0
  53. runnable/defaults.py +159 -0
  54. runnable/entrypoints.py +390 -0
  55. runnable/exceptions.py +137 -0
  56. runnable/executor.py +561 -0
  57. runnable/gantt.py +1646 -0
  58. runnable/graph.py +501 -0
  59. runnable/names.py +546 -0
  60. runnable/nodes.py +593 -0
  61. runnable/parameters.py +217 -0
  62. runnable/pickler.py +96 -0
  63. runnable/sdk.py +1277 -0
  64. runnable/secrets.py +92 -0
  65. runnable/tasks.py +1268 -0
  66. runnable/telemetry.py +142 -0
  67. runnable/utils.py +423 -0
  68. runnable-0.50.0.dist-info/METADATA +189 -0
  69. runnable-0.50.0.dist-info/RECORD +72 -0
  70. runnable-0.50.0.dist-info/WHEEL +4 -0
  71. runnable-0.50.0.dist-info/entry_points.txt +53 -0
  72. runnable-0.50.0.dist-info/licenses/LICENSE +201 -0
runnable/tasks.py ADDED
@@ -0,0 +1,1268 @@
1
+ import contextlib
2
+ import copy
3
+ import importlib
4
+ import inspect
5
+ import io
6
+ import json
7
+ import logging
8
+ import os
9
+ import subprocess
10
+ import sys
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from pickle import PicklingError
14
+ from string import Template
15
+ from typing import Any, Callable, Dict, List, Literal, Optional, cast
16
+
17
+ import logfire_api as logfire
18
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
19
+ from rich.segment import Segment
20
+ from rich.style import Style
21
+ from stevedore import driver
22
+
23
+ import runnable.context as context
24
+ from runnable import console, defaults, exceptions, parameters, task_console, utils
25
+ from runnable.datastore import (
26
+ JsonParameter,
27
+ MetricParameter,
28
+ ObjectParameter,
29
+ Parameter,
30
+ StepAttempt,
31
+ )
32
+ from runnable.defaults import IterableParameterModel
33
+ from runnable.telemetry import truncate_value
34
+
35
+ logger = logging.getLogger(defaults.LOGGER_NAME)
36
+
37
+
38
+ class TeeIO(io.StringIO):
39
+ """
40
+ A custom class to write to the buffer, output stream, and Rich console simultaneously.
41
+
42
+ This implementation directly adds to Rich Console's internal recording buffer using
43
+ proper Segment objects, avoiding the infinite recursion that occurs when using
44
+ Rich Console's print() method.
45
+ """
46
+
47
+ def __init__(
48
+ self, output_stream=sys.stdout, rich_console=None, stream_type="stdout"
49
+ ):
50
+ super().__init__()
51
+ self.output_stream = output_stream
52
+ self.rich_console = rich_console
53
+ self.stream_type = stream_type
54
+
55
+ def write(self, s):
56
+ if s: # Only process non-empty strings
57
+ super().write(s) # Write to the buffer for later retrieval
58
+ self.output_stream.write(s) # Display immediately
59
+
60
+ # Record directly to Rich's internal buffer using proper Segments
61
+ # Note: We record ALL content including newlines, not just stripped content
62
+ if self.rich_console:
63
+ if self.stream_type == "stderr":
64
+ # Red style for stderr
65
+ style = Style(color="red")
66
+ segment = Segment(s, style)
67
+ else:
68
+ # No style for stdout
69
+ segment = Segment(s)
70
+
71
+ # Add to Rich's record buffer (no recursion!)
72
+ self.rich_console._record_buffer.append(segment)
73
+
74
+ return len(s) if s else 0
75
+
76
+ def flush(self):
77
+ super().flush()
78
+ self.output_stream.flush()
79
+
80
+
81
+ @contextlib.contextmanager
82
+ def redirect_output(console=None):
83
+ """
84
+ Context manager that captures output to both display and Rich console recording.
85
+
86
+ This implementation uses TeeIO which directly records to Rich Console's internal
87
+ buffer, eliminating the need for post-processing and avoiding infinite recursion.
88
+
89
+ Args:
90
+ console: Rich Console instance for recording (typically task_console)
91
+ """
92
+ # Backup the original stdout and stderr
93
+ original_stdout = sys.stdout
94
+ original_stderr = sys.stderr
95
+
96
+ # Create TeeIO instances that handle display + Rich console recording simultaneously
97
+ sys.stdout = TeeIO(original_stdout, rich_console=console, stream_type="stdout")
98
+ sys.stderr = TeeIO(original_stderr, rich_console=console, stream_type="stderr")
99
+
100
+ # Update logging handlers to use the new stdout
101
+ original_streams = []
102
+ for handler in logging.getLogger().handlers:
103
+ if isinstance(handler, logging.StreamHandler):
104
+ original_streams.append((handler, handler.stream))
105
+ handler.stream = sys.stdout
106
+
107
+ try:
108
+ yield sys.stdout, sys.stderr
109
+ finally:
110
+ # Restore the original stdout and stderr
111
+ sys.stdout = original_stdout
112
+ sys.stderr = original_stderr
113
+
114
+ # Restore logging handler streams
115
+ for handler, original_stream in original_streams:
116
+ handler.stream = original_stream
117
+
118
+ # No additional Rich console processing needed - TeeIO handles it directly!
119
+
120
+
121
+ class TaskReturns(BaseModel):
122
+ name: str
123
+ kind: Literal["json", "object", "metric"] = Field(default="json")
124
+
125
+
126
+ class BaseTaskType(BaseModel):
127
+ """A base task class which does the execution of command defined by the user."""
128
+
129
+ task_type: str = Field(serialization_alias="command_type")
130
+ secrets: List[str] = Field(
131
+ default_factory=list
132
+ ) # A list of secrets to expose by secrets manager
133
+ returns: List[TaskReturns] = Field(
134
+ default_factory=list, alias="returns"
135
+ ) # The return values of the task
136
+ internal_branch_name: str = Field(default="")
137
+
138
+ model_config = ConfigDict(extra="forbid")
139
+
140
+ def get_summary(self) -> Dict[str, Any]:
141
+ return self.model_dump(by_alias=True)
142
+
143
+ @property
144
+ def _context(self):
145
+ current_context = context.get_run_context()
146
+ if current_context is None:
147
+ raise RuntimeError("No run context available in current execution context")
148
+ return current_context
149
+
150
+ def set_secrets_as_env_variables(self):
151
+ # Preparing the environment for the task execution
152
+ current_context = context.get_run_context()
153
+ if current_context is None:
154
+ raise RuntimeError("No run context available for secrets")
155
+
156
+ for key in self.secrets:
157
+ secret_value = current_context.secrets.get(key)
158
+ os.environ[key] = secret_value
159
+
160
+ def delete_secrets_from_env_variables(self):
161
+ # Cleaning up the environment after the task execution
162
+ for key in self.secrets:
163
+ if key in os.environ:
164
+ del os.environ[key]
165
+
166
+ def execute_command(
167
+ self,
168
+ iter_variable: Optional[IterableParameterModel] = None,
169
+ ) -> StepAttempt:
170
+ """The function to execute the command.
171
+
172
+ And map_variable is sent in as an argument into the function.
173
+
174
+ Args:
175
+ map_variable (dict, optional): If the command is part of map node, the value of map. Defaults to None.
176
+
177
+ Raises:
178
+ NotImplementedError: Base class, not implemented
179
+ """
180
+ raise NotImplementedError()
181
+
182
+ async def execute_command_async(
183
+ self,
184
+ iter_variable: Optional[IterableParameterModel] = None,
185
+ event_callback: Optional[Callable[[dict], None]] = None,
186
+ ) -> StepAttempt:
187
+ """
188
+ Async command execution.
189
+
190
+ Only implemented by task types that support async execution
191
+ (AsyncPythonTaskType). Sync task types (PythonTaskType,
192
+ NotebookTaskType, ShellTaskType) raise NotImplementedError.
193
+
194
+ Args:
195
+ map_variable: If the command is part of map node.
196
+ event_callback: Optional callback for streaming events.
197
+
198
+ Raises:
199
+ NotImplementedError: If task type does not support async execution.
200
+ """
201
+ raise NotImplementedError(
202
+ f"{self.__class__.__name__} does not support async execution. "
203
+ f"Use AsyncPythonTask for async functions."
204
+ )
205
+
206
+ def _diff_parameters(
207
+ self, parameters_in: Dict[str, Parameter], context_params: Dict[str, Parameter]
208
+ ) -> Dict[str, Parameter]:
209
+ # If the parameter is different from existing parameters, then it is updated
210
+ diff: Dict[str, Parameter] = {}
211
+ for param_name, param in context_params.items():
212
+ if param_name in parameters_in:
213
+ if parameters_in[param_name] != param:
214
+ diff[param_name] = param
215
+ continue
216
+
217
+ diff[param_name] = param
218
+
219
+ return diff
220
+
221
+ @contextlib.contextmanager
222
+ def expose_secrets(self):
223
+ """Context manager to expose secrets to the execution."""
224
+ self.set_secrets_as_env_variables()
225
+ try:
226
+ yield
227
+ except Exception as e: # pylint: disable=broad-except
228
+ logger.exception(e)
229
+ finally:
230
+ self.delete_secrets_from_env_variables()
231
+
232
+ def _safe_serialize_params(self, params: Dict[str, Parameter]) -> Dict[str, Any]:
233
+ """Safely serialize parameters for telemetry, truncating per value.
234
+
235
+ ObjectParameter values are not serializable (pickled objects),
236
+ so they are represented as "<object>".
237
+ """
238
+ serializable: Dict[str, Any] = {}
239
+ for k, v in params.items():
240
+ if isinstance(v, ObjectParameter):
241
+ serializable[k] = "<object>"
242
+ else:
243
+ serializable[k] = truncate_value(v.get_value())
244
+ return serializable
245
+
246
+ def _emit_event(self, event: Dict[str, Any]) -> None:
247
+ """Push event to stream queue if one is set (for SSE streaming)."""
248
+ from runnable.telemetry import get_stream_queue
249
+
250
+ q = get_stream_queue()
251
+ if q is not None:
252
+ q.put_nowait(event)
253
+
254
+ @contextlib.contextmanager
255
+ def execution_context(
256
+ self,
257
+ iter_variable: Optional[IterableParameterModel] = None,
258
+ allow_complex: bool = True,
259
+ ):
260
+ params = self._context.run_log_store.get_parameters(
261
+ run_id=self._context.run_id, internal_branch_name=self.internal_branch_name
262
+ ).copy()
263
+ logger.info(f"Parameters available for the execution: {params}")
264
+
265
+ task_console.log("Parameters available for the execution:")
266
+ task_console.log(params)
267
+
268
+ logger.debug(f"Resolved parameters: {params}")
269
+
270
+ if not allow_complex:
271
+ params = {
272
+ key: value
273
+ for key, value in params.items()
274
+ if isinstance(value, JsonParameter)
275
+ or isinstance(value, MetricParameter)
276
+ }
277
+
278
+ parameters_in = copy.deepcopy(params)
279
+ try:
280
+ yield params
281
+ except Exception as e: # pylint: disable=broad-except
282
+ console.log(e, style=defaults.error_style)
283
+ logger.exception(e)
284
+ finally:
285
+ # Update parameters
286
+ # This should only update the parameters that are changed at the root level.
287
+ diff_parameters = self._diff_parameters(
288
+ parameters_in=parameters_in, context_params=params
289
+ )
290
+ self._context.run_log_store.set_parameters(
291
+ parameters=diff_parameters,
292
+ run_id=self._context.run_id,
293
+ internal_branch_name=self.internal_branch_name,
294
+ )
295
+
296
+
297
+ def task_return_to_parameter(task_return: TaskReturns, value: Any) -> Parameter:
298
+ # implicit support for pydantic models
299
+ if isinstance(value, BaseModel) and task_return.kind == "json":
300
+ try:
301
+ return JsonParameter(kind="json", value=value.model_dump(by_alias=True))
302
+ except PicklingError:
303
+ logging.warning("Pydantic model is not serializable")
304
+
305
+ if task_return.kind == "json":
306
+ return JsonParameter(kind="json", value=value)
307
+
308
+ if task_return.kind == "metric":
309
+ return MetricParameter(kind="metric", value=value)
310
+
311
+ if task_return.kind == "object":
312
+ obj = ObjectParameter(value=task_return.name, kind="object")
313
+ obj.put_object(data=value)
314
+ return obj
315
+
316
+ raise Exception(f"Unknown return type: {task_return.kind}")
317
+
318
+
319
+ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
320
+ """
321
+ --8<-- [start:python_reference]
322
+ An execution node of the pipeline of python functions.
323
+ Please refer to define pipeline/tasks/python for more information.
324
+
325
+ As part of the dag definition, a python task is defined as follows:
326
+
327
+ dag:
328
+ steps:
329
+ python_task: # The name of the node
330
+ type: task
331
+ command_type: python # this is default
332
+ command: my_module.my_function # the dotted path to the function. Please refer to the yaml section of
333
+ define pipeline/tasks/python for concrete details.
334
+ returns:
335
+ - name: # The name to assign the return value
336
+ kind: json # the default value is json,
337
+ can be object for python objects and metric for metrics
338
+ secrets:
339
+ - my_secret_key # A list of secrets to expose by secrets manager
340
+ catalog:
341
+ get:
342
+ - A list of glob patterns to get from the catalog to the local file system
343
+ put:
344
+ - A list of glob patterns to put to the catalog from the local file system
345
+ on_failure: The name of the step to traverse in case of failure
346
+ overrides:
347
+ Individual tasks can override the global configuration config by referring to the
348
+ specific override.
349
+
350
+ For example,
351
+ #Global configuration
352
+ executor:
353
+ type: local-container
354
+ config:
355
+ docker_image: "runnable/runnable:latest"
356
+ overrides:
357
+ custom_docker_image:
358
+ docker_image: "runnable/runnable:custom"
359
+
360
+ ## In the node definition
361
+ overrides:
362
+ local-container:
363
+ docker_image: "runnable/runnable:custom"
364
+
365
+ This instruction will override the docker image for the local-container executor.
366
+ next: The next node to execute after this task, use "success" to terminate the pipeline successfully
367
+ or "fail" to terminate the pipeline with an error.
368
+ --8<-- [end:python_reference]
369
+ """
370
+
371
+ task_type: str = Field(default="python", serialization_alias="command_type")
372
+ command: str
373
+
374
+ def execute_command(
375
+ self,
376
+ iter_variable: Optional[IterableParameterModel] = None,
377
+ ) -> StepAttempt:
378
+ """Execute the notebook as defined by the command."""
379
+ attempt_log = StepAttempt(
380
+ status=defaults.FAIL,
381
+ start_time=str(datetime.now()),
382
+ retry_indicator=self._context.retry_indicator,
383
+ )
384
+
385
+ with logfire.span(
386
+ "task:{task_name}",
387
+ task_name=self.command,
388
+ task_type=self.task_type,
389
+ ):
390
+ with (
391
+ self.execution_context(iter_variable=iter_variable) as params,
392
+ self.expose_secrets() as _,
393
+ ):
394
+ logfire.info(
395
+ "Task started",
396
+ inputs=self._safe_serialize_params(params),
397
+ )
398
+ self._emit_event(
399
+ {
400
+ "type": "task_started",
401
+ "name": self.command,
402
+ "inputs": self._safe_serialize_params(params),
403
+ }
404
+ )
405
+
406
+ module, func = utils.get_module_and_attr_names(self.command)
407
+ sys.path.insert(
408
+ 0, os.getcwd()
409
+ ) # Need to add the current directory to path
410
+ imported_module = importlib.import_module(module)
411
+ f = getattr(imported_module, func)
412
+
413
+ try:
414
+ try:
415
+ filtered_parameters = parameters.filter_arguments_for_func(
416
+ f, params.copy(), iter_variable
417
+ )
418
+ logger.info(
419
+ f"Calling {func} from {module} with {filtered_parameters}"
420
+ )
421
+ with redirect_output(console=task_console) as (
422
+ buffer,
423
+ stderr_buffer,
424
+ ):
425
+ user_set_parameters = f(
426
+ **filtered_parameters
427
+ ) # This is a tuple or single value
428
+ except Exception as e:
429
+ raise exceptions.CommandCallError(
430
+ f"Function call: {self.command} did not succeed.\n"
431
+ ) from e
432
+ finally:
433
+ attempt_log.input_parameters = params.copy()
434
+ if iter_variable and iter_variable.map_variable:
435
+ attempt_log.input_parameters.update(
436
+ {
437
+ k: JsonParameter(value=v.value, kind="json")
438
+ for k, v in iter_variable.map_variable.items()
439
+ }
440
+ )
441
+
442
+ if self.returns:
443
+ if not isinstance(
444
+ user_set_parameters, tuple
445
+ ): # make it a tuple
446
+ user_set_parameters = (user_set_parameters,)
447
+
448
+ if len(user_set_parameters) != len(self.returns):
449
+ raise ValueError(
450
+ "Returns task signature does not match the function returns"
451
+ )
452
+
453
+ output_parameters: Dict[str, Parameter] = {}
454
+ metrics: Dict[str, Parameter] = {}
455
+
456
+ for i, task_return in enumerate(self.returns):
457
+ output_parameter = task_return_to_parameter(
458
+ task_return=task_return,
459
+ value=user_set_parameters[i],
460
+ )
461
+
462
+ if task_return.kind == "metric":
463
+ metrics[task_return.name] = output_parameter
464
+
465
+ output_parameters[task_return.name] = output_parameter
466
+
467
+ attempt_log.output_parameters = output_parameters
468
+ attempt_log.user_defined_metrics = metrics
469
+ params.update(output_parameters)
470
+
471
+ logfire.info(
472
+ "Task completed",
473
+ outputs=self._safe_serialize_params(output_parameters),
474
+ status="success",
475
+ )
476
+ self._emit_event(
477
+ {
478
+ "type": "task_completed",
479
+ "name": self.command,
480
+ "outputs": self._safe_serialize_params(
481
+ output_parameters
482
+ ),
483
+ }
484
+ )
485
+ else:
486
+ logfire.info("Task completed", status="success")
487
+ self._emit_event(
488
+ {
489
+ "type": "task_completed",
490
+ "name": self.command,
491
+ }
492
+ )
493
+
494
+ attempt_log.status = defaults.SUCCESS
495
+ except Exception as _e:
496
+ msg = f"Call to the function {self.command} did not succeed.\n"
497
+ attempt_log.message = msg
498
+ task_console.print_exception(show_locals=False)
499
+ task_console.log(_e, style=defaults.error_style)
500
+ logfire.error("Task failed", error=str(_e)[:256])
501
+ self._emit_event(
502
+ {
503
+ "type": "task_error",
504
+ "name": self.command,
505
+ "error": str(_e)[:256],
506
+ }
507
+ )
508
+
509
+ attempt_log.end_time = str(datetime.now())
510
+
511
+ return attempt_log
512
+
513
+
514
+ class NotebookTaskType(BaseTaskType):
515
+ """
516
+ --8<-- [start:notebook_reference]
517
+ An execution node of the pipeline of notebook execution.
518
+ Please refer to define pipeline/tasks/notebook for more information.
519
+
520
+ As part of the dag definition, a notebook task is defined as follows:
521
+
522
+ dag:
523
+ steps:
524
+ notebook_task: # The name of the node
525
+ type: task
526
+ command_type: notebook
527
+ command: the path to the notebook relative to project root.
528
+ optional_ploomber_args: a dictionary of arguments to be passed to ploomber engine
529
+ returns:
530
+ - name: # The name to assign the return value
531
+ kind: json # the default value is json,
532
+ can be object for python objects and metric for metrics
533
+ secrets:
534
+ - my_secret_key # A list of secrets to expose by secrets manager
535
+ catalog:
536
+ get:
537
+ - A list of glob patterns to get from the catalog to the local file system
538
+ put:
539
+ - A list of glob patterns to put to the catalog from the local file system
540
+ on_failure: The name of the step to traverse in case of failure
541
+ overrides:
542
+ Individual tasks can override the global configuration config by referring to the
543
+ specific override.
544
+
545
+ For example,
546
+ #Global configuration
547
+ executor:
548
+ type: local-container
549
+ config:
550
+ docker_image: "runnable/runnable:latest"
551
+ overrides:
552
+ custom_docker_image:
553
+ docker_image: "runnable/runnable:custom"
554
+
555
+ ## In the node definition
556
+ overrides:
557
+ local-container:
558
+ docker_image: "runnable/runnable:custom"
559
+
560
+ This instruction will override the docker image for the local-container executor.
561
+ next: The next node to execute after this task, use "success" to terminate the pipeline successfully
562
+ or "fail" to terminate the pipeline with an error.
563
+ --8<-- [end:notebook_reference]
564
+ """
565
+
566
+ task_type: str = Field(default="notebook", serialization_alias="command_type")
567
+ command: str
568
+ optional_ploomber_args: dict = {}
569
+
570
+ @field_validator("command")
571
+ @classmethod
572
+ def notebook_should_end_with_ipynb(cls, command: str) -> str:
573
+ if not command.endswith(".ipynb"):
574
+ raise Exception("Notebook task should point to a ipynb file")
575
+
576
+ return command
577
+
578
+ def get_notebook_output_path(
579
+ self,
580
+ iter_variable: Optional[IterableParameterModel] = None,
581
+ ) -> str:
582
+ tag = ""
583
+ if iter_variable and iter_variable.map_variable:
584
+ for key, value_model in iter_variable.map_variable.items():
585
+ tag += f"{key}_{value_model.value}_"
586
+
587
+ if isinstance(self._context, context.PipelineContext):
588
+ assert self._context.pipeline_executor._context_node
589
+ tag += self._context.pipeline_executor._context_node.name
590
+
591
+ tag = "".join(x for x in tag if x.isalnum()).strip("-")
592
+
593
+ output_path = Path(".", self.command)
594
+ file_name = output_path.parent / (output_path.stem + f"-{tag}_out.ipynb")
595
+
596
+ return str(file_name)
597
+
598
+ def execute_command(
599
+ self,
600
+ iter_variable: Optional[IterableParameterModel] = None,
601
+ ) -> StepAttempt:
602
+ """Execute the python notebook as defined by the command.
603
+
604
+ Args:
605
+ map_variable (dict, optional): If the node is part of internal branch. Defaults to None.
606
+
607
+ Raises:
608
+ ImportError: If necessary dependencies are not installed
609
+ Exception: If anything else fails
610
+ """
611
+ attempt_log = StepAttempt(
612
+ status=defaults.FAIL,
613
+ start_time=str(datetime.now()),
614
+ retry_indicator=self._context.retry_indicator,
615
+ )
616
+
617
+ with logfire.span(
618
+ "task:{task_name}",
619
+ task_name=self.command,
620
+ task_type=self.task_type,
621
+ ):
622
+ try:
623
+ import ploomber_engine as pm
624
+ from ploomber_engine.ipython import PloomberClient
625
+
626
+ notebook_output_path = self.get_notebook_output_path(
627
+ iter_variable=iter_variable
628
+ )
629
+
630
+ with (
631
+ self.execution_context(
632
+ iter_variable=iter_variable, allow_complex=False
633
+ ) as params,
634
+ self.expose_secrets() as _,
635
+ ):
636
+ logfire.info(
637
+ "Task started",
638
+ inputs=self._safe_serialize_params(params),
639
+ )
640
+ self._emit_event(
641
+ {
642
+ "type": "task_started",
643
+ "name": self.command,
644
+ "inputs": self._safe_serialize_params(params),
645
+ }
646
+ )
647
+
648
+ attempt_log.input_parameters = params.copy()
649
+ copy_params = copy.deepcopy(params)
650
+
651
+ if iter_variable and iter_variable.map_variable:
652
+ for key, value_model in iter_variable.map_variable.items():
653
+ copy_params[key] = JsonParameter(
654
+ kind="json", value=value_model.value
655
+ )
656
+
657
+ notebook_params = {k: v.get_value() for k, v in copy_params.items()}
658
+
659
+ ploomber_optional_args = self.optional_ploomber_args
660
+
661
+ kwds = {
662
+ "input_path": self.command,
663
+ "output_path": notebook_output_path,
664
+ "parameters": notebook_params,
665
+ "log_output": True,
666
+ "progress_bar": False,
667
+ }
668
+ kwds.update(ploomber_optional_args)
669
+
670
+ with redirect_output(console=task_console) as (
671
+ buffer,
672
+ stderr_buffer,
673
+ ):
674
+ pm.execute_notebook(**kwds)
675
+
676
+ current_context = context.get_run_context()
677
+ if current_context is None:
678
+ raise RuntimeError(
679
+ "No run context available for catalog operations"
680
+ )
681
+ current_context.catalog.put(name=notebook_output_path)
682
+
683
+ client = PloomberClient.from_path(path=notebook_output_path)
684
+ namespace = client.get_namespace()
685
+
686
+ output_parameters: Dict[str, Parameter] = {}
687
+ try:
688
+ for task_return in self.returns:
689
+ template_vars = {}
690
+ if iter_variable and iter_variable.map_variable:
691
+ template_vars = {
692
+ k: v.value
693
+ for k, v in iter_variable.map_variable.items()
694
+ }
695
+ param_name = Template(task_return.name).safe_substitute(
696
+ template_vars # type: ignore
697
+ )
698
+
699
+ output_parameters[param_name] = task_return_to_parameter(
700
+ task_return=task_return,
701
+ value=namespace[task_return.name],
702
+ )
703
+ except PicklingError as e:
704
+ logger.exception("Notebooks cannot return objects")
705
+ logger.exception(e)
706
+ logfire.error("Notebook pickling error", error=str(e)[:256])
707
+ raise
708
+
709
+ if output_parameters:
710
+ attempt_log.output_parameters = output_parameters
711
+ params.update(output_parameters)
712
+ logfire.info(
713
+ "Task completed",
714
+ outputs=self._safe_serialize_params(output_parameters),
715
+ status="success",
716
+ )
717
+ self._emit_event(
718
+ {
719
+ "type": "task_completed",
720
+ "name": self.command,
721
+ "outputs": self._safe_serialize_params(
722
+ output_parameters
723
+ ),
724
+ }
725
+ )
726
+ else:
727
+ logfire.info("Task completed", status="success")
728
+ self._emit_event(
729
+ {
730
+ "type": "task_completed",
731
+ "name": self.command,
732
+ }
733
+ )
734
+
735
+ attempt_log.status = defaults.SUCCESS
736
+
737
+ except (ImportError, Exception) as e:
738
+ msg = (
739
+ f"Call to the notebook command {self.command} did not succeed.\n"
740
+ "Ensure that you have installed runnable with notebook extras"
741
+ )
742
+ logger.exception(msg)
743
+ logger.exception(e)
744
+ logfire.error("Task failed", error=str(e)[:256])
745
+ self._emit_event(
746
+ {"type": "task_error", "name": self.command, "error": str(e)[:256]}
747
+ )
748
+
749
+ attempt_log.status = defaults.FAIL
750
+
751
+ attempt_log.end_time = str(datetime.now())
752
+
753
+ return attempt_log
754
+
755
+
756
+ class ShellTaskType(BaseTaskType):
757
+ """
758
+ --8<-- [start:shell_reference]
759
+ An execution node of the pipeline of shell execution.
760
+ Please refer to define pipeline/tasks/shell for more information.
761
+
762
+ As part of the dag definition, a shell task is defined as follows:
763
+
764
+ dag:
765
+ steps:
766
+ shell_task: # The name of the node
767
+ type: task
768
+ command_type: shell
769
+ command: The command to execute, it could be multiline
770
+ optional_ploomber_args: a dictionary of arguments to be passed to ploomber engine
771
+ returns:
772
+ - name: # The name to assign the return value
773
+ kind: json # the default value is json,
774
+ can be object for python objects and metric for metrics
775
+ secrets:
776
+ - my_secret_key # A list of secrets to expose by secrets manager
777
+ catalog:
778
+ get:
779
+ - A list of glob patterns to get from the catalog to the local file system
780
+ put:
781
+ - A list of glob patterns to put to the catalog from the local file system
782
+ on_failure: The name of the step to traverse in case of failure
783
+ overrides:
784
+ Individual tasks can override the global configuration config by referring to the
785
+ specific override.
786
+
787
+ For example,
788
+ #Global configuration
789
+ executor:
790
+ type: local-container
791
+ config:
792
+ docker_image: "runnable/runnable:latest"
793
+ overrides:
794
+ custom_docker_image:
795
+ docker_image: "runnable/runnable:custom"
796
+
797
+ ## In the node definition
798
+ overrides:
799
+ local-container:
800
+ docker_image: "runnable/runnable:custom"
801
+
802
+ This instruction will override the docker image for the local-container executor.
803
+ next: The next node to execute after this task, use "success" to terminate the pipeline successfully
804
+ or "fail" to terminate the pipeline with an error.
805
+ --8<-- [end:shell_reference]
806
+ """
807
+
808
+ task_type: str = Field(default="shell", serialization_alias="command_type")
809
+ command: str
810
+
811
+ @field_validator("returns")
812
+ @classmethod
813
+ def returns_should_be_json(cls, returns: List[TaskReturns]):
814
+ for task_return in returns:
815
+ if task_return.kind == "object" or task_return.kind == "pydantic":
816
+ raise ValueError(
817
+ "Pydantic models or Objects are not allowed in returns"
818
+ )
819
+
820
+ return returns
821
+
822
+ def execute_command(
823
+ self,
824
+ iter_variable: Optional[IterableParameterModel] = None,
825
+ ) -> StepAttempt:
826
+ # Using shell=True as we want to have chained commands to be executed in the same shell.
827
+ """Execute the shell command as defined by the command.
828
+
829
+ Args:
830
+ map_variable (dict, optional): If the node is part of an internal branch. Defaults to None.
831
+ """
832
+ attempt_log = StepAttempt(
833
+ status=defaults.FAIL,
834
+ start_time=str(datetime.now()),
835
+ retry_indicator=self._context.retry_indicator,
836
+ )
837
+ subprocess_env = {}
838
+
839
+ # Expose RUNNABLE environment variables to be passed to the subprocess.
840
+ for key, value in os.environ.items():
841
+ if key.startswith("RUNNABLE_"):
842
+ subprocess_env[key] = value
843
+
844
+ # Expose map variable as environment variables
845
+ if iter_variable and iter_variable.map_variable:
846
+ for key, value_model in iter_variable.map_variable.items():
847
+ subprocess_env[key] = str(value_model.value)
848
+
849
+ # Expose secrets as environment variables
850
+ if self.secrets:
851
+ current_context = context.get_run_context()
852
+ if current_context is None:
853
+ raise RuntimeError("No run context available for secrets")
854
+
855
+ for key in self.secrets:
856
+ secret_value = current_context.secrets.get(key)
857
+ subprocess_env[key] = secret_value
858
+
859
+ with logfire.span(
860
+ "task:{task_name}",
861
+ task_name=self.command[:100], # Truncate long commands
862
+ task_type=self.task_type,
863
+ ):
864
+ try:
865
+ with self.execution_context(
866
+ iter_variable=iter_variable, allow_complex=False
867
+ ) as params:
868
+ logfire.info(
869
+ "Task started",
870
+ inputs=self._safe_serialize_params(params),
871
+ )
872
+ self._emit_event(
873
+ {
874
+ "type": "task_started",
875
+ "name": self.command[:100],
876
+ "inputs": self._safe_serialize_params(params),
877
+ }
878
+ )
879
+
880
+ subprocess_env.update({k: v.get_value() for k, v in params.items()})
881
+
882
+ attempt_log.input_parameters = params.copy()
883
+ # Json dumps all runnable environment variables
884
+ for key, value in subprocess_env.items():
885
+ if isinstance(value, str):
886
+ continue
887
+ subprocess_env[key] = json.dumps(value)
888
+
889
+ collect_delimiter = "=== COLLECT ==="
890
+
891
+ command = (
892
+ self.command.strip() + f" && echo '{collect_delimiter}' && env"
893
+ )
894
+ logger.info(f"Executing shell command: {command}")
895
+
896
+ capture = False
897
+ return_keys = {x.name: x for x in self.returns}
898
+
899
+ proc = subprocess.Popen(
900
+ command,
901
+ shell=True,
902
+ env=subprocess_env,
903
+ stdout=subprocess.PIPE,
904
+ stderr=subprocess.PIPE,
905
+ text=True,
906
+ )
907
+ result = proc.communicate()
908
+ logger.debug(result)
909
+ logger.info(proc.returncode)
910
+
911
+ if proc.returncode != 0:
912
+ msg = ",".join(result[1].split("\n"))
913
+ task_console.print(msg, style=defaults.error_style)
914
+ raise exceptions.CommandCallError(msg)
915
+
916
+ # for stderr
917
+ for line in result[1].split("\n"):
918
+ if line.strip() == "":
919
+ continue
920
+ task_console.print(line, style=defaults.warning_style)
921
+
922
+ output_parameters: Dict[str, Parameter] = {}
923
+ metrics: Dict[str, Parameter] = {}
924
+
925
+ # only from stdout
926
+ for line in result[0].split("\n"):
927
+ if line.strip() == "":
928
+ continue
929
+
930
+ logger.info(line)
931
+ task_console.print(line)
932
+
933
+ if line.strip() == collect_delimiter:
934
+ # The lines from now on should be captured
935
+ capture = True
936
+ continue
937
+
938
+ if capture:
939
+ key, value = line.strip().split("=", 1)
940
+ if key in return_keys:
941
+ task_return = return_keys[key]
942
+
943
+ try:
944
+ value = json.loads(value)
945
+ except json.JSONDecodeError:
946
+ value = value
947
+
948
+ output_parameter = task_return_to_parameter(
949
+ task_return=task_return,
950
+ value=value,
951
+ )
952
+
953
+ if task_return.kind == "metric":
954
+ metrics[task_return.name] = output_parameter
955
+
956
+ output_parameters[task_return.name] = output_parameter
957
+
958
+ attempt_log.output_parameters = output_parameters
959
+ attempt_log.user_defined_metrics = metrics
960
+ params.update(output_parameters)
961
+
962
+ if output_parameters:
963
+ logfire.info(
964
+ "Task completed",
965
+ outputs=self._safe_serialize_params(output_parameters),
966
+ status="success",
967
+ )
968
+ self._emit_event(
969
+ {
970
+ "type": "task_completed",
971
+ "name": self.command[:100],
972
+ "outputs": self._safe_serialize_params(
973
+ output_parameters
974
+ ),
975
+ }
976
+ )
977
+ else:
978
+ logfire.info("Task completed", status="success")
979
+ self._emit_event(
980
+ {
981
+ "type": "task_completed",
982
+ "name": self.command[:100],
983
+ }
984
+ )
985
+
986
+ attempt_log.status = defaults.SUCCESS
987
+ except exceptions.CommandCallError as e:
988
+ msg = f"Call to the command {self.command} did not succeed"
989
+ logger.exception(msg)
990
+ logger.exception(e)
991
+
992
+ task_console.log(msg, style=defaults.error_style)
993
+ task_console.log(e, style=defaults.error_style)
994
+ logfire.error("Task failed", error=str(e)[:256])
995
+ self._emit_event(
996
+ {
997
+ "type": "task_error",
998
+ "name": self.command[:100],
999
+ "error": str(e)[:256],
1000
+ }
1001
+ )
1002
+
1003
+ attempt_log.status = defaults.FAIL
1004
+
1005
+ attempt_log.end_time = str(datetime.now())
1006
+ return attempt_log
1007
+
1008
+
1009
+ class AsyncPythonTaskType(BaseTaskType):
1010
+ """
1011
+ An execution node for async Python functions.
1012
+
1013
+ This task type is designed for async functions that need to be awaited.
1014
+ It supports AsyncGenerator functions for streaming events.
1015
+
1016
+ Usage in pipeline definition:
1017
+ task = AsyncPythonTask(
1018
+ function=my_async_function,
1019
+ name="async_task",
1020
+ returns=[...]
1021
+ )
1022
+ """
1023
+
1024
+ task_type: str = Field(default="async-python", serialization_alias="command_type")
1025
+ command: str
1026
+ stream_end_type: str = Field(default="done")
1027
+
1028
+ def execute_command(
1029
+ self,
1030
+ iter_variable: Optional[IterableParameterModel] = None,
1031
+ ) -> StepAttempt:
1032
+ """Sync execution is not supported for async tasks."""
1033
+ raise RuntimeError(
1034
+ "AsyncPythonTaskType requires async execution. "
1035
+ "Use execute_command_async() or run the pipeline with execute_async()."
1036
+ )
1037
+
1038
+ async def execute_command_async(
1039
+ self,
1040
+ iter_variable: Optional[IterableParameterModel] = None,
1041
+ event_callback: Optional[Callable[[dict], None]] = None,
1042
+ ) -> StepAttempt:
1043
+ """Execute the async Python function."""
1044
+ attempt_log = StepAttempt(
1045
+ status=defaults.FAIL,
1046
+ start_time=str(datetime.now()),
1047
+ retry_indicator=self._context.retry_indicator,
1048
+ )
1049
+
1050
+ with logfire.span(
1051
+ "task:{task_name}",
1052
+ task_name=self.command,
1053
+ task_type=self.task_type,
1054
+ ):
1055
+ with (
1056
+ self.execution_context(iter_variable=iter_variable) as params,
1057
+ self.expose_secrets() as _,
1058
+ ):
1059
+ logfire.info(
1060
+ "Task started",
1061
+ inputs=self._safe_serialize_params(params),
1062
+ )
1063
+ self._emit_event(
1064
+ {
1065
+ "type": "task_started",
1066
+ "name": self.command,
1067
+ "inputs": self._safe_serialize_params(params),
1068
+ }
1069
+ )
1070
+
1071
+ module, func = utils.get_module_and_attr_names(self.command)
1072
+ sys.path.insert(0, os.getcwd())
1073
+ imported_module = importlib.import_module(module)
1074
+ f = getattr(imported_module, func)
1075
+
1076
+ try:
1077
+ try:
1078
+ filtered_parameters = parameters.filter_arguments_for_func(
1079
+ f, params.copy(), iter_variable
1080
+ )
1081
+ logger.info(
1082
+ f"Calling async {func} from {module} with {filtered_parameters}"
1083
+ )
1084
+
1085
+ with redirect_output(console=task_console) as (
1086
+ buffer,
1087
+ stderr_buffer,
1088
+ ):
1089
+ result = f(**filtered_parameters)
1090
+
1091
+ # Check if result is an AsyncGenerator for streaming
1092
+ if inspect.isasyncgen(result):
1093
+ user_set_parameters = None
1094
+ async for item in result:
1095
+ if isinstance(item, dict) and "type" in item:
1096
+ # It's an event - emit it
1097
+ if event_callback:
1098
+ event_callback(item)
1099
+ self._emit_event(item)
1100
+
1101
+ # Extract return values from the final event
1102
+ # The stream end event contains the actual return values
1103
+ if item.get("type") == self.stream_end_type:
1104
+ # Remove the "type" key and use remaining keys as return values
1105
+ return_data = {
1106
+ k: v
1107
+ for k, v in item.items()
1108
+ if k != "type"
1109
+ }
1110
+ # If only one value, return it directly; otherwise return tuple
1111
+ if len(return_data) == 1:
1112
+ user_set_parameters = list(
1113
+ return_data.values()
1114
+ )[0]
1115
+ elif len(return_data) > 1:
1116
+ user_set_parameters = tuple(
1117
+ return_data.values()
1118
+ )
1119
+ elif inspect.iscoroutine(result):
1120
+ # Regular async function
1121
+ user_set_parameters = await result
1122
+ else:
1123
+ # Sync function called through async task (shouldn't happen but handle it)
1124
+ user_set_parameters = result
1125
+
1126
+ except Exception as e:
1127
+ raise exceptions.CommandCallError(
1128
+ f"Async function call: {self.command} did not succeed.\n"
1129
+ ) from e
1130
+ finally:
1131
+ attempt_log.input_parameters = params.copy()
1132
+ if iter_variable and iter_variable.map_variable:
1133
+ attempt_log.input_parameters.update(
1134
+ {
1135
+ k: JsonParameter(value=v.value, kind="json")
1136
+ for k, v in iter_variable.map_variable.items()
1137
+ }
1138
+ )
1139
+
1140
+ if self.returns:
1141
+ if not isinstance(user_set_parameters, tuple):
1142
+ user_set_parameters = (user_set_parameters,)
1143
+
1144
+ if len(user_set_parameters) != len(self.returns):
1145
+ raise ValueError(
1146
+ "Returns task signature does not match the function returns"
1147
+ )
1148
+
1149
+ output_parameters: Dict[str, Parameter] = {}
1150
+ metrics: Dict[str, Parameter] = {}
1151
+
1152
+ for i, task_return in enumerate(self.returns):
1153
+ output_parameter = task_return_to_parameter(
1154
+ task_return=task_return,
1155
+ value=user_set_parameters[i],
1156
+ )
1157
+
1158
+ if task_return.kind == "metric":
1159
+ metrics[task_return.name] = output_parameter
1160
+
1161
+ output_parameters[task_return.name] = output_parameter
1162
+
1163
+ attempt_log.output_parameters = output_parameters
1164
+ attempt_log.user_defined_metrics = metrics
1165
+ params.update(output_parameters)
1166
+
1167
+ logfire.info(
1168
+ "Task completed",
1169
+ outputs=self._safe_serialize_params(output_parameters),
1170
+ status="success",
1171
+ )
1172
+ self._emit_event(
1173
+ {
1174
+ "type": "task_completed",
1175
+ "name": self.command,
1176
+ "outputs": self._safe_serialize_params(
1177
+ output_parameters
1178
+ ),
1179
+ }
1180
+ )
1181
+ else:
1182
+ logfire.info("Task completed", status="success")
1183
+ self._emit_event(
1184
+ {
1185
+ "type": "task_completed",
1186
+ "name": self.command,
1187
+ }
1188
+ )
1189
+
1190
+ attempt_log.status = defaults.SUCCESS
1191
+ except Exception as _e:
1192
+ msg = (
1193
+ f"Call to the async function {self.command} did not succeed.\n"
1194
+ )
1195
+ attempt_log.message = msg
1196
+ task_console.print_exception(show_locals=False)
1197
+ task_console.log(_e, style=defaults.error_style)
1198
+ logfire.error("Task failed", error=str(_e)[:256])
1199
+ self._emit_event(
1200
+ {
1201
+ "type": "task_error",
1202
+ "name": self.command,
1203
+ "error": str(_e)[:256],
1204
+ }
1205
+ )
1206
+
1207
+ attempt_log.end_time = str(datetime.now())
1208
+
1209
+ return attempt_log
1210
+
1211
+
1212
+ def convert_binary_to_string(data):
1213
+ """
1214
+ Recursively converts 1 and 0 values in a nested dictionary to "1" and "0".
1215
+
1216
+ Args:
1217
+ data (dict or any): The input data (dictionary, list, or other).
1218
+
1219
+ Returns:
1220
+ dict or any: The modified data with binary values converted to strings.
1221
+ """
1222
+
1223
+ if isinstance(data, dict):
1224
+ for key, value in data.items():
1225
+ data[key] = convert_binary_to_string(value)
1226
+ return data
1227
+ elif isinstance(data, list):
1228
+ return [convert_binary_to_string(item) for item in data]
1229
+ elif data == 1:
1230
+ return "1"
1231
+ elif data == 0:
1232
+ return "0"
1233
+ else:
1234
+ return data # Return other values unchanged
1235
+
1236
+
1237
+ def create_task(kwargs_for_init) -> BaseTaskType:
1238
+ """
1239
+ Creates a task object from the command configuration.
1240
+
1241
+ Args:
1242
+ A dictionary of keyword arguments that are sent by the user to the task.
1243
+ Check against the model class for the validity of it.
1244
+
1245
+ Returns:
1246
+ tasks.BaseTaskType: The command object
1247
+ """
1248
+ # The dictionary cannot be modified
1249
+
1250
+ kwargs = kwargs_for_init.copy()
1251
+ command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
1252
+
1253
+ kwargs = convert_binary_to_string(kwargs)
1254
+
1255
+ try:
1256
+ task_mgr: driver.DriverManager = driver.DriverManager(
1257
+ namespace="tasks",
1258
+ name=command_type,
1259
+ invoke_on_load=True,
1260
+ invoke_kwds=kwargs,
1261
+ )
1262
+ return cast(BaseTaskType, task_mgr.driver)
1263
+ except Exception as _e:
1264
+ msg = (
1265
+ f"Could not find the task type {command_type}. Please ensure you have installed "
1266
+ "the extension that provides the node type."
1267
+ )
1268
+ raise Exception(msg) from _e