zenml-nightly 0.83.1.dev20250708__py3-none-any.whl → 0.83.1.dev20250710__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/login.py +141 -18
  3. zenml/cli/project.py +8 -6
  4. zenml/cli/utils.py +63 -16
  5. zenml/client.py +4 -1
  6. zenml/config/compiler.py +1 -0
  7. zenml/config/retry_config.py +5 -3
  8. zenml/config/step_configurations.py +7 -1
  9. zenml/console.py +4 -1
  10. zenml/constants.py +0 -1
  11. zenml/enums.py +13 -4
  12. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +58 -4
  13. zenml/integrations/kubernetes/orchestrators/kube_utils.py +172 -0
  14. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +37 -23
  15. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +92 -22
  16. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +59 -0
  17. zenml/logger.py +6 -4
  18. zenml/login/web_login.py +13 -6
  19. zenml/models/v2/core/model_version.py +9 -1
  20. zenml/models/v2/core/pipeline_run.py +1 -0
  21. zenml/models/v2/core/step_run.py +35 -1
  22. zenml/orchestrators/base_orchestrator.py +63 -8
  23. zenml/orchestrators/dag_runner.py +3 -1
  24. zenml/orchestrators/publish_utils.py +4 -1
  25. zenml/orchestrators/step_launcher.py +77 -139
  26. zenml/orchestrators/step_run_utils.py +16 -0
  27. zenml/orchestrators/step_runner.py +1 -4
  28. zenml/pipelines/pipeline_decorator.py +6 -1
  29. zenml/pipelines/pipeline_definition.py +7 -0
  30. zenml/zen_server/auth.py +0 -1
  31. zenml/zen_stores/migrations/versions/360fa84718bf_step_run_versioning.py +64 -0
  32. zenml/zen_stores/migrations/versions/85289fea86ff_adding_source_to_logs.py +1 -1
  33. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +21 -0
  34. zenml/zen_stores/schemas/pipeline_run_schemas.py +31 -2
  35. zenml/zen_stores/schemas/step_run_schemas.py +41 -17
  36. zenml/zen_stores/sql_zen_store.py +152 -32
  37. zenml/zen_stores/template_utils.py +29 -9
  38. zenml_nightly-0.83.1.dev20250710.dist-info/METADATA +499 -0
  39. {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/RECORD +42 -41
  40. zenml_nightly-0.83.1.dev20250708.dist-info/METADATA +0 -538
  41. {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/LICENSE +0 -0
  42. {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/WHEEL +0 -0
  43. {zenml_nightly-0.83.1.dev20250708.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/entry_points.txt +0 -0
zenml/login/web_login.py CHANGED
@@ -21,6 +21,7 @@ from typing import Optional, Union
21
21
  import requests
22
22
 
23
23
  from zenml import __version__
24
+ from zenml.cli import utils as cli_utils
24
25
  from zenml.config.global_config import GlobalConfiguration
25
26
  from zenml.constants import (
26
27
  API,
@@ -175,11 +176,17 @@ def web_login(
175
176
  # URL to it
176
177
  verification_uri = base_url + verification_uri
177
178
  webbrowser.open(verification_uri)
178
- logger.info(
179
- f"If your browser did not open automatically, please open the "
180
- f"following URL into your browser to proceed with the authentication:"
181
- f"\n\n{verification_uri}\n"
179
+
180
+ # Display the verification URL without panel styling
181
+ from zenml.console import console
182
+
183
+ console.print()
184
+ console.print(
185
+ "If your browser did not open automatically, please open the following URL into your browser to proceed with the authentication:",
186
+ style="white",
182
187
  )
188
+ console.print(verification_uri, style="bright_blue underline")
189
+ console.print()
183
190
 
184
191
  # Poll the OAuth2 server until the user has authorized the device
185
192
  token_request = OAuthDeviceTokenRequest(
@@ -201,9 +208,9 @@ def web_login(
201
208
  # The user has authorized the device, so we can extract the access token
202
209
  token_response = OAuthTokenResponse(**response.json())
203
210
  if zenml_pro:
204
- logger.info("Successfully logged in to ZenML Pro.")
211
+ cli_utils.success("Successfully logged in to ZenML Pro.")
205
212
  else:
206
- logger.info(f"Successfully logged in to {url}.")
213
+ cli_utils.success(f"Successfully logged in to {url}.")
207
214
  break
208
215
  elif response.status_code == 400:
209
216
  try:
@@ -13,6 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Models representing model versions."""
15
15
 
16
+ import json
16
17
  from typing import (
17
18
  TYPE_CHECKING,
18
19
  ClassVar,
@@ -411,10 +412,17 @@ class ModelVersionResponse(
411
412
 
412
413
  from zenml.client import Client
413
414
 
415
+ data_artifact_types = [
416
+ value
417
+ for value in ArtifactType.values()
418
+ if value
419
+ not in [ArtifactType.MODEL.value, ArtifactType.SERVICE.value]
420
+ ]
421
+
414
422
  artifact_versions = pagination_utils.depaginate(
415
423
  Client().list_artifact_versions,
416
424
  model_version_id=self.id,
417
- type=ArtifactType.DATA,
425
+ type="oneof:" + json.dumps(data_artifact_types),
418
426
  project=self.project_id,
419
427
  )
420
428
 
@@ -497,6 +497,7 @@ class PipelineRunResponse(
497
497
  Client().list_run_steps,
498
498
  pipeline_run_id=self.id,
499
499
  project=self.project_id,
500
+ exclude_retried=True,
500
501
  )
501
502
  }
502
503
 
@@ -177,6 +177,12 @@ class StepRunResponseBody(ProjectScopedResponseBody):
177
177
  """Response body for step runs."""
178
178
 
179
179
  status: ExecutionStatus = Field(title="The status of the step.")
180
+ version: int = Field(
181
+ title="The version of the step run.",
182
+ )
183
+ is_retriable: bool = Field(
184
+ title="Whether the step run is retriable.",
185
+ )
180
186
  start_time: Optional[datetime] = Field(
181
187
  title="The start time of the step run.",
182
188
  default=None,
@@ -420,6 +426,24 @@ class StepRunResponse(
420
426
  """
421
427
  return self.get_body().status
422
428
 
429
+ @property
430
+ def version(self) -> int:
431
+ """The `version` property.
432
+
433
+ Returns:
434
+ the value of the property.
435
+ """
436
+ return self.get_body().version
437
+
438
+ @property
439
+ def is_retriable(self) -> bool:
440
+ """The `is_retriable` property.
441
+
442
+ Returns:
443
+ the value of the property.
444
+ """
445
+ return self.get_body().is_retriable
446
+
423
447
  @property
424
448
  def inputs(self) -> Dict[str, List[StepRunInputResponse]]:
425
449
  """The `inputs` property.
@@ -602,6 +626,7 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
602
626
  *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS,
603
627
  *RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS,
604
628
  "model",
629
+ "exclude_retried",
605
630
  ]
606
631
  CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
607
632
  *ProjectScopedFilter.CLI_EXCLUDE_FIELDS,
@@ -666,6 +691,10 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
666
691
  default=None,
667
692
  description="Name/ID of the model associated with the step run.",
668
693
  )
694
+ exclude_retried: Optional[bool] = Field(
695
+ default=None,
696
+ description="Whether to exclude retried step runs.",
697
+ )
669
698
  model_config = ConfigDict(protected_namespaces=())
670
699
 
671
700
  def get_custom_filters(
@@ -681,7 +710,7 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
681
710
  """
682
711
  custom_filters = super().get_custom_filters(table)
683
712
 
684
- from sqlmodel import and_
713
+ from sqlmodel import and_, col
685
714
 
686
715
  from zenml.zen_stores.schemas import (
687
716
  ModelSchema,
@@ -699,4 +728,9 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin):
699
728
  )
700
729
  custom_filters.append(model_filter)
701
730
 
731
+ if self.exclude_retried:
732
+ custom_filters.append(
733
+ col(StepRunSchema.status) != ExecutionStatus.RETRIED.value
734
+ )
735
+
702
736
  return custom_filters
@@ -13,6 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Base orchestrator class."""
15
15
 
16
+ import time
16
17
  from abc import ABC, abstractmethod
17
18
  from typing import (
18
19
  TYPE_CHECKING,
@@ -33,7 +34,7 @@ from zenml.constants import (
33
34
  handle_bool_env_var,
34
35
  )
35
36
  from zenml.enums import ExecutionStatus, StackComponentType
36
- from zenml.exceptions import RunMonitoringError
37
+ from zenml.exceptions import RunMonitoringError, RunStoppedException
37
38
  from zenml.logger import get_logger
38
39
  from zenml.metadata.metadata_types import MetadataType
39
40
  from zenml.orchestrators.publish_utils import (
@@ -127,6 +128,15 @@ class BaseOrchestratorConfig(StackComponentConfig):
127
128
  """
128
129
  return True
129
130
 
131
+ @property
132
+ def handles_step_retries(self) -> bool:
133
+ """Whether the orchestrator handles step retries.
134
+
135
+ Returns:
136
+ Whether the orchestrator handles step retries.
137
+ """
138
+ return False
139
+
130
140
 
131
141
  class BaseOrchestrator(StackComponent, ABC):
132
142
  """Base class for all orchestrators."""
@@ -346,14 +356,59 @@ class BaseOrchestrator(StackComponent, ABC):
346
356
 
347
357
  Args:
348
358
  step: The step to run.
359
+
360
+ Raises:
361
+ RunStoppedException: If the run was stopped.
362
+ BaseException: If the step failed all retries.
349
363
  """
350
- assert self._active_deployment
351
- launcher = StepLauncher(
352
- deployment=self._active_deployment,
353
- step=step,
354
- orchestrator_run_id=self.get_orchestrator_run_id(),
355
- )
356
- launcher.launch()
364
+
365
+ def _launch_step() -> None:
366
+ assert self._active_deployment
367
+
368
+ launcher = StepLauncher(
369
+ deployment=self._active_deployment,
370
+ step=step,
371
+ orchestrator_run_id=self.get_orchestrator_run_id(),
372
+ )
373
+ launcher.launch()
374
+
375
+ if self.config.handles_step_retries:
376
+ _launch_step()
377
+ else:
378
+ # The orchestrator subclass doesn't handle step retries, so we
379
+ # handle it in-process instead
380
+ retries = 0
381
+ retry_config = step.config.retry
382
+ max_retries = retry_config.max_retries if retry_config else 0
383
+ delay = retry_config.delay if retry_config else 0
384
+ backoff = retry_config.backoff if retry_config else 1
385
+
386
+ while retries <= max_retries:
387
+ try:
388
+ _launch_step()
389
+ except RunStoppedException:
390
+ # Don't retry if the run was stopped
391
+ raise
392
+ except BaseException:
393
+ retries += 1
394
+ if retries <= max_retries:
395
+ logger.info(
396
+ "Sleeping for %d seconds before retrying step `%s`.",
397
+ delay,
398
+ step.config.name,
399
+ )
400
+ time.sleep(delay)
401
+ delay *= backoff
402
+ else:
403
+ if max_retries > 0:
404
+ logger.error(
405
+ "Failed to run step `%s` after %d retries.",
406
+ step.config.name,
407
+ max_retries,
408
+ )
409
+ raise
410
+ else:
411
+ break
357
412
 
358
413
  @staticmethod
359
414
  def requires_resources_in_orchestration_environment(
@@ -217,7 +217,9 @@ class ThreadedDagRunner:
217
217
  self.node_states[node] = NodeStatus.PENDING
218
218
 
219
219
  # Run node in new thread.
220
- thread = threading.Thread(target=self._run_node, args=(node,))
220
+ thread = threading.Thread(
221
+ name=node, target=self._run_node, args=(node,)
222
+ )
221
223
  thread.start()
222
224
  return thread
223
225
 
@@ -190,7 +190,10 @@ def get_pipeline_run_status(
190
190
  return ExecutionStatus.FAILED
191
191
 
192
192
  # If there is a running step, the run is running
193
- elif ExecutionStatus.RUNNING in step_statuses:
193
+ elif (
194
+ ExecutionStatus.RUNNING in step_statuses
195
+ or ExecutionStatus.RETRYING in step_statuses
196
+ ):
194
197
  return ExecutionStatus.RUNNING
195
198
 
196
199
  # If there are less steps than the total number of steps, it is running
@@ -13,7 +13,6 @@
13
13
  # permissions and limitations under the License.
14
14
  """Class to launch (run directly or using a step operator) steps."""
15
15
 
16
- import os
17
16
  import signal
18
17
  import time
19
18
  from contextlib import nullcontext
@@ -25,7 +24,6 @@ from zenml.config.step_configurations import Step
25
24
  from zenml.config.step_run_info import StepRunInfo
26
25
  from zenml.constants import (
27
26
  ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
28
- ENV_ZENML_IGNORE_FAILURE_HOOK,
29
27
  handle_bool_env_var,
30
28
  )
31
29
  from zenml.enums import ExecutionStatus
@@ -245,141 +243,93 @@ class StepLauncher:
245
243
  artifact_store_id=self._stack.artifact_store.id,
246
244
  )
247
245
 
248
- try:
249
- with logs_context:
250
- if run_was_created:
251
- pipeline_run_metadata = (
252
- self._stack.get_pipeline_run_metadata(
253
- run_id=pipeline_run.id
254
- )
246
+ with logs_context:
247
+ if run_was_created:
248
+ pipeline_run_metadata = self._stack.get_pipeline_run_metadata(
249
+ run_id=pipeline_run.id
250
+ )
251
+ publish_utils.publish_pipeline_run_metadata(
252
+ pipeline_run_id=pipeline_run.id,
253
+ pipeline_run_metadata=pipeline_run_metadata,
254
+ )
255
+ if model_version := pipeline_run.model_version:
256
+ step_run_utils.log_model_version_dashboard_url(
257
+ model_version=model_version
255
258
  )
256
- publish_utils.publish_pipeline_run_metadata(
257
- pipeline_run_id=pipeline_run.id,
258
- pipeline_run_metadata=pipeline_run_metadata,
259
+
260
+ request_factory = step_run_utils.StepRunRequestFactory(
261
+ deployment=self._deployment,
262
+ pipeline_run=pipeline_run,
263
+ stack=self._stack,
264
+ )
265
+ step_run_request = request_factory.create_request(
266
+ invocation_id=self._step_name
267
+ )
268
+ step_run_request.logs = logs_model
269
+
270
+ try:
271
+ request_factory.populate_request(request=step_run_request)
272
+ except:
273
+ logger.exception(f"Failed preparing step `{self._step_name}`.")
274
+ step_run_request.status = ExecutionStatus.FAILED
275
+ step_run_request.end_time = utc_now()
276
+ raise
277
+ finally:
278
+ step_run = Client().zen_store.create_run_step(step_run_request)
279
+ self._step_run = step_run
280
+ if model_version := step_run.model_version:
281
+ step_run_utils.log_model_version_dashboard_url(
282
+ model_version=model_version
259
283
  )
260
- if model_version := pipeline_run.model_version:
261
- step_run_utils.log_model_version_dashboard_url(
262
- model_version=model_version
263
- )
264
284
 
265
- request_factory = step_run_utils.StepRunRequestFactory(
266
- deployment=self._deployment,
267
- pipeline_run=pipeline_run,
268
- stack=self._stack,
269
- )
270
- step_run_request = request_factory.create_request(
271
- invocation_id=self._step_name
272
- )
273
- step_run_request.logs = logs_model
285
+ if not step_run.status.is_finished:
286
+ logger.info(f"Step `{self._step_name}` has started.")
274
287
 
275
288
  try:
276
- request_factory.populate_request(request=step_run_request)
277
- except:
278
- logger.exception(
279
- f"Failed preparing step `{self._step_name}`."
280
- )
281
- step_run_request.status = ExecutionStatus.FAILED
282
- step_run_request.end_time = utc_now()
283
- raise
284
- finally:
285
- step_run = Client().zen_store.create_run_step(
286
- step_run_request
287
- )
288
- # Store step run ID for signal handler
289
- self._step_run = step_run
290
- if model_version := step_run.model_version:
291
- step_run_utils.log_model_version_dashboard_url(
292
- model_version=model_version
289
+ # here pass a forced save_to_file callable to be
290
+ # used as a dump function to use before starting
291
+ # the external jobs in step operators
292
+ if isinstance(
293
+ logs_context,
294
+ step_logging.PipelineLogsStorageContext,
295
+ ):
296
+ force_write_logs = partial(
297
+ logs_context.storage.save_to_file,
298
+ force=True,
293
299
  )
300
+ else:
294
301
 
295
- if not step_run.status.is_finished:
296
- logger.info(f"Step `{self._step_name}` has started.")
297
- retries = 0
298
- last_retry = True
299
- max_retries = (
300
- step_run.config.retry.max_retries
301
- if step_run.config.retry
302
- else 1
303
- )
304
- delay = (
305
- step_run.config.retry.delay
306
- if step_run.config.retry
307
- else 0
302
+ def _bypass() -> None:
303
+ return None
304
+
305
+ force_write_logs = _bypass
306
+ self._run_step(
307
+ pipeline_run=pipeline_run,
308
+ step_run=step_run,
309
+ force_write_logs=force_write_logs,
308
310
  )
309
- backoff = (
310
- step_run.config.retry.backoff
311
- if step_run.config.retry
312
- else 1
311
+ except RunStoppedException as e:
312
+ raise e
313
+ except BaseException as e: # noqa: E722
314
+ logger.error(
315
+ "Failed to run step `%s`: %s",
316
+ self._step_name,
317
+ e,
313
318
  )
314
-
315
- while retries < max_retries:
316
- last_retry = retries == max_retries - 1
317
- try:
318
- # here pass a forced save_to_file callable to be
319
- # used as a dump function to use before starting
320
- # the external jobs in step operators
321
- if isinstance(
322
- logs_context,
323
- step_logging.PipelineLogsStorageContext,
324
- ):
325
- force_write_logs = partial(
326
- logs_context.storage.save_to_file,
327
- force=True,
328
- )
329
- else:
330
-
331
- def _bypass() -> None:
332
- return None
333
-
334
- force_write_logs = _bypass
335
- self._run_step(
336
- pipeline_run=pipeline_run,
337
- step_run=step_run,
338
- last_retry=last_retry,
339
- force_write_logs=force_write_logs,
340
- )
341
- break
342
- except RunStoppedException as e:
343
- raise e
344
- except BaseException as e: # noqa: E722
345
- retries += 1
346
- if retries < max_retries:
347
- logger.error(
348
- f"Failed to run step `{self._step_name}`. Retrying..."
349
- )
350
- logger.exception(e)
351
- logger.info(
352
- f"Sleeping for {delay} seconds before retrying."
353
- )
354
- time.sleep(delay)
355
- delay *= backoff
356
- else:
357
- logger.error(
358
- f"Failed to run step `{self._step_name}` after {max_retries} retries. Exiting."
359
- )
360
- logger.exception(e)
361
- publish_utils.publish_failed_step_run(
362
- step_run.id
363
- )
364
- raise
365
- else:
366
- logger.info(
367
- f"Using cached version of step `{self._step_name}`."
319
+ publish_utils.publish_failed_step_run(step_run.id)
320
+ raise
321
+ else:
322
+ logger.info(
323
+ f"Using cached version of step `{self._step_name}`."
324
+ )
325
+ if (
326
+ model_version := step_run.model_version
327
+ or pipeline_run.model_version
328
+ ):
329
+ step_run_utils.link_output_artifacts_to_model_version(
330
+ artifacts=step_run.outputs,
331
+ model_version=model_version,
368
332
  )
369
- if (
370
- model_version := step_run.model_version
371
- or pipeline_run.model_version
372
- ):
373
- step_run_utils.link_output_artifacts_to_model_version(
374
- artifacts=step_run.outputs,
375
- model_version=model_version,
376
- )
377
- except RunStoppedException:
378
- logger.info(f"Pipeline run `{pipeline_run.name}` stopped.")
379
- raise
380
- except: # noqa: E722
381
- logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
382
- raise
383
333
 
384
334
  def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
385
335
  """Creates a pipeline run or reuses an existing one.
@@ -421,7 +371,6 @@ class StepLauncher:
421
371
  pipeline_run: PipelineRunResponse,
422
372
  step_run: StepRunResponse,
423
373
  force_write_logs: Callable[..., Any],
424
- last_retry: bool = True,
425
374
  ) -> None:
426
375
  """Runs the current step.
427
376
 
@@ -429,7 +378,6 @@ class StepLauncher:
429
378
  pipeline_run: The model of the current pipeline run.
430
379
  step_run: The model of the current step run.
431
380
  force_write_logs: The context for the step logs.
432
- last_retry: Whether this is the last retry of the step.
433
381
  """
434
382
  # Prepare step run information.
435
383
  step_run_info = StepRunInfo(
@@ -457,7 +405,6 @@ class StepLauncher:
457
405
  self._run_step_with_step_operator(
458
406
  step_operator_name=step_operator_name,
459
407
  step_run_info=step_run_info,
460
- last_retry=last_retry,
461
408
  )
462
409
  else:
463
410
  self._run_step_without_step_operator(
@@ -466,7 +413,6 @@ class StepLauncher:
466
413
  step_run_info=step_run_info,
467
414
  input_artifacts=step_run.regular_inputs,
468
415
  output_artifact_uris=output_artifact_uris,
469
- last_retry=last_retry,
470
416
  )
471
417
  except: # noqa: E722
472
418
  output_utils.remove_artifact_dirs(
@@ -484,14 +430,12 @@ class StepLauncher:
484
430
  self,
485
431
  step_operator_name: Optional[str],
486
432
  step_run_info: StepRunInfo,
487
- last_retry: bool,
488
433
  ) -> None:
489
434
  """Runs the current step with a step operator.
490
435
 
491
436
  Args:
492
437
  step_operator_name: The name of the step operator to use.
493
438
  step_run_info: Additional information needed to run the step.
494
- last_retry: Whether this is the last retry of the step.
495
439
  """
496
440
  step_operator = _get_step_operator(
497
441
  stack=self._stack,
@@ -509,8 +453,6 @@ class StepLauncher:
509
453
  environment = orchestrator_utils.get_config_environment_vars(
510
454
  pipeline_run_id=step_run_info.run_id,
511
455
  )
512
- if last_retry:
513
- environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
514
456
  logger.info(
515
457
  "Using step operator `%s` to run step `%s`.",
516
458
  step_operator.name,
@@ -529,7 +471,6 @@ class StepLauncher:
529
471
  step_run_info: StepRunInfo,
530
472
  input_artifacts: Dict[str, StepRunInputResponse],
531
473
  output_artifact_uris: Dict[str, str],
532
- last_retry: bool,
533
474
  ) -> None:
534
475
  """Runs the current step without a step operator.
535
476
 
@@ -539,10 +480,7 @@ class StepLauncher:
539
480
  step_run_info: Additional information needed to run the step.
540
481
  input_artifacts: The input artifact versions of the current step.
541
482
  output_artifact_uris: The output artifact URIs of the current step.
542
- last_retry: Whether this is the last retry of the step.
543
483
  """
544
- if last_retry:
545
- os.environ[ENV_ZENML_IGNORE_FAILURE_HOOK] = "false"
546
484
  runner = StepRunner(step=self._step, stack=self._stack)
547
485
  runner.run(
548
486
  pipeline_run=pipeline_run,
@@ -59,6 +59,22 @@ class StepRunRequestFactory:
59
59
  self.pipeline_run = pipeline_run
60
60
  self.stack = stack
61
61
 
62
+ def has_caching_enabled(self, invocation_id: str) -> bool:
63
+ """Check if the step has caching enabled.
64
+
65
+ Args:
66
+ invocation_id: The invocation ID for which to check if caching is
67
+ enabled.
68
+
69
+ Returns:
70
+ Whether the step has caching enabled.
71
+ """
72
+ step = self.deployment.step_configurations[invocation_id]
73
+ return utils.is_setting_enabled(
74
+ is_enabled_on_step=step.config.enable_cache,
75
+ is_enabled_on_pipeline=self.deployment.pipeline_configuration.enable_cache,
76
+ )
77
+
62
78
  def create_request(self, invocation_id: str) -> StepRunRequest:
63
79
  """Create a step run request.
64
80
 
@@ -34,7 +34,6 @@ from zenml.config.step_configurations import StepConfiguration
34
34
  from zenml.config.step_run_info import StepRunInfo
35
35
  from zenml.constants import (
36
36
  ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
37
- ENV_ZENML_IGNORE_FAILURE_HOOK,
38
37
  handle_bool_env_var,
39
38
  )
40
39
  from zenml.enums import ArtifactSaveType
@@ -194,9 +193,7 @@ class StepRunner:
194
193
  )
195
194
  except BaseException as step_exception: # noqa: E722
196
195
  step_failed = True
197
- if not handle_bool_env_var(
198
- ENV_ZENML_IGNORE_FAILURE_HOOK, False
199
- ):
196
+ if not step_run.is_retriable:
200
197
  if (
201
198
  failure_hook_source
202
199
  := self.configuration.failure_hook_source
@@ -29,6 +29,7 @@ from zenml.logger import get_logger
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from zenml.config.base_settings import SettingsOrDict
32
+ from zenml.config.retry_config import StepRetryConfig
32
33
  from zenml.model.model import Model
33
34
  from zenml.pipelines.pipeline_definition import Pipeline
34
35
  from zenml.types import HookSpecification
@@ -57,6 +58,7 @@ def pipeline(
57
58
  on_failure: Optional["HookSpecification"] = None,
58
59
  on_success: Optional["HookSpecification"] = None,
59
60
  model: Optional["Model"] = None,
61
+ retry: Optional["StepRetryConfig"] = None,
60
62
  substitutions: Optional[Dict[str, str]] = None,
61
63
  ) -> Callable[["F"], "Pipeline"]: ...
62
64
 
@@ -75,6 +77,7 @@ def pipeline(
75
77
  on_failure: Optional["HookSpecification"] = None,
76
78
  on_success: Optional["HookSpecification"] = None,
77
79
  model: Optional["Model"] = None,
80
+ retry: Optional["StepRetryConfig"] = None,
78
81
  substitutions: Optional[Dict[str, str]] = None,
79
82
  ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]:
80
83
  """Decorator to create a pipeline.
@@ -97,6 +100,7 @@ def pipeline(
97
100
  function with no arguments, or a source path to such a function
98
101
  (e.g. `module.my_function`).
99
102
  model: configuration of the model in the Model Control Plane.
103
+ retry: Retry configuration for the pipeline steps.
100
104
  substitutions: Extra placeholders to use in the name templates.
101
105
 
102
106
  Returns:
@@ -108,6 +112,7 @@ def pipeline(
108
112
 
109
113
  p = Pipeline(
110
114
  name=name or func.__name__,
115
+ entrypoint=func,
111
116
  enable_cache=enable_cache,
112
117
  enable_artifact_metadata=enable_artifact_metadata,
113
118
  enable_step_logs=enable_step_logs,
@@ -118,7 +123,7 @@ def pipeline(
118
123
  on_failure=on_failure,
119
124
  on_success=on_success,
120
125
  model=model,
121
- entrypoint=func,
126
+ retry=retry,
122
127
  substitutions=substitutions,
123
128
  )
124
129