runnable 0.50.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. extensions/README.md +0 -0
  2. extensions/__init__.py +0 -0
  3. extensions/catalog/README.md +0 -0
  4. extensions/catalog/any_path.py +214 -0
  5. extensions/catalog/file_system.py +52 -0
  6. extensions/catalog/minio.py +72 -0
  7. extensions/catalog/pyproject.toml +14 -0
  8. extensions/catalog/s3.py +11 -0
  9. extensions/job_executor/README.md +0 -0
  10. extensions/job_executor/__init__.py +236 -0
  11. extensions/job_executor/emulate.py +70 -0
  12. extensions/job_executor/k8s.py +553 -0
  13. extensions/job_executor/k8s_job_spec.yaml +37 -0
  14. extensions/job_executor/local.py +35 -0
  15. extensions/job_executor/local_container.py +161 -0
  16. extensions/job_executor/pyproject.toml +16 -0
  17. extensions/nodes/README.md +0 -0
  18. extensions/nodes/__init__.py +0 -0
  19. extensions/nodes/conditional.py +301 -0
  20. extensions/nodes/fail.py +78 -0
  21. extensions/nodes/loop.py +394 -0
  22. extensions/nodes/map.py +477 -0
  23. extensions/nodes/parallel.py +281 -0
  24. extensions/nodes/pyproject.toml +15 -0
  25. extensions/nodes/stub.py +93 -0
  26. extensions/nodes/success.py +78 -0
  27. extensions/nodes/task.py +156 -0
  28. extensions/pipeline_executor/README.md +0 -0
  29. extensions/pipeline_executor/__init__.py +871 -0
  30. extensions/pipeline_executor/argo.py +1266 -0
  31. extensions/pipeline_executor/emulate.py +119 -0
  32. extensions/pipeline_executor/local.py +226 -0
  33. extensions/pipeline_executor/local_container.py +369 -0
  34. extensions/pipeline_executor/mocked.py +159 -0
  35. extensions/pipeline_executor/pyproject.toml +16 -0
  36. extensions/run_log_store/README.md +0 -0
  37. extensions/run_log_store/__init__.py +0 -0
  38. extensions/run_log_store/any_path.py +100 -0
  39. extensions/run_log_store/chunked_fs.py +122 -0
  40. extensions/run_log_store/chunked_minio.py +141 -0
  41. extensions/run_log_store/file_system.py +91 -0
  42. extensions/run_log_store/generic_chunked.py +549 -0
  43. extensions/run_log_store/minio.py +114 -0
  44. extensions/run_log_store/pyproject.toml +15 -0
  45. extensions/secrets/README.md +0 -0
  46. extensions/secrets/dotenv.py +62 -0
  47. extensions/secrets/pyproject.toml +15 -0
  48. runnable/__init__.py +108 -0
  49. runnable/catalog.py +141 -0
  50. runnable/cli.py +484 -0
  51. runnable/context.py +730 -0
  52. runnable/datastore.py +1058 -0
  53. runnable/defaults.py +159 -0
  54. runnable/entrypoints.py +390 -0
  55. runnable/exceptions.py +137 -0
  56. runnable/executor.py +561 -0
  57. runnable/gantt.py +1646 -0
  58. runnable/graph.py +501 -0
  59. runnable/names.py +546 -0
  60. runnable/nodes.py +593 -0
  61. runnable/parameters.py +217 -0
  62. runnable/pickler.py +96 -0
  63. runnable/sdk.py +1277 -0
  64. runnable/secrets.py +92 -0
  65. runnable/tasks.py +1268 -0
  66. runnable/telemetry.py +142 -0
  67. runnable/utils.py +423 -0
  68. runnable-0.50.0.dist-info/METADATA +189 -0
  69. runnable-0.50.0.dist-info/RECORD +72 -0
  70. runnable-0.50.0.dist-info/WHEEL +4 -0
  71. runnable-0.50.0.dist-info/entry_points.txt +53 -0
  72. runnable-0.50.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1266 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import secrets
5
+ import shlex
6
+ import string
7
+ from collections import namedtuple
8
+ from enum import Enum
9
+ from functools import cached_property
10
+ from typing import Annotated, Any, Literal, Optional, cast
11
+
12
+ from pydantic import (
13
+ BaseModel,
14
+ ConfigDict,
15
+ Field,
16
+ PlainSerializer,
17
+ PrivateAttr,
18
+ computed_field,
19
+ model_validator,
20
+ )
21
+ from pydantic.alias_generators import to_camel
22
+ from ruamel.yaml import YAML
23
+
24
+ from extensions.nodes.conditional import ConditionalNode
25
+ from extensions.nodes.map import MapNode
26
+ from extensions.nodes.parallel import ParallelNode
27
+ from extensions.nodes.task import TaskNode
28
+ from extensions.pipeline_executor import GenericPipelineExecutor
29
+ from runnable import defaults, exceptions
30
+ from runnable.datastore import StepAttempt
31
+ from runnable.defaults import IterableParameterModel, MapVariableModel
32
+ from runnable.graph import Graph, search_node_by_internal_name
33
+ from runnable.nodes import BaseNode
34
+
35
+ # TODO: Do we need a PVC if we are using remote storage?
36
+
37
+
38
+ class BaseModelWIthConfig(BaseModel, use_enum_values=True):
39
+ model_config = ConfigDict(
40
+ extra="forbid",
41
+ alias_generator=to_camel,
42
+ populate_by_name=True,
43
+ from_attributes=True,
44
+ validate_default=True,
45
+ )
46
+
47
+
48
+ class BackOff(BaseModelWIthConfig):
49
+ duration: str = Field(default="2m")
50
+ factor: float = Field(default=2)
51
+ max_duration: str = Field(default="1h")
52
+
53
+
54
+ class RetryStrategy(BaseModelWIthConfig):
55
+ back_off: Optional[BackOff] = Field(default=None)
56
+ limit: int = 0
57
+ retry_policy: str = Field(default="Always")
58
+
59
+
60
+ class SecretEnvVar(BaseModel):
61
+ """
62
+ Renders:
63
+ env:
64
+ - name: MYSECRETPASSWORD
65
+ valueFrom:
66
+ secretKeyRef:
67
+ name: my-secret
68
+ key: MYSECRETPASSWORD
69
+ """
70
+
71
+ environment_variable: str = Field(serialization_alias="name")
72
+ secret_name: str = Field(exclude=True)
73
+ secret_key: str = Field(exclude=True)
74
+
75
+ @computed_field # type: ignore
76
+ @property
77
+ def valueFrom(self) -> dict[str, dict[str, str]]:
78
+ return {
79
+ "secretKeyRef": {
80
+ "name": self.secret_name,
81
+ "key": self.secret_key,
82
+ }
83
+ }
84
+
85
+
86
+ class EnvVar(BaseModelWIthConfig):
87
+ name: str
88
+ value: str
89
+
90
+
91
+ class OutputParameter(BaseModelWIthConfig):
92
+ name: str
93
+ value_from: dict[str, str] = {
94
+ "path": "/tmp/output.txt",
95
+ }
96
+
97
+
98
+ class Parameter(BaseModelWIthConfig):
99
+ name: str
100
+ value: Optional[str | int | float | bool] = Field(default=None)
101
+
102
+
103
+ class Inputs(BaseModelWIthConfig):
104
+ parameters: Optional[list[Parameter]] = Field(default=None)
105
+
106
+
107
+ class Outputs(BaseModelWIthConfig):
108
+ parameters: Optional[list[OutputParameter]] = Field(default=None)
109
+
110
+
111
+ class Arguments(BaseModelWIthConfig):
112
+ parameters: Optional[list[Parameter]] = Field(default=None)
113
+
114
+
115
+ class ConfigMapCache(BaseModelWIthConfig):
116
+ name: str
117
+ key: str = Field(default="cache")
118
+
119
+
120
+ class Cache(BaseModelWIthConfig):
121
+ config_map: ConfigMapCache
122
+
123
+
124
+ class Memoize(BaseModelWIthConfig):
125
+ key: str
126
+ cache: Optional[Cache] = Field(default=None)
127
+
128
+
129
+ class TolerationEffect(str, Enum):
130
+ NoSchedule = "NoSchedule"
131
+ PreferNoSchedule = "PreferNoSchedule"
132
+ NoExecute = "NoExecute"
133
+
134
+
135
+ class TolerationOperator(str, Enum):
136
+ Exists = "Exists"
137
+ Equal = "Equal"
138
+
139
+
140
+ class PodMetaData(BaseModelWIthConfig):
141
+ annotations: dict[str, str] = Field(default_factory=dict)
142
+ labels: dict[str, str] = Field(default_factory=dict)
143
+
144
+
145
+ class Toleration(BaseModelWIthConfig):
146
+ effect: Optional[TolerationEffect] = Field(default=None)
147
+ key: Optional[str] = Field(default=None)
148
+ operator: TolerationOperator = Field(default=TolerationOperator.Equal)
149
+ tolerationSeconds: Optional[int] = Field(default=None)
150
+ value: Optional[str] = Field(default=None)
151
+
152
+ @model_validator(mode="after")
153
+ def validate_tolerations(self) -> "Toleration":
154
+ if not self.key:
155
+ if self.operator != TolerationOperator.Exists:
156
+ raise ValueError("Toleration key is required when operator is Equal")
157
+
158
+ if self.operator == TolerationOperator.Exists:
159
+ if self.value:
160
+ raise ValueError(
161
+ "Toleration value is not allowed when operator is Exists"
162
+ )
163
+ return self
164
+
165
+
166
+ class ImagePullPolicy(str, Enum):
167
+ Always = "Always"
168
+ IfNotPresent = "IfNotPresent"
169
+ Never = "Never"
170
+
171
+
172
+ class PersistentVolumeClaimSource(BaseModelWIthConfig):
173
+ claim_name: str
174
+ read_only: bool = Field(default=False)
175
+
176
+
177
+ class Volume(BaseModelWIthConfig):
178
+ name: str
179
+ persistent_volume_claim: PersistentVolumeClaimSource
180
+
181
+ def __hash__(self):
182
+ return hash(self.name)
183
+
184
+
185
+ class VolumeMount(BaseModelWIthConfig):
186
+ mount_path: str
187
+ name: str
188
+ read_only: bool = Field(default=False)
189
+
190
+ @model_validator(mode="after")
191
+ def validate_volume_mount(self) -> "VolumeMount":
192
+ if "." in self.mount_path:
193
+ raise ValueError("mount_path cannot contain '.'")
194
+
195
+ return self
196
+
197
+
198
+ VolumePair = namedtuple("VolumePair", ["volume", "volume_mount"])
199
+
200
+
201
+ class LabelSelectorRequirement(BaseModelWIthConfig):
202
+ key: str
203
+ operator: str
204
+ values: list[str]
205
+
206
+
207
+ class PodGCStrategy(str, Enum):
208
+ OnPodCompletion = "OnPodCompletion"
209
+ OnPodSuccess = "OnPodSuccess"
210
+ OnWorkflowCompletion = "OnWorkflowCompletion"
211
+ OnWorkflowSuccess = "OnWorkflowSuccess"
212
+
213
+
214
+ class LabelSelector(BaseModelWIthConfig):
215
+ matchExpressions: list[LabelSelectorRequirement] = Field(default_factory=list)
216
+ matchLabels: dict[str, str] = Field(default_factory=dict)
217
+
218
+
219
+ class PodGC(BaseModelWIthConfig):
220
+ delete_delay_duration: str = Field(default="1h") # 1 hour
221
+ label_selector: Optional[LabelSelector] = Field(default=None)
222
+ strategy: Optional[PodGCStrategy] = Field(default=None)
223
+
224
+
225
+ class Request(BaseModel):
226
+ """
227
+ The default requests
228
+ """
229
+
230
+ memory: str = "1Gi"
231
+ cpu: str = "250m"
232
+
233
+
234
+ VendorGPU = Annotated[
235
+ Optional[int],
236
+ PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
237
+ ]
238
+
239
+
240
+ class Limit(Request):
241
+ """
242
+ The default limits
243
+ """
244
+
245
+ gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
246
+
247
+
248
+ class Resources(BaseModel):
249
+ limits: Limit = Field(default=Limit(), serialization_alias="limits")
250
+ requests: Request = Field(default=Request(), serialization_alias="requests")
251
+
252
+
253
+ # Lets construct this from UserDefaults
254
+ class ArgoTemplateDefaults(BaseModelWIthConfig):
255
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
256
+ fail_fast: bool = Field(default=True)
257
+ node_selector: dict[str, str] = Field(default_factory=dict)
258
+ parallelism: Optional[int] = Field(default=None)
259
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
260
+ timeout: Optional[str] = Field(default=None)
261
+ tolerations: Optional[list[Toleration]] = Field(default=None)
262
+
263
+ model_config = ConfigDict(
264
+ extra="ignore",
265
+ )
266
+
267
+
268
+ class CommonDefaults(BaseModelWIthConfig):
269
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
270
+ fail_fast: bool = Field(default=True)
271
+ node_selector: dict[str, str] = Field(default_factory=dict)
272
+ parallelism: Optional[int] = Field(default=None)
273
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
274
+ timeout: Optional[str] = Field(default=None)
275
+ tolerations: Optional[list[Toleration]] = Field(default=None)
276
+ image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.Always)
277
+ resources: Resources = Field(default_factory=Resources)
278
+ env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
279
+
280
+
281
+ # The user provided defaults at the top level
282
+ class UserDefaults(CommonDefaults):
283
+ image: str
284
+
285
+
286
+ # Overrides need not have image
287
+ class Overrides(CommonDefaults):
288
+ image: Optional[str] = Field(default=None)
289
+
290
+
291
+ # User provides this as part of the argoSpec
292
+ # some an be provided here or as a template default or node override
293
+ class ArgoWorkflowSpec(BaseModelWIthConfig):
294
+ active_deadline_seconds: int = Field(default=86400) # 1 day for the whole workflow
295
+ arguments: Optional[Arguments] = Field(default=None)
296
+ entrypoint: Literal["runnable-dag"] = Field(default="runnable-dag", frozen=True)
297
+ node_selector: dict[str, str] = Field(default_factory=dict)
298
+ parallelism: Optional[int] = Field(default=None) # GLobal parallelism
299
+ pod_gc: Optional[PodGC] = Field(default=None, serialization_alias="podGC")
300
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
301
+ service_account_name: Optional[str] = Field(default=None)
302
+ tolerations: Optional[list[Toleration]] = Field(default=None)
303
+ template_defaults: Optional[ArgoTemplateDefaults] = Field(default=None)
304
+
305
+
306
+ class ArgoMetadata(BaseModelWIthConfig):
307
+ annotations: Optional[dict[str, str]] = Field(default=None)
308
+ generate_name: str # User can mention this to uniquely identify the run
309
+ labels: dict[str, str] = Field(default_factory=dict)
310
+ namespace: Optional[str] = Field(default="default")
311
+
312
+
313
+ class ArgoWorkflow(BaseModelWIthConfig):
314
+ apiVersion: Literal["argoproj.io/v1alpha1"] = Field(
315
+ default="argoproj.io/v1alpha1", frozen=True
316
+ )
317
+ kind: Literal["Workflow"] = Field(default="Workflow", frozen=True)
318
+ metadata: ArgoMetadata
319
+ spec: ArgoWorkflowSpec
320
+
321
+
322
+ class CronSchedule(BaseModelWIthConfig):
323
+ """Minimal cron schedule configuration for CronWorkflows."""
324
+
325
+ schedules: list[str] # Cron expressions, e.g. ["0 0 * * *"]
326
+ timezone: Optional[str] = Field(default=None) # e.g. "America/Los_Angeles"
327
+
328
+
329
+ # The below are not visible to the user
330
+ class DagTask(BaseModelWIthConfig):
331
+ name: str
332
+ template: str # Should be name of a container template or dag template
333
+ arguments: Optional[Arguments] = Field(default=None)
334
+ with_param: Optional[str] = Field(default=None)
335
+ when_param: Optional[str] = Field(default=None, serialization_alias="when")
336
+ depends: Optional[str] = Field(default=None)
337
+
338
+
339
+ class CoreDagTemplate(BaseModelWIthConfig):
340
+ tasks: list[DagTask] = Field(default_factory=list[DagTask])
341
+
342
+
343
+ class CoreContainerTemplate(BaseModelWIthConfig):
344
+ image: str
345
+ command: list[str]
346
+ image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.IfNotPresent)
347
+ env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
348
+ volume_mounts: list[VolumeMount] = Field(default_factory=list)
349
+ resources: Resources = Field(default_factory=Resources)
350
+
351
+
352
+ class DagTemplate(BaseModelWIthConfig):
353
+ name: str
354
+ dag: CoreDagTemplate = Field(default_factory=CoreDagTemplate)
355
+ inputs: Optional[Inputs] = Field(default=None)
356
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
357
+ fail_fast: bool = Field(default=True)
358
+
359
+ model_config = ConfigDict(
360
+ extra="ignore",
361
+ )
362
+
363
+ def __hash__(self):
364
+ return hash(self.name)
365
+
366
+
367
+ class ContainerTemplate((BaseModelWIthConfig)):
368
+ name: str
369
+ container: CoreContainerTemplate
370
+ inputs: Optional[Inputs] = Field(default=None)
371
+ outputs: Optional[Outputs] = Field(default=None)
372
+ memoize: Optional[Memoize] = Field(default=None)
373
+
374
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
375
+ metadata: Optional[PodMetaData] = Field(default=None)
376
+ node_selector: dict[str, str] = Field(default_factory=dict)
377
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
378
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
379
+ timeout: Optional[str] = Field(default=None)
380
+ tolerations: Optional[list[Toleration]] = Field(default=None)
381
+ volumes: Optional[list[Volume]] = Field(default=None)
382
+
383
+ model_config = ConfigDict(
384
+ extra="ignore",
385
+ )
386
+
387
+ def __hash__(self):
388
+ return hash(self.name)
389
+
390
+
391
+ class CustomVolume(BaseModelWIthConfig):
392
+ mount_path: str
393
+ persistent_volume_claim: PersistentVolumeClaimSource
394
+
395
+
396
+ class ArgoExecutor(GenericPipelineExecutor):
397
+ """
398
+ Executes the pipeline using Argo Workflows.
399
+
400
+ The defaults configuration is kept similar to the
401
+ [Argo Workflow spec](https://argo-workflows.readthedocs.io/en/latest/fields/#workflow).
402
+
403
+ Configuration:
404
+
405
+ ```yaml
406
+ pipeline-executor:
407
+ type: argo
408
+ config:
409
+ pvc_for_runnable: "my-pvc"
410
+ custom_volumes:
411
+ - mount_path: "/tmp"
412
+ persistent_volume_claim:
413
+ claim_name: "my-pvc"
414
+ read_only: false/true
415
+ expose_parameters_as_inputs: true/false
416
+ configmap_cache_name: "my-cache-name" # Optional: defaults to runnable-xxxxxx
417
+ secrets_from_k8s:
418
+ - key1
419
+ - key2
420
+ - ...
421
+ output_file: "argo-pipeline.yaml"
422
+ log_level: "DEBUG"/"INFO"/"WARNING"/"ERROR"/"CRITICAL"
423
+ cron_schedule: # Optional: generates CronWorkflow instead of Workflow
424
+ schedules:
425
+ - "0 0 * * *" # Cron expressions
426
+ timezone: "UTC" # Optional timezone
427
+ defaults:
428
+ image: "my-image"
429
+ activeDeadlineSeconds: 86400
430
+ failFast: true
431
+ nodeSelector:
432
+ label: value
433
+ parallelism: 1
434
+ retryStrategy:
435
+ backoff:
436
+ duration: "2m"
437
+ factor: 2
438
+ maxDuration: "1h"
439
+ limit: 0
440
+ retryPolicy: "Always"
441
+ timeout: "1h"
442
+ tolerations:
443
+ imagePullPolicy: "Always"/"IfNotPresent"/"Never"
444
+ resources:
445
+ limits:
446
+ memory: "1Gi"
447
+ cpu: "250m"
448
+ gpu: 0
449
+ requests:
450
+ memory: "1Gi"
451
+ cpu: "250m"
452
+ env:
453
+ - name: "MY_ENV"
454
+ value: "my-value"
455
+ - name: secret_env
456
+ secretName: "my-secret"
457
+ secretKey: "my-key"
458
+ overrides:
459
+ key1:
460
+ ... similar structure to defaults
461
+
462
+ argoWorkflow:
463
+ metadata:
464
+ annotations:
465
+ key1: value1
466
+ key2: value2
467
+ generateName: "my-workflow"
468
+ labels:
469
+ key1: value1
470
+
471
+ ```
472
+
473
+ As of now, ```runnable``` needs a pvc to store the logs and the catalog; provided by ```pvc_for_runnable```.
474
+ - ```custom_volumes``` can be used to mount additional volumes to the container.
475
+
476
+ - ```expose_parameters_as_inputs``` can be used to expose the initial parameters as inputs to the workflow.
477
+ - ```secrets_from_k8s``` can be used to expose the secrets from the k8s secret store.
478
+ - ```output_file``` is the file where the argo pipeline will be dumped.
479
+ - ```log_level``` is the log level for the containers.
480
+ - ```cron_schedule``` generates an Argo CronWorkflow instead of a regular Workflow for scheduled execution.
481
+ - ```defaults``` is the default configuration for all the containers.
482
+
483
+
484
+ """
485
+
486
+ service_name: str = "argo"
487
+ _should_setup_run_log_at_traversal: bool = PrivateAttr(default=False)
488
+ mock: bool = False
489
+
490
+ model_config = ConfigDict(
491
+ extra="forbid",
492
+ alias_generator=to_camel,
493
+ populate_by_name=True,
494
+ from_attributes=True,
495
+ use_enum_values=True,
496
+ )
497
+ pvc_for_runnable: Optional[str] = Field(default=None)
498
+ custom_volumes: Optional[list[CustomVolume]] = Field(
499
+ default_factory=list[CustomVolume]
500
+ )
501
+
502
+ expose_parameters_as_inputs: bool = True
503
+ configmap_cache_name: Optional[str] = Field(default=None)
504
+ secret_from_k8s: Optional[str] = Field(default=None)
505
+ output_file: str = Field(default="argo-pipeline.yaml")
506
+ log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
507
+ default="INFO"
508
+ )
509
+
510
+ defaults: UserDefaults
511
+ argo_workflow: ArgoWorkflow
512
+ cron_schedule: Optional[CronSchedule] = Field(default=None)
513
+
514
+ overrides: dict[str, Overrides] = Field(default_factory=dict)
515
+
516
+ # This should be used when we refer to run_id or log_level in the containers
517
+ _run_id_as_parameter: str = PrivateAttr(default="{{workflow.parameters.run_id}}")
518
+ _log_level_as_parameter: str = PrivateAttr(
519
+ default="{{workflow.parameters.log_level}}"
520
+ )
521
+
522
+ _templates: list[ContainerTemplate | DagTemplate] = PrivateAttr(
523
+ default_factory=list
524
+ )
525
+ _container_log_location: str = PrivateAttr(default="/tmp/run_logs/")
526
+ _container_catalog_location: str = PrivateAttr(default="/tmp/catalog/")
527
+ _added_initial_container: bool = PrivateAttr(default=False)
528
+ _cache_name: str = PrivateAttr(default="")
529
+
530
+ def model_post_init(self, __context: Any) -> None:
531
+ self.argo_workflow.spec.template_defaults = ArgoTemplateDefaults(
532
+ **self.defaults.model_dump()
533
+ )
534
+ self._cache_name = self._generate_cache_name()
535
+
536
+ def _generate_cache_name(self) -> str:
537
+ """Generate or use configured ConfigMap name for this workflow's cache."""
538
+ if self.configmap_cache_name:
539
+ return self.configmap_cache_name
540
+
541
+ chars = string.ascii_lowercase + string.digits
542
+ suffix = "".join(secrets.choice(chars) for _ in range(6))
543
+ return f"runnable-{suffix}"
544
+
545
+ @property
546
+ def cache_name(self) -> str:
547
+ """Get the ConfigMap name for this workflow's memoization cache."""
548
+ return self._cache_name
549
+
550
+ def sanitize_name(self, name: str) -> str:
551
+ formatted_name = name.replace(" ", "-").replace(".", "-").replace("_", "-")
552
+ tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
553
+ unique_name = f"{formatted_name}-{tag}"
554
+ unique_name = unique_name.replace("map-variable-placeholder-", "")
555
+ return unique_name
556
+
557
+ def _add_retry_env_vars(self, container_template: CoreContainerTemplate):
558
+ """Add retry environment variables to all containers."""
559
+ # Add retry run id environment variable
560
+ retry_run_id_env = EnvVar(
561
+ name=defaults.RETRY_RUN_ID, value="{{workflow.parameters.retry_run_id}}"
562
+ )
563
+ container_template.env.append(retry_run_id_env)
564
+
565
+ # Add retry indicator environment variable
566
+ retry_indicator_env = EnvVar(
567
+ name=defaults.RETRY_INDICATOR,
568
+ value="{{workflow.parameters.retry_indicator}}",
569
+ )
570
+ container_template.env.append(retry_indicator_env)
571
+
572
+ def _set_up_initial_container(self, container_template: CoreContainerTemplate):
573
+ if self._added_initial_container:
574
+ return
575
+
576
+ parameters: list[Parameter] = []
577
+
578
+ if self.argo_workflow.spec.arguments:
579
+ parameters = self.argo_workflow.spec.arguments.parameters or []
580
+
581
+ for parameter in parameters or []:
582
+ key, _ = parameter.name, parameter.value
583
+ env_var = EnvVar(
584
+ name=defaults.PARAMETER_PREFIX + key,
585
+ value="{{workflow.parameters." + key + "}}",
586
+ )
587
+ container_template.env.append(env_var)
588
+
589
+ env_var = EnvVar(name="error_on_existing_run_id", value="true")
590
+ container_template.env.append(env_var)
591
+
592
+ # After the first container is added, set the added_initial_container to True
593
+ self._added_initial_container = True
594
+
595
+ def _create_fan_templates(
596
+ self,
597
+ node: BaseNode,
598
+ mode: str,
599
+ parameters: Optional[list[Parameter]],
600
+ task_name: str,
601
+ ):
602
+ iter_variable: IterableParameterModel = IterableParameterModel()
603
+ for parameter in parameters or []:
604
+ iter_variable.map_variable[parameter.name] = ( # type: ignore
605
+ "{{inputs.parameters." + str(parameter.name) + "}}"
606
+ )
607
+
608
+ fan_command = self._context.get_fan_command(
609
+ mode=mode,
610
+ node=node,
611
+ run_id=self._run_id_as_parameter,
612
+ iter_variable=iter_variable,
613
+ )
614
+
615
+ core_container_template = CoreContainerTemplate(
616
+ command=shlex.split(fan_command),
617
+ image=self.defaults.image,
618
+ image_pull_policy=self.defaults.image_pull_policy,
619
+ volume_mounts=[
620
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
621
+ ],
622
+ )
623
+
624
+ # Add retry environment variables to all containers
625
+ self._add_retry_env_vars(container_template=core_container_template)
626
+
627
+ # Either a task or a fan-out can the first container
628
+ self._set_up_initial_container(container_template=core_container_template)
629
+
630
+ task_name += f"-fan-{mode}"
631
+
632
+ outputs: Optional[Outputs] = None
633
+ if mode == "out" and node.node_type == "map":
634
+ outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
635
+ if mode == "out" and node.node_type == "conditional":
636
+ outputs = Outputs(parameters=[OutputParameter(name="case")])
637
+
638
+ config_map_key = (
639
+ "{{workflow.parameters.run_id}}-"
640
+ + f"{task_name}-"
641
+ + "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
642
+ )
643
+
644
+ container_template = ContainerTemplate(
645
+ name=task_name,
646
+ container=core_container_template,
647
+ inputs=Inputs(parameters=parameters),
648
+ outputs=outputs,
649
+ memoize=Memoize(
650
+ key=config_map_key,
651
+ cache=Cache(config_map=ConfigMapCache(name=self.cache_name)),
652
+ ),
653
+ active_deadline_seconds=self.defaults.active_deadline_seconds,
654
+ node_selector=self.defaults.node_selector,
655
+ parallelism=self.defaults.parallelism,
656
+ retry_strategy=self.defaults.retry_strategy,
657
+ timeout=self.defaults.timeout,
658
+ tolerations=self.defaults.tolerations,
659
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
660
+ )
661
+
662
+ self._templates.append(container_template)
663
+
664
+ def _create_container_template(
665
+ self,
666
+ node: BaseNode,
667
+ task_name: str,
668
+ inputs: Optional[Inputs] = None,
669
+ ) -> ContainerTemplate:
670
+ assert node.node_type in ["task", "success", "stub", "fail"]
671
+
672
+ node_override = None
673
+ if hasattr(node, "overrides"):
674
+ override_key = node.overrides.get(self.service_name, "")
675
+ try:
676
+ node_override = self.overrides.get(override_key)
677
+ except: # noqa
678
+ raise Exception("Override not found for: ", override_key)
679
+
680
+ effective_settings = self.defaults.model_dump()
681
+ if node_override:
682
+ effective_settings.update(node_override.model_dump(exclude_none=True))
683
+
684
+ inputs = inputs or Inputs(parameters=[])
685
+
686
+ # Should look like: '{"map_variable":{"chunk":{"value":3}},"loop_variable":[]}'
687
+ iter_variable: IterableParameterModel = IterableParameterModel()
688
+ for parameter in inputs.parameters or []:
689
+ map_variable = MapVariableModel(
690
+ value=f"{{{{inputs.parameters.{str(parameter.name)}}}}}"
691
+ )
692
+ iter_variable.map_variable[parameter.name] = ( # type: ignore
693
+ map_variable
694
+ )
695
+
696
+ # command = "runnable execute-single-node"
697
+ command = self._context.get_node_callable_command(
698
+ node=node,
699
+ iter_variable=iter_variable,
700
+ over_write_run_id=self._run_id_as_parameter,
701
+ log_level=self._log_level_as_parameter,
702
+ )
703
+
704
+ core_container_template = CoreContainerTemplate(
705
+ command=shlex.split(command),
706
+ image=effective_settings["image"],
707
+ image_pull_policy=effective_settings["image_pull_policy"],
708
+ resources=effective_settings["resources"],
709
+ volume_mounts=[
710
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
711
+ ],
712
+ )
713
+
714
+ # Add retry environment variables to all containers
715
+ self._add_retry_env_vars(container_template=core_container_template)
716
+
717
+ self._set_up_initial_container(container_template=core_container_template)
718
+ self._expose_secrets_to_task(
719
+ working_on=node, container_template=core_container_template
720
+ )
721
+ self._set_env_vars_to_task(node, core_container_template)
722
+ config_map_key = (
723
+ "{{workflow.parameters.run_id}}-"
724
+ + f"{task_name}-"
725
+ + "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
726
+ )
727
+
728
+ container_template = ContainerTemplate(
729
+ name=task_name,
730
+ container=core_container_template,
731
+ inputs=Inputs(
732
+ parameters=[
733
+ Parameter(name=param.name) for param in inputs.parameters or []
734
+ ]
735
+ ),
736
+ memoize=Memoize(
737
+ key=config_map_key,
738
+ cache=Cache(config_map=ConfigMapCache(name=self.cache_name)),
739
+ ),
740
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
741
+ **node_override.model_dump() if node_override else {},
742
+ )
743
+
744
+ if iter_variable and iter_variable.map_variable:
745
+ # Do not cache map tasks as they have different inputs each time
746
+ container_template.memoize = None
747
+
748
+ return container_template
749
+
750
+ def _set_env_vars_to_task(
751
+ self, working_on: BaseNode, container_template: CoreContainerTemplate
752
+ ):
753
+ global_envs: dict[str, str] = {}
754
+
755
+ # Apply defaults environment variables to all node types
756
+ for env_var in self.defaults.env:
757
+ env_var = cast(EnvVar, env_var)
758
+ global_envs[env_var.name] = env_var.value
759
+
760
+ # Apply node-specific overrides only for task nodes that support overrides
761
+ if working_on.node_type in ["task"] and hasattr(working_on, "overrides"):
762
+ override_key = working_on.overrides.get(self.service_name, "")
763
+ node_override = self.overrides.get(override_key, None)
764
+
765
+ # Update the global envs with the node overrides
766
+ if node_override:
767
+ for env_var in node_override.env:
768
+ env_var = cast(EnvVar, env_var)
769
+ global_envs[env_var.name] = env_var.value
770
+
771
+ for key, value in global_envs.items():
772
+ env_var_to_add = EnvVar(name=key, value=value)
773
+ container_template.env.append(env_var_to_add)
774
+
775
+ # Add argo uid as environment variable
776
+ argo_uid_env = EnvVar(
777
+ name="RUNNABLE_CODE_ID_ARGO_WORKFLOW_UID",
778
+ value="{{workflow.uid}}",
779
+ )
780
+ container_template.env.append(argo_uid_env)
781
+
782
+ def _expose_secrets_to_task(
783
+ self,
784
+ working_on: BaseNode,
785
+ container_template: CoreContainerTemplate,
786
+ ):
787
+ if not isinstance(working_on, TaskNode):
788
+ return
789
+ secrets = working_on.executable.secrets
790
+ for secret in secrets:
791
+ assert self.secret_from_k8s is not None
792
+ secret_env_var = SecretEnvVar(
793
+ environment_variable=secret,
794
+ secret_name=self.secret_from_k8s, # This has to be exposed from config
795
+ secret_key=secret,
796
+ )
797
+ container_template.env.append(secret_env_var)
798
+
799
+ def _handle_failures(
800
+ self,
801
+ working_on: BaseNode,
802
+ dag: Graph,
803
+ task_name: str,
804
+ parent_dag_template: DagTemplate,
805
+ ):
806
+ if working_on._get_on_failure_node():
807
+ # Create a new dag template
808
+ on_failure_dag: DagTemplate = DagTemplate(name=f"on-failure-{task_name}")
809
+ # Add on failure of the current task to be the failure dag template
810
+ on_failure_task = DagTask(
811
+ name=f"on-failure-{task_name}",
812
+ template=f"on-failure-{task_name}",
813
+ depends=task_name + ".Failed",
814
+ )
815
+ # Set failfast of the dag template to be false
816
+ # If not, this branch will never be invoked
817
+ parent_dag_template.fail_fast = False
818
+
819
+ assert parent_dag_template.dag
820
+
821
+ parent_dag_template.dag.tasks.append(on_failure_task)
822
+
823
+ self._gather_tasks_for_dag_template(
824
+ on_failure_dag,
825
+ dag=dag,
826
+ start_at=working_on._get_on_failure_node(),
827
+ )
828
+
829
+ # For the future me:
830
+ # - A task can output a array: in this case, its the fan out.
831
+ # - We are using withParam and arguments of the map template to send that value in
832
+ # - The map template should receive that value as a parameter into the template.
833
+ # - The task then start to use it as inputs.parameters.iterate-on
834
+ # the when param should be an evaluation
835
+
836
+ def _gather_tasks_for_dag_template(
837
+ self,
838
+ dag_template: DagTemplate,
839
+ dag: Graph,
840
+ start_at: str,
841
+ parameters: Optional[list[Parameter]] = None,
842
+ ):
843
+ current_node: str = start_at
844
+ depends: str = ""
845
+
846
+ dag_template.dag = CoreDagTemplate()
847
+
848
+ while True:
849
+ # Create the dag task with for the parent dag
850
+ working_on: BaseNode = dag.get_node_by_name(current_node)
851
+ task_name = self.sanitize_name(working_on.internal_name)
852
+ current_task = DagTask(
853
+ name=task_name,
854
+ template=task_name,
855
+ depends=depends if not depends else depends + ".Succeeded",
856
+ arguments=Arguments(
857
+ parameters=[
858
+ Parameter(
859
+ name=param.name,
860
+ value=f"{{{{inputs.parameters.{param.name}}}}}",
861
+ )
862
+ for param in parameters or []
863
+ ]
864
+ ),
865
+ )
866
+ dag_template.dag.tasks.append(current_task)
867
+ depends = task_name
868
+
869
+ match working_on.node_type:
870
+ case "task" | "success" | "stub" | "fail":
871
+ template_of_container = self._create_container_template(
872
+ working_on,
873
+ task_name=task_name,
874
+ inputs=Inputs(parameters=parameters),
875
+ )
876
+ assert template_of_container.container is not None
877
+
878
+ self._templates.append(template_of_container)
879
+
880
+ case "map" | "parallel" | "conditional":
881
+ assert (
882
+ isinstance(working_on, MapNode)
883
+ or isinstance(working_on, ParallelNode)
884
+ or isinstance(working_on, ConditionalNode)
885
+ )
886
+ node_type = working_on.node_type
887
+
888
+ composite_template: DagTemplate = DagTemplate(
889
+ name=task_name, fail_fast=False
890
+ )
891
+
892
+ # Add the fan out task
893
+ fan_out_task = DagTask(
894
+ name=f"{task_name}-fan-out",
895
+ template=f"{task_name}-fan-out",
896
+ arguments=Arguments(parameters=parameters),
897
+ )
898
+ composite_template.dag.tasks.append(fan_out_task)
899
+ self._create_fan_templates(
900
+ node=working_on,
901
+ mode="out",
902
+ parameters=parameters,
903
+ task_name=task_name,
904
+ )
905
+
906
+ # Add the composite task
907
+ with_param: Optional[str] = None
908
+ when_param: Optional[str] = None
909
+ added_parameters = parameters or []
910
+ branches = {}
911
+ if node_type == "map":
912
+ # If the node is map, we need to handle the iterate as and on
913
+ assert isinstance(working_on, MapNode)
914
+ added_parameters = added_parameters + [
915
+ Parameter(name=working_on.iterate_as, value="{{item}}")
916
+ ]
917
+ with_param = f"{{{{tasks.{task_name}-fan-out.outputs.parameters.iterate-on}}}}"
918
+
919
+ branches["branch"] = working_on.branch
920
+ elif node_type == "parallel":
921
+ assert isinstance(working_on, ParallelNode)
922
+ branches = working_on.branches
923
+ elif node_type == "conditional":
924
+ assert isinstance(working_on, ConditionalNode)
925
+ branches = working_on.branches
926
+ when_param = (
927
+ f"{{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
928
+ )
929
+ else:
930
+ raise ValueError("Invalid node type")
931
+
932
+ fan_in_depends = ""
933
+
934
+ for name, branch in branches.items():
935
+ match_when = branch.internal_branch_name.split(".")[-1]
936
+ name = (
937
+ name.replace(" ", "-").replace(".", "-").replace("_", "-")
938
+ )
939
+
940
+ if node_type == "conditional":
941
+ assert isinstance(working_on, ConditionalNode)
942
+ when_param = f"'{match_when}' == {{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
943
+
944
+ branch_task = DagTask(
945
+ name=f"{task_name}-{name}",
946
+ template=f"{task_name}-{name}",
947
+ depends=f"{task_name}-fan-out.Succeeded",
948
+ arguments=Arguments(parameters=added_parameters),
949
+ with_param=with_param,
950
+ when_param=when_param,
951
+ )
952
+ composite_template.dag.tasks.append(branch_task)
953
+
954
+ branch_template = DagTemplate(
955
+ name=branch_task.name,
956
+ inputs=Inputs(
957
+ parameters=[
958
+ Parameter(name=param.name, value=None)
959
+ for param in added_parameters
960
+ ]
961
+ ),
962
+ )
963
+
964
+ assert isinstance(branch, Graph)
965
+
966
+ self._gather_tasks_for_dag_template(
967
+ dag_template=branch_template,
968
+ dag=branch,
969
+ start_at=branch.start_at,
970
+ parameters=added_parameters,
971
+ )
972
+
973
+ fan_in_depends += f"{branch_task.name}.Succeeded || {branch_task.name}.Failed || "
974
+
975
+ fan_in_task = DagTask(
976
+ name=f"{task_name}-fan-in",
977
+ template=f"{task_name}-fan-in",
978
+ depends=fan_in_depends.strip(" || "),
979
+ arguments=Arguments(parameters=parameters),
980
+ )
981
+
982
+ composite_template.dag.tasks.append(fan_in_task)
983
+ self._create_fan_templates(
984
+ node=working_on,
985
+ mode="in",
986
+ parameters=parameters,
987
+ task_name=task_name,
988
+ )
989
+
990
+ self._templates.append(composite_template)
991
+
992
+ self._handle_failures(
993
+ working_on,
994
+ dag,
995
+ task_name,
996
+ parent_dag_template=dag_template,
997
+ )
998
+
999
+ if working_on.node_type == "success" or working_on.node_type == "fail":
1000
+ break
1001
+
1002
+ current_node = working_on._get_next_node()
1003
+
1004
+ self._templates.append(dag_template)
1005
+
1006
+ def execute_graph(
1007
+ self,
1008
+ dag: Graph,
1009
+ iter_variable: Optional[IterableParameterModel] = None,
1010
+ ):
1011
+ # All the arguments set at the spec level can be referred as "{{workflow.parameters.*}}"
1012
+ # We want to use that functionality to override the parameters at the task level
1013
+ # We should be careful to override them only at the first task.
1014
+ arguments = [] # Can be updated in the UI
1015
+ if self.expose_parameters_as_inputs:
1016
+ for key, value in self._get_parameters().items():
1017
+ value = value.get_value() # type: ignore
1018
+ if isinstance(value, dict) or isinstance(value, list):
1019
+ continue
1020
+
1021
+ parameter = Parameter(name=key, value=value) # type: ignore
1022
+ arguments.append(parameter)
1023
+
1024
+ # run_id parameter - required at workflow submission time (no default value)
1025
+ run_id_var = Parameter(name="run_id")
1026
+ arguments.append(run_id_var)
1027
+
1028
+ # Optional retry parameters with empty string defaults
1029
+ retry_run_id_var = Parameter(name="retry_run_id", value="")
1030
+ retry_indicator_var = Parameter(name="retry_indicator", value="")
1031
+ arguments.append(retry_run_id_var)
1032
+ arguments.append(retry_indicator_var)
1033
+
1034
+ log_level_var = Parameter(name="log_level", value=self.log_level)
1035
+ arguments.append(log_level_var)
1036
+ self.argo_workflow.spec.arguments = Arguments(parameters=arguments)
1037
+
1038
+ # This is the entry point of the argo execution
1039
+ runnable_dag: DagTemplate = DagTemplate(name="runnable-dag")
1040
+
1041
+ self._gather_tasks_for_dag_template(
1042
+ runnable_dag,
1043
+ dag,
1044
+ start_at=dag.start_at,
1045
+ parameters=[],
1046
+ )
1047
+
1048
+ argo_workflow_dump = self.argo_workflow.model_dump(
1049
+ by_alias=True,
1050
+ exclude_none=True,
1051
+ round_trip=False,
1052
+ )
1053
+ argo_workflow_dump["spec"]["templates"] = [
1054
+ template.model_dump(
1055
+ by_alias=True,
1056
+ exclude_none=True,
1057
+ )
1058
+ for template in self._templates
1059
+ ]
1060
+
1061
+ argo_workflow_dump["spec"]["volumes"] = [
1062
+ volume_pair.volume.model_dump(by_alias=True)
1063
+ for volume_pair in self.volume_pairs
1064
+ ]
1065
+
1066
+ # If cron_schedule is set, wrap in CronWorkflow
1067
+ if self.cron_schedule:
1068
+ output_dump = self._wrap_as_cron_workflow(argo_workflow_dump)
1069
+ else:
1070
+ output_dump = argo_workflow_dump
1071
+
1072
+ yaml = YAML()
1073
+ with open(self.output_file, "w") as f:
1074
+ yaml.indent(mapping=2, sequence=4, offset=2)
1075
+ yaml.dump(
1076
+ output_dump,
1077
+ f,
1078
+ )
1079
+
1080
+ def _wrap_as_cron_workflow(self, workflow_dump: dict) -> dict:
1081
+ """Wrap a Workflow dump as a CronWorkflow."""
1082
+ assert self.cron_schedule is not None
1083
+
1084
+ # Extract metadata and convert generateName to name
1085
+ metadata = workflow_dump["metadata"].copy()
1086
+ generate_name = metadata.pop("generateName", "runnable-cron-")
1087
+ metadata["name"] = generate_name.rstrip("-")
1088
+
1089
+ # Build CronWorkflow spec
1090
+ cron_spec: dict[str, Any] = {
1091
+ "schedules": self.cron_schedule.schedules,
1092
+ "workflowSpec": workflow_dump["spec"],
1093
+ }
1094
+
1095
+ if self.cron_schedule.timezone:
1096
+ cron_spec["timezone"] = self.cron_schedule.timezone
1097
+
1098
+ return {
1099
+ "apiVersion": "argoproj.io/v1alpha1",
1100
+ "kind": "CronWorkflow",
1101
+ "metadata": metadata,
1102
+ "spec": cron_spec,
1103
+ }
1104
+
1105
+ def _implicitly_fail(
1106
+ self,
1107
+ node: BaseNode,
1108
+ iter_variable: Optional[IterableParameterModel] = None,
1109
+ ):
1110
+ assert self._context.dag
1111
+ _, current_branch = search_node_by_internal_name(
1112
+ dag=self._context.dag, internal_name=node.internal_name
1113
+ )
1114
+ _, next_node_name = self._get_status_and_next_node_name(
1115
+ node, current_branch, iter_variable=iter_variable
1116
+ )
1117
+ if next_node_name:
1118
+ # Terminal nodes do not have next node name
1119
+ next_node = current_branch.get_node_by_name(next_node_name)
1120
+
1121
+ if next_node.node_type == defaults.FAIL:
1122
+ self.execute_node(next_node, iter_variable=iter_variable)
1123
+
1124
+ def add_code_identities(self, node: BaseNode, attempt_log: StepAttempt):
1125
+ super().add_code_identities(node, attempt_log)
1126
+
1127
+ workflow_uid = os.getenv("RUNNABLE_CODE_ID_ARGO_WORKFLOW_UID")
1128
+
1129
+ if workflow_uid:
1130
+ code_id = self._context.run_log_store.create_code_identity()
1131
+
1132
+ code_id.code_identifier = workflow_uid
1133
+ code_id.code_identifier_type = "argo"
1134
+ code_id.code_identifier_dependable = True
1135
+ code_id.code_identifier_url = "argo workflow"
1136
+ attempt_log.code_identities.append(code_id)
1137
+
1138
+ def execute_node(
1139
+ self,
1140
+ node: BaseNode,
1141
+ iter_variable: Optional[IterableParameterModel] = None,
1142
+ ):
1143
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
1144
+ exists_ok = error_on_existing_run_id == "false"
1145
+
1146
+ self._use_volumes()
1147
+ self._set_up_run_log(exists_ok=exists_ok)
1148
+
1149
+ try:
1150
+ # This should only happen during a retry
1151
+ step_log = self._context.run_log_store.get_step_log(
1152
+ node._get_step_log_name(iter_variable), self._context.run_id
1153
+ )
1154
+ assert self._context.is_retry
1155
+ except exceptions.StepLogNotFoundError:
1156
+ step_log = self._context.run_log_store.create_step_log(
1157
+ node.name, node._get_step_log_name(iter_variable)
1158
+ )
1159
+
1160
+ step_log.step_type = node.node_type
1161
+ step_log.status = defaults.PROCESSING
1162
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
1163
+
1164
+ self._execute_node(node=node, iter_variable=iter_variable)
1165
+
1166
+ # Raise exception if the step failed
1167
+ step_log = self._context.run_log_store.get_step_log(
1168
+ node._get_step_log_name(iter_variable), self._context.run_id
1169
+ )
1170
+ if step_log.status == defaults.FAIL:
1171
+ run_log = self._context.run_log_store.get_run_log_by_id(
1172
+ self._context.run_id
1173
+ )
1174
+ run_log.status = defaults.FAIL
1175
+ self._context.run_log_store.put_run_log(run_log)
1176
+ raise Exception(f"Step {node.name} failed")
1177
+
1178
+ # This makes the fail node execute if we are heading that way.
1179
+ self._implicitly_fail(node, iter_variable)
1180
+
1181
+ def fan_out(
1182
+ self,
1183
+ node: BaseNode,
1184
+ iter_variable: Optional[IterableParameterModel] = None,
1185
+ ):
1186
+ # This could be the first step of the graph
1187
+ self._use_volumes()
1188
+
1189
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
1190
+ exists_ok = error_on_existing_run_id == "false"
1191
+ self._set_up_run_log(exists_ok=exists_ok)
1192
+
1193
+ super().fan_out(node, iter_variable)
1194
+
1195
+ # If its a map node, write the list values to "/tmp/output.txt"
1196
+ if node.node_type == "map":
1197
+ assert isinstance(node, MapNode)
1198
+ iterate_on = self._context.run_log_store.get_parameters(
1199
+ self._context.run_id
1200
+ )[node.iterate_on]
1201
+
1202
+ with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
1203
+ json.dump(iterate_on.get_value(), myfile, indent=4)
1204
+
1205
+ if node.node_type == "conditional":
1206
+ assert isinstance(node, ConditionalNode)
1207
+
1208
+ with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
1209
+ json.dump(node.get_parameter_value(), myfile, indent=4)
1210
+
1211
+ def fan_in(
1212
+ self,
1213
+ node: BaseNode,
1214
+ iter_variable: Optional[IterableParameterModel] = None,
1215
+ ):
1216
+ self._use_volumes()
1217
+ super().fan_in(node, iter_variable)
1218
+
1219
+ def _use_volumes(self):
1220
+ match self._context.run_log_store.service_name:
1221
+ case "file-system":
1222
+ self._context.run_log_store.log_folder = self._container_log_location
1223
+ case "chunked-fs":
1224
+ self._context.run_log_store.log_folder = self._container_log_location
1225
+
1226
+ match self._context.catalog.service_name:
1227
+ case "file-system":
1228
+ self._context.catalog.catalog_location = (
1229
+ self._container_catalog_location
1230
+ )
1231
+
1232
+ @cached_property
1233
+ def volume_pairs(self) -> list[VolumePair]:
1234
+ volume_pairs: list[VolumePair] = []
1235
+
1236
+ if self.pvc_for_runnable:
1237
+ common_volume = Volume(
1238
+ name="runnable",
1239
+ persistent_volume_claim=PersistentVolumeClaimSource(
1240
+ claim_name=self.pvc_for_runnable
1241
+ ),
1242
+ )
1243
+ common_volume_mount = VolumeMount(
1244
+ name="runnable",
1245
+ mount_path="/tmp",
1246
+ )
1247
+ volume_pairs.append(
1248
+ VolumePair(volume=common_volume, volume_mount=common_volume_mount)
1249
+ )
1250
+ counter = 0
1251
+ for custom_volume in self.custom_volumes or []:
1252
+ name = f"custom-volume-{counter}"
1253
+ volume_pairs.append(
1254
+ VolumePair(
1255
+ volume=Volume(
1256
+ name=name,
1257
+ persistent_volume_claim=custom_volume.persistent_volume_claim,
1258
+ ),
1259
+ volume_mount=VolumeMount(
1260
+ name=name,
1261
+ mount_path=custom_volume.mount_path,
1262
+ ),
1263
+ )
1264
+ )
1265
+ counter += 1
1266
+ return volume_pairs