zenml-nightly 0.66.0.dev20240924__py3-none-any.whl → 0.66.0.dev20240925__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,7 @@ from typing import (
21
21
  List,
22
22
  Optional,
23
23
  Union,
24
+ cast,
24
25
  )
25
26
  from uuid import UUID
26
27
 
@@ -136,7 +137,8 @@ class PipelineRunUpdate(BaseModel):
136
137
  "configured by this pipeline run explicitly.",
137
138
  default=None,
138
139
  )
139
- # TODO: we should maybe have a different update model here, the upper three attributes should only be for internal use
140
+ # TODO: we should maybe have a different update model here, the upper
141
+ # three attributes should only be for internal use
140
142
  add_tags: Optional[List[str]] = Field(
141
143
  default=None, title="New tags to add to the pipeline run."
142
144
  )
@@ -235,6 +237,7 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
235
237
  description="Template used for the pipeline run.",
236
238
  )
237
239
  is_templatable: bool = Field(
240
+ default=False,
238
241
  description="Whether a template can be created from this run.",
239
242
  )
240
243
 
@@ -307,6 +310,64 @@ class PipelineRunResponse(
307
310
 
308
311
  return get_artifacts_versions_of_pipeline_run(self, only_produced=True)
309
312
 
313
+ def refresh_run_status(self) -> "PipelineRunResponse":
314
+ """Method to refresh the status of a run if it is initializing/running.
315
+
316
+ Returns:
317
+ The updated pipeline.
318
+
319
+ Raises:
320
+ ValueError: If the stack of the run response is None.
321
+ """
322
+ if self.status in [
323
+ ExecutionStatus.INITIALIZING,
324
+ ExecutionStatus.RUNNING,
325
+ ]:
326
+ # Check if the stack still accessible
327
+ if self.stack is None:
328
+ raise ValueError(
329
+ "The stack that this pipeline run response was executed on"
330
+ "has been deleted."
331
+ )
332
+
333
+ # Create the orchestrator instance
334
+ from zenml.enums import StackComponentType
335
+ from zenml.orchestrators.base_orchestrator import BaseOrchestrator
336
+ from zenml.stack.stack_component import StackComponent
337
+
338
+ # Check if the stack still accessible
339
+ orchestrator_list = self.stack.components.get(
340
+ StackComponentType.ORCHESTRATOR, []
341
+ )
342
+ if len(orchestrator_list) == 0:
343
+ raise ValueError(
344
+ "The orchestrator that this pipeline run response was "
345
+ "executed with has been deleted."
346
+ )
347
+
348
+ orchestrator = cast(
349
+ BaseOrchestrator,
350
+ StackComponent.from_model(
351
+ component_model=orchestrator_list[0]
352
+ ),
353
+ )
354
+
355
+ # Fetch the status
356
+ status = orchestrator.fetch_status(run=self)
357
+
358
+ # If it is different from the current status, update it
359
+ if status != self.status:
360
+ from zenml.client import Client
361
+ from zenml.models import PipelineRunUpdate
362
+
363
+ client = Client()
364
+ return client.zen_store.update_run(
365
+ run_id=self.id,
366
+ run_update=PipelineRunUpdate(status=status),
367
+ )
368
+
369
+ return self
370
+
310
371
  # Body and metadata properties
311
372
  @property
312
373
  def status(self) -> ExecutionStatus:
@@ -138,7 +138,10 @@ def deploy_pipeline(
138
138
  previous_value = constants.SHOULD_PREVENT_PIPELINE_EXECUTION
139
139
  constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True
140
140
  try:
141
- stack.deploy_pipeline(deployment=deployment)
141
+ stack.deploy_pipeline(
142
+ deployment=deployment,
143
+ placeholder_run=placeholder_run,
144
+ )
142
145
  except Exception as e:
143
146
  if (
144
147
  placeholder_run
@@ -14,12 +14,14 @@
14
14
  """Base orchestrator class."""
15
15
 
16
16
  from abc import ABC, abstractmethod
17
- from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast
17
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, cast
18
18
 
19
19
  from pydantic import model_validator
20
20
 
21
- from zenml.enums import StackComponentType
21
+ from zenml.enums import ExecutionStatus, StackComponentType
22
22
  from zenml.logger import get_logger
23
+ from zenml.metadata.metadata_types import MetadataType
24
+ from zenml.orchestrators.publish_utils import publish_pipeline_run_metadata
23
25
  from zenml.orchestrators.step_launcher import StepLauncher
24
26
  from zenml.orchestrators.utils import get_config_environment_vars
25
27
  from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig
@@ -27,7 +29,7 @@ from zenml.utils.pydantic_utils import before_validator_handler
27
29
 
28
30
  if TYPE_CHECKING:
29
31
  from zenml.config.step_configurations import Step
30
- from zenml.models import PipelineDeploymentResponse
32
+ from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
31
33
 
32
34
  logger = get_logger(__name__)
33
35
 
@@ -124,7 +126,7 @@ class BaseOrchestrator(StackComponent, ABC):
124
126
  deployment: "PipelineDeploymentResponse",
125
127
  stack: "Stack",
126
128
  environment: Dict[str, str],
127
- ) -> Any:
129
+ ) -> Optional[Iterator[Dict[str, MetadataType]]]:
128
130
  """The method needs to be implemented by the respective orchestrator.
129
131
 
130
132
  Depending on the type of orchestrator you'll have to perform slightly
@@ -169,29 +171,41 @@ class BaseOrchestrator(StackComponent, ABC):
169
171
  self,
170
172
  deployment: "PipelineDeploymentResponse",
171
173
  stack: "Stack",
174
+ placeholder_run: Optional["PipelineRunResponse"] = None,
172
175
  ) -> Any:
173
176
  """Runs a pipeline on a stack.
174
177
 
175
178
  Args:
176
179
  deployment: The pipeline deployment.
177
180
  stack: The stack on which to run the pipeline.
178
-
179
- Returns:
180
- Orchestrator-specific return value.
181
+ placeholder_run: An optional placeholder run for the deployment.
182
+ This will be deleted in case the pipeline deployment failed.
181
183
  """
182
184
  self._prepare_run(deployment=deployment)
183
185
 
184
186
  environment = get_config_environment_vars(deployment=deployment)
185
187
 
186
188
  try:
187
- result = self.prepare_or_run_pipeline(
188
- deployment=deployment, stack=stack, environment=environment
189
- )
189
+ if metadata_iterator := self.prepare_or_run_pipeline(
190
+ deployment=deployment,
191
+ stack=stack,
192
+ environment=environment,
193
+ ):
194
+ for metadata_dict in metadata_iterator:
195
+ try:
196
+ if placeholder_run:
197
+ publish_pipeline_run_metadata(
198
+ pipeline_run_id=placeholder_run.id,
199
+ pipeline_run_metadata={self.id: metadata_dict},
200
+ )
201
+ except Exception as e:
202
+ logger.debug(
203
+ "Something went went wrong trying to publish the"
204
+ f"run metadata: {e}"
205
+ )
190
206
  finally:
191
207
  self._cleanup_run()
192
208
 
193
- return result
194
-
195
209
  def run_step(self, step: "Step") -> None:
196
210
  """Runs the given step.
197
211
 
@@ -239,6 +253,21 @@ class BaseOrchestrator(StackComponent, ABC):
239
253
  """Cleans up the active run."""
240
254
  self._active_deployment = None
241
255
 
256
+ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
257
+ """Refreshes the status of a specific pipeline run.
258
+
259
+ Args:
260
+ run: A pipeline run response to fetch its status.
261
+
262
+ Raises:
263
+ NotImplementedError: If any orchestrator inheriting from the base
264
+ class does not implement this logic.
265
+ """
266
+ raise NotImplementedError(
267
+ "The fetch status functionality is not implemented for the "
268
+ f"'{self.__class__.__name__}' orchestrator."
269
+ )
270
+
242
271
 
243
272
  class BaseOrchestratorFlavor(Flavor):
244
273
  """Base orchestrator flavor class."""
zenml/stack/stack.py CHANGED
@@ -62,7 +62,11 @@ if TYPE_CHECKING:
62
62
  from zenml.image_builders import BaseImageBuilder
63
63
  from zenml.model_deployers import BaseModelDeployer
64
64
  from zenml.model_registries import BaseModelRegistry
65
- from zenml.models import PipelineDeploymentBase, PipelineDeploymentResponse
65
+ from zenml.models import (
66
+ PipelineDeploymentBase,
67
+ PipelineDeploymentResponse,
68
+ PipelineRunResponse,
69
+ )
66
70
  from zenml.orchestrators import BaseOrchestrator
67
71
  from zenml.stack import StackComponent
68
72
  from zenml.step_operators import BaseStepOperator
@@ -826,16 +830,21 @@ class Stack:
826
830
  def deploy_pipeline(
827
831
  self,
828
832
  deployment: "PipelineDeploymentResponse",
833
+ placeholder_run: Optional["PipelineRunResponse"] = None,
829
834
  ) -> Any:
830
835
  """Deploys a pipeline on this stack.
831
836
 
832
837
  Args:
833
838
  deployment: The pipeline deployment.
839
+ placeholder_run: An optional placeholder run for the deployment.
840
+ This will be deleted in case the pipeline deployment failed.
834
841
 
835
842
  Returns:
836
843
  The return value of the call to `orchestrator.run_pipeline(...)`.
837
844
  """
838
- return self.orchestrator.run(deployment=deployment, stack=self)
845
+ return self.orchestrator.run(
846
+ deployment=deployment, stack=self, placeholder_run=placeholder_run
847
+ )
839
848
 
840
849
  def _get_active_components_for_step(
841
850
  self, step_config: "StepConfiguration"
@@ -1,6 +1,7 @@
1
1
  """Utils concerning anything concerning the cloud control plane backend."""
2
2
 
3
3
  import os
4
+ from datetime import datetime, timedelta, timezone
4
5
  from typing import Any, Dict, Optional
5
6
 
6
7
  import requests
@@ -19,11 +20,9 @@ class ZenMLCloudConfiguration(BaseModel):
19
20
  """ZenML Pro RBAC configuration."""
20
21
 
21
22
  api_url: str
22
-
23
23
  oauth2_client_id: str
24
24
  oauth2_client_secret: str
25
25
  oauth2_audience: str
26
- auth0_domain: str
27
26
 
28
27
  @field_validator("api_url")
29
28
  @classmethod
@@ -68,6 +67,8 @@ class ZenMLCloudConnection:
68
67
  """Initialize the RBAC component."""
69
68
  self._config = ZenMLCloudConfiguration.from_environment()
70
69
  self._session: Optional[requests.Session] = None
70
+ self._token: Optional[str] = None
71
+ self._token_expires_at: Optional[datetime] = None
71
72
 
72
73
  def get(
73
74
  self, endpoint: str, params: Optional[Dict[str, Any]]
@@ -91,7 +92,8 @@ class ZenMLCloudConnection:
91
92
 
92
93
  response = self.session.get(url=url, params=params, timeout=7)
93
94
  if response.status_code == 401:
94
- # Refresh the auth token and try again
95
+ # If we get an Unauthorized error from the API serer, we refresh the
96
+ # auth token and try again
95
97
  self._clear_session()
96
98
  response = self.session.get(url=url, params=params, timeout=7)
97
99
 
@@ -186,6 +188,8 @@ class ZenMLCloudConnection:
186
188
  def _clear_session(self) -> None:
187
189
  """Clear the authentication session."""
188
190
  self._session = None
191
+ self._token = None
192
+ self._token_expires_at = None
189
193
 
190
194
  def _fetch_auth_token(self) -> str:
191
195
  """Fetch an auth token for the Cloud API from auth0.
@@ -196,8 +200,16 @@ class ZenMLCloudConnection:
196
200
  Returns:
197
201
  Auth token.
198
202
  """
203
+ if (
204
+ self._token is not None
205
+ and self._token_expires_at is not None
206
+ and datetime.now(timezone.utc) + timedelta(minutes=5)
207
+ < self._token_expires_at
208
+ ):
209
+ return self._token
210
+
199
211
  # Get an auth token from auth0
200
- auth0_url = f"https://{self._config.auth0_domain}/oauth/token"
212
+ login_url = f"{self._config.api_url}/auth/login"
201
213
  headers = {"content-type": "application/x-www-form-urlencoded"}
202
214
  payload = {
203
215
  "client_id": self._config.oauth2_client_id,
@@ -207,18 +219,31 @@ class ZenMLCloudConnection:
207
219
  }
208
220
  try:
209
221
  response = requests.post(
210
- auth0_url, headers=headers, data=payload, timeout=7
222
+ login_url, headers=headers, data=payload, timeout=7
211
223
  )
212
224
  response.raise_for_status()
213
225
  except Exception as e:
214
226
  raise RuntimeError(f"Error fetching auth token from auth0: {e}")
215
227
 
216
- access_token = response.json().get("access_token", "")
228
+ json_response = response.json()
229
+ access_token = json_response.get("access_token", "")
230
+ expires_in = json_response.get("expires_in", 0)
217
231
 
218
- if not access_token or not isinstance(access_token, str):
232
+ if (
233
+ not access_token
234
+ or not isinstance(access_token, str)
235
+ or not expires_in
236
+ or not isinstance(expires_in, int)
237
+ ):
219
238
  raise RuntimeError("Could not fetch auth token from auth0.")
220
239
 
221
- return str(access_token)
240
+ self._token = access_token
241
+ self._token_expires_at = datetime.now(timezone.utc) + timedelta(
242
+ seconds=expires_in
243
+ )
244
+
245
+ assert self._token is not None
246
+ return self._token
222
247
 
223
248
 
224
249
  def cloud_connection() -> ZenMLCloudConnection:
@@ -22,13 +22,15 @@ from zenml.constants import (
22
22
  API,
23
23
  GRAPH,
24
24
  PIPELINE_CONFIGURATION,
25
+ REFRESH,
25
26
  RUNS,
26
27
  STATUS,
27
28
  STEPS,
28
29
  VERSION_1,
29
30
  )
30
- from zenml.enums import ExecutionStatus
31
+ from zenml.enums import ExecutionStatus, StackComponentType
31
32
  from zenml.lineage_graph.lineage_graph import LineageGraph
33
+ from zenml.logger import get_logger
32
34
  from zenml.models import (
33
35
  Page,
34
36
  PipelineRunFilter,
@@ -45,7 +47,8 @@ from zenml.zen_server.rbac.endpoint_utils import (
45
47
  verify_permissions_and_list_entities,
46
48
  verify_permissions_and_update_entity,
47
49
  )
48
- from zenml.zen_server.rbac.models import ResourceType
50
+ from zenml.zen_server.rbac.models import Action, ResourceType
51
+ from zenml.zen_server.rbac.utils import verify_permission_for_model
49
52
  from zenml.zen_server.utils import (
50
53
  handle_exceptions,
51
54
  make_dependable,
@@ -59,6 +62,9 @@ router = APIRouter(
59
62
  )
60
63
 
61
64
 
65
+ logger = get_logger(__name__)
66
+
67
+
62
68
  @router.get(
63
69
  "",
64
70
  response_model=Page[PipelineRunResponse],
@@ -99,6 +105,7 @@ def list_runs(
99
105
  def get_run(
100
106
  run_id: UUID,
101
107
  hydrate: bool = True,
108
+ refresh_status: bool = False,
102
109
  _: AuthContext = Security(authorize),
103
110
  ) -> PipelineRunResponse:
104
111
  """Get a specific pipeline run using its ID.
@@ -107,13 +114,47 @@ def get_run(
107
114
  run_id: ID of the pipeline run to get.
108
115
  hydrate: Flag deciding whether to hydrate the output model(s)
109
116
  by including metadata fields in the response.
117
+ refresh_status: Flag deciding whether we should try to refresh
118
+ the status of the pipeline run using its orchestrator.
110
119
 
111
120
  Returns:
112
121
  The pipeline run.
122
+
123
+ Raises:
124
+ RuntimeError: If the stack or the orchestrator of the run is deleted.
113
125
  """
114
- return verify_permissions_and_get_entity(
126
+ run = verify_permissions_and_get_entity(
115
127
  id=run_id, get_method=zen_store().get_run, hydrate=hydrate
116
128
  )
129
+ if refresh_status:
130
+ try:
131
+ # Check the stack and its orchestrator
132
+ if run.stack is not None:
133
+ orchestrators = run.stack.components.get(
134
+ StackComponentType.ORCHESTRATOR, []
135
+ )
136
+ if orchestrators:
137
+ verify_permission_for_model(
138
+ model=orchestrators[0], action=Action.READ
139
+ )
140
+ else:
141
+ raise RuntimeError(
142
+ f"The orchestrator, the run '{run.id}' was executed "
143
+ "with, is deleted."
144
+ )
145
+ else:
146
+ raise RuntimeError(
147
+ f"The stack, the run '{run.id}' was executed on, is deleted."
148
+ )
149
+
150
+ run = run.refresh_run_status()
151
+
152
+ except Exception as e:
153
+ logger.warning(
154
+ "An error occurred while refreshing the status of the "
155
+ f"pipeline run: {e}"
156
+ )
157
+ return run
117
158
 
118
159
 
119
160
  @router.put(
@@ -267,3 +308,48 @@ def get_run_status(
267
308
  id=run_id, get_method=zen_store().get_run, hydrate=False
268
309
  )
269
310
  return run.status
311
+
312
+
313
+ @router.get(
314
+ "/{run_id}" + REFRESH,
315
+ responses={401: error_response, 404: error_response, 422: error_response},
316
+ )
317
+ @handle_exceptions
318
+ def refresh_run_status(
319
+ run_id: UUID,
320
+ _: AuthContext = Security(authorize),
321
+ ) -> None:
322
+ """Refreshes the status of a specific pipeline run.
323
+
324
+ Args:
325
+ run_id: ID of the pipeline run to refresh.
326
+
327
+ Raises:
328
+ RuntimeError: If the stack or the orchestrator of the run is deleted.
329
+ """
330
+ # Verify access to the run
331
+ run = verify_permissions_and_get_entity(
332
+ id=run_id,
333
+ get_method=zen_store().get_run,
334
+ hydrate=True,
335
+ )
336
+
337
+ # Check the stack and its orchestrator
338
+ if run.stack is not None:
339
+ orchestrators = run.stack.components.get(
340
+ StackComponentType.ORCHESTRATOR, []
341
+ )
342
+ if orchestrators:
343
+ verify_permission_for_model(
344
+ model=orchestrators[0], action=Action.READ
345
+ )
346
+ else:
347
+ raise RuntimeError(
348
+ f"The orchestrator, the run '{run.id}' was executed with, is "
349
+ "deleted."
350
+ )
351
+ else:
352
+ raise RuntimeError(
353
+ f"The stack, the run '{run.id}' was executed on, is deleted."
354
+ )
355
+ run.refresh_run_status()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zenml-nightly
3
- Version: 0.66.0.dev20240924
3
+ Version: 0.66.0.dev20240925
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  Home-page: https://zenml.io
6
6
  License: Apache-2.0
@@ -24,6 +24,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
24
  Classifier: Topic :: System :: Distributed Computing
25
25
  Classifier: Typing :: Typed
26
26
  Provides-Extra: adlfs
27
+ Provides-Extra: azureml
27
28
  Provides-Extra: connectors-aws
28
29
  Provides-Extra: connectors-azure
29
30
  Provides-Extra: connectors-gcp
@@ -32,6 +33,7 @@ Provides-Extra: dev
32
33
  Provides-Extra: gcsfs
33
34
  Provides-Extra: mlstacks
34
35
  Provides-Extra: s3fs
36
+ Provides-Extra: sagemaker
35
37
  Provides-Extra: secrets-aws
36
38
  Provides-Extra: secrets-azure
37
39
  Provides-Extra: secrets-gcp
@@ -39,10 +41,12 @@ Provides-Extra: secrets-hashicorp
39
41
  Provides-Extra: server
40
42
  Provides-Extra: templates
41
43
  Provides-Extra: terraform
44
+ Provides-Extra: vertex
42
45
  Requires-Dist: Jinja2 ; extra == "server"
43
46
  Requires-Dist: adlfs (>=2021.10.0) ; extra == "adlfs"
44
47
  Requires-Dist: alembic (>=1.8.1,<1.9.0)
45
48
  Requires-Dist: aws-profile-manager (>=0.5.0) ; extra == "connectors-aws"
49
+ Requires-Dist: azure-ai-ml (==1.18.0) ; extra == "azureml"
46
50
  Requires-Dist: azure-identity (>=1.4.0) ; extra == "secrets-azure" or extra == "connectors-azure"
47
51
  Requires-Dist: azure-keyvault-secrets (>=4.0.0) ; extra == "secrets-azure"
48
52
  Requires-Dist: azure-mgmt-containerregistry (>=10.0.0) ; extra == "connectors-azure"
@@ -63,6 +67,7 @@ Requires-Dist: docker (>=7.1.0,<7.2.0)
63
67
  Requires-Dist: fastapi (>=0.100,<=0.110) ; extra == "server"
64
68
  Requires-Dist: gcsfs (>=2022.11.0) ; extra == "gcsfs"
65
69
  Requires-Dist: gitpython (>=3.1.18,<4.0.0)
70
+ Requires-Dist: google-cloud-aiplatform (>=1.34.0) ; extra == "vertex"
66
71
  Requires-Dist: google-cloud-artifact-registry (>=1.11.3) ; extra == "connectors-gcp"
67
72
  Requires-Dist: google-cloud-container (>=2.21.0) ; extra == "connectors-gcp"
68
73
  Requires-Dist: google-cloud-secret-manager (>=2.12.5) ; extra == "secrets-gcp"
@@ -72,6 +77,7 @@ Requires-Dist: hypothesis (>=6.43.1,<7.0.0) ; extra == "dev"
72
77
  Requires-Dist: importlib_metadata (<=7.0.0) ; python_version < "3.10"
73
78
  Requires-Dist: ipinfo (>=4.4.3) ; extra == "server"
74
79
  Requires-Dist: jinja2-time (>=0.2.0,<0.3.0) ; extra == "templates"
80
+ Requires-Dist: kfp (>=2.6.0) ; extra == "vertex"
75
81
  Requires-Dist: kubernetes (>=18.20.0) ; extra == "connectors-kubernetes" or extra == "connectors-aws" or extra == "connectors-gcp" or extra == "connectors-azure"
76
82
  Requires-Dist: maison (<2.0) ; extra == "dev"
77
83
  Requires-Dist: mike (>=1.1.2,<2.0.0) ; extra == "dev"
@@ -105,6 +111,7 @@ Requires-Dist: pyyaml-include (<2.0) ; extra == "templates"
105
111
  Requires-Dist: rich[jupyter] (>=12.0.0)
106
112
  Requires-Dist: ruff (>=0.1.7) ; extra == "templates" or extra == "dev"
107
113
  Requires-Dist: s3fs (>=2022.11.0) ; extra == "s3fs"
114
+ Requires-Dist: sagemaker (>=2.117.0) ; extra == "sagemaker"
108
115
  Requires-Dist: secure (>=0.3.0,<0.4.0) ; extra == "server"
109
116
  Requires-Dist: setuptools
110
117
  Requires-Dist: sqlalchemy (>=2.0.0,<3.0.0)