zenml-nightly 0.75.0.dev20250314__py3-none-any.whl → 0.75.0.dev20250315__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/login.py +21 -18
  3. zenml/cli/server.py +5 -5
  4. zenml/client.py +5 -1
  5. zenml/config/server_config.py +9 -9
  6. zenml/integrations/gcp/__init__.py +1 -0
  7. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +5 -0
  8. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +5 -28
  9. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +125 -78
  10. zenml/integrations/gcp/vertex_custom_job_parameters.py +50 -0
  11. zenml/login/credentials.py +26 -27
  12. zenml/login/credentials_store.py +5 -5
  13. zenml/login/pro/client.py +9 -9
  14. zenml/login/pro/utils.py +8 -8
  15. zenml/login/pro/{tenant → workspace}/__init__.py +1 -1
  16. zenml/login/pro/{tenant → workspace}/client.py +25 -25
  17. zenml/login/pro/{tenant → workspace}/models.py +27 -28
  18. zenml/models/v2/core/model.py +9 -1
  19. zenml/models/v2/core/tag.py +96 -3
  20. zenml/models/v2/misc/server_models.py +6 -6
  21. zenml/orchestrators/step_run_utils.py +1 -1
  22. zenml/utils/dashboard_utils.py +1 -1
  23. zenml/utils/tag_utils.py +0 -12
  24. zenml/zen_server/cloud_utils.py +3 -3
  25. zenml/zen_server/feature_gate/endpoint_utils.py +1 -1
  26. zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
  27. zenml/zen_server/rbac/models.py +30 -5
  28. zenml/zen_server/rbac/zenml_cloud_rbac.py +7 -70
  29. zenml/zen_server/routers/server_endpoints.py +2 -2
  30. zenml/zen_server/zen_server_api.py +3 -3
  31. zenml/zen_stores/base_zen_store.py +3 -3
  32. zenml/zen_stores/rest_zen_store.py +1 -1
  33. zenml/zen_stores/sql_zen_store.py +60 -7
  34. {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/METADATA +1 -1
  35. {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/RECORD +38 -37
  36. {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/LICENSE +0 -0
  37. {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/WHEEL +0 -0
  38. {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.75.0.dev20250314
1
+ 0.75.0.dev20250315
zenml/cli/login.py CHANGED
@@ -246,7 +246,7 @@ def connect_to_pro_server(
246
246
  """
247
247
  from zenml.login.credentials_store import get_credentials_store
248
248
  from zenml.login.pro.client import ZenMLProClient
249
- from zenml.login.pro.tenant.models import TenantStatus
249
+ from zenml.login.pro.workspace.models import WorkspaceStatus
250
250
 
251
251
  pro_api_url = pro_api_url or ZENML_PRO_API_URL
252
252
  pro_api_url = pro_api_url.rstrip("/")
@@ -295,7 +295,7 @@ def connect_to_pro_server(
295
295
  # We also need to remove all existing API tokens associated with the
296
296
  # target ZenML Pro API, otherwise they will continue to be used after
297
297
  # the re-login flow.
298
- credentials_store.clear_all_pro_tokens(pro_api_url)
298
+ credentials_store.clear_all_pro_tokens()
299
299
  try:
300
300
  token = web_login(
301
301
  pro_api_url=pro_api_url,
@@ -310,13 +310,14 @@ def connect_to_pro_server(
310
310
  "your session expires."
311
311
  )
312
312
 
313
- tenant_id: Optional[str] = None
313
+ workspace_id: Optional[str] = None
314
314
  if token.device_metadata:
315
- tenant_id = token.device_metadata.get("tenant_id")
315
+ # TODO: is this still correct?
316
+ workspace_id = token.device_metadata.get("tenant_id")
316
317
 
317
- if tenant_id is None and pro_server is None:
318
+ if workspace_id is None and pro_server is None:
318
319
  # This is not really supposed to happen, because the implementation
319
- # of the web login workflow should always return a tenant ID, but
320
+ # of the web login workflow should always return a workspace ID, but
320
321
  # we're handling it just in case.
321
322
  cli_utils.declare(
322
323
  "A valid server was not selected during the login process. "
@@ -328,14 +329,14 @@ def connect_to_pro_server(
328
329
 
329
330
  # The server selected during the web login process overrides any
330
331
  # server argument passed to the command.
331
- server_id = UUID(tenant_id)
332
+ server_id = UUID(workspace_id)
332
333
 
333
334
  client = ZenMLProClient(pro_api_url)
334
335
 
335
336
  if server_id:
336
- server = client.tenant.get(server_id)
337
+ server = client.workspace.get(server_id)
337
338
  elif server_url:
338
- servers = client.tenant.list(url=server_url, member_only=True)
339
+ servers = client.workspace.list(url=server_url, member_only=True)
339
340
  if not servers:
340
341
  raise AuthorizationException(
341
342
  f"The '{server_url}' URL belongs to a ZenML Pro server, "
@@ -345,7 +346,9 @@ def connect_to_pro_server(
345
346
 
346
347
  server = servers[0]
347
348
  elif server_name:
348
- servers = client.tenant.list(tenant_name=server_name, member_only=True)
349
+ servers = client.workspace.list(
350
+ workspace_name=server_name, member_only=True
351
+ )
349
352
  if not servers:
350
353
  raise AuthorizationException(
351
354
  f"No ZenML Pro server with the name '{server_name}' exists "
@@ -361,15 +364,15 @@ def connect_to_pro_server(
361
364
 
362
365
  server_id = server.id
363
366
 
364
- if server.status == TenantStatus.PENDING:
367
+ if server.status == WorkspaceStatus.PENDING:
365
368
  with console.status(
366
369
  f"Waiting for your `{server.name}` ZenML Pro server to be set up..."
367
370
  ):
368
371
  timeout = 180 # 3 minutes
369
372
  while True:
370
373
  time.sleep(5)
371
- server = client.tenant.get(server_id)
372
- if server.status != TenantStatus.PENDING:
374
+ server = client.workspace.get(server_id)
375
+ if server.status != WorkspaceStatus.PENDING:
373
376
  break
374
377
  timeout -= 5
375
378
  if timeout <= 0:
@@ -380,7 +383,7 @@ def connect_to_pro_server(
380
383
  f"ZenML Pro dashboard at {server.dashboard_url}."
381
384
  )
382
385
 
383
- if server.status == TenantStatus.FAILED:
386
+ if server.status == WorkspaceStatus.FAILED:
384
387
  cli_utils.error(
385
388
  f"Your `{server.name}` ZenML Pro server is currently in a "
386
389
  "failed state. Please manage the server state by visiting the "
@@ -388,7 +391,7 @@ def connect_to_pro_server(
388
391
  "your server administrator."
389
392
  )
390
393
 
391
- elif server.status == TenantStatus.DEACTIVATED:
394
+ elif server.status == WorkspaceStatus.DEACTIVATED:
392
395
  cli_utils.error(
393
396
  f"Your `{server.name}` ZenML Pro server is currently "
394
397
  "deactivated. Please manage the server state by visiting the "
@@ -396,7 +399,7 @@ def connect_to_pro_server(
396
399
  "your server administrator."
397
400
  )
398
401
 
399
- elif server.status == TenantStatus.AVAILABLE:
402
+ elif server.status == WorkspaceStatus.AVAILABLE:
400
403
  if not server.url:
401
404
  cli_utils.error(
402
405
  f"The ZenML Pro server '{server.name}' is not currently "
@@ -418,7 +421,7 @@ def connect_to_pro_server(
418
421
  connect_to_server(server.url, api_key=api_key, pro_server=True)
419
422
 
420
423
  # Update the stored server info with more accurate data taken from the
421
- # ZenML Pro tenant object.
424
+ # ZenML Pro workspace object.
422
425
  credentials_store.update_server_info(server.url, server)
423
426
 
424
427
  cli_utils.declare(f"Connected to ZenML Pro server: {server.name}.")
@@ -836,7 +839,7 @@ def login(
836
839
  pro_api_url=pro_api_url,
837
840
  )
838
841
 
839
- elif current_non_local_server:
842
+ elif current_non_local_server and not refresh:
840
843
  # The server argument is not provided, so we default to
841
844
  # re-authenticating to the current non-local server that the client is
842
845
  # connected to.
zenml/cli/server.py CHANGED
@@ -220,7 +220,7 @@ def status() -> None:
220
220
  )
221
221
  if pro_credentials:
222
222
  pro_client = ZenMLProClient(pro_credentials.url)
223
- pro_servers = pro_client.tenant.list(
223
+ pro_servers = pro_client.workspace.list(
224
224
  url=store_cfg.url, member_only=True
225
225
  )
226
226
  if pro_servers:
@@ -575,7 +575,7 @@ def server_list(
575
575
  from zenml.login.credentials_store import get_credentials_store
576
576
  from zenml.login.pro.client import ZenMLProClient
577
577
  from zenml.login.pro.constants import ZENML_PRO_API_URL
578
- from zenml.login.pro.tenant.models import TenantRead, TenantStatus
578
+ from zenml.login.pro.workspace.models import WorkspaceRead, WorkspaceStatus
579
579
 
580
580
  pro_api_url = pro_api_url or ZENML_PRO_API_URL
581
581
  pro_api_url = pro_api_url.rstrip("/")
@@ -601,10 +601,10 @@ def server_list(
601
601
  # that the user has never connected to (and are therefore not stored in
602
602
  # the credentials store).
603
603
 
604
- accessible_pro_servers: List[TenantRead] = []
604
+ accessible_pro_servers: List[WorkspaceRead] = []
605
605
  try:
606
606
  client = ZenMLProClient(pro_api_url)
607
- accessible_pro_servers = client.tenant.list(member_only=not all)
607
+ accessible_pro_servers = client.workspace.list(member_only=not all)
608
608
  except AuthorizationException as e:
609
609
  cli_utils.warning(f"ZenML Pro authorization error: {e}")
610
610
 
@@ -638,7 +638,7 @@ def server_list(
638
638
  accessible_pro_servers = [
639
639
  s
640
640
  for s in accessible_pro_servers
641
- if s.status == TenantStatus.AVAILABLE
641
+ if s.status == WorkspaceStatus.AVAILABLE
642
642
  ]
643
643
 
644
644
  if not accessible_pro_servers:
zenml/client.py CHANGED
@@ -68,6 +68,7 @@ from zenml.enums import (
68
68
  SorterOps,
69
69
  StackComponentType,
70
70
  StoreType,
71
+ TaggableResourceTypes,
71
72
  )
72
73
  from zenml.exceptions import (
73
74
  AuthorizationException,
@@ -7770,6 +7771,7 @@ class Client(metaclass=ClientMetaClass):
7770
7771
  name: Optional[str] = None,
7771
7772
  color: Optional[Union[str, ColorVariants]] = None,
7772
7773
  exclusive: Optional[bool] = None,
7774
+ resource_type: Optional[Union[str, TaggableResourceTypes]] = None,
7773
7775
  hydrate: bool = False,
7774
7776
  ) -> Page[TagResponse]:
7775
7777
  """Get tags by filter.
@@ -7777,7 +7779,7 @@ class Client(metaclass=ClientMetaClass):
7777
7779
  Args:
7778
7780
  sort_by: The column to sort by.
7779
7781
  page: The page of items.
7780
- size: The maximum size of all pages.
7782
+ size: The maximum size of all pages
7781
7783
  logical_operator: Which logical operator to use [and, or].
7782
7784
  id: Use the id of stacks to filter by.
7783
7785
  user: Use the user to filter by.
@@ -7786,6 +7788,7 @@ class Client(metaclass=ClientMetaClass):
7786
7788
  name: The name of the tag.
7787
7789
  color: The color of the tag.
7788
7790
  exclusive: Flag indicating whether the tag is exclusive.
7791
+ resource_type: Filter tags associated with a specific resource type.
7789
7792
  hydrate: Flag deciding whether to hydrate the output model(s)
7790
7793
  by including metadata fields in the response.
7791
7794
 
@@ -7805,6 +7808,7 @@ class Client(metaclass=ClientMetaClass):
7805
7808
  name=name,
7806
7809
  color=color,
7807
7810
  exclusive=exclusive,
7811
+ resource_type=resource_type,
7808
7812
  ),
7809
7813
  hydrate=hydrate,
7810
7814
  )
@@ -592,23 +592,23 @@ class ServerConfiguration(BaseModel):
592
592
  server_config.external_user_info_url = (
593
593
  f"{server_pro_config.api_url}/users/authorize_server"
594
594
  )
595
- server_config.external_server_id = server_pro_config.tenant_id
595
+ server_config.external_server_id = server_pro_config.workspace_id
596
596
  server_config.rbac_implementation_source = (
597
597
  "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC"
598
598
  )
599
599
  server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface"
600
600
  server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES
601
- server_config.dashboard_url = f"{server_pro_config.dashboard_url}/organizations/{server_pro_config.organization_id}/tenants/{server_pro_config.tenant_id}"
601
+ server_config.dashboard_url = f"{server_pro_config.dashboard_url}/workspaces/{server_pro_config.workspace_id}"
602
602
  server_config.metadata.update(
603
603
  dict(
604
604
  account_id=str(server_pro_config.organization_id),
605
605
  organization_id=str(server_pro_config.organization_id),
606
- tenant_id=str(server_pro_config.tenant_id),
606
+ workspace_id=str(server_pro_config.workspace_id),
607
607
  )
608
608
  )
609
- if server_pro_config.tenant_name:
609
+ if server_pro_config.workspace_name:
610
610
  server_config.metadata.update(
611
- dict(tenant_name=server_pro_config.tenant_name)
611
+ dict(workspace_name=server_pro_config.workspace_name)
612
612
  )
613
613
 
614
614
  extra_cors_allow_origins = [
@@ -660,8 +660,8 @@ class ServerProConfiguration(BaseModel):
660
660
  oauth2_audience: The OAuth2 audience.
661
661
  organization_id: The ZenML Pro organization ID.
662
662
  organization_name: The ZenML Pro organization name.
663
- tenant_id: The ZenML Pro tenant ID.
664
- tenant_name: The ZenML Pro tenant name.
663
+ workspace_id: The ZenML Pro workspace ID.
664
+ workspace_name: The ZenML Pro workspace name.
665
665
  """
666
666
 
667
667
  api_url: str
@@ -670,8 +670,8 @@ class ServerProConfiguration(BaseModel):
670
670
  oauth2_audience: str
671
671
  organization_id: UUID
672
672
  organization_name: Optional[str] = None
673
- tenant_id: UUID
674
- tenant_name: Optional[str] = None
673
+ workspace_id: UUID
674
+ workspace_name: Optional[str] = None
675
675
 
676
676
  @field_validator("api_url", "dashboard_url")
677
677
  @classmethod
@@ -52,6 +52,7 @@ class GcpIntegration(Integration):
52
52
  "google-cloud-storage>=2.9.0",
53
53
  "google-cloud-aiplatform>=1.34.0", # includes shapely pin fix
54
54
  "google-cloud-build>=3.11.0",
55
+ "google-cloud-pipeline-components>=2.19.0",
55
56
  "kubernetes",
56
57
  ]
57
58
  REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"]
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
23
23
  from zenml.integrations.gcp.google_credentials_mixin import (
24
24
  GoogleCredentialsConfigMixin,
25
25
  )
26
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
27
+ VertexCustomJobParameters,
28
+ )
26
29
  from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
27
30
  from zenml.models import ServiceConnectorRequirements
28
31
  from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
@@ -61,6 +64,8 @@ class VertexOrchestratorSettings(BaseSettings):
61
64
  node_selector_constraint: Optional[Tuple[str, str]] = None
62
65
  pod_settings: Optional[KubernetesPodSettings] = None
63
66
 
67
+ custom_job_parameters: Optional[VertexCustomJobParameters] = None
68
+
64
69
  _node_selector_deprecation = (
65
70
  deprecation_utils.deprecate_pydantic_attributes(
66
71
  "node_selector_constraint"
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
23
23
  from zenml.integrations.gcp.google_credentials_mixin import (
24
24
  GoogleCredentialsConfigMixin,
25
25
  )
26
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
27
+ VertexCustomJobParameters,
28
+ )
26
29
  from zenml.models import ServiceConnectorRequirements
27
30
  from zenml.step_operators.base_step_operator import (
28
31
  BaseStepOperatorConfig,
@@ -33,34 +36,8 @@ if TYPE_CHECKING:
33
36
  from zenml.integrations.gcp.step_operators import VertexStepOperator
34
37
 
35
38
 
36
- class VertexStepOperatorSettings(BaseSettings):
37
- """Settings for the Vertex step operator.
38
-
39
- Attributes:
40
- accelerator_type: Defines which accelerator (GPU, TPU) is used for the
41
- job. Check out out this table to see which accelerator
42
- type and count are compatible with your chosen machine type:
43
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
44
- accelerator_count: Defines number of accelerators to be used for the
45
- job. Check out out this table to see which accelerator
46
- type and count are compatible with your chosen machine type:
47
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
48
- machine_type: Machine type specified here
49
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
50
- boot_disk_size_gb: Size of the boot disk in GB. (Default: 100)
51
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
52
- boot_disk_type: Type of the boot disk. (Default: pd-ssd)
53
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
54
- persistent_resource_id: The ID of the persistent resource to use for the job.
55
- https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
56
- """
57
-
58
- accelerator_type: Optional[str] = None
59
- accelerator_count: int = 0
60
- machine_type: str = "n1-standard-4"
61
- boot_disk_size_gb: int = 100
62
- boot_disk_type: str = "pd-ssd"
63
- persistent_resource_id: Optional[str] = None
39
+ class VertexStepOperatorSettings(VertexCustomJobParameters, BaseSettings):
40
+ """Settings for the Vertex step operator."""
64
41
 
65
42
 
66
43
  class VertexStepOperatorConfig(
@@ -49,8 +49,12 @@ from uuid import UUID
49
49
  from google.api_core import exceptions as google_exceptions
50
50
  from google.cloud import aiplatform
51
51
  from google.cloud.aiplatform_v1.types import PipelineState
52
+ from google_cloud_pipeline_components.v1.custom_job.utils import (
53
+ create_custom_training_job_from_component,
54
+ )
52
55
  from kfp import dsl
53
56
  from kfp.compiler import Compiler
57
+ from kfp.dsl.base_component import BaseComponent
54
58
 
55
59
  from zenml.config.resource_settings import ResourceSettings
56
60
  from zenml.constants import (
@@ -71,13 +75,15 @@ from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
71
75
  from zenml.integrations.gcp.google_credentials_mixin import (
72
76
  GoogleCredentialsMixin,
73
77
  )
78
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
79
+ VertexCustomJobParameters,
80
+ )
74
81
  from zenml.io import fileio
75
82
  from zenml.logger import get_logger
76
83
  from zenml.metadata.metadata_types import MetadataType, Uri
77
84
  from zenml.orchestrators import ContainerizedOrchestrator
78
85
  from zenml.orchestrators.utils import get_orchestrator_run_name
79
86
  from zenml.stack.stack_validator import StackValidator
80
- from zenml.utils import yaml_utils
81
87
  from zenml.utils.io_utils import get_global_config_directory
82
88
 
83
89
  if TYPE_CHECKING:
@@ -263,14 +269,14 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
263
269
  "schedule to a Vertex orchestrator."
264
270
  )
265
271
 
266
- def _create_dynamic_component(
272
+ def _create_container_component(
267
273
  self,
268
274
  image: str,
269
275
  command: List[str],
270
276
  arguments: List[str],
271
277
  component_name: str,
272
- ) -> dsl.PipelineTask:
273
- """Creates a dynamic container component for a Vertex pipeline.
278
+ ) -> BaseComponent:
279
+ """Creates a container component for a Vertex pipeline.
274
280
 
275
281
  Args:
276
282
  image: The image to use for the component.
@@ -279,7 +285,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
279
285
  component_name: The name of the component.
280
286
 
281
287
  Returns:
282
- The dynamic container component.
288
+ The container component.
283
289
  """
284
290
 
285
291
  def dynamic_container_component() -> dsl.ContainerSpec:
@@ -294,7 +300,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
294
300
  args=arguments,
295
301
  )
296
302
 
297
- # Change the name of the function
298
303
  new_container_spec_func = types.FunctionType(
299
304
  dynamic_container_component.__code__,
300
305
  dynamic_container_component.__globals__,
@@ -303,9 +308,50 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
303
308
  closure=dynamic_container_component.__closure__,
304
309
  )
305
310
  pipeline_task = dsl.container_component(new_container_spec_func)
306
-
307
311
  return pipeline_task
308
312
 
313
+ def _convert_to_custom_training_job(
314
+ self,
315
+ component: BaseComponent,
316
+ settings: VertexOrchestratorSettings,
317
+ environment: Dict[str, str],
318
+ ) -> BaseComponent:
319
+ """Convert a component to a custom training job component.
320
+
321
+ Args:
322
+ component: The component to convert.
323
+ settings: The settings for the custom training job.
324
+ environment: The environment variables to set in the custom
325
+ training job.
326
+
327
+ Returns:
328
+ The custom training job component.
329
+ """
330
+ custom_job_parameters = (
331
+ settings.custom_job_parameters or VertexCustomJobParameters()
332
+ )
333
+ if (
334
+ custom_job_parameters.persistent_resource_id
335
+ and not custom_job_parameters.service_account
336
+ ):
337
+ # Persistent resources require an explicit service account, but
338
+ # none was provided in the custom job parameters. We try to fall
339
+ # back to the workload service account.
340
+ custom_job_parameters.service_account = (
341
+ self.config.workload_service_account
342
+ )
343
+
344
+ custom_job_component = create_custom_training_job_from_component(
345
+ component_spec=component,
346
+ env=[
347
+ {"name": key, "value": value}
348
+ for key, value in environment.items()
349
+ ],
350
+ **custom_job_parameters.model_dump(),
351
+ )
352
+
353
+ return custom_job_component
354
+
309
355
  def prepare_or_run_pipeline(
310
356
  self,
311
357
  deployment: "PipelineDeploymentResponse",
@@ -383,7 +429,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
383
429
  Returns:
384
430
  pipeline_func
385
431
  """
386
- step_name_to_dynamic_component: Dict[str, Any] = {}
432
+ step_name_to_dynamic_component: Dict[str, BaseComponent] = {}
387
433
 
388
434
  for step_name, step in deployment.step_configurations.items():
389
435
  image = self.get_image(
@@ -397,7 +443,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
397
443
  deployment_id=deployment.id,
398
444
  )
399
445
  )
400
- dynamic_component = self._create_dynamic_component(
446
+ component = self._create_container_component(
401
447
  image, command, arguments, step_name
402
448
  )
403
449
  step_settings = cast(
@@ -442,7 +488,11 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
442
488
  key,
443
489
  )
444
490
 
445
- step_name_to_dynamic_component[step_name] = dynamic_component
491
+ step_name_to_dynamic_component[step_name] = component
492
+
493
+ environment[ENV_ZENML_VERTEX_RUN_ID] = (
494
+ dsl.PIPELINE_JOB_NAME_PLACEHOLDER
495
+ )
446
496
 
447
497
  @dsl.pipeline( # type: ignore[misc]
448
498
  display_name=orchestrator_run_name,
@@ -462,81 +512,81 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
462
512
  step_name_to_dynamic_component[upstream_step_name]
463
513
  for upstream_step_name in step.spec.upstream_steps
464
514
  ]
465
- task = (
466
- component()
467
- .set_display_name(
468
- name=component_name,
469
- )
470
- .set_caching_options(enable_caching=False)
471
- .set_env_variable(
472
- name=ENV_ZENML_VERTEX_RUN_ID,
473
- value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
474
- )
475
- .after(*upstream_step_components)
476
- )
477
515
 
478
516
  step_settings = cast(
479
517
  VertexOrchestratorSettings, self.get_settings(step)
480
518
  )
481
- pod_settings = step_settings.pod_settings
482
-
483
- node_selector_constraint: Optional[Tuple[str, str]] = None
484
- if pod_settings and (
485
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
486
- in pod_settings.node_selectors.keys()
487
- ):
488
- node_selector_constraint = (
489
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
490
- pod_settings.node_selectors[
491
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
492
- ],
519
+
520
+ use_custom_training_job = (
521
+ step_settings.custom_job_parameters is not None
522
+ )
523
+
524
+ if use_custom_training_job:
525
+ if not step.config.resource_settings.empty:
526
+ logger.warning(
527
+ "Ignoring resource settings because "
528
+ "the step is running as a custom training job. "
529
+ "Use `custom_job_parameters.machine_type` "
530
+ "to configure the machine type instead."
531
+ )
532
+ if step_settings.node_selector_constraint:
533
+ logger.warning(
534
+ "Ignoring node selector constraint because "
535
+ "the step is running as a custom training job. "
536
+ "Use `custom_job_parameters.accelerator_type` "
537
+ "to configure the accelerator type instead."
538
+ )
539
+ component = self._convert_to_custom_training_job(
540
+ component,
541
+ settings=step_settings,
542
+ environment=environment,
543
+ )
544
+ task = (
545
+ component()
546
+ .set_display_name(name=component_name)
547
+ .set_caching_options(enable_caching=False)
548
+ .after(*upstream_step_components)
493
549
  )
494
- elif step_settings.node_selector_constraint:
495
- node_selector_constraint = (
496
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
497
- step_settings.node_selector_constraint[1],
550
+ else:
551
+ task = (
552
+ component()
553
+ .set_display_name(
554
+ name=component_name,
555
+ )
556
+ .set_caching_options(enable_caching=False)
557
+ .after(*upstream_step_components)
498
558
  )
559
+ for key, value in environment.items():
560
+ task = task.set_env_variable(name=key, value=value)
499
561
 
500
- self._configure_container_resources(
501
- dynamic_component=task,
502
- resource_settings=step.config.resource_settings,
503
- node_selector_constraint=node_selector_constraint,
504
- )
562
+ pod_settings = step_settings.pod_settings
505
563
 
506
- return dynamic_pipeline
564
+ node_selector_constraint: Optional[Tuple[str, str]] = (
565
+ None
566
+ )
567
+ if pod_settings and (
568
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
569
+ in pod_settings.node_selectors.keys()
570
+ ):
571
+ node_selector_constraint = (
572
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
573
+ pod_settings.node_selectors[
574
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
575
+ ],
576
+ )
577
+ elif step_settings.node_selector_constraint:
578
+ node_selector_constraint = (
579
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
580
+ step_settings.node_selector_constraint[1],
581
+ )
507
582
 
508
- def _update_json_with_environment(
509
- yaml_file_path: str, environment: Dict[str, str]
510
- ) -> None:
511
- """Updates the env section of the steps in the YAML file with the given environment variables.
583
+ self._configure_container_resources(
584
+ dynamic_component=task,
585
+ resource_settings=step.config.resource_settings,
586
+ node_selector_constraint=node_selector_constraint,
587
+ )
512
588
 
513
- Args:
514
- yaml_file_path: The path to the YAML file to update.
515
- environment: A dictionary of environment variables to add.
516
- """
517
- pipeline_definition = yaml_utils.read_json(pipeline_file_path)
518
-
519
- # Iterate through each component and add the environment variables
520
- for executor in pipeline_definition["deploymentSpec"]["executors"]:
521
- if (
522
- "container"
523
- in pipeline_definition["deploymentSpec"]["executors"][
524
- executor
525
- ]
526
- ):
527
- container = pipeline_definition["deploymentSpec"][
528
- "executors"
529
- ][executor]["container"]
530
- if "env" not in container:
531
- container["env"] = []
532
- for key, value in environment.items():
533
- container["env"].append({"name": key, "value": value})
534
-
535
- yaml_utils.write_json(pipeline_file_path, pipeline_definition)
536
-
537
- print(
538
- f"Updated YAML file with environment variables at {yaml_file_path}"
539
- )
589
+ return dynamic_pipeline
540
590
 
541
591
  # Save the generated pipeline to a file.
542
592
  fileio.makedirs(self.pipeline_directory)
@@ -556,9 +606,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
556
606
  ),
557
607
  )
558
608
 
559
- # Let's update the YAML file with the environment variables
560
- _update_json_with_environment(pipeline_file_path, environment)
561
-
562
609
  logger.info(
563
610
  "Writing Vertex workflow definition to `%s`.", pipeline_file_path
564
611
  )