runnable 0.50.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/README.md +0 -0
- extensions/__init__.py +0 -0
- extensions/catalog/README.md +0 -0
- extensions/catalog/any_path.py +214 -0
- extensions/catalog/file_system.py +52 -0
- extensions/catalog/minio.py +72 -0
- extensions/catalog/pyproject.toml +14 -0
- extensions/catalog/s3.py +11 -0
- extensions/job_executor/README.md +0 -0
- extensions/job_executor/__init__.py +236 -0
- extensions/job_executor/emulate.py +70 -0
- extensions/job_executor/k8s.py +553 -0
- extensions/job_executor/k8s_job_spec.yaml +37 -0
- extensions/job_executor/local.py +35 -0
- extensions/job_executor/local_container.py +161 -0
- extensions/job_executor/pyproject.toml +16 -0
- extensions/nodes/README.md +0 -0
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +301 -0
- extensions/nodes/fail.py +78 -0
- extensions/nodes/loop.py +394 -0
- extensions/nodes/map.py +477 -0
- extensions/nodes/parallel.py +281 -0
- extensions/nodes/pyproject.toml +15 -0
- extensions/nodes/stub.py +93 -0
- extensions/nodes/success.py +78 -0
- extensions/nodes/task.py +156 -0
- extensions/pipeline_executor/README.md +0 -0
- extensions/pipeline_executor/__init__.py +871 -0
- extensions/pipeline_executor/argo.py +1266 -0
- extensions/pipeline_executor/emulate.py +119 -0
- extensions/pipeline_executor/local.py +226 -0
- extensions/pipeline_executor/local_container.py +369 -0
- extensions/pipeline_executor/mocked.py +159 -0
- extensions/pipeline_executor/pyproject.toml +16 -0
- extensions/run_log_store/README.md +0 -0
- extensions/run_log_store/__init__.py +0 -0
- extensions/run_log_store/any_path.py +100 -0
- extensions/run_log_store/chunked_fs.py +122 -0
- extensions/run_log_store/chunked_minio.py +141 -0
- extensions/run_log_store/file_system.py +91 -0
- extensions/run_log_store/generic_chunked.py +549 -0
- extensions/run_log_store/minio.py +114 -0
- extensions/run_log_store/pyproject.toml +15 -0
- extensions/secrets/README.md +0 -0
- extensions/secrets/dotenv.py +62 -0
- extensions/secrets/pyproject.toml +15 -0
- runnable/__init__.py +108 -0
- runnable/catalog.py +141 -0
- runnable/cli.py +484 -0
- runnable/context.py +730 -0
- runnable/datastore.py +1058 -0
- runnable/defaults.py +159 -0
- runnable/entrypoints.py +390 -0
- runnable/exceptions.py +137 -0
- runnable/executor.py +561 -0
- runnable/gantt.py +1646 -0
- runnable/graph.py +501 -0
- runnable/names.py +546 -0
- runnable/nodes.py +593 -0
- runnable/parameters.py +217 -0
- runnable/pickler.py +96 -0
- runnable/sdk.py +1277 -0
- runnable/secrets.py +92 -0
- runnable/tasks.py +1268 -0
- runnable/telemetry.py +142 -0
- runnable/utils.py +423 -0
- runnable-0.50.0.dist-info/METADATA +189 -0
- runnable-0.50.0.dist-info/RECORD +72 -0
- runnable-0.50.0.dist-info/WHEEL +4 -0
- runnable-0.50.0.dist-info/entry_points.txt +53 -0
- runnable-0.50.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1266 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import secrets
|
|
5
|
+
import shlex
|
|
6
|
+
import string
|
|
7
|
+
from collections import namedtuple
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
from typing import Annotated, Any, Literal, Optional, cast
|
|
11
|
+
|
|
12
|
+
from pydantic import (
|
|
13
|
+
BaseModel,
|
|
14
|
+
ConfigDict,
|
|
15
|
+
Field,
|
|
16
|
+
PlainSerializer,
|
|
17
|
+
PrivateAttr,
|
|
18
|
+
computed_field,
|
|
19
|
+
model_validator,
|
|
20
|
+
)
|
|
21
|
+
from pydantic.alias_generators import to_camel
|
|
22
|
+
from ruamel.yaml import YAML
|
|
23
|
+
|
|
24
|
+
from extensions.nodes.conditional import ConditionalNode
|
|
25
|
+
from extensions.nodes.map import MapNode
|
|
26
|
+
from extensions.nodes.parallel import ParallelNode
|
|
27
|
+
from extensions.nodes.task import TaskNode
|
|
28
|
+
from extensions.pipeline_executor import GenericPipelineExecutor
|
|
29
|
+
from runnable import defaults, exceptions
|
|
30
|
+
from runnable.datastore import StepAttempt
|
|
31
|
+
from runnable.defaults import IterableParameterModel, MapVariableModel
|
|
32
|
+
from runnable.graph import Graph, search_node_by_internal_name
|
|
33
|
+
from runnable.nodes import BaseNode
|
|
34
|
+
|
|
35
|
+
# TODO: Do we need a PVC if we are using remote storage?
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BaseModelWIthConfig(BaseModel, use_enum_values=True):
|
|
39
|
+
model_config = ConfigDict(
|
|
40
|
+
extra="forbid",
|
|
41
|
+
alias_generator=to_camel,
|
|
42
|
+
populate_by_name=True,
|
|
43
|
+
from_attributes=True,
|
|
44
|
+
validate_default=True,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BackOff(BaseModelWIthConfig):
|
|
49
|
+
duration: str = Field(default="2m")
|
|
50
|
+
factor: float = Field(default=2)
|
|
51
|
+
max_duration: str = Field(default="1h")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RetryStrategy(BaseModelWIthConfig):
|
|
55
|
+
back_off: Optional[BackOff] = Field(default=None)
|
|
56
|
+
limit: int = 0
|
|
57
|
+
retry_policy: str = Field(default="Always")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SecretEnvVar(BaseModel):
|
|
61
|
+
"""
|
|
62
|
+
Renders:
|
|
63
|
+
env:
|
|
64
|
+
- name: MYSECRETPASSWORD
|
|
65
|
+
valueFrom:
|
|
66
|
+
secretKeyRef:
|
|
67
|
+
name: my-secret
|
|
68
|
+
key: MYSECRETPASSWORD
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
environment_variable: str = Field(serialization_alias="name")
|
|
72
|
+
secret_name: str = Field(exclude=True)
|
|
73
|
+
secret_key: str = Field(exclude=True)
|
|
74
|
+
|
|
75
|
+
@computed_field # type: ignore
|
|
76
|
+
@property
|
|
77
|
+
def valueFrom(self) -> dict[str, dict[str, str]]:
|
|
78
|
+
return {
|
|
79
|
+
"secretKeyRef": {
|
|
80
|
+
"name": self.secret_name,
|
|
81
|
+
"key": self.secret_key,
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class EnvVar(BaseModelWIthConfig):
|
|
87
|
+
name: str
|
|
88
|
+
value: str
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class OutputParameter(BaseModelWIthConfig):
|
|
92
|
+
name: str
|
|
93
|
+
value_from: dict[str, str] = {
|
|
94
|
+
"path": "/tmp/output.txt",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Parameter(BaseModelWIthConfig):
|
|
99
|
+
name: str
|
|
100
|
+
value: Optional[str | int | float | bool] = Field(default=None)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class Inputs(BaseModelWIthConfig):
|
|
104
|
+
parameters: Optional[list[Parameter]] = Field(default=None)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Outputs(BaseModelWIthConfig):
|
|
108
|
+
parameters: Optional[list[OutputParameter]] = Field(default=None)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class Arguments(BaseModelWIthConfig):
|
|
112
|
+
parameters: Optional[list[Parameter]] = Field(default=None)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ConfigMapCache(BaseModelWIthConfig):
|
|
116
|
+
name: str
|
|
117
|
+
key: str = Field(default="cache")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class Cache(BaseModelWIthConfig):
|
|
121
|
+
config_map: ConfigMapCache
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Memoize(BaseModelWIthConfig):
|
|
125
|
+
key: str
|
|
126
|
+
cache: Optional[Cache] = Field(default=None)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class TolerationEffect(str, Enum):
|
|
130
|
+
NoSchedule = "NoSchedule"
|
|
131
|
+
PreferNoSchedule = "PreferNoSchedule"
|
|
132
|
+
NoExecute = "NoExecute"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TolerationOperator(str, Enum):
|
|
136
|
+
Exists = "Exists"
|
|
137
|
+
Equal = "Equal"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class PodMetaData(BaseModelWIthConfig):
|
|
141
|
+
annotations: dict[str, str] = Field(default_factory=dict)
|
|
142
|
+
labels: dict[str, str] = Field(default_factory=dict)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class Toleration(BaseModelWIthConfig):
|
|
146
|
+
effect: Optional[TolerationEffect] = Field(default=None)
|
|
147
|
+
key: Optional[str] = Field(default=None)
|
|
148
|
+
operator: TolerationOperator = Field(default=TolerationOperator.Equal)
|
|
149
|
+
tolerationSeconds: Optional[int] = Field(default=None)
|
|
150
|
+
value: Optional[str] = Field(default=None)
|
|
151
|
+
|
|
152
|
+
@model_validator(mode="after")
|
|
153
|
+
def validate_tolerations(self) -> "Toleration":
|
|
154
|
+
if not self.key:
|
|
155
|
+
if self.operator != TolerationOperator.Exists:
|
|
156
|
+
raise ValueError("Toleration key is required when operator is Equal")
|
|
157
|
+
|
|
158
|
+
if self.operator == TolerationOperator.Exists:
|
|
159
|
+
if self.value:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Toleration value is not allowed when operator is Exists"
|
|
162
|
+
)
|
|
163
|
+
return self
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ImagePullPolicy(str, Enum):
|
|
167
|
+
Always = "Always"
|
|
168
|
+
IfNotPresent = "IfNotPresent"
|
|
169
|
+
Never = "Never"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class PersistentVolumeClaimSource(BaseModelWIthConfig):
|
|
173
|
+
claim_name: str
|
|
174
|
+
read_only: bool = Field(default=False)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Volume(BaseModelWIthConfig):
|
|
178
|
+
name: str
|
|
179
|
+
persistent_volume_claim: PersistentVolumeClaimSource
|
|
180
|
+
|
|
181
|
+
def __hash__(self):
|
|
182
|
+
return hash(self.name)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class VolumeMount(BaseModelWIthConfig):
|
|
186
|
+
mount_path: str
|
|
187
|
+
name: str
|
|
188
|
+
read_only: bool = Field(default=False)
|
|
189
|
+
|
|
190
|
+
@model_validator(mode="after")
|
|
191
|
+
def validate_volume_mount(self) -> "VolumeMount":
|
|
192
|
+
if "." in self.mount_path:
|
|
193
|
+
raise ValueError("mount_path cannot contain '.'")
|
|
194
|
+
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
VolumePair = namedtuple("VolumePair", ["volume", "volume_mount"])
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class LabelSelectorRequirement(BaseModelWIthConfig):
|
|
202
|
+
key: str
|
|
203
|
+
operator: str
|
|
204
|
+
values: list[str]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class PodGCStrategy(str, Enum):
|
|
208
|
+
OnPodCompletion = "OnPodCompletion"
|
|
209
|
+
OnPodSuccess = "OnPodSuccess"
|
|
210
|
+
OnWorkflowCompletion = "OnWorkflowCompletion"
|
|
211
|
+
OnWorkflowSuccess = "OnWorkflowSuccess"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class LabelSelector(BaseModelWIthConfig):
|
|
215
|
+
matchExpressions: list[LabelSelectorRequirement] = Field(default_factory=list)
|
|
216
|
+
matchLabels: dict[str, str] = Field(default_factory=dict)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class PodGC(BaseModelWIthConfig):
|
|
220
|
+
delete_delay_duration: str = Field(default="1h") # 1 hour
|
|
221
|
+
label_selector: Optional[LabelSelector] = Field(default=None)
|
|
222
|
+
strategy: Optional[PodGCStrategy] = Field(default=None)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class Request(BaseModel):
|
|
226
|
+
"""
|
|
227
|
+
The default requests
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
memory: str = "1Gi"
|
|
231
|
+
cpu: str = "250m"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
VendorGPU = Annotated[
|
|
235
|
+
Optional[int],
|
|
236
|
+
PlainSerializer(lambda x: str(x), return_type=str, when_used="unless-none"),
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class Limit(Request):
|
|
241
|
+
"""
|
|
242
|
+
The default limits
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class Resources(BaseModel):
|
|
249
|
+
limits: Limit = Field(default=Limit(), serialization_alias="limits")
|
|
250
|
+
requests: Request = Field(default=Request(), serialization_alias="requests")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# Lets construct this from UserDefaults
|
|
254
|
+
class ArgoTemplateDefaults(BaseModelWIthConfig):
|
|
255
|
+
active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
|
|
256
|
+
fail_fast: bool = Field(default=True)
|
|
257
|
+
node_selector: dict[str, str] = Field(default_factory=dict)
|
|
258
|
+
parallelism: Optional[int] = Field(default=None)
|
|
259
|
+
retry_strategy: Optional[RetryStrategy] = Field(default=None)
|
|
260
|
+
timeout: Optional[str] = Field(default=None)
|
|
261
|
+
tolerations: Optional[list[Toleration]] = Field(default=None)
|
|
262
|
+
|
|
263
|
+
model_config = ConfigDict(
|
|
264
|
+
extra="ignore",
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class CommonDefaults(BaseModelWIthConfig):
|
|
269
|
+
active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
|
|
270
|
+
fail_fast: bool = Field(default=True)
|
|
271
|
+
node_selector: dict[str, str] = Field(default_factory=dict)
|
|
272
|
+
parallelism: Optional[int] = Field(default=None)
|
|
273
|
+
retry_strategy: Optional[RetryStrategy] = Field(default=None)
|
|
274
|
+
timeout: Optional[str] = Field(default=None)
|
|
275
|
+
tolerations: Optional[list[Toleration]] = Field(default=None)
|
|
276
|
+
image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.Always)
|
|
277
|
+
resources: Resources = Field(default_factory=Resources)
|
|
278
|
+
env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# The user provided defaults at the top level
|
|
282
|
+
class UserDefaults(CommonDefaults):
|
|
283
|
+
image: str
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# Overrides need not have image
|
|
287
|
+
class Overrides(CommonDefaults):
|
|
288
|
+
image: Optional[str] = Field(default=None)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# User provides this as part of the argoSpec
|
|
292
|
+
# some an be provided here or as a template default or node override
|
|
293
|
+
class ArgoWorkflowSpec(BaseModelWIthConfig):
|
|
294
|
+
active_deadline_seconds: int = Field(default=86400) # 1 day for the whole workflow
|
|
295
|
+
arguments: Optional[Arguments] = Field(default=None)
|
|
296
|
+
entrypoint: Literal["runnable-dag"] = Field(default="runnable-dag", frozen=True)
|
|
297
|
+
node_selector: dict[str, str] = Field(default_factory=dict)
|
|
298
|
+
parallelism: Optional[int] = Field(default=None) # GLobal parallelism
|
|
299
|
+
pod_gc: Optional[PodGC] = Field(default=None, serialization_alias="podGC")
|
|
300
|
+
retry_strategy: Optional[RetryStrategy] = Field(default=None)
|
|
301
|
+
service_account_name: Optional[str] = Field(default=None)
|
|
302
|
+
tolerations: Optional[list[Toleration]] = Field(default=None)
|
|
303
|
+
template_defaults: Optional[ArgoTemplateDefaults] = Field(default=None)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class ArgoMetadata(BaseModelWIthConfig):
|
|
307
|
+
annotations: Optional[dict[str, str]] = Field(default=None)
|
|
308
|
+
generate_name: str # User can mention this to uniquely identify the run
|
|
309
|
+
labels: dict[str, str] = Field(default_factory=dict)
|
|
310
|
+
namespace: Optional[str] = Field(default="default")
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class ArgoWorkflow(BaseModelWIthConfig):
|
|
314
|
+
apiVersion: Literal["argoproj.io/v1alpha1"] = Field(
|
|
315
|
+
default="argoproj.io/v1alpha1", frozen=True
|
|
316
|
+
)
|
|
317
|
+
kind: Literal["Workflow"] = Field(default="Workflow", frozen=True)
|
|
318
|
+
metadata: ArgoMetadata
|
|
319
|
+
spec: ArgoWorkflowSpec
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class CronSchedule(BaseModelWIthConfig):
|
|
323
|
+
"""Minimal cron schedule configuration for CronWorkflows."""
|
|
324
|
+
|
|
325
|
+
schedules: list[str] # Cron expressions, e.g. ["0 0 * * *"]
|
|
326
|
+
timezone: Optional[str] = Field(default=None) # e.g. "America/Los_Angeles"
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# The below are not visible to the user
|
|
330
|
+
class DagTask(BaseModelWIthConfig):
|
|
331
|
+
name: str
|
|
332
|
+
template: str # Should be name of a container template or dag template
|
|
333
|
+
arguments: Optional[Arguments] = Field(default=None)
|
|
334
|
+
with_param: Optional[str] = Field(default=None)
|
|
335
|
+
when_param: Optional[str] = Field(default=None, serialization_alias="when")
|
|
336
|
+
depends: Optional[str] = Field(default=None)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class CoreDagTemplate(BaseModelWIthConfig):
|
|
340
|
+
tasks: list[DagTask] = Field(default_factory=list[DagTask])
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class CoreContainerTemplate(BaseModelWIthConfig):
|
|
344
|
+
image: str
|
|
345
|
+
command: list[str]
|
|
346
|
+
image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.IfNotPresent)
|
|
347
|
+
env: list[EnvVar | SecretEnvVar] = Field(default_factory=list)
|
|
348
|
+
volume_mounts: list[VolumeMount] = Field(default_factory=list)
|
|
349
|
+
resources: Resources = Field(default_factory=Resources)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class DagTemplate(BaseModelWIthConfig):
|
|
353
|
+
name: str
|
|
354
|
+
dag: CoreDagTemplate = Field(default_factory=CoreDagTemplate)
|
|
355
|
+
inputs: Optional[Inputs] = Field(default=None)
|
|
356
|
+
parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
|
|
357
|
+
fail_fast: bool = Field(default=True)
|
|
358
|
+
|
|
359
|
+
model_config = ConfigDict(
|
|
360
|
+
extra="ignore",
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def __hash__(self):
|
|
364
|
+
return hash(self.name)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class ContainerTemplate((BaseModelWIthConfig)):
|
|
368
|
+
name: str
|
|
369
|
+
container: CoreContainerTemplate
|
|
370
|
+
inputs: Optional[Inputs] = Field(default=None)
|
|
371
|
+
outputs: Optional[Outputs] = Field(default=None)
|
|
372
|
+
memoize: Optional[Memoize] = Field(default=None)
|
|
373
|
+
|
|
374
|
+
active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
|
|
375
|
+
metadata: Optional[PodMetaData] = Field(default=None)
|
|
376
|
+
node_selector: dict[str, str] = Field(default_factory=dict)
|
|
377
|
+
parallelism: Optional[int] = Field(default=None) # Not sure if this is needed
|
|
378
|
+
retry_strategy: Optional[RetryStrategy] = Field(default=None)
|
|
379
|
+
timeout: Optional[str] = Field(default=None)
|
|
380
|
+
tolerations: Optional[list[Toleration]] = Field(default=None)
|
|
381
|
+
volumes: Optional[list[Volume]] = Field(default=None)
|
|
382
|
+
|
|
383
|
+
model_config = ConfigDict(
|
|
384
|
+
extra="ignore",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def __hash__(self):
|
|
388
|
+
return hash(self.name)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class CustomVolume(BaseModelWIthConfig):
|
|
392
|
+
mount_path: str
|
|
393
|
+
persistent_volume_claim: PersistentVolumeClaimSource
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class ArgoExecutor(GenericPipelineExecutor):
|
|
397
|
+
"""
|
|
398
|
+
Executes the pipeline using Argo Workflows.
|
|
399
|
+
|
|
400
|
+
The defaults configuration is kept similar to the
|
|
401
|
+
[Argo Workflow spec](https://argo-workflows.readthedocs.io/en/latest/fields/#workflow).
|
|
402
|
+
|
|
403
|
+
Configuration:
|
|
404
|
+
|
|
405
|
+
```yaml
|
|
406
|
+
pipeline-executor:
|
|
407
|
+
type: argo
|
|
408
|
+
config:
|
|
409
|
+
pvc_for_runnable: "my-pvc"
|
|
410
|
+
custom_volumes:
|
|
411
|
+
- mount_path: "/tmp"
|
|
412
|
+
persistent_volume_claim:
|
|
413
|
+
claim_name: "my-pvc"
|
|
414
|
+
read_only: false/true
|
|
415
|
+
expose_parameters_as_inputs: true/false
|
|
416
|
+
configmap_cache_name: "my-cache-name" # Optional: defaults to runnable-xxxxxx
|
|
417
|
+
secrets_from_k8s:
|
|
418
|
+
- key1
|
|
419
|
+
- key2
|
|
420
|
+
- ...
|
|
421
|
+
output_file: "argo-pipeline.yaml"
|
|
422
|
+
log_level: "DEBUG"/"INFO"/"WARNING"/"ERROR"/"CRITICAL"
|
|
423
|
+
cron_schedule: # Optional: generates CronWorkflow instead of Workflow
|
|
424
|
+
schedules:
|
|
425
|
+
- "0 0 * * *" # Cron expressions
|
|
426
|
+
timezone: "UTC" # Optional timezone
|
|
427
|
+
defaults:
|
|
428
|
+
image: "my-image"
|
|
429
|
+
activeDeadlineSeconds: 86400
|
|
430
|
+
failFast: true
|
|
431
|
+
nodeSelector:
|
|
432
|
+
label: value
|
|
433
|
+
parallelism: 1
|
|
434
|
+
retryStrategy:
|
|
435
|
+
backoff:
|
|
436
|
+
duration: "2m"
|
|
437
|
+
factor: 2
|
|
438
|
+
maxDuration: "1h"
|
|
439
|
+
limit: 0
|
|
440
|
+
retryPolicy: "Always"
|
|
441
|
+
timeout: "1h"
|
|
442
|
+
tolerations:
|
|
443
|
+
imagePullPolicy: "Always"/"IfNotPresent"/"Never"
|
|
444
|
+
resources:
|
|
445
|
+
limits:
|
|
446
|
+
memory: "1Gi"
|
|
447
|
+
cpu: "250m"
|
|
448
|
+
gpu: 0
|
|
449
|
+
requests:
|
|
450
|
+
memory: "1Gi"
|
|
451
|
+
cpu: "250m"
|
|
452
|
+
env:
|
|
453
|
+
- name: "MY_ENV"
|
|
454
|
+
value: "my-value"
|
|
455
|
+
- name: secret_env
|
|
456
|
+
secretName: "my-secret"
|
|
457
|
+
secretKey: "my-key"
|
|
458
|
+
overrides:
|
|
459
|
+
key1:
|
|
460
|
+
... similar structure to defaults
|
|
461
|
+
|
|
462
|
+
argoWorkflow:
|
|
463
|
+
metadata:
|
|
464
|
+
annotations:
|
|
465
|
+
key1: value1
|
|
466
|
+
key2: value2
|
|
467
|
+
generateName: "my-workflow"
|
|
468
|
+
labels:
|
|
469
|
+
key1: value1
|
|
470
|
+
|
|
471
|
+
```
|
|
472
|
+
|
|
473
|
+
As of now, ```runnable``` needs a pvc to store the logs and the catalog; provided by ```pvc_for_runnable```.
|
|
474
|
+
- ```custom_volumes``` can be used to mount additional volumes to the container.
|
|
475
|
+
|
|
476
|
+
- ```expose_parameters_as_inputs``` can be used to expose the initial parameters as inputs to the workflow.
|
|
477
|
+
- ```secrets_from_k8s``` can be used to expose the secrets from the k8s secret store.
|
|
478
|
+
- ```output_file``` is the file where the argo pipeline will be dumped.
|
|
479
|
+
- ```log_level``` is the log level for the containers.
|
|
480
|
+
- ```cron_schedule``` generates an Argo CronWorkflow instead of a regular Workflow for scheduled execution.
|
|
481
|
+
- ```defaults``` is the default configuration for all the containers.
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
service_name: str = "argo"
|
|
487
|
+
_should_setup_run_log_at_traversal: bool = PrivateAttr(default=False)
|
|
488
|
+
mock: bool = False
|
|
489
|
+
|
|
490
|
+
model_config = ConfigDict(
|
|
491
|
+
extra="forbid",
|
|
492
|
+
alias_generator=to_camel,
|
|
493
|
+
populate_by_name=True,
|
|
494
|
+
from_attributes=True,
|
|
495
|
+
use_enum_values=True,
|
|
496
|
+
)
|
|
497
|
+
pvc_for_runnable: Optional[str] = Field(default=None)
|
|
498
|
+
custom_volumes: Optional[list[CustomVolume]] = Field(
|
|
499
|
+
default_factory=list[CustomVolume]
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
expose_parameters_as_inputs: bool = True
|
|
503
|
+
configmap_cache_name: Optional[str] = Field(default=None)
|
|
504
|
+
secret_from_k8s: Optional[str] = Field(default=None)
|
|
505
|
+
output_file: str = Field(default="argo-pipeline.yaml")
|
|
506
|
+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
|
|
507
|
+
default="INFO"
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
defaults: UserDefaults
|
|
511
|
+
argo_workflow: ArgoWorkflow
|
|
512
|
+
cron_schedule: Optional[CronSchedule] = Field(default=None)
|
|
513
|
+
|
|
514
|
+
overrides: dict[str, Overrides] = Field(default_factory=dict)
|
|
515
|
+
|
|
516
|
+
# This should be used when we refer to run_id or log_level in the containers
|
|
517
|
+
_run_id_as_parameter: str = PrivateAttr(default="{{workflow.parameters.run_id}}")
|
|
518
|
+
_log_level_as_parameter: str = PrivateAttr(
|
|
519
|
+
default="{{workflow.parameters.log_level}}"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
_templates: list[ContainerTemplate | DagTemplate] = PrivateAttr(
|
|
523
|
+
default_factory=list
|
|
524
|
+
)
|
|
525
|
+
_container_log_location: str = PrivateAttr(default="/tmp/run_logs/")
|
|
526
|
+
_container_catalog_location: str = PrivateAttr(default="/tmp/catalog/")
|
|
527
|
+
_added_initial_container: bool = PrivateAttr(default=False)
|
|
528
|
+
_cache_name: str = PrivateAttr(default="")
|
|
529
|
+
|
|
530
|
+
def model_post_init(self, __context: Any) -> None:
|
|
531
|
+
self.argo_workflow.spec.template_defaults = ArgoTemplateDefaults(
|
|
532
|
+
**self.defaults.model_dump()
|
|
533
|
+
)
|
|
534
|
+
self._cache_name = self._generate_cache_name()
|
|
535
|
+
|
|
536
|
+
def _generate_cache_name(self) -> str:
|
|
537
|
+
"""Generate or use configured ConfigMap name for this workflow's cache."""
|
|
538
|
+
if self.configmap_cache_name:
|
|
539
|
+
return self.configmap_cache_name
|
|
540
|
+
|
|
541
|
+
chars = string.ascii_lowercase + string.digits
|
|
542
|
+
suffix = "".join(secrets.choice(chars) for _ in range(6))
|
|
543
|
+
return f"runnable-{suffix}"
|
|
544
|
+
|
|
545
|
+
@property
|
|
546
|
+
def cache_name(self) -> str:
|
|
547
|
+
"""Get the ConfigMap name for this workflow's memoization cache."""
|
|
548
|
+
return self._cache_name
|
|
549
|
+
|
|
550
|
+
def sanitize_name(self, name: str) -> str:
|
|
551
|
+
formatted_name = name.replace(" ", "-").replace(".", "-").replace("_", "-")
|
|
552
|
+
tag = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
|
|
553
|
+
unique_name = f"{formatted_name}-{tag}"
|
|
554
|
+
unique_name = unique_name.replace("map-variable-placeholder-", "")
|
|
555
|
+
return unique_name
|
|
556
|
+
|
|
557
|
+
def _add_retry_env_vars(self, container_template: CoreContainerTemplate):
|
|
558
|
+
"""Add retry environment variables to all containers."""
|
|
559
|
+
# Add retry run id environment variable
|
|
560
|
+
retry_run_id_env = EnvVar(
|
|
561
|
+
name=defaults.RETRY_RUN_ID, value="{{workflow.parameters.retry_run_id}}"
|
|
562
|
+
)
|
|
563
|
+
container_template.env.append(retry_run_id_env)
|
|
564
|
+
|
|
565
|
+
# Add retry indicator environment variable
|
|
566
|
+
retry_indicator_env = EnvVar(
|
|
567
|
+
name=defaults.RETRY_INDICATOR,
|
|
568
|
+
value="{{workflow.parameters.retry_indicator}}",
|
|
569
|
+
)
|
|
570
|
+
container_template.env.append(retry_indicator_env)
|
|
571
|
+
|
|
572
|
+
def _set_up_initial_container(self, container_template: CoreContainerTemplate):
|
|
573
|
+
if self._added_initial_container:
|
|
574
|
+
return
|
|
575
|
+
|
|
576
|
+
parameters: list[Parameter] = []
|
|
577
|
+
|
|
578
|
+
if self.argo_workflow.spec.arguments:
|
|
579
|
+
parameters = self.argo_workflow.spec.arguments.parameters or []
|
|
580
|
+
|
|
581
|
+
for parameter in parameters or []:
|
|
582
|
+
key, _ = parameter.name, parameter.value
|
|
583
|
+
env_var = EnvVar(
|
|
584
|
+
name=defaults.PARAMETER_PREFIX + key,
|
|
585
|
+
value="{{workflow.parameters." + key + "}}",
|
|
586
|
+
)
|
|
587
|
+
container_template.env.append(env_var)
|
|
588
|
+
|
|
589
|
+
env_var = EnvVar(name="error_on_existing_run_id", value="true")
|
|
590
|
+
container_template.env.append(env_var)
|
|
591
|
+
|
|
592
|
+
# After the first container is added, set the added_initial_container to True
|
|
593
|
+
self._added_initial_container = True
|
|
594
|
+
|
|
595
|
+
def _create_fan_templates(
|
|
596
|
+
self,
|
|
597
|
+
node: BaseNode,
|
|
598
|
+
mode: str,
|
|
599
|
+
parameters: Optional[list[Parameter]],
|
|
600
|
+
task_name: str,
|
|
601
|
+
):
|
|
602
|
+
iter_variable: IterableParameterModel = IterableParameterModel()
|
|
603
|
+
for parameter in parameters or []:
|
|
604
|
+
iter_variable.map_variable[parameter.name] = ( # type: ignore
|
|
605
|
+
"{{inputs.parameters." + str(parameter.name) + "}}"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
fan_command = self._context.get_fan_command(
|
|
609
|
+
mode=mode,
|
|
610
|
+
node=node,
|
|
611
|
+
run_id=self._run_id_as_parameter,
|
|
612
|
+
iter_variable=iter_variable,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
core_container_template = CoreContainerTemplate(
|
|
616
|
+
command=shlex.split(fan_command),
|
|
617
|
+
image=self.defaults.image,
|
|
618
|
+
image_pull_policy=self.defaults.image_pull_policy,
|
|
619
|
+
volume_mounts=[
|
|
620
|
+
volume_pair.volume_mount for volume_pair in self.volume_pairs
|
|
621
|
+
],
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Add retry environment variables to all containers
|
|
625
|
+
self._add_retry_env_vars(container_template=core_container_template)
|
|
626
|
+
|
|
627
|
+
# Either a task or a fan-out can the first container
|
|
628
|
+
self._set_up_initial_container(container_template=core_container_template)
|
|
629
|
+
|
|
630
|
+
task_name += f"-fan-{mode}"
|
|
631
|
+
|
|
632
|
+
outputs: Optional[Outputs] = None
|
|
633
|
+
if mode == "out" and node.node_type == "map":
|
|
634
|
+
outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
|
|
635
|
+
if mode == "out" and node.node_type == "conditional":
|
|
636
|
+
outputs = Outputs(parameters=[OutputParameter(name="case")])
|
|
637
|
+
|
|
638
|
+
config_map_key = (
|
|
639
|
+
"{{workflow.parameters.run_id}}-"
|
|
640
|
+
+ f"{task_name}-"
|
|
641
|
+
+ "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
container_template = ContainerTemplate(
|
|
645
|
+
name=task_name,
|
|
646
|
+
container=core_container_template,
|
|
647
|
+
inputs=Inputs(parameters=parameters),
|
|
648
|
+
outputs=outputs,
|
|
649
|
+
memoize=Memoize(
|
|
650
|
+
key=config_map_key,
|
|
651
|
+
cache=Cache(config_map=ConfigMapCache(name=self.cache_name)),
|
|
652
|
+
),
|
|
653
|
+
active_deadline_seconds=self.defaults.active_deadline_seconds,
|
|
654
|
+
node_selector=self.defaults.node_selector,
|
|
655
|
+
parallelism=self.defaults.parallelism,
|
|
656
|
+
retry_strategy=self.defaults.retry_strategy,
|
|
657
|
+
timeout=self.defaults.timeout,
|
|
658
|
+
tolerations=self.defaults.tolerations,
|
|
659
|
+
volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
self._templates.append(container_template)
|
|
663
|
+
|
|
664
|
+
def _create_container_template(
|
|
665
|
+
self,
|
|
666
|
+
node: BaseNode,
|
|
667
|
+
task_name: str,
|
|
668
|
+
inputs: Optional[Inputs] = None,
|
|
669
|
+
) -> ContainerTemplate:
|
|
670
|
+
assert node.node_type in ["task", "success", "stub", "fail"]
|
|
671
|
+
|
|
672
|
+
node_override = None
|
|
673
|
+
if hasattr(node, "overrides"):
|
|
674
|
+
override_key = node.overrides.get(self.service_name, "")
|
|
675
|
+
try:
|
|
676
|
+
node_override = self.overrides.get(override_key)
|
|
677
|
+
except: # noqa
|
|
678
|
+
raise Exception("Override not found for: ", override_key)
|
|
679
|
+
|
|
680
|
+
effective_settings = self.defaults.model_dump()
|
|
681
|
+
if node_override:
|
|
682
|
+
effective_settings.update(node_override.model_dump(exclude_none=True))
|
|
683
|
+
|
|
684
|
+
inputs = inputs or Inputs(parameters=[])
|
|
685
|
+
|
|
686
|
+
# Should look like: '{"map_variable":{"chunk":{"value":3}},"loop_variable":[]}'
|
|
687
|
+
iter_variable: IterableParameterModel = IterableParameterModel()
|
|
688
|
+
for parameter in inputs.parameters or []:
|
|
689
|
+
map_variable = MapVariableModel(
|
|
690
|
+
value=f"{{{{inputs.parameters.{str(parameter.name)}}}}}"
|
|
691
|
+
)
|
|
692
|
+
iter_variable.map_variable[parameter.name] = ( # type: ignore
|
|
693
|
+
map_variable
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# command = "runnable execute-single-node"
|
|
697
|
+
command = self._context.get_node_callable_command(
|
|
698
|
+
node=node,
|
|
699
|
+
iter_variable=iter_variable,
|
|
700
|
+
over_write_run_id=self._run_id_as_parameter,
|
|
701
|
+
log_level=self._log_level_as_parameter,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
core_container_template = CoreContainerTemplate(
|
|
705
|
+
command=shlex.split(command),
|
|
706
|
+
image=effective_settings["image"],
|
|
707
|
+
image_pull_policy=effective_settings["image_pull_policy"],
|
|
708
|
+
resources=effective_settings["resources"],
|
|
709
|
+
volume_mounts=[
|
|
710
|
+
volume_pair.volume_mount for volume_pair in self.volume_pairs
|
|
711
|
+
],
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Add retry environment variables to all containers
|
|
715
|
+
self._add_retry_env_vars(container_template=core_container_template)
|
|
716
|
+
|
|
717
|
+
self._set_up_initial_container(container_template=core_container_template)
|
|
718
|
+
self._expose_secrets_to_task(
|
|
719
|
+
working_on=node, container_template=core_container_template
|
|
720
|
+
)
|
|
721
|
+
self._set_env_vars_to_task(node, core_container_template)
|
|
722
|
+
config_map_key = (
|
|
723
|
+
"{{workflow.parameters.run_id}}-"
|
|
724
|
+
+ f"{task_name}-"
|
|
725
|
+
+ "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
container_template = ContainerTemplate(
|
|
729
|
+
name=task_name,
|
|
730
|
+
container=core_container_template,
|
|
731
|
+
inputs=Inputs(
|
|
732
|
+
parameters=[
|
|
733
|
+
Parameter(name=param.name) for param in inputs.parameters or []
|
|
734
|
+
]
|
|
735
|
+
),
|
|
736
|
+
memoize=Memoize(
|
|
737
|
+
key=config_map_key,
|
|
738
|
+
cache=Cache(config_map=ConfigMapCache(name=self.cache_name)),
|
|
739
|
+
),
|
|
740
|
+
volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
|
|
741
|
+
**node_override.model_dump() if node_override else {},
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
if iter_variable and iter_variable.map_variable:
|
|
745
|
+
# Do not cache map tasks as they have different inputs each time
|
|
746
|
+
container_template.memoize = None
|
|
747
|
+
|
|
748
|
+
return container_template
|
|
749
|
+
|
|
750
|
+
def _set_env_vars_to_task(
|
|
751
|
+
self, working_on: BaseNode, container_template: CoreContainerTemplate
|
|
752
|
+
):
|
|
753
|
+
global_envs: dict[str, str] = {}
|
|
754
|
+
|
|
755
|
+
# Apply defaults environment variables to all node types
|
|
756
|
+
for env_var in self.defaults.env:
|
|
757
|
+
env_var = cast(EnvVar, env_var)
|
|
758
|
+
global_envs[env_var.name] = env_var.value
|
|
759
|
+
|
|
760
|
+
# Apply node-specific overrides only for task nodes that support overrides
|
|
761
|
+
if working_on.node_type in ["task"] and hasattr(working_on, "overrides"):
|
|
762
|
+
override_key = working_on.overrides.get(self.service_name, "")
|
|
763
|
+
node_override = self.overrides.get(override_key, None)
|
|
764
|
+
|
|
765
|
+
# Update the global envs with the node overrides
|
|
766
|
+
if node_override:
|
|
767
|
+
for env_var in node_override.env:
|
|
768
|
+
env_var = cast(EnvVar, env_var)
|
|
769
|
+
global_envs[env_var.name] = env_var.value
|
|
770
|
+
|
|
771
|
+
for key, value in global_envs.items():
|
|
772
|
+
env_var_to_add = EnvVar(name=key, value=value)
|
|
773
|
+
container_template.env.append(env_var_to_add)
|
|
774
|
+
|
|
775
|
+
# Add argo uid as environment variable
|
|
776
|
+
argo_uid_env = EnvVar(
|
|
777
|
+
name="RUNNABLE_CODE_ID_ARGO_WORKFLOW_UID",
|
|
778
|
+
value="{{workflow.uid}}",
|
|
779
|
+
)
|
|
780
|
+
container_template.env.append(argo_uid_env)
|
|
781
|
+
|
|
782
|
+
def _expose_secrets_to_task(
|
|
783
|
+
self,
|
|
784
|
+
working_on: BaseNode,
|
|
785
|
+
container_template: CoreContainerTemplate,
|
|
786
|
+
):
|
|
787
|
+
if not isinstance(working_on, TaskNode):
|
|
788
|
+
return
|
|
789
|
+
secrets = working_on.executable.secrets
|
|
790
|
+
for secret in secrets:
|
|
791
|
+
assert self.secret_from_k8s is not None
|
|
792
|
+
secret_env_var = SecretEnvVar(
|
|
793
|
+
environment_variable=secret,
|
|
794
|
+
secret_name=self.secret_from_k8s, # This has to be exposed from config
|
|
795
|
+
secret_key=secret,
|
|
796
|
+
)
|
|
797
|
+
container_template.env.append(secret_env_var)
|
|
798
|
+
|
|
799
|
+
def _handle_failures(
|
|
800
|
+
self,
|
|
801
|
+
working_on: BaseNode,
|
|
802
|
+
dag: Graph,
|
|
803
|
+
task_name: str,
|
|
804
|
+
parent_dag_template: DagTemplate,
|
|
805
|
+
):
|
|
806
|
+
if working_on._get_on_failure_node():
|
|
807
|
+
# Create a new dag template
|
|
808
|
+
on_failure_dag: DagTemplate = DagTemplate(name=f"on-failure-{task_name}")
|
|
809
|
+
# Add on failure of the current task to be the failure dag template
|
|
810
|
+
on_failure_task = DagTask(
|
|
811
|
+
name=f"on-failure-{task_name}",
|
|
812
|
+
template=f"on-failure-{task_name}",
|
|
813
|
+
depends=task_name + ".Failed",
|
|
814
|
+
)
|
|
815
|
+
# Set failfast of the dag template to be false
|
|
816
|
+
# If not, this branch will never be invoked
|
|
817
|
+
parent_dag_template.fail_fast = False
|
|
818
|
+
|
|
819
|
+
assert parent_dag_template.dag
|
|
820
|
+
|
|
821
|
+
parent_dag_template.dag.tasks.append(on_failure_task)
|
|
822
|
+
|
|
823
|
+
self._gather_tasks_for_dag_template(
|
|
824
|
+
on_failure_dag,
|
|
825
|
+
dag=dag,
|
|
826
|
+
start_at=working_on._get_on_failure_node(),
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# For the future me:
|
|
830
|
+
# - A task can output a array: in this case, its the fan out.
|
|
831
|
+
# - We are using withParam and arguments of the map template to send that value in
|
|
832
|
+
# - The map template should receive that value as a parameter into the template.
|
|
833
|
+
# - The task then start to use it as inputs.parameters.iterate-on
|
|
834
|
+
# the when param should be an evaluation
|
|
835
|
+
|
|
836
|
+
def _gather_tasks_for_dag_template(
|
|
837
|
+
self,
|
|
838
|
+
dag_template: DagTemplate,
|
|
839
|
+
dag: Graph,
|
|
840
|
+
start_at: str,
|
|
841
|
+
parameters: Optional[list[Parameter]] = None,
|
|
842
|
+
):
|
|
843
|
+
current_node: str = start_at
|
|
844
|
+
depends: str = ""
|
|
845
|
+
|
|
846
|
+
dag_template.dag = CoreDagTemplate()
|
|
847
|
+
|
|
848
|
+
while True:
|
|
849
|
+
# Create the dag task with for the parent dag
|
|
850
|
+
working_on: BaseNode = dag.get_node_by_name(current_node)
|
|
851
|
+
task_name = self.sanitize_name(working_on.internal_name)
|
|
852
|
+
current_task = DagTask(
|
|
853
|
+
name=task_name,
|
|
854
|
+
template=task_name,
|
|
855
|
+
depends=depends if not depends else depends + ".Succeeded",
|
|
856
|
+
arguments=Arguments(
|
|
857
|
+
parameters=[
|
|
858
|
+
Parameter(
|
|
859
|
+
name=param.name,
|
|
860
|
+
value=f"{{{{inputs.parameters.{param.name}}}}}",
|
|
861
|
+
)
|
|
862
|
+
for param in parameters or []
|
|
863
|
+
]
|
|
864
|
+
),
|
|
865
|
+
)
|
|
866
|
+
dag_template.dag.tasks.append(current_task)
|
|
867
|
+
depends = task_name
|
|
868
|
+
|
|
869
|
+
match working_on.node_type:
|
|
870
|
+
case "task" | "success" | "stub" | "fail":
|
|
871
|
+
template_of_container = self._create_container_template(
|
|
872
|
+
working_on,
|
|
873
|
+
task_name=task_name,
|
|
874
|
+
inputs=Inputs(parameters=parameters),
|
|
875
|
+
)
|
|
876
|
+
assert template_of_container.container is not None
|
|
877
|
+
|
|
878
|
+
self._templates.append(template_of_container)
|
|
879
|
+
|
|
880
|
+
case "map" | "parallel" | "conditional":
|
|
881
|
+
assert (
|
|
882
|
+
isinstance(working_on, MapNode)
|
|
883
|
+
or isinstance(working_on, ParallelNode)
|
|
884
|
+
or isinstance(working_on, ConditionalNode)
|
|
885
|
+
)
|
|
886
|
+
node_type = working_on.node_type
|
|
887
|
+
|
|
888
|
+
composite_template: DagTemplate = DagTemplate(
|
|
889
|
+
name=task_name, fail_fast=False
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Add the fan out task
|
|
893
|
+
fan_out_task = DagTask(
|
|
894
|
+
name=f"{task_name}-fan-out",
|
|
895
|
+
template=f"{task_name}-fan-out",
|
|
896
|
+
arguments=Arguments(parameters=parameters),
|
|
897
|
+
)
|
|
898
|
+
composite_template.dag.tasks.append(fan_out_task)
|
|
899
|
+
self._create_fan_templates(
|
|
900
|
+
node=working_on,
|
|
901
|
+
mode="out",
|
|
902
|
+
parameters=parameters,
|
|
903
|
+
task_name=task_name,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
# Add the composite task
|
|
907
|
+
with_param: Optional[str] = None
|
|
908
|
+
when_param: Optional[str] = None
|
|
909
|
+
added_parameters = parameters or []
|
|
910
|
+
branches = {}
|
|
911
|
+
if node_type == "map":
|
|
912
|
+
# If the node is map, we need to handle the iterate as and on
|
|
913
|
+
assert isinstance(working_on, MapNode)
|
|
914
|
+
added_parameters = added_parameters + [
|
|
915
|
+
Parameter(name=working_on.iterate_as, value="{{item}}")
|
|
916
|
+
]
|
|
917
|
+
with_param = f"{{{{tasks.{task_name}-fan-out.outputs.parameters.iterate-on}}}}"
|
|
918
|
+
|
|
919
|
+
branches["branch"] = working_on.branch
|
|
920
|
+
elif node_type == "parallel":
|
|
921
|
+
assert isinstance(working_on, ParallelNode)
|
|
922
|
+
branches = working_on.branches
|
|
923
|
+
elif node_type == "conditional":
|
|
924
|
+
assert isinstance(working_on, ConditionalNode)
|
|
925
|
+
branches = working_on.branches
|
|
926
|
+
when_param = (
|
|
927
|
+
f"{{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
|
|
928
|
+
)
|
|
929
|
+
else:
|
|
930
|
+
raise ValueError("Invalid node type")
|
|
931
|
+
|
|
932
|
+
fan_in_depends = ""
|
|
933
|
+
|
|
934
|
+
for name, branch in branches.items():
|
|
935
|
+
match_when = branch.internal_branch_name.split(".")[-1]
|
|
936
|
+
name = (
|
|
937
|
+
name.replace(" ", "-").replace(".", "-").replace("_", "-")
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
if node_type == "conditional":
|
|
941
|
+
assert isinstance(working_on, ConditionalNode)
|
|
942
|
+
when_param = f"'{match_when}' == {{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
|
|
943
|
+
|
|
944
|
+
branch_task = DagTask(
|
|
945
|
+
name=f"{task_name}-{name}",
|
|
946
|
+
template=f"{task_name}-{name}",
|
|
947
|
+
depends=f"{task_name}-fan-out.Succeeded",
|
|
948
|
+
arguments=Arguments(parameters=added_parameters),
|
|
949
|
+
with_param=with_param,
|
|
950
|
+
when_param=when_param,
|
|
951
|
+
)
|
|
952
|
+
composite_template.dag.tasks.append(branch_task)
|
|
953
|
+
|
|
954
|
+
branch_template = DagTemplate(
|
|
955
|
+
name=branch_task.name,
|
|
956
|
+
inputs=Inputs(
|
|
957
|
+
parameters=[
|
|
958
|
+
Parameter(name=param.name, value=None)
|
|
959
|
+
for param in added_parameters
|
|
960
|
+
]
|
|
961
|
+
),
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
assert isinstance(branch, Graph)
|
|
965
|
+
|
|
966
|
+
self._gather_tasks_for_dag_template(
|
|
967
|
+
dag_template=branch_template,
|
|
968
|
+
dag=branch,
|
|
969
|
+
start_at=branch.start_at,
|
|
970
|
+
parameters=added_parameters,
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
fan_in_depends += f"{branch_task.name}.Succeeded || {branch_task.name}.Failed || "
|
|
974
|
+
|
|
975
|
+
fan_in_task = DagTask(
|
|
976
|
+
name=f"{task_name}-fan-in",
|
|
977
|
+
template=f"{task_name}-fan-in",
|
|
978
|
+
depends=fan_in_depends.strip(" || "),
|
|
979
|
+
arguments=Arguments(parameters=parameters),
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
composite_template.dag.tasks.append(fan_in_task)
|
|
983
|
+
self._create_fan_templates(
|
|
984
|
+
node=working_on,
|
|
985
|
+
mode="in",
|
|
986
|
+
parameters=parameters,
|
|
987
|
+
task_name=task_name,
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
self._templates.append(composite_template)
|
|
991
|
+
|
|
992
|
+
self._handle_failures(
|
|
993
|
+
working_on,
|
|
994
|
+
dag,
|
|
995
|
+
task_name,
|
|
996
|
+
parent_dag_template=dag_template,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
if working_on.node_type == "success" or working_on.node_type == "fail":
|
|
1000
|
+
break
|
|
1001
|
+
|
|
1002
|
+
current_node = working_on._get_next_node()
|
|
1003
|
+
|
|
1004
|
+
self._templates.append(dag_template)
|
|
1005
|
+
|
|
1006
|
+
def execute_graph(
|
|
1007
|
+
self,
|
|
1008
|
+
dag: Graph,
|
|
1009
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
1010
|
+
):
|
|
1011
|
+
# All the arguments set at the spec level can be referred as "{{workflow.parameters.*}}"
|
|
1012
|
+
# We want to use that functionality to override the parameters at the task level
|
|
1013
|
+
# We should be careful to override them only at the first task.
|
|
1014
|
+
arguments = [] # Can be updated in the UI
|
|
1015
|
+
if self.expose_parameters_as_inputs:
|
|
1016
|
+
for key, value in self._get_parameters().items():
|
|
1017
|
+
value = value.get_value() # type: ignore
|
|
1018
|
+
if isinstance(value, dict) or isinstance(value, list):
|
|
1019
|
+
continue
|
|
1020
|
+
|
|
1021
|
+
parameter = Parameter(name=key, value=value) # type: ignore
|
|
1022
|
+
arguments.append(parameter)
|
|
1023
|
+
|
|
1024
|
+
# run_id parameter - required at workflow submission time (no default value)
|
|
1025
|
+
run_id_var = Parameter(name="run_id")
|
|
1026
|
+
arguments.append(run_id_var)
|
|
1027
|
+
|
|
1028
|
+
# Optional retry parameters with empty string defaults
|
|
1029
|
+
retry_run_id_var = Parameter(name="retry_run_id", value="")
|
|
1030
|
+
retry_indicator_var = Parameter(name="retry_indicator", value="")
|
|
1031
|
+
arguments.append(retry_run_id_var)
|
|
1032
|
+
arguments.append(retry_indicator_var)
|
|
1033
|
+
|
|
1034
|
+
log_level_var = Parameter(name="log_level", value=self.log_level)
|
|
1035
|
+
arguments.append(log_level_var)
|
|
1036
|
+
self.argo_workflow.spec.arguments = Arguments(parameters=arguments)
|
|
1037
|
+
|
|
1038
|
+
# This is the entry point of the argo execution
|
|
1039
|
+
runnable_dag: DagTemplate = DagTemplate(name="runnable-dag")
|
|
1040
|
+
|
|
1041
|
+
self._gather_tasks_for_dag_template(
|
|
1042
|
+
runnable_dag,
|
|
1043
|
+
dag,
|
|
1044
|
+
start_at=dag.start_at,
|
|
1045
|
+
parameters=[],
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
argo_workflow_dump = self.argo_workflow.model_dump(
|
|
1049
|
+
by_alias=True,
|
|
1050
|
+
exclude_none=True,
|
|
1051
|
+
round_trip=False,
|
|
1052
|
+
)
|
|
1053
|
+
argo_workflow_dump["spec"]["templates"] = [
|
|
1054
|
+
template.model_dump(
|
|
1055
|
+
by_alias=True,
|
|
1056
|
+
exclude_none=True,
|
|
1057
|
+
)
|
|
1058
|
+
for template in self._templates
|
|
1059
|
+
]
|
|
1060
|
+
|
|
1061
|
+
argo_workflow_dump["spec"]["volumes"] = [
|
|
1062
|
+
volume_pair.volume.model_dump(by_alias=True)
|
|
1063
|
+
for volume_pair in self.volume_pairs
|
|
1064
|
+
]
|
|
1065
|
+
|
|
1066
|
+
# If cron_schedule is set, wrap in CronWorkflow
|
|
1067
|
+
if self.cron_schedule:
|
|
1068
|
+
output_dump = self._wrap_as_cron_workflow(argo_workflow_dump)
|
|
1069
|
+
else:
|
|
1070
|
+
output_dump = argo_workflow_dump
|
|
1071
|
+
|
|
1072
|
+
yaml = YAML()
|
|
1073
|
+
with open(self.output_file, "w") as f:
|
|
1074
|
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
|
1075
|
+
yaml.dump(
|
|
1076
|
+
output_dump,
|
|
1077
|
+
f,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
def _wrap_as_cron_workflow(self, workflow_dump: dict) -> dict:
|
|
1081
|
+
"""Wrap a Workflow dump as a CronWorkflow."""
|
|
1082
|
+
assert self.cron_schedule is not None
|
|
1083
|
+
|
|
1084
|
+
# Extract metadata and convert generateName to name
|
|
1085
|
+
metadata = workflow_dump["metadata"].copy()
|
|
1086
|
+
generate_name = metadata.pop("generateName", "runnable-cron-")
|
|
1087
|
+
metadata["name"] = generate_name.rstrip("-")
|
|
1088
|
+
|
|
1089
|
+
# Build CronWorkflow spec
|
|
1090
|
+
cron_spec: dict[str, Any] = {
|
|
1091
|
+
"schedules": self.cron_schedule.schedules,
|
|
1092
|
+
"workflowSpec": workflow_dump["spec"],
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
if self.cron_schedule.timezone:
|
|
1096
|
+
cron_spec["timezone"] = self.cron_schedule.timezone
|
|
1097
|
+
|
|
1098
|
+
return {
|
|
1099
|
+
"apiVersion": "argoproj.io/v1alpha1",
|
|
1100
|
+
"kind": "CronWorkflow",
|
|
1101
|
+
"metadata": metadata,
|
|
1102
|
+
"spec": cron_spec,
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
def _implicitly_fail(
|
|
1106
|
+
self,
|
|
1107
|
+
node: BaseNode,
|
|
1108
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
1109
|
+
):
|
|
1110
|
+
assert self._context.dag
|
|
1111
|
+
_, current_branch = search_node_by_internal_name(
|
|
1112
|
+
dag=self._context.dag, internal_name=node.internal_name
|
|
1113
|
+
)
|
|
1114
|
+
_, next_node_name = self._get_status_and_next_node_name(
|
|
1115
|
+
node, current_branch, iter_variable=iter_variable
|
|
1116
|
+
)
|
|
1117
|
+
if next_node_name:
|
|
1118
|
+
# Terminal nodes do not have next node name
|
|
1119
|
+
next_node = current_branch.get_node_by_name(next_node_name)
|
|
1120
|
+
|
|
1121
|
+
if next_node.node_type == defaults.FAIL:
|
|
1122
|
+
self.execute_node(next_node, iter_variable=iter_variable)
|
|
1123
|
+
|
|
1124
|
+
def add_code_identities(self, node: BaseNode, attempt_log: StepAttempt):
|
|
1125
|
+
super().add_code_identities(node, attempt_log)
|
|
1126
|
+
|
|
1127
|
+
workflow_uid = os.getenv("RUNNABLE_CODE_ID_ARGO_WORKFLOW_UID")
|
|
1128
|
+
|
|
1129
|
+
if workflow_uid:
|
|
1130
|
+
code_id = self._context.run_log_store.create_code_identity()
|
|
1131
|
+
|
|
1132
|
+
code_id.code_identifier = workflow_uid
|
|
1133
|
+
code_id.code_identifier_type = "argo"
|
|
1134
|
+
code_id.code_identifier_dependable = True
|
|
1135
|
+
code_id.code_identifier_url = "argo workflow"
|
|
1136
|
+
attempt_log.code_identities.append(code_id)
|
|
1137
|
+
|
|
1138
|
+
def execute_node(
|
|
1139
|
+
self,
|
|
1140
|
+
node: BaseNode,
|
|
1141
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
1142
|
+
):
|
|
1143
|
+
error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
|
|
1144
|
+
exists_ok = error_on_existing_run_id == "false"
|
|
1145
|
+
|
|
1146
|
+
self._use_volumes()
|
|
1147
|
+
self._set_up_run_log(exists_ok=exists_ok)
|
|
1148
|
+
|
|
1149
|
+
try:
|
|
1150
|
+
# This should only happen during a retry
|
|
1151
|
+
step_log = self._context.run_log_store.get_step_log(
|
|
1152
|
+
node._get_step_log_name(iter_variable), self._context.run_id
|
|
1153
|
+
)
|
|
1154
|
+
assert self._context.is_retry
|
|
1155
|
+
except exceptions.StepLogNotFoundError:
|
|
1156
|
+
step_log = self._context.run_log_store.create_step_log(
|
|
1157
|
+
node.name, node._get_step_log_name(iter_variable)
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
step_log.step_type = node.node_type
|
|
1161
|
+
step_log.status = defaults.PROCESSING
|
|
1162
|
+
self._context.run_log_store.add_step_log(step_log, self._context.run_id)
|
|
1163
|
+
|
|
1164
|
+
self._execute_node(node=node, iter_variable=iter_variable)
|
|
1165
|
+
|
|
1166
|
+
# Raise exception if the step failed
|
|
1167
|
+
step_log = self._context.run_log_store.get_step_log(
|
|
1168
|
+
node._get_step_log_name(iter_variable), self._context.run_id
|
|
1169
|
+
)
|
|
1170
|
+
if step_log.status == defaults.FAIL:
|
|
1171
|
+
run_log = self._context.run_log_store.get_run_log_by_id(
|
|
1172
|
+
self._context.run_id
|
|
1173
|
+
)
|
|
1174
|
+
run_log.status = defaults.FAIL
|
|
1175
|
+
self._context.run_log_store.put_run_log(run_log)
|
|
1176
|
+
raise Exception(f"Step {node.name} failed")
|
|
1177
|
+
|
|
1178
|
+
# This makes the fail node execute if we are heading that way.
|
|
1179
|
+
self._implicitly_fail(node, iter_variable)
|
|
1180
|
+
|
|
1181
|
+
def fan_out(
|
|
1182
|
+
self,
|
|
1183
|
+
node: BaseNode,
|
|
1184
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
1185
|
+
):
|
|
1186
|
+
# This could be the first step of the graph
|
|
1187
|
+
self._use_volumes()
|
|
1188
|
+
|
|
1189
|
+
error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
|
|
1190
|
+
exists_ok = error_on_existing_run_id == "false"
|
|
1191
|
+
self._set_up_run_log(exists_ok=exists_ok)
|
|
1192
|
+
|
|
1193
|
+
super().fan_out(node, iter_variable)
|
|
1194
|
+
|
|
1195
|
+
# If its a map node, write the list values to "/tmp/output.txt"
|
|
1196
|
+
if node.node_type == "map":
|
|
1197
|
+
assert isinstance(node, MapNode)
|
|
1198
|
+
iterate_on = self._context.run_log_store.get_parameters(
|
|
1199
|
+
self._context.run_id
|
|
1200
|
+
)[node.iterate_on]
|
|
1201
|
+
|
|
1202
|
+
with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
|
|
1203
|
+
json.dump(iterate_on.get_value(), myfile, indent=4)
|
|
1204
|
+
|
|
1205
|
+
if node.node_type == "conditional":
|
|
1206
|
+
assert isinstance(node, ConditionalNode)
|
|
1207
|
+
|
|
1208
|
+
with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
|
|
1209
|
+
json.dump(node.get_parameter_value(), myfile, indent=4)
|
|
1210
|
+
|
|
1211
|
+
def fan_in(
|
|
1212
|
+
self,
|
|
1213
|
+
node: BaseNode,
|
|
1214
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
1215
|
+
):
|
|
1216
|
+
self._use_volumes()
|
|
1217
|
+
super().fan_in(node, iter_variable)
|
|
1218
|
+
|
|
1219
|
+
def _use_volumes(self):
|
|
1220
|
+
match self._context.run_log_store.service_name:
|
|
1221
|
+
case "file-system":
|
|
1222
|
+
self._context.run_log_store.log_folder = self._container_log_location
|
|
1223
|
+
case "chunked-fs":
|
|
1224
|
+
self._context.run_log_store.log_folder = self._container_log_location
|
|
1225
|
+
|
|
1226
|
+
match self._context.catalog.service_name:
|
|
1227
|
+
case "file-system":
|
|
1228
|
+
self._context.catalog.catalog_location = (
|
|
1229
|
+
self._container_catalog_location
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
@cached_property
|
|
1233
|
+
def volume_pairs(self) -> list[VolumePair]:
|
|
1234
|
+
volume_pairs: list[VolumePair] = []
|
|
1235
|
+
|
|
1236
|
+
if self.pvc_for_runnable:
|
|
1237
|
+
common_volume = Volume(
|
|
1238
|
+
name="runnable",
|
|
1239
|
+
persistent_volume_claim=PersistentVolumeClaimSource(
|
|
1240
|
+
claim_name=self.pvc_for_runnable
|
|
1241
|
+
),
|
|
1242
|
+
)
|
|
1243
|
+
common_volume_mount = VolumeMount(
|
|
1244
|
+
name="runnable",
|
|
1245
|
+
mount_path="/tmp",
|
|
1246
|
+
)
|
|
1247
|
+
volume_pairs.append(
|
|
1248
|
+
VolumePair(volume=common_volume, volume_mount=common_volume_mount)
|
|
1249
|
+
)
|
|
1250
|
+
counter = 0
|
|
1251
|
+
for custom_volume in self.custom_volumes or []:
|
|
1252
|
+
name = f"custom-volume-{counter}"
|
|
1253
|
+
volume_pairs.append(
|
|
1254
|
+
VolumePair(
|
|
1255
|
+
volume=Volume(
|
|
1256
|
+
name=name,
|
|
1257
|
+
persistent_volume_claim=custom_volume.persistent_volume_claim,
|
|
1258
|
+
),
|
|
1259
|
+
volume_mount=VolumeMount(
|
|
1260
|
+
name=name,
|
|
1261
|
+
mount_path=custom_volume.mount_path,
|
|
1262
|
+
),
|
|
1263
|
+
)
|
|
1264
|
+
)
|
|
1265
|
+
counter += 1
|
|
1266
|
+
return volume_pairs
|