zenml-nightly 0.75.0.dev20250314__py3-none-any.whl → 0.75.0.dev20250316__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/login.py +21 -18
- zenml/cli/server.py +5 -5
- zenml/client.py +5 -1
- zenml/config/server_config.py +9 -9
- zenml/integrations/gcp/__init__.py +1 -0
- zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +5 -0
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +5 -28
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +125 -78
- zenml/integrations/gcp/vertex_custom_job_parameters.py +50 -0
- zenml/login/credentials.py +26 -27
- zenml/login/credentials_store.py +5 -5
- zenml/login/pro/client.py +9 -9
- zenml/login/pro/utils.py +8 -8
- zenml/login/pro/{tenant → workspace}/__init__.py +1 -1
- zenml/login/pro/{tenant → workspace}/client.py +25 -25
- zenml/login/pro/{tenant → workspace}/models.py +27 -28
- zenml/models/v2/core/model.py +9 -1
- zenml/models/v2/core/tag.py +96 -3
- zenml/models/v2/misc/server_models.py +6 -6
- zenml/orchestrators/step_run_utils.py +1 -1
- zenml/utils/dashboard_utils.py +1 -1
- zenml/utils/tag_utils.py +0 -12
- zenml/zen_server/cloud_utils.py +3 -3
- zenml/zen_server/feature_gate/endpoint_utils.py +1 -1
- zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
- zenml/zen_server/rbac/models.py +30 -5
- zenml/zen_server/rbac/zenml_cloud_rbac.py +7 -70
- zenml/zen_server/routers/server_endpoints.py +2 -2
- zenml/zen_server/zen_server_api.py +3 -3
- zenml/zen_stores/base_zen_store.py +3 -3
- zenml/zen_stores/rest_zen_store.py +1 -1
- zenml/zen_stores/sql_zen_store.py +60 -7
- {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250316.dist-info}/METADATA +1 -1
- {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250316.dist-info}/RECORD +38 -37
- {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250316.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250316.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.75.0.dev20250314.dist-info → zenml_nightly-0.75.0.dev20250316.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.75.0.
|
1
|
+
0.75.0.dev20250316
|
zenml/cli/login.py
CHANGED
@@ -246,7 +246,7 @@ def connect_to_pro_server(
|
|
246
246
|
"""
|
247
247
|
from zenml.login.credentials_store import get_credentials_store
|
248
248
|
from zenml.login.pro.client import ZenMLProClient
|
249
|
-
from zenml.login.pro.
|
249
|
+
from zenml.login.pro.workspace.models import WorkspaceStatus
|
250
250
|
|
251
251
|
pro_api_url = pro_api_url or ZENML_PRO_API_URL
|
252
252
|
pro_api_url = pro_api_url.rstrip("/")
|
@@ -295,7 +295,7 @@ def connect_to_pro_server(
|
|
295
295
|
# We also need to remove all existing API tokens associated with the
|
296
296
|
# target ZenML Pro API, otherwise they will continue to be used after
|
297
297
|
# the re-login flow.
|
298
|
-
credentials_store.clear_all_pro_tokens(
|
298
|
+
credentials_store.clear_all_pro_tokens()
|
299
299
|
try:
|
300
300
|
token = web_login(
|
301
301
|
pro_api_url=pro_api_url,
|
@@ -310,13 +310,14 @@ def connect_to_pro_server(
|
|
310
310
|
"your session expires."
|
311
311
|
)
|
312
312
|
|
313
|
-
|
313
|
+
workspace_id: Optional[str] = None
|
314
314
|
if token.device_metadata:
|
315
|
-
|
315
|
+
# TODO: is this still correct?
|
316
|
+
workspace_id = token.device_metadata.get("tenant_id")
|
316
317
|
|
317
|
-
if
|
318
|
+
if workspace_id is None and pro_server is None:
|
318
319
|
# This is not really supposed to happen, because the implementation
|
319
|
-
# of the web login workflow should always return a
|
320
|
+
# of the web login workflow should always return a workspace ID, but
|
320
321
|
# we're handling it just in case.
|
321
322
|
cli_utils.declare(
|
322
323
|
"A valid server was not selected during the login process. "
|
@@ -328,14 +329,14 @@ def connect_to_pro_server(
|
|
328
329
|
|
329
330
|
# The server selected during the web login process overrides any
|
330
331
|
# server argument passed to the command.
|
331
|
-
server_id = UUID(
|
332
|
+
server_id = UUID(workspace_id)
|
332
333
|
|
333
334
|
client = ZenMLProClient(pro_api_url)
|
334
335
|
|
335
336
|
if server_id:
|
336
|
-
server = client.
|
337
|
+
server = client.workspace.get(server_id)
|
337
338
|
elif server_url:
|
338
|
-
servers = client.
|
339
|
+
servers = client.workspace.list(url=server_url, member_only=True)
|
339
340
|
if not servers:
|
340
341
|
raise AuthorizationException(
|
341
342
|
f"The '{server_url}' URL belongs to a ZenML Pro server, "
|
@@ -345,7 +346,9 @@ def connect_to_pro_server(
|
|
345
346
|
|
346
347
|
server = servers[0]
|
347
348
|
elif server_name:
|
348
|
-
servers = client.
|
349
|
+
servers = client.workspace.list(
|
350
|
+
workspace_name=server_name, member_only=True
|
351
|
+
)
|
349
352
|
if not servers:
|
350
353
|
raise AuthorizationException(
|
351
354
|
f"No ZenML Pro server with the name '{server_name}' exists "
|
@@ -361,15 +364,15 @@ def connect_to_pro_server(
|
|
361
364
|
|
362
365
|
server_id = server.id
|
363
366
|
|
364
|
-
if server.status ==
|
367
|
+
if server.status == WorkspaceStatus.PENDING:
|
365
368
|
with console.status(
|
366
369
|
f"Waiting for your `{server.name}` ZenML Pro server to be set up..."
|
367
370
|
):
|
368
371
|
timeout = 180 # 3 minutes
|
369
372
|
while True:
|
370
373
|
time.sleep(5)
|
371
|
-
server = client.
|
372
|
-
if server.status !=
|
374
|
+
server = client.workspace.get(server_id)
|
375
|
+
if server.status != WorkspaceStatus.PENDING:
|
373
376
|
break
|
374
377
|
timeout -= 5
|
375
378
|
if timeout <= 0:
|
@@ -380,7 +383,7 @@ def connect_to_pro_server(
|
|
380
383
|
f"ZenML Pro dashboard at {server.dashboard_url}."
|
381
384
|
)
|
382
385
|
|
383
|
-
if server.status ==
|
386
|
+
if server.status == WorkspaceStatus.FAILED:
|
384
387
|
cli_utils.error(
|
385
388
|
f"Your `{server.name}` ZenML Pro server is currently in a "
|
386
389
|
"failed state. Please manage the server state by visiting the "
|
@@ -388,7 +391,7 @@ def connect_to_pro_server(
|
|
388
391
|
"your server administrator."
|
389
392
|
)
|
390
393
|
|
391
|
-
elif server.status ==
|
394
|
+
elif server.status == WorkspaceStatus.DEACTIVATED:
|
392
395
|
cli_utils.error(
|
393
396
|
f"Your `{server.name}` ZenML Pro server is currently "
|
394
397
|
"deactivated. Please manage the server state by visiting the "
|
@@ -396,7 +399,7 @@ def connect_to_pro_server(
|
|
396
399
|
"your server administrator."
|
397
400
|
)
|
398
401
|
|
399
|
-
elif server.status ==
|
402
|
+
elif server.status == WorkspaceStatus.AVAILABLE:
|
400
403
|
if not server.url:
|
401
404
|
cli_utils.error(
|
402
405
|
f"The ZenML Pro server '{server.name}' is not currently "
|
@@ -418,7 +421,7 @@ def connect_to_pro_server(
|
|
418
421
|
connect_to_server(server.url, api_key=api_key, pro_server=True)
|
419
422
|
|
420
423
|
# Update the stored server info with more accurate data taken from the
|
421
|
-
# ZenML Pro
|
424
|
+
# ZenML Pro workspace object.
|
422
425
|
credentials_store.update_server_info(server.url, server)
|
423
426
|
|
424
427
|
cli_utils.declare(f"Connected to ZenML Pro server: {server.name}.")
|
@@ -836,7 +839,7 @@ def login(
|
|
836
839
|
pro_api_url=pro_api_url,
|
837
840
|
)
|
838
841
|
|
839
|
-
elif current_non_local_server:
|
842
|
+
elif current_non_local_server and not refresh:
|
840
843
|
# The server argument is not provided, so we default to
|
841
844
|
# re-authenticating to the current non-local server that the client is
|
842
845
|
# connected to.
|
zenml/cli/server.py
CHANGED
@@ -220,7 +220,7 @@ def status() -> None:
|
|
220
220
|
)
|
221
221
|
if pro_credentials:
|
222
222
|
pro_client = ZenMLProClient(pro_credentials.url)
|
223
|
-
pro_servers = pro_client.
|
223
|
+
pro_servers = pro_client.workspace.list(
|
224
224
|
url=store_cfg.url, member_only=True
|
225
225
|
)
|
226
226
|
if pro_servers:
|
@@ -575,7 +575,7 @@ def server_list(
|
|
575
575
|
from zenml.login.credentials_store import get_credentials_store
|
576
576
|
from zenml.login.pro.client import ZenMLProClient
|
577
577
|
from zenml.login.pro.constants import ZENML_PRO_API_URL
|
578
|
-
from zenml.login.pro.
|
578
|
+
from zenml.login.pro.workspace.models import WorkspaceRead, WorkspaceStatus
|
579
579
|
|
580
580
|
pro_api_url = pro_api_url or ZENML_PRO_API_URL
|
581
581
|
pro_api_url = pro_api_url.rstrip("/")
|
@@ -601,10 +601,10 @@ def server_list(
|
|
601
601
|
# that the user has never connected to (and are therefore not stored in
|
602
602
|
# the credentials store).
|
603
603
|
|
604
|
-
accessible_pro_servers: List[
|
604
|
+
accessible_pro_servers: List[WorkspaceRead] = []
|
605
605
|
try:
|
606
606
|
client = ZenMLProClient(pro_api_url)
|
607
|
-
accessible_pro_servers = client.
|
607
|
+
accessible_pro_servers = client.workspace.list(member_only=not all)
|
608
608
|
except AuthorizationException as e:
|
609
609
|
cli_utils.warning(f"ZenML Pro authorization error: {e}")
|
610
610
|
|
@@ -638,7 +638,7 @@ def server_list(
|
|
638
638
|
accessible_pro_servers = [
|
639
639
|
s
|
640
640
|
for s in accessible_pro_servers
|
641
|
-
if s.status ==
|
641
|
+
if s.status == WorkspaceStatus.AVAILABLE
|
642
642
|
]
|
643
643
|
|
644
644
|
if not accessible_pro_servers:
|
zenml/client.py
CHANGED
@@ -68,6 +68,7 @@ from zenml.enums import (
|
|
68
68
|
SorterOps,
|
69
69
|
StackComponentType,
|
70
70
|
StoreType,
|
71
|
+
TaggableResourceTypes,
|
71
72
|
)
|
72
73
|
from zenml.exceptions import (
|
73
74
|
AuthorizationException,
|
@@ -7770,6 +7771,7 @@ class Client(metaclass=ClientMetaClass):
|
|
7770
7771
|
name: Optional[str] = None,
|
7771
7772
|
color: Optional[Union[str, ColorVariants]] = None,
|
7772
7773
|
exclusive: Optional[bool] = None,
|
7774
|
+
resource_type: Optional[Union[str, TaggableResourceTypes]] = None,
|
7773
7775
|
hydrate: bool = False,
|
7774
7776
|
) -> Page[TagResponse]:
|
7775
7777
|
"""Get tags by filter.
|
@@ -7777,7 +7779,7 @@ class Client(metaclass=ClientMetaClass):
|
|
7777
7779
|
Args:
|
7778
7780
|
sort_by: The column to sort by.
|
7779
7781
|
page: The page of items.
|
7780
|
-
size: The maximum size of all pages
|
7782
|
+
size: The maximum size of all pages
|
7781
7783
|
logical_operator: Which logical operator to use [and, or].
|
7782
7784
|
id: Use the id of stacks to filter by.
|
7783
7785
|
user: Use the user to filter by.
|
@@ -7786,6 +7788,7 @@ class Client(metaclass=ClientMetaClass):
|
|
7786
7788
|
name: The name of the tag.
|
7787
7789
|
color: The color of the tag.
|
7788
7790
|
exclusive: Flag indicating whether the tag is exclusive.
|
7791
|
+
resource_type: Filter tags associated with a specific resource type.
|
7789
7792
|
hydrate: Flag deciding whether to hydrate the output model(s)
|
7790
7793
|
by including metadata fields in the response.
|
7791
7794
|
|
@@ -7805,6 +7808,7 @@ class Client(metaclass=ClientMetaClass):
|
|
7805
7808
|
name=name,
|
7806
7809
|
color=color,
|
7807
7810
|
exclusive=exclusive,
|
7811
|
+
resource_type=resource_type,
|
7808
7812
|
),
|
7809
7813
|
hydrate=hydrate,
|
7810
7814
|
)
|
zenml/config/server_config.py
CHANGED
@@ -592,23 +592,23 @@ class ServerConfiguration(BaseModel):
|
|
592
592
|
server_config.external_user_info_url = (
|
593
593
|
f"{server_pro_config.api_url}/users/authorize_server"
|
594
594
|
)
|
595
|
-
server_config.external_server_id = server_pro_config.
|
595
|
+
server_config.external_server_id = server_pro_config.workspace_id
|
596
596
|
server_config.rbac_implementation_source = (
|
597
597
|
"zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC"
|
598
598
|
)
|
599
599
|
server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface"
|
600
600
|
server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES
|
601
|
-
server_config.dashboard_url = f"{server_pro_config.dashboard_url}/
|
601
|
+
server_config.dashboard_url = f"{server_pro_config.dashboard_url}/workspaces/{server_pro_config.workspace_id}"
|
602
602
|
server_config.metadata.update(
|
603
603
|
dict(
|
604
604
|
account_id=str(server_pro_config.organization_id),
|
605
605
|
organization_id=str(server_pro_config.organization_id),
|
606
|
-
|
606
|
+
workspace_id=str(server_pro_config.workspace_id),
|
607
607
|
)
|
608
608
|
)
|
609
|
-
if server_pro_config.
|
609
|
+
if server_pro_config.workspace_name:
|
610
610
|
server_config.metadata.update(
|
611
|
-
dict(
|
611
|
+
dict(workspace_name=server_pro_config.workspace_name)
|
612
612
|
)
|
613
613
|
|
614
614
|
extra_cors_allow_origins = [
|
@@ -660,8 +660,8 @@ class ServerProConfiguration(BaseModel):
|
|
660
660
|
oauth2_audience: The OAuth2 audience.
|
661
661
|
organization_id: The ZenML Pro organization ID.
|
662
662
|
organization_name: The ZenML Pro organization name.
|
663
|
-
|
664
|
-
|
663
|
+
workspace_id: The ZenML Pro workspace ID.
|
664
|
+
workspace_name: The ZenML Pro workspace name.
|
665
665
|
"""
|
666
666
|
|
667
667
|
api_url: str
|
@@ -670,8 +670,8 @@ class ServerProConfiguration(BaseModel):
|
|
670
670
|
oauth2_audience: str
|
671
671
|
organization_id: UUID
|
672
672
|
organization_name: Optional[str] = None
|
673
|
-
|
674
|
-
|
673
|
+
workspace_id: UUID
|
674
|
+
workspace_name: Optional[str] = None
|
675
675
|
|
676
676
|
@field_validator("api_url", "dashboard_url")
|
677
677
|
@classmethod
|
@@ -52,6 +52,7 @@ class GcpIntegration(Integration):
|
|
52
52
|
"google-cloud-storage>=2.9.0",
|
53
53
|
"google-cloud-aiplatform>=1.34.0", # includes shapely pin fix
|
54
54
|
"google-cloud-build>=3.11.0",
|
55
|
+
"google-cloud-pipeline-components>=2.19.0",
|
55
56
|
"kubernetes",
|
56
57
|
]
|
57
58
|
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"]
|
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
|
|
23
23
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
24
24
|
GoogleCredentialsConfigMixin,
|
25
25
|
)
|
26
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
27
|
+
VertexCustomJobParameters,
|
28
|
+
)
|
26
29
|
from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
|
27
30
|
from zenml.models import ServiceConnectorRequirements
|
28
31
|
from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
|
@@ -61,6 +64,8 @@ class VertexOrchestratorSettings(BaseSettings):
|
|
61
64
|
node_selector_constraint: Optional[Tuple[str, str]] = None
|
62
65
|
pod_settings: Optional[KubernetesPodSettings] = None
|
63
66
|
|
67
|
+
custom_job_parameters: Optional[VertexCustomJobParameters] = None
|
68
|
+
|
64
69
|
_node_selector_deprecation = (
|
65
70
|
deprecation_utils.deprecate_pydantic_attributes(
|
66
71
|
"node_selector_constraint"
|
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
|
|
23
23
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
24
24
|
GoogleCredentialsConfigMixin,
|
25
25
|
)
|
26
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
27
|
+
VertexCustomJobParameters,
|
28
|
+
)
|
26
29
|
from zenml.models import ServiceConnectorRequirements
|
27
30
|
from zenml.step_operators.base_step_operator import (
|
28
31
|
BaseStepOperatorConfig,
|
@@ -33,34 +36,8 @@ if TYPE_CHECKING:
|
|
33
36
|
from zenml.integrations.gcp.step_operators import VertexStepOperator
|
34
37
|
|
35
38
|
|
36
|
-
class VertexStepOperatorSettings(BaseSettings):
|
37
|
-
"""Settings for the Vertex step operator.
|
38
|
-
|
39
|
-
Attributes:
|
40
|
-
accelerator_type: Defines which accelerator (GPU, TPU) is used for the
|
41
|
-
job. Check out out this table to see which accelerator
|
42
|
-
type and count are compatible with your chosen machine type:
|
43
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
|
44
|
-
accelerator_count: Defines number of accelerators to be used for the
|
45
|
-
job. Check out out this table to see which accelerator
|
46
|
-
type and count are compatible with your chosen machine type:
|
47
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
|
48
|
-
machine_type: Machine type specified here
|
49
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
|
50
|
-
boot_disk_size_gb: Size of the boot disk in GB. (Default: 100)
|
51
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
52
|
-
boot_disk_type: Type of the boot disk. (Default: pd-ssd)
|
53
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
54
|
-
persistent_resource_id: The ID of the persistent resource to use for the job.
|
55
|
-
https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
|
56
|
-
"""
|
57
|
-
|
58
|
-
accelerator_type: Optional[str] = None
|
59
|
-
accelerator_count: int = 0
|
60
|
-
machine_type: str = "n1-standard-4"
|
61
|
-
boot_disk_size_gb: int = 100
|
62
|
-
boot_disk_type: str = "pd-ssd"
|
63
|
-
persistent_resource_id: Optional[str] = None
|
39
|
+
class VertexStepOperatorSettings(VertexCustomJobParameters, BaseSettings):
|
40
|
+
"""Settings for the Vertex step operator."""
|
64
41
|
|
65
42
|
|
66
43
|
class VertexStepOperatorConfig(
|
@@ -49,8 +49,12 @@ from uuid import UUID
|
|
49
49
|
from google.api_core import exceptions as google_exceptions
|
50
50
|
from google.cloud import aiplatform
|
51
51
|
from google.cloud.aiplatform_v1.types import PipelineState
|
52
|
+
from google_cloud_pipeline_components.v1.custom_job.utils import (
|
53
|
+
create_custom_training_job_from_component,
|
54
|
+
)
|
52
55
|
from kfp import dsl
|
53
56
|
from kfp.compiler import Compiler
|
57
|
+
from kfp.dsl.base_component import BaseComponent
|
54
58
|
|
55
59
|
from zenml.config.resource_settings import ResourceSettings
|
56
60
|
from zenml.constants import (
|
@@ -71,13 +75,15 @@ from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
|
|
71
75
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
72
76
|
GoogleCredentialsMixin,
|
73
77
|
)
|
78
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
79
|
+
VertexCustomJobParameters,
|
80
|
+
)
|
74
81
|
from zenml.io import fileio
|
75
82
|
from zenml.logger import get_logger
|
76
83
|
from zenml.metadata.metadata_types import MetadataType, Uri
|
77
84
|
from zenml.orchestrators import ContainerizedOrchestrator
|
78
85
|
from zenml.orchestrators.utils import get_orchestrator_run_name
|
79
86
|
from zenml.stack.stack_validator import StackValidator
|
80
|
-
from zenml.utils import yaml_utils
|
81
87
|
from zenml.utils.io_utils import get_global_config_directory
|
82
88
|
|
83
89
|
if TYPE_CHECKING:
|
@@ -263,14 +269,14 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
263
269
|
"schedule to a Vertex orchestrator."
|
264
270
|
)
|
265
271
|
|
266
|
-
def
|
272
|
+
def _create_container_component(
|
267
273
|
self,
|
268
274
|
image: str,
|
269
275
|
command: List[str],
|
270
276
|
arguments: List[str],
|
271
277
|
component_name: str,
|
272
|
-
) ->
|
273
|
-
"""Creates a
|
278
|
+
) -> BaseComponent:
|
279
|
+
"""Creates a container component for a Vertex pipeline.
|
274
280
|
|
275
281
|
Args:
|
276
282
|
image: The image to use for the component.
|
@@ -279,7 +285,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
279
285
|
component_name: The name of the component.
|
280
286
|
|
281
287
|
Returns:
|
282
|
-
The
|
288
|
+
The container component.
|
283
289
|
"""
|
284
290
|
|
285
291
|
def dynamic_container_component() -> dsl.ContainerSpec:
|
@@ -294,7 +300,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
294
300
|
args=arguments,
|
295
301
|
)
|
296
302
|
|
297
|
-
# Change the name of the function
|
298
303
|
new_container_spec_func = types.FunctionType(
|
299
304
|
dynamic_container_component.__code__,
|
300
305
|
dynamic_container_component.__globals__,
|
@@ -303,9 +308,50 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
303
308
|
closure=dynamic_container_component.__closure__,
|
304
309
|
)
|
305
310
|
pipeline_task = dsl.container_component(new_container_spec_func)
|
306
|
-
|
307
311
|
return pipeline_task
|
308
312
|
|
313
|
+
def _convert_to_custom_training_job(
|
314
|
+
self,
|
315
|
+
component: BaseComponent,
|
316
|
+
settings: VertexOrchestratorSettings,
|
317
|
+
environment: Dict[str, str],
|
318
|
+
) -> BaseComponent:
|
319
|
+
"""Convert a component to a custom training job component.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
component: The component to convert.
|
323
|
+
settings: The settings for the custom training job.
|
324
|
+
environment: The environment variables to set in the custom
|
325
|
+
training job.
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
The custom training job component.
|
329
|
+
"""
|
330
|
+
custom_job_parameters = (
|
331
|
+
settings.custom_job_parameters or VertexCustomJobParameters()
|
332
|
+
)
|
333
|
+
if (
|
334
|
+
custom_job_parameters.persistent_resource_id
|
335
|
+
and not custom_job_parameters.service_account
|
336
|
+
):
|
337
|
+
# Persistent resources require an explicit service account, but
|
338
|
+
# none was provided in the custom job parameters. We try to fall
|
339
|
+
# back to the workload service account.
|
340
|
+
custom_job_parameters.service_account = (
|
341
|
+
self.config.workload_service_account
|
342
|
+
)
|
343
|
+
|
344
|
+
custom_job_component = create_custom_training_job_from_component(
|
345
|
+
component_spec=component,
|
346
|
+
env=[
|
347
|
+
{"name": key, "value": value}
|
348
|
+
for key, value in environment.items()
|
349
|
+
],
|
350
|
+
**custom_job_parameters.model_dump(),
|
351
|
+
)
|
352
|
+
|
353
|
+
return custom_job_component
|
354
|
+
|
309
355
|
def prepare_or_run_pipeline(
|
310
356
|
self,
|
311
357
|
deployment: "PipelineDeploymentResponse",
|
@@ -383,7 +429,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
383
429
|
Returns:
|
384
430
|
pipeline_func
|
385
431
|
"""
|
386
|
-
step_name_to_dynamic_component: Dict[str,
|
432
|
+
step_name_to_dynamic_component: Dict[str, BaseComponent] = {}
|
387
433
|
|
388
434
|
for step_name, step in deployment.step_configurations.items():
|
389
435
|
image = self.get_image(
|
@@ -397,7 +443,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
397
443
|
deployment_id=deployment.id,
|
398
444
|
)
|
399
445
|
)
|
400
|
-
|
446
|
+
component = self._create_container_component(
|
401
447
|
image, command, arguments, step_name
|
402
448
|
)
|
403
449
|
step_settings = cast(
|
@@ -442,7 +488,11 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
442
488
|
key,
|
443
489
|
)
|
444
490
|
|
445
|
-
step_name_to_dynamic_component[step_name] =
|
491
|
+
step_name_to_dynamic_component[step_name] = component
|
492
|
+
|
493
|
+
environment[ENV_ZENML_VERTEX_RUN_ID] = (
|
494
|
+
dsl.PIPELINE_JOB_NAME_PLACEHOLDER
|
495
|
+
)
|
446
496
|
|
447
497
|
@dsl.pipeline( # type: ignore[misc]
|
448
498
|
display_name=orchestrator_run_name,
|
@@ -462,81 +512,81 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
462
512
|
step_name_to_dynamic_component[upstream_step_name]
|
463
513
|
for upstream_step_name in step.spec.upstream_steps
|
464
514
|
]
|
465
|
-
task = (
|
466
|
-
component()
|
467
|
-
.set_display_name(
|
468
|
-
name=component_name,
|
469
|
-
)
|
470
|
-
.set_caching_options(enable_caching=False)
|
471
|
-
.set_env_variable(
|
472
|
-
name=ENV_ZENML_VERTEX_RUN_ID,
|
473
|
-
value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
|
474
|
-
)
|
475
|
-
.after(*upstream_step_components)
|
476
|
-
)
|
477
515
|
|
478
516
|
step_settings = cast(
|
479
517
|
VertexOrchestratorSettings, self.get_settings(step)
|
480
518
|
)
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
519
|
+
|
520
|
+
use_custom_training_job = (
|
521
|
+
step_settings.custom_job_parameters is not None
|
522
|
+
)
|
523
|
+
|
524
|
+
if use_custom_training_job:
|
525
|
+
if not step.config.resource_settings.empty:
|
526
|
+
logger.warning(
|
527
|
+
"Ignoring resource settings because "
|
528
|
+
"the step is running as a custom training job. "
|
529
|
+
"Use `custom_job_parameters.machine_type` "
|
530
|
+
"to configure the machine type instead."
|
531
|
+
)
|
532
|
+
if step_settings.node_selector_constraint:
|
533
|
+
logger.warning(
|
534
|
+
"Ignoring node selector constraint because "
|
535
|
+
"the step is running as a custom training job. "
|
536
|
+
"Use `custom_job_parameters.accelerator_type` "
|
537
|
+
"to configure the accelerator type instead."
|
538
|
+
)
|
539
|
+
component = self._convert_to_custom_training_job(
|
540
|
+
component,
|
541
|
+
settings=step_settings,
|
542
|
+
environment=environment,
|
543
|
+
)
|
544
|
+
task = (
|
545
|
+
component()
|
546
|
+
.set_display_name(name=component_name)
|
547
|
+
.set_caching_options(enable_caching=False)
|
548
|
+
.after(*upstream_step_components)
|
493
549
|
)
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
550
|
+
else:
|
551
|
+
task = (
|
552
|
+
component()
|
553
|
+
.set_display_name(
|
554
|
+
name=component_name,
|
555
|
+
)
|
556
|
+
.set_caching_options(enable_caching=False)
|
557
|
+
.after(*upstream_step_components)
|
498
558
|
)
|
559
|
+
for key, value in environment.items():
|
560
|
+
task = task.set_env_variable(name=key, value=value)
|
499
561
|
|
500
|
-
|
501
|
-
dynamic_component=task,
|
502
|
-
resource_settings=step.config.resource_settings,
|
503
|
-
node_selector_constraint=node_selector_constraint,
|
504
|
-
)
|
562
|
+
pod_settings = step_settings.pod_settings
|
505
563
|
|
506
|
-
|
564
|
+
node_selector_constraint: Optional[Tuple[str, str]] = (
|
565
|
+
None
|
566
|
+
)
|
567
|
+
if pod_settings and (
|
568
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
|
569
|
+
in pod_settings.node_selectors.keys()
|
570
|
+
):
|
571
|
+
node_selector_constraint = (
|
572
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
|
573
|
+
pod_settings.node_selectors[
|
574
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
|
575
|
+
],
|
576
|
+
)
|
577
|
+
elif step_settings.node_selector_constraint:
|
578
|
+
node_selector_constraint = (
|
579
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
|
580
|
+
step_settings.node_selector_constraint[1],
|
581
|
+
)
|
507
582
|
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
583
|
+
self._configure_container_resources(
|
584
|
+
dynamic_component=task,
|
585
|
+
resource_settings=step.config.resource_settings,
|
586
|
+
node_selector_constraint=node_selector_constraint,
|
587
|
+
)
|
512
588
|
|
513
|
-
|
514
|
-
yaml_file_path: The path to the YAML file to update.
|
515
|
-
environment: A dictionary of environment variables to add.
|
516
|
-
"""
|
517
|
-
pipeline_definition = yaml_utils.read_json(pipeline_file_path)
|
518
|
-
|
519
|
-
# Iterate through each component and add the environment variables
|
520
|
-
for executor in pipeline_definition["deploymentSpec"]["executors"]:
|
521
|
-
if (
|
522
|
-
"container"
|
523
|
-
in pipeline_definition["deploymentSpec"]["executors"][
|
524
|
-
executor
|
525
|
-
]
|
526
|
-
):
|
527
|
-
container = pipeline_definition["deploymentSpec"][
|
528
|
-
"executors"
|
529
|
-
][executor]["container"]
|
530
|
-
if "env" not in container:
|
531
|
-
container["env"] = []
|
532
|
-
for key, value in environment.items():
|
533
|
-
container["env"].append({"name": key, "value": value})
|
534
|
-
|
535
|
-
yaml_utils.write_json(pipeline_file_path, pipeline_definition)
|
536
|
-
|
537
|
-
print(
|
538
|
-
f"Updated YAML file with environment variables at {yaml_file_path}"
|
539
|
-
)
|
589
|
+
return dynamic_pipeline
|
540
590
|
|
541
591
|
# Save the generated pipeline to a file.
|
542
592
|
fileio.makedirs(self.pipeline_directory)
|
@@ -556,9 +606,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
556
606
|
),
|
557
607
|
)
|
558
608
|
|
559
|
-
# Let's update the YAML file with the environment variables
|
560
|
-
_update_json_with_environment(pipeline_file_path, environment)
|
561
|
-
|
562
609
|
logger.info(
|
563
610
|
"Writing Vertex workflow definition to `%s`.", pipeline_file_path
|
564
611
|
)
|