zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240925__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/base.py +2 -2
  4. zenml/cli/pipeline.py +21 -0
  5. zenml/cli/utils.py +14 -11
  6. zenml/client.py +68 -3
  7. zenml/config/step_configurations.py +0 -5
  8. zenml/constants.py +3 -0
  9. zenml/enums.py +2 -0
  10. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  11. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
  12. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  13. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  14. zenml/integrations/lightning/__init__.py +1 -1
  15. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  16. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  17. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  18. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  19. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  20. zenml/models/v2/base/filter.py +315 -149
  21. zenml/models/v2/base/scoped.py +5 -2
  22. zenml/models/v2/core/artifact_version.py +69 -8
  23. zenml/models/v2/core/model.py +43 -6
  24. zenml/models/v2/core/model_version.py +49 -1
  25. zenml/models/v2/core/model_version_artifact.py +18 -3
  26. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  27. zenml/models/v2/core/pipeline.py +108 -1
  28. zenml/models/v2/core/pipeline_run.py +172 -21
  29. zenml/models/v2/core/run_template.py +53 -1
  30. zenml/models/v2/core/stack.py +33 -5
  31. zenml/models/v2/core/step_run.py +7 -0
  32. zenml/new/pipelines/pipeline.py +4 -0
  33. zenml/new/pipelines/run_utils.py +4 -1
  34. zenml/orchestrators/base_orchestrator.py +41 -12
  35. zenml/stack/stack.py +11 -2
  36. zenml/utils/env_utils.py +54 -1
  37. zenml/utils/string_utils.py +50 -0
  38. zenml/zen_server/cloud_utils.py +33 -8
  39. zenml/zen_server/routers/runs_endpoints.py +89 -3
  40. zenml/zen_stores/sql_zen_store.py +1 -0
  41. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/METADATA +8 -1
  42. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/RECORD +45 -45
  43. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/LICENSE +0 -0
  44. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/WHEEL +0 -0
  45. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/entry_points.txt +0 -0
@@ -15,13 +15,17 @@
15
15
 
16
16
  import base64
17
17
  import datetime
18
+ import functools
18
19
  import random
19
20
  import string
21
+ from typing import Any, Callable, Dict, TypeVar, cast
20
22
 
21
23
  from pydantic import BaseModel
22
24
 
23
25
  from zenml.constants import BANNED_NAME_CHARACTERS
24
26
 
27
+ V = TypeVar("V", bound=Any)
28
+
25
29
 
26
30
  def get_human_readable_time(seconds: float) -> str:
27
31
  """Convert seconds into a human-readable string.
@@ -167,3 +171,49 @@ def format_name_template(
167
171
  datetime.datetime.now(datetime.timezone.utc).strftime("%H_%M_%S_%f"),
168
172
  )
169
173
  return name_template.format(**kwargs)
174
+
175
+
176
+ def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V:
177
+ """Recursively substitute strings in objects.
178
+
179
+ Args:
180
+ value: An object in which the strings should be recursively substituted.
181
+ This can be a pydantic model, dict, set, list, tuple or any
182
+ primitive type.
183
+ substitution_func: The function that does the actual string
184
+ substitution.
185
+
186
+ Returns:
187
+ The object with the substitution function applied to all string values.
188
+ """
189
+ substitute_ = functools.partial(
190
+ substitute_string, substitution_func=substitution_func
191
+ )
192
+
193
+ if isinstance(value, BaseModel):
194
+ model_values = {}
195
+
196
+ for k, v in value.__iter__():
197
+ new_value = substitute_(v)
198
+
199
+ if k not in value.model_fields_set and new_value == getattr(
200
+ value, k
201
+ ):
202
+ # This is a default value on the model and was not set
203
+ # explicitly. In this case, we don't include it in the model
204
+ # values to keep the `exclude_unset` behavior the same
205
+ continue
206
+
207
+ model_values[k] = new_value
208
+
209
+ return cast(V, type(value).model_validate(model_values))
210
+ elif isinstance(value, Dict):
211
+ return cast(
212
+ V, {substitute_(k): substitute_(v) for k, v in value.items()}
213
+ )
214
+ elif isinstance(value, (list, set, tuple)):
215
+ return cast(V, type(value)(substitute_(v) for v in value))
216
+ elif isinstance(value, str):
217
+ return cast(V, substitution_func(value))
218
+
219
+ return value
@@ -1,6 +1,7 @@
1
1
  """Utils concerning anything concerning the cloud control plane backend."""
2
2
 
3
3
  import os
4
+ from datetime import datetime, timedelta, timezone
4
5
  from typing import Any, Dict, Optional
5
6
 
6
7
  import requests
@@ -19,11 +20,9 @@ class ZenMLCloudConfiguration(BaseModel):
19
20
  """ZenML Pro RBAC configuration."""
20
21
 
21
22
  api_url: str
22
-
23
23
  oauth2_client_id: str
24
24
  oauth2_client_secret: str
25
25
  oauth2_audience: str
26
- auth0_domain: str
27
26
 
28
27
  @field_validator("api_url")
29
28
  @classmethod
@@ -68,6 +67,8 @@ class ZenMLCloudConnection:
68
67
  """Initialize the RBAC component."""
69
68
  self._config = ZenMLCloudConfiguration.from_environment()
70
69
  self._session: Optional[requests.Session] = None
70
+ self._token: Optional[str] = None
71
+ self._token_expires_at: Optional[datetime] = None
71
72
 
72
73
  def get(
73
74
  self, endpoint: str, params: Optional[Dict[str, Any]]
@@ -91,7 +92,8 @@ class ZenMLCloudConnection:
91
92
 
92
93
  response = self.session.get(url=url, params=params, timeout=7)
93
94
  if response.status_code == 401:
94
- # Refresh the auth token and try again
95
+ # If we get an Unauthorized error from the API serer, we refresh the
96
+ # auth token and try again
95
97
  self._clear_session()
96
98
  response = self.session.get(url=url, params=params, timeout=7)
97
99
 
@@ -186,6 +188,8 @@ class ZenMLCloudConnection:
186
188
  def _clear_session(self) -> None:
187
189
  """Clear the authentication session."""
188
190
  self._session = None
191
+ self._token = None
192
+ self._token_expires_at = None
189
193
 
190
194
  def _fetch_auth_token(self) -> str:
191
195
  """Fetch an auth token for the Cloud API from auth0.
@@ -196,8 +200,16 @@ class ZenMLCloudConnection:
196
200
  Returns:
197
201
  Auth token.
198
202
  """
203
+ if (
204
+ self._token is not None
205
+ and self._token_expires_at is not None
206
+ and datetime.now(timezone.utc) + timedelta(minutes=5)
207
+ < self._token_expires_at
208
+ ):
209
+ return self._token
210
+
199
211
  # Get an auth token from auth0
200
- auth0_url = f"https://{self._config.auth0_domain}/oauth/token"
212
+ login_url = f"{self._config.api_url}/auth/login"
201
213
  headers = {"content-type": "application/x-www-form-urlencoded"}
202
214
  payload = {
203
215
  "client_id": self._config.oauth2_client_id,
@@ -207,18 +219,31 @@ class ZenMLCloudConnection:
207
219
  }
208
220
  try:
209
221
  response = requests.post(
210
- auth0_url, headers=headers, data=payload, timeout=7
222
+ login_url, headers=headers, data=payload, timeout=7
211
223
  )
212
224
  response.raise_for_status()
213
225
  except Exception as e:
214
226
  raise RuntimeError(f"Error fetching auth token from auth0: {e}")
215
227
 
216
- access_token = response.json().get("access_token", "")
228
+ json_response = response.json()
229
+ access_token = json_response.get("access_token", "")
230
+ expires_in = json_response.get("expires_in", 0)
217
231
 
218
- if not access_token or not isinstance(access_token, str):
232
+ if (
233
+ not access_token
234
+ or not isinstance(access_token, str)
235
+ or not expires_in
236
+ or not isinstance(expires_in, int)
237
+ ):
219
238
  raise RuntimeError("Could not fetch auth token from auth0.")
220
239
 
221
- return str(access_token)
240
+ self._token = access_token
241
+ self._token_expires_at = datetime.now(timezone.utc) + timedelta(
242
+ seconds=expires_in
243
+ )
244
+
245
+ assert self._token is not None
246
+ return self._token
222
247
 
223
248
 
224
249
  def cloud_connection() -> ZenMLCloudConnection:
@@ -22,13 +22,15 @@ from zenml.constants import (
22
22
  API,
23
23
  GRAPH,
24
24
  PIPELINE_CONFIGURATION,
25
+ REFRESH,
25
26
  RUNS,
26
27
  STATUS,
27
28
  STEPS,
28
29
  VERSION_1,
29
30
  )
30
- from zenml.enums import ExecutionStatus
31
+ from zenml.enums import ExecutionStatus, StackComponentType
31
32
  from zenml.lineage_graph.lineage_graph import LineageGraph
33
+ from zenml.logger import get_logger
32
34
  from zenml.models import (
33
35
  Page,
34
36
  PipelineRunFilter,
@@ -45,7 +47,8 @@ from zenml.zen_server.rbac.endpoint_utils import (
45
47
  verify_permissions_and_list_entities,
46
48
  verify_permissions_and_update_entity,
47
49
  )
48
- from zenml.zen_server.rbac.models import ResourceType
50
+ from zenml.zen_server.rbac.models import Action, ResourceType
51
+ from zenml.zen_server.rbac.utils import verify_permission_for_model
49
52
  from zenml.zen_server.utils import (
50
53
  handle_exceptions,
51
54
  make_dependable,
@@ -59,6 +62,9 @@ router = APIRouter(
59
62
  )
60
63
 
61
64
 
65
+ logger = get_logger(__name__)
66
+
67
+
62
68
  @router.get(
63
69
  "",
64
70
  response_model=Page[PipelineRunResponse],
@@ -99,6 +105,7 @@ def list_runs(
99
105
  def get_run(
100
106
  run_id: UUID,
101
107
  hydrate: bool = True,
108
+ refresh_status: bool = False,
102
109
  _: AuthContext = Security(authorize),
103
110
  ) -> PipelineRunResponse:
104
111
  """Get a specific pipeline run using its ID.
@@ -107,13 +114,47 @@ def get_run(
107
114
  run_id: ID of the pipeline run to get.
108
115
  hydrate: Flag deciding whether to hydrate the output model(s)
109
116
  by including metadata fields in the response.
117
+ refresh_status: Flag deciding whether we should try to refresh
118
+ the status of the pipeline run using its orchestrator.
110
119
 
111
120
  Returns:
112
121
  The pipeline run.
122
+
123
+ Raises:
124
+ RuntimeError: If the stack or the orchestrator of the run is deleted.
113
125
  """
114
- return verify_permissions_and_get_entity(
126
+ run = verify_permissions_and_get_entity(
115
127
  id=run_id, get_method=zen_store().get_run, hydrate=hydrate
116
128
  )
129
+ if refresh_status:
130
+ try:
131
+ # Check the stack and its orchestrator
132
+ if run.stack is not None:
133
+ orchestrators = run.stack.components.get(
134
+ StackComponentType.ORCHESTRATOR, []
135
+ )
136
+ if orchestrators:
137
+ verify_permission_for_model(
138
+ model=orchestrators[0], action=Action.READ
139
+ )
140
+ else:
141
+ raise RuntimeError(
142
+ f"The orchestrator, the run '{run.id}' was executed "
143
+ "with, is deleted."
144
+ )
145
+ else:
146
+ raise RuntimeError(
147
+ f"The stack, the run '{run.id}' was executed on, is deleted."
148
+ )
149
+
150
+ run = run.refresh_run_status()
151
+
152
+ except Exception as e:
153
+ logger.warning(
154
+ "An error occurred while refreshing the status of the "
155
+ f"pipeline run: {e}"
156
+ )
157
+ return run
117
158
 
118
159
 
119
160
  @router.put(
@@ -267,3 +308,48 @@ def get_run_status(
267
308
  id=run_id, get_method=zen_store().get_run, hydrate=False
268
309
  )
269
310
  return run.status
311
+
312
+
313
+ @router.get(
314
+ "/{run_id}" + REFRESH,
315
+ responses={401: error_response, 404: error_response, 422: error_response},
316
+ )
317
+ @handle_exceptions
318
+ def refresh_run_status(
319
+ run_id: UUID,
320
+ _: AuthContext = Security(authorize),
321
+ ) -> None:
322
+ """Refreshes the status of a specific pipeline run.
323
+
324
+ Args:
325
+ run_id: ID of the pipeline run to refresh.
326
+
327
+ Raises:
328
+ RuntimeError: If the stack or the orchestrator of the run is deleted.
329
+ """
330
+ # Verify access to the run
331
+ run = verify_permissions_and_get_entity(
332
+ id=run_id,
333
+ get_method=zen_store().get_run,
334
+ hydrate=True,
335
+ )
336
+
337
+ # Check the stack and its orchestrator
338
+ if run.stack is not None:
339
+ orchestrators = run.stack.components.get(
340
+ StackComponentType.ORCHESTRATOR, []
341
+ )
342
+ if orchestrators:
343
+ verify_permission_for_model(
344
+ model=orchestrators[0], action=Action.READ
345
+ )
346
+ else:
347
+ raise RuntimeError(
348
+ f"The orchestrator, the run '{run.id}' was executed with, is "
349
+ "deleted."
350
+ )
351
+ else:
352
+ raise RuntimeError(
353
+ f"The stack, the run '{run.id}' was executed on, is deleted."
354
+ )
355
+ run.refresh_run_status()
@@ -973,6 +973,7 @@ class SqlZenStore(BaseZenStore):
973
973
  ValueError: if the filtered page number is out of bounds.
974
974
  RuntimeError: if the schema does not have a `to_model` method.
975
975
  """
976
+ query = query.distinct()
976
977
  query = filter_model.apply_filter(query=query, table=table)
977
978
  query = query.distinct()
978
979
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zenml-nightly
3
- Version: 0.66.0.dev20240923
3
+ Version: 0.66.0.dev20240925
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  Home-page: https://zenml.io
6
6
  License: Apache-2.0
@@ -24,6 +24,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
24
  Classifier: Topic :: System :: Distributed Computing
25
25
  Classifier: Typing :: Typed
26
26
  Provides-Extra: adlfs
27
+ Provides-Extra: azureml
27
28
  Provides-Extra: connectors-aws
28
29
  Provides-Extra: connectors-azure
29
30
  Provides-Extra: connectors-gcp
@@ -32,6 +33,7 @@ Provides-Extra: dev
32
33
  Provides-Extra: gcsfs
33
34
  Provides-Extra: mlstacks
34
35
  Provides-Extra: s3fs
36
+ Provides-Extra: sagemaker
35
37
  Provides-Extra: secrets-aws
36
38
  Provides-Extra: secrets-azure
37
39
  Provides-Extra: secrets-gcp
@@ -39,10 +41,12 @@ Provides-Extra: secrets-hashicorp
39
41
  Provides-Extra: server
40
42
  Provides-Extra: templates
41
43
  Provides-Extra: terraform
44
+ Provides-Extra: vertex
42
45
  Requires-Dist: Jinja2 ; extra == "server"
43
46
  Requires-Dist: adlfs (>=2021.10.0) ; extra == "adlfs"
44
47
  Requires-Dist: alembic (>=1.8.1,<1.9.0)
45
48
  Requires-Dist: aws-profile-manager (>=0.5.0) ; extra == "connectors-aws"
49
+ Requires-Dist: azure-ai-ml (==1.18.0) ; extra == "azureml"
46
50
  Requires-Dist: azure-identity (>=1.4.0) ; extra == "secrets-azure" or extra == "connectors-azure"
47
51
  Requires-Dist: azure-keyvault-secrets (>=4.0.0) ; extra == "secrets-azure"
48
52
  Requires-Dist: azure-mgmt-containerregistry (>=10.0.0) ; extra == "connectors-azure"
@@ -63,6 +67,7 @@ Requires-Dist: docker (>=7.1.0,<7.2.0)
63
67
  Requires-Dist: fastapi (>=0.100,<=0.110) ; extra == "server"
64
68
  Requires-Dist: gcsfs (>=2022.11.0) ; extra == "gcsfs"
65
69
  Requires-Dist: gitpython (>=3.1.18,<4.0.0)
70
+ Requires-Dist: google-cloud-aiplatform (>=1.34.0) ; extra == "vertex"
66
71
  Requires-Dist: google-cloud-artifact-registry (>=1.11.3) ; extra == "connectors-gcp"
67
72
  Requires-Dist: google-cloud-container (>=2.21.0) ; extra == "connectors-gcp"
68
73
  Requires-Dist: google-cloud-secret-manager (>=2.12.5) ; extra == "secrets-gcp"
@@ -72,6 +77,7 @@ Requires-Dist: hypothesis (>=6.43.1,<7.0.0) ; extra == "dev"
72
77
  Requires-Dist: importlib_metadata (<=7.0.0) ; python_version < "3.10"
73
78
  Requires-Dist: ipinfo (>=4.4.3) ; extra == "server"
74
79
  Requires-Dist: jinja2-time (>=0.2.0,<0.3.0) ; extra == "templates"
80
+ Requires-Dist: kfp (>=2.6.0) ; extra == "vertex"
75
81
  Requires-Dist: kubernetes (>=18.20.0) ; extra == "connectors-kubernetes" or extra == "connectors-aws" or extra == "connectors-gcp" or extra == "connectors-azure"
76
82
  Requires-Dist: maison (<2.0) ; extra == "dev"
77
83
  Requires-Dist: mike (>=1.1.2,<2.0.0) ; extra == "dev"
@@ -105,6 +111,7 @@ Requires-Dist: pyyaml-include (<2.0) ; extra == "templates"
105
111
  Requires-Dist: rich[jupyter] (>=12.0.0)
106
112
  Requires-Dist: ruff (>=0.1.7) ; extra == "templates" or extra == "dev"
107
113
  Requires-Dist: s3fs (>=2022.11.0) ; extra == "s3fs"
114
+ Requires-Dist: sagemaker (>=2.117.0) ; extra == "sagemaker"
108
115
  Requires-Dist: secure (>=0.3.0,<0.4.0) ; extra == "server"
109
116
  Requires-Dist: setuptools
110
117
  Requires-Dist: sqlalchemy (>=2.0.0,<3.0.0)