zenml-nightly 0.57.1.dev20240522__py3-none-any.whl → 0.57.1.dev20240527__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/base.py +15 -16
- zenml/client.py +11 -1
- zenml/config/__init__.py +2 -0
- zenml/config/pipeline_configurations.py +2 -0
- zenml/config/pipeline_run_configuration.py +2 -0
- zenml/config/retry_config.py +27 -0
- zenml/config/server_config.py +13 -9
- zenml/config/step_configurations.py +2 -0
- zenml/constants.py +1 -0
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +2 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +14 -0
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +4 -0
- zenml/integrations/slack/alerters/slack_alerter.py +0 -2
- zenml/model/model.py +77 -45
- zenml/models/v2/core/model_version.py +1 -1
- zenml/models/v2/core/pipeline_run.py +12 -0
- zenml/models/v2/core/step_run.py +12 -0
- zenml/models/v2/misc/server_models.py +9 -3
- zenml/new/pipelines/run_utils.py +8 -2
- zenml/new/steps/step_decorator.py +5 -0
- zenml/orchestrators/step_launcher.py +71 -53
- zenml/orchestrators/step_runner.py +26 -132
- zenml/orchestrators/utils.py +158 -1
- zenml/steps/base_step.py +7 -0
- zenml/utils/dashboard_utils.py +4 -8
- zenml/zen_server/deploy/helm/templates/_environment.tpl +5 -5
- zenml/zen_server/deploy/helm/values.yaml +13 -9
- zenml/zen_server/pipeline_deployment/utils.py +6 -2
- zenml/zen_server/routers/auth_endpoints.py +4 -4
- zenml/zen_server/zen_server_api.py +1 -1
- zenml/zen_stores/base_zen_store.py +2 -2
- zenml/zen_stores/schemas/pipeline_run_schemas.py +12 -0
- zenml/zen_stores/schemas/step_run_schemas.py +14 -0
- zenml/zen_stores/sql_zen_store.py +4 -2
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/METADATA +3 -3
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/RECORD +40 -39
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/entry_points.txt +0 -0
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
|
|
33
33
|
from types import FunctionType
|
34
34
|
|
35
35
|
from zenml.config.base_settings import SettingsOrDict
|
36
|
+
from zenml.config.retry_config import StepRetryConfig
|
36
37
|
from zenml.config.source import Source
|
37
38
|
from zenml.materializers.base_materializer import BaseMaterializer
|
38
39
|
from zenml.model.model import Model
|
@@ -72,6 +73,7 @@ def step(
|
|
72
73
|
on_failure: Optional["HookSpecification"] = None,
|
73
74
|
on_success: Optional["HookSpecification"] = None,
|
74
75
|
model: Optional["Model"] = None,
|
76
|
+
retry: Optional["StepRetryConfig"] = None,
|
75
77
|
model_version: Optional["Model"] = None, # TODO: deprecate me
|
76
78
|
) -> Callable[["F"], "BaseStep"]: ...
|
77
79
|
|
@@ -92,6 +94,7 @@ def step(
|
|
92
94
|
on_failure: Optional["HookSpecification"] = None,
|
93
95
|
on_success: Optional["HookSpecification"] = None,
|
94
96
|
model: Optional["Model"] = None,
|
97
|
+
retry: Optional["StepRetryConfig"] = None,
|
95
98
|
model_version: Optional["Model"] = None, # TODO: deprecate me
|
96
99
|
) -> Union["BaseStep", Callable[["F"], "BaseStep"]]:
|
97
100
|
"""Decorator to create a ZenML step.
|
@@ -123,6 +126,7 @@ def step(
|
|
123
126
|
function with no arguments, or a source path to such a function
|
124
127
|
(e.g. `module.my_function`).
|
125
128
|
model: configuration of the model in the Model Control Plane.
|
129
|
+
retry: configuration of step retry in case of step failure.
|
126
130
|
model_version: DEPRECATED, please use `model` instead.
|
127
131
|
|
128
132
|
Returns:
|
@@ -162,6 +166,7 @@ def step(
|
|
162
166
|
on_failure=on_failure,
|
163
167
|
on_success=on_success,
|
164
168
|
model=model or model_version,
|
169
|
+
retry=retry,
|
165
170
|
)
|
166
171
|
|
167
172
|
return step_instance
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Class to launch (run directly or using a step operator) steps."""
|
15
15
|
|
16
|
+
import os
|
16
17
|
import time
|
17
18
|
from contextlib import nullcontext
|
18
19
|
from datetime import datetime
|
@@ -23,6 +24,7 @@ from zenml.config.step_configurations import Step
|
|
23
24
|
from zenml.config.step_run_info import StepRunInfo
|
24
25
|
from zenml.constants import (
|
25
26
|
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
|
27
|
+
ENV_ZENML_IGNORE_FAILURE_HOOK,
|
26
28
|
STEP_SOURCE_PARAMETER_NAME,
|
27
29
|
TEXT_FIELD_MAX_LENGTH,
|
28
30
|
handle_bool_env_var,
|
@@ -31,7 +33,6 @@ from zenml.enums import ExecutionStatus
|
|
31
33
|
from zenml.environment import get_run_environment_dict
|
32
34
|
from zenml.logger import get_logger
|
33
35
|
from zenml.logging import step_logging
|
34
|
-
from zenml.model.utils import link_artifact_config_to_model
|
35
36
|
from zenml.models import (
|
36
37
|
ArtifactVersionResponse,
|
37
38
|
LogsRequest,
|
@@ -53,7 +54,6 @@ from zenml.stack import Stack
|
|
53
54
|
from zenml.utils import string_utils
|
54
55
|
|
55
56
|
if TYPE_CHECKING:
|
56
|
-
from zenml.model.model import Model
|
57
57
|
from zenml.step_operators import BaseStepOperator
|
58
58
|
|
59
59
|
logger = get_logger(__name__)
|
@@ -226,20 +226,56 @@ class StepLauncher:
|
|
226
226
|
|
227
227
|
logger.info(f"Step `{self._step_name}` has started.")
|
228
228
|
if execution_needed:
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
229
|
+
retries = 0
|
230
|
+
last_retry = True
|
231
|
+
max_retries = (
|
232
|
+
step_run_response.config.retry.max_retries
|
233
|
+
if step_run_response.config.retry
|
234
|
+
else 1
|
235
|
+
)
|
236
|
+
delay = (
|
237
|
+
step_run_response.config.retry.delay
|
238
|
+
if step_run_response.config.retry
|
239
|
+
else 0
|
240
|
+
)
|
241
|
+
backoff = (
|
242
|
+
step_run_response.config.retry.backoff
|
243
|
+
if step_run_response.config.retry
|
244
|
+
else 1
|
245
|
+
)
|
246
|
+
|
247
|
+
while retries < max_retries:
|
248
|
+
last_retry = retries == max_retries - 1
|
249
|
+
try:
|
250
|
+
self._run_step(
|
251
|
+
pipeline_run=pipeline_run,
|
252
|
+
step_run=step_run_response,
|
253
|
+
last_retry=last_retry,
|
254
|
+
)
|
255
|
+
logger.info(
|
256
|
+
f"Step '{self._step_name}' completed successfully."
|
257
|
+
)
|
258
|
+
break
|
259
|
+
except BaseException as e: # noqa: E722
|
260
|
+
retries += 1
|
261
|
+
if retries < max_retries:
|
262
|
+
logger.error(
|
263
|
+
f"Failed to run step '{self._step_name}'. Retrying..."
|
264
|
+
)
|
265
|
+
logger.exception(e)
|
266
|
+
logger.info(
|
267
|
+
f"Sleeping for {delay} seconds before retrying."
|
268
|
+
)
|
269
|
+
time.sleep(delay)
|
270
|
+
delay *= backoff
|
271
|
+
else:
|
272
|
+
logger.error(
|
273
|
+
f"Failed to run step '{self._step_name}' after {max_retries} retries. Exiting."
|
274
|
+
)
|
275
|
+
publish_utils.publish_failed_step_run(
|
276
|
+
step_run_response.id
|
277
|
+
)
|
278
|
+
raise
|
243
279
|
|
244
280
|
except: # noqa: E722
|
245
281
|
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
|
@@ -360,61 +396,33 @@ class StepLauncher:
|
|
360
396
|
output_name: artifact.id
|
361
397
|
for output_name, artifact in cached_outputs.items()
|
362
398
|
}
|
363
|
-
|
399
|
+
orchestrator_utils._link_cached_artifacts_to_model(
|
364
400
|
model_from_context=model,
|
365
401
|
step_run=step_run,
|
402
|
+
step_source=self._step.spec.source,
|
366
403
|
)
|
404
|
+
if self._step.config.model:
|
405
|
+
orchestrator_utils._link_pipeline_run_to_model_from_context(
|
406
|
+
pipeline_run_id=step_run.pipeline_run_id,
|
407
|
+
model=self._step.config.model,
|
408
|
+
)
|
367
409
|
step_run.status = ExecutionStatus.CACHED
|
368
410
|
step_run.end_time = step_run.start_time
|
369
411
|
|
370
412
|
return execution_needed, step_run
|
371
413
|
|
372
|
-
def _link_cached_artifacts_to_model(
|
373
|
-
self,
|
374
|
-
model_from_context: Optional["Model"],
|
375
|
-
step_run: StepRunRequest,
|
376
|
-
) -> None:
|
377
|
-
"""Links the output artifacts of the cached step to the model version in Control Plane.
|
378
|
-
|
379
|
-
Args:
|
380
|
-
model_from_context: The model version of the current step.
|
381
|
-
step_run: The step to run.
|
382
|
-
"""
|
383
|
-
from zenml.artifacts.artifact_config import ArtifactConfig
|
384
|
-
from zenml.steps.base_step import BaseStep
|
385
|
-
from zenml.steps.utils import parse_return_type_annotations
|
386
|
-
|
387
|
-
step_instance = BaseStep.load_from_source(self._step.spec.source)
|
388
|
-
output_annotations = parse_return_type_annotations(
|
389
|
-
step_instance.entrypoint
|
390
|
-
)
|
391
|
-
for output_name_, output_id in step_run.outputs.items():
|
392
|
-
artifact_config_ = None
|
393
|
-
if output_name_ in output_annotations:
|
394
|
-
annotation = output_annotations.get(output_name_, None)
|
395
|
-
if annotation and annotation.artifact_config is not None:
|
396
|
-
artifact_config_ = annotation.artifact_config.copy()
|
397
|
-
# no artifact config found or artifact was produced by `save_artifact`
|
398
|
-
# inside the step body, so was never in annotations
|
399
|
-
if artifact_config_ is None:
|
400
|
-
artifact_config_ = ArtifactConfig(name=output_name_)
|
401
|
-
|
402
|
-
link_artifact_config_to_model(
|
403
|
-
artifact_config=artifact_config_,
|
404
|
-
model=model_from_context,
|
405
|
-
artifact_version_id=output_id,
|
406
|
-
)
|
407
|
-
|
408
414
|
def _run_step(
|
409
415
|
self,
|
410
416
|
pipeline_run: PipelineRunResponse,
|
411
417
|
step_run: StepRunResponse,
|
418
|
+
last_retry: bool = True,
|
412
419
|
) -> None:
|
413
420
|
"""Runs the current step.
|
414
421
|
|
415
422
|
Args:
|
416
423
|
pipeline_run: The model of the current pipeline run.
|
417
424
|
step_run: The model of the current step run.
|
425
|
+
last_retry: Whether this is the last retry of the step.
|
418
426
|
"""
|
419
427
|
# Prepare step run information.
|
420
428
|
step_run_info = StepRunInfo(
|
@@ -437,6 +445,7 @@ class StepLauncher:
|
|
437
445
|
self._run_step_with_step_operator(
|
438
446
|
step_operator_name=self._step.config.step_operator,
|
439
447
|
step_run_info=step_run_info,
|
448
|
+
last_retry=last_retry,
|
440
449
|
)
|
441
450
|
else:
|
442
451
|
self._run_step_without_step_operator(
|
@@ -445,6 +454,7 @@ class StepLauncher:
|
|
445
454
|
step_run_info=step_run_info,
|
446
455
|
input_artifacts=step_run.inputs,
|
447
456
|
output_artifact_uris=output_artifact_uris,
|
457
|
+
last_retry=last_retry,
|
448
458
|
)
|
449
459
|
except: # noqa: E722
|
450
460
|
output_utils.remove_artifact_dirs(
|
@@ -462,12 +472,14 @@ class StepLauncher:
|
|
462
472
|
self,
|
463
473
|
step_operator_name: str,
|
464
474
|
step_run_info: StepRunInfo,
|
475
|
+
last_retry: bool,
|
465
476
|
) -> None:
|
466
477
|
"""Runs the current step with a step operator.
|
467
478
|
|
468
479
|
Args:
|
469
480
|
step_operator_name: The name of the step operator to use.
|
470
481
|
step_run_info: Additional information needed to run the step.
|
482
|
+
last_retry: Whether this is the last retry of the step.
|
471
483
|
"""
|
472
484
|
step_operator = _get_step_operator(
|
473
485
|
stack=self._stack,
|
@@ -485,6 +497,8 @@ class StepLauncher:
|
|
485
497
|
environment = orchestrator_utils.get_config_environment_vars(
|
486
498
|
deployment=self._deployment
|
487
499
|
)
|
500
|
+
if last_retry:
|
501
|
+
environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
|
488
502
|
logger.info(
|
489
503
|
"Using step operator `%s` to run step `%s`.",
|
490
504
|
step_operator.name,
|
@@ -503,6 +517,7 @@ class StepLauncher:
|
|
503
517
|
step_run_info: StepRunInfo,
|
504
518
|
input_artifacts: Dict[str, ArtifactVersionResponse],
|
505
519
|
output_artifact_uris: Dict[str, str],
|
520
|
+
last_retry: bool,
|
506
521
|
) -> None:
|
507
522
|
"""Runs the current step without a step operator.
|
508
523
|
|
@@ -512,7 +527,10 @@ class StepLauncher:
|
|
512
527
|
step_run_info: Additional information needed to run the step.
|
513
528
|
input_artifacts: The input artifact versions of the current step.
|
514
529
|
output_artifact_uris: The output artifact URIs of the current step.
|
530
|
+
last_retry: Whether this is the last retry of the step.
|
515
531
|
"""
|
532
|
+
if last_retry:
|
533
|
+
os.environ[ENV_ZENML_IGNORE_FAILURE_HOOK] = "false"
|
516
534
|
runner = StepRunner(step=self._step, stack=self._stack)
|
517
535
|
runner.run(
|
518
536
|
pipeline_run=pipeline_run,
|
@@ -23,7 +23,6 @@ from typing import (
|
|
23
23
|
Dict,
|
24
24
|
List,
|
25
25
|
Optional,
|
26
|
-
Set,
|
27
26
|
Tuple,
|
28
27
|
Type,
|
29
28
|
)
|
@@ -33,11 +32,11 @@ from pydantic.typing import get_origin, is_union
|
|
33
32
|
|
34
33
|
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
|
35
34
|
from zenml.artifacts.utils import save_artifact
|
36
|
-
from zenml.client import Client
|
37
35
|
from zenml.config.step_configurations import StepConfiguration
|
38
36
|
from zenml.config.step_run_info import StepRunInfo
|
39
37
|
from zenml.constants import (
|
40
38
|
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE,
|
39
|
+
ENV_ZENML_IGNORE_FAILURE_HOOK,
|
41
40
|
handle_bool_env_var,
|
42
41
|
)
|
43
42
|
from zenml.exceptions import StepContextError, StepInterfaceError
|
@@ -52,7 +51,11 @@ from zenml.orchestrators.publish_utils import (
|
|
52
51
|
publish_step_run_metadata,
|
53
52
|
publish_successful_step_run,
|
54
53
|
)
|
55
|
-
from zenml.orchestrators.utils import
|
54
|
+
from zenml.orchestrators.utils import (
|
55
|
+
_link_pipeline_run_to_model_from_artifacts,
|
56
|
+
_link_pipeline_run_to_model_from_context,
|
57
|
+
is_setting_enabled,
|
58
|
+
)
|
56
59
|
from zenml.steps.step_environment import StepEnvironment
|
57
60
|
from zenml.steps.utils import (
|
58
61
|
OutputSignature,
|
@@ -62,9 +65,6 @@ from zenml.steps.utils import (
|
|
62
65
|
from zenml.utils import materializer_utils, source_utils
|
63
66
|
|
64
67
|
if TYPE_CHECKING:
|
65
|
-
from zenml.artifacts.external_artifact_config import (
|
66
|
-
ExternalArtifactConfiguration,
|
67
|
-
)
|
68
68
|
from zenml.config.source import Source
|
69
69
|
from zenml.config.step_configurations import Step
|
70
70
|
from zenml.models import (
|
@@ -192,8 +192,8 @@ class StepRunner:
|
|
192
192
|
input_artifacts=input_artifacts,
|
193
193
|
)
|
194
194
|
|
195
|
-
|
196
|
-
|
195
|
+
_link_pipeline_run_to_model_from_context(
|
196
|
+
pipeline_run_id=pipeline_run.id
|
197
197
|
)
|
198
198
|
|
199
199
|
step_failed = False
|
@@ -203,15 +203,18 @@ class StepRunner:
|
|
203
203
|
)
|
204
204
|
except BaseException as step_exception: # noqa: E722
|
205
205
|
step_failed = True
|
206
|
-
|
207
|
-
|
208
|
-
)
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
206
|
+
if not handle_bool_env_var(
|
207
|
+
ENV_ZENML_IGNORE_FAILURE_HOOK, False
|
208
|
+
):
|
209
|
+
if (
|
210
|
+
failure_hook_source
|
211
|
+
:= self.configuration.failure_hook_source
|
212
|
+
):
|
213
|
+
logger.info("Detected failure hook. Running...")
|
214
|
+
self.load_and_run_hook(
|
215
|
+
failure_hook_source,
|
216
|
+
step_exception=step_exception,
|
217
|
+
)
|
215
218
|
raise
|
216
219
|
finally:
|
217
220
|
step_run_metadata = self._stack.get_step_run_metadata(
|
@@ -225,10 +228,10 @@ class StepRunner:
|
|
225
228
|
info=step_run_info, step_failed=step_failed
|
226
229
|
)
|
227
230
|
if not step_failed:
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
231
|
+
if (
|
232
|
+
success_hook_source
|
233
|
+
:= self.configuration.success_hook_source
|
234
|
+
):
|
232
235
|
logger.info("Detected success hook. Running...")
|
233
236
|
self.load_and_run_hook(
|
234
237
|
success_hook_source,
|
@@ -258,8 +261,8 @@ class StepRunner:
|
|
258
261
|
link_step_artifacts_to_model(
|
259
262
|
artifact_version_ids=output_artifact_ids
|
260
263
|
)
|
261
|
-
|
262
|
-
|
264
|
+
_link_pipeline_run_to_model_from_artifacts(
|
265
|
+
pipeline_run_id=pipeline_run.id,
|
263
266
|
artifact_names=list(output_artifact_ids.keys()),
|
264
267
|
external_artifacts=list(
|
265
268
|
step_run.config.external_input_artifacts.values()
|
@@ -645,115 +648,6 @@ class StepRunner:
|
|
645
648
|
except StepContextError:
|
646
649
|
return
|
647
650
|
|
648
|
-
def _get_model_versions_from_artifacts(
|
649
|
-
self,
|
650
|
-
artifact_names: List[str],
|
651
|
-
) -> Set[Tuple[UUID, UUID]]:
|
652
|
-
"""Gets the model versions from the artifacts.
|
653
|
-
|
654
|
-
Args:
|
655
|
-
artifact_names: The names of the published output artifacts.
|
656
|
-
|
657
|
-
Returns:
|
658
|
-
Set of tuples of (model_id, model_version_id).
|
659
|
-
"""
|
660
|
-
models = set()
|
661
|
-
for artifact_name in artifact_names:
|
662
|
-
artifact_config = (
|
663
|
-
get_step_context()._get_output(artifact_name).artifact_config
|
664
|
-
)
|
665
|
-
if artifact_config is not None:
|
666
|
-
if (model := artifact_config._model) is not None:
|
667
|
-
model_version_response = (
|
668
|
-
model._get_or_create_model_version()
|
669
|
-
)
|
670
|
-
models.add(
|
671
|
-
(
|
672
|
-
model_version_response.model.id,
|
673
|
-
model_version_response.id,
|
674
|
-
)
|
675
|
-
)
|
676
|
-
else:
|
677
|
-
break
|
678
|
-
return models
|
679
|
-
|
680
|
-
def _get_model_versions_from_config(self) -> Set[Tuple[UUID, UUID]]:
|
681
|
-
"""Gets the model versions from the step model version.
|
682
|
-
|
683
|
-
Returns:
|
684
|
-
Set of tuples of (model_id, model_version_id).
|
685
|
-
"""
|
686
|
-
try:
|
687
|
-
mc = get_step_context().model
|
688
|
-
model_version = mc._get_or_create_model_version()
|
689
|
-
return {(model_version.model.id, model_version.id)}
|
690
|
-
except StepContextError:
|
691
|
-
return set()
|
692
|
-
|
693
|
-
def _link_pipeline_run_to_model_from_context(
|
694
|
-
self,
|
695
|
-
pipeline_run: "PipelineRunResponse",
|
696
|
-
) -> None:
|
697
|
-
"""Links the pipeline run to the model version using artifacts data.
|
698
|
-
|
699
|
-
Args:
|
700
|
-
pipeline_run: The response model of current pipeline run.
|
701
|
-
"""
|
702
|
-
from zenml.models import ModelVersionPipelineRunRequest
|
703
|
-
|
704
|
-
models = self._get_model_versions_from_config()
|
705
|
-
|
706
|
-
client = Client()
|
707
|
-
for model in models:
|
708
|
-
client.zen_store.create_model_version_pipeline_run_link(
|
709
|
-
ModelVersionPipelineRunRequest(
|
710
|
-
user=Client().active_user.id,
|
711
|
-
workspace=Client().active_workspace.id,
|
712
|
-
pipeline_run=pipeline_run.id,
|
713
|
-
model=model[0],
|
714
|
-
model_version=model[1],
|
715
|
-
)
|
716
|
-
)
|
717
|
-
|
718
|
-
def _link_pipeline_run_to_model_from_artifacts(
|
719
|
-
self,
|
720
|
-
pipeline_run: "PipelineRunResponse",
|
721
|
-
artifact_names: List[str],
|
722
|
-
external_artifacts: List["ExternalArtifactConfiguration"],
|
723
|
-
) -> None:
|
724
|
-
"""Links the pipeline run to the model version using artifacts data.
|
725
|
-
|
726
|
-
Args:
|
727
|
-
pipeline_run: The response model of current pipeline run.
|
728
|
-
artifact_names: The name of the published output artifacts.
|
729
|
-
external_artifacts: The external artifacts of the step.
|
730
|
-
"""
|
731
|
-
from zenml.models import ModelVersionPipelineRunRequest
|
732
|
-
|
733
|
-
models = self._get_model_versions_from_artifacts(artifact_names)
|
734
|
-
client = Client()
|
735
|
-
|
736
|
-
# Add models from external artifacts
|
737
|
-
for external_artifact in external_artifacts:
|
738
|
-
if external_artifact.model:
|
739
|
-
models.add(
|
740
|
-
(
|
741
|
-
external_artifact.model.model_id,
|
742
|
-
external_artifact.model.id,
|
743
|
-
)
|
744
|
-
)
|
745
|
-
|
746
|
-
for model in models:
|
747
|
-
client.zen_store.create_model_version_pipeline_run_link(
|
748
|
-
ModelVersionPipelineRunRequest(
|
749
|
-
user=client.active_user.id,
|
750
|
-
workspace=client.active_workspace.id,
|
751
|
-
pipeline_run=pipeline_run.id,
|
752
|
-
model=model[0],
|
753
|
-
model_version=model[1],
|
754
|
-
)
|
755
|
-
)
|
756
|
-
|
757
651
|
def load_and_run_hook(
|
758
652
|
self,
|
759
653
|
hook_source: "Source",
|
zenml/orchestrators/utils.py
CHANGED
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
import random
|
17
17
|
from datetime import datetime
|
18
|
-
from typing import TYPE_CHECKING, Dict, Optional
|
18
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
19
19
|
from uuid import UUID
|
20
20
|
|
21
21
|
from zenml.client import Client
|
22
22
|
from zenml.config.global_config import (
|
23
23
|
GlobalConfiguration,
|
24
24
|
)
|
25
|
+
from zenml.config.source import Source
|
25
26
|
from zenml.constants import (
|
26
27
|
ENV_ZENML_ACTIVE_STACK_ID,
|
27
28
|
ENV_ZENML_ACTIVE_WORKSPACE_ID,
|
@@ -29,8 +30,16 @@ from zenml.constants import (
|
|
29
30
|
PIPELINE_API_TOKEN_EXPIRES_MINUTES,
|
30
31
|
)
|
31
32
|
from zenml.enums import StoreType
|
33
|
+
from zenml.exceptions import StepContextError
|
34
|
+
from zenml.model.utils import link_artifact_config_to_model
|
35
|
+
from zenml.models.v2.core.step_run import StepRunRequest
|
36
|
+
from zenml.new.steps.step_context import get_step_context
|
32
37
|
|
33
38
|
if TYPE_CHECKING:
|
39
|
+
from zenml.artifacts.external_artifact_config import (
|
40
|
+
ExternalArtifactConfiguration,
|
41
|
+
)
|
42
|
+
from zenml.model.model import Model
|
34
43
|
from zenml.models import PipelineDeploymentResponse
|
35
44
|
|
36
45
|
|
@@ -148,3 +157,151 @@ def get_run_name(run_name_template: str) -> str:
|
|
148
157
|
raise ValueError("Empty run names are not allowed.")
|
149
158
|
|
150
159
|
return run_name
|
160
|
+
|
161
|
+
|
162
|
+
def _link_pipeline_run_to_model_from_context(
|
163
|
+
pipeline_run_id: "UUID", model: Optional["Model"] = None
|
164
|
+
) -> None:
|
165
|
+
"""Links the pipeline run to the model version using artifacts data.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
pipeline_run_id: The ID of the current pipeline run.
|
169
|
+
model: Model configured in the step
|
170
|
+
"""
|
171
|
+
from zenml.models import ModelVersionPipelineRunRequest
|
172
|
+
|
173
|
+
if not model:
|
174
|
+
model_id, model_version_id = _get_model_versions_from_config()
|
175
|
+
else:
|
176
|
+
model_id, model_version_id = model.model_id, model.id
|
177
|
+
|
178
|
+
if model_id and model_version_id:
|
179
|
+
Client().zen_store.create_model_version_pipeline_run_link(
|
180
|
+
ModelVersionPipelineRunRequest(
|
181
|
+
user=Client().active_user.id,
|
182
|
+
workspace=Client().active_workspace.id,
|
183
|
+
pipeline_run=pipeline_run_id,
|
184
|
+
model=model_id,
|
185
|
+
model_version=model_version_id,
|
186
|
+
)
|
187
|
+
)
|
188
|
+
|
189
|
+
|
190
|
+
def _get_model_versions_from_config() -> Tuple[Optional[UUID], Optional[UUID]]:
|
191
|
+
"""Gets the model versions from the step model version.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
Tuple of (model_id, model_version_id).
|
195
|
+
"""
|
196
|
+
try:
|
197
|
+
mc = get_step_context().model
|
198
|
+
return mc.model_id, mc.id
|
199
|
+
except StepContextError:
|
200
|
+
return None, None
|
201
|
+
|
202
|
+
|
203
|
+
def _link_cached_artifacts_to_model(
|
204
|
+
model_from_context: Optional["Model"],
|
205
|
+
step_run: StepRunRequest,
|
206
|
+
step_source: Source,
|
207
|
+
) -> None:
|
208
|
+
"""Links the output artifacts of the cached step to the model version in Control Plane.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
model_from_context: The model version of the current step.
|
212
|
+
step_run: The step to run.
|
213
|
+
step_source: The source of the step.
|
214
|
+
"""
|
215
|
+
from zenml.artifacts.artifact_config import ArtifactConfig
|
216
|
+
from zenml.steps.base_step import BaseStep
|
217
|
+
from zenml.steps.utils import parse_return_type_annotations
|
218
|
+
|
219
|
+
step_instance = BaseStep.load_from_source(step_source)
|
220
|
+
output_annotations = parse_return_type_annotations(
|
221
|
+
step_instance.entrypoint
|
222
|
+
)
|
223
|
+
for output_name_, output_id in step_run.outputs.items():
|
224
|
+
artifact_config_ = None
|
225
|
+
if output_name_ in output_annotations:
|
226
|
+
annotation = output_annotations.get(output_name_, None)
|
227
|
+
if annotation and annotation.artifact_config is not None:
|
228
|
+
artifact_config_ = annotation.artifact_config.copy()
|
229
|
+
# no artifact config found or artifact was produced by `save_artifact`
|
230
|
+
# inside the step body, so was never in annotations
|
231
|
+
if artifact_config_ is None:
|
232
|
+
artifact_config_ = ArtifactConfig(name=output_name_)
|
233
|
+
|
234
|
+
link_artifact_config_to_model(
|
235
|
+
artifact_config=artifact_config_,
|
236
|
+
model=model_from_context,
|
237
|
+
artifact_version_id=output_id,
|
238
|
+
)
|
239
|
+
|
240
|
+
|
241
|
+
def _link_pipeline_run_to_model_from_artifacts(
|
242
|
+
pipeline_run_id: UUID,
|
243
|
+
artifact_names: List[str],
|
244
|
+
external_artifacts: List["ExternalArtifactConfiguration"],
|
245
|
+
) -> None:
|
246
|
+
"""Links the pipeline run to the model version using artifacts data.
|
247
|
+
|
248
|
+
Args:
|
249
|
+
pipeline_run_id: The ID of the current pipeline run.
|
250
|
+
artifact_names: The name of the published output artifacts.
|
251
|
+
external_artifacts: The external artifacts of the step.
|
252
|
+
"""
|
253
|
+
from zenml.models import ModelVersionPipelineRunRequest
|
254
|
+
|
255
|
+
models = _get_model_versions_from_artifacts(artifact_names)
|
256
|
+
client = Client()
|
257
|
+
|
258
|
+
# Add models from external artifacts
|
259
|
+
for external_artifact in external_artifacts:
|
260
|
+
if external_artifact.model:
|
261
|
+
models.add(
|
262
|
+
(
|
263
|
+
external_artifact.model.model_id,
|
264
|
+
external_artifact.model.id,
|
265
|
+
)
|
266
|
+
)
|
267
|
+
|
268
|
+
for model in models:
|
269
|
+
client.zen_store.create_model_version_pipeline_run_link(
|
270
|
+
ModelVersionPipelineRunRequest(
|
271
|
+
user=client.active_user.id,
|
272
|
+
workspace=client.active_workspace.id,
|
273
|
+
pipeline_run=pipeline_run_id,
|
274
|
+
model=model[0],
|
275
|
+
model_version=model[1],
|
276
|
+
)
|
277
|
+
)
|
278
|
+
|
279
|
+
|
280
|
+
def _get_model_versions_from_artifacts(
|
281
|
+
artifact_names: List[str],
|
282
|
+
) -> Set[Tuple[UUID, UUID]]:
|
283
|
+
"""Gets the model versions from the artifacts.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
artifact_names: The names of the published output artifacts.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
Set of tuples of (model_id, model_version_id).
|
290
|
+
"""
|
291
|
+
models = set()
|
292
|
+
for artifact_name in artifact_names:
|
293
|
+
artifact_config = (
|
294
|
+
get_step_context()._get_output(artifact_name).artifact_config
|
295
|
+
)
|
296
|
+
if artifact_config is not None:
|
297
|
+
if (model := artifact_config._model) is not None:
|
298
|
+
model_version_response = model._get_or_create_model_version()
|
299
|
+
models.add(
|
300
|
+
(
|
301
|
+
model_version_response.model.id,
|
302
|
+
model_version_response.id,
|
303
|
+
)
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
break
|
307
|
+
return models
|