zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/base.py +2 -2
  4. zenml/cli/pipeline.py +21 -0
  5. zenml/cli/utils.py +14 -11
  6. zenml/client.py +68 -3
  7. zenml/config/step_configurations.py +0 -5
  8. zenml/constants.py +3 -0
  9. zenml/enums.py +2 -0
  10. zenml/integrations/__init__.py +1 -0
  11. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  12. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
  13. zenml/integrations/azure/__init__.py +6 -2
  14. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  15. zenml/integrations/constants.py +1 -0
  16. zenml/integrations/deepchecks/__init__.py +1 -1
  17. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
  18. zenml/integrations/deepchecks/validation_checks.py +62 -5
  19. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  20. zenml/integrations/lightning/__init__.py +1 -1
  21. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  22. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  23. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  24. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  25. zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
  26. zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
  27. zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
  28. zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
  29. zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
  30. zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
  31. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  32. zenml/models/v2/base/filter.py +315 -149
  33. zenml/models/v2/base/scoped.py +5 -2
  34. zenml/models/v2/core/artifact_version.py +69 -8
  35. zenml/models/v2/core/model.py +43 -6
  36. zenml/models/v2/core/model_version.py +49 -1
  37. zenml/models/v2/core/model_version_artifact.py +18 -3
  38. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  39. zenml/models/v2/core/pipeline.py +108 -1
  40. zenml/models/v2/core/pipeline_run.py +172 -21
  41. zenml/models/v2/core/run_template.py +53 -1
  42. zenml/models/v2/core/stack.py +33 -5
  43. zenml/models/v2/core/step_run.py +7 -0
  44. zenml/new/pipelines/pipeline.py +4 -0
  45. zenml/new/pipelines/run_utils.py +4 -1
  46. zenml/orchestrators/base_orchestrator.py +41 -12
  47. zenml/stack/stack.py +11 -2
  48. zenml/utils/env_utils.py +54 -1
  49. zenml/utils/string_utils.py +50 -0
  50. zenml/zen_server/cloud_utils.py +33 -8
  51. zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
  52. zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
  53. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
  54. zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
  55. zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
  56. zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
  57. zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
  58. zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
  59. zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
  60. zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
  61. zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
  63. zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
  64. zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
  65. zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
  66. zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
  67. zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
  69. zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
  70. zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
  71. zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
  72. zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
  73. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
  74. zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
  75. zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
  77. zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
  78. zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
  79. zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
  80. zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
  81. zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
  82. zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
  83. zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
  85. zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
  86. zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
  87. zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
  88. zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
  89. zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
  90. zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
  91. zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
  92. zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
  93. zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
  94. zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
  95. zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
  96. zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
  97. zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
  98. zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
  99. zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
  100. zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
  101. zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
  102. zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
  103. zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
  104. zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
  105. zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
  106. zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
  107. zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
  108. zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
  109. zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
  110. zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
  111. zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
  112. zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
  113. zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
  114. zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
  115. zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
  116. zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
  117. zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
  118. zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
  119. zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
  120. zenml/zen_server/dashboard/index.html +4 -4
  121. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  122. zenml/zen_server/dashboard_legacy/index.html +1 -1
  123. zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
  124. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  125. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
  126. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
  127. zenml/zen_server/routers/runs_endpoints.py +89 -3
  128. zenml/zen_stores/sql_zen_store.py +1 -0
  129. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/METADATA +8 -1
  130. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/RECORD +133 -125
  131. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
  132. zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
  133. zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
  134. zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
  135. zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
  136. zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
  137. zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
  138. zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
  139. zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
  140. zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
  141. zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
  142. zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
  143. zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
  144. zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
  145. zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
  146. zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
  147. zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
  148. zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
  149. zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
  150. zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
  151. zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
  152. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/LICENSE +0 -0
  153. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/WHEEL +0 -0
  154. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from typing import (
19
19
  TYPE_CHECKING,
20
20
  Any,
21
21
  Dict,
22
+ Iterator,
22
23
  List,
23
24
  Optional,
24
25
  Tuple,
@@ -46,8 +47,11 @@ from azure.identity import DefaultAzureCredential
46
47
 
47
48
  from zenml.config.base_settings import BaseSettings
48
49
  from zenml.config.step_configurations import Step
49
- from zenml.constants import METADATA_ORCHESTRATOR_URL
50
- from zenml.enums import StackComponentType
50
+ from zenml.constants import (
51
+ METADATA_ORCHESTRATOR_RUN_ID,
52
+ METADATA_ORCHESTRATOR_URL,
53
+ )
54
+ from zenml.enums import ExecutionStatus, StackComponentType
51
55
  from zenml.integrations.azure.azureml_utils import create_or_get_compute
52
56
  from zenml.integrations.azure.flavors.azureml import AzureMLComputeTypes
53
57
  from zenml.integrations.azure.flavors.azureml_orchestrator_flavor import (
@@ -65,7 +69,7 @@ from zenml.stack import StackValidator
65
69
  from zenml.utils.string_utils import b64_encode
66
70
 
67
71
  if TYPE_CHECKING:
68
- from zenml.models import PipelineDeploymentResponse
72
+ from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
69
73
  from zenml.stack import Stack
70
74
 
71
75
  logger = get_logger(__name__)
@@ -199,7 +203,7 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
199
203
  deployment: "PipelineDeploymentResponse",
200
204
  stack: "Stack",
201
205
  environment: Dict[str, str],
202
- ) -> None:
206
+ ) -> Iterator[Dict[str, MetadataType]]:
203
207
  """Prepares or runs a pipeline on AzureML.
204
208
 
205
209
  Args:
@@ -210,6 +214,9 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
210
214
 
211
215
  Raises:
212
216
  RuntimeError: If the creation of the schedule fails.
217
+
218
+ Yields:
219
+ A dictionary of metadata related to the pipeline run.
213
220
  """
214
221
  # Authentication
215
222
  if connector := self.get_connector():
@@ -379,6 +386,10 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
379
386
  else:
380
387
  job = ml_client.jobs.create_or_update(pipeline_job)
381
388
  logger.info(f"Pipeline {run_name} has been started.")
389
+
390
+ # Yield metadata based on the generated job object
391
+ yield from self.compute_metadata(job)
392
+
382
393
  assert job.services is not None
383
394
  assert job.name is not None
384
395
 
@@ -428,3 +439,145 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
428
439
  f"job: {e}"
429
440
  )
430
441
  return {}
442
+
443
+ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
444
+ """Refreshes the status of a specific pipeline run.
445
+
446
+ Args:
447
+ run: The run that was executed by this orchestrator.
448
+
449
+ Returns:
450
+ the actual status of the pipeline execution.
451
+
452
+ Raises:
453
+ AssertionError: If the run was not executed by to this orchestrator.
454
+ ValueError: If it fetches an unknown state or if we can not fetch
455
+ the orchestrator run ID.
456
+ """
457
+ # Make sure that the stack exists and is accessible
458
+ if run.stack is None:
459
+ raise ValueError(
460
+ "The stack that the run was executed on is not available "
461
+ "anymore."
462
+ )
463
+
464
+ # Make sure that the run belongs to this orchestrator
465
+ assert (
466
+ self.id
467
+ == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
468
+ )
469
+
470
+ # Initialize the AzureML client
471
+ if connector := self.get_connector():
472
+ credentials = connector.connect()
473
+ else:
474
+ credentials = DefaultAzureCredential()
475
+
476
+ ml_client = MLClient(
477
+ credential=credentials,
478
+ subscription_id=self.config.subscription_id,
479
+ resource_group_name=self.config.resource_group,
480
+ workspace_name=self.config.workspace,
481
+ )
482
+
483
+ # Fetch the status of the PipelineJob
484
+ if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
485
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
486
+ elif run.orchestrator_run_id is not None:
487
+ run_id = run.orchestrator_run_id
488
+ else:
489
+ raise ValueError(
490
+ "Can not find the orchestrator run ID, thus can not fetch "
491
+ "the status."
492
+ )
493
+ status = ml_client.jobs.get(run_id).status
494
+
495
+ # Map the potential outputs to ZenML ExecutionStatus. Potential values:
496
+ # https://learn.microsoft.com/en-us/python/api/azure-ai-ml/azure.ai.ml.entities.pipelinejob?view=azure-python#azure-ai-ml-entities-pipelinejob-status
497
+ if status in [
498
+ "NotStarted",
499
+ "Starting",
500
+ "Provisioning",
501
+ "Preparing",
502
+ "Queued",
503
+ ]:
504
+ return ExecutionStatus.INITIALIZING
505
+ elif status in ["Running", "Finalizing"]:
506
+ return ExecutionStatus.RUNNING
507
+ elif status in [
508
+ "CancelRequested",
509
+ "Failed",
510
+ "Canceled",
511
+ "NotResponding",
512
+ ]:
513
+ return ExecutionStatus.FAILED
514
+ elif status in ["Completed"]:
515
+ return ExecutionStatus.COMPLETED
516
+ else:
517
+ raise ValueError("Unknown status for the pipeline job.")
518
+
519
+ def compute_metadata(self, job: Any) -> Iterator[Dict[str, MetadataType]]:
520
+ """Generate run metadata based on the generated AzureML PipelineJob.
521
+
522
+ Args:
523
+ job: The corresponding PipelineJob object.
524
+
525
+ Yields:
526
+ A dictionary of metadata related to the pipeline run.
527
+ """
528
+ # Metadata
529
+ metadata: Dict[str, MetadataType] = {}
530
+
531
+ # Orchestrator Run ID
532
+ if run_id := self._compute_orchestrator_run_id(job):
533
+ metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
534
+
535
+ # URL to the AzureML's pipeline view
536
+ if orchestrator_url := self._compute_orchestrator_url(job):
537
+ metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
538
+
539
+ yield metadata
540
+
541
+ @staticmethod
542
+ def _compute_orchestrator_url(job: Any) -> Optional[str]:
543
+ """Generate the Orchestrator Dashboard URL upon pipeline execution.
544
+
545
+ Args:
546
+ job: The corresponding PipelineJob object.
547
+
548
+ Returns:
549
+ the URL to the dashboard view in AzureML.
550
+ """
551
+ try:
552
+ if job.studio_url:
553
+ return str(job.studio_url)
554
+
555
+ return None
556
+
557
+ except Exception as e:
558
+ logger.warning(
559
+ f"There was an issue while extracting the pipeline url: {e}"
560
+ )
561
+ return None
562
+
563
+ @staticmethod
564
+ def _compute_orchestrator_run_id(job: Any) -> Optional[str]:
565
+ """Generate the Orchestrator Dashboard URL upon pipeline execution.
566
+
567
+ Args:
568
+ job: The corresponding PipelineJob object.
569
+
570
+ Returns:
571
+ the URL to the dashboard view in AzureML.
572
+ """
573
+ try:
574
+ if job.name:
575
+ return str(job.name)
576
+
577
+ return None
578
+
579
+ except Exception as e:
580
+ logger.warning(
581
+ f"There was an issue while extracting the pipeline run ID: {e}"
582
+ )
583
+ return None
@@ -64,6 +64,7 @@ SKYPILOT_AWS = "skypilot_aws"
64
64
  SKYPILOT_GCP = "skypilot_gcp"
65
65
  SKYPILOT_AZURE = "skypilot_azure"
66
66
  SKYPILOT_LAMBDA = "skypilot_lambda"
67
+ SKYPILOT_KUBERNETES = "skypilot_kubernetes"
67
68
  SLACK = "slack"
68
69
  SPARK = "spark"
69
70
  TEKTON = "tekton"
@@ -35,7 +35,7 @@ class DeepchecksIntegration(Integration):
35
35
 
36
36
  NAME = DEEPCHECKS
37
37
  REQUIREMENTS = [
38
- "deepchecks[vision]>=0.18.0",
38
+ "deepchecks[vision]~=0.18.0",
39
39
  "torchvision>=0.14.0",
40
40
  "opencv-python==4.5.5.64", # pin to same version
41
41
  "opencv-python-headless==4.5.5.64", # pin to same version
@@ -17,6 +17,7 @@ from typing import (
17
17
  Any,
18
18
  ClassVar,
19
19
  Dict,
20
+ List,
20
21
  Optional,
21
22
  Sequence,
22
23
  Tuple,
@@ -28,9 +29,8 @@ import pandas as pd
28
29
  from deepchecks.core.checks import BaseCheck
29
30
  from deepchecks.core.suite import SuiteResult
30
31
  from deepchecks.tabular import Dataset as TabularData
32
+ from deepchecks.tabular import ModelComparisonSuite
31
33
  from deepchecks.tabular import Suite as TabularSuite
32
-
33
- # not part of deepchecks.tabular.checks
34
34
  from deepchecks.tabular.suites import full_suite as full_tabular_suite
35
35
  from deepchecks.vision import Suite as VisionSuite
36
36
  from deepchecks.vision import VisionData
@@ -102,7 +102,7 @@ class DeepchecksDataValidator(BaseDataValidator):
102
102
  comparison_dataset: Optional[
103
103
  Union[pd.DataFrame, DataLoader[Any]]
104
104
  ] = None,
105
- model: Optional[Union[ClassifierMixin, Module]] = None,
105
+ models: Optional[List[Union[ClassifierMixin, Module]]] = None,
106
106
  check_list: Optional[Sequence[str]] = None,
107
107
  dataset_kwargs: Dict[str, Any] = {},
108
108
  check_kwargs: Dict[str, Dict[str, Any]] = {},
@@ -123,7 +123,7 @@ class DeepchecksDataValidator(BaseDataValidator):
123
123
  validation.
124
124
  comparison_dataset: Optional secondary (comparison) dataset argument
125
125
  used during comparison checks.
126
- model: Optional model argument used during validation.
126
+ models: Optional model argument used during validation.
127
127
  check_list: Optional list of ZenML Deepchecks check identifiers
128
128
  specifying the list of Deepchecks checks to be performed.
129
129
  dataset_kwargs: Additional keyword arguments to be passed to the
@@ -149,6 +149,7 @@ class DeepchecksDataValidator(BaseDataValidator):
149
149
  # arguments and the check list.
150
150
  is_tabular = False
151
151
  is_vision = False
152
+ is_multi_model = False
152
153
  for dataset in [reference_dataset, comparison_dataset]:
153
154
  if dataset is None:
154
155
  continue
@@ -163,7 +164,18 @@ class DeepchecksDataValidator(BaseDataValidator):
163
164
  f"data and {str(DataLoader)} for computer vision data."
164
165
  )
165
166
 
166
- if model:
167
+ if models:
168
+ # if there's more than one models, we should set the
169
+ # is_multi_model to True
170
+ if len(models) > 1:
171
+ is_multi_model = True
172
+ # if the models are of different types, raise an error
173
+ # only the same type of models can be used for comparison
174
+ if len(set(type(model) for model in models)) > 1:
175
+ raise TypeError(
176
+ "Models used for comparison checks must be of the same type."
177
+ )
178
+ model = models[0]
167
179
  if isinstance(model, ClassifierMixin):
168
180
  is_tabular = True
169
181
  elif isinstance(model, Module):
@@ -190,8 +202,18 @@ class DeepchecksDataValidator(BaseDataValidator):
190
202
  if not check_list:
191
203
  # default to executing all the checks listed in the supplied
192
204
  # checks enum type if a custom check list is not supplied
205
+ # don't include the TABULAR_PERFORMANCE_BIAS check enum value
206
+ # as it requires a protected feature name to be set
207
+ checks_to_exclude = [
208
+ DeepchecksModelValidationCheck.TABULAR_PERFORMANCE_BIAS
209
+ ]
210
+ check_enum_values = [
211
+ check.value
212
+ for check in check_enum
213
+ if check not in checks_to_exclude
214
+ ]
193
215
  tabular_checks, vision_checks = cls._split_checks(
194
- check_enum.values()
216
+ check_enum_values
195
217
  )
196
218
  if is_tabular:
197
219
  check_list = tabular_checks
@@ -254,6 +276,10 @@ class DeepchecksDataValidator(BaseDataValidator):
254
276
  suite_class = VisionSuite
255
277
  full_suite = full_vision_suite()
256
278
 
279
+ # if is_multi_model is True, we need to use the ModelComparisonSuite
280
+ if is_multi_model:
281
+ suite_class = ModelComparisonSuite
282
+
257
283
  train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
258
284
  test_dataset = None
259
285
  if comparison_dataset is not None:
@@ -294,13 +320,28 @@ class DeepchecksDataValidator(BaseDataValidator):
294
320
  continue
295
321
  condition_method(**condition_kwargs)
296
322
 
297
- suite.add(check)
298
- return suite.run(
299
- train_dataset=train_dataset,
300
- test_dataset=test_dataset,
301
- model=model,
302
- **run_kwargs,
303
- )
323
+ # if the check is supported by the suite, add it
324
+ if isinstance(check, suite.supported_checks()):
325
+ suite.add(check)
326
+ else:
327
+ logger.warning(
328
+ f"Check {check_name} is not supported by the {suite_class} "
329
+ "suite. Ignoring the check."
330
+ )
331
+
332
+ if isinstance(suite, ModelComparisonSuite):
333
+ return suite.run(
334
+ models=models,
335
+ train_datasets=train_dataset,
336
+ test_datasets=test_dataset,
337
+ )
338
+ else:
339
+ return suite.run(
340
+ train_dataset=train_dataset,
341
+ test_dataset=test_dataset,
342
+ model=models[0] if models else None,
343
+ **run_kwargs,
344
+ )
304
345
 
305
346
  def data_validation(
306
347
  self,
@@ -444,7 +485,7 @@ class DeepchecksDataValidator(BaseDataValidator):
444
485
  check_enum=check_enum,
445
486
  reference_dataset=dataset,
446
487
  comparison_dataset=comparison_dataset,
447
- model=model,
488
+ models=[model],
448
489
  check_list=check_list,
449
490
  dataset_kwargs=dataset_kwargs,
450
491
  check_kwargs=check_kwargs,
@@ -153,8 +153,8 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
153
153
 
154
154
  This list reflects the set of data integrity checks provided by Deepchecks:
155
155
 
156
- * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
157
- * [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)
156
+ * [for tabular data](https://docs.deepchecks.com/stable/tabular/auto_checks/data_integrity/index.html)
157
+ * [for computer vision](https://docs.deepchecks.com/stable/vision/auto_checks/data_integrity/index.html)
158
158
 
159
159
  All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
160
160
  `deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
@@ -176,6 +176,9 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
176
176
  TABULAR_FEATURE_LABEL_CORRELATION = source_utils.resolve(
177
177
  tabular_checks.FeatureLabelCorrelation
178
178
  ).import_path
179
+ TABULAR_IDENTIFIER_LABEL_CORRELATION = source_utils.resolve(
180
+ tabular_checks.IdentifierLabelCorrelation
181
+ ).import_path
179
182
  TABULAR_IS_SINGLE_VALUE = source_utils.resolve(
180
183
  tabular_checks.IsSingleValue
181
184
  ).import_path
@@ -197,6 +200,12 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
197
200
  TABULAR_STRING_MISMATCH = source_utils.resolve(
198
201
  tabular_checks.StringMismatch
199
202
  ).import_path
203
+ TABULAR_CLASS_IMBALANCE = source_utils.resolve(
204
+ tabular_checks.ClassImbalance
205
+ ).import_path
206
+ TABULAR_PERCENT_OF_NULLS = source_utils.resolve(
207
+ tabular_checks.PercentOfNulls
208
+ ).import_path
200
209
 
201
210
  VISION_IMAGE_PROPERTY_OUTLIERS = source_utils.resolve(
202
211
  vision_checks.ImagePropertyOutliers
@@ -204,6 +213,9 @@ class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
204
213
  VISION_LABEL_PROPERTY_OUTLIERS = source_utils.resolve(
205
214
  vision_checks.LabelPropertyOutliers
206
215
  ).import_path
216
+ VISION_PROPERTY_LABEL_CORRELATION = source_utils.resolve(
217
+ vision_checks.PropertyLabelCorrelation
218
+ ).import_path
207
219
 
208
220
 
209
221
  class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
@@ -246,19 +258,37 @@ class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
246
258
  TABULAR_TRAIN_TEST_FEATURE_DRIFT = source_utils.resolve(
247
259
  tabular_checks.TrainTestFeatureDrift
248
260
  ).import_path
261
+ TABULAR_FEATURE_DRIFT = source_utils.resolve(
262
+ tabular_checks.FeatureDrift
263
+ ).import_path
249
264
  TABULAR_TRAIN_TEST_LABEL_DRIFT = source_utils.resolve(
250
265
  tabular_checks.TrainTestLabelDrift
251
266
  ).import_path
267
+ TABULAR_LABEL_DRIFT = source_utils.resolve(
268
+ tabular_checks.LabelDrift
269
+ ).import_path
252
270
  TABULAR_TRAIN_TEST_SAMPLES_MIX = source_utils.resolve(
253
271
  tabular_checks.TrainTestSamplesMix
254
272
  ).import_path
255
273
  TABULAR_WHOLE_DATASET_DRIFT = source_utils.resolve(
256
274
  tabular_checks.WholeDatasetDrift
257
275
  ).import_path
276
+ TABULAR_NEW_CATEGORY_TRAIN_TEST = source_utils.resolve(
277
+ tabular_checks.NewCategoryTrainTest
278
+ ).import_path
279
+ TABULAR_MULTIVARIATE_DRIFT = source_utils.resolve(
280
+ tabular_checks.MultivariateDrift
281
+ ).import_path
258
282
 
283
+ VISION_PROPERTY_LABEL_CORRELATION_CHANGE = source_utils.resolve(
284
+ vision_checks.PropertyLabelCorrelationChange
285
+ ).import_path
259
286
  VISION_HEATMAP_COMPARISON = source_utils.resolve(
260
287
  vision_checks.HeatmapComparison
261
288
  ).import_path
289
+ VISION_LABEL_DRIFT = source_utils.resolve(
290
+ vision_checks.LabelDrift
291
+ ).import_path
262
292
  VISION_IMAGE_DATASET_DRIFT = source_utils.resolve(
263
293
  vision_checks.ImageDatasetDrift
264
294
  ).import_path
@@ -268,9 +298,6 @@ class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
268
298
  VISION_NEW_LABELS = source_utils.resolve(
269
299
  vision_checks.NewLabels
270
300
  ).import_path
271
- VISION_TRAIN_TEST_LABEL_DRIFT = source_utils.resolve(
272
- vision_checks.TrainTestLabelDrift
273
- ).import_path
274
301
 
275
302
 
276
303
  class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
@@ -296,6 +323,12 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
296
323
  TABULAR_MODEL_INFERENCE_TIME = source_utils.resolve(
297
324
  tabular_checks.ModelInferenceTime
298
325
  ).import_path
326
+ TABULAR_MODEL_INFO = source_utils.resolve(
327
+ tabular_checks.ModelInfo
328
+ ).import_path
329
+ TABULAR_PERFORMANCE_BIAS = source_utils.resolve(
330
+ tabular_checks.model_evaluation.PerformanceBias
331
+ ).import_path
299
332
  TABULAR_REGRESSION_ERROR_DISTRIBUTION = source_utils.resolve(
300
333
  tabular_checks.RegressionErrorDistribution
301
334
  ).import_path
@@ -308,6 +341,18 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
308
341
  TABULAR_SEGMENT_PERFORMANCE = source_utils.resolve(
309
342
  tabular_checks.SegmentPerformance
310
343
  ).import_path
344
+ TABULAR_WEAK_SEGMENT_PERFORMANCE = source_utils.resolve(
345
+ tabular_checks.WeakSegmentsPerformance
346
+ ).import_path
347
+ TABULAR_SINGLE_DATASET_PERFORMANCE = source_utils.resolve(
348
+ tabular_checks.SingleDatasetPerformance
349
+ ).import_path
350
+ TABULAR_TRAIN_TEST_PERFORMANCE = source_utils.resolve(
351
+ tabular_checks.TrainTestPerformance
352
+ ).import_path
353
+ TABULAR_MULTI_MODEL_PERFORMANCE_REPORT = source_utils.resolve(
354
+ tabular_checks.MultiModelPerformanceReport
355
+ ).import_path
311
356
 
312
357
  VISION_CONFUSION_MATRIX_REPORT = source_utils.resolve(
313
358
  vision_checks.ConfusionMatrixReport
@@ -318,6 +363,12 @@ class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
318
363
  VISION_MEAN_AVERAGE_RECALL_REPORT = source_utils.resolve(
319
364
  vision_checks.MeanAverageRecallReport
320
365
  ).import_path
366
+ VISION_SINGLE_DATASET_PERFORMANCE = source_utils.resolve(
367
+ vision_checks.SingleDatasetPerformance
368
+ ).import_path
369
+ VISION_WEAK_SEGMENT_PERFORMANCE = source_utils.resolve(
370
+ vision_checks.WeakSegmentsPerformance
371
+ ).import_path
321
372
 
322
373
 
323
374
  class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
@@ -343,6 +394,9 @@ class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
343
394
  TABULAR_TRAIN_TEST_PREDICTION_DRIFT = source_utils.resolve(
344
395
  tabular_checks.TrainTestPredictionDrift
345
396
  ).import_path
397
+ TABULAR_PREDICTION_DRIFT = source_utils.resolve(
398
+ tabular_checks.PredictionDrift
399
+ ).import_path
346
400
  TABULAR_UNUSED_FEATURES = source_utils.resolve(
347
401
  tabular_checks.UnusedFeatures
348
402
  ).import_path
@@ -356,3 +410,6 @@ class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
356
410
  VISION_TRAIN_TEST_PREDICTION_DRIFT = source_utils.resolve(
357
411
  vision_checks.TrainTestPredictionDrift
358
412
  ).import_path
413
+ VISION_PREDICTION_DRIFT = source_utils.resolve(
414
+ vision_checks.PredictionDrift
415
+ ).import_path