zenml-nightly 0.84.1.dev20250805__py3-none-any.whl → 0.84.1.dev20250806__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/integrations/kubernetes/constants.py +27 -0
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +79 -36
- zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +55 -24
- zenml/integrations/kubernetes/orchestrators/dag_runner.py +367 -0
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +368 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +144 -262
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +392 -244
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +53 -85
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +74 -32
- zenml/logging/step_logging.py +33 -30
- zenml/steps/base_step.py +6 -6
- zenml/steps/step_decorator.py +4 -4
- {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/METADATA +1 -1
- {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/RECORD +18 -16
- {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.84.1.
|
1
|
+
0.84.1.dev20250806
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
#
|
15
|
+
# Parts of the `prepare_or_run_pipeline()` method of this file are
|
16
|
+
# inspired by the Kubernetes dag runner implementation of tfx
|
17
|
+
"""Kubernetes orchestrator constants."""
|
18
|
+
|
19
|
+
ENV_ZENML_KUBERNETES_RUN_ID = "ZENML_KUBERNETES_RUN_ID"
|
20
|
+
KUBERNETES_SECRET_TOKEN_KEY_NAME = "zenml_api_token"
|
21
|
+
KUBERNETES_CRON_JOB_METADATA_KEY = "cron_job_name"
|
22
|
+
# Annotation keys
|
23
|
+
ORCHESTRATOR_ANNOTATION_KEY = "zenml.io/orchestrator"
|
24
|
+
RUN_ID_ANNOTATION_KEY = "zenml.io/run-id"
|
25
|
+
ORCHESTRATOR_RUN_ID_ANNOTATION_KEY = "zenml.io/orchestrator-run-id"
|
26
|
+
STEP_NAME_ANNOTATION_KEY = "zenml.io/step-name"
|
27
|
+
STEP_OPERATOR_ANNOTATION_KEY = "zenml.io/step-operator"
|
@@ -15,7 +15,13 @@
|
|
15
15
|
|
16
16
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
17
17
|
|
18
|
-
from pydantic import
|
18
|
+
from pydantic import (
|
19
|
+
Field,
|
20
|
+
NonNegativeInt,
|
21
|
+
PositiveFloat,
|
22
|
+
PositiveInt,
|
23
|
+
field_validator,
|
24
|
+
)
|
19
25
|
|
20
26
|
from zenml.config.base_settings import BaseSettings
|
21
27
|
from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE
|
@@ -23,6 +29,7 @@ from zenml.integrations.kubernetes import KUBERNETES_ORCHESTRATOR_FLAVOR
|
|
23
29
|
from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
|
24
30
|
from zenml.models import ServiceConnectorRequirements
|
25
31
|
from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
|
32
|
+
from zenml.utils import deprecation_utils
|
26
33
|
|
27
34
|
if TYPE_CHECKING:
|
28
35
|
from zenml.integrations.kubernetes.orchestrators import (
|
@@ -42,16 +49,6 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
42
49
|
description="Whether to wait for all pipeline steps to complete. "
|
43
50
|
"When `False`, the client returns immediately and execution continues asynchronously.",
|
44
51
|
)
|
45
|
-
timeout: int = Field(
|
46
|
-
default=0,
|
47
|
-
description="Maximum seconds to wait for synchronous runs. Set to `0` for unlimited duration.",
|
48
|
-
)
|
49
|
-
stream_step_logs: bool = Field(
|
50
|
-
default=True,
|
51
|
-
description="If `True`, the orchestrator pod will stream the logs "
|
52
|
-
"of the step pods. This only has an effect if specified on the "
|
53
|
-
"pipeline, not on individual steps.",
|
54
|
-
)
|
55
52
|
service_account_name: Optional[str] = Field(
|
56
53
|
default=None,
|
57
54
|
description="Kubernetes service account for the orchestrator pod. "
|
@@ -74,25 +71,9 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
74
71
|
default=None,
|
75
72
|
description="Pod configuration for the orchestrator container that launches step pods.",
|
76
73
|
)
|
77
|
-
|
74
|
+
job_name_prefix: Optional[str] = Field(
|
78
75
|
default=None,
|
79
|
-
description="
|
80
|
-
)
|
81
|
-
pod_startup_timeout: int = Field(
|
82
|
-
default=600,
|
83
|
-
description="Maximum seconds to wait for step pods to start. Default is 10 minutes.",
|
84
|
-
)
|
85
|
-
pod_failure_max_retries: int = Field(
|
86
|
-
default=3,
|
87
|
-
description="Maximum retry attempts when step pods fail to start.",
|
88
|
-
)
|
89
|
-
pod_failure_retry_delay: int = Field(
|
90
|
-
default=10,
|
91
|
-
description="Delay in seconds between pod failure retry attempts.",
|
92
|
-
)
|
93
|
-
pod_failure_backoff: float = Field(
|
94
|
-
default=1.0,
|
95
|
-
description="Exponential backoff factor for retry delays. Values > 1.0 increase delay with each retry.",
|
76
|
+
description="Prefix for the job name.",
|
96
77
|
)
|
97
78
|
max_parallelism: Optional[PositiveInt] = Field(
|
98
79
|
default=None,
|
@@ -100,19 +81,20 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
100
81
|
)
|
101
82
|
successful_jobs_history_limit: Optional[NonNegativeInt] = Field(
|
102
83
|
default=None,
|
103
|
-
description="Number of successful scheduled jobs to retain in
|
84
|
+
description="Number of successful scheduled jobs to retain in history.",
|
104
85
|
)
|
105
86
|
failed_jobs_history_limit: Optional[NonNegativeInt] = Field(
|
106
87
|
default=None,
|
107
|
-
description="Number of failed scheduled jobs to retain in
|
88
|
+
description="Number of failed scheduled jobs to retain in history.",
|
108
89
|
)
|
109
90
|
ttl_seconds_after_finished: Optional[NonNegativeInt] = Field(
|
110
91
|
default=None,
|
111
|
-
description="Seconds to keep finished
|
92
|
+
description="Seconds to keep finished jobs before automatic cleanup.",
|
112
93
|
)
|
113
94
|
active_deadline_seconds: Optional[NonNegativeInt] = Field(
|
114
95
|
default=None,
|
115
|
-
description="
|
96
|
+
description="Job deadline in seconds. If the job doesn't finish "
|
97
|
+
"within this time, it will be terminated.",
|
116
98
|
)
|
117
99
|
backoff_limit_margin: NonNegativeInt = Field(
|
118
100
|
default=0,
|
@@ -131,6 +113,10 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
131
113
|
"the chance of the server receiving the maximum amount of retry "
|
132
114
|
"requests.",
|
133
115
|
)
|
116
|
+
orchestrator_job_backoff_limit: NonNegativeInt = Field(
|
117
|
+
default=3,
|
118
|
+
description="The backoff limit for the orchestrator job.",
|
119
|
+
)
|
134
120
|
fail_on_container_waiting_reasons: Optional[List[str]] = Field(
|
135
121
|
default=[
|
136
122
|
"InvalidImageName",
|
@@ -144,11 +130,21 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
144
130
|
"`pod.status.containerStatuses[*].state.waiting.reason` of a job pod, "
|
145
131
|
"should cause the job to fail immediately.",
|
146
132
|
)
|
147
|
-
job_monitoring_interval:
|
133
|
+
job_monitoring_interval: PositiveFloat = Field(
|
148
134
|
default=3,
|
149
135
|
description="The interval in seconds to monitor the job. Each interval "
|
150
|
-
"is used to check for container issues
|
151
|
-
|
136
|
+
"is used to check for container issues for the job pods.",
|
137
|
+
)
|
138
|
+
job_monitoring_delay: PositiveFloat = Field(
|
139
|
+
default=0.0,
|
140
|
+
description="The delay in seconds to wait between monitoring active "
|
141
|
+
"step jobs. This can be used to reduce load on the Kubernetes API "
|
142
|
+
"server.",
|
143
|
+
)
|
144
|
+
interrupt_check_interval: PositiveFloat = Field(
|
145
|
+
default=1.0,
|
146
|
+
description="The interval in seconds to check for run interruptions.",
|
147
|
+
ge=0.5,
|
152
148
|
)
|
153
149
|
pod_failure_policy: Optional[Dict[str, Any]] = Field(
|
154
150
|
default=None,
|
@@ -169,6 +165,53 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
169
165
|
description="When stopping a pipeline run, the amount of seconds to wait for a step pod to shutdown gracefully.",
|
170
166
|
)
|
171
167
|
|
168
|
+
# Deprecated fields
|
169
|
+
timeout: Optional[int] = Field(
|
170
|
+
default=None,
|
171
|
+
deprecated=True,
|
172
|
+
description="DEPRECATED/UNUSED.",
|
173
|
+
)
|
174
|
+
stream_step_logs: Optional[bool] = Field(
|
175
|
+
default=None,
|
176
|
+
deprecated=True,
|
177
|
+
description="DEPRECATED/UNUSED.",
|
178
|
+
)
|
179
|
+
pod_startup_timeout: Optional[int] = Field(
|
180
|
+
default=None,
|
181
|
+
description="DEPRECATED/UNUSED.",
|
182
|
+
deprecated=True,
|
183
|
+
)
|
184
|
+
pod_failure_max_retries: Optional[int] = Field(
|
185
|
+
default=None,
|
186
|
+
description="DEPRECATED/UNUSED.",
|
187
|
+
deprecated=True,
|
188
|
+
)
|
189
|
+
pod_failure_retry_delay: Optional[int] = Field(
|
190
|
+
default=None,
|
191
|
+
description="DEPRECATED/UNUSED.",
|
192
|
+
deprecated=True,
|
193
|
+
)
|
194
|
+
pod_failure_backoff: Optional[float] = Field(
|
195
|
+
default=None,
|
196
|
+
description="DEPRECATED/UNUSED.",
|
197
|
+
deprecated=True,
|
198
|
+
)
|
199
|
+
pod_name_prefix: Optional[str] = Field(
|
200
|
+
default=None,
|
201
|
+
deprecated=True,
|
202
|
+
description="DEPRECATED/UNUSED.",
|
203
|
+
)
|
204
|
+
|
205
|
+
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
|
206
|
+
"timeout",
|
207
|
+
"stream_step_logs",
|
208
|
+
"pod_startup_timeout",
|
209
|
+
"pod_failure_max_retries",
|
210
|
+
"pod_failure_retry_delay",
|
211
|
+
"pod_failure_backoff",
|
212
|
+
("pod_name_prefix", "job_name_prefix"),
|
213
|
+
)
|
214
|
+
|
172
215
|
@field_validator("pod_failure_policy", mode="before")
|
173
216
|
@classmethod
|
174
217
|
def _convert_pod_failure_policy(cls, value: Any) -> Any:
|
@@ -13,9 +13,9 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Kubernetes step operator flavor."""
|
15
15
|
|
16
|
-
from typing import TYPE_CHECKING, Optional, Type
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
17
17
|
|
18
|
-
from pydantic import Field
|
18
|
+
from pydantic import Field, NonNegativeInt
|
19
19
|
|
20
20
|
from zenml.config.base_settings import BaseSettings
|
21
21
|
from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE
|
@@ -23,6 +23,7 @@ from zenml.integrations.kubernetes import KUBERNETES_STEP_OPERATOR_FLAVOR
|
|
23
23
|
from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
|
24
24
|
from zenml.models import ServiceConnectorRequirements
|
25
25
|
from zenml.step_operators import BaseStepOperatorConfig, BaseStepOperatorFlavor
|
26
|
+
from zenml.utils import deprecation_utils
|
26
27
|
|
27
28
|
if TYPE_CHECKING:
|
28
29
|
from zenml.integrations.kubernetes.step_operators import (
|
@@ -31,11 +32,7 @@ if TYPE_CHECKING:
|
|
31
32
|
|
32
33
|
|
33
34
|
class KubernetesStepOperatorSettings(BaseSettings):
|
34
|
-
"""Settings for the Kubernetes step operator.
|
35
|
-
|
36
|
-
Configuration options for individual step execution on Kubernetes.
|
37
|
-
Field descriptions are defined inline using Field() descriptors.
|
38
|
-
"""
|
35
|
+
"""Settings for the Kubernetes step operator."""
|
39
36
|
|
40
37
|
pod_settings: Optional[KubernetesPodSettings] = Field(
|
41
38
|
default=None,
|
@@ -49,32 +46,66 @@ class KubernetesStepOperatorSettings(BaseSettings):
|
|
49
46
|
default=False,
|
50
47
|
description="Whether to run step containers in privileged mode with extended permissions.",
|
51
48
|
)
|
52
|
-
|
53
|
-
default=
|
54
|
-
description="
|
49
|
+
job_name_prefix: Optional[str] = Field(
|
50
|
+
default=None,
|
51
|
+
description="Prefix for the job name.",
|
55
52
|
)
|
56
|
-
|
57
|
-
default=
|
58
|
-
description="
|
53
|
+
ttl_seconds_after_finished: Optional[NonNegativeInt] = Field(
|
54
|
+
default=None,
|
55
|
+
description="Seconds to keep finished jobs before automatic cleanup.",
|
59
56
|
)
|
60
|
-
|
61
|
-
default=
|
62
|
-
description="
|
57
|
+
active_deadline_seconds: Optional[NonNegativeInt] = Field(
|
58
|
+
default=None,
|
59
|
+
description="Job deadline in seconds. If the job doesn't finish "
|
60
|
+
"within this time, it will be terminated.",
|
63
61
|
)
|
64
|
-
|
65
|
-
default=
|
66
|
-
|
62
|
+
fail_on_container_waiting_reasons: Optional[List[str]] = Field(
|
63
|
+
default=[
|
64
|
+
"InvalidImageName",
|
65
|
+
"ErrImagePull",
|
66
|
+
"ImagePullBackOff",
|
67
|
+
"CreateContainerConfigError",
|
68
|
+
],
|
69
|
+
description="List of container waiting reasons that should cause the "
|
70
|
+
"job to fail immediately. This should be set to a list of "
|
71
|
+
"nonrecoverable reasons, which if found in any "
|
72
|
+
"`pod.status.containerStatuses[*].state.waiting.reason` of a job pod, "
|
73
|
+
"should cause the job to fail immediately.",
|
74
|
+
)
|
75
|
+
|
76
|
+
# Deprecated fields
|
77
|
+
pod_startup_timeout: Optional[int] = Field(
|
78
|
+
default=None,
|
79
|
+
deprecated=True,
|
80
|
+
description="DEPRECATED/UNUSED.",
|
81
|
+
)
|
82
|
+
pod_failure_max_retries: Optional[int] = Field(
|
83
|
+
default=None,
|
84
|
+
deprecated=True,
|
85
|
+
description="DEPRECATED/UNUSED.",
|
86
|
+
)
|
87
|
+
pod_failure_retry_delay: Optional[int] = Field(
|
88
|
+
default=None,
|
89
|
+
deprecated=True,
|
90
|
+
description="DEPRECATED/UNUSED.",
|
91
|
+
)
|
92
|
+
pod_failure_backoff: Optional[float] = Field(
|
93
|
+
default=None,
|
94
|
+
deprecated=True,
|
95
|
+
description="DEPRECATED/UNUSED.",
|
96
|
+
)
|
97
|
+
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
|
98
|
+
"pod_startup_timeout",
|
99
|
+
"pod_failure_max_retries",
|
100
|
+
"pod_failure_retry_delay",
|
101
|
+
"pod_failure_backoff",
|
67
102
|
)
|
68
103
|
|
69
104
|
|
70
105
|
class KubernetesStepOperatorConfig(
|
71
106
|
BaseStepOperatorConfig, KubernetesStepOperatorSettings
|
72
107
|
):
|
73
|
-
"""Configuration for the Kubernetes step operator.
|
74
|
-
|
75
|
-
Defines cluster connection and execution settings.
|
76
|
-
Field descriptions are defined inline using Field() descriptors.
|
77
|
-
"""
|
108
|
+
"""Configuration for the Kubernetes step operator."""
|
78
109
|
|
79
110
|
kubernetes_namespace: str = Field(
|
80
111
|
default="zenml",
|
@@ -0,0 +1,367 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""DAG runner."""
|
15
|
+
|
16
|
+
import queue
|
17
|
+
import threading
|
18
|
+
import time
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
20
|
+
from typing import Any, Callable, Dict, List, Optional
|
21
|
+
|
22
|
+
from pydantic import BaseModel
|
23
|
+
|
24
|
+
from zenml.logger import get_logger
|
25
|
+
from zenml.utils.enum_utils import StrEnum
|
26
|
+
|
27
|
+
logger = get_logger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class NodeStatus(StrEnum):
|
31
|
+
"""Status of a DAG node."""
|
32
|
+
|
33
|
+
NOT_READY = "not_ready" # Can not be started yet
|
34
|
+
READY = "ready" # Can be started but is still waiting in the queue
|
35
|
+
STARTING = "starting" # Is being started, but not yet running
|
36
|
+
RUNNING = "running"
|
37
|
+
COMPLETED = "completed"
|
38
|
+
FAILED = "failed"
|
39
|
+
SKIPPED = "skipped"
|
40
|
+
CANCELLED = "cancelled"
|
41
|
+
|
42
|
+
|
43
|
+
class InterruptMode(StrEnum):
|
44
|
+
"""Interrupt mode."""
|
45
|
+
|
46
|
+
GRACEFUL = "graceful"
|
47
|
+
FORCE = "force"
|
48
|
+
|
49
|
+
|
50
|
+
class Node(BaseModel):
|
51
|
+
"""DAG node."""
|
52
|
+
|
53
|
+
id: str
|
54
|
+
status: NodeStatus = NodeStatus.NOT_READY
|
55
|
+
upstream_nodes: List[str] = []
|
56
|
+
metadata: Dict[str, Any] = {}
|
57
|
+
|
58
|
+
@property
|
59
|
+
def is_finished(self) -> bool:
|
60
|
+
"""Whether the node is finished.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Whether the node is finished.
|
64
|
+
"""
|
65
|
+
return self.status in {
|
66
|
+
NodeStatus.COMPLETED,
|
67
|
+
NodeStatus.FAILED,
|
68
|
+
NodeStatus.SKIPPED,
|
69
|
+
NodeStatus.CANCELLED,
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
class DagRunner:
|
74
|
+
"""DAG runner.
|
75
|
+
|
76
|
+
This class does the orchestration of running the nodes of a DAG. It is
|
77
|
+
running two loops in separate threads:
|
78
|
+
The main thread
|
79
|
+
- checks if any nodes should be skipped or are ready to
|
80
|
+
run, in which case the node will be added to the startup queue
|
81
|
+
- creates a worker thread to start the node and executes it in a thread
|
82
|
+
pool if there are nodes in the startup queue and the maximum
|
83
|
+
parallelism is not reached
|
84
|
+
- periodically checks if the DAG should be interrupted
|
85
|
+
The monitoring thread
|
86
|
+
- monitors the running nodes and updates their status
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
nodes: List[Node],
|
92
|
+
node_startup_function: Callable[[Node], NodeStatus],
|
93
|
+
node_monitoring_function: Callable[[Node], NodeStatus],
|
94
|
+
node_stop_function: Optional[Callable[[Node], None]] = None,
|
95
|
+
interrupt_function: Optional[
|
96
|
+
Callable[[], Optional[InterruptMode]]
|
97
|
+
] = None,
|
98
|
+
monitoring_interval: float = 1.0,
|
99
|
+
monitoring_delay: float = 0.0,
|
100
|
+
interrupt_check_interval: float = 1.0,
|
101
|
+
max_parallelism: Optional[int] = None,
|
102
|
+
) -> None:
|
103
|
+
"""Initialize the DAG runner.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
nodes: The nodes of the DAG.
|
107
|
+
node_startup_function: The function to start a node.
|
108
|
+
node_monitoring_function: The function to monitor a node.
|
109
|
+
node_stop_function: The function to stop a node.
|
110
|
+
interrupt_function: Will be periodically called to check if the
|
111
|
+
DAG should be interrupted.
|
112
|
+
monitoring_interval: The interval in which the nodes are monitored.
|
113
|
+
monitoring_delay: The delay in seconds to wait between monitoring
|
114
|
+
different nodes.
|
115
|
+
interrupt_check_interval: The interval in which the interrupt
|
116
|
+
function is called.
|
117
|
+
max_parallelism: The maximum number of nodes to run in parallel.
|
118
|
+
"""
|
119
|
+
self.nodes = {node.id: node for node in nodes}
|
120
|
+
self.startup_queue: queue.Queue[Node] = queue.Queue()
|
121
|
+
self.node_startup_function = node_startup_function
|
122
|
+
self.node_monitoring_function = node_monitoring_function
|
123
|
+
self.node_stop_function = node_stop_function
|
124
|
+
self.interrupt_function = interrupt_function
|
125
|
+
self.monitoring_thread = threading.Thread(
|
126
|
+
name="DagRunner-Monitoring-Loop",
|
127
|
+
target=self._monitoring_loop,
|
128
|
+
daemon=True,
|
129
|
+
)
|
130
|
+
self.monitoring_interval = monitoring_interval
|
131
|
+
self.monitoring_delay = monitoring_delay
|
132
|
+
self.interrupt_check_interval = interrupt_check_interval
|
133
|
+
self.max_parallelism = max_parallelism
|
134
|
+
self.shutdown_event = threading.Event()
|
135
|
+
self.startup_executor = ThreadPoolExecutor(
|
136
|
+
max_workers=10, thread_name_prefix="DagRunner-Startup"
|
137
|
+
)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def running_nodes(self) -> List[Node]:
|
141
|
+
"""Running nodes.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Running nodes.
|
145
|
+
"""
|
146
|
+
return [
|
147
|
+
node
|
148
|
+
for node in self.nodes.values()
|
149
|
+
if node.status == NodeStatus.RUNNING
|
150
|
+
]
|
151
|
+
|
152
|
+
@property
|
153
|
+
def active_nodes(self) -> List[Node]:
|
154
|
+
"""Active nodes.
|
155
|
+
|
156
|
+
Active nodes are nodes that are either running or starting.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Active nodes.
|
160
|
+
"""
|
161
|
+
return [
|
162
|
+
node
|
163
|
+
for node in self.nodes.values()
|
164
|
+
if node.status in {NodeStatus.RUNNING, NodeStatus.STARTING}
|
165
|
+
]
|
166
|
+
|
167
|
+
def _initialize_startup_queue(self) -> None:
|
168
|
+
"""Initialize the startup queue.
|
169
|
+
|
170
|
+
The startup queue contains all nodes that are ready to be started.
|
171
|
+
"""
|
172
|
+
for node in self.nodes.values():
|
173
|
+
if node.status in {NodeStatus.READY, NodeStatus.STARTING}:
|
174
|
+
self.startup_queue.put(node)
|
175
|
+
|
176
|
+
def _can_start_node(self, node: Node) -> bool:
|
177
|
+
"""Check if a node can be started.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
node: The node to check.
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
Whether the node can be started.
|
184
|
+
"""
|
185
|
+
return all(
|
186
|
+
self.nodes[upstream_node_id].status == NodeStatus.COMPLETED
|
187
|
+
for upstream_node_id in node.upstream_nodes
|
188
|
+
)
|
189
|
+
|
190
|
+
def _should_skip_node(self, node: Node) -> bool:
|
191
|
+
"""Check if a node should be skipped.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
node: The node to check.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
Whether the node should be skipped.
|
198
|
+
"""
|
199
|
+
return any(
|
200
|
+
self.nodes[upstream_node_id].status
|
201
|
+
in {NodeStatus.FAILED, NodeStatus.SKIPPED, NodeStatus.CANCELLED}
|
202
|
+
for upstream_node_id in node.upstream_nodes
|
203
|
+
)
|
204
|
+
|
205
|
+
def _start_node(self, node: Node) -> None:
|
206
|
+
"""Start a node.
|
207
|
+
|
208
|
+
This will start of a thread that will run the startup function.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
node: The node to start.
|
212
|
+
"""
|
213
|
+
node.status = NodeStatus.STARTING
|
214
|
+
|
215
|
+
def _start_node_task() -> None:
|
216
|
+
if self.shutdown_event.is_set():
|
217
|
+
logger.debug(
|
218
|
+
"Cancelling startup of node `%s` because shutdown was "
|
219
|
+
"requested.",
|
220
|
+
node.id,
|
221
|
+
)
|
222
|
+
return
|
223
|
+
|
224
|
+
try:
|
225
|
+
node.status = self.node_startup_function(node)
|
226
|
+
except Exception:
|
227
|
+
node.status = NodeStatus.FAILED
|
228
|
+
logger.exception("Node `%s` failed to start.", node.id)
|
229
|
+
else:
|
230
|
+
logger.info(
|
231
|
+
"Node `%s` started (status: %s)", node.id, node.status
|
232
|
+
)
|
233
|
+
|
234
|
+
self.startup_executor.submit(_start_node_task)
|
235
|
+
|
236
|
+
def _stop_node(self, node: Node) -> None:
|
237
|
+
"""Stop a node.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
node: The node to stop.
|
241
|
+
|
242
|
+
Raises:
|
243
|
+
RuntimeError: If the node stop function is not set.
|
244
|
+
"""
|
245
|
+
if not self.node_stop_function:
|
246
|
+
raise RuntimeError("Node stop function is not set.")
|
247
|
+
|
248
|
+
self.node_stop_function(node)
|
249
|
+
|
250
|
+
def _stop_all_nodes(self) -> None:
|
251
|
+
"""Stop all running nodes."""
|
252
|
+
for node in self.running_nodes:
|
253
|
+
self._stop_node(node)
|
254
|
+
node.status = NodeStatus.CANCELLED
|
255
|
+
|
256
|
+
def _process_nodes(self) -> bool:
|
257
|
+
"""Process the nodes.
|
258
|
+
|
259
|
+
This method will check if any nodes should be skipped or are ready to
|
260
|
+
run, in which case the node will be added to the startup queue.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Whether the DAG is finished.
|
264
|
+
"""
|
265
|
+
finished = True
|
266
|
+
|
267
|
+
for node in self.nodes.values():
|
268
|
+
if node.status == NodeStatus.NOT_READY:
|
269
|
+
if self._should_skip_node(node):
|
270
|
+
node.status = NodeStatus.SKIPPED
|
271
|
+
logger.warning(
|
272
|
+
"Skipping node `%s` because upstream node failed.",
|
273
|
+
node.id,
|
274
|
+
)
|
275
|
+
elif self._can_start_node(node):
|
276
|
+
node.status = NodeStatus.READY
|
277
|
+
self.startup_queue.put(node)
|
278
|
+
|
279
|
+
if not node.is_finished:
|
280
|
+
finished = False
|
281
|
+
|
282
|
+
# Start nodes until we reach the maximum configured parallelism
|
283
|
+
max_parallelism = self.max_parallelism or len(self.nodes)
|
284
|
+
while len(self.active_nodes) < max_parallelism:
|
285
|
+
try:
|
286
|
+
node = self.startup_queue.get_nowait()
|
287
|
+
except queue.Empty:
|
288
|
+
break
|
289
|
+
else:
|
290
|
+
self.startup_queue.task_done()
|
291
|
+
self._start_node(node)
|
292
|
+
|
293
|
+
return finished
|
294
|
+
|
295
|
+
def _monitoring_loop(self) -> None:
|
296
|
+
"""Monitoring loop.
|
297
|
+
|
298
|
+
This should run in a separate thread and monitors the running nodes.
|
299
|
+
"""
|
300
|
+
while not self.shutdown_event.is_set():
|
301
|
+
start_time = time.time()
|
302
|
+
for node in self.running_nodes:
|
303
|
+
try:
|
304
|
+
node.status = self.node_monitoring_function(node)
|
305
|
+
except Exception:
|
306
|
+
node.status = NodeStatus.FAILED
|
307
|
+
logger.exception("Node `%s` failed.", node.id)
|
308
|
+
else:
|
309
|
+
logger.debug(
|
310
|
+
"Node `%s` status updated to `%s`",
|
311
|
+
node.id,
|
312
|
+
node.status,
|
313
|
+
)
|
314
|
+
if node.status == NodeStatus.FAILED:
|
315
|
+
logger.error("Node `%s` failed.", node.id)
|
316
|
+
elif node.status == NodeStatus.COMPLETED:
|
317
|
+
logger.info("Node `%s` completed.", node.id)
|
318
|
+
|
319
|
+
time.sleep(self.monitoring_delay)
|
320
|
+
|
321
|
+
duration = time.time() - start_time
|
322
|
+
time_to_sleep = max(0, self.monitoring_interval - duration)
|
323
|
+
self.shutdown_event.wait(timeout=time_to_sleep)
|
324
|
+
|
325
|
+
def run(self) -> Dict[str, NodeStatus]:
|
326
|
+
"""Run the DAG.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
The final node states.
|
330
|
+
"""
|
331
|
+
self._initialize_startup_queue()
|
332
|
+
|
333
|
+
self.monitoring_thread.start()
|
334
|
+
|
335
|
+
interrupt_mode = None
|
336
|
+
last_interrupt_check = time.time()
|
337
|
+
|
338
|
+
while True:
|
339
|
+
if self.interrupt_function is not None:
|
340
|
+
if (
|
341
|
+
time.time() - last_interrupt_check
|
342
|
+
>= self.interrupt_check_interval
|
343
|
+
):
|
344
|
+
if interrupt_mode := self.interrupt_function():
|
345
|
+
logger.warning("DAG execution interrupted.")
|
346
|
+
break
|
347
|
+
last_interrupt_check = time.time()
|
348
|
+
|
349
|
+
is_finished = self._process_nodes()
|
350
|
+
if is_finished:
|
351
|
+
break
|
352
|
+
|
353
|
+
time.sleep(0.5)
|
354
|
+
|
355
|
+
self.shutdown_event.set()
|
356
|
+
if interrupt_mode == InterruptMode.FORCE:
|
357
|
+
# If a force interrupt was requested, we stop all running nodes.
|
358
|
+
self._stop_all_nodes()
|
359
|
+
|
360
|
+
self.monitoring_thread.join()
|
361
|
+
|
362
|
+
node_statuses = {
|
363
|
+
node_id: node.status for node_id, node in self.nodes.items()
|
364
|
+
}
|
365
|
+
logger.debug("Finished with node statuses: %s", node_statuses)
|
366
|
+
|
367
|
+
return node_statuses
|