zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +7 -0
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +21 -0
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/constants.py +3 -0
- zenml/enums.py +2 -0
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
- zenml/integrations/azure/__init__.py +6 -2
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
- zenml/integrations/constants.py +1 -0
- zenml/integrations/deepchecks/__init__.py +1 -1
- zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
- zenml/integrations/deepchecks/validation_checks.py +62 -5
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
- zenml/integrations/lightning/__init__.py +1 -1
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
- zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
- zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
- zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
- zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
- zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
- zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +172 -21
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/new/pipelines/run_utils.py +4 -1
- zenml/orchestrators/base_orchestrator.py +41 -12
- zenml/stack/stack.py +11 -2
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_server/cloud_utils.py +33 -8
- zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
- zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
- zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
- zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
- zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
- zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
- zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
- zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
- zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
- zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
- zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
- zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
- zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
- zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
- zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
- zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
- zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
- zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
- zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
- zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
- zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
- zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
- zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
- zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
- zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
- zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
- zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
- zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
- zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
- zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
- zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
- zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
- zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
- zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
- zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
- zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
- zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
- zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
- zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
- zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
- zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
- zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
- zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
- zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
- zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
- zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
- zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
- zenml/zen_server/dashboard/index.html +4 -4
- zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
- zenml/zen_server/dashboard_legacy/index.html +1 -1
- zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
- zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
- zenml/zen_server/routers/runs_endpoints.py +89 -3
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/METADATA +8 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/RECORD +133 -125
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
- zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
- zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
- zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
- zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
- zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
- zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
- zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
- zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
- zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
- zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
- zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
- zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
- zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
- zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
- zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
- zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
- zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
- zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
- zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
- zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
TYPE_CHECKING,
|
20
20
|
Any,
|
21
21
|
Dict,
|
22
|
+
Iterator,
|
22
23
|
List,
|
23
24
|
Optional,
|
24
25
|
Tuple,
|
@@ -46,8 +47,11 @@ from azure.identity import DefaultAzureCredential
|
|
46
47
|
|
47
48
|
from zenml.config.base_settings import BaseSettings
|
48
49
|
from zenml.config.step_configurations import Step
|
49
|
-
from zenml.constants import
|
50
|
-
|
50
|
+
from zenml.constants import (
|
51
|
+
METADATA_ORCHESTRATOR_RUN_ID,
|
52
|
+
METADATA_ORCHESTRATOR_URL,
|
53
|
+
)
|
54
|
+
from zenml.enums import ExecutionStatus, StackComponentType
|
51
55
|
from zenml.integrations.azure.azureml_utils import create_or_get_compute
|
52
56
|
from zenml.integrations.azure.flavors.azureml import AzureMLComputeTypes
|
53
57
|
from zenml.integrations.azure.flavors.azureml_orchestrator_flavor import (
|
@@ -65,7 +69,7 @@ from zenml.stack import StackValidator
|
|
65
69
|
from zenml.utils.string_utils import b64_encode
|
66
70
|
|
67
71
|
if TYPE_CHECKING:
|
68
|
-
from zenml.models import PipelineDeploymentResponse
|
72
|
+
from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
|
69
73
|
from zenml.stack import Stack
|
70
74
|
|
71
75
|
logger = get_logger(__name__)
|
@@ -199,7 +203,7 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
|
|
199
203
|
deployment: "PipelineDeploymentResponse",
|
200
204
|
stack: "Stack",
|
201
205
|
environment: Dict[str, str],
|
202
|
-
) ->
|
206
|
+
) -> Iterator[Dict[str, MetadataType]]:
|
203
207
|
"""Prepares or runs a pipeline on AzureML.
|
204
208
|
|
205
209
|
Args:
|
@@ -210,6 +214,9 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
|
|
210
214
|
|
211
215
|
Raises:
|
212
216
|
RuntimeError: If the creation of the schedule fails.
|
217
|
+
|
218
|
+
Yields:
|
219
|
+
A dictionary of metadata related to the pipeline run.
|
213
220
|
"""
|
214
221
|
# Authentication
|
215
222
|
if connector := self.get_connector():
|
@@ -379,6 +386,10 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
|
|
379
386
|
else:
|
380
387
|
job = ml_client.jobs.create_or_update(pipeline_job)
|
381
388
|
logger.info(f"Pipeline {run_name} has been started.")
|
389
|
+
|
390
|
+
# Yield metadata based on the generated job object
|
391
|
+
yield from self.compute_metadata(job)
|
392
|
+
|
382
393
|
assert job.services is not None
|
383
394
|
assert job.name is not None
|
384
395
|
|
@@ -428,3 +439,145 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
|
|
428
439
|
f"job: {e}"
|
429
440
|
)
|
430
441
|
return {}
|
442
|
+
|
443
|
+
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
|
444
|
+
"""Refreshes the status of a specific pipeline run.
|
445
|
+
|
446
|
+
Args:
|
447
|
+
run: The run that was executed by this orchestrator.
|
448
|
+
|
449
|
+
Returns:
|
450
|
+
the actual status of the pipeline execution.
|
451
|
+
|
452
|
+
Raises:
|
453
|
+
AssertionError: If the run was not executed by to this orchestrator.
|
454
|
+
ValueError: If it fetches an unknown state or if we can not fetch
|
455
|
+
the orchestrator run ID.
|
456
|
+
"""
|
457
|
+
# Make sure that the stack exists and is accessible
|
458
|
+
if run.stack is None:
|
459
|
+
raise ValueError(
|
460
|
+
"The stack that the run was executed on is not available "
|
461
|
+
"anymore."
|
462
|
+
)
|
463
|
+
|
464
|
+
# Make sure that the run belongs to this orchestrator
|
465
|
+
assert (
|
466
|
+
self.id
|
467
|
+
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
|
468
|
+
)
|
469
|
+
|
470
|
+
# Initialize the AzureML client
|
471
|
+
if connector := self.get_connector():
|
472
|
+
credentials = connector.connect()
|
473
|
+
else:
|
474
|
+
credentials = DefaultAzureCredential()
|
475
|
+
|
476
|
+
ml_client = MLClient(
|
477
|
+
credential=credentials,
|
478
|
+
subscription_id=self.config.subscription_id,
|
479
|
+
resource_group_name=self.config.resource_group,
|
480
|
+
workspace_name=self.config.workspace,
|
481
|
+
)
|
482
|
+
|
483
|
+
# Fetch the status of the PipelineJob
|
484
|
+
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
|
485
|
+
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
|
486
|
+
elif run.orchestrator_run_id is not None:
|
487
|
+
run_id = run.orchestrator_run_id
|
488
|
+
else:
|
489
|
+
raise ValueError(
|
490
|
+
"Can not find the orchestrator run ID, thus can not fetch "
|
491
|
+
"the status."
|
492
|
+
)
|
493
|
+
status = ml_client.jobs.get(run_id).status
|
494
|
+
|
495
|
+
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
|
496
|
+
# https://learn.microsoft.com/en-us/python/api/azure-ai-ml/azure.ai.ml.entities.pipelinejob?view=azure-python#azure-ai-ml-entities-pipelinejob-status
|
497
|
+
if status in [
|
498
|
+
"NotStarted",
|
499
|
+
"Starting",
|
500
|
+
"Provisioning",
|
501
|
+
"Preparing",
|
502
|
+
"Queued",
|
503
|
+
]:
|
504
|
+
return ExecutionStatus.INITIALIZING
|
505
|
+
elif status in ["Running", "Finalizing"]:
|
506
|
+
return ExecutionStatus.RUNNING
|
507
|
+
elif status in [
|
508
|
+
"CancelRequested",
|
509
|
+
"Failed",
|
510
|
+
"Canceled",
|
511
|
+
"NotResponding",
|
512
|
+
]:
|
513
|
+
return ExecutionStatus.FAILED
|
514
|
+
elif status in ["Completed"]:
|
515
|
+
return ExecutionStatus.COMPLETED
|
516
|
+
else:
|
517
|
+
raise ValueError("Unknown status for the pipeline job.")
|
518
|
+
|
519
|
+
def compute_metadata(self, job: Any) -> Iterator[Dict[str, MetadataType]]:
|
520
|
+
"""Generate run metadata based on the generated AzureML PipelineJob.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
job: The corresponding PipelineJob object.
|
524
|
+
|
525
|
+
Yields:
|
526
|
+
A dictionary of metadata related to the pipeline run.
|
527
|
+
"""
|
528
|
+
# Metadata
|
529
|
+
metadata: Dict[str, MetadataType] = {}
|
530
|
+
|
531
|
+
# Orchestrator Run ID
|
532
|
+
if run_id := self._compute_orchestrator_run_id(job):
|
533
|
+
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
|
534
|
+
|
535
|
+
# URL to the AzureML's pipeline view
|
536
|
+
if orchestrator_url := self._compute_orchestrator_url(job):
|
537
|
+
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
|
538
|
+
|
539
|
+
yield metadata
|
540
|
+
|
541
|
+
@staticmethod
|
542
|
+
def _compute_orchestrator_url(job: Any) -> Optional[str]:
|
543
|
+
"""Generate the Orchestrator Dashboard URL upon pipeline execution.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
job: The corresponding PipelineJob object.
|
547
|
+
|
548
|
+
Returns:
|
549
|
+
the URL to the dashboard view in AzureML.
|
550
|
+
"""
|
551
|
+
try:
|
552
|
+
if job.studio_url:
|
553
|
+
return str(job.studio_url)
|
554
|
+
|
555
|
+
return None
|
556
|
+
|
557
|
+
except Exception as e:
|
558
|
+
logger.warning(
|
559
|
+
f"There was an issue while extracting the pipeline url: {e}"
|
560
|
+
)
|
561
|
+
return None
|
562
|
+
|
563
|
+
@staticmethod
|
564
|
+
def _compute_orchestrator_run_id(job: Any) -> Optional[str]:
|
565
|
+
"""Generate the Orchestrator Dashboard URL upon pipeline execution.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
job: The corresponding PipelineJob object.
|
569
|
+
|
570
|
+
Returns:
|
571
|
+
the URL to the dashboard view in AzureML.
|
572
|
+
"""
|
573
|
+
try:
|
574
|
+
if job.name:
|
575
|
+
return str(job.name)
|
576
|
+
|
577
|
+
return None
|
578
|
+
|
579
|
+
except Exception as e:
|
580
|
+
logger.warning(
|
581
|
+
f"There was an issue while extracting the pipeline run ID: {e}"
|
582
|
+
)
|
583
|
+
return None
|
zenml/integrations/constants.py
CHANGED
@@ -35,7 +35,7 @@ class DeepchecksIntegration(Integration):
|
|
35
35
|
|
36
36
|
NAME = DEEPCHECKS
|
37
37
|
REQUIREMENTS = [
|
38
|
-
"deepchecks[vision]
|
38
|
+
"deepchecks[vision]~=0.18.0",
|
39
39
|
"torchvision>=0.14.0",
|
40
40
|
"opencv-python==4.5.5.64", # pin to same version
|
41
41
|
"opencv-python-headless==4.5.5.64", # pin to same version
|
@@ -17,6 +17,7 @@ from typing import (
|
|
17
17
|
Any,
|
18
18
|
ClassVar,
|
19
19
|
Dict,
|
20
|
+
List,
|
20
21
|
Optional,
|
21
22
|
Sequence,
|
22
23
|
Tuple,
|
@@ -28,9 +29,8 @@ import pandas as pd
|
|
28
29
|
from deepchecks.core.checks import BaseCheck
|
29
30
|
from deepchecks.core.suite import SuiteResult
|
30
31
|
from deepchecks.tabular import Dataset as TabularData
|
32
|
+
from deepchecks.tabular import ModelComparisonSuite
|
31
33
|
from deepchecks.tabular import Suite as TabularSuite
|
32
|
-
|
33
|
-
# not part of deepchecks.tabular.checks
|
34
34
|
from deepchecks.tabular.suites import full_suite as full_tabular_suite
|
35
35
|
from deepchecks.vision import Suite as VisionSuite
|
36
36
|
from deepchecks.vision import VisionData
|
@@ -102,7 +102,7 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
102
102
|
comparison_dataset: Optional[
|
103
103
|
Union[pd.DataFrame, DataLoader[Any]]
|
104
104
|
] = None,
|
105
|
-
|
105
|
+
models: Optional[List[Union[ClassifierMixin, Module]]] = None,
|
106
106
|
check_list: Optional[Sequence[str]] = None,
|
107
107
|
dataset_kwargs: Dict[str, Any] = {},
|
108
108
|
check_kwargs: Dict[str, Dict[str, Any]] = {},
|
@@ -123,7 +123,7 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
123
123
|
validation.
|
124
124
|
comparison_dataset: Optional secondary (comparison) dataset argument
|
125
125
|
used during comparison checks.
|
126
|
-
|
126
|
+
models: Optional model argument used during validation.
|
127
127
|
check_list: Optional list of ZenML Deepchecks check identifiers
|
128
128
|
specifying the list of Deepchecks checks to be performed.
|
129
129
|
dataset_kwargs: Additional keyword arguments to be passed to the
|
@@ -149,6 +149,7 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
149
149
|
# arguments and the check list.
|
150
150
|
is_tabular = False
|
151
151
|
is_vision = False
|
152
|
+
is_multi_model = False
|
152
153
|
for dataset in [reference_dataset, comparison_dataset]:
|
153
154
|
if dataset is None:
|
154
155
|
continue
|
@@ -163,7 +164,18 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
163
164
|
f"data and {str(DataLoader)} for computer vision data."
|
164
165
|
)
|
165
166
|
|
166
|
-
if
|
167
|
+
if models:
|
168
|
+
# if there's more than one models, we should set the
|
169
|
+
# is_multi_model to True
|
170
|
+
if len(models) > 1:
|
171
|
+
is_multi_model = True
|
172
|
+
# if the models are of different types, raise an error
|
173
|
+
# only the same type of models can be used for comparison
|
174
|
+
if len(set(type(model) for model in models)) > 1:
|
175
|
+
raise TypeError(
|
176
|
+
"Models used for comparison checks must be of the same type."
|
177
|
+
)
|
178
|
+
model = models[0]
|
167
179
|
if isinstance(model, ClassifierMixin):
|
168
180
|
is_tabular = True
|
169
181
|
elif isinstance(model, Module):
|
@@ -190,8 +202,18 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
190
202
|
if not check_list:
|
191
203
|
# default to executing all the checks listed in the supplied
|
192
204
|
# checks enum type if a custom check list is not supplied
|
205
|
+
# don't include the TABULAR_PERFORMANCE_BIAS check enum value
|
206
|
+
# as it requires a protected feature name to be set
|
207
|
+
checks_to_exclude = [
|
208
|
+
DeepchecksModelValidationCheck.TABULAR_PERFORMANCE_BIAS
|
209
|
+
]
|
210
|
+
check_enum_values = [
|
211
|
+
check.value
|
212
|
+
for check in check_enum
|
213
|
+
if check not in checks_to_exclude
|
214
|
+
]
|
193
215
|
tabular_checks, vision_checks = cls._split_checks(
|
194
|
-
|
216
|
+
check_enum_values
|
195
217
|
)
|
196
218
|
if is_tabular:
|
197
219
|
check_list = tabular_checks
|
@@ -254,6 +276,10 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
254
276
|
suite_class = VisionSuite
|
255
277
|
full_suite = full_vision_suite()
|
256
278
|
|
279
|
+
# if is_multi_model is True, we need to use the ModelComparisonSuite
|
280
|
+
if is_multi_model:
|
281
|
+
suite_class = ModelComparisonSuite
|
282
|
+
|
257
283
|
train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
|
258
284
|
test_dataset = None
|
259
285
|
if comparison_dataset is not None:
|
@@ -294,13 +320,28 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
294
320
|
continue
|
295
321
|
condition_method(**condition_kwargs)
|
296
322
|
|
297
|
-
suite
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
323
|
+
# if the check is supported by the suite, add it
|
324
|
+
if isinstance(check, suite.supported_checks()):
|
325
|
+
suite.add(check)
|
326
|
+
else:
|
327
|
+
logger.warning(
|
328
|
+
f"Check {check_name} is not supported by the {suite_class} "
|
329
|
+
"suite. Ignoring the check."
|
330
|
+
)
|
331
|
+
|
332
|
+
if isinstance(suite, ModelComparisonSuite):
|
333
|
+
return suite.run(
|
334
|
+
models=models,
|
335
|
+
train_datasets=train_dataset,
|
336
|
+
test_datasets=test_dataset,
|
337
|
+
)
|
338
|
+
else:
|
339
|
+
return suite.run(
|
340
|
+
train_dataset=train_dataset,
|
341
|
+
test_dataset=test_dataset,
|
342
|
+
model=models[0] if models else None,
|
343
|
+
**run_kwargs,
|
344
|
+
)
|
304
345
|
|
305
346
|
def data_validation(
|
306
347
|
self,
|
@@ -444,7 +485,7 @@ class DeepchecksDataValidator(BaseDataValidator):
|
|
444
485
|
check_enum=check_enum,
|
445
486
|
reference_dataset=dataset,
|
446
487
|
comparison_dataset=comparison_dataset,
|
447
|
-
|
488
|
+
models=[model],
|
448
489
|
check_list=check_list,
|
449
490
|
dataset_kwargs=dataset_kwargs,
|
450
491
|
check_kwargs=check_kwargs,
|
@@ -153,8 +153,8 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
|
|
153
153
|
|
154
154
|
This list reflects the set of data integrity checks provided by Deepchecks:
|
155
155
|
|
156
|
-
* [for tabular data](https://docs.deepchecks.com/
|
157
|
-
* [for computer vision](https://docs.deepchecks.com/
|
156
|
+
* [for tabular data](https://docs.deepchecks.com/stable/tabular/auto_checks/data_integrity/index.html)
|
157
|
+
* [for computer vision](https://docs.deepchecks.com/stable/vision/auto_checks/data_integrity/index.html)
|
158
158
|
|
159
159
|
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
|
160
160
|
`deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
|
@@ -176,6 +176,9 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
|
|
176
176
|
TABULAR_FEATURE_LABEL_CORRELATION = source_utils.resolve(
|
177
177
|
tabular_checks.FeatureLabelCorrelation
|
178
178
|
).import_path
|
179
|
+
TABULAR_IDENTIFIER_LABEL_CORRELATION = source_utils.resolve(
|
180
|
+
tabular_checks.IdentifierLabelCorrelation
|
181
|
+
).import_path
|
179
182
|
TABULAR_IS_SINGLE_VALUE = source_utils.resolve(
|
180
183
|
tabular_checks.IsSingleValue
|
181
184
|
).import_path
|
@@ -197,6 +200,12 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
|
|
197
200
|
TABULAR_STRING_MISMATCH = source_utils.resolve(
|
198
201
|
tabular_checks.StringMismatch
|
199
202
|
).import_path
|
203
|
+
TABULAR_CLASS_IMBALANCE = source_utils.resolve(
|
204
|
+
tabular_checks.ClassImbalance
|
205
|
+
).import_path
|
206
|
+
TABULAR_PERCENT_OF_NULLS = source_utils.resolve(
|
207
|
+
tabular_checks.PercentOfNulls
|
208
|
+
).import_path
|
200
209
|
|
201
210
|
VISION_IMAGE_PROPERTY_OUTLIERS = source_utils.resolve(
|
202
211
|
vision_checks.ImagePropertyOutliers
|
@@ -204,6 +213,9 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
|
|
204
213
|
VISION_LABEL_PROPERTY_OUTLIERS = source_utils.resolve(
|
205
214
|
vision_checks.LabelPropertyOutliers
|
206
215
|
).import_path
|
216
|
+
VISION_PROPERTY_LABEL_CORRELATION = source_utils.resolve(
|
217
|
+
vision_checks.PropertyLabelCorrelation
|
218
|
+
).import_path
|
207
219
|
|
208
220
|
|
209
221
|
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
|
@@ -246,19 +258,37 @@ class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
|
|
246
258
|
TABULAR_TRAIN_TEST_FEATURE_DRIFT = source_utils.resolve(
|
247
259
|
tabular_checks.TrainTestFeatureDrift
|
248
260
|
).import_path
|
261
|
+
TABULAR_FEATURE_DRIFT = source_utils.resolve(
|
262
|
+
tabular_checks.FeatureDrift
|
263
|
+
).import_path
|
249
264
|
TABULAR_TRAIN_TEST_LABEL_DRIFT = source_utils.resolve(
|
250
265
|
tabular_checks.TrainTestLabelDrift
|
251
266
|
).import_path
|
267
|
+
TABULAR_LABEL_DRIFT = source_utils.resolve(
|
268
|
+
tabular_checks.LabelDrift
|
269
|
+
).import_path
|
252
270
|
TABULAR_TRAIN_TEST_SAMPLES_MIX = source_utils.resolve(
|
253
271
|
tabular_checks.TrainTestSamplesMix
|
254
272
|
).import_path
|
255
273
|
TABULAR_WHOLE_DATASET_DRIFT = source_utils.resolve(
|
256
274
|
tabular_checks.WholeDatasetDrift
|
257
275
|
).import_path
|
276
|
+
TABULAR_NEW_CATEGORY_TRAIN_TEST = source_utils.resolve(
|
277
|
+
tabular_checks.NewCategoryTrainTest
|
278
|
+
).import_path
|
279
|
+
TABULAR_MULTIVARIATE_DRIFT = source_utils.resolve(
|
280
|
+
tabular_checks.MultivariateDrift
|
281
|
+
).import_path
|
258
282
|
|
283
|
+
VISION_PROPERTY_LABEL_CORRELATION_CHANGE = source_utils.resolve(
|
284
|
+
vision_checks.PropertyLabelCorrelationChange
|
285
|
+
).import_path
|
259
286
|
VISION_HEATMAP_COMPARISON = source_utils.resolve(
|
260
287
|
vision_checks.HeatmapComparison
|
261
288
|
).import_path
|
289
|
+
VISION_LABEL_DRIFT = source_utils.resolve(
|
290
|
+
vision_checks.LabelDrift
|
291
|
+
).import_path
|
262
292
|
VISION_IMAGE_DATASET_DRIFT = source_utils.resolve(
|
263
293
|
vision_checks.ImageDatasetDrift
|
264
294
|
).import_path
|
@@ -268,9 +298,6 @@ class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
|
|
268
298
|
VISION_NEW_LABELS = source_utils.resolve(
|
269
299
|
vision_checks.NewLabels
|
270
300
|
).import_path
|
271
|
-
VISION_TRAIN_TEST_LABEL_DRIFT = source_utils.resolve(
|
272
|
-
vision_checks.TrainTestLabelDrift
|
273
|
-
).import_path
|
274
301
|
|
275
302
|
|
276
303
|
class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
|
@@ -296,6 +323,12 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
|
|
296
323
|
TABULAR_MODEL_INFERENCE_TIME = source_utils.resolve(
|
297
324
|
tabular_checks.ModelInferenceTime
|
298
325
|
).import_path
|
326
|
+
TABULAR_MODEL_INFO = source_utils.resolve(
|
327
|
+
tabular_checks.ModelInfo
|
328
|
+
).import_path
|
329
|
+
TABULAR_PERFORMANCE_BIAS = source_utils.resolve(
|
330
|
+
tabular_checks.model_evaluation.PerformanceBias
|
331
|
+
).import_path
|
299
332
|
TABULAR_REGRESSION_ERROR_DISTRIBUTION = source_utils.resolve(
|
300
333
|
tabular_checks.RegressionErrorDistribution
|
301
334
|
).import_path
|
@@ -308,6 +341,18 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
|
|
308
341
|
TABULAR_SEGMENT_PERFORMANCE = source_utils.resolve(
|
309
342
|
tabular_checks.SegmentPerformance
|
310
343
|
).import_path
|
344
|
+
TABULAR_WEAK_SEGMENT_PERFORMANCE = source_utils.resolve(
|
345
|
+
tabular_checks.WeakSegmentsPerformance
|
346
|
+
).import_path
|
347
|
+
TABULAR_SINGLE_DATASET_PERFORMANCE = source_utils.resolve(
|
348
|
+
tabular_checks.SingleDatasetPerformance
|
349
|
+
).import_path
|
350
|
+
TABULAR_TRAIN_TEST_PERFORMANCE = source_utils.resolve(
|
351
|
+
tabular_checks.TrainTestPerformance
|
352
|
+
).import_path
|
353
|
+
TABULAR_MULTI_MODEL_PERFORMANCE_REPORT = source_utils.resolve(
|
354
|
+
tabular_checks.MultiModelPerformanceReport
|
355
|
+
).import_path
|
311
356
|
|
312
357
|
VISION_CONFUSION_MATRIX_REPORT = source_utils.resolve(
|
313
358
|
vision_checks.ConfusionMatrixReport
|
@@ -318,6 +363,12 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
|
|
318
363
|
VISION_MEAN_AVERAGE_RECALL_REPORT = source_utils.resolve(
|
319
364
|
vision_checks.MeanAverageRecallReport
|
320
365
|
).import_path
|
366
|
+
VISION_SINGLE_DATASET_PERFORMANCE = source_utils.resolve(
|
367
|
+
vision_checks.SingleDatasetPerformance
|
368
|
+
).import_path
|
369
|
+
VISION_WEAK_SEGMENT_PERFORMANCE = source_utils.resolve(
|
370
|
+
vision_checks.WeakSegmentsPerformance
|
371
|
+
).import_path
|
321
372
|
|
322
373
|
|
323
374
|
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
|
@@ -343,6 +394,9 @@ class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
|
|
343
394
|
TABULAR_TRAIN_TEST_PREDICTION_DRIFT = source_utils.resolve(
|
344
395
|
tabular_checks.TrainTestPredictionDrift
|
345
396
|
).import_path
|
397
|
+
TABULAR_PREDICTION_DRIFT = source_utils.resolve(
|
398
|
+
tabular_checks.PredictionDrift
|
399
|
+
).import_path
|
346
400
|
TABULAR_UNUSED_FEATURES = source_utils.resolve(
|
347
401
|
tabular_checks.UnusedFeatures
|
348
402
|
).import_path
|
@@ -356,3 +410,6 @@ class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
|
|
356
410
|
VISION_TRAIN_TEST_PREDICTION_DRIFT = source_utils.resolve(
|
357
411
|
vision_checks.TrainTestPredictionDrift
|
358
412
|
).import_path
|
413
|
+
VISION_PREDICTION_DRIFT = source_utils.resolve(
|
414
|
+
vision_checks.PredictionDrift
|
415
|
+
).import_path
|