zenml-nightly 0.80.2.dev20250414__py3-none-any.whl → 0.80.2.dev20250416__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/utils.py +7 -2
  3. zenml/cli/utils.py +13 -11
  4. zenml/config/compiler.py +1 -0
  5. zenml/config/global_config.py +1 -1
  6. zenml/config/pipeline_configurations.py +1 -0
  7. zenml/config/pipeline_run_configuration.py +1 -0
  8. zenml/config/server_config.py +7 -0
  9. zenml/constants.py +8 -0
  10. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +47 -5
  11. zenml/integrations/gcp/vertex_custom_job_parameters.py +15 -1
  12. zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +12 -0
  13. zenml/integrations/kubernetes/orchestrators/kube_utils.py +92 -0
  14. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +12 -3
  15. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -65
  16. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +11 -3
  17. zenml/logging/step_logging.py +41 -21
  18. zenml/login/credentials_store.py +31 -0
  19. zenml/materializers/path_materializer.py +17 -2
  20. zenml/models/v2/base/base.py +8 -4
  21. zenml/models/v2/base/filter.py +1 -1
  22. zenml/models/v2/core/pipeline_run.py +19 -0
  23. zenml/orchestrators/step_launcher.py +2 -3
  24. zenml/orchestrators/step_runner.py +2 -2
  25. zenml/orchestrators/utils.py +2 -5
  26. zenml/pipelines/pipeline_context.py +1 -0
  27. zenml/pipelines/pipeline_decorator.py +4 -0
  28. zenml/pipelines/pipeline_definition.py +83 -22
  29. zenml/pipelines/run_utils.py +4 -0
  30. zenml/steps/utils.py +1 -1
  31. zenml/utils/io_utils.py +23 -0
  32. zenml/zen_server/auth.py +96 -64
  33. zenml/zen_server/cloud_utils.py +7 -1
  34. zenml/zen_server/download_utils.py +123 -0
  35. zenml/zen_server/jwt.py +0 -14
  36. zenml/zen_server/rbac/rbac_interface.py +10 -3
  37. zenml/zen_server/rbac/utils.py +13 -3
  38. zenml/zen_server/rbac/zenml_cloud_rbac.py +14 -8
  39. zenml/zen_server/routers/artifact_version_endpoints.py +86 -3
  40. zenml/zen_server/routers/auth_endpoints.py +5 -36
  41. zenml/zen_server/routers/pipeline_deployments_endpoints.py +63 -26
  42. zenml/zen_server/routers/runs_endpoints.py +57 -0
  43. zenml/zen_server/routers/users_endpoints.py +13 -8
  44. zenml/zen_server/template_execution/utils.py +3 -3
  45. zenml/zen_stores/migrations/versions/ff538a321a92_migrate_onboarding_state.py +123 -0
  46. zenml/zen_stores/rest_zen_store.py +16 -13
  47. zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -0
  48. zenml/zen_stores/schemas/server_settings_schemas.py +4 -1
  49. zenml/zen_stores/sql_zen_store.py +18 -0
  50. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/METADATA +2 -1
  51. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/RECORD +54 -52
  52. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/LICENSE +0 -0
  53. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/WHEEL +0 -0
  54. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.80.2.dev20250414
1
+ 0.80.2.dev20250416
zenml/artifacts/utils.py CHANGED
@@ -35,7 +35,9 @@ from zenml.artifacts.preexisting_data_materializer import (
35
35
  PreexistingDataMaterializer,
36
36
  )
37
37
  from zenml.client import Client
38
- from zenml.constants import MODEL_METADATA_YAML_FILE_NAME
38
+ from zenml.constants import (
39
+ MODEL_METADATA_YAML_FILE_NAME,
40
+ )
39
41
  from zenml.enums import (
40
42
  ArtifactSaveType,
41
43
  ArtifactType,
@@ -43,7 +45,10 @@ from zenml.enums import (
43
45
  StackComponentType,
44
46
  VisualizationType,
45
47
  )
46
- from zenml.exceptions import DoesNotExistException, StepContextError
48
+ from zenml.exceptions import (
49
+ DoesNotExistException,
50
+ StepContextError,
51
+ )
47
52
  from zenml.io import fileio
48
53
  from zenml.logger import get_logger
49
54
  from zenml.metadata.metadata_types import validate_metadata
zenml/cli/utils.py CHANGED
@@ -308,13 +308,13 @@ def print_pydantic_models(
308
308
  if isinstance(model, BaseIdentifiedResponse):
309
309
  include_columns = ["id"]
310
310
 
311
- if "name" in model.model_fields:
311
+ if "name" in type(model).model_fields:
312
312
  include_columns.append("name")
313
313
 
314
314
  include_columns.extend(
315
315
  [
316
316
  k
317
- for k in model.get_body().model_fields.keys()
317
+ for k in type(model.get_body()).model_fields.keys()
318
318
  if k not in exclude_columns
319
319
  ]
320
320
  )
@@ -323,7 +323,9 @@ def print_pydantic_models(
323
323
  include_columns.extend(
324
324
  [
325
325
  k
326
- for k in model.get_metadata().model_fields.keys()
326
+ for k in type(
327
+ model.get_metadata()
328
+ ).model_fields.keys()
327
329
  if k not in exclude_columns
328
330
  ]
329
331
  )
@@ -347,7 +349,7 @@ def print_pydantic_models(
347
349
  # we want to attempt to represent them by name, if they contain
348
350
  # such a field, else the id is used
349
351
  if isinstance(value, BaseIdentifiedResponse):
350
- if "name" in value.model_fields:
352
+ if "name" in type(value).model_fields:
351
353
  items[k] = str(getattr(value, "name"))
352
354
  else:
353
355
  items[k] = str(value.id)
@@ -357,7 +359,7 @@ def print_pydantic_models(
357
359
  elif isinstance(value, list):
358
360
  for v in value:
359
361
  if isinstance(v, BaseIdentifiedResponse):
360
- if "name" in v.model_fields:
362
+ if "name" in type(v).model_fields:
361
363
  items.setdefault(k, []).append(
362
364
  str(getattr(v, "name"))
363
365
  )
@@ -448,13 +450,13 @@ def print_pydantic_model(
448
450
  if isinstance(model, BaseIdentifiedResponse):
449
451
  include_columns = ["id"]
450
452
 
451
- if "name" in model.model_fields:
453
+ if "name" in type(model).model_fields:
452
454
  include_columns.append("name")
453
455
 
454
456
  include_columns.extend(
455
457
  [
456
458
  k
457
- for k in model.get_body().model_fields.keys()
459
+ for k in type(model.get_body()).model_fields.keys()
458
460
  if k not in exclude_columns
459
461
  ]
460
462
  )
@@ -463,7 +465,7 @@ def print_pydantic_model(
463
465
  include_columns.extend(
464
466
  [
465
467
  k
466
- for k in model.get_metadata().model_fields.keys()
468
+ for k in type(model.get_metadata()).model_fields.keys()
467
469
  if k not in exclude_columns
468
470
  ]
469
471
  )
@@ -482,7 +484,7 @@ def print_pydantic_model(
482
484
  for k in include_columns:
483
485
  value = getattr(model, k)
484
486
  if isinstance(value, BaseIdentifiedResponse):
485
- if "name" in value.model_fields:
487
+ if "name" in type(value).model_fields:
486
488
  items[k] = str(getattr(value, "name"))
487
489
  else:
488
490
  items[k] = str(value.id)
@@ -492,7 +494,7 @@ def print_pydantic_model(
492
494
  elif isinstance(value, list):
493
495
  for v in value:
494
496
  if isinstance(v, BaseIdentifiedResponse):
495
- if "name" in v.model_fields:
497
+ if "name" in type(v).model_fields:
496
498
  items.setdefault(k, []).append(str(getattr(v, "name")))
497
499
  else:
498
500
  items.setdefault(k, []).append(str(v.id))
@@ -2138,7 +2140,7 @@ def _scrub_secret(config: StackComponentConfig) -> Dict[str, Any]:
2138
2140
  A configuration with secret values removed.
2139
2141
  """
2140
2142
  config_dict = {}
2141
- config_fields = config.__class__.model_fields
2143
+ config_fields = type(config).model_fields
2142
2144
  for key, value in config_fields.items():
2143
2145
  if getattr(config, key):
2144
2146
  if secret_utils.is_secret_field(value):
zenml/config/compiler.py CHANGED
@@ -210,6 +210,7 @@ class Compiler:
210
210
  enable_artifact_metadata=config.enable_artifact_metadata,
211
211
  enable_artifact_visualization=config.enable_artifact_visualization,
212
212
  enable_step_logs=config.enable_step_logs,
213
+ enable_pipeline_logs=config.enable_pipeline_logs,
213
214
  settings=config.settings,
214
215
  tags=config.tags,
215
216
  extra=config.extra,
@@ -447,7 +447,7 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
447
447
  """
448
448
  environment_vars = {}
449
449
 
450
- for key in self.model_fields.keys():
450
+ for key in type(self).model_fields.keys():
451
451
  if key == "store":
452
452
  # The store configuration uses its own environment variable
453
453
  # naming scheme
@@ -41,6 +41,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
41
41
  enable_artifact_metadata: Optional[bool] = None
42
42
  enable_artifact_visualization: Optional[bool] = None
43
43
  enable_step_logs: Optional[bool] = None
44
+ enable_pipeline_logs: Optional[bool] = None
44
45
  settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
45
46
  tags: Optional[List[Union[str, "Tag"]]] = None
46
47
  extra: Dict[str, Any] = {}
@@ -40,6 +40,7 @@ class PipelineRunConfiguration(
40
40
  enable_artifact_metadata: Optional[bool] = None
41
41
  enable_artifact_visualization: Optional[bool] = None
42
42
  enable_step_logs: Optional[bool] = None
43
+ enable_pipeline_logs: Optional[bool] = None
43
44
  schedule: Optional[Schedule] = None
44
45
  build: Union[PipelineBuildBase, UUID, None] = Field(
45
46
  default=None, union_mode="left_to_right"
@@ -34,6 +34,7 @@ from zenml.constants import (
34
34
  DEFAULT_ZENML_JWT_TOKEN_LEEWAY,
35
35
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
36
36
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
37
+ DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT,
37
38
  DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
38
39
  DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME,
39
40
  DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY,
@@ -245,6 +246,8 @@ class ServerConfiguration(BaseModel):
245
246
  memcache_default_expiry: The default expiry time in seconds for cache
246
247
  entries. If not specified, the default value of 30 seconds will be
247
248
  used.
249
+ file_download_size_limit: The maximum size of the file download in
250
+ bytes. If not specified, the default value of 2GB will be used.
248
251
  """
249
252
 
250
253
  deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
@@ -346,6 +349,10 @@ class ServerConfiguration(BaseModel):
346
349
  memcache_max_capacity: int = 1000
347
350
  memcache_default_expiry: int = 30
348
351
 
352
+ file_download_size_limit: int = (
353
+ DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT
354
+ )
355
+
349
356
  _deployment_id: Optional[UUID] = None
350
357
 
351
358
  @model_validator(mode="before")
zenml/constants.py CHANGED
@@ -168,6 +168,7 @@ ENV_ZENML_SKIP_STACK_VALIDATION = "ZENML_SKIP_STACK_VALIDATION"
168
168
  ENV_ZENML_SERVER = "ZENML_SERVER"
169
169
  ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS"
170
170
  ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS"
171
+ ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE = "ZENML_DISABLE_PIPELINE_LOGS_STORAGE"
171
172
  ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE"
172
173
  ENV_ZENML_DISABLE_STEP_NAMES_IN_LOGS = "ZENML_DISABLE_STEP_NAMES_IN_LOGS"
173
174
  ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
@@ -192,12 +193,16 @@ ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_"
192
193
  ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE"
193
194
  ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME"
194
195
  ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE"
196
+
195
197
  ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK = (
196
198
  "ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK"
197
199
  )
198
200
  ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING = "ZENML_PREVENT_CLIENT_SIDE_CACHING"
199
201
  ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING = "DISABLE_CREDENTIALS_DISK_CACHING"
200
202
  ENV_ZENML_RUNNER_IMAGE_DISABLE_UV = "ZENML_RUNNER_IMAGE_DISABLE_UV"
203
+ ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY = (
204
+ "ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY"
205
+ )
201
206
  # Logging variables
202
207
  IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False)
203
208
 
@@ -284,6 +289,7 @@ DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour
284
289
  DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME = (
285
290
  60 * 60 * 24 * 7
286
291
  ) # 7 days
292
+ DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT = 2 * 1024 * 1024 * 1024 # 20 GB
287
293
 
288
294
  DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = (
289
295
  "max-age=63072000; includeSubdomains"
@@ -350,10 +356,12 @@ CODE_REPOSITORIES = "/code_repositories"
350
356
  COMPONENT_TYPES = "/component-types"
351
357
  CONFIG = "/config"
352
358
  CURRENT_USER = "/current-user"
359
+ DATA = "/data"
353
360
  DEACTIVATE = "/deactivate"
354
361
  DEVICES = "/devices"
355
362
  DEVICE_AUTHORIZATION = "/device_authorization"
356
363
  DEVICE_VERIFY = "/verify"
364
+ DOWNLOAD_TOKEN = "/download-token"
357
365
  EMAIL_ANALYTICS = "/email-opt-in"
358
366
  EVENT_FLAVORS = "/event-flavors"
359
367
  EVENT_SOURCES = "/event-sources"
@@ -341,13 +341,55 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
341
341
  self.config.workload_service_account
342
342
  )
343
343
 
344
+ # Create a dictionary of explicit parameters
345
+ params = custom_job_parameters.model_dump(
346
+ exclude_none=True, exclude={"additional_training_job_args"}
347
+ )
348
+
349
+ # Remove None values to let defaults be set by the function
350
+ params = {k: v for k, v in params.items() if v is not None}
351
+
352
+ # Add environment variables
353
+ params["env"] = [
354
+ {"name": key, "value": value} for key, value in environment.items()
355
+ ]
356
+
357
+ # Check if any advanced parameters will override explicit parameters
358
+ if custom_job_parameters.additional_training_job_args:
359
+ overridden_params = set(params.keys()) & set(
360
+ custom_job_parameters.additional_training_job_args.keys()
361
+ )
362
+ if overridden_params:
363
+ logger.warning(
364
+ f"The following explicit parameters are being overridden by values in "
365
+ f"additional_training_job_args: {', '.join(overridden_params)}. "
366
+ f"This may lead to unexpected behavior. Consider using either explicit "
367
+ f"parameters or additional_training_job_args, but not both for the same parameters."
368
+ )
369
+
370
+ # Add any advanced parameters - these will override explicit parameters if provided
371
+ params.update(custom_job_parameters.additional_training_job_args)
372
+
373
+ # Add other parameters from orchestrator config if not already in params
374
+ if self.config.network and "network" not in params:
375
+ params["network"] = self.config.network
376
+
377
+ if (
378
+ self.config.encryption_spec_key_name
379
+ and "encryption_spec_key_name" not in params
380
+ ):
381
+ params["encryption_spec_key_name"] = (
382
+ self.config.encryption_spec_key_name
383
+ )
384
+ if (
385
+ self.config.workload_service_account
386
+ and "service_account" not in params
387
+ ):
388
+ params["service_account"] = self.config.workload_service_account
389
+
344
390
  custom_job_component = create_custom_training_job_from_component(
345
391
  component_spec=component,
346
- env=[
347
- {"name": key, "value": value}
348
- for key, value in environment.items()
349
- ],
350
- **custom_job_parameters.model_dump(),
392
+ **params,
351
393
  )
352
394
 
353
395
  return custom_job_component
@@ -13,7 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Vertex custom job parameter model."""
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Dict, Optional
17
17
 
18
18
  from pydantic import BaseModel
19
19
 
@@ -37,8 +37,21 @@ class VertexCustomJobParameters(BaseModel):
37
37
  boot_disk_type: Type of the boot disk. (Default: pd-ssd)
38
38
  https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
39
39
  persistent_resource_id: The ID of the persistent resource to use for the job.
40
+ If empty (default), the job will not use a persistent resource.
41
+ When using a persistent resource, you must also specify a service_account.
42
+ Conversely, when explicitly setting this to an empty string, you
43
+ should not specify a service_account (ZenML will handle this automatically).
40
44
  https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
41
45
  service_account: Specifies the service account to be used.
46
+ This is required when using a persistent_resource_id, and
47
+ should not be set when persistent_resource_id="".
48
+ additional_training_job_args: Additional arguments to pass to the create_custom_training_job_from_component
49
+ function. This allows passing any additional parameters supported by the Google
50
+ Cloud Pipeline Components library without requiring ZenML to update its API.
51
+ Note: If you specify parameters in this dictionary that are also defined as explicit
52
+ attributes (like machine_type or boot_disk_size_gb), the values in this dictionary
53
+ will override the explicit values.
54
+ See: https://google-cloud-pipeline-components.readthedocs.io/en/google-cloud-pipeline-components-2.19.0/api/v1/custom_job.html
42
55
  """
43
56
 
44
57
  accelerator_type: Optional[str] = None
@@ -48,3 +61,4 @@ class VertexCustomJobParameters(BaseModel):
48
61
  boot_disk_type: str = "pd-ssd"
49
62
  persistent_resource_id: Optional[str] = None
50
63
  service_account: Optional[str] = None
64
+ additional_training_job_args: Dict[str, Any] = {}
@@ -35,11 +35,23 @@ class KubernetesStepOperatorSettings(BaseSettings):
35
35
  pod_settings: Pod settings to apply to pods executing the steps.
36
36
  service_account_name: Name of the service account to use for the pod.
37
37
  privileged: If the container should be run in privileged mode.
38
+ pod_startup_timeout: The maximum time to wait for a pending step pod to
39
+ start (in seconds).
40
+ pod_failure_max_retries: The maximum number of times to retry a step
41
+ pod if the step Kubernetes pod fails to start
42
+ pod_failure_retry_delay: The delay in seconds between pod
43
+ failure retries and pod startup retries (in seconds)
44
+ pod_failure_backoff: The backoff factor for pod failure retries and
45
+ pod startup retries.
38
46
  """
39
47
 
40
48
  pod_settings: Optional[KubernetesPodSettings] = None
41
49
  service_account_name: Optional[str] = None
42
50
  privileged: bool = False
51
+ pod_startup_timeout: int = 60 * 10 # Default 10 minutes
52
+ pod_failure_max_retries: int = 3
53
+ pod_failure_retry_delay: int = 10
54
+ pod_failure_backoff: float = 1.0
43
55
 
44
56
 
45
57
  class KubernetesStepOperatorConfig(
@@ -462,3 +462,95 @@ def delete_secret(
462
462
  name=secret_name,
463
463
  namespace=namespace,
464
464
  )
465
+
466
+
467
+ def create_and_wait_for_pod_to_start(
468
+ core_api: k8s_client.CoreV1Api,
469
+ pod_display_name: str,
470
+ pod_name: str,
471
+ pod_manifest: k8s_client.V1Pod,
472
+ namespace: str,
473
+ startup_max_retries: int,
474
+ startup_failure_delay: float,
475
+ startup_failure_backoff: float,
476
+ startup_timeout: float,
477
+ ) -> None:
478
+ """Create a pod and wait for it to reach a desired state.
479
+
480
+ Args:
481
+ core_api: Client of Core V1 API of Kubernetes API.
482
+ pod_display_name: The display name of the pod to use in logs.
483
+ pod_name: The name of the pod to create.
484
+ pod_manifest: The manifest of the pod to create.
485
+ namespace: The namespace in which to create the pod.
486
+ startup_max_retries: The maximum number of retries for the pod startup.
487
+ startup_failure_delay: The delay between retries for the pod startup.
488
+ startup_failure_backoff: The backoff factor for the pod startup.
489
+ startup_timeout: The maximum time to wait for the pod to start.
490
+
491
+ Raises:
492
+ TimeoutError: If the pod is still in a pending state after the maximum
493
+ wait time has elapsed.
494
+ Exception: If the pod fails to start after the maximum number of
495
+ retries.
496
+ """
497
+ retries = 0
498
+
499
+ while retries < startup_max_retries:
500
+ try:
501
+ # Create and run pod.
502
+ core_api.create_namespaced_pod(
503
+ namespace=namespace,
504
+ body=pod_manifest,
505
+ )
506
+ break
507
+ except Exception as e:
508
+ retries += 1
509
+ if retries < startup_max_retries:
510
+ logger.debug(f"The {pod_display_name} failed to start: {e}")
511
+ logger.error(
512
+ f"Failed to create {pod_display_name}. "
513
+ f"Retrying in {startup_failure_delay} seconds..."
514
+ )
515
+ time.sleep(startup_failure_delay)
516
+ startup_failure_delay *= startup_failure_backoff
517
+ else:
518
+ logger.error(
519
+ f"Failed to create {pod_display_name} after "
520
+ f"{startup_max_retries} retries. Exiting."
521
+ )
522
+ raise
523
+
524
+ # Wait for pod to start
525
+ logger.info(f"Waiting for {pod_display_name} to start...")
526
+ max_wait = startup_timeout
527
+ total_wait: float = 0
528
+ delay = startup_failure_delay
529
+ while True:
530
+ pod = get_pod(
531
+ core_api=core_api,
532
+ pod_name=pod_name,
533
+ namespace=namespace,
534
+ )
535
+ if not pod or pod_is_not_pending(pod):
536
+ break
537
+ if total_wait >= max_wait:
538
+ # Have to delete the pending pod so it doesn't start running
539
+ # later on.
540
+ try:
541
+ core_api.delete_namespaced_pod(
542
+ name=pod_name,
543
+ namespace=namespace,
544
+ )
545
+ except Exception:
546
+ pass
547
+ raise TimeoutError(
548
+ f"The {pod_display_name} is still in a pending state "
549
+ f"after {total_wait} seconds. Exiting."
550
+ )
551
+
552
+ if total_wait + delay > max_wait:
553
+ delay = max_wait - total_wait
554
+ total_wait += delay
555
+ time.sleep(delay)
556
+ delay *= startup_failure_backoff
@@ -543,14 +543,23 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
543
543
  mount_local_stores=self.config.is_local,
544
544
  )
545
545
 
546
- self._k8s_core_api.create_namespaced_pod(
546
+ kube_utils.create_and_wait_for_pod_to_start(
547
+ core_api=self._k8s_core_api,
548
+ pod_display_name="Kubernetes orchestrator pod",
549
+ pod_name=pod_name,
550
+ pod_manifest=pod_manifest,
547
551
  namespace=self.config.kubernetes_namespace,
548
- body=pod_manifest,
552
+ startup_max_retries=settings.pod_failure_max_retries,
553
+ startup_failure_delay=settings.pod_failure_retry_delay,
554
+ startup_failure_backoff=settings.pod_failure_backoff,
555
+ startup_timeout=settings.pod_startup_timeout,
549
556
  )
550
557
 
551
558
  # Wait for the orchestrator pod to finish and stream logs.
552
559
  if settings.synchronous:
553
- logger.info("Waiting for Kubernetes orchestrator pod...")
560
+ logger.info(
561
+ "Waiting for Kubernetes orchestrator pod to finish..."
562
+ )
554
563
  kube_utils.wait_pod(
555
564
  kube_client_fn=self.get_kube_client,
556
565
  pod_name=pod_name,
@@ -15,7 +15,6 @@
15
15
 
16
16
  import argparse
17
17
  import socket
18
- import time
19
18
  from typing import Any, Dict
20
19
  from uuid import UUID
21
20
 
@@ -103,8 +102,6 @@ def main() -> None:
103
102
 
104
103
  Raises:
105
104
  Exception: If the pod fails to start.
106
- TimeoutError: If the pod is still in a pending state after the
107
- maximum wait time has elapsed.
108
105
  """
109
106
  # Define Kubernetes pod name.
110
107
  pod_name = f"{orchestrator_run_id}-{step_name}"
@@ -176,68 +173,17 @@ def main() -> None:
176
173
  mount_local_stores=mount_local_stores,
177
174
  )
178
175
 
179
- retries = 0
180
- max_retries = settings.pod_failure_max_retries
181
- delay: float = settings.pod_failure_retry_delay
182
- backoff = settings.pod_failure_backoff
183
-
184
- while retries < max_retries:
185
- try:
186
- # Create and run pod.
187
- core_api.create_namespaced_pod(
188
- namespace=args.kubernetes_namespace,
189
- body=pod_manifest,
190
- )
191
- break
192
- except Exception as e:
193
- retries += 1
194
- if retries < max_retries:
195
- logger.debug(
196
- f"Pod for step `{step_name}` failed to start: {e}"
197
- )
198
- logger.error(
199
- f"Failed to create pod for step `{step_name}`. "
200
- f"Retrying in {delay} seconds..."
201
- )
202
- time.sleep(delay)
203
- delay *= backoff
204
- else:
205
- logger.error(
206
- f"Failed to create pod for step `{step_name}` after "
207
- f"{max_retries} retries. Exiting."
208
- )
209
- raise
210
-
211
- # Wait for pod to start
212
- max_wait = settings.pod_startup_timeout
213
- total_wait: float = 0
214
- delay = settings.pod_failure_retry_delay
215
- while True:
216
- pod = kube_utils.get_pod(
217
- core_api, pod_name, args.kubernetes_namespace
218
- )
219
- if not pod or kube_utils.pod_is_not_pending(pod):
220
- break
221
- if total_wait >= max_wait:
222
- # Have to delete the pending pod so it doesn't start running
223
- # later on.
224
- try:
225
- core_api.delete_namespaced_pod(
226
- name=pod_name,
227
- namespace=args.kubernetes_namespace,
228
- )
229
- except Exception:
230
- pass
231
- raise TimeoutError(
232
- f"Pod for step `{step_name}` is still in a pending state "
233
- f"after {total_wait} seconds. Exiting."
234
- )
235
-
236
- if total_wait + delay > max_wait:
237
- delay = max_wait - total_wait
238
- total_wait += delay
239
- time.sleep(delay)
240
- delay *= backoff
176
+ kube_utils.create_and_wait_for_pod_to_start(
177
+ core_api=core_api,
178
+ pod_display_name=f"pod for step `{step_name}`",
179
+ pod_name=pod_name,
180
+ pod_manifest=pod_manifest,
181
+ namespace=args.kubernetes_namespace,
182
+ startup_max_retries=settings.pod_failure_max_retries,
183
+ startup_failure_delay=settings.pod_failure_retry_delay,
184
+ startup_failure_backoff=settings.pod_failure_backoff,
185
+ startup_timeout=settings.pod_startup_timeout,
186
+ )
241
187
 
242
188
  # Wait for pod to finish.
243
189
  logger.info(f"Waiting for pod of step `{step_name}` to finish...")
@@ -218,13 +218,21 @@ class KubernetesStepOperator(BaseStepOperator):
218
218
  mount_local_stores=False,
219
219
  )
220
220
 
221
- self._k8s_core_api.create_namespaced_pod(
221
+ kube_utils.create_and_wait_for_pod_to_start(
222
+ core_api=self._k8s_core_api,
223
+ pod_display_name=f"pod of step `{info.pipeline_step_name}`",
224
+ pod_name=pod_name,
225
+ pod_manifest=pod_manifest,
222
226
  namespace=self.config.kubernetes_namespace,
223
- body=pod_manifest,
227
+ startup_max_retries=settings.pod_failure_max_retries,
228
+ startup_failure_delay=settings.pod_failure_retry_delay,
229
+ startup_failure_backoff=settings.pod_failure_backoff,
230
+ startup_timeout=settings.pod_startup_timeout,
224
231
  )
225
232
 
226
233
  logger.info(
227
- "Waiting for pod of step `%s` to start...", info.pipeline_step_name
234
+ "Waiting for pod of step `%s` to finish...",
235
+ info.pipeline_step_name,
228
236
  )
229
237
  kube_utils.wait_pod(
230
238
  kube_client_fn=self.get_kube_client,