zenml-nightly 0.70.0.dev20241202__py3-none-any.whl → 0.70.0.dev20241204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +4 -4
- zenml/cli/base.py +1 -1
- zenml/cli/pipeline.py +48 -79
- zenml/config/secret_reference_mixin.py +1 -1
- zenml/image_builders/base_image_builder.py +5 -2
- zenml/image_builders/build_context.py +7 -16
- zenml/integrations/aws/__init__.py +3 -0
- zenml/integrations/aws/flavors/__init__.py +6 -0
- zenml/integrations/aws/flavors/aws_image_builder_flavor.py +146 -0
- zenml/integrations/aws/image_builders/__init__.py +20 -0
- zenml/integrations/aws/image_builders/aws_image_builder.py +307 -0
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
- zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +2 -1
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +11 -0
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +11 -0
- zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py +7 -5
- zenml/integrations/neptune/experiment_trackers/run_state.py +69 -53
- zenml/integrations/registry.py +2 -2
- zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +12 -0
- zenml/materializers/built_in_materializer.py +1 -1
- zenml/orchestrators/base_orchestrator.py +13 -1
- zenml/orchestrators/output_utils.py +5 -1
- zenml/service_connectors/service_connector_utils.py +3 -9
- zenml/stack/stack_component.py +1 -1
- zenml/stack_deployments/aws_stack_deployment.py +22 -0
- zenml/utils/archivable.py +65 -36
- zenml/utils/code_utils.py +8 -4
- zenml/utils/docker_utils.py +9 -0
- zenml/zen_stores/rest_zen_store.py +1 -1
- {zenml_nightly-0.70.0.dev20241202.dist-info → zenml_nightly-0.70.0.dev20241204.dist-info}/METADATA +1 -1
- {zenml_nightly-0.70.0.dev20241202.dist-info → zenml_nightly-0.70.0.dev20241204.dist-info}/RECORD +35 -32
- {zenml_nightly-0.70.0.dev20241202.dist-info → zenml_nightly-0.70.0.dev20241204.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.70.0.dev20241202.dist-info → zenml_nightly-0.70.0.dev20241204.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.70.0.dev20241202.dist-info → zenml_nightly-0.70.0.dev20241204.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""AWS Code Build image builder implementation."""
|
15
|
+
|
16
|
+
import time
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
|
18
|
+
from urllib.parse import urlparse
|
19
|
+
from uuid import uuid4
|
20
|
+
|
21
|
+
import boto3
|
22
|
+
|
23
|
+
from zenml.enums import StackComponentType
|
24
|
+
from zenml.image_builders import BaseImageBuilder
|
25
|
+
from zenml.integrations.aws import (
|
26
|
+
AWS_CONTAINER_REGISTRY_FLAVOR,
|
27
|
+
)
|
28
|
+
from zenml.integrations.aws.flavors import AWSImageBuilderConfig
|
29
|
+
from zenml.logger import get_logger
|
30
|
+
from zenml.stack import StackValidator
|
31
|
+
from zenml.utils.archivable import ArchiveType
|
32
|
+
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from zenml.container_registries import BaseContainerRegistry
|
35
|
+
from zenml.image_builders import BuildContext
|
36
|
+
from zenml.stack import Stack
|
37
|
+
|
38
|
+
logger = get_logger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class AWSImageBuilder(BaseImageBuilder):
|
42
|
+
"""AWS Code Build image builder implementation."""
|
43
|
+
|
44
|
+
_code_build_client: Optional[Any] = None
|
45
|
+
|
46
|
+
@property
|
47
|
+
def config(self) -> AWSImageBuilderConfig:
|
48
|
+
"""The stack component configuration.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The configuration.
|
52
|
+
"""
|
53
|
+
return cast(AWSImageBuilderConfig, self._config)
|
54
|
+
|
55
|
+
@property
|
56
|
+
def is_building_locally(self) -> bool:
|
57
|
+
"""Whether the image builder builds the images on the client machine.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
True if the image builder builds locally, False otherwise.
|
61
|
+
"""
|
62
|
+
return False
|
63
|
+
|
64
|
+
@property
|
65
|
+
def validator(self) -> Optional["StackValidator"]:
|
66
|
+
"""Validates the stack for the AWS Code Build Image Builder.
|
67
|
+
|
68
|
+
The AWS Code Build Image Builder requires a container registry to
|
69
|
+
push the image to and an S3 Artifact Store to upload the build context,
|
70
|
+
so AWS Code Build can access it.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
Stack validator.
|
74
|
+
"""
|
75
|
+
|
76
|
+
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
|
77
|
+
if stack.artifact_store.flavor != "s3":
|
78
|
+
return False, (
|
79
|
+
"The AWS Image Builder requires an S3 Artifact Store to "
|
80
|
+
"upload the build context, so AWS Code Build can access it."
|
81
|
+
"Please update your stack to include an S3 Artifact Store "
|
82
|
+
"and try again."
|
83
|
+
)
|
84
|
+
|
85
|
+
return True, ""
|
86
|
+
|
87
|
+
return StackValidator(
|
88
|
+
required_components={StackComponentType.CONTAINER_REGISTRY},
|
89
|
+
custom_validation_function=_validate_remote_components,
|
90
|
+
)
|
91
|
+
|
92
|
+
@property
|
93
|
+
def code_build_client(self) -> Any:
|
94
|
+
"""The authenticated AWS Code Build client to use for interacting with AWS services.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The authenticated AWS Code Build client.
|
98
|
+
|
99
|
+
Raises:
|
100
|
+
RuntimeError: If the AWS Code Build client cannot be created.
|
101
|
+
"""
|
102
|
+
if (
|
103
|
+
self._code_build_client is not None
|
104
|
+
and self.connector_has_expired()
|
105
|
+
):
|
106
|
+
self._code_build_client = None
|
107
|
+
if self._code_build_client is not None:
|
108
|
+
return self._code_build_client
|
109
|
+
|
110
|
+
# Option 1: Service connector
|
111
|
+
if connector := self.get_connector():
|
112
|
+
boto_session = connector.connect()
|
113
|
+
if not isinstance(boto_session, boto3.Session):
|
114
|
+
raise RuntimeError(
|
115
|
+
f"Expected to receive a `boto3.Session` object from the "
|
116
|
+
f"linked connector, but got type `{type(boto_session)}`."
|
117
|
+
)
|
118
|
+
# Option 2: Implicit configuration
|
119
|
+
else:
|
120
|
+
boto_session = boto3.Session()
|
121
|
+
|
122
|
+
self._code_build_client = boto_session.client("codebuild")
|
123
|
+
return self._code_build_client
|
124
|
+
|
125
|
+
def build(
|
126
|
+
self,
|
127
|
+
image_name: str,
|
128
|
+
build_context: "BuildContext",
|
129
|
+
docker_build_options: Dict[str, Any],
|
130
|
+
container_registry: Optional["BaseContainerRegistry"] = None,
|
131
|
+
) -> str:
|
132
|
+
"""Builds and pushes a Docker image.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
image_name: Name of the image to build and push.
|
136
|
+
build_context: The build context to use for the image.
|
137
|
+
docker_build_options: Docker build options.
|
138
|
+
container_registry: Optional container registry to push to.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
The Docker image name with digest.
|
142
|
+
|
143
|
+
Raises:
|
144
|
+
RuntimeError: If no container registry is passed.
|
145
|
+
RuntimeError: If the Cloud Build build fails.
|
146
|
+
"""
|
147
|
+
if not container_registry:
|
148
|
+
raise RuntimeError(
|
149
|
+
"The AWS Image Builder requires a container registry to push "
|
150
|
+
"the image to. Please provide one and try again."
|
151
|
+
)
|
152
|
+
|
153
|
+
logger.info("Using AWS Code Build to build image `%s`", image_name)
|
154
|
+
cloud_build_context = self._upload_build_context(
|
155
|
+
build_context=build_context,
|
156
|
+
parent_path_directory_name=f"code-build-contexts/{str(self.id)}",
|
157
|
+
archive_type=ArchiveType.ZIP,
|
158
|
+
)
|
159
|
+
|
160
|
+
url_parts = urlparse(cloud_build_context)
|
161
|
+
bucket = url_parts.netloc
|
162
|
+
object_path = url_parts.path.lstrip("/")
|
163
|
+
logger.info(
|
164
|
+
"Build context located in bucket `%s` and object path `%s`",
|
165
|
+
bucket,
|
166
|
+
object_path,
|
167
|
+
)
|
168
|
+
|
169
|
+
# Pass authentication credentials as environment variables, if
|
170
|
+
# the container registry has credentials and if implicit authentication
|
171
|
+
# is disabled
|
172
|
+
environment_variables_override: Dict[str, str] = {}
|
173
|
+
pre_build_commands = []
|
174
|
+
if not self.config.implicit_container_registry_auth:
|
175
|
+
credentials = container_registry.credentials
|
176
|
+
if credentials:
|
177
|
+
environment_variables_override = {
|
178
|
+
"CONTAINER_REGISTRY_USERNAME": credentials[0],
|
179
|
+
"CONTAINER_REGISTRY_PASSWORD": credentials[1],
|
180
|
+
}
|
181
|
+
pre_build_commands = [
|
182
|
+
"echo Logging in to container registry",
|
183
|
+
'echo "$CONTAINER_REGISTRY_PASSWORD" | docker login --username "$CONTAINER_REGISTRY_USERNAME" --password-stdin '
|
184
|
+
f"{container_registry.config.uri}",
|
185
|
+
]
|
186
|
+
elif container_registry.flavor == AWS_CONTAINER_REGISTRY_FLAVOR:
|
187
|
+
pre_build_commands = [
|
188
|
+
"echo Logging in to EKS",
|
189
|
+
f"aws ecr get-login-password --region {self.code_build_client._client_config.region_name} | docker login --username AWS --password-stdin {container_registry.config.uri}",
|
190
|
+
]
|
191
|
+
|
192
|
+
# Convert the docker_build_options dictionary to a list of strings
|
193
|
+
docker_build_args = ""
|
194
|
+
for key, value in docker_build_options.items():
|
195
|
+
option = f"--{key}"
|
196
|
+
if isinstance(value, list):
|
197
|
+
for val in value:
|
198
|
+
docker_build_args += f"{option} {val} "
|
199
|
+
elif value is not None and not isinstance(value, bool):
|
200
|
+
docker_build_args += f"{option} {value} "
|
201
|
+
elif value is not False:
|
202
|
+
docker_build_args += f"{option} "
|
203
|
+
|
204
|
+
pre_build_commands_str = "\n".join(
|
205
|
+
[f" - {command}" for command in pre_build_commands]
|
206
|
+
)
|
207
|
+
|
208
|
+
# Generate and use a unique tag for the Docker image. This is easier
|
209
|
+
# than trying to parse the image digest from the Code Build logs.
|
210
|
+
build_id = str(uuid4())
|
211
|
+
# Replace the tag in the image name with the unique build ID
|
212
|
+
repo_name = image_name.split(":")[0]
|
213
|
+
alt_image_name = f"{repo_name}:{build_id}"
|
214
|
+
|
215
|
+
buildspec = f"""
|
216
|
+
version: 0.2
|
217
|
+
phases:
|
218
|
+
pre_build:
|
219
|
+
commands:
|
220
|
+
{pre_build_commands_str}
|
221
|
+
build:
|
222
|
+
commands:
|
223
|
+
- echo Build started on `date`
|
224
|
+
- echo Building the Docker image...
|
225
|
+
- docker build -t {image_name} . {docker_build_args}
|
226
|
+
- echo Build completed on `date`
|
227
|
+
post_build:
|
228
|
+
commands:
|
229
|
+
- echo Pushing the Docker image...
|
230
|
+
- docker push {image_name}
|
231
|
+
- docker tag {image_name} {alt_image_name}
|
232
|
+
- docker push {alt_image_name}
|
233
|
+
- echo Pushed the Docker image
|
234
|
+
artifacts:
|
235
|
+
files:
|
236
|
+
- '**/*'
|
237
|
+
"""
|
238
|
+
|
239
|
+
if self.config.custom_env_vars:
|
240
|
+
environment_variables_override.update(self.config.custom_env_vars)
|
241
|
+
|
242
|
+
environment_variables_override_list = [
|
243
|
+
{
|
244
|
+
"name": key,
|
245
|
+
"value": value,
|
246
|
+
"type": "PLAINTEXT",
|
247
|
+
}
|
248
|
+
for key, value in environment_variables_override.items()
|
249
|
+
]
|
250
|
+
|
251
|
+
# Override the build project with the parameters needed to run a
|
252
|
+
# docker-in-docker build, as covered here: https://docs.aws.amazon.com/codebuild/latest/userguide/sample-docker-section.html
|
253
|
+
response = self.code_build_client.start_build(
|
254
|
+
projectName=self.config.code_build_project,
|
255
|
+
environmentTypeOverride="LINUX_CONTAINER",
|
256
|
+
imageOverride=self.config.build_image,
|
257
|
+
computeTypeOverride=self.config.compute_type,
|
258
|
+
privilegedModeOverride=False,
|
259
|
+
sourceTypeOverride="S3",
|
260
|
+
sourceLocationOverride=f"{bucket}/{object_path}",
|
261
|
+
buildspecOverride=buildspec,
|
262
|
+
environmentVariablesOverride=environment_variables_override_list,
|
263
|
+
# no artifacts
|
264
|
+
artifactsOverride={"type": "NO_ARTIFACTS"},
|
265
|
+
)
|
266
|
+
|
267
|
+
build_arn = response["build"]["arn"]
|
268
|
+
|
269
|
+
# Parse the AWS region, account, codebuild project and build name from the ARN
|
270
|
+
aws_region, aws_account, build = build_arn.split(":", maxsplit=5)[3:6]
|
271
|
+
codebuild_project = build.split("/")[1].split(":")[0]
|
272
|
+
|
273
|
+
logs_url = f"https://{aws_region}.console.aws.amazon.com/codesuite/codebuild/{aws_account}/projects/{codebuild_project}/{build}/log"
|
274
|
+
logger.info(
|
275
|
+
f"Running Code Build to build the Docker image. Cloud Build logs: `{logs_url}`",
|
276
|
+
)
|
277
|
+
|
278
|
+
# Wait for the build to complete
|
279
|
+
code_build_id = response["build"]["id"]
|
280
|
+
while True:
|
281
|
+
build_status = self.code_build_client.batch_get_builds(
|
282
|
+
ids=[code_build_id]
|
283
|
+
)
|
284
|
+
build = build_status["builds"][0]
|
285
|
+
status = build["buildStatus"]
|
286
|
+
if status in [
|
287
|
+
"SUCCEEDED",
|
288
|
+
"FAILED",
|
289
|
+
"FAULT",
|
290
|
+
"TIMED_OUT",
|
291
|
+
"STOPPED",
|
292
|
+
]:
|
293
|
+
break
|
294
|
+
time.sleep(10)
|
295
|
+
|
296
|
+
if status != "SUCCEEDED":
|
297
|
+
raise RuntimeError(
|
298
|
+
f"The Code Build run to build the Docker image has failed. More "
|
299
|
+
f"information can be found in the Cloud Build logs: {logs_url}."
|
300
|
+
)
|
301
|
+
|
302
|
+
logger.info(
|
303
|
+
f"The Docker image has been built successfully. More information can "
|
304
|
+
f"be found in the Cloud Build logs: `{logs_url}`."
|
305
|
+
)
|
306
|
+
|
307
|
+
return alt_image_name
|
@@ -793,7 +793,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
793
793
|
"set or set to 0. The accelerator type will be ignored. "
|
794
794
|
"To fix this warning, either remove the specified "
|
795
795
|
"accelerator type or set the `gpu_count` using the "
|
796
|
-
"ResourceSettings (https://docs.zenml.io/how-to/training-with-gpus
|
796
|
+
"ResourceSettings (https://docs.zenml.io/how-to/advanced-topics/training-with-gpus)."
|
797
797
|
)
|
798
798
|
|
799
799
|
return dynamic_component
|
@@ -25,6 +25,7 @@ from zenml.image_builders import BaseImageBuilder
|
|
25
25
|
from zenml.integrations.kaniko.flavors import KanikoImageBuilderConfig
|
26
26
|
from zenml.logger import get_logger
|
27
27
|
from zenml.stack import StackValidator
|
28
|
+
from zenml.utils.archivable import ArchiveType
|
28
29
|
|
29
30
|
if TYPE_CHECKING:
|
30
31
|
from zenml.container_registries import BaseContainerRegistry
|
@@ -295,7 +296,7 @@ class KanikoImageBuilder(BaseImageBuilder):
|
|
295
296
|
logger.debug("Writing build context to process stdin.")
|
296
297
|
assert process.stdin
|
297
298
|
with process.stdin as _, tempfile.TemporaryFile(mode="w+b") as f:
|
298
|
-
build_context.write_archive(f,
|
299
|
+
build_context.write_archive(f, archive_type=ArchiveType.TAR_GZ)
|
299
300
|
while True:
|
300
301
|
data = f.read(1024)
|
301
302
|
if not data:
|
@@ -134,6 +134,17 @@ class KubernetesOrchestratorConfig(
|
|
134
134
|
"""
|
135
135
|
return True
|
136
136
|
|
137
|
+
@property
|
138
|
+
def supports_client_side_caching(self) -> bool:
|
139
|
+
"""Whether the orchestrator supports client side caching.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Whether the orchestrator supports client side caching.
|
143
|
+
"""
|
144
|
+
# The Kubernetes orchestrator starts step pods from a pipeline pod.
|
145
|
+
# This is currently not supported when using client-side caching.
|
146
|
+
return False
|
147
|
+
|
137
148
|
|
138
149
|
class KubernetesOrchestratorFlavor(BaseOrchestratorFlavor):
|
139
150
|
"""Kubernetes orchestrator flavor."""
|
@@ -94,6 +94,17 @@ class LightningOrchestratorConfig(
|
|
94
94
|
"""
|
95
95
|
return False
|
96
96
|
|
97
|
+
@property
|
98
|
+
def supports_client_side_caching(self) -> bool:
|
99
|
+
"""Whether the orchestrator supports client side caching.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Whether the orchestrator supports client side caching.
|
103
|
+
"""
|
104
|
+
# The Lightning orchestrator starts step studios from a pipeline studio.
|
105
|
+
# This is currently not supported when using client-side caching.
|
106
|
+
return False
|
107
|
+
|
97
108
|
|
98
109
|
class LightningOrchestratorFlavor(BaseOrchestratorFlavor):
|
99
110
|
"""Lightning orchestrator flavor."""
|
@@ -77,10 +77,12 @@ class NeptuneExperimentTracker(BaseExperimentTracker):
|
|
77
77
|
NeptuneExperimentTrackerSettings, self.get_settings(info)
|
78
78
|
)
|
79
79
|
|
80
|
-
self.run_state.
|
81
|
-
|
82
|
-
|
83
|
-
|
80
|
+
self.run_state.initialize(
|
81
|
+
project=self.config.project,
|
82
|
+
token=self.config.api_token,
|
83
|
+
run_name=info.run_name,
|
84
|
+
tags=list(settings.tags),
|
85
|
+
)
|
84
86
|
|
85
87
|
def get_step_run_metadata(
|
86
88
|
self, info: "StepRunInfo"
|
@@ -107,4 +109,4 @@ class NeptuneExperimentTracker(BaseExperimentTracker):
|
|
107
109
|
"""
|
108
110
|
self.run_state.active_run.sync()
|
109
111
|
self.run_state.active_run.stop()
|
110
|
-
self.run_state.
|
112
|
+
self.run_state.reset()
|
@@ -20,7 +20,6 @@ import neptune
|
|
20
20
|
|
21
21
|
import zenml
|
22
22
|
from zenml.client import Client
|
23
|
-
from zenml.integrations.constants import NEPTUNE
|
24
23
|
from zenml.utils.singleton import SingletonMetaClass
|
25
24
|
|
26
25
|
if TYPE_CHECKING:
|
@@ -29,20 +28,38 @@ if TYPE_CHECKING:
|
|
29
28
|
_INTEGRATION_VERSION_KEY = "source_code/integrations/zenml"
|
30
29
|
|
31
30
|
|
32
|
-
class InvalidExperimentTrackerSelected(Exception):
|
33
|
-
"""Raised if a Neptune run is fetched while using a different experiment tracker."""
|
34
|
-
|
35
|
-
|
36
31
|
class RunProvider(metaclass=SingletonMetaClass):
|
37
32
|
"""Singleton object used to store and persist a Neptune run state across the pipeline."""
|
38
33
|
|
39
34
|
def __init__(self) -> None:
|
40
35
|
"""Initialize RunProvider. Called with no arguments."""
|
41
36
|
self._active_run: Optional["Run"] = None
|
42
|
-
self._project: Optional[str]
|
43
|
-
self._run_name: Optional[str]
|
44
|
-
self._token: Optional[str]
|
45
|
-
self._tags: Optional[List[str]]
|
37
|
+
self._project: Optional[str] = None
|
38
|
+
self._run_name: Optional[str] = None
|
39
|
+
self._token: Optional[str] = None
|
40
|
+
self._tags: Optional[List[str]] = None
|
41
|
+
self._initialized = False
|
42
|
+
|
43
|
+
def initialize(
|
44
|
+
self,
|
45
|
+
project: Optional[str] = None,
|
46
|
+
token: Optional[str] = None,
|
47
|
+
run_name: Optional[str] = None,
|
48
|
+
tags: Optional[List[str]] = None,
|
49
|
+
) -> None:
|
50
|
+
"""Initialize the run state.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
project: The neptune project.
|
54
|
+
token: The neptune token.
|
55
|
+
run_name: The neptune run name.
|
56
|
+
tags: Tags for the neptune run.
|
57
|
+
"""
|
58
|
+
self._project = project
|
59
|
+
self._token = token
|
60
|
+
self._run_name = run_name
|
61
|
+
self._tags = tags
|
62
|
+
self._initialized = True
|
46
63
|
|
47
64
|
@property
|
48
65
|
def project(self) -> Optional[Any]:
|
@@ -53,15 +70,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
53
70
|
"""
|
54
71
|
return self._project
|
55
72
|
|
56
|
-
@project.setter
|
57
|
-
def project(self, project: str) -> None:
|
58
|
-
"""Setter for project name.
|
59
|
-
|
60
|
-
Args:
|
61
|
-
project: Neptune project name
|
62
|
-
"""
|
63
|
-
self._project = project
|
64
|
-
|
65
73
|
@property
|
66
74
|
def token(self) -> Optional[Any]:
|
67
75
|
"""Getter for API token.
|
@@ -71,15 +79,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
71
79
|
"""
|
72
80
|
return self._token
|
73
81
|
|
74
|
-
@token.setter
|
75
|
-
def token(self, token: str) -> None:
|
76
|
-
"""Setter for API token.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
token: Neptune API token
|
80
|
-
"""
|
81
|
-
self._token = token
|
82
|
-
|
83
82
|
@property
|
84
83
|
def run_name(self) -> Optional[Any]:
|
85
84
|
"""Getter for run name.
|
@@ -89,15 +88,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
89
88
|
"""
|
90
89
|
return self._run_name
|
91
90
|
|
92
|
-
@run_name.setter
|
93
|
-
def run_name(self, run_name: str) -> None:
|
94
|
-
"""Setter for run name.
|
95
|
-
|
96
|
-
Args:
|
97
|
-
run_name: name of the pipeline run
|
98
|
-
"""
|
99
|
-
self._run_name = run_name
|
100
|
-
|
101
91
|
@property
|
102
92
|
def tags(self) -> Optional[Any]:
|
103
93
|
"""Getter for run tags.
|
@@ -107,14 +97,14 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
107
97
|
"""
|
108
98
|
return self._tags
|
109
99
|
|
110
|
-
@
|
111
|
-
def
|
112
|
-
"""
|
100
|
+
@property
|
101
|
+
def initialized(self) -> bool:
|
102
|
+
"""If the run state is initialized.
|
113
103
|
|
114
|
-
|
115
|
-
|
104
|
+
Returns:
|
105
|
+
If the run state is initialized.
|
116
106
|
"""
|
117
|
-
self.
|
107
|
+
return self._initialized
|
118
108
|
|
119
109
|
@property
|
120
110
|
def active_run(self) -> "Run":
|
@@ -137,9 +127,14 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
137
127
|
self._active_run = run
|
138
128
|
return self._active_run
|
139
129
|
|
140
|
-
def
|
141
|
-
"""
|
130
|
+
def reset(self) -> None:
|
131
|
+
"""Reset the run state."""
|
142
132
|
self._active_run = None
|
133
|
+
self._project = None
|
134
|
+
self._run_name = None
|
135
|
+
self._token = None
|
136
|
+
self._tags = None
|
137
|
+
self._initialized = False
|
143
138
|
|
144
139
|
|
145
140
|
def get_neptune_run() -> "Run":
|
@@ -149,14 +144,35 @@ def get_neptune_run() -> "Run":
|
|
149
144
|
Neptune run object
|
150
145
|
|
151
146
|
Raises:
|
152
|
-
|
147
|
+
RuntimeError: When unable to fetch the active neptune run.
|
153
148
|
"""
|
154
|
-
|
155
|
-
|
156
|
-
if experiment_tracker.flavor == NEPTUNE: # type: ignore
|
157
|
-
return experiment_tracker.run_state.active_run # type: ignore
|
158
|
-
raise InvalidExperimentTrackerSelected(
|
159
|
-
"Fetching a Neptune run works only with the 'neptune' flavor of "
|
160
|
-
"the experiment tracker. The flavor currently selected is %s"
|
161
|
-
% experiment_tracker.flavor # type: ignore
|
149
|
+
from zenml.integrations.neptune.experiment_trackers import (
|
150
|
+
NeptuneExperimentTracker,
|
162
151
|
)
|
152
|
+
|
153
|
+
experiment_tracker = Client().active_stack.experiment_tracker
|
154
|
+
|
155
|
+
if not experiment_tracker:
|
156
|
+
raise RuntimeError(
|
157
|
+
"Unable to get neptune run: Missing experiment tracker in the "
|
158
|
+
"active stack."
|
159
|
+
)
|
160
|
+
|
161
|
+
if not isinstance(experiment_tracker, NeptuneExperimentTracker):
|
162
|
+
raise RuntimeError(
|
163
|
+
"Unable to get neptune run: Experiment tracker in the active "
|
164
|
+
f"stack ({experiment_tracker.flavor}) is not a neptune experiment "
|
165
|
+
"tracker."
|
166
|
+
)
|
167
|
+
|
168
|
+
run_state = experiment_tracker.run_state
|
169
|
+
if not run_state.initialized:
|
170
|
+
raise RuntimeError(
|
171
|
+
"Unable to get neptune run: The experiment tracker has not been "
|
172
|
+
"initialized. To solve this, make sure you use the experiment "
|
173
|
+
"tracker in your step. See "
|
174
|
+
"https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it "
|
175
|
+
"for more information."
|
176
|
+
)
|
177
|
+
|
178
|
+
return experiment_tracker.run_state.active_run
|
zenml/integrations/registry.py
CHANGED
@@ -111,7 +111,7 @@ class IntegrationRegistry(object):
|
|
111
111
|
)
|
112
112
|
else:
|
113
113
|
raise KeyError(
|
114
|
-
f"
|
114
|
+
f"Integration {integration_name} does not exist. "
|
115
115
|
f"Currently the following integrations are implemented. "
|
116
116
|
f"{self.list_integration_names}"
|
117
117
|
)
|
@@ -148,7 +148,7 @@ class IntegrationRegistry(object):
|
|
148
148
|
].get_uninstall_requirements(target_os=target_os)
|
149
149
|
else:
|
150
150
|
raise KeyError(
|
151
|
-
f"
|
151
|
+
f"Integration {integration_name} does not exist. "
|
152
152
|
f"Currently the following integrations are implemented. "
|
153
153
|
f"{self.list_integration_names}"
|
154
154
|
)
|
@@ -144,3 +144,15 @@ class SkypilotBaseOrchestratorConfig(
|
|
144
144
|
True if this config is for a local component, False otherwise.
|
145
145
|
"""
|
146
146
|
return False
|
147
|
+
|
148
|
+
@property
|
149
|
+
def supports_client_side_caching(self) -> bool:
|
150
|
+
"""Whether the orchestrator supports client side caching.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Whether the orchestrator supports client side caching.
|
154
|
+
"""
|
155
|
+
# The Skypilot orchestrator runs the entire pipeline in a single VM, or
|
156
|
+
# starts additional VMs from the root VM. Both of those cases are
|
157
|
+
# currently not supported when using client-side caching.
|
158
|
+
return False
|
@@ -429,7 +429,7 @@ class BuiltInContainerMaterializer(BaseMaterializer):
|
|
429
429
|
# doesn't work for non-serializable types as they
|
430
430
|
# are saved as list of lists in different files
|
431
431
|
if _is_serializable(data):
|
432
|
-
return {self.data_path: VisualizationType.JSON}
|
432
|
+
return {self.data_path.replace("\\", "/"): VisualizationType.JSON}
|
433
433
|
return {}
|
434
434
|
|
435
435
|
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
|
@@ -61,7 +61,7 @@ class BaseOrchestratorConfig(StackComponentConfig):
|
|
61
61
|
"The 'custom_docker_base_image_name' field has been "
|
62
62
|
"deprecated. To use a custom base container image with your "
|
63
63
|
"orchestrators, please use the DockerSettings in your "
|
64
|
-
"pipeline (see https://docs.zenml.io/how-to/customize-docker-builds)."
|
64
|
+
"pipeline (see https://docs.zenml.io/how-to/infrastructure-deployment/customize-docker-builds)."
|
65
65
|
)
|
66
66
|
|
67
67
|
return data
|
@@ -84,6 +84,15 @@ class BaseOrchestratorConfig(StackComponentConfig):
|
|
84
84
|
"""
|
85
85
|
return False
|
86
86
|
|
87
|
+
@property
|
88
|
+
def supports_client_side_caching(self) -> bool:
|
89
|
+
"""Whether the orchestrator supports client side caching.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
Whether the orchestrator supports client side caching.
|
93
|
+
"""
|
94
|
+
return True
|
95
|
+
|
87
96
|
|
88
97
|
class BaseOrchestrator(StackComponent, ABC):
|
89
98
|
"""Base class for all orchestrators.
|
@@ -205,6 +214,7 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
205
214
|
|
206
215
|
if (
|
207
216
|
placeholder_run
|
217
|
+
and self.config.supports_client_side_caching
|
208
218
|
and not deployment.schedule
|
209
219
|
and not prevent_client_side_caching
|
210
220
|
):
|
@@ -232,6 +242,8 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
232
242
|
self._cleanup_run()
|
233
243
|
logger.info("All steps of the pipeline run were cached.")
|
234
244
|
return
|
245
|
+
else:
|
246
|
+
logger.debug("Skipping client-side caching.")
|
235
247
|
|
236
248
|
try:
|
237
249
|
if metadata_iterator := self.prepare_or_run_pipeline(
|