zenml-nightly 0.83.1.dev20250708__py3-none-any.whl → 0.83.1.dev20250710__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/login.py +141 -18
- zenml/cli/project.py +8 -6
- zenml/cli/utils.py +63 -16
- zenml/client.py +4 -1
- zenml/config/compiler.py +1 -0
- zenml/config/retry_config.py +5 -3
- zenml/config/step_configurations.py +7 -1
- zenml/console.py +4 -1
- zenml/constants.py +0 -1
- zenml/enums.py +13 -4
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +58 -4
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +172 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +37 -23
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +92 -22
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +59 -0
- zenml/logger.py +6 -4
- zenml/login/web_login.py +13 -6
- zenml/models/v2/core/model_version.py +9 -1
- zenml/models/v2/core/pipeline_run.py +1 -0
- zenml/models/v2/core/step_run.py +35 -1
- zenml/orchestrators/base_orchestrator.py +63 -8
- zenml/orchestrators/dag_runner.py +3 -1
- zenml/orchestrators/publish_utils.py +4 -1
- zenml/orchestrators/step_launcher.py +77 -139
- zenml/orchestrators/step_run_utils.py +16 -0
- zenml/orchestrators/step_runner.py +1 -4
- zenml/pipelines/pipeline_decorator.py +6 -1
- zenml/pipelines/pipeline_definition.py +7 -0
- zenml/zen_server/auth.py +0 -1
- zenml/zen_stores/migrations/versions/360fa84718bf_step_run_versioning.py +64 -0
- zenml/zen_stores/migrations/versions/85289fea86ff_adding_source_to_logs.py +1 -1
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +21 -0
- zenml/zen_stores/schemas/pipeline_run_schemas.py +31 -2
- zenml/zen_stores/schemas/step_run_schemas.py +41 -17
- zenml/zen_stores/sql_zen_store.py +152 -32
- zenml/zen_stores/template_utils.py +29 -9
- zenml_nightly-0.83.1.dev20250710.dist-info/METADATA +499 -0
- {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/RECORD +42 -41
- zenml_nightly-0.83.1.dev20250708.dist-info/METADATA +0 -538
- {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/entry_points.txt +0 -0
zenml/login/web_login.py
CHANGED
@@ -21,6 +21,7 @@ from typing import Optional, Union
|
|
21
21
|
import requests
|
22
22
|
|
23
23
|
from zenml import __version__
|
24
|
+
from zenml.cli import utils as cli_utils
|
24
25
|
from zenml.config.global_config import GlobalConfiguration
|
25
26
|
from zenml.constants import (
|
26
27
|
API,
|
@@ -175,11 +176,17 @@ def web_login(
|
|
175
176
|
# URL to it
|
176
177
|
verification_uri = base_url + verification_uri
|
177
178
|
webbrowser.open(verification_uri)
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
179
|
+
|
180
|
+
# Display the verification URL without panel styling
|
181
|
+
from zenml.console import console
|
182
|
+
|
183
|
+
console.print()
|
184
|
+
console.print(
|
185
|
+
"If your browser did not open automatically, please open the following URL into your browser to proceed with the authentication:",
|
186
|
+
style="white",
|
182
187
|
)
|
188
|
+
console.print(verification_uri, style="bright_blue underline")
|
189
|
+
console.print()
|
183
190
|
|
184
191
|
# Poll the OAuth2 server until the user has authorized the device
|
185
192
|
token_request = OAuthDeviceTokenRequest(
|
@@ -201,9 +208,9 @@ def web_login(
|
|
201
208
|
# The user has authorized the device, so we can extract the access token
|
202
209
|
token_response = OAuthTokenResponse(**response.json())
|
203
210
|
if zenml_pro:
|
204
|
-
|
211
|
+
cli_utils.success("✔ Successfully logged in to ZenML Pro.")
|
205
212
|
else:
|
206
|
-
|
213
|
+
cli_utils.success(f"✔ Successfully logged in to {url}.")
|
207
214
|
break
|
208
215
|
elif response.status_code == 400:
|
209
216
|
try:
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Models representing model versions."""
|
15
15
|
|
16
|
+
import json
|
16
17
|
from typing import (
|
17
18
|
TYPE_CHECKING,
|
18
19
|
ClassVar,
|
@@ -411,10 +412,17 @@ class ModelVersionResponse(
|
|
411
412
|
|
412
413
|
from zenml.client import Client
|
413
414
|
|
415
|
+
data_artifact_types = [
|
416
|
+
value
|
417
|
+
for value in ArtifactType.values()
|
418
|
+
if value
|
419
|
+
not in [ArtifactType.MODEL.value, ArtifactType.SERVICE.value]
|
420
|
+
]
|
421
|
+
|
414
422
|
artifact_versions = pagination_utils.depaginate(
|
415
423
|
Client().list_artifact_versions,
|
416
424
|
model_version_id=self.id,
|
417
|
-
type=
|
425
|
+
type="oneof:" + json.dumps(data_artifact_types),
|
418
426
|
project=self.project_id,
|
419
427
|
)
|
420
428
|
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -177,6 +177,12 @@ class StepRunResponseBody(ProjectScopedResponseBody):
|
|
177
177
|
"""Response body for step runs."""
|
178
178
|
|
179
179
|
status: ExecutionStatus = Field(title="The status of the step.")
|
180
|
+
version: int = Field(
|
181
|
+
title="The version of the step run.",
|
182
|
+
)
|
183
|
+
is_retriable: bool = Field(
|
184
|
+
title="Whether the step run is retriable.",
|
185
|
+
)
|
180
186
|
start_time: Optional[datetime] = Field(
|
181
187
|
title="The start time of the step run.",
|
182
188
|
default=None,
|
@@ -420,6 +426,24 @@ class StepRunResponse(
|
|
420
426
|
"""
|
421
427
|
return self.get_body().status
|
422
428
|
|
429
|
+
@property
|
430
|
+
def version(self) -> int:
|
431
|
+
"""The `version` property.
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
the value of the property.
|
435
|
+
"""
|
436
|
+
return self.get_body().version
|
437
|
+
|
438
|
+
@property
|
439
|
+
def is_retriable(self) -> bool:
|
440
|
+
"""The `is_retriable` property.
|
441
|
+
|
442
|
+
Returns:
|
443
|
+
the value of the property.
|
444
|
+
"""
|
445
|
+
return self.get_body().is_retriable
|
446
|
+
|
423
447
|
@property
|
424
448
|
def inputs(self) -> Dict[str, List[StepRunInputResponse]]:
|
425
449
|
"""The `inputs` property.
|
@@ -602,6 +626,7 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
|
|
602
626
|
*ProjectScopedFilter.FILTER_EXCLUDE_FIELDS,
|
603
627
|
*RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS,
|
604
628
|
"model",
|
629
|
+
"exclude_retried",
|
605
630
|
]
|
606
631
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
607
632
|
*ProjectScopedFilter.CLI_EXCLUDE_FIELDS,
|
@@ -666,6 +691,10 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
|
|
666
691
|
default=None,
|
667
692
|
description="Name/ID of the model associated with the step run.",
|
668
693
|
)
|
694
|
+
exclude_retried: Optional[bool] = Field(
|
695
|
+
default=None,
|
696
|
+
description="Whether to exclude retried step runs.",
|
697
|
+
)
|
669
698
|
model_config = ConfigDict(protected_namespaces=())
|
670
699
|
|
671
700
|
def get_custom_filters(
|
@@ -681,7 +710,7 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
|
|
681
710
|
"""
|
682
711
|
custom_filters = super().get_custom_filters(table)
|
683
712
|
|
684
|
-
from sqlmodel import and_
|
713
|
+
from sqlmodel import and_, col
|
685
714
|
|
686
715
|
from zenml.zen_stores.schemas import (
|
687
716
|
ModelSchema,
|
@@ -699,4 +728,9 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
|
|
699
728
|
)
|
700
729
|
custom_filters.append(model_filter)
|
701
730
|
|
731
|
+
if self.exclude_retried:
|
732
|
+
custom_filters.append(
|
733
|
+
col(StepRunSchema.status) != ExecutionStatus.RETRIED.value
|
734
|
+
)
|
735
|
+
|
702
736
|
return custom_filters
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Base orchestrator class."""
|
15
15
|
|
16
|
+
import time
|
16
17
|
from abc import ABC, abstractmethod
|
17
18
|
from typing import (
|
18
19
|
TYPE_CHECKING,
|
@@ -33,7 +34,7 @@ from zenml.constants import (
|
|
33
34
|
handle_bool_env_var,
|
34
35
|
)
|
35
36
|
from zenml.enums import ExecutionStatus, StackComponentType
|
36
|
-
from zenml.exceptions import RunMonitoringError
|
37
|
+
from zenml.exceptions import RunMonitoringError, RunStoppedException
|
37
38
|
from zenml.logger import get_logger
|
38
39
|
from zenml.metadata.metadata_types import MetadataType
|
39
40
|
from zenml.orchestrators.publish_utils import (
|
@@ -127,6 +128,15 @@ class BaseOrchestratorConfig(StackComponentConfig):
|
|
127
128
|
"""
|
128
129
|
return True
|
129
130
|
|
131
|
+
@property
|
132
|
+
def handles_step_retries(self) -> bool:
|
133
|
+
"""Whether the orchestrator handles step retries.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
Whether the orchestrator handles step retries.
|
137
|
+
"""
|
138
|
+
return False
|
139
|
+
|
130
140
|
|
131
141
|
class BaseOrchestrator(StackComponent, ABC):
|
132
142
|
"""Base class for all orchestrators."""
|
@@ -346,14 +356,59 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
346
356
|
|
347
357
|
Args:
|
348
358
|
step: The step to run.
|
359
|
+
|
360
|
+
Raises:
|
361
|
+
RunStoppedException: If the run was stopped.
|
362
|
+
BaseException: If the step failed all retries.
|
349
363
|
"""
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
364
|
+
|
365
|
+
def _launch_step() -> None:
|
366
|
+
assert self._active_deployment
|
367
|
+
|
368
|
+
launcher = StepLauncher(
|
369
|
+
deployment=self._active_deployment,
|
370
|
+
step=step,
|
371
|
+
orchestrator_run_id=self.get_orchestrator_run_id(),
|
372
|
+
)
|
373
|
+
launcher.launch()
|
374
|
+
|
375
|
+
if self.config.handles_step_retries:
|
376
|
+
_launch_step()
|
377
|
+
else:
|
378
|
+
# The orchestrator subclass doesn't handle step retries, so we
|
379
|
+
# handle it in-process instead
|
380
|
+
retries = 0
|
381
|
+
retry_config = step.config.retry
|
382
|
+
max_retries = retry_config.max_retries if retry_config else 0
|
383
|
+
delay = retry_config.delay if retry_config else 0
|
384
|
+
backoff = retry_config.backoff if retry_config else 1
|
385
|
+
|
386
|
+
while retries <= max_retries:
|
387
|
+
try:
|
388
|
+
_launch_step()
|
389
|
+
except RunStoppedException:
|
390
|
+
# Don't retry if the run was stopped
|
391
|
+
raise
|
392
|
+
except BaseException:
|
393
|
+
retries += 1
|
394
|
+
if retries <= max_retries:
|
395
|
+
logger.info(
|
396
|
+
"Sleeping for %d seconds before retrying step `%s`.",
|
397
|
+
delay,
|
398
|
+
step.config.name,
|
399
|
+
)
|
400
|
+
time.sleep(delay)
|
401
|
+
delay *= backoff
|
402
|
+
else:
|
403
|
+
if max_retries > 0:
|
404
|
+
logger.error(
|
405
|
+
"Failed to run step `%s` after %d retries.",
|
406
|
+
step.config.name,
|
407
|
+
max_retries,
|
408
|
+
)
|
409
|
+
raise
|
410
|
+
else:
|
411
|
+
break
|
357
412
|
|
358
413
|
@staticmethod
|
359
414
|
def requires_resources_in_orchestration_environment(
|
@@ -217,7 +217,9 @@ class ThreadedDagRunner:
|
|
217
217
|
self.node_states[node] = NodeStatus.PENDING
|
218
218
|
|
219
219
|
# Run node in new thread.
|
220
|
-
thread = threading.Thread(
|
220
|
+
thread = threading.Thread(
|
221
|
+
name=node, target=self._run_node, args=(node,)
|
222
|
+
)
|
221
223
|
thread.start()
|
222
224
|
return thread
|
223
225
|
|
@@ -190,7 +190,10 @@ def get_pipeline_run_status(
|
|
190
190
|
return ExecutionStatus.FAILED
|
191
191
|
|
192
192
|
# If there is a running step, the run is running
|
193
|
-
elif
|
193
|
+
elif (
|
194
|
+
ExecutionStatus.RUNNING in step_statuses
|
195
|
+
or ExecutionStatus.RETRYING in step_statuses
|
196
|
+
):
|
194
197
|
return ExecutionStatus.RUNNING
|
195
198
|
|
196
199
|
# If there are less steps than the total number of steps, it is running
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Class to launch (run directly or using a step operator) steps."""
|
15
15
|
|
16
|
-
import os
|
17
16
|
import signal
|
18
17
|
import time
|
19
18
|
from contextlib import nullcontext
|
@@ -25,7 +24,6 @@ from zenml.config.step_configurations import Step
|
|
25
24
|
from zenml.config.step_run_info import StepRunInfo
|
26
25
|
from zenml.constants import (
|
27
26
|
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
|
28
|
-
ENV_ZENML_IGNORE_FAILURE_HOOK,
|
29
27
|
handle_bool_env_var,
|
30
28
|
)
|
31
29
|
from zenml.enums import ExecutionStatus
|
@@ -245,141 +243,93 @@ class StepLauncher:
|
|
245
243
|
artifact_store_id=self._stack.artifact_store.id,
|
246
244
|
)
|
247
245
|
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
246
|
+
with logs_context:
|
247
|
+
if run_was_created:
|
248
|
+
pipeline_run_metadata = self._stack.get_pipeline_run_metadata(
|
249
|
+
run_id=pipeline_run.id
|
250
|
+
)
|
251
|
+
publish_utils.publish_pipeline_run_metadata(
|
252
|
+
pipeline_run_id=pipeline_run.id,
|
253
|
+
pipeline_run_metadata=pipeline_run_metadata,
|
254
|
+
)
|
255
|
+
if model_version := pipeline_run.model_version:
|
256
|
+
step_run_utils.log_model_version_dashboard_url(
|
257
|
+
model_version=model_version
|
255
258
|
)
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
+
|
260
|
+
request_factory = step_run_utils.StepRunRequestFactory(
|
261
|
+
deployment=self._deployment,
|
262
|
+
pipeline_run=pipeline_run,
|
263
|
+
stack=self._stack,
|
264
|
+
)
|
265
|
+
step_run_request = request_factory.create_request(
|
266
|
+
invocation_id=self._step_name
|
267
|
+
)
|
268
|
+
step_run_request.logs = logs_model
|
269
|
+
|
270
|
+
try:
|
271
|
+
request_factory.populate_request(request=step_run_request)
|
272
|
+
except:
|
273
|
+
logger.exception(f"Failed preparing step `{self._step_name}`.")
|
274
|
+
step_run_request.status = ExecutionStatus.FAILED
|
275
|
+
step_run_request.end_time = utc_now()
|
276
|
+
raise
|
277
|
+
finally:
|
278
|
+
step_run = Client().zen_store.create_run_step(step_run_request)
|
279
|
+
self._step_run = step_run
|
280
|
+
if model_version := step_run.model_version:
|
281
|
+
step_run_utils.log_model_version_dashboard_url(
|
282
|
+
model_version=model_version
|
259
283
|
)
|
260
|
-
if model_version := pipeline_run.model_version:
|
261
|
-
step_run_utils.log_model_version_dashboard_url(
|
262
|
-
model_version=model_version
|
263
|
-
)
|
264
284
|
|
265
|
-
|
266
|
-
|
267
|
-
pipeline_run=pipeline_run,
|
268
|
-
stack=self._stack,
|
269
|
-
)
|
270
|
-
step_run_request = request_factory.create_request(
|
271
|
-
invocation_id=self._step_name
|
272
|
-
)
|
273
|
-
step_run_request.logs = logs_model
|
285
|
+
if not step_run.status.is_finished:
|
286
|
+
logger.info(f"Step `{self._step_name}` has started.")
|
274
287
|
|
275
288
|
try:
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
step_run_request
|
287
|
-
)
|
288
|
-
# Store step run ID for signal handler
|
289
|
-
self._step_run = step_run
|
290
|
-
if model_version := step_run.model_version:
|
291
|
-
step_run_utils.log_model_version_dashboard_url(
|
292
|
-
model_version=model_version
|
289
|
+
# here pass a forced save_to_file callable to be
|
290
|
+
# used as a dump function to use before starting
|
291
|
+
# the external jobs in step operators
|
292
|
+
if isinstance(
|
293
|
+
logs_context,
|
294
|
+
step_logging.PipelineLogsStorageContext,
|
295
|
+
):
|
296
|
+
force_write_logs = partial(
|
297
|
+
logs_context.storage.save_to_file,
|
298
|
+
force=True,
|
293
299
|
)
|
300
|
+
else:
|
294
301
|
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
)
|
304
|
-
delay = (
|
305
|
-
step_run.config.retry.delay
|
306
|
-
if step_run.config.retry
|
307
|
-
else 0
|
302
|
+
def _bypass() -> None:
|
303
|
+
return None
|
304
|
+
|
305
|
+
force_write_logs = _bypass
|
306
|
+
self._run_step(
|
307
|
+
pipeline_run=pipeline_run,
|
308
|
+
step_run=step_run,
|
309
|
+
force_write_logs=force_write_logs,
|
308
310
|
)
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
311
|
+
except RunStoppedException as e:
|
312
|
+
raise e
|
313
|
+
except BaseException as e: # noqa: E722
|
314
|
+
logger.error(
|
315
|
+
"Failed to run step `%s`: %s",
|
316
|
+
self._step_name,
|
317
|
+
e,
|
313
318
|
)
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
force=True,
|
328
|
-
)
|
329
|
-
else:
|
330
|
-
|
331
|
-
def _bypass() -> None:
|
332
|
-
return None
|
333
|
-
|
334
|
-
force_write_logs = _bypass
|
335
|
-
self._run_step(
|
336
|
-
pipeline_run=pipeline_run,
|
337
|
-
step_run=step_run,
|
338
|
-
last_retry=last_retry,
|
339
|
-
force_write_logs=force_write_logs,
|
340
|
-
)
|
341
|
-
break
|
342
|
-
except RunStoppedException as e:
|
343
|
-
raise e
|
344
|
-
except BaseException as e: # noqa: E722
|
345
|
-
retries += 1
|
346
|
-
if retries < max_retries:
|
347
|
-
logger.error(
|
348
|
-
f"Failed to run step `{self._step_name}`. Retrying..."
|
349
|
-
)
|
350
|
-
logger.exception(e)
|
351
|
-
logger.info(
|
352
|
-
f"Sleeping for {delay} seconds before retrying."
|
353
|
-
)
|
354
|
-
time.sleep(delay)
|
355
|
-
delay *= backoff
|
356
|
-
else:
|
357
|
-
logger.error(
|
358
|
-
f"Failed to run step `{self._step_name}` after {max_retries} retries. Exiting."
|
359
|
-
)
|
360
|
-
logger.exception(e)
|
361
|
-
publish_utils.publish_failed_step_run(
|
362
|
-
step_run.id
|
363
|
-
)
|
364
|
-
raise
|
365
|
-
else:
|
366
|
-
logger.info(
|
367
|
-
f"Using cached version of step `{self._step_name}`."
|
319
|
+
publish_utils.publish_failed_step_run(step_run.id)
|
320
|
+
raise
|
321
|
+
else:
|
322
|
+
logger.info(
|
323
|
+
f"Using cached version of step `{self._step_name}`."
|
324
|
+
)
|
325
|
+
if (
|
326
|
+
model_version := step_run.model_version
|
327
|
+
or pipeline_run.model_version
|
328
|
+
):
|
329
|
+
step_run_utils.link_output_artifacts_to_model_version(
|
330
|
+
artifacts=step_run.outputs,
|
331
|
+
model_version=model_version,
|
368
332
|
)
|
369
|
-
if (
|
370
|
-
model_version := step_run.model_version
|
371
|
-
or pipeline_run.model_version
|
372
|
-
):
|
373
|
-
step_run_utils.link_output_artifacts_to_model_version(
|
374
|
-
artifacts=step_run.outputs,
|
375
|
-
model_version=model_version,
|
376
|
-
)
|
377
|
-
except RunStoppedException:
|
378
|
-
logger.info(f"Pipeline run `{pipeline_run.name}` stopped.")
|
379
|
-
raise
|
380
|
-
except: # noqa: E722
|
381
|
-
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
|
382
|
-
raise
|
383
333
|
|
384
334
|
def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
|
385
335
|
"""Creates a pipeline run or reuses an existing one.
|
@@ -421,7 +371,6 @@ class StepLauncher:
|
|
421
371
|
pipeline_run: PipelineRunResponse,
|
422
372
|
step_run: StepRunResponse,
|
423
373
|
force_write_logs: Callable[..., Any],
|
424
|
-
last_retry: bool = True,
|
425
374
|
) -> None:
|
426
375
|
"""Runs the current step.
|
427
376
|
|
@@ -429,7 +378,6 @@ class StepLauncher:
|
|
429
378
|
pipeline_run: The model of the current pipeline run.
|
430
379
|
step_run: The model of the current step run.
|
431
380
|
force_write_logs: The context for the step logs.
|
432
|
-
last_retry: Whether this is the last retry of the step.
|
433
381
|
"""
|
434
382
|
# Prepare step run information.
|
435
383
|
step_run_info = StepRunInfo(
|
@@ -457,7 +405,6 @@ class StepLauncher:
|
|
457
405
|
self._run_step_with_step_operator(
|
458
406
|
step_operator_name=step_operator_name,
|
459
407
|
step_run_info=step_run_info,
|
460
|
-
last_retry=last_retry,
|
461
408
|
)
|
462
409
|
else:
|
463
410
|
self._run_step_without_step_operator(
|
@@ -466,7 +413,6 @@ class StepLauncher:
|
|
466
413
|
step_run_info=step_run_info,
|
467
414
|
input_artifacts=step_run.regular_inputs,
|
468
415
|
output_artifact_uris=output_artifact_uris,
|
469
|
-
last_retry=last_retry,
|
470
416
|
)
|
471
417
|
except: # noqa: E722
|
472
418
|
output_utils.remove_artifact_dirs(
|
@@ -484,14 +430,12 @@ class StepLauncher:
|
|
484
430
|
self,
|
485
431
|
step_operator_name: Optional[str],
|
486
432
|
step_run_info: StepRunInfo,
|
487
|
-
last_retry: bool,
|
488
433
|
) -> None:
|
489
434
|
"""Runs the current step with a step operator.
|
490
435
|
|
491
436
|
Args:
|
492
437
|
step_operator_name: The name of the step operator to use.
|
493
438
|
step_run_info: Additional information needed to run the step.
|
494
|
-
last_retry: Whether this is the last retry of the step.
|
495
439
|
"""
|
496
440
|
step_operator = _get_step_operator(
|
497
441
|
stack=self._stack,
|
@@ -509,8 +453,6 @@ class StepLauncher:
|
|
509
453
|
environment = orchestrator_utils.get_config_environment_vars(
|
510
454
|
pipeline_run_id=step_run_info.run_id,
|
511
455
|
)
|
512
|
-
if last_retry:
|
513
|
-
environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
|
514
456
|
logger.info(
|
515
457
|
"Using step operator `%s` to run step `%s`.",
|
516
458
|
step_operator.name,
|
@@ -529,7 +471,6 @@ class StepLauncher:
|
|
529
471
|
step_run_info: StepRunInfo,
|
530
472
|
input_artifacts: Dict[str, StepRunInputResponse],
|
531
473
|
output_artifact_uris: Dict[str, str],
|
532
|
-
last_retry: bool,
|
533
474
|
) -> None:
|
534
475
|
"""Runs the current step without a step operator.
|
535
476
|
|
@@ -539,10 +480,7 @@ class StepLauncher:
|
|
539
480
|
step_run_info: Additional information needed to run the step.
|
540
481
|
input_artifacts: The input artifact versions of the current step.
|
541
482
|
output_artifact_uris: The output artifact URIs of the current step.
|
542
|
-
last_retry: Whether this is the last retry of the step.
|
543
483
|
"""
|
544
|
-
if last_retry:
|
545
|
-
os.environ[ENV_ZENML_IGNORE_FAILURE_HOOK] = "false"
|
546
484
|
runner = StepRunner(step=self._step, stack=self._stack)
|
547
485
|
runner.run(
|
548
486
|
pipeline_run=pipeline_run,
|
@@ -59,6 +59,22 @@ class StepRunRequestFactory:
|
|
59
59
|
self.pipeline_run = pipeline_run
|
60
60
|
self.stack = stack
|
61
61
|
|
62
|
+
def has_caching_enabled(self, invocation_id: str) -> bool:
|
63
|
+
"""Check if the step has caching enabled.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
invocation_id: The invocation ID for which to check if caching is
|
67
|
+
enabled.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
Whether the step has caching enabled.
|
71
|
+
"""
|
72
|
+
step = self.deployment.step_configurations[invocation_id]
|
73
|
+
return utils.is_setting_enabled(
|
74
|
+
is_enabled_on_step=step.config.enable_cache,
|
75
|
+
is_enabled_on_pipeline=self.deployment.pipeline_configuration.enable_cache,
|
76
|
+
)
|
77
|
+
|
62
78
|
def create_request(self, invocation_id: str) -> StepRunRequest:
|
63
79
|
"""Create a step run request.
|
64
80
|
|
@@ -34,7 +34,6 @@ from zenml.config.step_configurations import StepConfiguration
|
|
34
34
|
from zenml.config.step_run_info import StepRunInfo
|
35
35
|
from zenml.constants import (
|
36
36
|
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
|
37
|
-
ENV_ZENML_IGNORE_FAILURE_HOOK,
|
38
37
|
handle_bool_env_var,
|
39
38
|
)
|
40
39
|
from zenml.enums import ArtifactSaveType
|
@@ -194,9 +193,7 @@ class StepRunner:
|
|
194
193
|
)
|
195
194
|
except BaseException as step_exception: # noqa: E722
|
196
195
|
step_failed = True
|
197
|
-
if not
|
198
|
-
ENV_ZENML_IGNORE_FAILURE_HOOK, False
|
199
|
-
):
|
196
|
+
if not step_run.is_retriable:
|
200
197
|
if (
|
201
198
|
failure_hook_source
|
202
199
|
:= self.configuration.failure_hook_source
|
@@ -29,6 +29,7 @@ from zenml.logger import get_logger
|
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from zenml.config.base_settings import SettingsOrDict
|
32
|
+
from zenml.config.retry_config import StepRetryConfig
|
32
33
|
from zenml.model.model import Model
|
33
34
|
from zenml.pipelines.pipeline_definition import Pipeline
|
34
35
|
from zenml.types import HookSpecification
|
@@ -57,6 +58,7 @@ def pipeline(
|
|
57
58
|
on_failure: Optional["HookSpecification"] = None,
|
58
59
|
on_success: Optional["HookSpecification"] = None,
|
59
60
|
model: Optional["Model"] = None,
|
61
|
+
retry: Optional["StepRetryConfig"] = None,
|
60
62
|
substitutions: Optional[Dict[str, str]] = None,
|
61
63
|
) -> Callable[["F"], "Pipeline"]: ...
|
62
64
|
|
@@ -75,6 +77,7 @@ def pipeline(
|
|
75
77
|
on_failure: Optional["HookSpecification"] = None,
|
76
78
|
on_success: Optional["HookSpecification"] = None,
|
77
79
|
model: Optional["Model"] = None,
|
80
|
+
retry: Optional["StepRetryConfig"] = None,
|
78
81
|
substitutions: Optional[Dict[str, str]] = None,
|
79
82
|
) -> Union["Pipeline", Callable[["F"], "Pipeline"]]:
|
80
83
|
"""Decorator to create a pipeline.
|
@@ -97,6 +100,7 @@ def pipeline(
|
|
97
100
|
function with no arguments, or a source path to such a function
|
98
101
|
(e.g. `module.my_function`).
|
99
102
|
model: configuration of the model in the Model Control Plane.
|
103
|
+
retry: Retry configuration for the pipeline steps.
|
100
104
|
substitutions: Extra placeholders to use in the name templates.
|
101
105
|
|
102
106
|
Returns:
|
@@ -108,6 +112,7 @@ def pipeline(
|
|
108
112
|
|
109
113
|
p = Pipeline(
|
110
114
|
name=name or func.__name__,
|
115
|
+
entrypoint=func,
|
111
116
|
enable_cache=enable_cache,
|
112
117
|
enable_artifact_metadata=enable_artifact_metadata,
|
113
118
|
enable_step_logs=enable_step_logs,
|
@@ -118,7 +123,7 @@ def pipeline(
|
|
118
123
|
on_failure=on_failure,
|
119
124
|
on_success=on_success,
|
120
125
|
model=model,
|
121
|
-
|
126
|
+
retry=retry,
|
122
127
|
substitutions=substitutions,
|
123
128
|
)
|
124
129
|
|