runnable 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,65 +1,53 @@
1
1
  import json
2
- import logging
2
+ import os
3
3
  import random
4
4
  import shlex
5
5
  import string
6
- from abc import ABC, abstractmethod
7
- from collections import OrderedDict
8
- from typing import Dict, List, Optional, Union, cast
6
+ from collections import namedtuple
7
+ from enum import Enum
8
+ from functools import cached_property
9
+ from typing import Annotated, Literal, Optional
9
10
 
10
11
  from pydantic import (
11
12
  BaseModel,
12
13
  ConfigDict,
13
14
  Field,
15
+ PlainSerializer,
16
+ PrivateAttr,
14
17
  computed_field,
15
- field_serializer,
16
- field_validator,
18
+ model_validator,
17
19
  )
18
- from pydantic.functional_serializers import PlainSerializer
20
+ from pydantic.alias_generators import to_camel
19
21
  from ruamel.yaml import YAML
20
- from typing_extensions import Annotated
21
22
 
22
- from extensions.nodes.nodes import DagNode, MapNode, ParallelNode
23
+ from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
23
24
  from extensions.pipeline_executor import GenericPipelineExecutor
24
- from runnable import defaults, exceptions, utils
25
+ from runnable import defaults, utils
25
26
  from runnable.defaults import TypeMapVariable
26
- from runnable.graph import Graph, create_node, search_node_by_internal_name
27
+ from runnable.graph import Graph, search_node_by_internal_name
27
28
  from runnable.nodes import BaseNode
28
29
 
29
- logger = logging.getLogger(defaults.NAME)
30
-
31
- # TODO: Leave the run log in consistent state.
32
- # TODO: Make the config camel case just like Argo does.
33
-
34
- """
35
- executor:
36
- type: argo
37
- config:
38
- image: # apply to template
39
- max_workflow_duration: # Apply to spec
40
- nodeSelector: #Apply to spec
41
- parallelism: #apply to spec
42
- resources: # convert to podSpecPath
43
- limits:
44
- requests:
45
- retryStrategy:
46
- max_step_duration: # apply to templateDefaults
47
- step_timeout: # apply to templateDefaults
48
- tolerations: # apply to spec
49
- imagePullPolicy: # apply to template
50
-
51
- overrides:
52
- override:
53
- tolerations: # template
54
- image: # container
55
- max_step_duration: # template
56
- step_timeout: #template
57
- nodeSelector: #template
58
- parallelism: # this need to applied for map
59
- resources: # container
60
- imagePullPolicy: #container
61
- retryStrategy: # template
62
- """
30
+
31
+ class BaseModelWIthConfig(BaseModel, use_enum_values=True):
32
+ model_config = ConfigDict(
33
+ extra="forbid",
34
+ alias_generator=to_camel,
35
+ populate_by_name=True,
36
+ from_attributes=True,
37
+ validate_default=True,
38
+ )
39
+
40
+
41
+ class BackOff(BaseModelWIthConfig):
42
+ duration: str = Field(default="2m")
43
+ factor: float = Field(default=2)
44
+ max_duration: str = Field(default="1h")
45
+
46
+
47
+ class RetryStrategy(BaseModelWIthConfig):
48
+ back_off: Optional[BackOff] = Field(default=None)
49
+ limit: int = 0
50
+ retry_policy: str = Field(default="Always")
63
51
 
64
52
 
65
53
  class SecretEnvVar(BaseModel):
@@ -79,7 +67,7 @@ class SecretEnvVar(BaseModel):
79
67
 
80
68
  @computed_field # type: ignore
81
69
  @property
82
- def valueFrom(self) -> Dict[str, Dict[str, str]]:
70
+ def value_from(self) -> dict[str, dict[str, str]]:
83
71
  return {
84
72
  "secretKeyRef": {
85
73
  "name": self.secret_name,
@@ -88,778 +76,756 @@ class SecretEnvVar(BaseModel):
88
76
  }
89
77
 
90
78
 
91
- class EnvVar(BaseModel):
92
- """
93
- Renders:
94
- parameters: # in arguments
95
- - name: x
96
- value: 3 # This is optional for workflow parameters
79
+ class EnvVar(BaseModelWIthConfig):
80
+ name: str
81
+ value: str
97
82
 
98
- """
99
83
 
84
+ class OutputParameter(BaseModelWIthConfig):
100
85
  name: str
101
- value: Union[str, int, float] = Field(default="")
86
+ value_from: dict[str, str] = {
87
+ "path": "/tmp/output.txt",
88
+ }
102
89
 
103
90
 
104
- class Parameter(BaseModel):
91
+ class Parameter(BaseModelWIthConfig):
105
92
  name: str
106
- value: Optional[str] = None
93
+ value: Optional[str | int | float | bool] = Field(default=None)
107
94
 
108
- @field_serializer("name")
109
- def serialize_name(self, name: str) -> str:
110
- return f"{str(name)}"
111
95
 
112
- @field_serializer("value")
113
- def serialize_value(self, value: str) -> str:
114
- return f"{value}"
96
+ class Inputs(BaseModelWIthConfig):
97
+ parameters: Optional[list[Parameter]] = Field(default=None)
115
98
 
116
99
 
117
- class OutputParameter(Parameter):
118
- """
119
- Renders:
120
- - name: step-name
121
- valueFrom:
122
- path: /tmp/output.txt
123
- """
100
+ class Outputs(BaseModelWIthConfig):
101
+ parameters: Optional[list[OutputParameter]] = Field(default=None)
124
102
 
125
- path: str = Field(default="/tmp/output.txt", exclude=True)
126
103
 
127
- @computed_field # type: ignore
128
- @property
129
- def valueFrom(self) -> Dict[str, str]:
130
- return {"path": self.path}
104
+ class Arguments(BaseModelWIthConfig):
105
+ parameters: Optional[list[Parameter]] = Field(default=None)
131
106
 
132
107
 
133
- class Argument(BaseModel):
134
- """
135
- Templates are called with arguments, which become inputs for the template
136
- Renders:
137
- arguments:
138
- parameters:
139
- - name: The name of the parameter
140
- value: The value of the parameter
141
- """
108
+ class TolerationEffect(str, Enum):
109
+ NoSchedule = "NoSchedule"
110
+ PreferNoSchedule = "PreferNoSchedule"
111
+ NoExecute = "NoExecute"
142
112
 
143
- name: str
144
- value: str
145
113
 
146
- @field_serializer("name")
147
- def serialize_name(self, name: str) -> str:
148
- return f"{str(name)}"
114
+ class TolerationOperator(str, Enum):
115
+ Exists = "Exists"
116
+ Equal = "Equal"
149
117
 
150
- @field_serializer("value")
151
- def serialize_value(self, value: str) -> str:
152
- return f"{value}"
153
118
 
119
+ class PodMetaData(BaseModelWIthConfig):
120
+ annotations: dict[str, str] = Field(default_factory=dict)
121
+ labels: dict[str, str] = Field(default_factory=dict)
154
122
 
155
- class Request(BaseModel):
156
- """
157
- The default requests
158
- """
159
123
 
160
- memory: str = "1Gi"
161
- cpu: str = "250m"
124
+ class Toleration(BaseModelWIthConfig):
125
+ effect: Optional[TolerationEffect] = Field(default=None)
126
+ key: Optional[str] = Field(default=None)
127
+ operator: TolerationOperator = Field(default=TolerationOperator.Equal)
128
+ tolerationSeconds: Optional[int] = Field(default=None)
129
+ value: Optional[str] = Field(default=None)
162
130
 
131
+ @model_validator(mode="after")
132
+ def validate_tolerations(self) -> "Toleration":
133
+ if not self.key:
134
+ if self.operator != TolerationOperator.Exists:
135
+ raise ValueError("Toleration key is required when operator is Equal")
163
136
 
164
- VendorGPU = Annotated[
165
- Optional[int],
166
- PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
167
- ]
137
+ if self.operator == TolerationOperator.Exists:
138
+ if self.value:
139
+ raise ValueError(
140
+ "Toleration value is not allowed when operator is Exists"
141
+ )
142
+ return self
168
143
 
169
144
 
170
- class Limit(Request):
171
- """
172
- The default limits
173
- """
145
+ class ImagePullPolicy(str, Enum):
146
+ Always = "Always"
147
+ IfNotPresent = "IfNotPresent"
148
+ Never = "Never"
174
149
 
175
- gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
176
150
 
151
+ class PersistentVolumeClaimSource(BaseModelWIthConfig):
152
+ claim_name: str
153
+ read_only: bool = Field(default=False)
177
154
 
178
- class Resources(BaseModel):
179
- limits: Limit = Field(default=Limit(), serialization_alias="limits")
180
- requests: Request = Field(default=Request(), serialization_alias="requests")
181
155
 
156
+ class Volume(BaseModelWIthConfig):
157
+ name: str
158
+ persistent_volume_claim: PersistentVolumeClaimSource
182
159
 
183
- class BackOff(BaseModel):
184
- duration_in_seconds: int = Field(default=2 * 60, serialization_alias="duration")
185
- factor: float = Field(default=2, serialization_alias="factor")
186
- max_duration: int = Field(default=60 * 60, serialization_alias="maxDuration")
160
+ def __hash__(self):
161
+ return hash(self.name)
187
162
 
188
- @field_serializer("duration_in_seconds")
189
- def cast_duration_as_str(self, duration_in_seconds: int, _info) -> str:
190
- return str(duration_in_seconds)
191
163
 
192
- @field_serializer("max_duration")
193
- def cast_mas_duration_as_str(self, max_duration: int, _info) -> str:
194
- return str(max_duration)
164
+ class VolumeMount(BaseModelWIthConfig):
165
+ mount_path: str
166
+ name: str
167
+ read_only: bool = Field(default=False)
195
168
 
169
+ @model_validator(mode="after")
170
+ def validate_volume_mount(self) -> "VolumeMount":
171
+ if "." in self.mount_path:
172
+ raise ValueError("mount_path cannot contain '.'")
173
+
174
+ return self
196
175
 
197
- class Retry(BaseModel):
198
- limit: int = 0
199
- retry_policy: str = Field(default="Always", serialization_alias="retryPolicy")
200
- back_off: BackOff = Field(default=BackOff(), serialization_alias="backoff")
201
176
 
202
- @field_serializer("limit")
203
- def cast_limit_as_str(self, limit: int, _info) -> str:
204
- return str(limit)
177
+ VolumePair = namedtuple("VolumePair", ["volume", "volume_mount"])
205
178
 
206
179
 
207
- class Toleration(BaseModel):
208
- effect: str
180
+ class LabelSelectorRequirement(BaseModelWIthConfig):
209
181
  key: str
210
182
  operator: str
211
- value: str
212
-
213
-
214
- class TemplateDefaults(BaseModel):
215
- max_step_duration: int = Field(
216
- default=60 * 60 * 2,
217
- serialization_alias="activeDeadlineSeconds",
218
- gt=0,
219
- description="Max run time of a step",
220
- )
221
-
222
- @computed_field # type: ignore
223
- @property
224
- def timeout(self) -> str:
225
- return f"{self.max_step_duration + 60*60}s"
226
-
183
+ values: list[str]
227
184
 
228
- ShlexCommand = Annotated[
229
- str, PlainSerializer(lambda x: shlex.split(x), return_type=List[str])
230
- ]
231
185
 
186
+ class PodGCStrategy(str, Enum):
187
+ OnPodCompletion = "OnPodCompletion"
188
+ OnPodSuccess = "OnPodSuccess"
189
+ OnWorkflowCompletion = "OnWorkflowCompletion"
190
+ OnWorkflowSuccess = "OnWorkflowSuccess"
232
191
 
233
- class Container(BaseModel):
234
- image: str
235
- command: ShlexCommand
236
- volume_mounts: Optional[List["ContainerVolume"]] = Field(
237
- default=None, serialization_alias="volumeMounts"
238
- )
239
- image_pull_policy: str = Field(default="", serialization_alias="imagePullPolicy")
240
- resources: Optional[Resources] = Field(
241
- default=None, serialization_alias="resources"
242
- )
243
192
 
244
- env_vars: List[EnvVar] = Field(default_factory=list, exclude=True)
245
- secrets_from_k8s: List[SecretEnvVar] = Field(default_factory=list, exclude=True)
193
+ class LabelSelector(BaseModelWIthConfig):
194
+ matchExpressions: list[LabelSelectorRequirement] = Field(default_factory=list)
195
+ matchLabels: dict[str, str] = Field(default_factory=dict)
246
196
 
247
- @computed_field # type: ignore
248
- @property
249
- def env(self) -> Optional[List[Union[EnvVar, SecretEnvVar]]]:
250
- if not self.env_vars and not self.secrets_from_k8s:
251
- return None
252
197
 
253
- return self.env_vars + self.secrets_from_k8s
198
+ class PodGC(BaseModelWIthConfig):
199
+ delete_delay_duration: str = Field(default="1h") # 1 hour
200
+ label_selector: Optional[LabelSelector] = Field(default=None)
201
+ strategy: Optional[PodGCStrategy] = Field(default=None)
254
202
 
255
203
 
256
- class DagTaskTemplate(BaseModel):
204
+ class Request(BaseModel):
257
205
  """
258
- dag:
259
- tasks:
260
- name: A
261
- template: nested-diamond
262
- arguments:
263
- parameters: [{name: message, value: A}]
206
+ The default requests
264
207
  """
265
208
 
266
- name: str
267
- template: str
268
- depends: List[str] = []
269
- arguments: Optional[List[Argument]] = Field(default=None)
270
- with_param: Optional[str] = Field(default=None, serialization_alias="withParam")
271
-
272
- @field_serializer("depends")
273
- def transform_depends_as_str(self, depends: List[str]) -> str:
274
- return " || ".join(depends)
209
+ memory: str = "1Gi"
210
+ cpu: str = "250m"
275
211
 
276
- @field_serializer("arguments", when_used="unless-none")
277
- def empty_arguments_to_none(
278
- self, arguments: List[Argument]
279
- ) -> Dict[str, List[Argument]]:
280
- return {"parameters": arguments}
281
212
 
213
+ VendorGPU = Annotated[
214
+ Optional[int],
215
+ PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
216
+ ]
282
217
 
283
- class ContainerTemplate(BaseModel):
284
- # These templates are used for actual execution nodes.
285
- name: str
286
- active_deadline_seconds: Optional[int] = Field(
287
- default=None, serialization_alias="activeDeadlineSeconds", gt=0
288
- )
289
- node_selector: Optional[Dict[str, str]] = Field(
290
- default=None, serialization_alias="nodeSelector"
291
- )
292
- retry_strategy: Optional[Retry] = Field(
293
- default=None, serialization_alias="retryStrategy"
294
- )
295
- tolerations: Optional[List[Toleration]] = Field(
296
- default=None, serialization_alias="tolerations"
297
- )
298
218
 
299
- container: Container
219
+ class Limit(Request):
220
+ """
221
+ The default limits
222
+ """
300
223
 
301
- outputs: Optional[List[OutputParameter]] = Field(
302
- default=None, serialization_alias="outputs"
303
- )
304
- inputs: Optional[List[Parameter]] = Field(
305
- default=None, serialization_alias="inputs"
306
- )
224
+ gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
307
225
 
308
- def __hash__(self):
309
- return hash(self.name)
310
226
 
311
- @field_serializer("outputs", when_used="unless-none")
312
- def reshape_outputs(
313
- self, outputs: List[OutputParameter]
314
- ) -> Dict[str, List[OutputParameter]]:
315
- return {"parameters": outputs}
227
+ class Resources(BaseModel):
228
+ limits: Limit = Field(default=Limit(), serialization_alias="limits")
229
+ requests: Request = Field(default=Request(), serialization_alias="requests")
316
230
 
317
- @field_serializer("inputs", when_used="unless-none")
318
- def reshape_inputs(self, inputs: List[Parameter]) -> Dict[str, List[Parameter]]:
319
- return {"parameters": inputs}
320
231
 
232
+ # This is what the user can override per template
233
+ # Some are specific to container and some are specific to dag
234
+ class TemplateDefaults(BaseModelWIthConfig):
235
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
236
+ fail_fast: bool = Field(default=True)
237
+ node_selector: dict[str, str] = Field(default_factory=dict)
238
+ parallelism: Optional[int] = Field(default=None)
239
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
240
+ timeout: Optional[str] = Field(default=None)
241
+ tolerations: Optional[list[Toleration]] = Field(default=None)
321
242
 
322
- class DagTemplate(BaseModel):
323
- # These are used for parallel, map nodes dag definition
324
- name: str = "runnable-dag"
325
- tasks: List[DagTaskTemplate] = Field(default=[], exclude=True)
326
- inputs: Optional[List[Parameter]] = Field(
327
- default=None, serialization_alias="inputs"
243
+ # These are in addition to what argo spec provides
244
+ image: str
245
+ image_pull_policy: Optional[ImagePullPolicy] = Field(default=ImagePullPolicy.Always)
246
+ resources: Resources = Field(default_factory=Resources)
247
+
248
+
249
+ # User provides this as part of the argoSpec
250
+ # some an be provided here or as a template default or node override
251
+ class ArgoWorkflowSpec(BaseModelWIthConfig):
252
+ active_deadline_seconds: int = Field(default=86400) # 1 day for the whole workflow
253
+ arguments: Optional[Arguments] = Field(default=None)
254
+ entrypoint: Literal["runnable-dag"] = Field(default="runnable-dag", frozen=True)
255
+ node_selector: dict[str, str] = Field(default_factory=dict)
256
+ parallelism: Optional[int] = Field(default=None) # GLobal parallelism
257
+ pod_gc: Optional[PodGC] = Field(default=None, serialization_alias="podGC")
258
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
259
+ service_account_name: Optional[str] = Field(default=None)
260
+ template_defaults: TemplateDefaults
261
+ tolerations: Optional[list[Toleration]] = Field(default=None)
262
+
263
+
264
+ class ArgoMetadata(BaseModelWIthConfig):
265
+ annotations: Optional[dict[str, str]] = Field(default=None)
266
+ generate_name: str # User can mention this to uniquely identify the run
267
+ labels: dict[str, str] = Field(default_factory=dict)
268
+ namespace: Optional[str] = Field(default="default")
269
+
270
+
271
+ class ArgoWorkflow(BaseModelWIthConfig):
272
+ apiVersion: Literal["argoproj.io/v1alpha1"] = Field(
273
+ default="argoproj.io/v1alpha1", frozen=True
328
274
  )
329
- parallelism: Optional[int] = None
330
- fail_fast: bool = Field(default=False, serialization_alias="failFast")
275
+ kind: Literal["Workflow"] = Field(default="Workflow", frozen=True)
276
+ metadata: ArgoMetadata
277
+ spec: ArgoWorkflowSpec
331
278
 
332
- @field_validator("parallelism")
333
- @classmethod
334
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
335
- if parallelism is not None and parallelism <= 0:
336
- raise ValueError("Parallelism must be a positive integer greater than 0")
337
- return parallelism
338
279
 
339
- @computed_field # type: ignore
340
- @property
341
- def dag(self) -> Dict[str, List[DagTaskTemplate]]:
342
- return {"tasks": self.tasks}
280
+ # The below are not visible to the user
281
+ class DagTask(BaseModelWIthConfig):
282
+ name: str
283
+ template: str # Should be name of a container template or dag template
284
+ arguments: Optional[Arguments] = Field(default=None)
285
+ with_param: Optional[str] = Field(default=None)
286
+ depends: Optional[str] = Field(default=None)
343
287
 
344
- @field_serializer("inputs", when_used="unless-none")
345
- def reshape_inputs(
346
- self, inputs: List[Parameter], _info
347
- ) -> Dict[str, List[Parameter]]:
348
- return {"parameters": inputs}
349
288
 
289
+ class CoreDagTemplate(BaseModelWIthConfig):
290
+ tasks: list[DagTask] = Field(default_factory=list[DagTask])
350
291
 
351
- class Volume(BaseModel):
352
- """
353
- spec config requires, name and persistentVolumeClaim
354
- step requires name and mountPath
355
- """
356
-
357
- name: str
358
- claim: str = Field(exclude=True)
359
- mount_path: str = Field(serialization_alias="mountPath", exclude=True)
360
292
 
361
- @computed_field # type: ignore
362
- @property
363
- def persistentVolumeClaim(self) -> Dict[str, str]:
364
- return {"claimName": self.claim}
293
+ class CoreContainerTemplate(BaseModelWIthConfig):
294
+ image: str
295
+ command: list[str]
296
+ image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.IfNotPresent)
297
+ env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
298
+ volume_mounts: list[VolumeMount] = Field(default_factory=list)
299
+ resources: Resources = Field(default_factory=Resources)
365
300
 
366
301
 
367
- class ContainerVolume(BaseModel):
302
+ class DagTemplate(BaseModelWIthConfig):
368
303
  name: str
369
- mount_path: str = Field(serialization_alias="mountPath")
304
+ dag: CoreDagTemplate = Field(default_factory=CoreDagTemplate)
305
+ inputs: Optional[Inputs] = Field(default=None)
306
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
307
+ fail_fast: bool = Field(default=True)
370
308
 
309
+ model_config = ConfigDict(
310
+ extra="ignore",
311
+ )
371
312
 
372
- class UserVolumeMounts(BaseModel):
373
- """
374
- The volume specification as user defines it.
375
- """
313
+ def __hash__(self):
314
+ return hash(self.name)
376
315
 
377
- name: str # This is the name of the PVC on K8s
378
- mount_path: str # This is mount path on the container
379
316
 
317
+ class ContainerTemplate((BaseModelWIthConfig)):
318
+ name: str
319
+ container: CoreContainerTemplate
320
+ inputs: Optional[Inputs] = Field(default=None)
321
+ outputs: Optional[Outputs] = Field(default=None)
322
+
323
+ # The remaining can be from template defaults or node overrides
324
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
325
+ metadata: Optional[PodMetaData] = Field(default=None)
326
+ node_selector: dict[str, str] = Field(default_factory=dict)
327
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
328
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
329
+ timeout: Optional[str] = Field(default=None)
330
+ tolerations: Optional[list[Toleration]] = Field(default=None)
331
+ volumes: Optional[list[Volume]] = Field(default=None)
332
+
333
+ model_config = ConfigDict(
334
+ extra="ignore",
335
+ )
380
336
 
381
- class NodeRenderer(ABC):
382
- allowed_node_types: List[str] = []
337
+ def __hash__(self):
338
+ return hash(self.name)
383
339
 
384
- def __init__(self, executor: "ArgoExecutor", node: BaseNode) -> None:
385
- self.executor = executor
386
- self.node = node
387
340
 
388
- @abstractmethod
389
- def render(self, list_of_iter_values: Optional[List] = None):
390
- pass
341
+ class CustomVolume(BaseModelWIthConfig):
342
+ mount_path: str
343
+ persistent_volume_claim: PersistentVolumeClaimSource
391
344
 
392
345
 
393
- class ExecutionNode(NodeRenderer):
394
- allowed_node_types = ["task", "stub", "success", "fail"]
346
+ class ArgoExecutor(GenericPipelineExecutor):
347
+ service_name: str = "argo"
348
+ _is_local: bool = False
349
+ mock: bool = False
350
+
351
+ model_config = ConfigDict(
352
+ extra="forbid",
353
+ alias_generator=to_camel,
354
+ populate_by_name=True,
355
+ from_attributes=True,
356
+ use_enum_values=True,
357
+ )
395
358
 
396
- def render(self, list_of_iter_values: Optional[List] = None):
397
- """
398
- Compose the map variable and create the execution command.
399
- Create an input to the command.
400
- create_container_template : creates an argument for the list of iter values
401
- """
402
- map_variable = self.executor.compose_map_variable(list_of_iter_values)
403
- command = utils.get_node_execution_command(
404
- self.node,
405
- over_write_run_id=self.executor._run_id_placeholder,
406
- map_variable=map_variable,
407
- log_level=self.executor._log_level,
408
- )
359
+ argo_workflow: ArgoWorkflow
409
360
 
410
- inputs = []
411
- if list_of_iter_values:
412
- for val in list_of_iter_values:
413
- inputs.append(Parameter(name=val))
361
+ # Lets use a generic one
362
+ pvc_for_runnable: Optional[str] = Field(default=None)
363
+ # pvc_for_catalog: Optional[str] = Field(default=None)
364
+ # pvc_for_run_log: Optional[str] = Field(default=None)
365
+ custom_volumes: Optional[list[CustomVolume]] = Field(
366
+ default_factory=list[CustomVolume]
367
+ )
414
368
 
415
- # Create the container template
416
- container_template = self.executor.create_container_template(
417
- working_on=self.node,
418
- command=command,
419
- inputs=inputs,
420
- )
369
+ expose_parameters_as_inputs: bool = True
370
+ secret_from_k8s: Optional[str] = Field(default=None)
371
+ output_file: str = Field(default="argo-pipeline.yaml")
372
+ log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
373
+ default="INFO"
374
+ )
421
375
 
422
- self.executor._container_templates.append(container_template)
376
+ # This should be used when we refer to run_id or log_level in the containers
377
+ _run_id_as_parameter: str = PrivateAttr(default="{{workflow.parameters.run_id}}")
378
+ _log_level_as_parameter: str = PrivateAttr(
379
+ default="{{workflow.parameters.log_level}}"
380
+ )
423
381
 
382
+ _templates: list[ContainerTemplate | DagTemplate] = PrivateAttr(
383
+ default_factory=list
384
+ )
385
+ _container_log_location: str = PrivateAttr(default="/tmp/run_logs/")
386
+ _container_catalog_location: str = PrivateAttr(default="/tmp/catalog/")
387
+ _added_initial_container: bool = PrivateAttr(default=False)
388
+
389
+ def sanitize_name(self, name: str) -> str:
390
+ formatted_name = name.replace(" ", "-").replace(".", "-").replace("_", "-")
391
+ tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
392
+ unique_name = f"{formatted_name}-{tag}"
393
+ unique_name = unique_name.replace("map-variable-placeholder-", "")
394
+ return unique_name
395
+
396
+ def _set_up_initial_container(self, container_template: CoreContainerTemplate):
397
+ if self._added_initial_container:
398
+ return
399
+
400
+ parameters: list[Parameter] = []
401
+
402
+ if self.argo_workflow.spec.arguments:
403
+ parameters = self.argo_workflow.spec.arguments.parameters or []
404
+
405
+ for parameter in parameters or []:
406
+ key, _ = parameter.name, parameter.value
407
+ env_var = EnvVar(
408
+ name=defaults.PARAMETER_PREFIX + key,
409
+ value="{{workflow.parameters." + key + "}}",
410
+ )
411
+ container_template.env.append(env_var)
424
412
 
425
- class DagNodeRenderer(NodeRenderer):
426
- allowed_node_types = ["dag"]
413
+ env_var = EnvVar(name="error_on_existing_run_id", value="true")
414
+ container_template.env.append(env_var)
427
415
 
428
- def render(self, list_of_iter_values: Optional[List] = None):
429
- self.node = cast(DagNode, self.node)
430
- task_template_arguments = []
431
- dag_inputs = []
432
- if list_of_iter_values:
433
- for value in list_of_iter_values:
434
- task_template_arguments.append(
435
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
436
- )
437
- dag_inputs.append(Parameter(name=value))
416
+ # After the first container is added, set the added_initial_container to True
417
+ self._added_initial_container = True
438
418
 
439
- clean_name = self.executor.get_clean_name(self.node)
440
- fan_out_template = self.executor._create_fan_out_template(
441
- composite_node=self.node, list_of_iter_values=list_of_iter_values
442
- )
443
- fan_out_template.arguments = (
444
- task_template_arguments if task_template_arguments else None
445
- )
419
+ def _create_fan_templates(
420
+ self,
421
+ node: BaseNode,
422
+ mode: str,
423
+ parameters: Optional[list[Parameter]],
424
+ task_name: str,
425
+ ):
426
+ template_defaults = self.argo_workflow.spec.template_defaults.model_dump()
446
427
 
447
- fan_in_template = self.executor._create_fan_in_template(
448
- composite_node=self.node, list_of_iter_values=list_of_iter_values
449
- )
450
- fan_in_template.arguments = (
451
- task_template_arguments if task_template_arguments else None
452
- )
428
+ map_variable: TypeMapVariable = {}
429
+ for parameter in parameters or []:
430
+ map_variable[parameter.name] = ( # type: ignore
431
+ "{{inputs.parameters." + str(parameter.name) + "}}"
432
+ )
453
433
 
454
- self.executor._gather_task_templates_of_dag(
455
- self.node.branch,
456
- dag_name=f"{clean_name}-branch",
457
- list_of_iter_values=list_of_iter_values,
434
+ fan_command = utils.get_fan_command(
435
+ mode=mode,
436
+ node=node,
437
+ run_id=self._run_id_as_parameter,
438
+ map_variable=map_variable,
458
439
  )
459
440
 
460
- branch_template = DagTaskTemplate(
461
- name=f"{clean_name}-branch",
462
- template=f"{clean_name}-branch",
463
- arguments=task_template_arguments if task_template_arguments else None,
441
+ core_container_template = CoreContainerTemplate(
442
+ command=shlex.split(fan_command),
443
+ image=template_defaults["image"],
444
+ image_pull_policy=template_defaults["image_pull_policy"],
445
+ volume_mounts=[
446
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
447
+ ],
464
448
  )
465
- branch_template.depends.append(f"{clean_name}-fan-out.Succeeded")
466
- fan_in_template.depends.append(f"{clean_name}-branch.Succeeded")
467
- fan_in_template.depends.append(f"{clean_name}-branch.Failed")
468
-
469
- self.executor._dag_templates.append(
470
- DagTemplate(
471
- tasks=[fan_out_template, branch_template, fan_in_template],
472
- name=clean_name,
473
- inputs=dag_inputs if dag_inputs else None,
474
- )
475
- )
476
-
477
449
 
478
- class ParallelNodeRender(NodeRenderer):
479
- allowed_node_types = ["parallel"]
450
+ # Either a task or a fan-out can the first container
451
+ self._set_up_initial_container(container_template=core_container_template)
480
452
 
481
- def render(self, list_of_iter_values: Optional[List] = None):
482
- self.node = cast(ParallelNode, self.node)
483
- task_template_arguments = []
484
- dag_inputs = []
485
- if list_of_iter_values:
486
- for value in list_of_iter_values:
487
- task_template_arguments.append(
488
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
489
- )
490
- dag_inputs.append(Parameter(name=value))
453
+ task_name += f"-fan-{mode}"
491
454
 
492
- clean_name = self.executor.get_clean_name(self.node)
493
- fan_out_template = self.executor._create_fan_out_template(
494
- composite_node=self.node, list_of_iter_values=list_of_iter_values
495
- )
496
- fan_out_template.arguments = (
497
- task_template_arguments if task_template_arguments else None
498
- )
455
+ outputs: Optional[Outputs] = None
456
+ if mode == "out" and node.node_type == "map":
457
+ outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
499
458
 
500
- fan_in_template = self.executor._create_fan_in_template(
501
- composite_node=self.node, list_of_iter_values=list_of_iter_values
502
- )
503
- fan_in_template.arguments = (
504
- task_template_arguments if task_template_arguments else None
459
+ container_template = ContainerTemplate(
460
+ container=core_container_template,
461
+ name=task_name,
462
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
463
+ inputs=Inputs(parameters=parameters),
464
+ outputs=outputs,
465
+ **template_defaults,
505
466
  )
506
467
 
507
- branch_templates = []
508
- for name, branch in self.node.branches.items():
509
- branch_name = self.executor.sanitize_name(name)
510
- self.executor._gather_task_templates_of_dag(
511
- branch,
512
- dag_name=f"{clean_name}-{branch_name}",
513
- list_of_iter_values=list_of_iter_values,
514
- )
515
- task_template = DagTaskTemplate(
516
- name=f"{clean_name}-{branch_name}",
517
- template=f"{clean_name}-{branch_name}",
518
- arguments=task_template_arguments if task_template_arguments else None,
519
- )
520
- task_template.depends.append(f"{clean_name}-fan-out.Succeeded")
521
- fan_in_template.depends.append(f"{task_template.name}.Succeeded")
522
- fan_in_template.depends.append(f"{task_template.name}.Failed")
523
- branch_templates.append(task_template)
524
-
525
- executor_config = self.executor._resolve_executor_config(self.node)
526
-
527
- self.executor._dag_templates.append(
528
- DagTemplate(
529
- tasks=[fan_out_template] + branch_templates + [fan_in_template],
530
- name=clean_name,
531
- inputs=dag_inputs if dag_inputs else None,
532
- parallelism=executor_config.get("parallelism", None),
533
- )
534
- )
468
+ self._templates.append(container_template)
535
469
 
470
+ def _create_container_template(
471
+ self,
472
+ node: BaseNode,
473
+ task_name: str,
474
+ inputs: Optional[Inputs] = None,
475
+ ) -> ContainerTemplate:
476
+ template_defaults = self.argo_workflow.spec.template_defaults.model_dump()
536
477
 
537
- class MapNodeRender(NodeRenderer):
538
- allowed_node_types = ["map"]
478
+ node_overide = {}
479
+ if hasattr(node, "overides"):
480
+ node_overide = node.overides
539
481
 
540
- def render(self, list_of_iter_values: Optional[List] = None):
541
- self.node = cast(MapNode, self.node)
542
- task_template_arguments = []
543
- dag_inputs = []
482
+ # update template defaults with node overrides
483
+ template_defaults.update(node_overide)
544
484
 
545
- if not list_of_iter_values:
546
- list_of_iter_values = []
485
+ inputs = inputs or Inputs(parameters=[])
547
486
 
548
- for value in list_of_iter_values:
549
- task_template_arguments.append(
550
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
487
+ map_variable: TypeMapVariable = {}
488
+ for parameter in inputs.parameters or []:
489
+ map_variable[parameter.name] = ( # type: ignore
490
+ "{{inputs.parameters." + str(parameter.name) + "}}"
551
491
  )
552
- dag_inputs.append(Parameter(name=value))
553
492
 
554
- clean_name = self.executor.get_clean_name(self.node)
555
-
556
- fan_out_template = self.executor._create_fan_out_template(
557
- composite_node=self.node, list_of_iter_values=list_of_iter_values
558
- )
559
- fan_out_template.arguments = (
560
- task_template_arguments if task_template_arguments else None
493
+ # command = "runnable execute-single-node"
494
+ command = utils.get_node_execution_command(
495
+ node=node,
496
+ over_write_run_id=self._run_id_as_parameter,
497
+ map_variable=map_variable,
498
+ log_level=self._log_level_as_parameter,
561
499
  )
562
500
 
563
- fan_in_template = self.executor._create_fan_in_template(
564
- composite_node=self.node, list_of_iter_values=list_of_iter_values
565
- )
566
- fan_in_template.arguments = (
567
- task_template_arguments if task_template_arguments else None
501
+ core_container_template = CoreContainerTemplate(
502
+ command=shlex.split(command),
503
+ image=template_defaults["image"],
504
+ image_pull_policy=template_defaults["image_pull_policy"],
505
+ volume_mounts=[
506
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
507
+ ],
568
508
  )
569
509
 
570
- list_of_iter_values.append(self.node.iterate_as)
510
+ self._set_up_initial_container(container_template=core_container_template)
571
511
 
572
- self.executor._gather_task_templates_of_dag(
573
- self.node.branch,
574
- dag_name=f"{clean_name}-map",
575
- list_of_iter_values=list_of_iter_values,
576
- )
577
-
578
- task_template = DagTaskTemplate(
579
- name=f"{clean_name}-map",
580
- template=f"{clean_name}-map",
581
- arguments=task_template_arguments if task_template_arguments else None,
582
- )
583
- task_template.with_param = (
584
- "{{tasks."
585
- + f"{clean_name}-fan-out"
586
- + ".outputs.parameters."
587
- + "iterate-on"
588
- + "}}"
512
+ container_template = ContainerTemplate(
513
+ container=core_container_template,
514
+ name=task_name,
515
+ inputs=Inputs(
516
+ parameters=[
517
+ Parameter(name=param.name) for param in inputs.parameters or []
518
+ ]
519
+ ),
520
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
521
+ **template_defaults,
589
522
  )
590
523
 
591
- argument = Argument(name=self.node.iterate_as, value="{{item}}")
592
- if task_template.arguments is None:
593
- task_template.arguments = []
594
- task_template.arguments.append(argument)
595
-
596
- task_template.depends.append(f"{clean_name}-fan-out.Succeeded")
597
- fan_in_template.depends.append(f"{clean_name}-map.Succeeded")
598
- fan_in_template.depends.append(f"{clean_name}-map.Failed")
599
-
600
- executor_config = self.executor._resolve_executor_config(self.node)
524
+ return container_template
601
525
 
602
- self.executor._dag_templates.append(
603
- DagTemplate(
604
- tasks=[fan_out_template, task_template, fan_in_template],
605
- name=clean_name,
606
- inputs=dag_inputs if dag_inputs else None,
607
- parallelism=executor_config.get("parallelism", None),
608
- fail_fast=executor_config.get("fail_fast", True),
526
+ def _expose_secrets_to_task(
527
+ self,
528
+ working_on: BaseNode,
529
+ container_template: CoreContainerTemplate,
530
+ ):
531
+ assert isinstance(working_on, TaskNode)
532
+ secrets = working_on.executable.secrets
533
+ for secret in secrets:
534
+ assert self.secret_from_k8s is not None
535
+ secret_env_var = SecretEnvVar(
536
+ environment_variable=secret,
537
+ secret_name=self.secret_from_k8s, # This has to be exposed from config
538
+ secret_key=secret,
609
539
  )
610
- )
611
-
612
-
613
- def get_renderer(node):
614
- renderers = NodeRenderer.__subclasses__()
540
+ container_template.env.append(secret_env_var)
615
541
 
616
- for renderer in renderers:
617
- if node.node_type in renderer.allowed_node_types:
618
- return renderer
619
- raise Exception("This node type is not render-able")
620
-
621
-
622
- class MetaData(BaseModel):
623
- generate_name: str = Field(
624
- default="runnable-dag-", serialization_alias="generateName"
625
- )
626
- # The type ignore is related to: https://github.com/python/mypy/issues/18191
627
- annotations: Optional[Dict[str, str]] = Field(default_factory=dict) # type: ignore
628
- labels: Optional[Dict[str, str]] = Field(default_factory=dict) # type: ignore
629
- namespace: Optional[str] = Field(default=None)
630
-
631
-
632
- class Spec(BaseModel):
633
- active_deadline_seconds: int = Field(serialization_alias="activeDeadlineSeconds")
634
- entrypoint: str = Field(default="runnable-dag")
635
- node_selector: Optional[Dict[str, str]] = Field(
636
- default_factory=dict, # type: ignore
637
- serialization_alias="nodeSelector",
638
- )
639
- tolerations: Optional[List[Toleration]] = Field(
640
- default=None, serialization_alias="tolerations"
641
- )
642
- parallelism: Optional[int] = Field(default=None, serialization_alias="parallelism")
542
+ def _handle_failures(
543
+ self,
544
+ working_on: BaseNode,
545
+ dag: Graph,
546
+ task_name: str,
547
+ parent_dag_template: DagTemplate,
548
+ ):
549
+ if working_on._get_on_failure_node():
550
+ # Create a new dag template
551
+ on_failure_dag: DagTemplate = DagTemplate(name=f"on-failure-{task_name}")
552
+ # Add on failure of the current task to be the failure dag template
553
+ on_failure_task = DagTask(
554
+ name=f"on-failure-{task_name}",
555
+ template=f"on-failure-{task_name}",
556
+ depends=task_name + ".Failed",
557
+ )
558
+ # Set failfast of the dag template to be false
559
+ # If not, this branch will never be invoked
560
+ parent_dag_template.fail_fast = False
643
561
 
644
- # TODO: This has to be user driven
645
- pod_gc: Dict[str, str] = Field( # type ignore
646
- default={"strategy": "OnPodSuccess", "deleteDelayDuration": "600s"},
647
- serialization_alias="podGC",
648
- )
562
+ assert parent_dag_template.dag
649
563
 
650
- retry_strategy: Retry = Field(default=Retry(), serialization_alias="retryStrategy")
651
- service_account_name: Optional[str] = Field(
652
- default=None, serialization_alias="serviceAccountName"
653
- )
564
+ parent_dag_template.dag.tasks.append(on_failure_task)
565
+ self._gather_tasks_for_dag_template(
566
+ on_failure_dag,
567
+ dag=dag,
568
+ start_at=working_on._get_on_failure_node(),
569
+ )
654
570
 
655
- templates: List[Union[DagTemplate, ContainerTemplate]] = Field(default_factory=list)
656
- template_defaults: Optional[TemplateDefaults] = Field(
657
- default=None, serialization_alias="templateDefaults"
658
- )
571
+ # For the future me:
572
+ # - A task can output a array: in this case, its the fan out.
573
+ # - We are using withParam and arguments of the map template to send that value in
574
+ # - The map template should receive that value as a parameter into the template.
575
+ # - The task then start to use it as inputs.parameters.iterate-on
659
576
 
660
- arguments: Optional[List[EnvVar]] = Field(default_factory=list) # type: ignore
661
- persistent_volumes: List[UserVolumeMounts] = Field(
662
- default_factory=list, exclude=True
663
- )
577
+ def _gather_tasks_for_dag_template(
578
+ self,
579
+ dag_template: DagTemplate,
580
+ dag: Graph,
581
+ start_at: str,
582
+ parameters: Optional[list[Parameter]] = None,
583
+ ):
584
+ current_node: str = start_at
585
+ depends: str = ""
664
586
 
665
- @field_validator("parallelism")
666
- @classmethod
667
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
668
- if parallelism is not None and parallelism <= 0:
669
- raise ValueError("Parallelism must be a positive integer greater than 0")
670
- return parallelism
587
+ dag_template.dag = CoreDagTemplate()
671
588
 
672
- @computed_field # type: ignore
673
- @property
674
- def volumes(self) -> List[Volume]:
675
- volumes: List[Volume] = []
676
- claim_names = {}
677
- for i, user_volume in enumerate(self.persistent_volumes):
678
- if user_volume.name in claim_names:
679
- raise Exception(f"Duplicate claim name {user_volume.name}")
680
- claim_names[user_volume.name] = user_volume.name
681
-
682
- volume = Volume(
683
- name=f"executor-{i}",
684
- claim=user_volume.name,
685
- mount_path=user_volume.mount_path,
589
+ while True:
590
+ # Create the dag task with for the parent dag
591
+ working_on: BaseNode = dag.get_node_by_name(current_node)
592
+ task_name = self.sanitize_name(working_on.internal_name)
593
+ current_task = DagTask(
594
+ name=task_name,
595
+ template=task_name,
596
+ depends=depends if not depends else depends + ".Succeeded",
597
+ arguments=Arguments(
598
+ parameters=[
599
+ Parameter(
600
+ name=param.name,
601
+ value=f"{{{{inputs.parameters.{param.name}}}}}",
602
+ )
603
+ for param in parameters or []
604
+ ]
605
+ ),
686
606
  )
687
- volumes.append(volume)
688
- return volumes
607
+ dag_template.dag.tasks.append(current_task)
608
+ depends = task_name
609
+
610
+ match working_on.node_type:
611
+ case "task" | "success" | "stub":
612
+ template_of_container = self._create_container_template(
613
+ working_on,
614
+ task_name=task_name,
615
+ inputs=Inputs(parameters=parameters),
616
+ )
617
+ assert template_of_container.container is not None
689
618
 
690
- @field_serializer("arguments", when_used="unless-none")
691
- def reshape_arguments(
692
- self, arguments: List[EnvVar], _info
693
- ) -> Dict[str, List[EnvVar]]:
694
- return {"parameters": arguments}
619
+ if working_on.node_type == "task":
620
+ self._expose_secrets_to_task(
621
+ working_on,
622
+ container_template=template_of_container.container,
623
+ )
695
624
 
625
+ self._templates.append(template_of_container)
696
626
 
697
- class Workflow(BaseModel):
698
- api_version: str = Field(
699
- default="argoproj.io/v1alpha1",
700
- serialization_alias="apiVersion",
701
- )
702
- kind: str = "Workflow"
703
- metadata: MetaData = Field(default=MetaData())
704
- spec: Spec
627
+ case "map" | "parallel":
628
+ assert isinstance(working_on, MapNode) or isinstance(
629
+ working_on, ParallelNode
630
+ )
631
+ node_type = working_on.node_type
705
632
 
633
+ composite_template: DagTemplate = DagTemplate(
634
+ name=task_name, fail_fast=False
635
+ )
706
636
 
707
- class Override(BaseModel):
708
- model_config = ConfigDict(extra="ignore")
637
+ # Add the fan out task
638
+ fan_out_task = DagTask(
639
+ name=f"{task_name}-fan-out",
640
+ template=f"{task_name}-fan-out",
641
+ arguments=Arguments(parameters=parameters),
642
+ )
643
+ composite_template.dag.tasks.append(fan_out_task)
644
+ self._create_fan_templates(
645
+ node=working_on,
646
+ mode="out",
647
+ parameters=parameters,
648
+ task_name=task_name,
649
+ )
709
650
 
710
- image: str
711
- tolerations: Optional[List[Toleration]] = Field(default=None)
651
+ # Add the composite task
652
+ with_param = None
653
+ added_parameters = parameters or []
654
+ branches = {}
655
+ if node_type == "map":
656
+ # If the node is map, we need to handle the iterate as and on
657
+ assert isinstance(working_on, MapNode)
658
+ added_parameters = added_parameters + [
659
+ Parameter(name=working_on.iterate_as, value="{{item}}")
660
+ ]
661
+ with_param = f"{{{{tasks.{task_name}-fan-out.outputs.parameters.iterate-on}}}}"
662
+
663
+ branches["branch"] = working_on.branch
664
+ elif node_type == "parallel":
665
+ assert isinstance(working_on, ParallelNode)
666
+ branches = working_on.branches
667
+ else:
668
+ raise ValueError("Invalid node type")
669
+
670
+ fan_in_depends = ""
671
+
672
+ for name, branch in branches.items():
673
+ name = (
674
+ name.replace(" ", "-").replace(".", "-").replace("_", "-")
675
+ )
712
676
 
713
- max_step_duration_in_seconds: int = Field(
714
- default=2 * 60 * 60, # 2 hours
715
- gt=0,
716
- )
677
+ branch_task = DagTask(
678
+ name=f"{task_name}-{name}",
679
+ template=f"{task_name}-{name}",
680
+ depends=f"{task_name}-fan-out.Succeeded",
681
+ arguments=Arguments(parameters=added_parameters),
682
+ with_param=with_param,
683
+ )
684
+ composite_template.dag.tasks.append(branch_task)
685
+
686
+ branch_template = DagTemplate(
687
+ name=branch_task.name,
688
+ inputs=Inputs(
689
+ parameters=[
690
+ Parameter(name=param.name, value=None)
691
+ for param in added_parameters
692
+ ]
693
+ ),
694
+ )
717
695
 
718
- node_selector: Optional[Dict[str, str]] = Field(
719
- default=None,
720
- serialization_alias="nodeSelector",
721
- )
696
+ self._gather_tasks_for_dag_template(
697
+ dag_template=branch_template,
698
+ dag=branch,
699
+ start_at=branch.start_at,
700
+ parameters=added_parameters,
701
+ )
722
702
 
723
- parallelism: Optional[int] = Field(
724
- default=None,
725
- serialization_alias="parallelism",
726
- )
703
+ fan_in_depends += f"{branch_task.name}.Succeeded || {branch_task.name}.Failed || "
727
704
 
728
- resources: Resources = Field(
729
- default=Resources(),
730
- serialization_alias="resources",
731
- )
705
+ fan_in_task = DagTask(
706
+ name=f"{task_name}-fan-in",
707
+ template=f"{task_name}-fan-in",
708
+ depends=fan_in_depends.strip(" || "),
709
+ arguments=Arguments(parameters=parameters),
710
+ )
732
711
 
733
- image_pull_policy: str = Field(default="")
712
+ composite_template.dag.tasks.append(fan_in_task)
713
+ self._create_fan_templates(
714
+ node=working_on,
715
+ mode="in",
716
+ parameters=parameters,
717
+ task_name=task_name,
718
+ )
734
719
 
735
- retry_strategy: Retry = Field(
736
- default=Retry(),
737
- serialization_alias="retryStrategy",
738
- description="Common across all templates",
739
- )
720
+ self._templates.append(composite_template)
740
721
 
741
- @field_validator("parallelism")
742
- @classmethod
743
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
744
- if parallelism is not None and parallelism <= 0:
745
- raise ValueError("Parallelism must be a positive integer greater than 0")
746
- return parallelism
722
+ self._handle_failures(
723
+ working_on,
724
+ dag,
725
+ task_name,
726
+ parent_dag_template=dag_template,
727
+ )
747
728
 
729
+ if working_on.node_type == "success" or working_on.node_type == "fail":
730
+ break
748
731
 
749
- class ArgoExecutor(GenericPipelineExecutor):
750
- service_name: str = "argo"
751
- _is_local: bool = False
732
+ current_node = working_on._get_next_node()
752
733
 
753
- # TODO: Add logging level as option.
734
+ self._templates.append(dag_template)
754
735
 
755
- model_config = ConfigDict(extra="forbid")
736
+ def execute_graph(
737
+ self,
738
+ dag: Graph,
739
+ map_variable: dict[str, str | int | float] | None = None,
740
+ **kwargs,
741
+ ):
742
+ # All the arguments set at the spec level can be referred as "{{workflow.parameters.*}}"
743
+ # We want to use that functionality to override the parameters at the task level
744
+ # We should be careful to override them only at the first task.
745
+ arguments = [] # Can be updated in the UI
746
+ if self.expose_parameters_as_inputs:
747
+ for key, value in self._get_parameters().items():
748
+ value = value.get_value() # type: ignore
749
+ if isinstance(value, dict) or isinstance(value, list):
750
+ continue
756
751
 
757
- image: str
758
- expose_parameters_as_inputs: bool = True
759
- secrets_from_k8s: List[SecretEnvVar] = Field(default_factory=list)
760
- output_file: str = "argo-pipeline.yaml"
752
+ parameter = Parameter(name=key, value=value) # type: ignore
753
+ arguments.append(parameter)
761
754
 
762
- # Metadata related fields
763
- name: str = Field(
764
- default="runnable-dag-", description="Used as an identifier for the workflow"
765
- )
766
- annotations: Dict[str, str] = Field(default_factory=dict)
767
- labels: Dict[str, str] = Field(default_factory=dict)
755
+ run_id_var = Parameter(name="run_id", value="{{workflow.uid}}")
756
+ log_level_var = Parameter(name="log_level", value=self.log_level)
757
+ arguments.append(run_id_var)
758
+ arguments.append(log_level_var)
759
+ self.argo_workflow.spec.arguments = Arguments(parameters=arguments)
760
+
761
+ # This is the entry point of the argo execution
762
+ runnable_dag: DagTemplate = DagTemplate(name="runnable-dag")
763
+
764
+ self._gather_tasks_for_dag_template(
765
+ runnable_dag,
766
+ dag,
767
+ start_at=dag.start_at,
768
+ parameters=[],
769
+ )
770
+
771
+ argo_workflow_dump = self.argo_workflow.model_dump(
772
+ by_alias=True,
773
+ exclude={
774
+ "spec": {
775
+ "template_defaults": {"image_pull_policy", "image", "resources"}
776
+ }
777
+ },
778
+ exclude_none=True,
779
+ round_trip=False,
780
+ )
781
+ argo_workflow_dump["spec"]["templates"] = [
782
+ template.model_dump(
783
+ by_alias=True,
784
+ exclude_none=True,
785
+ )
786
+ for template in self._templates
787
+ ]
768
788
 
769
- max_workflow_duration_in_seconds: int = Field(
770
- 2 * 24 * 60 * 60, # 2 days
771
- serialization_alias="activeDeadlineSeconds",
772
- gt=0,
773
- )
774
- node_selector: Optional[Dict[str, str]] = Field(
775
- default=None,
776
- serialization_alias="nodeSelector",
777
- )
778
- parallelism: Optional[int] = Field(
779
- default=None,
780
- serialization_alias="parallelism",
781
- )
782
- resources: Resources = Field(
783
- default=Resources(),
784
- serialization_alias="resources",
785
- exclude=True,
786
- )
787
- retry_strategy: Retry = Field(
788
- default=Retry(),
789
- serialization_alias="retryStrategy",
790
- description="Common across all templates",
791
- )
792
- max_step_duration_in_seconds: int = Field(
793
- default=2 * 60 * 60, # 2 hours
794
- gt=0,
795
- )
796
- tolerations: Optional[List[Toleration]] = Field(default=None)
797
- image_pull_policy: str = Field(default="")
798
- service_account_name: Optional[str] = None
799
- persistent_volumes: List[UserVolumeMounts] = Field(default_factory=list)
800
-
801
- _run_id_placeholder: str = "{{workflow.parameters.run_id}}"
802
- _log_level: str = "{{workflow.parameters.log_level}}"
803
- _container_templates: List[ContainerTemplate] = []
804
- _dag_templates: List[DagTemplate] = []
805
- _clean_names: Dict[str, str] = {}
806
- _container_volumes: List[ContainerVolume] = []
807
-
808
- @field_validator("parallelism")
809
- @classmethod
810
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
811
- if parallelism is not None and parallelism <= 0:
812
- raise ValueError("Parallelism must be a positive integer greater than 0")
813
- return parallelism
789
+ argo_workflow_dump["spec"]["volumes"] = [
790
+ volume_pair.volume.model_dump(by_alias=True)
791
+ for volume_pair in self.volume_pairs
792
+ ]
814
793
 
815
- @computed_field # type: ignore
816
- @property
817
- def step_timeout(self) -> int:
818
- """
819
- Maximum time the step can take to complete, including the pending state.
820
- """
821
- return (
822
- self.max_step_duration_in_seconds + 2 * 60 * 60
823
- ) # 2 hours + max_step_duration_in_seconds
794
+ yaml = YAML()
795
+ with open(self.output_file, "w") as f:
796
+ yaml.indent(mapping=2, sequence=4, offset=2)
797
+ yaml.dump(
798
+ argo_workflow_dump,
799
+ f,
800
+ )
824
801
 
825
- @property
826
- def metadata(self) -> MetaData:
827
- return MetaData(
828
- generate_name=self.name,
829
- annotations=self.annotations,
830
- labels=self.labels,
802
+ def _implicitly_fail(self, node: BaseNode, map_variable: TypeMapVariable):
803
+ assert self._context.dag
804
+ _, current_branch = search_node_by_internal_name(
805
+ dag=self._context.dag, internal_name=node.internal_name
831
806
  )
832
-
833
- @property
834
- def spec(self) -> Spec:
835
- return Spec(
836
- active_deadline_seconds=self.max_workflow_duration_in_seconds,
837
- node_selector=self.node_selector,
838
- tolerations=self.tolerations,
839
- parallelism=self.parallelism,
840
- retry_strategy=self.retry_strategy,
841
- service_account_name=self.service_account_name,
842
- persistent_volumes=self.persistent_volumes,
843
- template_defaults=TemplateDefaults(
844
- max_step_duration=self.max_step_duration_in_seconds
845
- ),
807
+ _, next_node_name = self._get_status_and_next_node_name(
808
+ node, current_branch, map_variable=map_variable
846
809
  )
810
+ if next_node_name:
811
+ # Terminal nodes do not have next node name
812
+ next_node = current_branch.get_node_by_name(next_node_name)
847
813
 
848
- # TODO: This has to move to execute_node?
849
- def prepare_for_execution(self):
850
- """
851
- Perform any modifications to the services prior to execution of the node.
852
-
853
- Args:
854
- node (Node): [description]
855
- map_variable (dict, optional): [description]. Defaults to None.
856
- """
857
-
858
- self._set_up_run_log(exists_ok=True)
814
+ if next_node.node_type == defaults.FAIL:
815
+ self.execute_node(next_node, map_variable=map_variable)
859
816
 
860
817
  def execute_node(
861
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
818
+ self,
819
+ node: BaseNode,
820
+ map_variable: dict[str, str | int | float] | None = None,
821
+ **kwargs,
862
822
  ):
823
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
824
+ exists_ok = error_on_existing_run_id == "false"
825
+
826
+ self._use_volumes()
827
+ self._set_up_run_log(exists_ok=exists_ok)
828
+
863
829
  step_log = self._context.run_log_store.create_step_log(
864
830
  node.name, node._get_step_log_name(map_variable)
865
831
  )
@@ -870,36 +836,30 @@ class ArgoExecutor(GenericPipelineExecutor):
870
836
  step_log.status = defaults.PROCESSING
871
837
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
872
838
 
873
- super()._execute_node(node, map_variable=map_variable, **kwargs)
874
-
875
- # Implicit fail
876
- if self._context.dag:
877
- # functions and notebooks do not have dags
878
- _, current_branch = search_node_by_internal_name(
879
- dag=self._context.dag, internal_name=node.internal_name
880
- )
881
- _, next_node_name = self._get_status_and_next_node_name(
882
- node, current_branch, map_variable=map_variable
883
- )
884
- if next_node_name:
885
- # Terminal nodes do not have next node name
886
- next_node = current_branch.get_node_by_name(next_node_name)
887
-
888
- if next_node.node_type == defaults.FAIL:
889
- self.execute_node(next_node, map_variable=map_variable)
839
+ self._execute_node(node=node, map_variable=map_variable, **kwargs)
890
840
 
841
+ # Raise exception if the step failed
891
842
  step_log = self._context.run_log_store.get_step_log(
892
843
  node._get_step_log_name(map_variable), self._context.run_id
893
844
  )
894
845
  if step_log.status == defaults.FAIL:
895
846
  raise Exception(f"Step {node.name} failed")
896
847
 
848
+ self._implicitly_fail(node, map_variable)
849
+
897
850
  def fan_out(self, node: BaseNode, map_variable: TypeMapVariable = None):
851
+ # This could be the first step of the graph
852
+ self._use_volumes()
853
+
854
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
855
+ exists_ok = error_on_existing_run_id == "false"
856
+ self._set_up_run_log(exists_ok=exists_ok)
857
+
898
858
  super().fan_out(node, map_variable)
899
859
 
900
860
  # If its a map node, write the list values to "/tmp/output.txt"
901
861
  if node.node_type == "map":
902
- node = cast(MapNode, node)
862
+ assert isinstance(node, MapNode)
903
863
  iterate_on = self._context.run_log_store.get_parameters(
904
864
  self._context.run_id
905
865
  )[node.iterate_on]
@@ -907,401 +867,55 @@ class ArgoExecutor(GenericPipelineExecutor):
907
867
  with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
908
868
  json.dump(iterate_on.get_value(), myfile, indent=4)
909
869
 
910
- def sanitize_name(self, name):
911
- return name.replace(" ", "-").replace(".", "-").replace("_", "-")
912
-
913
- def get_clean_name(self, node: BaseNode):
914
- # Cache names for the node
915
- if node.internal_name not in self._clean_names:
916
- sanitized = self.sanitize_name(node.name)
917
- tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
918
- self._clean_names[node.internal_name] = (
919
- f"{sanitized}-{node.node_type}-{tag}"
920
- )
921
-
922
- return self._clean_names[node.internal_name]
923
-
924
- def compose_map_variable(
925
- self, list_of_iter_values: Optional[List] = None
926
- ) -> TypeMapVariable:
927
- map_variable = OrderedDict()
928
-
929
- # If we are inside a map node, compose a map_variable
930
- # The values of "iterate_as" are sent over as inputs to the container template
931
- if list_of_iter_values:
932
- for var in list_of_iter_values:
933
- map_variable[var] = "{{inputs.parameters." + str(var) + "}}"
934
-
935
- return map_variable # type: ignore
936
-
937
- def create_container_template(
938
- self,
939
- working_on: BaseNode,
940
- command: str,
941
- inputs: Optional[List] = None,
942
- outputs: Optional[List] = None,
943
- overwrite_name: str = "",
944
- ):
945
- effective_node_config = self._resolve_executor_config(working_on)
946
-
947
- override: Override = Override(**effective_node_config)
948
-
949
- container = Container(
950
- command=command,
951
- image=override.image,
952
- volume_mounts=self._container_volumes,
953
- image_pull_policy=override.image_pull_policy,
954
- resources=override.resources,
955
- secrets_from_k8s=self.secrets_from_k8s,
956
- )
957
-
958
- if (
959
- working_on.name == self._context.dag.start_at
960
- and self.expose_parameters_as_inputs
961
- ):
962
- for key, value in self._get_parameters().items():
963
- value = value.get_value() # type: ignore
964
- # Get the value from work flow parameters for dynamic behavior
965
- if (
966
- isinstance(value, int)
967
- or isinstance(value, float)
968
- or isinstance(value, str)
969
- ):
970
- env_var = EnvVar(
971
- name=defaults.PARAMETER_PREFIX + key,
972
- value="{{workflow.parameters." + key + "}}",
973
- )
974
- container.env_vars.append(env_var)
975
-
976
- clean_name = self.get_clean_name(working_on)
977
- if overwrite_name:
978
- clean_name = overwrite_name
979
-
980
- container_template = ContainerTemplate(
981
- name=clean_name,
982
- active_deadline_seconds=(
983
- override.max_step_duration_in_seconds
984
- if self.max_step_duration_in_seconds
985
- != override.max_step_duration_in_seconds
986
- else None
987
- ),
988
- container=container,
989
- retry_strategy=override.retry_strategy
990
- if self.retry_strategy != override.retry_strategy
991
- else None,
992
- tolerations=override.tolerations
993
- if self.tolerations != override.tolerations
994
- else None,
995
- node_selector=override.node_selector
996
- if self.node_selector != override.node_selector
997
- else None,
998
- )
999
-
1000
- # inputs are the "iterate_as" value map variables in the same order as they are observed
1001
- # We need to expose the map variables in the command of the container
1002
- if inputs:
1003
- if not container_template.inputs:
1004
- container_template.inputs = []
1005
- container_template.inputs.extend(inputs)
870
+ def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
871
+ self._use_volumes()
872
+ super().fan_in(node, map_variable)
873
+
874
+ def _use_volumes(self):
875
+ match self._context.run_log_store.service_name:
876
+ case "file-system":
877
+ self._context.run_log_store.log_folder = self._container_log_location
878
+ case "chunked-fs":
879
+ self._context.run_log_store.log_folder = self._container_log_location
880
+
881
+ match self._context.catalog_handler.service_name:
882
+ case "file-system":
883
+ self._context.catalog_handler.catalog_location = (
884
+ self._container_catalog_location
885
+ )
1006
886
 
1007
- # The map step fan out would create an output that we should propagate via Argo
1008
- if outputs:
1009
- if not container_template.outputs:
1010
- container_template.outputs = []
1011
- container_template.outputs.extend(outputs)
887
+ @cached_property
888
+ def volume_pairs(self) -> list[VolumePair]:
889
+ volume_pairs: list[VolumePair] = []
1012
890
 
1013
- return container_template
1014
-
1015
- def _create_fan_out_template(
1016
- self, composite_node, list_of_iter_values: Optional[List] = None
1017
- ):
1018
- clean_name = self.get_clean_name(composite_node)
1019
- inputs = []
1020
- # If we are fanning out already map state, we need to send the map variable inside
1021
- # The container template also should be accepting an input parameter
1022
- map_variable = None
1023
- if list_of_iter_values:
1024
- map_variable = self.compose_map_variable(
1025
- list_of_iter_values=list_of_iter_values
891
+ if self.pvc_for_runnable:
892
+ common_volume = Volume(
893
+ name="runnable",
894
+ persistent_volume_claim=PersistentVolumeClaimSource(
895
+ claim_name=self.pvc_for_runnable
896
+ ),
1026
897
  )
1027
-
1028
- for val in list_of_iter_values:
1029
- inputs.append(Parameter(name=val))
1030
-
1031
- command = utils.get_fan_command(
1032
- mode="out",
1033
- node=composite_node,
1034
- run_id=self._run_id_placeholder,
1035
- map_variable=map_variable,
1036
- log_level=self._log_level,
1037
- )
1038
-
1039
- outputs = []
1040
- # If the node is a map node, we have to set the output parameters
1041
- # Output is always the step's internal name + iterate-on
1042
- if composite_node.node_type == "map":
1043
- output_parameter = OutputParameter(name="iterate-on")
1044
- outputs.append(output_parameter)
1045
-
1046
- # Create the node now
1047
- step_config = {"command": command, "type": "task", "next": "dummy"}
1048
- node = create_node(name=f"{clean_name}-fan-out", step_config=step_config)
1049
-
1050
- container_template = self.create_container_template(
1051
- working_on=node,
1052
- command=command,
1053
- outputs=outputs,
1054
- inputs=inputs,
1055
- overwrite_name=f"{clean_name}-fan-out",
1056
- )
1057
-
1058
- self._container_templates.append(container_template)
1059
- return DagTaskTemplate(
1060
- name=f"{clean_name}-fan-out", template=f"{clean_name}-fan-out"
1061
- )
1062
-
1063
- def _create_fan_in_template(
1064
- self, composite_node, list_of_iter_values: Optional[List] = None
1065
- ):
1066
- clean_name = self.get_clean_name(composite_node)
1067
- inputs = []
1068
- # If we are fanning in already map state, we need to send the map variable inside
1069
- # The container template also should be accepting an input parameter
1070
- map_variable = None
1071
- if list_of_iter_values:
1072
- map_variable = self.compose_map_variable(
1073
- list_of_iter_values=list_of_iter_values
898
+ common_volume_mount = VolumeMount(
899
+ name="runnable",
900
+ mount_path="/tmp",
1074
901
  )
1075
-
1076
- for val in list_of_iter_values:
1077
- inputs.append(Parameter(name=val))
1078
-
1079
- command = utils.get_fan_command(
1080
- mode="in",
1081
- node=composite_node,
1082
- run_id=self._run_id_placeholder,
1083
- map_variable=map_variable,
1084
- log_level=self._log_level,
1085
- )
1086
-
1087
- step_config = {"command": command, "type": "task", "next": "dummy"}
1088
- node = create_node(name=f"{clean_name}-fan-in", step_config=step_config)
1089
- container_template = self.create_container_template(
1090
- working_on=node,
1091
- command=command,
1092
- inputs=inputs,
1093
- overwrite_name=f"{clean_name}-fan-in",
1094
- )
1095
- self._container_templates.append(container_template)
1096
- clean_name = self.get_clean_name(composite_node)
1097
- return DagTaskTemplate(
1098
- name=f"{clean_name}-fan-in", template=f"{clean_name}-fan-in"
1099
- )
1100
-
1101
- def _gather_task_templates_of_dag(
1102
- self,
1103
- dag: Graph,
1104
- dag_name="runnable-dag",
1105
- list_of_iter_values: Optional[List] = None,
1106
- ):
1107
- current_node = dag.start_at
1108
- previous_node = None
1109
- previous_node_template_name = None
1110
-
1111
- templates: Dict[str, DagTaskTemplate] = {}
1112
-
1113
- if not list_of_iter_values:
1114
- list_of_iter_values = []
1115
-
1116
- while True:
1117
- working_on = dag.get_node_by_name(current_node)
1118
- if previous_node == current_node:
1119
- raise Exception("Potentially running in a infinite loop")
1120
-
1121
- render_obj = get_renderer(working_on)(executor=self, node=working_on)
1122
- render_obj.render(list_of_iter_values=list_of_iter_values.copy())
1123
-
1124
- clean_name = self.get_clean_name(working_on)
1125
-
1126
- # If a task template for clean name exists, retrieve it (could have been created by on_failure)
1127
- template = templates.get(
1128
- clean_name, DagTaskTemplate(name=clean_name, template=clean_name)
902
+ volume_pairs.append(
903
+ VolumePair(volume=common_volume, volume_mount=common_volume_mount)
1129
904
  )
1130
-
1131
- # Link the current node to previous node, if the previous node was successful.
1132
- if previous_node:
1133
- template.depends.append(f"{previous_node_template_name}.Succeeded")
1134
-
1135
- templates[clean_name] = template
1136
-
1137
- # On failure nodes
1138
- if (
1139
- working_on.node_type not in ["success", "fail"]
1140
- and working_on._get_on_failure_node()
1141
- ):
1142
- failure_node = dag.get_node_by_name(working_on._get_on_failure_node())
1143
-
1144
- # same logic, if a template exists, retrieve it
1145
- # if not, create a new one
1146
- render_obj = get_renderer(working_on)(executor=self, node=failure_node)
1147
- render_obj.render(list_of_iter_values=list_of_iter_values.copy())
1148
-
1149
- failure_template_name = self.get_clean_name(failure_node)
1150
- # If a task template for clean name exists, retrieve it
1151
- failure_template = templates.get(
1152
- failure_template_name,
1153
- DagTaskTemplate(
1154
- name=failure_template_name, template=failure_template_name
905
+ counter = 0
906
+ for custom_volume in self.custom_volumes or []:
907
+ name = f"custom-volume-{counter}"
908
+ volume_pairs.append(
909
+ VolumePair(
910
+ volume=Volume(
911
+ name=name,
912
+ persistent_volume_claim=custom_volume.persistent_volume_claim,
913
+ ),
914
+ volume_mount=VolumeMount(
915
+ name=name,
916
+ mount_path=custom_volume.mount_path,
1155
917
  ),
1156
918
  )
1157
- failure_template.depends.append(f"{clean_name}.Failed")
1158
- templates[failure_template_name] = failure_template
1159
-
1160
- # If we are in a map node, we need to add the values as arguments
1161
- template = templates[clean_name]
1162
- if list_of_iter_values:
1163
- if not template.arguments:
1164
- template.arguments = []
1165
- for value in list_of_iter_values:
1166
- template.arguments.append(
1167
- Argument(
1168
- name=value, value="{{inputs.parameters." + value + "}}"
1169
- )
1170
- )
1171
-
1172
- # Move ahead to the next node
1173
- previous_node = current_node
1174
- previous_node_template_name = self.get_clean_name(working_on)
1175
-
1176
- if working_on.node_type in ["success", "fail"]:
1177
- break
1178
-
1179
- current_node = working_on._get_next_node()
1180
-
1181
- # Add the iteration values as input to dag template
1182
- dag_template = DagTemplate(tasks=list(templates.values()), name=dag_name)
1183
- if list_of_iter_values:
1184
- if not dag_template.inputs:
1185
- dag_template.inputs = []
1186
- dag_template.inputs.extend(
1187
- [Parameter(name=val) for val in list_of_iter_values]
1188
- )
1189
-
1190
- # Add the dag template to the list of templates
1191
- self._dag_templates.append(dag_template)
1192
-
1193
- def _get_template_defaults(self) -> TemplateDefaults:
1194
- user_provided_config = self.model_dump(by_alias=False)
1195
-
1196
- return TemplateDefaults(**user_provided_config)
1197
-
1198
- def execute_graph(self, dag: Graph, map_variable: Optional[dict] = None, **kwargs):
1199
- # TODO: Add metadata
1200
- arguments = []
1201
- # Expose "simple" parameters as workflow arguments for dynamic behavior
1202
- if self.expose_parameters_as_inputs:
1203
- for key, value in self._get_parameters().items():
1204
- value = value.get_value() # type: ignore
1205
- if isinstance(value, dict) or isinstance(value, list):
1206
- continue
1207
-
1208
- env_var = EnvVar(name=key, value=value) # type: ignore
1209
- arguments.append(env_var)
1210
-
1211
- run_id_var = EnvVar(name="run_id", value="{{workflow.uid}}")
1212
- log_level_var = EnvVar(name="log_level", value=defaults.LOG_LEVEL)
1213
- arguments.append(run_id_var)
1214
- arguments.append(log_level_var)
1215
-
1216
- # TODO: Can we do reruns?
1217
-
1218
- for volume in self.spec.volumes:
1219
- self._container_volumes.append(
1220
- ContainerVolume(name=volume.name, mount_path=volume.mount_path)
1221
- )
1222
-
1223
- # Container specifications are globally collected and added at the end.
1224
- # Dag specifications are added as part of the dag traversal.
1225
- templates: List[Union[DagTemplate, ContainerTemplate]] = []
1226
- self._gather_task_templates_of_dag(dag=dag, list_of_iter_values=[])
1227
- templates.extend(self._dag_templates)
1228
- templates.extend(self._container_templates)
1229
-
1230
- spec = self.spec
1231
- spec.templates = templates
1232
- spec.arguments = arguments
1233
- workflow = Workflow(metadata=self.metadata, spec=spec)
1234
-
1235
- yaml = YAML()
1236
- with open(self.output_file, "w") as f:
1237
- yaml.indent(mapping=2, sequence=4, offset=2)
1238
-
1239
- yaml.dump(workflow.model_dump(by_alias=True, exclude_none=True), f)
1240
-
1241
- def send_return_code(self, stage="traversal"):
1242
- """
1243
- Convenience function used by pipeline to send return code to the caller of the cli
1244
-
1245
- Raises:
1246
- Exception: If the pipeline execution failed
1247
- """
1248
- if (
1249
- stage != "traversal"
1250
- ): # traversal does no actual execution, so return code is pointless
1251
- run_id = self._context.run_id
1252
-
1253
- run_log = self._context.run_log_store.get_run_log_by_id(
1254
- run_id=run_id, full=False
1255
919
  )
1256
- if run_log.status == defaults.FAIL:
1257
- raise exceptions.ExecutionFailedError(run_id)
1258
-
1259
-
1260
- # TODO:
1261
- # class FileSystemRunLogStore(BaseIntegration):
1262
- # """
1263
- # Only local execution mode is possible for Buffered Run Log store
1264
- # """
1265
-
1266
- # executor_type = "argo"
1267
- # service_type = "run_log_store" # One of secret, catalog, datastore
1268
- # service_provider = "file-system" # The actual implementation of the service
1269
-
1270
- # def validate(self, **kwargs):
1271
- # msg = (
1272
- # "Argo cannot run work with file-system run log store. "
1273
- # "Unless you have made a mechanism to use volume mounts."
1274
- # "Using this run log store if the pipeline has concurrent tasks might lead to unexpected results"
1275
- # )
1276
- # logger.warning(msg)
1277
-
1278
-
1279
- # class ChunkedFileSystemRunLogStore(BaseIntegration):
1280
- # """
1281
- # Only local execution mode is possible for Buffered Run Log store
1282
- # """
1283
-
1284
- # executor_type = "argo"
1285
- # service_type = "run_log_store" # One of secret, catalog, datastore
1286
- # service_provider = "chunked-fs" # The actual implementation of the service
1287
-
1288
- # def validate(self, **kwargs):
1289
- # msg = (
1290
- # "Argo cannot run work with chunked file-system run log store. "
1291
- # "Unless you have made a mechanism to use volume mounts"
1292
- # )
1293
- # logger.warning(msg)
1294
-
1295
-
1296
- # class FileSystemCatalog(BaseIntegration):
1297
- # """
1298
- # Only local execution mode is possible for Buffered Run Log store
1299
- # """
1300
-
1301
- # executor_type = "argo"
1302
- # service_type = "catalog" # One of secret, catalog, datastore
1303
- # service_provider = "file-system" # The actual implementation of the service
1304
-
1305
- # def validate(self, **kwargs):
1306
- # msg = "Argo cannot run work with file-system run log store. Unless you have made a mechanism to use volume mounts"
1307
- # logger.warning(msg)
920
+ counter += 1
921
+ return volume_pairs