zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +7 -0
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +21 -0
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/constants.py +3 -0
- zenml/enums.py +2 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
- zenml/integrations/lightning/__init__.py +1 -1
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
- zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
- zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +172 -21
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/new/pipelines/run_utils.py +4 -1
- zenml/orchestrators/base_orchestrator.py +41 -12
- zenml/stack/stack.py +11 -2
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_server/cloud_utils.py +33 -8
- zenml/zen_server/routers/runs_endpoints.py +89 -3
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/METADATA +8 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/RECORD +45 -45
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/entry_points.txt +0 -0
zenml/utils/string_utils.py
CHANGED
@@ -15,13 +15,17 @@
|
|
15
15
|
|
16
16
|
import base64
|
17
17
|
import datetime
|
18
|
+
import functools
|
18
19
|
import random
|
19
20
|
import string
|
21
|
+
from typing import Any, Callable, Dict, TypeVar, cast
|
20
22
|
|
21
23
|
from pydantic import BaseModel
|
22
24
|
|
23
25
|
from zenml.constants import BANNED_NAME_CHARACTERS
|
24
26
|
|
27
|
+
V = TypeVar("V", bound=Any)
|
28
|
+
|
25
29
|
|
26
30
|
def get_human_readable_time(seconds: float) -> str:
|
27
31
|
"""Convert seconds into a human-readable string.
|
@@ -167,3 +171,49 @@ def format_name_template(
|
|
167
171
|
datetime.datetime.now(datetime.timezone.utc).strftime("%H_%M_%S_%f"),
|
168
172
|
)
|
169
173
|
return name_template.format(**kwargs)
|
174
|
+
|
175
|
+
|
176
|
+
def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V:
|
177
|
+
"""Recursively substitute strings in objects.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
value: An object in which the strings should be recursively substituted.
|
181
|
+
This can be a pydantic model, dict, set, list, tuple or any
|
182
|
+
primitive type.
|
183
|
+
substitution_func: The function that does the actual string
|
184
|
+
substitution.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
The object with the substitution function applied to all string values.
|
188
|
+
"""
|
189
|
+
substitute_ = functools.partial(
|
190
|
+
substitute_string, substitution_func=substitution_func
|
191
|
+
)
|
192
|
+
|
193
|
+
if isinstance(value, BaseModel):
|
194
|
+
model_values = {}
|
195
|
+
|
196
|
+
for k, v in value.__iter__():
|
197
|
+
new_value = substitute_(v)
|
198
|
+
|
199
|
+
if k not in value.model_fields_set and new_value == getattr(
|
200
|
+
value, k
|
201
|
+
):
|
202
|
+
# This is a default value on the model and was not set
|
203
|
+
# explicitly. In this case, we don't include it in the model
|
204
|
+
# values to keep the `exclude_unset` behavior the same
|
205
|
+
continue
|
206
|
+
|
207
|
+
model_values[k] = new_value
|
208
|
+
|
209
|
+
return cast(V, type(value).model_validate(model_values))
|
210
|
+
elif isinstance(value, Dict):
|
211
|
+
return cast(
|
212
|
+
V, {substitute_(k): substitute_(v) for k, v in value.items()}
|
213
|
+
)
|
214
|
+
elif isinstance(value, (list, set, tuple)):
|
215
|
+
return cast(V, type(value)(substitute_(v) for v in value))
|
216
|
+
elif isinstance(value, str):
|
217
|
+
return cast(V, substitution_func(value))
|
218
|
+
|
219
|
+
return value
|
zenml/zen_server/cloud_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Utils concerning anything concerning the cloud control plane backend."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
from datetime import datetime, timedelta, timezone
|
4
5
|
from typing import Any, Dict, Optional
|
5
6
|
|
6
7
|
import requests
|
@@ -19,11 +20,9 @@ class ZenMLCloudConfiguration(BaseModel):
|
|
19
20
|
"""ZenML Pro RBAC configuration."""
|
20
21
|
|
21
22
|
api_url: str
|
22
|
-
|
23
23
|
oauth2_client_id: str
|
24
24
|
oauth2_client_secret: str
|
25
25
|
oauth2_audience: str
|
26
|
-
auth0_domain: str
|
27
26
|
|
28
27
|
@field_validator("api_url")
|
29
28
|
@classmethod
|
@@ -68,6 +67,8 @@ class ZenMLCloudConnection:
|
|
68
67
|
"""Initialize the RBAC component."""
|
69
68
|
self._config = ZenMLCloudConfiguration.from_environment()
|
70
69
|
self._session: Optional[requests.Session] = None
|
70
|
+
self._token: Optional[str] = None
|
71
|
+
self._token_expires_at: Optional[datetime] = None
|
71
72
|
|
72
73
|
def get(
|
73
74
|
self, endpoint: str, params: Optional[Dict[str, Any]]
|
@@ -91,7 +92,8 @@ class ZenMLCloudConnection:
|
|
91
92
|
|
92
93
|
response = self.session.get(url=url, params=params, timeout=7)
|
93
94
|
if response.status_code == 401:
|
94
|
-
#
|
95
|
+
# If we get an Unauthorized error from the API serer, we refresh the
|
96
|
+
# auth token and try again
|
95
97
|
self._clear_session()
|
96
98
|
response = self.session.get(url=url, params=params, timeout=7)
|
97
99
|
|
@@ -186,6 +188,8 @@ class ZenMLCloudConnection:
|
|
186
188
|
def _clear_session(self) -> None:
|
187
189
|
"""Clear the authentication session."""
|
188
190
|
self._session = None
|
191
|
+
self._token = None
|
192
|
+
self._token_expires_at = None
|
189
193
|
|
190
194
|
def _fetch_auth_token(self) -> str:
|
191
195
|
"""Fetch an auth token for the Cloud API from auth0.
|
@@ -196,8 +200,16 @@ class ZenMLCloudConnection:
|
|
196
200
|
Returns:
|
197
201
|
Auth token.
|
198
202
|
"""
|
203
|
+
if (
|
204
|
+
self._token is not None
|
205
|
+
and self._token_expires_at is not None
|
206
|
+
and datetime.now(timezone.utc) + timedelta(minutes=5)
|
207
|
+
< self._token_expires_at
|
208
|
+
):
|
209
|
+
return self._token
|
210
|
+
|
199
211
|
# Get an auth token from auth0
|
200
|
-
|
212
|
+
login_url = f"{self._config.api_url}/auth/login"
|
201
213
|
headers = {"content-type": "application/x-www-form-urlencoded"}
|
202
214
|
payload = {
|
203
215
|
"client_id": self._config.oauth2_client_id,
|
@@ -207,18 +219,31 @@ class ZenMLCloudConnection:
|
|
207
219
|
}
|
208
220
|
try:
|
209
221
|
response = requests.post(
|
210
|
-
|
222
|
+
login_url, headers=headers, data=payload, timeout=7
|
211
223
|
)
|
212
224
|
response.raise_for_status()
|
213
225
|
except Exception as e:
|
214
226
|
raise RuntimeError(f"Error fetching auth token from auth0: {e}")
|
215
227
|
|
216
|
-
|
228
|
+
json_response = response.json()
|
229
|
+
access_token = json_response.get("access_token", "")
|
230
|
+
expires_in = json_response.get("expires_in", 0)
|
217
231
|
|
218
|
-
if
|
232
|
+
if (
|
233
|
+
not access_token
|
234
|
+
or not isinstance(access_token, str)
|
235
|
+
or not expires_in
|
236
|
+
or not isinstance(expires_in, int)
|
237
|
+
):
|
219
238
|
raise RuntimeError("Could not fetch auth token from auth0.")
|
220
239
|
|
221
|
-
|
240
|
+
self._token = access_token
|
241
|
+
self._token_expires_at = datetime.now(timezone.utc) + timedelta(
|
242
|
+
seconds=expires_in
|
243
|
+
)
|
244
|
+
|
245
|
+
assert self._token is not None
|
246
|
+
return self._token
|
222
247
|
|
223
248
|
|
224
249
|
def cloud_connection() -> ZenMLCloudConnection:
|
@@ -22,13 +22,15 @@ from zenml.constants import (
|
|
22
22
|
API,
|
23
23
|
GRAPH,
|
24
24
|
PIPELINE_CONFIGURATION,
|
25
|
+
REFRESH,
|
25
26
|
RUNS,
|
26
27
|
STATUS,
|
27
28
|
STEPS,
|
28
29
|
VERSION_1,
|
29
30
|
)
|
30
|
-
from zenml.enums import ExecutionStatus
|
31
|
+
from zenml.enums import ExecutionStatus, StackComponentType
|
31
32
|
from zenml.lineage_graph.lineage_graph import LineageGraph
|
33
|
+
from zenml.logger import get_logger
|
32
34
|
from zenml.models import (
|
33
35
|
Page,
|
34
36
|
PipelineRunFilter,
|
@@ -45,7 +47,8 @@ from zenml.zen_server.rbac.endpoint_utils import (
|
|
45
47
|
verify_permissions_and_list_entities,
|
46
48
|
verify_permissions_and_update_entity,
|
47
49
|
)
|
48
|
-
from zenml.zen_server.rbac.models import ResourceType
|
50
|
+
from zenml.zen_server.rbac.models import Action, ResourceType
|
51
|
+
from zenml.zen_server.rbac.utils import verify_permission_for_model
|
49
52
|
from zenml.zen_server.utils import (
|
50
53
|
handle_exceptions,
|
51
54
|
make_dependable,
|
@@ -59,6 +62,9 @@ router = APIRouter(
|
|
59
62
|
)
|
60
63
|
|
61
64
|
|
65
|
+
logger = get_logger(__name__)
|
66
|
+
|
67
|
+
|
62
68
|
@router.get(
|
63
69
|
"",
|
64
70
|
response_model=Page[PipelineRunResponse],
|
@@ -99,6 +105,7 @@ def list_runs(
|
|
99
105
|
def get_run(
|
100
106
|
run_id: UUID,
|
101
107
|
hydrate: bool = True,
|
108
|
+
refresh_status: bool = False,
|
102
109
|
_: AuthContext = Security(authorize),
|
103
110
|
) -> PipelineRunResponse:
|
104
111
|
"""Get a specific pipeline run using its ID.
|
@@ -107,13 +114,47 @@ def get_run(
|
|
107
114
|
run_id: ID of the pipeline run to get.
|
108
115
|
hydrate: Flag deciding whether to hydrate the output model(s)
|
109
116
|
by including metadata fields in the response.
|
117
|
+
refresh_status: Flag deciding whether we should try to refresh
|
118
|
+
the status of the pipeline run using its orchestrator.
|
110
119
|
|
111
120
|
Returns:
|
112
121
|
The pipeline run.
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
RuntimeError: If the stack or the orchestrator of the run is deleted.
|
113
125
|
"""
|
114
|
-
|
126
|
+
run = verify_permissions_and_get_entity(
|
115
127
|
id=run_id, get_method=zen_store().get_run, hydrate=hydrate
|
116
128
|
)
|
129
|
+
if refresh_status:
|
130
|
+
try:
|
131
|
+
# Check the stack and its orchestrator
|
132
|
+
if run.stack is not None:
|
133
|
+
orchestrators = run.stack.components.get(
|
134
|
+
StackComponentType.ORCHESTRATOR, []
|
135
|
+
)
|
136
|
+
if orchestrators:
|
137
|
+
verify_permission_for_model(
|
138
|
+
model=orchestrators[0], action=Action.READ
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
raise RuntimeError(
|
142
|
+
f"The orchestrator, the run '{run.id}' was executed "
|
143
|
+
"with, is deleted."
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
raise RuntimeError(
|
147
|
+
f"The stack, the run '{run.id}' was executed on, is deleted."
|
148
|
+
)
|
149
|
+
|
150
|
+
run = run.refresh_run_status()
|
151
|
+
|
152
|
+
except Exception as e:
|
153
|
+
logger.warning(
|
154
|
+
"An error occurred while refreshing the status of the "
|
155
|
+
f"pipeline run: {e}"
|
156
|
+
)
|
157
|
+
return run
|
117
158
|
|
118
159
|
|
119
160
|
@router.put(
|
@@ -267,3 +308,48 @@ def get_run_status(
|
|
267
308
|
id=run_id, get_method=zen_store().get_run, hydrate=False
|
268
309
|
)
|
269
310
|
return run.status
|
311
|
+
|
312
|
+
|
313
|
+
@router.get(
|
314
|
+
"/{run_id}" + REFRESH,
|
315
|
+
responses={401: error_response, 404: error_response, 422: error_response},
|
316
|
+
)
|
317
|
+
@handle_exceptions
|
318
|
+
def refresh_run_status(
|
319
|
+
run_id: UUID,
|
320
|
+
_: AuthContext = Security(authorize),
|
321
|
+
) -> None:
|
322
|
+
"""Refreshes the status of a specific pipeline run.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
run_id: ID of the pipeline run to refresh.
|
326
|
+
|
327
|
+
Raises:
|
328
|
+
RuntimeError: If the stack or the orchestrator of the run is deleted.
|
329
|
+
"""
|
330
|
+
# Verify access to the run
|
331
|
+
run = verify_permissions_and_get_entity(
|
332
|
+
id=run_id,
|
333
|
+
get_method=zen_store().get_run,
|
334
|
+
hydrate=True,
|
335
|
+
)
|
336
|
+
|
337
|
+
# Check the stack and its orchestrator
|
338
|
+
if run.stack is not None:
|
339
|
+
orchestrators = run.stack.components.get(
|
340
|
+
StackComponentType.ORCHESTRATOR, []
|
341
|
+
)
|
342
|
+
if orchestrators:
|
343
|
+
verify_permission_for_model(
|
344
|
+
model=orchestrators[0], action=Action.READ
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
raise RuntimeError(
|
348
|
+
f"The orchestrator, the run '{run.id}' was executed with, is "
|
349
|
+
"deleted."
|
350
|
+
)
|
351
|
+
else:
|
352
|
+
raise RuntimeError(
|
353
|
+
f"The stack, the run '{run.id}' was executed on, is deleted."
|
354
|
+
)
|
355
|
+
run.refresh_run_status()
|
@@ -973,6 +973,7 @@ class SqlZenStore(BaseZenStore):
|
|
973
973
|
ValueError: if the filtered page number is out of bounds.
|
974
974
|
RuntimeError: if the schema does not have a `to_model` method.
|
975
975
|
"""
|
976
|
+
query = query.distinct()
|
976
977
|
query = filter_model.apply_filter(query=query, table=table)
|
977
978
|
query = query.distinct()
|
978
979
|
|
{zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: zenml-nightly
|
3
|
-
Version: 0.66.0.
|
3
|
+
Version: 0.66.0.dev20240925
|
4
4
|
Summary: ZenML: Write production-ready ML code.
|
5
5
|
Home-page: https://zenml.io
|
6
6
|
License: Apache-2.0
|
@@ -24,6 +24,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
24
24
|
Classifier: Topic :: System :: Distributed Computing
|
25
25
|
Classifier: Typing :: Typed
|
26
26
|
Provides-Extra: adlfs
|
27
|
+
Provides-Extra: azureml
|
27
28
|
Provides-Extra: connectors-aws
|
28
29
|
Provides-Extra: connectors-azure
|
29
30
|
Provides-Extra: connectors-gcp
|
@@ -32,6 +33,7 @@ Provides-Extra: dev
|
|
32
33
|
Provides-Extra: gcsfs
|
33
34
|
Provides-Extra: mlstacks
|
34
35
|
Provides-Extra: s3fs
|
36
|
+
Provides-Extra: sagemaker
|
35
37
|
Provides-Extra: secrets-aws
|
36
38
|
Provides-Extra: secrets-azure
|
37
39
|
Provides-Extra: secrets-gcp
|
@@ -39,10 +41,12 @@ Provides-Extra: secrets-hashicorp
|
|
39
41
|
Provides-Extra: server
|
40
42
|
Provides-Extra: templates
|
41
43
|
Provides-Extra: terraform
|
44
|
+
Provides-Extra: vertex
|
42
45
|
Requires-Dist: Jinja2 ; extra == "server"
|
43
46
|
Requires-Dist: adlfs (>=2021.10.0) ; extra == "adlfs"
|
44
47
|
Requires-Dist: alembic (>=1.8.1,<1.9.0)
|
45
48
|
Requires-Dist: aws-profile-manager (>=0.5.0) ; extra == "connectors-aws"
|
49
|
+
Requires-Dist: azure-ai-ml (==1.18.0) ; extra == "azureml"
|
46
50
|
Requires-Dist: azure-identity (>=1.4.0) ; extra == "secrets-azure" or extra == "connectors-azure"
|
47
51
|
Requires-Dist: azure-keyvault-secrets (>=4.0.0) ; extra == "secrets-azure"
|
48
52
|
Requires-Dist: azure-mgmt-containerregistry (>=10.0.0) ; extra == "connectors-azure"
|
@@ -63,6 +67,7 @@ Requires-Dist: docker (>=7.1.0,<7.2.0)
|
|
63
67
|
Requires-Dist: fastapi (>=0.100,<=0.110) ; extra == "server"
|
64
68
|
Requires-Dist: gcsfs (>=2022.11.0) ; extra == "gcsfs"
|
65
69
|
Requires-Dist: gitpython (>=3.1.18,<4.0.0)
|
70
|
+
Requires-Dist: google-cloud-aiplatform (>=1.34.0) ; extra == "vertex"
|
66
71
|
Requires-Dist: google-cloud-artifact-registry (>=1.11.3) ; extra == "connectors-gcp"
|
67
72
|
Requires-Dist: google-cloud-container (>=2.21.0) ; extra == "connectors-gcp"
|
68
73
|
Requires-Dist: google-cloud-secret-manager (>=2.12.5) ; extra == "secrets-gcp"
|
@@ -72,6 +77,7 @@ Requires-Dist: hypothesis (>=6.43.1,<7.0.0) ; extra == "dev"
|
|
72
77
|
Requires-Dist: importlib_metadata (<=7.0.0) ; python_version < "3.10"
|
73
78
|
Requires-Dist: ipinfo (>=4.4.3) ; extra == "server"
|
74
79
|
Requires-Dist: jinja2-time (>=0.2.0,<0.3.0) ; extra == "templates"
|
80
|
+
Requires-Dist: kfp (>=2.6.0) ; extra == "vertex"
|
75
81
|
Requires-Dist: kubernetes (>=18.20.0) ; extra == "connectors-kubernetes" or extra == "connectors-aws" or extra == "connectors-gcp" or extra == "connectors-azure"
|
76
82
|
Requires-Dist: maison (<2.0) ; extra == "dev"
|
77
83
|
Requires-Dist: mike (>=1.1.2,<2.0.0) ; extra == "dev"
|
@@ -105,6 +111,7 @@ Requires-Dist: pyyaml-include (<2.0) ; extra == "templates"
|
|
105
111
|
Requires-Dist: rich[jupyter] (>=12.0.0)
|
106
112
|
Requires-Dist: ruff (>=0.1.7) ; extra == "templates" or extra == "dev"
|
107
113
|
Requires-Dist: s3fs (>=2022.11.0) ; extra == "s3fs"
|
114
|
+
Requires-Dist: sagemaker (>=2.117.0) ; extra == "sagemaker"
|
108
115
|
Requires-Dist: secure (>=0.3.0,<0.4.0) ; extra == "server"
|
109
116
|
Requires-Dist: setuptools
|
110
117
|
Requires-Dist: sqlalchemy (>=2.0.0,<3.0.0)
|