runnable 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,65 +1,53 @@
1
1
  import json
2
- import logging
2
+ import os
3
3
  import random
4
4
  import shlex
5
5
  import string
6
- from abc import ABC, abstractmethod
7
- from collections import OrderedDict
8
- from typing import Dict, List, Optional, Union, cast
6
+ from collections import namedtuple
7
+ from enum import Enum
8
+ from functools import cached_property
9
+ from typing import Annotated, Literal, Optional
9
10
 
10
11
  from pydantic import (
11
12
  BaseModel,
12
13
  ConfigDict,
13
14
  Field,
15
+ PlainSerializer,
16
+ PrivateAttr,
14
17
  computed_field,
15
- field_serializer,
16
- field_validator,
18
+ model_validator,
17
19
  )
18
- from pydantic.functional_serializers import PlainSerializer
20
+ from pydantic.alias_generators import to_camel
19
21
  from ruamel.yaml import YAML
20
- from typing_extensions import Annotated
21
22
 
22
- from extensions.nodes.nodes import DagNode, MapNode, ParallelNode
23
+ from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
23
24
  from extensions.pipeline_executor import GenericPipelineExecutor
24
- from runnable import defaults, exceptions, utils
25
+ from runnable import defaults, utils
25
26
  from runnable.defaults import TypeMapVariable
26
- from runnable.graph import Graph, create_node, search_node_by_internal_name
27
+ from runnable.graph import Graph, search_node_by_internal_name
27
28
  from runnable.nodes import BaseNode
28
29
 
29
- logger = logging.getLogger(defaults.NAME)
30
-
31
- # TODO: Leave the run log in consistent state.
32
- # TODO: Make the config camel case just like Argo does.
33
-
34
- """
35
- executor:
36
- type: argo
37
- config:
38
- image: # apply to template
39
- max_workflow_duration: # Apply to spec
40
- nodeSelector: #Apply to spec
41
- parallelism: #apply to spec
42
- resources: # convert to podSpecPath
43
- limits:
44
- requests:
45
- retryStrategy:
46
- max_step_duration: # apply to templateDefaults
47
- step_timeout: # apply to templateDefaults
48
- tolerations: # apply to spec
49
- imagePullPolicy: # apply to template
50
-
51
- overrides:
52
- override:
53
- tolerations: # template
54
- image: # container
55
- max_step_duration: # template
56
- step_timeout: #template
57
- nodeSelector: #template
58
- parallelism: # this need to applied for map
59
- resources: # container
60
- imagePullPolicy: #container
61
- retryStrategy: # template
62
- """
30
+
31
+ class BaseModelWIthConfig(BaseModel, use_enum_values=True):
32
+ model_config = ConfigDict(
33
+ extra="forbid",
34
+ alias_generator=to_camel,
35
+ populate_by_name=True,
36
+ from_attributes=True,
37
+ validate_default=True,
38
+ )
39
+
40
+
41
+ class BackOff(BaseModelWIthConfig):
42
+ duration: str = Field(default="2m")
43
+ factor: float = Field(default=2)
44
+ max_duration: str = Field(default="1h")
45
+
46
+
47
+ class RetryStrategy(BaseModelWIthConfig):
48
+ back_off: Optional[BackOff] = Field(default=None)
49
+ limit: int = 0
50
+ retry_policy: str = Field(default="Always")
63
51
 
64
52
 
65
53
  class SecretEnvVar(BaseModel):
@@ -79,7 +67,7 @@ class SecretEnvVar(BaseModel):
79
67
 
80
68
  @computed_field # type: ignore
81
69
  @property
82
- def valueFrom(self) -> Dict[str, Dict[str, str]]:
70
+ def value_from(self) -> dict[str, dict[str, str]]:
83
71
  return {
84
72
  "secretKeyRef": {
85
73
  "name": self.secret_name,
@@ -88,778 +76,756 @@ class SecretEnvVar(BaseModel):
88
76
  }
89
77
 
90
78
 
91
- class EnvVar(BaseModel):
92
- """
93
- Renders:
94
- parameters: # in arguments
95
- - name: x
96
- value: 3 # This is optional for workflow parameters
79
+ class EnvVar(BaseModelWIthConfig):
80
+ name: str
81
+ value: str
97
82
 
98
- """
99
83
 
84
+ class OutputParameter(BaseModelWIthConfig):
100
85
  name: str
101
- value: Union[str, int, float] = Field(default="")
86
+ value_from: dict[str, str] = {
87
+ "path": "/tmp/output.txt",
88
+ }
102
89
 
103
90
 
104
- class Parameter(BaseModel):
91
+ class Parameter(BaseModelWIthConfig):
105
92
  name: str
106
- value: Optional[str] = None
93
+ value: Optional[str | int | float | bool] = Field(default=None)
107
94
 
108
- @field_serializer("name")
109
- def serialize_name(self, name: str) -> str:
110
- return f"{str(name)}"
111
95
 
112
- @field_serializer("value")
113
- def serialize_value(self, value: str) -> str:
114
- return f"{value}"
96
+ class Inputs(BaseModelWIthConfig):
97
+ parameters: Optional[list[Parameter]] = Field(default=None)
115
98
 
116
99
 
117
- class OutputParameter(Parameter):
118
- """
119
- Renders:
120
- - name: step-name
121
- valueFrom:
122
- path: /tmp/output.txt
123
- """
100
+ class Outputs(BaseModelWIthConfig):
101
+ parameters: Optional[list[OutputParameter]] = Field(default=None)
124
102
 
125
- path: str = Field(default="/tmp/output.txt", exclude=True)
126
103
 
127
- @computed_field # type: ignore
128
- @property
129
- def valueFrom(self) -> Dict[str, str]:
130
- return {"path": self.path}
104
+ class Arguments(BaseModelWIthConfig):
105
+ parameters: Optional[list[Parameter]] = Field(default=None)
131
106
 
132
107
 
133
- class Argument(BaseModel):
134
- """
135
- Templates are called with arguments, which become inputs for the template
136
- Renders:
137
- arguments:
138
- parameters:
139
- - name: The name of the parameter
140
- value: The value of the parameter
141
- """
108
+ class TolerationEffect(str, Enum):
109
+ NoSchedule = "NoSchedule"
110
+ PreferNoSchedule = "PreferNoSchedule"
111
+ NoExecute = "NoExecute"
142
112
 
143
- name: str
144
- value: str
145
113
 
146
- @field_serializer("name")
147
- def serialize_name(self, name: str) -> str:
148
- return f"{str(name)}"
114
+ class TolerationOperator(str, Enum):
115
+ Exists = "Exists"
116
+ Equal = "Equal"
149
117
 
150
- @field_serializer("value")
151
- def serialize_value(self, value: str) -> str:
152
- return f"{value}"
153
118
 
119
+ class PodMetaData(BaseModelWIthConfig):
120
+ annotations: dict[str, str] = Field(default_factory=dict)
121
+ labels: dict[str, str] = Field(default_factory=dict)
154
122
 
155
- class Request(BaseModel):
156
- """
157
- The default requests
158
- """
159
123
 
160
- memory: str = "1Gi"
161
- cpu: str = "250m"
124
+ class Toleration(BaseModelWIthConfig):
125
+ effect: Optional[TolerationEffect] = Field(default=None)
126
+ key: Optional[str] = Field(default=None)
127
+ operator: TolerationOperator = Field(default=TolerationOperator.Equal)
128
+ tolerationSeconds: Optional[int] = Field(default=None)
129
+ value: Optional[str] = Field(default=None)
162
130
 
131
+ @model_validator(mode="after")
132
+ def validate_tolerations(self) -> "Toleration":
133
+ if not self.key:
134
+ if self.operator != TolerationOperator.Exists:
135
+ raise ValueError("Toleration key is required when operator is Equal")
163
136
 
164
- VendorGPU = Annotated[
165
- Optional[int],
166
- PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
167
- ]
137
+ if self.operator == TolerationOperator.Exists:
138
+ if self.value:
139
+ raise ValueError(
140
+ "Toleration value is not allowed when operator is Exists"
141
+ )
142
+ return self
168
143
 
169
144
 
170
- class Limit(Request):
171
- """
172
- The default limits
173
- """
145
+ class ImagePullPolicy(str, Enum):
146
+ Always = "Always"
147
+ IfNotPresent = "IfNotPresent"
148
+ Never = "Never"
174
149
 
175
- gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
176
150
 
151
+ class PersistentVolumeClaimSource(BaseModelWIthConfig):
152
+ claim_name: str
153
+ read_only: bool = Field(default=False)
177
154
 
178
- class Resources(BaseModel):
179
- limits: Limit = Field(default=Limit(), serialization_alias="limits")
180
- requests: Request = Field(default=Request(), serialization_alias="requests")
181
155
 
156
+ class Volume(BaseModelWIthConfig):
157
+ name: str
158
+ persistent_volume_claim: PersistentVolumeClaimSource
182
159
 
183
- class BackOff(BaseModel):
184
- duration_in_seconds: int = Field(default=2 * 60, serialization_alias="duration")
185
- factor: float = Field(default=2, serialization_alias="factor")
186
- max_duration: int = Field(default=60 * 60, serialization_alias="maxDuration")
160
+ def __hash__(self):
161
+ return hash(self.name)
187
162
 
188
- @field_serializer("duration_in_seconds")
189
- def cast_duration_as_str(self, duration_in_seconds: int, _info) -> str:
190
- return str(duration_in_seconds)
191
163
 
192
- @field_serializer("max_duration")
193
- def cast_mas_duration_as_str(self, max_duration: int, _info) -> str:
194
- return str(max_duration)
164
+ class VolumeMount(BaseModelWIthConfig):
165
+ mount_path: str
166
+ name: str
167
+ read_only: bool = Field(default=False)
195
168
 
169
+ @model_validator(mode="after")
170
+ def validate_volume_mount(self) -> "VolumeMount":
171
+ if "." in self.mount_path:
172
+ raise ValueError("mount_path cannot contain '.'")
173
+
174
+ return self
196
175
 
197
- class Retry(BaseModel):
198
- limit: int = 0
199
- retry_policy: str = Field(default="Always", serialization_alias="retryPolicy")
200
- back_off: BackOff = Field(default=BackOff(), serialization_alias="backoff")
201
176
 
202
- @field_serializer("limit")
203
- def cast_limit_as_str(self, limit: int, _info) -> str:
204
- return str(limit)
177
+ VolumePair = namedtuple("VolumePair", ["volume", "volume_mount"])
205
178
 
206
179
 
207
- class Toleration(BaseModel):
208
- effect: str
180
+ class LabelSelectorRequirement(BaseModelWIthConfig):
209
181
  key: str
210
182
  operator: str
211
- value: str
212
-
213
-
214
- class TemplateDefaults(BaseModel):
215
- max_step_duration: int = Field(
216
- default=60 * 60 * 2,
217
- serialization_alias="activeDeadlineSeconds",
218
- gt=0,
219
- description="Max run time of a step",
220
- )
221
-
222
- @computed_field # type: ignore
223
- @property
224
- def timeout(self) -> str:
225
- return f"{self.max_step_duration + 60*60}s"
226
-
183
+ values: list[str]
227
184
 
228
- ShlexCommand = Annotated[
229
- str, PlainSerializer(lambda x: shlex.split(x), return_type=List[str])
230
- ]
231
185
 
186
+ class PodGCStrategy(str, Enum):
187
+ OnPodCompletion = "OnPodCompletion"
188
+ OnPodSuccess = "OnPodSuccess"
189
+ OnWorkflowCompletion = "OnWorkflowCompletion"
190
+ OnWorkflowSuccess = "OnWorkflowSuccess"
232
191
 
233
- class Container(BaseModel):
234
- image: str
235
- command: ShlexCommand
236
- volume_mounts: Optional[List["ContainerVolume"]] = Field(
237
- default=None, serialization_alias="volumeMounts"
238
- )
239
- image_pull_policy: str = Field(default="", serialization_alias="imagePullPolicy")
240
- resources: Optional[Resources] = Field(
241
- default=None, serialization_alias="resources"
242
- )
243
192
 
244
- env_vars: List[EnvVar] = Field(default_factory=list, exclude=True)
245
- secrets_from_k8s: List[SecretEnvVar] = Field(default_factory=list, exclude=True)
193
+ class LabelSelector(BaseModelWIthConfig):
194
+ matchExpressions: list[LabelSelectorRequirement] = Field(default_factory=list)
195
+ matchLabels: dict[str, str] = Field(default_factory=dict)
246
196
 
247
- @computed_field # type: ignore
248
- @property
249
- def env(self) -> Optional[List[Union[EnvVar, SecretEnvVar]]]:
250
- if not self.env_vars and not self.secrets_from_k8s:
251
- return None
252
197
 
253
- return self.env_vars + self.secrets_from_k8s
198
+ class PodGC(BaseModelWIthConfig):
199
+ delete_delay_duration: str = Field(default="1h") # 1 hour
200
+ label_selector: Optional[LabelSelector] = Field(default=None)
201
+ strategy: Optional[PodGCStrategy] = Field(default=None)
254
202
 
255
203
 
256
- class DagTaskTemplate(BaseModel):
204
+ class Request(BaseModel):
257
205
  """
258
- dag:
259
- tasks:
260
- name: A
261
- template: nested-diamond
262
- arguments:
263
- parameters: [{name: message, value: A}]
206
+ The default requests
264
207
  """
265
208
 
266
- name: str
267
- template: str
268
- depends: List[str] = []
269
- arguments: Optional[List[Argument]] = Field(default=None)
270
- with_param: Optional[str] = Field(default=None, serialization_alias="withParam")
271
-
272
- @field_serializer("depends")
273
- def transform_depends_as_str(self, depends: List[str]) -> str:
274
- return " || ".join(depends)
209
+ memory: str = "1Gi"
210
+ cpu: str = "250m"
275
211
 
276
- @field_serializer("arguments", when_used="unless-none")
277
- def empty_arguments_to_none(
278
- self, arguments: List[Argument]
279
- ) -> Dict[str, List[Argument]]:
280
- return {"parameters": arguments}
281
212
 
213
+ VendorGPU = Annotated[
214
+ Optional[int],
215
+ PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
216
+ ]
282
217
 
283
- class ContainerTemplate(BaseModel):
284
- # These templates are used for actual execution nodes.
285
- name: str
286
- active_deadline_seconds: Optional[int] = Field(
287
- default=None, serialization_alias="activeDeadlineSeconds", gt=0
288
- )
289
- node_selector: Optional[Dict[str, str]] = Field(
290
- default=None, serialization_alias="nodeSelector"
291
- )
292
- retry_strategy: Optional[Retry] = Field(
293
- default=None, serialization_alias="retryStrategy"
294
- )
295
- tolerations: Optional[List[Toleration]] = Field(
296
- default=None, serialization_alias="tolerations"
297
- )
298
218
 
299
- container: Container
219
+ class Limit(Request):
220
+ """
221
+ The default limits
222
+ """
300
223
 
301
- outputs: Optional[List[OutputParameter]] = Field(
302
- default=None, serialization_alias="outputs"
303
- )
304
- inputs: Optional[List[Parameter]] = Field(
305
- default=None, serialization_alias="inputs"
306
- )
224
+ gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
307
225
 
308
- def __hash__(self):
309
- return hash(self.name)
310
226
 
311
- @field_serializer("outputs", when_used="unless-none")
312
- def reshape_outputs(
313
- self, outputs: List[OutputParameter]
314
- ) -> Dict[str, List[OutputParameter]]:
315
- return {"parameters": outputs}
227
+ class Resources(BaseModel):
228
+ limits: Limit = Field(default=Limit(), serialization_alias="limits")
229
+ requests: Request = Field(default=Request(), serialization_alias="requests")
316
230
 
317
- @field_serializer("inputs", when_used="unless-none")
318
- def reshape_inputs(self, inputs: List[Parameter]) -> Dict[str, List[Parameter]]:
319
- return {"parameters": inputs}
320
231
 
232
+ # This is what the user can override per template
233
+ # Some are specific to container and some are specific to dag
234
+ class TemplateDefaults(BaseModelWIthConfig):
235
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
236
+ fail_fast: bool = Field(default=True)
237
+ node_selector: dict[str, str] = Field(default_factory=dict)
238
+ parallelism: Optional[int] = Field(default=None)
239
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
240
+ timeout: Optional[str] = Field(default=None)
241
+ tolerations: Optional[list[Toleration]] = Field(default=None)
321
242
 
322
- class DagTemplate(BaseModel):
323
- # These are used for parallel, map nodes dag definition
324
- name: str = "runnable-dag"
325
- tasks: List[DagTaskTemplate] = Field(default=[], exclude=True)
326
- inputs: Optional[List[Parameter]] = Field(
327
- default=None, serialization_alias="inputs"
243
+ # These are in addition to what argo spec provides
244
+ image: str
245
+ image_pull_policy: Optional[ImagePullPolicy] = Field(default=ImagePullPolicy.Always)
246
+ resources: Resources = Field(default_factory=Resources)
247
+
248
+
249
+ # User provides this as part of the argoSpec
250
+ # some an be provided here or as a template default or node override
251
+ class ArgoWorkflowSpec(BaseModelWIthConfig):
252
+ active_deadline_seconds: int = Field(default=86400) # 1 day for the whole workflow
253
+ arguments: Optional[Arguments] = Field(default=None)
254
+ entrypoint: Literal["runnable-dag"] = Field(default="runnable-dag", frozen=True)
255
+ node_selector: dict[str, str] = Field(default_factory=dict)
256
+ parallelism: Optional[int] = Field(default=None) # GLobal parallelism
257
+ pod_gc: Optional[PodGC] = Field(default=None, serialization_alias="podGC")
258
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
259
+ service_account_name: Optional[str] = Field(default=None)
260
+ template_defaults: TemplateDefaults
261
+ tolerations: Optional[list[Toleration]] = Field(default=None)
262
+
263
+
264
+ class ArgoMetadata(BaseModelWIthConfig):
265
+ annotations: Optional[dict[str, str]] = Field(default=None)
266
+ generate_name: str # User can mention this to uniquely identify the run
267
+ labels: dict[str, str] = Field(default_factory=dict)
268
+ namespace: Optional[str] = Field(default="default")
269
+
270
+
271
+ class ArgoWorkflow(BaseModelWIthConfig):
272
+ apiVersion: Literal["argoproj.io/v1alpha1"] = Field(
273
+ default="argoproj.io/v1alpha1", frozen=True
328
274
  )
329
- parallelism: Optional[int] = None
330
- fail_fast: bool = Field(default=False, serialization_alias="failFast")
275
+ kind: Literal["Workflow"] = Field(default="Workflow", frozen=True)
276
+ metadata: ArgoMetadata
277
+ spec: ArgoWorkflowSpec
331
278
 
332
- @field_validator("parallelism")
333
- @classmethod
334
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
335
- if parallelism is not None and parallelism <= 0:
336
- raise ValueError("Parallelism must be a positive integer greater than 0")
337
- return parallelism
338
279
 
339
- @computed_field # type: ignore
340
- @property
341
- def dag(self) -> Dict[str, List[DagTaskTemplate]]:
342
- return {"tasks": self.tasks}
280
+ # The below are not visible to the user
281
+ class DagTask(BaseModelWIthConfig):
282
+ name: str
283
+ template: str # Should be name of a container template or dag template
284
+ arguments: Optional[Arguments] = Field(default=None)
285
+ with_param: Optional[str] = Field(default=None)
286
+ depends: Optional[str] = Field(default=None)
343
287
 
344
- @field_serializer("inputs", when_used="unless-none")
345
- def reshape_inputs(
346
- self, inputs: List[Parameter], _info
347
- ) -> Dict[str, List[Parameter]]:
348
- return {"parameters": inputs}
349
288
 
289
+ class CoreDagTemplate(BaseModelWIthConfig):
290
+ tasks: list[DagTask] = Field(default_factory=list[DagTask])
350
291
 
351
- class Volume(BaseModel):
352
- """
353
- spec config requires, name and persistentVolumeClaim
354
- step requires name and mountPath
355
- """
356
-
357
- name: str
358
- claim: str = Field(exclude=True)
359
- mount_path: str = Field(serialization_alias="mountPath", exclude=True)
360
292
 
361
- @computed_field # type: ignore
362
- @property
363
- def persistentVolumeClaim(self) -> Dict[str, str]:
364
- return {"claimName": self.claim}
293
+ class CoreContainerTemplate(BaseModelWIthConfig):
294
+ image: str
295
+ command: list[str]
296
+ image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.IfNotPresent)
297
+ env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
298
+ volume_mounts: list[VolumeMount] = Field(default_factory=list)
299
+ resources: Resources = Field(default_factory=Resources)
365
300
 
366
301
 
367
- class ContainerVolume(BaseModel):
302
+ class DagTemplate(BaseModelWIthConfig):
368
303
  name: str
369
- mount_path: str = Field(serialization_alias="mountPath")
304
+ dag: CoreDagTemplate = Field(default_factory=CoreDagTemplate)
305
+ inputs: Optional[Inputs] = Field(default=None)
306
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
307
+ fail_fast: bool = Field(default=True)
370
308
 
309
+ model_config = ConfigDict(
310
+ extra="ignore",
311
+ )
371
312
 
372
- class UserVolumeMounts(BaseModel):
373
- """
374
- The volume specification as user defines it.
375
- """
313
+ def __hash__(self):
314
+ return hash(self.name)
376
315
 
377
- name: str # This is the name of the PVC on K8s
378
- mount_path: str # This is mount path on the container
379
316
 
317
+ class ContainerTemplate((BaseModelWIthConfig)):
318
+ name: str
319
+ container: CoreContainerTemplate
320
+ inputs: Optional[Inputs] = Field(default=None)
321
+ outputs: Optional[Outputs] = Field(default=None)
322
+
323
+ # The remaining can be from template defaults or node overrides
324
+ active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
325
+ metadata: Optional[PodMetaData] = Field(default=None)
326
+ node_selector: dict[str, str] = Field(default_factory=dict)
327
+ parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
328
+ retry_strategy: Optional[RetryStrategy] = Field(default=None)
329
+ timeout: Optional[str] = Field(default=None)
330
+ tolerations: Optional[list[Toleration]] = Field(default=None)
331
+ volumes: Optional[list[Volume]] = Field(default=None)
332
+
333
+ model_config = ConfigDict(
334
+ extra="ignore",
335
+ )
380
336
 
381
- class NodeRenderer(ABC):
382
- allowed_node_types: List[str] = []
337
+ def __hash__(self):
338
+ return hash(self.name)
383
339
 
384
- def __init__(self, executor: "ArgoExecutor", node: BaseNode) -> None:
385
- self.executor = executor
386
- self.node = node
387
340
 
388
- @abstractmethod
389
- def render(self, list_of_iter_values: Optional[List] = None):
390
- pass
341
+ class CustomVolume(BaseModelWIthConfig):
342
+ mount_path: str
343
+ persistent_volume_claim: PersistentVolumeClaimSource
391
344
 
392
345
 
393
- class ExecutionNode(NodeRenderer):
394
- allowed_node_types = ["task", "stub", "success", "fail"]
346
+ class ArgoExecutor(GenericPipelineExecutor):
347
+ service_name: str = "argo"
348
+ _is_local: bool = False
349
+ mock: bool = False
350
+
351
+ model_config = ConfigDict(
352
+ extra="forbid",
353
+ alias_generator=to_camel,
354
+ populate_by_name=True,
355
+ from_attributes=True,
356
+ use_enum_values=True,
357
+ )
395
358
 
396
- def render(self, list_of_iter_values: Optional[List] = None):
397
- """
398
- Compose the map variable and create the execution command.
399
- Create an input to the command.
400
- create_container_template : creates an argument for the list of iter values
401
- """
402
- map_variable = self.executor.compose_map_variable(list_of_iter_values)
403
- command = utils.get_node_execution_command(
404
- self.node,
405
- over_write_run_id=self.executor._run_id_placeholder,
406
- map_variable=map_variable,
407
- log_level=self.executor._log_level,
408
- )
359
+ argo_workflow: ArgoWorkflow
409
360
 
410
- inputs = []
411
- if list_of_iter_values:
412
- for val in list_of_iter_values:
413
- inputs.append(Parameter(name=val))
361
+ # Lets use a generic one
362
+ pvc_for_runnable: Optional[str] = Field(default=None)
363
+ # pvc_for_catalog: Optional[str] = Field(default=None)
364
+ # pvc_for_run_log: Optional[str] = Field(default=None)
365
+ custom_volumes: Optional[list[CustomVolume]] = Field(
366
+ default_factory=list[CustomVolume]
367
+ )
414
368
 
415
- # Create the container template
416
- container_template = self.executor.create_container_template(
417
- working_on=self.node,
418
- command=command,
419
- inputs=inputs,
420
- )
369
+ expose_parameters_as_inputs: bool = True
370
+ secret_from_k8s: Optional[str] = Field(default=None)
371
+ output_file: str = Field(default="argo-pipeline.yaml")
372
+ log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
373
+ default="INFO"
374
+ )
421
375
 
422
- self.executor._container_templates.append(container_template)
376
+ # This should be used when we refer to run_id or log_level in the containers
377
+ _run_id_as_parameter: str = PrivateAttr(default="{{workflow.parameters.run_id}}")
378
+ _log_level_as_parameter: str = PrivateAttr(
379
+ default="{{workflow.parameters.log_level}}"
380
+ )
423
381
 
382
+ _templates: list[ContainerTemplate | DagTemplate] = PrivateAttr(
383
+ default_factory=list
384
+ )
385
+ _container_log_location: str = PrivateAttr(default="/tmp/run_logs/")
386
+ _container_catalog_location: str = PrivateAttr(default="/tmp/catalog/")
387
+ _added_initial_container: bool = PrivateAttr(default=False)
388
+
389
+ def sanitize_name(self, name: str) -> str:
390
+ formatted_name = name.replace(" ", "-").replace(".", "-").replace("_", "-")
391
+ tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
392
+ unique_name = f"{formatted_name}-{tag}"
393
+ unique_name = unique_name.replace("map-variable-placeholder-", "")
394
+ return unique_name
395
+
396
+ def _set_up_initial_container(self, container_template: CoreContainerTemplate):
397
+ if self._added_initial_container:
398
+ return
399
+
400
+ parameters: list[Parameter] = []
401
+
402
+ if self.argo_workflow.spec.arguments:
403
+ parameters = self.argo_workflow.spec.arguments.parameters or []
404
+
405
+ for parameter in parameters or []:
406
+ key, _ = parameter.name, parameter.value
407
+ env_var = EnvVar(
408
+ name=defaults.PARAMETER_PREFIX + key,
409
+ value="{{workflow.parameters." + key + "}}",
410
+ )
411
+ container_template.env.append(env_var)
424
412
 
425
- class DagNodeRenderer(NodeRenderer):
426
- allowed_node_types = ["dag"]
413
+ env_var = EnvVar(name="error_on_existing_run_id", value="true")
414
+ container_template.env.append(env_var)
427
415
 
428
- def render(self, list_of_iter_values: Optional[List] = None):
429
- self.node = cast(DagNode, self.node)
430
- task_template_arguments = []
431
- dag_inputs = []
432
- if list_of_iter_values:
433
- for value in list_of_iter_values:
434
- task_template_arguments.append(
435
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
436
- )
437
- dag_inputs.append(Parameter(name=value))
416
+ # After the first container is added, set the added_initial_container to True
417
+ self._added_initial_container = True
438
418
 
439
- clean_name = self.executor.get_clean_name(self.node)
440
- fan_out_template = self.executor._create_fan_out_template(
441
- composite_node=self.node, list_of_iter_values=list_of_iter_values
442
- )
443
- fan_out_template.arguments = (
444
- task_template_arguments if task_template_arguments else None
445
- )
419
+ def _create_fan_templates(
420
+ self,
421
+ node: BaseNode,
422
+ mode: str,
423
+ parameters: Optional[list[Parameter]],
424
+ task_name: str,
425
+ ):
426
+ template_defaults = self.argo_workflow.spec.template_defaults.model_dump()
446
427
 
447
- fan_in_template = self.executor._create_fan_in_template(
448
- composite_node=self.node, list_of_iter_values=list_of_iter_values
449
- )
450
- fan_in_template.arguments = (
451
- task_template_arguments if task_template_arguments else None
452
- )
428
+ map_variable: TypeMapVariable = {}
429
+ for parameter in parameters or []:
430
+ map_variable[parameter.name] = ( # type: ignore
431
+ "{{inputs.parameters." + str(parameter.name) + "}}"
432
+ )
453
433
 
454
- self.executor._gather_task_templates_of_dag(
455
- self.node.branch,
456
- dag_name=f"{clean_name}-branch",
457
- list_of_iter_values=list_of_iter_values,
434
+ fan_command = utils.get_fan_command(
435
+ mode=mode,
436
+ node=node,
437
+ run_id=self._run_id_as_parameter,
438
+ map_variable=map_variable,
458
439
  )
459
440
 
460
- branch_template = DagTaskTemplate(
461
- name=f"{clean_name}-branch",
462
- template=f"{clean_name}-branch",
463
- arguments=task_template_arguments if task_template_arguments else None,
441
+ core_container_template = CoreContainerTemplate(
442
+ command=shlex.split(fan_command),
443
+ image=template_defaults["image"],
444
+ image_pull_policy=template_defaults["image_pull_policy"],
445
+ volume_mounts=[
446
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
447
+ ],
464
448
  )
465
- branch_template.depends.append(f"{clean_name}-fan-out.Succeeded")
466
- fan_in_template.depends.append(f"{clean_name}-branch.Succeeded")
467
- fan_in_template.depends.append(f"{clean_name}-branch.Failed")
468
-
469
- self.executor._dag_templates.append(
470
- DagTemplate(
471
- tasks=[fan_out_template, branch_template, fan_in_template],
472
- name=clean_name,
473
- inputs=dag_inputs if dag_inputs else None,
474
- )
475
- )
476
-
477
449
 
478
- class ParallelNodeRender(NodeRenderer):
479
- allowed_node_types = ["parallel"]
450
+ # Either a task or a fan-out can the first container
451
+ self._set_up_initial_container(container_template=core_container_template)
480
452
 
481
- def render(self, list_of_iter_values: Optional[List] = None):
482
- self.node = cast(ParallelNode, self.node)
483
- task_template_arguments = []
484
- dag_inputs = []
485
- if list_of_iter_values:
486
- for value in list_of_iter_values:
487
- task_template_arguments.append(
488
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
489
- )
490
- dag_inputs.append(Parameter(name=value))
453
+ task_name += f"-fan-{mode}"
491
454
 
492
- clean_name = self.executor.get_clean_name(self.node)
493
- fan_out_template = self.executor._create_fan_out_template(
494
- composite_node=self.node, list_of_iter_values=list_of_iter_values
495
- )
496
- fan_out_template.arguments = (
497
- task_template_arguments if task_template_arguments else None
498
- )
455
+ outputs: Optional[Outputs] = None
456
+ if mode == "out" and node.node_type == "map":
457
+ outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
499
458
 
500
- fan_in_template = self.executor._create_fan_in_template(
501
- composite_node=self.node, list_of_iter_values=list_of_iter_values
502
- )
503
- fan_in_template.arguments = (
504
- task_template_arguments if task_template_arguments else None
459
+ container_template = ContainerTemplate(
460
+ container=core_container_template,
461
+ name=task_name,
462
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
463
+ inputs=Inputs(parameters=parameters),
464
+ outputs=outputs,
465
+ **template_defaults,
505
466
  )
506
467
 
507
- branch_templates = []
508
- for name, branch in self.node.branches.items():
509
- branch_name = self.executor.sanitize_name(name)
510
- self.executor._gather_task_templates_of_dag(
511
- branch,
512
- dag_name=f"{clean_name}-{branch_name}",
513
- list_of_iter_values=list_of_iter_values,
514
- )
515
- task_template = DagTaskTemplate(
516
- name=f"{clean_name}-{branch_name}",
517
- template=f"{clean_name}-{branch_name}",
518
- arguments=task_template_arguments if task_template_arguments else None,
519
- )
520
- task_template.depends.append(f"{clean_name}-fan-out.Succeeded")
521
- fan_in_template.depends.append(f"{task_template.name}.Succeeded")
522
- fan_in_template.depends.append(f"{task_template.name}.Failed")
523
- branch_templates.append(task_template)
524
-
525
- executor_config = self.executor._resolve_executor_config(self.node)
526
-
527
- self.executor._dag_templates.append(
528
- DagTemplate(
529
- tasks=[fan_out_template] + branch_templates + [fan_in_template],
530
- name=clean_name,
531
- inputs=dag_inputs if dag_inputs else None,
532
- parallelism=executor_config.get("parallelism", None),
533
- )
534
- )
468
+ self._templates.append(container_template)
535
469
 
470
+ def _create_container_template(
471
+ self,
472
+ node: BaseNode,
473
+ task_name: str,
474
+ inputs: Optional[Inputs] = None,
475
+ ) -> ContainerTemplate:
476
+ template_defaults = self.argo_workflow.spec.template_defaults.model_dump()
536
477
 
537
- class MapNodeRender(NodeRenderer):
538
- allowed_node_types = ["map"]
478
+ node_overide = {}
479
+ if hasattr(node, "overides"):
480
+ node_overide = node.overides
539
481
 
540
- def render(self, list_of_iter_values: Optional[List] = None):
541
- self.node = cast(MapNode, self.node)
542
- task_template_arguments = []
543
- dag_inputs = []
482
+ # update template defaults with node overrides
483
+ template_defaults.update(node_overide)
544
484
 
545
- if not list_of_iter_values:
546
- list_of_iter_values = []
485
+ inputs = inputs or Inputs(parameters=[])
547
486
 
548
- for value in list_of_iter_values:
549
- task_template_arguments.append(
550
- Argument(name=value, value="{{inputs.parameters." + value + "}}")
487
+ map_variable: TypeMapVariable = {}
488
+ for parameter in inputs.parameters or []:
489
+ map_variable[parameter.name] = ( # type: ignore
490
+ "{{inputs.parameters." + str(parameter.name) + "}}"
551
491
  )
552
- dag_inputs.append(Parameter(name=value))
553
492
 
554
- clean_name = self.executor.get_clean_name(self.node)
555
-
556
- fan_out_template = self.executor._create_fan_out_template(
557
- composite_node=self.node, list_of_iter_values=list_of_iter_values
558
- )
559
- fan_out_template.arguments = (
560
- task_template_arguments if task_template_arguments else None
493
+ # command = "runnable execute-single-node"
494
+ command = utils.get_node_execution_command(
495
+ node=node,
496
+ over_write_run_id=self._run_id_as_parameter,
497
+ map_variable=map_variable,
498
+ log_level=self._log_level_as_parameter,
561
499
  )
562
500
 
563
- fan_in_template = self.executor._create_fan_in_template(
564
- composite_node=self.node, list_of_iter_values=list_of_iter_values
565
- )
566
- fan_in_template.arguments = (
567
- task_template_arguments if task_template_arguments else None
501
+ core_container_template = CoreContainerTemplate(
502
+ command=shlex.split(command),
503
+ image=template_defaults["image"],
504
+ image_pull_policy=template_defaults["image_pull_policy"],
505
+ volume_mounts=[
506
+ volume_pair.volume_mount for volume_pair in self.volume_pairs
507
+ ],
568
508
  )
569
509
 
570
- list_of_iter_values.append(self.node.iterate_as)
510
+ self._set_up_initial_container(container_template=core_container_template)
571
511
 
572
- self.executor._gather_task_templates_of_dag(
573
- self.node.branch,
574
- dag_name=f"{clean_name}-map",
575
- list_of_iter_values=list_of_iter_values,
576
- )
577
-
578
- task_template = DagTaskTemplate(
579
- name=f"{clean_name}-map",
580
- template=f"{clean_name}-map",
581
- arguments=task_template_arguments if task_template_arguments else None,
582
- )
583
- task_template.with_param = (
584
- "{{tasks."
585
- + f"{clean_name}-fan-out"
586
- + ".outputs.parameters."
587
- + "iterate-on"
588
- + "}}"
512
+ container_template = ContainerTemplate(
513
+ container=core_container_template,
514
+ name=task_name,
515
+ inputs=Inputs(
516
+ parameters=[
517
+ Parameter(name=param.name) for param in inputs.parameters or []
518
+ ]
519
+ ),
520
+ volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
521
+ **template_defaults,
589
522
  )
590
523
 
591
- argument = Argument(name=self.node.iterate_as, value="{{item}}")
592
- if task_template.arguments is None:
593
- task_template.arguments = []
594
- task_template.arguments.append(argument)
595
-
596
- task_template.depends.append(f"{clean_name}-fan-out.Succeeded")
597
- fan_in_template.depends.append(f"{clean_name}-map.Succeeded")
598
- fan_in_template.depends.append(f"{clean_name}-map.Failed")
599
-
600
- executor_config = self.executor._resolve_executor_config(self.node)
524
+ return container_template
601
525
 
602
- self.executor._dag_templates.append(
603
- DagTemplate(
604
- tasks=[fan_out_template, task_template, fan_in_template],
605
- name=clean_name,
606
- inputs=dag_inputs if dag_inputs else None,
607
- parallelism=executor_config.get("parallelism", None),
608
- fail_fast=executor_config.get("fail_fast", True),
526
+ def _expose_secrets_to_task(
527
+ self,
528
+ working_on: BaseNode,
529
+ container_template: CoreContainerTemplate,
530
+ ):
531
+ assert isinstance(working_on, TaskNode)
532
+ secrets = working_on.executable.secrets
533
+ for secret in secrets:
534
+ assert self.secret_from_k8s is not None
535
+ secret_env_var = SecretEnvVar(
536
+ environment_variable=secret,
537
+ secret_name=self.secret_from_k8s, # This has to be exposed from config
538
+ secret_key=secret,
609
539
  )
610
- )
611
-
612
-
613
- def get_renderer(node):
614
- renderers = NodeRenderer.__subclasses__()
540
+ container_template.env.append(secret_env_var)
615
541
 
616
- for renderer in renderers:
617
- if node.node_type in renderer.allowed_node_types:
618
- return renderer
619
- raise Exception("This node type is not render-able")
620
-
621
-
622
- class MetaData(BaseModel):
623
- generate_name: str = Field(
624
- default="runnable-dag-", serialization_alias="generateName"
625
- )
626
- # The type ignore is related to: https://github.com/python/mypy/issues/18191
627
- annotations: Optional[Dict[str, str]] = Field(default_factory=dict) # type: ignore
628
- labels: Optional[Dict[str, str]] = Field(default_factory=dict) # type: ignore
629
- namespace: Optional[str] = Field(default=None)
630
-
631
-
632
- class Spec(BaseModel):
633
- active_deadline_seconds: int = Field(serialization_alias="activeDeadlineSeconds")
634
- entrypoint: str = Field(default="runnable-dag")
635
- node_selector: Optional[Dict[str, str]] = Field(
636
- default_factory=dict, # type: ignore
637
- serialization_alias="nodeSelector",
638
- )
639
- tolerations: Optional[List[Toleration]] = Field(
640
- default=None, serialization_alias="tolerations"
641
- )
642
- parallelism: Optional[int] = Field(default=None, serialization_alias="parallelism")
542
+ def _handle_failures(
543
+ self,
544
+ working_on: BaseNode,
545
+ dag: Graph,
546
+ task_name: str,
547
+ parent_dag_template: DagTemplate,
548
+ ):
549
+ if working_on._get_on_failure_node():
550
+ # Create a new dag template
551
+ on_failure_dag: DagTemplate = DagTemplate(name=f"on-failure-{task_name}")
552
+ # Add on failure of the current task to be the failure dag template
553
+ on_failure_task = DagTask(
554
+ name=f"on-failure-{task_name}",
555
+ template=f"on-failure-{task_name}",
556
+ depends=task_name + ".Failed",
557
+ )
558
+ # Set failfast of the dag template to be false
559
+ # If not, this branch will never be invoked
560
+ parent_dag_template.fail_fast = False
643
561
 
644
- # TODO: This has to be user driven
645
- pod_gc: Dict[str, str] = Field( # type ignore
646
- default={"strategy": "OnPodSuccess", "deleteDelayDuration": "600s"},
647
- serialization_alias="podGC",
648
- )
562
+ assert parent_dag_template.dag
649
563
 
650
- retry_strategy: Retry = Field(default=Retry(), serialization_alias="retryStrategy")
651
- service_account_name: Optional[str] = Field(
652
- default=None, serialization_alias="serviceAccountName"
653
- )
564
+ parent_dag_template.dag.tasks.append(on_failure_task)
565
+ self._gather_tasks_for_dag_template(
566
+ on_failure_dag,
567
+ dag=dag,
568
+ start_at=working_on._get_on_failure_node(),
569
+ )
654
570
 
655
- templates: List[Union[DagTemplate, ContainerTemplate]] = Field(default_factory=list)
656
- template_defaults: Optional[TemplateDefaults] = Field(
657
- default=None, serialization_alias="templateDefaults"
658
- )
571
+ # For the future me:
572
+ # - A task can output a array: in this case, its the fan out.
573
+ # - We are using withParam and arguments of the map template to send that value in
574
+ # - The map template should receive that value as a parameter into the template.
575
+ # - The task then start to use it as inputs.parameters.iterate-on
659
576
 
660
- arguments: Optional[List[EnvVar]] = Field(default_factory=list) # type: ignore
661
- persistent_volumes: List[UserVolumeMounts] = Field(
662
- default_factory=list, exclude=True
663
- )
577
+ def _gather_tasks_for_dag_template(
578
+ self,
579
+ dag_template: DagTemplate,
580
+ dag: Graph,
581
+ start_at: str,
582
+ parameters: Optional[list[Parameter]] = None,
583
+ ):
584
+ current_node: str = start_at
585
+ depends: str = ""
664
586
 
665
- @field_validator("parallelism")
666
- @classmethod
667
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
668
- if parallelism is not None and parallelism <= 0:
669
- raise ValueError("Parallelism must be a positive integer greater than 0")
670
- return parallelism
587
+ dag_template.dag = CoreDagTemplate()
671
588
 
672
- @computed_field # type: ignore
673
- @property
674
- def volumes(self) -> List[Volume]:
675
- volumes: List[Volume] = []
676
- claim_names = {}
677
- for i, user_volume in enumerate(self.persistent_volumes):
678
- if user_volume.name in claim_names:
679
- raise Exception(f"Duplicate claim name {user_volume.name}")
680
- claim_names[user_volume.name] = user_volume.name
681
-
682
- volume = Volume(
683
- name=f"executor-{i}",
684
- claim=user_volume.name,
685
- mount_path=user_volume.mount_path,
589
+ while True:
590
+ # Create the dag task with for the parent dag
591
+ working_on: BaseNode = dag.get_node_by_name(current_node)
592
+ task_name = self.sanitize_name(working_on.internal_name)
593
+ current_task = DagTask(
594
+ name=task_name,
595
+ template=task_name,
596
+ depends=depends if not depends else depends + ".Succeeded",
597
+ arguments=Arguments(
598
+ parameters=[
599
+ Parameter(
600
+ name=param.name,
601
+ value=f"{{{{inputs.parameters.{param.name}}}}}",
602
+ )
603
+ for param in parameters or []
604
+ ]
605
+ ),
686
606
  )
687
- volumes.append(volume)
688
- return volumes
607
+ dag_template.dag.tasks.append(current_task)
608
+ depends = task_name
609
+
610
+ match working_on.node_type:
611
+ case "task" | "success" | "stub":
612
+ template_of_container = self._create_container_template(
613
+ working_on,
614
+ task_name=task_name,
615
+ inputs=Inputs(parameters=parameters),
616
+ )
617
+ assert template_of_container.container is not None
689
618
 
690
- @field_serializer("arguments", when_used="unless-none")
691
- def reshape_arguments(
692
- self, arguments: List[EnvVar], _info
693
- ) -> Dict[str, List[EnvVar]]:
694
- return {"parameters": arguments}
619
+ if working_on.node_type == "task":
620
+ self._expose_secrets_to_task(
621
+ working_on,
622
+ container_template=template_of_container.container,
623
+ )
695
624
 
625
+ self._templates.append(template_of_container)
696
626
 
697
- class Workflow(BaseModel):
698
- api_version: str = Field(
699
- default="argoproj.io/v1alpha1",
700
- serialization_alias="apiVersion",
701
- )
702
- kind: str = "Workflow"
703
- metadata: MetaData = Field(default=MetaData())
704
- spec: Spec
627
+ case "map" | "parallel":
628
+ assert isinstance(working_on, MapNode) or isinstance(
629
+ working_on, ParallelNode
630
+ )
631
+ node_type = working_on.node_type
705
632
 
633
+ composite_template: DagTemplate = DagTemplate(
634
+ name=task_name, fail_fast=False
635
+ )
706
636
 
707
- class Override(BaseModel):
708
- model_config = ConfigDict(extra="ignore")
637
+ # Add the fan out task
638
+ fan_out_task = DagTask(
639
+ name=f"{task_name}-fan-out",
640
+ template=f"{task_name}-fan-out",
641
+ arguments=Arguments(parameters=parameters),
642
+ )
643
+ composite_template.dag.tasks.append(fan_out_task)
644
+ self._create_fan_templates(
645
+ node=working_on,
646
+ mode="out",
647
+ parameters=parameters,
648
+ task_name=task_name,
649
+ )
709
650
 
710
- image: str
711
- tolerations: Optional[List[Toleration]] = Field(default=None)
651
+ # Add the composite task
652
+ with_param = None
653
+ added_parameters = parameters or []
654
+ branches = {}
655
+ if node_type == "map":
656
+ # If the node is map, we need to handle the iterate as and on
657
+ assert isinstance(working_on, MapNode)
658
+ added_parameters = added_parameters + [
659
+ Parameter(name=working_on.iterate_as, value="{{item}}")
660
+ ]
661
+ with_param = f"{{{{tasks.{task_name}-fan-out.outputs.parameters.iterate-on}}}}"
662
+
663
+ branches["branch"] = working_on.branch
664
+ elif node_type == "parallel":
665
+ assert isinstance(working_on, ParallelNode)
666
+ branches = working_on.branches
667
+ else:
668
+ raise ValueError("Invalid node type")
669
+
670
+ fan_in_depends = ""
671
+
672
+ for name, branch in branches.items():
673
+ name = (
674
+ name.replace(" ", "-").replace(".", "-").replace("_", "-")
675
+ )
712
676
 
713
- max_step_duration_in_seconds: int = Field(
714
- default=2 * 60 * 60, # 2 hours
715
- gt=0,
716
- )
677
+ branch_task = DagTask(
678
+ name=f"{task_name}-{name}",
679
+ template=f"{task_name}-{name}",
680
+ depends=f"{task_name}-fan-out.Succeeded",
681
+ arguments=Arguments(parameters=added_parameters),
682
+ with_param=with_param,
683
+ )
684
+ composite_template.dag.tasks.append(branch_task)
685
+
686
+ branch_template = DagTemplate(
687
+ name=branch_task.name,
688
+ inputs=Inputs(
689
+ parameters=[
690
+ Parameter(name=param.name, value=None)
691
+ for param in added_parameters
692
+ ]
693
+ ),
694
+ )
717
695
 
718
- node_selector: Optional[Dict[str, str]] = Field(
719
- default=None,
720
- serialization_alias="nodeSelector",
721
- )
696
+ self._gather_tasks_for_dag_template(
697
+ dag_template=branch_template,
698
+ dag=branch,
699
+ start_at=branch.start_at,
700
+ parameters=added_parameters,
701
+ )
722
702
 
723
- parallelism: Optional[int] = Field(
724
- default=None,
725
- serialization_alias="parallelism",
726
- )
703
+ fan_in_depends += f"{branch_task.name}.Succeeded || {branch_task.name}.Failed || "
727
704
 
728
- resources: Resources = Field(
729
- default=Resources(),
730
- serialization_alias="resources",
731
- )
705
+ fan_in_task = DagTask(
706
+ name=f"{task_name}-fan-in",
707
+ template=f"{task_name}-fan-in",
708
+ depends=fan_in_depends.strip(" || "),
709
+ arguments=Arguments(parameters=parameters),
710
+ )
732
711
 
733
- image_pull_policy: str = Field(default="")
712
+ composite_template.dag.tasks.append(fan_in_task)
713
+ self._create_fan_templates(
714
+ node=working_on,
715
+ mode="in",
716
+ parameters=parameters,
717
+ task_name=task_name,
718
+ )
734
719
 
735
- retry_strategy: Retry = Field(
736
- default=Retry(),
737
- serialization_alias="retryStrategy",
738
- description="Common across all templates",
739
- )
720
+ self._templates.append(composite_template)
740
721
 
741
- @field_validator("parallelism")
742
- @classmethod
743
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
744
- if parallelism is not None and parallelism <= 0:
745
- raise ValueError("Parallelism must be a positive integer greater than 0")
746
- return parallelism
722
+ self._handle_failures(
723
+ working_on,
724
+ dag,
725
+ task_name,
726
+ parent_dag_template=dag_template,
727
+ )
747
728
 
729
+ if working_on.node_type == "success" or working_on.node_type == "fail":
730
+ break
748
731
 
749
- class ArgoExecutor(GenericPipelineExecutor):
750
- service_name: str = "argo"
751
- _is_local: bool = False
732
+ current_node = working_on._get_next_node()
752
733
 
753
- # TODO: Add logging level as option.
734
+ self._templates.append(dag_template)
754
735
 
755
- model_config = ConfigDict(extra="forbid")
736
+ def execute_graph(
737
+ self,
738
+ dag: Graph,
739
+ map_variable: dict[str, str | int | float] | None = None,
740
+ **kwargs,
741
+ ):
742
+ # All the arguments set at the spec level can be referred as "{{workflow.parameters.*}}"
743
+ # We want to use that functionality to override the parameters at the task level
744
+ # We should be careful to override them only at the first task.
745
+ arguments = [] # Can be updated in the UI
746
+ if self.expose_parameters_as_inputs:
747
+ for key, value in self._get_parameters().items():
748
+ value = value.get_value() # type: ignore
749
+ if isinstance(value, dict) or isinstance(value, list):
750
+ continue
756
751
 
757
- image: str
758
- expose_parameters_as_inputs: bool = True
759
- secrets_from_k8s: List[SecretEnvVar] = Field(default_factory=list)
760
- output_file: str = "argo-pipeline.yaml"
752
+ parameter = Parameter(name=key, value=value) # type: ignore
753
+ arguments.append(parameter)
761
754
 
762
- # Metadata related fields
763
- name: str = Field(
764
- default="runnable-dag-", description="Used as an identifier for the workflow"
765
- )
766
- annotations: Dict[str, str] = Field(default_factory=dict)
767
- labels: Dict[str, str] = Field(default_factory=dict)
755
+ run_id_var = Parameter(name="run_id", value="{{workflow.uid}}")
756
+ log_level_var = Parameter(name="log_level", value=self.log_level)
757
+ arguments.append(run_id_var)
758
+ arguments.append(log_level_var)
759
+ self.argo_workflow.spec.arguments = Arguments(parameters=arguments)
760
+
761
+ # This is the entry point of the argo execution
762
+ runnable_dag: DagTemplate = DagTemplate(name="runnable-dag")
763
+
764
+ self._gather_tasks_for_dag_template(
765
+ runnable_dag,
766
+ dag,
767
+ start_at=dag.start_at,
768
+ parameters=[],
769
+ )
770
+
771
+ argo_workflow_dump = self.argo_workflow.model_dump(
772
+ by_alias=True,
773
+ exclude={
774
+ "spec": {
775
+ "template_defaults": {"image_pull_policy", "image", "resources"}
776
+ }
777
+ },
778
+ exclude_none=True,
779
+ round_trip=False,
780
+ )
781
+ argo_workflow_dump["spec"]["templates"] = [
782
+ template.model_dump(
783
+ by_alias=True,
784
+ exclude_none=True,
785
+ )
786
+ for template in self._templates
787
+ ]
768
788
 
769
- max_workflow_duration_in_seconds: int = Field(
770
- 2 * 24 * 60 * 60, # 2 days
771
- serialization_alias="activeDeadlineSeconds",
772
- gt=0,
773
- )
774
- node_selector: Optional[Dict[str, str]] = Field(
775
- default=None,
776
- serialization_alias="nodeSelector",
777
- )
778
- parallelism: Optional[int] = Field(
779
- default=None,
780
- serialization_alias="parallelism",
781
- )
782
- resources: Resources = Field(
783
- default=Resources(),
784
- serialization_alias="resources",
785
- exclude=True,
786
- )
787
- retry_strategy: Retry = Field(
788
- default=Retry(),
789
- serialization_alias="retryStrategy",
790
- description="Common across all templates",
791
- )
792
- max_step_duration_in_seconds: int = Field(
793
- default=2 * 60 * 60, # 2 hours
794
- gt=0,
795
- )
796
- tolerations: Optional[List[Toleration]] = Field(default=None)
797
- image_pull_policy: str = Field(default="")
798
- service_account_name: Optional[str] = None
799
- persistent_volumes: List[UserVolumeMounts] = Field(default_factory=list)
800
-
801
- _run_id_placeholder: str = "{{workflow.parameters.run_id}}"
802
- _log_level: str = "{{workflow.parameters.log_level}}"
803
- _container_templates: List[ContainerTemplate] = []
804
- _dag_templates: List[DagTemplate] = []
805
- _clean_names: Dict[str, str] = {}
806
- _container_volumes: List[ContainerVolume] = []
807
-
808
- @field_validator("parallelism")
809
- @classmethod
810
- def validate_parallelism(cls, parallelism: Optional[int]) -> Optional[int]:
811
- if parallelism is not None and parallelism <= 0:
812
- raise ValueError("Parallelism must be a positive integer greater than 0")
813
- return parallelism
789
+ argo_workflow_dump["spec"]["volumes"] = [
790
+ volume_pair.volume.model_dump(by_alias=True)
791
+ for volume_pair in self.volume_pairs
792
+ ]
814
793
 
815
- @computed_field # type: ignore
816
- @property
817
- def step_timeout(self) -> int:
818
- """
819
- Maximum time the step can take to complete, including the pending state.
820
- """
821
- return (
822
- self.max_step_duration_in_seconds + 2 * 60 * 60
823
- ) # 2 hours + max_step_duration_in_seconds
794
+ yaml = YAML()
795
+ with open(self.output_file, "w") as f:
796
+ yaml.indent(mapping=2, sequence=4, offset=2)
797
+ yaml.dump(
798
+ argo_workflow_dump,
799
+ f,
800
+ )
824
801
 
825
- @property
826
- def metadata(self) -> MetaData:
827
- return MetaData(
828
- generate_name=self.name,
829
- annotations=self.annotations,
830
- labels=self.labels,
802
+ def _implicitly_fail(self, node: BaseNode, map_variable: TypeMapVariable):
803
+ assert self._context.dag
804
+ _, current_branch = search_node_by_internal_name(
805
+ dag=self._context.dag, internal_name=node.internal_name
831
806
  )
832
-
833
- @property
834
- def spec(self) -> Spec:
835
- return Spec(
836
- active_deadline_seconds=self.max_workflow_duration_in_seconds,
837
- node_selector=self.node_selector,
838
- tolerations=self.tolerations,
839
- parallelism=self.parallelism,
840
- retry_strategy=self.retry_strategy,
841
- service_account_name=self.service_account_name,
842
- persistent_volumes=self.persistent_volumes,
843
- template_defaults=TemplateDefaults(
844
- max_step_duration=self.max_step_duration_in_seconds
845
- ),
807
+ _, next_node_name = self._get_status_and_next_node_name(
808
+ node, current_branch, map_variable=map_variable
846
809
  )
810
+ if next_node_name:
811
+ # Terminal nodes do not have next node name
812
+ next_node = current_branch.get_node_by_name(next_node_name)
847
813
 
848
- # TODO: This has to move to execute_node?
849
- def prepare_for_execution(self):
850
- """
851
- Perform any modifications to the services prior to execution of the node.
852
-
853
- Args:
854
- node (Node): [description]
855
- map_variable (dict, optional): [description]. Defaults to None.
856
- """
857
-
858
- self._set_up_run_log(exists_ok=True)
814
+ if next_node.node_type == defaults.FAIL:
815
+ self.execute_node(next_node, map_variable=map_variable)
859
816
 
860
817
  def execute_node(
861
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
818
+ self,
819
+ node: BaseNode,
820
+ map_variable: dict[str, str | int | float] | None = None,
821
+ **kwargs,
862
822
  ):
823
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
824
+ exists_ok = error_on_existing_run_id == "false"
825
+
826
+ self._use_volumes()
827
+ self._set_up_run_log(exists_ok=exists_ok)
828
+
863
829
  step_log = self._context.run_log_store.create_step_log(
864
830
  node.name, node._get_step_log_name(map_variable)
865
831
  )
@@ -870,36 +836,30 @@ class ArgoExecutor(GenericPipelineExecutor):
870
836
  step_log.status = defaults.PROCESSING
871
837
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
872
838
 
873
- super()._execute_node(node, map_variable=map_variable, **kwargs)
874
-
875
- # Implicit fail
876
- if self._context.dag:
877
- # functions and notebooks do not have dags
878
- _, current_branch = search_node_by_internal_name(
879
- dag=self._context.dag, internal_name=node.internal_name
880
- )
881
- _, next_node_name = self._get_status_and_next_node_name(
882
- node, current_branch, map_variable=map_variable
883
- )
884
- if next_node_name:
885
- # Terminal nodes do not have next node name
886
- next_node = current_branch.get_node_by_name(next_node_name)
887
-
888
- if next_node.node_type == defaults.FAIL:
889
- self.execute_node(next_node, map_variable=map_variable)
839
+ self._execute_node(node=node, map_variable=map_variable, **kwargs)
890
840
 
841
+ # Raise exception if the step failed
891
842
  step_log = self._context.run_log_store.get_step_log(
892
843
  node._get_step_log_name(map_variable), self._context.run_id
893
844
  )
894
845
  if step_log.status == defaults.FAIL:
895
846
  raise Exception(f"Step {node.name} failed")
896
847
 
848
+ self._implicitly_fail(node, map_variable)
849
+
897
850
  def fan_out(self, node: BaseNode, map_variable: TypeMapVariable = None):
851
+ # This could be the first step of the graph
852
+ self._use_volumes()
853
+
854
+ error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
855
+ exists_ok = error_on_existing_run_id == "false"
856
+ self._set_up_run_log(exists_ok=exists_ok)
857
+
898
858
  super().fan_out(node, map_variable)
899
859
 
900
860
  # If its a map node, write the list values to "/tmp/output.txt"
901
861
  if node.node_type == "map":
902
- node = cast(MapNode, node)
862
+ assert isinstance(node, MapNode)
903
863
  iterate_on = self._context.run_log_store.get_parameters(
904
864
  self._context.run_id
905
865
  )[node.iterate_on]
@@ -907,401 +867,55 @@ class ArgoExecutor(GenericPipelineExecutor):
907
867
  with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
908
868
  json.dump(iterate_on.get_value(), myfile, indent=4)
909
869
 
910
- def sanitize_name(self, name):
911
- return name.replace(" ", "-").replace(".", "-").replace("_", "-")
912
-
913
- def get_clean_name(self, node: BaseNode):
914
- # Cache names for the node
915
- if node.internal_name not in self._clean_names:
916
- sanitized = self.sanitize_name(node.name)
917
- tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
918
- self._clean_names[node.internal_name] = (
919
- f"{sanitized}-{node.node_type}-{tag}"
920
- )
921
-
922
- return self._clean_names[node.internal_name]
923
-
924
- def compose_map_variable(
925
- self, list_of_iter_values: Optional[List] = None
926
- ) -> TypeMapVariable:
927
- map_variable = OrderedDict()
928
-
929
- # If we are inside a map node, compose a map_variable
930
- # The values of "iterate_as" are sent over as inputs to the container template
931
- if list_of_iter_values:
932
- for var in list_of_iter_values:
933
- map_variable[var] = "{{inputs.parameters." + str(var) + "}}"
934
-
935
- return map_variable # type: ignore
936
-
937
- def create_container_template(
938
- self,
939
- working_on: BaseNode,
940
- command: str,
941
- inputs: Optional[List] = None,
942
- outputs: Optional[List] = None,
943
- overwrite_name: str = "",
944
- ):
945
- effective_node_config = self._resolve_executor_config(working_on)
946
-
947
- override: Override = Override(**effective_node_config)
948
-
949
- container = Container(
950
- command=command,
951
- image=override.image,
952
- volume_mounts=self._container_volumes,
953
- image_pull_policy=override.image_pull_policy,
954
- resources=override.resources,
955
- secrets_from_k8s=self.secrets_from_k8s,
956
- )
957
-
958
- if (
959
- working_on.name == self._context.dag.start_at
960
- and self.expose_parameters_as_inputs
961
- ):
962
- for key, value in self._get_parameters().items():
963
- value = value.get_value() # type: ignore
964
- # Get the value from work flow parameters for dynamic behavior
965
- if (
966
- isinstance(value, int)
967
- or isinstance(value, float)
968
- or isinstance(value, str)
969
- ):
970
- env_var = EnvVar(
971
- name=defaults.PARAMETER_PREFIX + key,
972
- value="{{workflow.parameters." + key + "}}",
973
- )
974
- container.env_vars.append(env_var)
975
-
976
- clean_name = self.get_clean_name(working_on)
977
- if overwrite_name:
978
- clean_name = overwrite_name
979
-
980
- container_template = ContainerTemplate(
981
- name=clean_name,
982
- active_deadline_seconds=(
983
- override.max_step_duration_in_seconds
984
- if self.max_step_duration_in_seconds
985
- != override.max_step_duration_in_seconds
986
- else None
987
- ),
988
- container=container,
989
- retry_strategy=override.retry_strategy
990
- if self.retry_strategy != override.retry_strategy
991
- else None,
992
- tolerations=override.tolerations
993
- if self.tolerations != override.tolerations
994
- else None,
995
- node_selector=override.node_selector
996
- if self.node_selector != override.node_selector
997
- else None,
998
- )
999
-
1000
- # inputs are the "iterate_as" value map variables in the same order as they are observed
1001
- # We need to expose the map variables in the command of the container
1002
- if inputs:
1003
- if not container_template.inputs:
1004
- container_template.inputs = []
1005
- container_template.inputs.extend(inputs)
870
+ def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
871
+ self._use_volumes()
872
+ super().fan_in(node, map_variable)
873
+
874
+ def _use_volumes(self):
875
+ match self._context.run_log_store.service_name:
876
+ case "file-system":
877
+ self._context.run_log_store.log_folder = self._container_log_location
878
+ case "chunked-fs":
879
+ self._context.run_log_store.log_folder = self._container_log_location
880
+
881
+ match self._context.catalog_handler.service_name:
882
+ case "file-system":
883
+ self._context.catalog_handler.catalog_location = (
884
+ self._container_catalog_location
885
+ )
1006
886
 
1007
- # The map step fan out would create an output that we should propagate via Argo
1008
- if outputs:
1009
- if not container_template.outputs:
1010
- container_template.outputs = []
1011
- container_template.outputs.extend(outputs)
887
+ @cached_property
888
+ def volume_pairs(self) -> list[VolumePair]:
889
+ volume_pairs: list[VolumePair] = []
1012
890
 
1013
- return container_template
1014
-
1015
- def _create_fan_out_template(
1016
- self, composite_node, list_of_iter_values: Optional[List] = None
1017
- ):
1018
- clean_name = self.get_clean_name(composite_node)
1019
- inputs = []
1020
- # If we are fanning out already map state, we need to send the map variable inside
1021
- # The container template also should be accepting an input parameter
1022
- map_variable = None
1023
- if list_of_iter_values:
1024
- map_variable = self.compose_map_variable(
1025
- list_of_iter_values=list_of_iter_values
891
+ if self.pvc_for_runnable:
892
+ common_volume = Volume(
893
+ name="runnable",
894
+ persistent_volume_claim=PersistentVolumeClaimSource(
895
+ claim_name=self.pvc_for_runnable
896
+ ),
1026
897
  )
1027
-
1028
- for val in list_of_iter_values:
1029
- inputs.append(Parameter(name=val))
1030
-
1031
- command = utils.get_fan_command(
1032
- mode="out",
1033
- node=composite_node,
1034
- run_id=self._run_id_placeholder,
1035
- map_variable=map_variable,
1036
- log_level=self._log_level,
1037
- )
1038
-
1039
- outputs = []
1040
- # If the node is a map node, we have to set the output parameters
1041
- # Output is always the step's internal name + iterate-on
1042
- if composite_node.node_type == "map":
1043
- output_parameter = OutputParameter(name="iterate-on")
1044
- outputs.append(output_parameter)
1045
-
1046
- # Create the node now
1047
- step_config = {"command": command, "type": "task", "next": "dummy"}
1048
- node = create_node(name=f"{clean_name}-fan-out", step_config=step_config)
1049
-
1050
- container_template = self.create_container_template(
1051
- working_on=node,
1052
- command=command,
1053
- outputs=outputs,
1054
- inputs=inputs,
1055
- overwrite_name=f"{clean_name}-fan-out",
1056
- )
1057
-
1058
- self._container_templates.append(container_template)
1059
- return DagTaskTemplate(
1060
- name=f"{clean_name}-fan-out", template=f"{clean_name}-fan-out"
1061
- )
1062
-
1063
- def _create_fan_in_template(
1064
- self, composite_node, list_of_iter_values: Optional[List] = None
1065
- ):
1066
- clean_name = self.get_clean_name(composite_node)
1067
- inputs = []
1068
- # If we are fanning in already map state, we need to send the map variable inside
1069
- # The container template also should be accepting an input parameter
1070
- map_variable = None
1071
- if list_of_iter_values:
1072
- map_variable = self.compose_map_variable(
1073
- list_of_iter_values=list_of_iter_values
898
+ common_volume_mount = VolumeMount(
899
+ name="runnable",
900
+ mount_path="/tmp",
1074
901
  )
1075
-
1076
- for val in list_of_iter_values:
1077
- inputs.append(Parameter(name=val))
1078
-
1079
- command = utils.get_fan_command(
1080
- mode="in",
1081
- node=composite_node,
1082
- run_id=self._run_id_placeholder,
1083
- map_variable=map_variable,
1084
- log_level=self._log_level,
1085
- )
1086
-
1087
- step_config = {"command": command, "type": "task", "next": "dummy"}
1088
- node = create_node(name=f"{clean_name}-fan-in", step_config=step_config)
1089
- container_template = self.create_container_template(
1090
- working_on=node,
1091
- command=command,
1092
- inputs=inputs,
1093
- overwrite_name=f"{clean_name}-fan-in",
1094
- )
1095
- self._container_templates.append(container_template)
1096
- clean_name = self.get_clean_name(composite_node)
1097
- return DagTaskTemplate(
1098
- name=f"{clean_name}-fan-in", template=f"{clean_name}-fan-in"
1099
- )
1100
-
1101
- def _gather_task_templates_of_dag(
1102
- self,
1103
- dag: Graph,
1104
- dag_name="runnable-dag",
1105
- list_of_iter_values: Optional[List] = None,
1106
- ):
1107
- current_node = dag.start_at
1108
- previous_node = None
1109
- previous_node_template_name = None
1110
-
1111
- templates: Dict[str, DagTaskTemplate] = {}
1112
-
1113
- if not list_of_iter_values:
1114
- list_of_iter_values = []
1115
-
1116
- while True:
1117
- working_on = dag.get_node_by_name(current_node)
1118
- if previous_node == current_node:
1119
- raise Exception("Potentially running in a infinite loop")
1120
-
1121
- render_obj = get_renderer(working_on)(executor=self, node=working_on)
1122
- render_obj.render(list_of_iter_values=list_of_iter_values.copy())
1123
-
1124
- clean_name = self.get_clean_name(working_on)
1125
-
1126
- # If a task template for clean name exists, retrieve it (could have been created by on_failure)
1127
- template = templates.get(
1128
- clean_name, DagTaskTemplate(name=clean_name, template=clean_name)
902
+ volume_pairs.append(
903
+ VolumePair(volume=common_volume, volume_mount=common_volume_mount)
1129
904
  )
1130
-
1131
- # Link the current node to previous node, if the previous node was successful.
1132
- if previous_node:
1133
- template.depends.append(f"{previous_node_template_name}.Succeeded")
1134
-
1135
- templates[clean_name] = template
1136
-
1137
- # On failure nodes
1138
- if (
1139
- working_on.node_type not in ["success", "fail"]
1140
- and working_on._get_on_failure_node()
1141
- ):
1142
- failure_node = dag.get_node_by_name(working_on._get_on_failure_node())
1143
-
1144
- # same logic, if a template exists, retrieve it
1145
- # if not, create a new one
1146
- render_obj = get_renderer(working_on)(executor=self, node=failure_node)
1147
- render_obj.render(list_of_iter_values=list_of_iter_values.copy())
1148
-
1149
- failure_template_name = self.get_clean_name(failure_node)
1150
- # If a task template for clean name exists, retrieve it
1151
- failure_template = templates.get(
1152
- failure_template_name,
1153
- DagTaskTemplate(
1154
- name=failure_template_name, template=failure_template_name
905
+ counter = 0
906
+ for custom_volume in self.custom_volumes or []:
907
+ name = f"custom-volume-{counter}"
908
+ volume_pairs.append(
909
+ VolumePair(
910
+ volume=Volume(
911
+ name=name,
912
+ persistent_volume_claim=custom_volume.persistent_volume_claim,
913
+ ),
914
+ volume_mount=VolumeMount(
915
+ name=name,
916
+ mount_path=custom_volume.mount_path,
1155
917
  ),
1156
918
  )
1157
- failure_template.depends.append(f"{clean_name}.Failed")
1158
- templates[failure_template_name] = failure_template
1159
-
1160
- # If we are in a map node, we need to add the values as arguments
1161
- template = templates[clean_name]
1162
- if list_of_iter_values:
1163
- if not template.arguments:
1164
- template.arguments = []
1165
- for value in list_of_iter_values:
1166
- template.arguments.append(
1167
- Argument(
1168
- name=value, value="{{inputs.parameters." + value + "}}"
1169
- )
1170
- )
1171
-
1172
- # Move ahead to the next node
1173
- previous_node = current_node
1174
- previous_node_template_name = self.get_clean_name(working_on)
1175
-
1176
- if working_on.node_type in ["success", "fail"]:
1177
- break
1178
-
1179
- current_node = working_on._get_next_node()
1180
-
1181
- # Add the iteration values as input to dag template
1182
- dag_template = DagTemplate(tasks=list(templates.values()), name=dag_name)
1183
- if list_of_iter_values:
1184
- if not dag_template.inputs:
1185
- dag_template.inputs = []
1186
- dag_template.inputs.extend(
1187
- [Parameter(name=val) for val in list_of_iter_values]
1188
- )
1189
-
1190
- # Add the dag template to the list of templates
1191
- self._dag_templates.append(dag_template)
1192
-
1193
- def _get_template_defaults(self) -> TemplateDefaults:
1194
- user_provided_config = self.model_dump(by_alias=False)
1195
-
1196
- return TemplateDefaults(**user_provided_config)
1197
-
1198
- def execute_graph(self, dag: Graph, map_variable: Optional[dict] = None, **kwargs):
1199
- # TODO: Add metadata
1200
- arguments = []
1201
- # Expose "simple" parameters as workflow arguments for dynamic behavior
1202
- if self.expose_parameters_as_inputs:
1203
- for key, value in self._get_parameters().items():
1204
- value = value.get_value() # type: ignore
1205
- if isinstance(value, dict) or isinstance(value, list):
1206
- continue
1207
-
1208
- env_var = EnvVar(name=key, value=value) # type: ignore
1209
- arguments.append(env_var)
1210
-
1211
- run_id_var = EnvVar(name="run_id", value="{{workflow.uid}}")
1212
- log_level_var = EnvVar(name="log_level", value=defaults.LOG_LEVEL)
1213
- arguments.append(run_id_var)
1214
- arguments.append(log_level_var)
1215
-
1216
- # TODO: Can we do reruns?
1217
-
1218
- for volume in self.spec.volumes:
1219
- self._container_volumes.append(
1220
- ContainerVolume(name=volume.name, mount_path=volume.mount_path)
1221
- )
1222
-
1223
- # Container specifications are globally collected and added at the end.
1224
- # Dag specifications are added as part of the dag traversal.
1225
- templates: List[Union[DagTemplate, ContainerTemplate]] = []
1226
- self._gather_task_templates_of_dag(dag=dag, list_of_iter_values=[])
1227
- templates.extend(self._dag_templates)
1228
- templates.extend(self._container_templates)
1229
-
1230
- spec = self.spec
1231
- spec.templates = templates
1232
- spec.arguments = arguments
1233
- workflow = Workflow(metadata=self.metadata, spec=spec)
1234
-
1235
- yaml = YAML()
1236
- with open(self.output_file, "w") as f:
1237
- yaml.indent(mapping=2, sequence=4, offset=2)
1238
-
1239
- yaml.dump(workflow.model_dump(by_alias=True, exclude_none=True), f)
1240
-
1241
- def send_return_code(self, stage="traversal"):
1242
- """
1243
- Convenience function used by pipeline to send return code to the caller of the cli
1244
-
1245
- Raises:
1246
- Exception: If the pipeline execution failed
1247
- """
1248
- if (
1249
- stage != "traversal"
1250
- ): # traversal does no actual execution, so return code is pointless
1251
- run_id = self._context.run_id
1252
-
1253
- run_log = self._context.run_log_store.get_run_log_by_id(
1254
- run_id=run_id, full=False
1255
919
  )
1256
- if run_log.status == defaults.FAIL:
1257
- raise exceptions.ExecutionFailedError(run_id)
1258
-
1259
-
1260
- # TODO:
1261
- # class FileSystemRunLogStore(BaseIntegration):
1262
- # """
1263
- # Only local execution mode is possible for Buffered Run Log store
1264
- # """
1265
-
1266
- # executor_type = "argo"
1267
- # service_type = "run_log_store" # One of secret, catalog, datastore
1268
- # service_provider = "file-system" # The actual implementation of the service
1269
-
1270
- # def validate(self, **kwargs):
1271
- # msg = (
1272
- # "Argo cannot run work with file-system run log store. "
1273
- # "Unless you have made a mechanism to use volume mounts."
1274
- # "Using this run log store if the pipeline has concurrent tasks might lead to unexpected results"
1275
- # )
1276
- # logger.warning(msg)
1277
-
1278
-
1279
- # class ChunkedFileSystemRunLogStore(BaseIntegration):
1280
- # """
1281
- # Only local execution mode is possible for Buffered Run Log store
1282
- # """
1283
-
1284
- # executor_type = "argo"
1285
- # service_type = "run_log_store" # One of secret, catalog, datastore
1286
- # service_provider = "chunked-fs" # The actual implementation of the service
1287
-
1288
- # def validate(self, **kwargs):
1289
- # msg = (
1290
- # "Argo cannot run work with chunked file-system run log store. "
1291
- # "Unless you have made a mechanism to use volume mounts"
1292
- # )
1293
- # logger.warning(msg)
1294
-
1295
-
1296
- # class FileSystemCatalog(BaseIntegration):
1297
- # """
1298
- # Only local execution mode is possible for Buffered Run Log store
1299
- # """
1300
-
1301
- # executor_type = "argo"
1302
- # service_type = "catalog" # One of secret, catalog, datastore
1303
- # service_provider = "file-system" # The actual implementation of the service
1304
-
1305
- # def validate(self, **kwargs):
1306
- # msg = "Argo cannot run work with file-system run log store. Unless you have made a mechanism to use volume mounts"
1307
- # logger.warning(msg)
920
+ counter += 1
921
+ return volume_pairs