zenml-nightly 0.62.0.dev20240729__py3-none-any.whl → 0.64.0.dev20240809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +2 -2
- RELEASE_NOTES.md +120 -0
- zenml/VERSION +1 -1
- zenml/__init__.py +0 -4
- zenml/actions/pipeline_run/pipeline_run_action.py +19 -17
- zenml/analytics/enums.py +4 -6
- zenml/cli/__init__.py +28 -76
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +54 -61
- zenml/cli/stack.py +6 -8
- zenml/cli/web_login.py +8 -0
- zenml/client.py +232 -103
- zenml/config/build_configuration.py +43 -17
- zenml/config/compiler.py +14 -22
- zenml/config/docker_settings.py +80 -57
- zenml/config/pipeline_run_configuration.py +3 -0
- zenml/config/server_config.py +3 -0
- zenml/config/source.py +60 -1
- zenml/constants.py +11 -2
- zenml/entrypoints/base_entrypoint_configuration.py +53 -8
- zenml/enums.py +4 -1
- zenml/environment.py +25 -9
- zenml/image_builders/base_image_builder.py +1 -1
- zenml/image_builders/build_context.py +25 -72
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +13 -4
- zenml/integrations/azure/__init__.py +4 -0
- zenml/integrations/azure/flavors/__init__.py +11 -0
- zenml/integrations/azure/flavors/azureml_orchestrator_flavor.py +263 -0
- zenml/{_hub → integrations/azure/orchestrators}/__init__.py +7 -2
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +544 -0
- zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py +86 -0
- zenml/integrations/azure/step_operators/azureml_step_operator.py +3 -0
- zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +20 -2
- zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +19 -13
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +7 -2
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +123 -6
- zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +1 -1
- zenml/integrations/mlflow/__init__.py +1 -1
- zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +3 -1
- zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +3 -0
- zenml/logger.py +13 -0
- zenml/models/__init__.py +26 -22
- zenml/models/v2/base/filter.py +32 -0
- zenml/models/v2/core/pipeline.py +73 -89
- zenml/models/v2/core/pipeline_build.py +15 -11
- zenml/models/v2/core/pipeline_deployment.py +72 -24
- zenml/models/v2/core/pipeline_run.py +65 -1
- zenml/models/v2/core/run_template.py +393 -0
- zenml/models/v2/core/server_settings.py +12 -0
- zenml/models/v2/core/user.py +0 -21
- zenml/models/v2/misc/server_models.py +7 -1
- zenml/models/v2/misc/stack_deployment.py +5 -0
- zenml/models/v2/misc/user_auth.py +0 -7
- zenml/new/pipelines/build_utils.py +220 -89
- zenml/new/pipelines/code_archive.py +157 -0
- zenml/new/pipelines/pipeline.py +46 -78
- zenml/new/pipelines/run_utils.py +79 -1
- zenml/post_execution/pipeline.py +1 -4
- zenml/service_connectors/service_connector_utils.py +18 -2
- zenml/stack_deployments/aws_stack_deployment.py +32 -8
- zenml/stack_deployments/azure_stack_deployment.py +122 -10
- zenml/stack_deployments/gcp_stack_deployment.py +36 -7
- zenml/stack_deployments/stack_deployment.py +23 -7
- zenml/steps/base_step.py +3 -0
- zenml/steps/utils.py +0 -4
- zenml/utils/archivable.py +149 -0
- zenml/utils/code_utils.py +244 -0
- zenml/utils/notebook_utils.py +122 -0
- zenml/utils/package_utils.py +39 -0
- zenml/utils/pipeline_docker_image_builder.py +3 -96
- zenml/utils/source_utils.py +109 -1
- zenml/zen_server/dashboard/assets/{404-B_YdvmwS.js → 404-CRAA_Lew.js} +1 -1
- zenml/zen_server/dashboard/assets/@radix-BXWm7HOa.js +85 -0
- zenml/zen_server/dashboard/assets/{@react-router-CO-OsFwI.js → @react-router-l3lMcXA2.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-l_1hUr1S.js → @reactflow-CeVxyqYT.js} +2 -2
- zenml/zen_server/dashboard/assets/{@tanstack-DYiOyJUL.js → @tanstack-FmcYZMuX.js} +4 -4
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-ErO9aOgK.js +1 -0
- zenml/zen_server/dashboard/assets/{AwarenessChannel-CFg5iX4Z.js → AwarenessChannel-CLXo5rKM.js} +1 -1
- zenml/zen_server/dashboard/assets/{CodeSnippet-Dvkx_82E.js → CodeSnippet-D0VLxT2A.js} +2 -2
- zenml/zen_server/dashboard/assets/CollapsibleCard-BaUPiVg0.js +1 -0
- zenml/zen_server/dashboard/assets/{Commands-DoN1xrEq.js → Commands-JrcZK-3j.js} +1 -1
- zenml/zen_server/dashboard/assets/CopyButton-Dbo52T1K.js +2 -0
- zenml/zen_server/dashboard/assets/{CsvVizualization-Ck-nZ43m.js → CsvVizualization-D3kAypDj.js} +3 -3
- zenml/zen_server/dashboard/assets/DisplayDate-DizbSeT-.js +1 -0
- zenml/zen_server/dashboard/assets/EditSecretDialog-Bd7mFLS4.js +1 -0
- zenml/zen_server/dashboard/assets/{EmptyState-BMLnFVlB.js → EmptyState-BHblM39I.js} +1 -1
- zenml/zen_server/dashboard/assets/{Error-kLtljEOM.js → Error-C6LeJSER.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-DguLLgTK.js → ExecutionStatus-jH4OrWBq.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-BXUMP21n.js → Helpbox-aAB2XP-z.js} +1 -1
- zenml/zen_server/dashboard/assets/{Infobox-DSt0O-dm.js → Infobox-BQ0aty32.js} +1 -1
- zenml/zen_server/dashboard/assets/{InlineAvatar-xsrsIGE-.js → InlineAvatar-DpTLgM3Q.js} +1 -1
- zenml/zen_server/dashboard/assets/Lock-CNyJvf2r.js +1 -0
- zenml/zen_server/dashboard/assets/{MarkdownVisualization-xp3hhULl.js → MarkdownVisualization-Bajxn0HY.js} +1 -1
- zenml/zen_server/dashboard/assets/NumberBox-BmKE0qnO.js +1 -0
- zenml/zen_server/dashboard/assets/{PasswordChecker-DUveqlva.js → PasswordChecker-yGGoJSB-.js} +1 -1
- zenml/zen_server/dashboard/assets/ProviderRadio-BBqkIuTd.js +1 -0
- zenml/zen_server/dashboard/assets/RadioItem-xLhXoiFV.js +1 -0
- zenml/zen_server/dashboard/assets/SearchField-C9R0mdaX.js +1 -0
- zenml/zen_server/dashboard/assets/{SetPassword-BXGTWiwj.js → SetPassword-52sNxNiO.js} +1 -1
- zenml/zen_server/dashboard/assets/{SuccessStep-DZC60t0x.js → SuccessStep-DlkItqYG.js} +1 -1
- zenml/zen_server/dashboard/assets/Tick-uxv80Q6a.js +1 -0
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-DGvwFWO1.js → UpdatePasswordSchemas-oN4G3sKz.js} +1 -1
- zenml/zen_server/dashboard/assets/{aws-BgKTfTfx.js → aws-0_3UsPif.js} +1 -1
- zenml/zen_server/dashboard/assets/{check-circle-i56092KI.js → check-circle-1_I207rW.js} +1 -1
- zenml/zen_server/dashboard/assets/chevron-down-BpaF8JqM.js +1 -0
- zenml/zen_server/dashboard/assets/{chevron-right-double-CZBOf6JM.js → chevron-right-double-Dk8e2L99.js} +1 -1
- zenml/zen_server/dashboard/assets/{cloud-only-C_yFCAkP.js → cloud-only-BkUuI0lZ.js} +1 -1
- zenml/zen_server/dashboard/assets/components-Br2ezRib.js +1 -0
- zenml/zen_server/dashboard/assets/{copy-BXNk6BjL.js → copy-f3XGPPxt.js} +1 -1
- zenml/zen_server/dashboard/assets/{database-1xWSgZfO.js → database-cXYNX9tt.js} +1 -1
- zenml/zen_server/dashboard/assets/{docker-CQMVm_4d.js → docker-8uj__HHK.js} +1 -1
- zenml/zen_server/dashboard/assets/dots-horizontal-sKQlWEni.js +1 -0
- zenml/zen_server/dashboard/assets/edit-C0MVvPD2.js +1 -0
- zenml/zen_server/dashboard/assets/{file-text-CqD_iu6l.js → file-text-B9JibxTs.js} +1 -1
- zenml/zen_server/dashboard/assets/{help-bu_DgLKI.js → help-FuHlZwn0.js} +1 -1
- zenml/zen_server/dashboard/assets/{index-rK_Wuy2W.js → index-Bd1xgUQG.js} +1 -1
- zenml/zen_server/dashboard/assets/index-DaGknux4.css +1 -0
- zenml/zen_server/dashboard/assets/{index-BczVOqUf.js → index-DhIZtpxB.js} +5 -5
- zenml/zen_server/dashboard/assets/index.esm-DT4uyn2i.js +1 -0
- zenml/zen_server/dashboard/assets/layout-D6oiSbfd.js +1 -0
- zenml/zen_server/dashboard/assets/{login-mutation-CrHrndTI.js → login-mutation-13A_JSVA.js} +1 -1
- zenml/zen_server/dashboard/assets/{logs-D8k8BVFf.js → logs-CgeE2vZP.js} +1 -1
- zenml/zen_server/dashboard/assets/{not-found-DYa4pC-C.js → not-found-B0Mmb90p.js} +1 -1
- zenml/zen_server/dashboard/assets/package-DdkziX79.js +1 -0
- zenml/zen_server/dashboard/assets/page-7-v2OBm-.js +1 -0
- zenml/zen_server/dashboard/assets/{page-MFQyIJd3.js → page-B3ozwdD1.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-BkuQDIf-.js → page-BGwA9B1M.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-1iL8aMqs.js → page-BkjAUyTA.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BnacgBiy.js +1 -0
- zenml/zen_server/dashboard/assets/page-BxF_KMQ3.js +2 -0
- zenml/zen_server/dashboard/assets/page-C4POHC0K.js +1 -0
- zenml/zen_server/dashboard/assets/page-C9kudd44.js +9 -0
- zenml/zen_server/dashboard/assets/page-CA1j3GpJ.js +1 -0
- zenml/zen_server/dashboard/assets/page-CCY6yfmu.js +1 -0
- zenml/zen_server/dashboard/assets/page-CgTe7Bme.js +1 -0
- zenml/zen_server/dashboard/assets/{page-8a4UMKXZ.js → page-Cgn-6v2Y.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CxQmQqDw.js +1 -0
- zenml/zen_server/dashboard/assets/page-D2Goey3H.js +1 -0
- zenml/zen_server/dashboard/assets/page-DLpOnf7u.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BhgCDInH.js → page-DSTQnBk-.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-1h_sD1jz.js → page-DTysUGOy.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-2grKx_MY.js → page-D_EXUFJb.js} +1 -1
- zenml/zen_server/dashboard/assets/page-Db15QzsM.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BDns21Iz.js → page-DugsjcQ_.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-C6-UGEbH.js → page-OFKSPyN7.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-BkeAAYwp.js → page-RnG-qhv9.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CCNRIt_f.js → page-T2BtjwPl.js} +1 -1
- zenml/zen_server/dashboard/assets/page-TXe1Eo3Z.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BnaevhnB.js → page-YiF_fNbe.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-uA5prJGY.js → page-hQaiQXfg.js} +1 -1
- zenml/zen_server/dashboard/assets/persist-3-5nOJ6m.js +1 -0
- zenml/zen_server/dashboard/assets/{play-circle-CNtZKDnW.js → play-circle-XSkLR12B.js} +1 -1
- zenml/zen_server/dashboard/assets/plus-FB9-lEq_.js +1 -0
- zenml/zen_server/dashboard/assets/refresh-COb6KYDi.js +1 -0
- zenml/zen_server/dashboard/assets/sharedSchema-BoYx_B_L.js +14 -0
- zenml/zen_server/dashboard/assets/{stack-detail-query-Cficsl6d.js → stack-detail-query-B-US_-wa.js} +1 -1
- zenml/zen_server/dashboard/assets/{terminal-By9cErXc.js → terminal-grtjrIEJ.js} +1 -1
- zenml/zen_server/dashboard/assets/trash-Cd5CSFqA.js +1 -0
- zenml/zen_server/dashboard/assets/{update-server-settings-mutation-7d8xi1tS.js → update-server-settings-mutation-B8GB_ubU.js} +1 -1
- zenml/zen_server/dashboard/assets/{url-D7mAQGUM.js → url-hcMJkz8p.js} +1 -1
- zenml/zen_server/dashboard/assets/{zod-BhoGpZ63.js → zod-CnykDKJj.js} +1 -1
- zenml/zen_server/dashboard/index.html +7 -7
- zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
- zenml/zen_server/dashboard_legacy/index.html +1 -1
- zenml/zen_server/dashboard_legacy/{precache-manifest.12246c7548e71e2c4438e496360de80c.js → precache-manifest.9c473c96a43298343a7ce1256183123b.js} +4 -4
- zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
- zenml/zen_server/dashboard_legacy/static/js/{main.3b27024b.chunk.js → main.463c90b9.chunk.js} +2 -2
- zenml/zen_server/dashboard_legacy/static/js/{main.3b27024b.chunk.js.map → main.463c90b9.chunk.js.map} +1 -1
- zenml/zen_server/deploy/helm/Chart.yaml +1 -1
- zenml/zen_server/deploy/helm/README.md +2 -2
- zenml/zen_server/rbac/models.py +1 -0
- zenml/zen_server/rbac/utils.py +4 -0
- zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -66
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +2 -53
- zenml/zen_server/routers/pipelines_endpoints.py +1 -74
- zenml/zen_server/routers/run_templates_endpoints.py +212 -0
- zenml/zen_server/routers/stack_deployment_endpoints.py +6 -0
- zenml/zen_server/routers/users_endpoints.py +0 -7
- zenml/zen_server/routers/workspaces_endpoints.py +79 -0
- zenml/zen_server/{pipeline_deployment → template_execution}/runner_entrypoint_configuration.py +1 -8
- zenml/zen_server/{pipeline_deployment → template_execution}/utils.py +214 -92
- zenml/zen_server/utils.py +77 -2
- zenml/zen_server/zen_server_api.py +54 -2
- zenml/zen_stores/base_zen_store.py +7 -1
- zenml/zen_stores/migrations/versions/0.63.0_release.py +23 -0
- zenml/zen_stores/migrations/versions/0.64.0_release.py +23 -0
- zenml/zen_stores/migrations/versions/026d4577b6a0_add_code_path.py +39 -0
- zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py +51 -0
- zenml/zen_stores/migrations/versions/7d1919bb1ef0_add_run_templates.py +100 -0
- zenml/zen_stores/migrations/versions/909550c7c4da_remove_user_hub_token.py +36 -0
- zenml/zen_stores/migrations/versions/b59aa68fdb1f_simplify_pipelines.py +139 -0
- zenml/zen_stores/rest_zen_store.py +112 -39
- zenml/zen_stores/schemas/__init__.py +2 -0
- zenml/zen_stores/schemas/pipeline_build_schemas.py +3 -3
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +32 -2
- zenml/zen_stores/schemas/pipeline_run_schemas.py +29 -3
- zenml/zen_stores/schemas/pipeline_schemas.py +29 -30
- zenml/zen_stores/schemas/run_template_schemas.py +264 -0
- zenml/zen_stores/schemas/server_settings_schemas.py +2 -0
- zenml/zen_stores/schemas/step_run_schemas.py +11 -4
- zenml/zen_stores/schemas/user_schemas.py +0 -2
- zenml/zen_stores/sql_zen_store.py +389 -151
- zenml/zen_stores/template_utils.py +261 -0
- zenml/zen_stores/zen_store_interface.py +93 -20
- {zenml_nightly-0.62.0.dev20240729.dist-info → zenml_nightly-0.64.0.dev20240809.dist-info}/METADATA +3 -3
- {zenml_nightly-0.62.0.dev20240729.dist-info → zenml_nightly-0.64.0.dev20240809.dist-info}/RECORD +211 -184
- zenml/_hub/client.py +0 -289
- zenml/_hub/constants.py +0 -21
- zenml/_hub/utils.py +0 -79
- zenml/cli/hub.py +0 -1116
- zenml/models/v2/core/pipeline_namespace.py +0 -113
- zenml/models/v2/misc/hub_plugin_models.py +0 -79
- zenml/new/pipelines/deserialization_utils.py +0 -292
- zenml/zen_server/dashboard/assets/@radix-CFOkMR_E.js +0 -85
- zenml/zen_server/dashboard/assets/CollapsibleCard-opiuBHHc.js +0 -1
- zenml/zen_server/dashboard/assets/CopyButton-Cr7xYEPb.js +0 -2
- zenml/zen_server/dashboard/assets/DisplayDate-DYgIjlDF.js +0 -1
- zenml/zen_server/dashboard/assets/Pagination-C6X-mifw.js +0 -1
- zenml/zen_server/dashboard/assets/index-EpMIKgrI.css +0 -1
- zenml/zen_server/dashboard/assets/index.esm-Corw4lXQ.js +0 -1
- zenml/zen_server/dashboard/assets/package-B3fWP-Dh.js +0 -1
- zenml/zen_server/dashboard/assets/page-5NCOHOsy.js +0 -1
- zenml/zen_server/dashboard/assets/page-B6h3iaHJ.js +0 -1
- zenml/zen_server/dashboard/assets/page-Bi-wtWiO.js +0 -5
- zenml/zen_server/dashboard/assets/page-Bq0YxkLV.js +0 -1
- zenml/zen_server/dashboard/assets/page-Bs2F4eoD.js +0 -2
- zenml/zen_server/dashboard/assets/page-CHNxpz3n.js +0 -1
- zenml/zen_server/dashboard/assets/page-DgorQFqi.js +0 -1
- zenml/zen_server/dashboard/assets/page-K8ebxVIs.js +0 -1
- zenml/zen_server/dashboard/assets/page-TgCF0P_U.js +0 -1
- zenml/zen_server/dashboard/assets/page-ZnCEe-eK.js +0 -9
- zenml/zen_server/dashboard/assets/persist-D7HJNBWx.js +0 -1
- zenml/zen_server/dashboard/assets/plus-C8WOyCzt.js +0 -1
- zenml/zen_server/dashboard/assets/secrets-video-OBJ6irhH.svg +0 -21
- zenml/zen_server/dashboard/assets/stacks-video-7gfxpAq4.svg +0 -21
- /zenml/zen_server/{pipeline_deployment → template_execution}/__init__.py +0 -0
- /zenml/zen_server/{pipeline_deployment → template_execution}/workload_manager_interface.py +0 -0
- {zenml_nightly-0.62.0.dev20240729.dist-info → zenml_nightly-0.64.0.dev20240809.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.62.0.dev20240729.dist-info → zenml_nightly-0.64.0.dev20240809.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.62.0.dev20240729.dist-info → zenml_nightly-0.64.0.dev20240809.dist-info}/entry_points.txt +0 -0
@@ -52,6 +52,7 @@ from zenml.orchestrators.utils import get_orchestrator_run_name
|
|
52
52
|
from zenml.orchestrators.wheeled_orchestrator import WheeledOrchestrator
|
53
53
|
from zenml.stack import StackValidator
|
54
54
|
from zenml.utils import io_utils
|
55
|
+
from zenml.utils.package_utils import clean_requirements
|
55
56
|
from zenml.utils.pipeline_docker_image_builder import (
|
56
57
|
PipelineDockerImageBuilder,
|
57
58
|
)
|
@@ -229,6 +230,9 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
229
230
|
ValueError: If the schedule is not set or if the cron expression
|
230
231
|
is not set.
|
231
232
|
"""
|
233
|
+
settings = cast(
|
234
|
+
DatabricksOrchestratorSettings, self.get_settings(deployment)
|
235
|
+
)
|
232
236
|
if deployment.schedule:
|
233
237
|
if (
|
234
238
|
deployment.schedule.catchup
|
@@ -246,7 +250,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
246
250
|
)
|
247
251
|
if (
|
248
252
|
deployment.schedule.cron_expression
|
249
|
-
and
|
253
|
+
and settings.schedule_timezone is None
|
250
254
|
):
|
251
255
|
raise ValueError(
|
252
256
|
"Property `schedule_timezone` must be set when passing "
|
@@ -321,7 +325,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
321
325
|
f"{deployment_id}_{step_name}",
|
322
326
|
ZENML_STEP_DEFAULT_ENTRYPOINT_COMMAND,
|
323
327
|
arguments,
|
324
|
-
requirements,
|
328
|
+
clean_requirements(requirements),
|
325
329
|
depends_on=upstream_steps,
|
326
330
|
zenml_project_wheel=zenml_project_wheel,
|
327
331
|
job_cluster_key=job_cluster_key,
|
@@ -366,7 +370,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
366
370
|
|
367
371
|
# Construct the env variables for the pipeline
|
368
372
|
env_vars = environment.copy()
|
369
|
-
spark_env_vars =
|
373
|
+
spark_env_vars = settings.spark_env_vars
|
370
374
|
if spark_env_vars:
|
371
375
|
for key, value in spark_env_vars.items():
|
372
376
|
env_vars[key] = value
|
@@ -385,6 +389,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
385
389
|
job_cluster_key = self.sanitize_name(f"{deployment_id}")
|
386
390
|
self._upload_and_run_pipeline(
|
387
391
|
pipeline_name=orchestrator_run_name,
|
392
|
+
settings=settings,
|
388
393
|
tasks=_construct_databricks_pipeline(
|
389
394
|
databricks_wheel_path, job_cluster_key
|
390
395
|
),
|
@@ -396,6 +401,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
396
401
|
def _upload_and_run_pipeline(
|
397
402
|
self,
|
398
403
|
pipeline_name: str,
|
404
|
+
settings: DatabricksOrchestratorSettings,
|
399
405
|
tasks: List[DatabricksTask],
|
400
406
|
env_vars: Dict[str, str],
|
401
407
|
job_cluster_key: str,
|
@@ -409,6 +415,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
409
415
|
env_vars: The environment variables.
|
410
416
|
job_cluster_key: The ID of the Databricks job cluster.
|
411
417
|
schedule: The schedule to run the pipeline
|
418
|
+
settings: The settings for the Databricks orchestrator.
|
412
419
|
|
413
420
|
Raises:
|
414
421
|
ValueError: If the `Job Compute` policy is not found.
|
@@ -416,12 +423,12 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
416
423
|
|
417
424
|
"""
|
418
425
|
databricks_client = self._get_databricks_client()
|
419
|
-
spark_conf =
|
426
|
+
spark_conf = settings.spark_conf or {}
|
420
427
|
spark_conf[
|
421
428
|
"spark.databricks.driver.dbfsLibraryInstallationAllowed"
|
422
429
|
] = "true"
|
423
430
|
|
424
|
-
policy_id =
|
431
|
+
policy_id = settings.policy_id or None
|
425
432
|
for policy in databricks_client.cluster_policies.list():
|
426
433
|
if policy.name == "Job Compute":
|
427
434
|
policy_id = policy.policy_id
|
@@ -432,17 +439,16 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
432
439
|
job_cluster = JobCluster(
|
433
440
|
job_cluster_key=job_cluster_key,
|
434
441
|
new_cluster=ClusterSpec(
|
435
|
-
spark_version=
|
442
|
+
spark_version=settings.spark_version
|
436
443
|
or DATABRICKS_SPARK_DEFAULT_VERSION,
|
437
|
-
num_workers=
|
438
|
-
node_type_id=
|
439
|
-
or "Standard_D4s_v5",
|
444
|
+
num_workers=settings.num_workers,
|
445
|
+
node_type_id=settings.node_type_id or "Standard_D4s_v5",
|
440
446
|
policy_id=policy_id,
|
441
447
|
autoscale=AutoScale(
|
442
|
-
min_workers=
|
443
|
-
max_workers=
|
448
|
+
min_workers=settings.autoscale[0],
|
449
|
+
max_workers=settings.autoscale[1],
|
444
450
|
),
|
445
|
-
single_user_name=
|
451
|
+
single_user_name=settings.single_user_name,
|
446
452
|
spark_env_vars=env_vars,
|
447
453
|
spark_conf=spark_conf,
|
448
454
|
workload_type=WorkloadType(
|
@@ -451,7 +457,7 @@ class DatabricksOrchestrator(WheeledOrchestrator):
|
|
451
457
|
),
|
452
458
|
)
|
453
459
|
if schedule and schedule.cron_expression:
|
454
|
-
schedule_timezone =
|
460
|
+
schedule_timezone = settings.schedule_timezone
|
455
461
|
if schedule_timezone:
|
456
462
|
databricks_schedule = CronSchedule(
|
457
463
|
quartz_cron_expression=schedule.cron_expression,
|
@@ -574,6 +574,9 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
574
574
|
run_name: Orchestrator run name.
|
575
575
|
settings: Pipeline level settings for this orchestrator.
|
576
576
|
schedule: The schedule the pipeline will run on.
|
577
|
+
|
578
|
+
Raises:
|
579
|
+
RuntimeError: If the Vertex Orchestrator fails to provision or any other Runtime errors
|
577
580
|
"""
|
578
581
|
# We have to replace the hyphens in the run name with underscores
|
579
582
|
# and lower case the string, because the Vertex AI Pipelines service
|
@@ -656,13 +659,15 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
656
659
|
run.wait()
|
657
660
|
|
658
661
|
except google_exceptions.ClientError as e:
|
659
|
-
logger.
|
660
|
-
|
662
|
+
logger.error("Failed to create the Vertex AI Pipelines job: %s", e)
|
663
|
+
raise RuntimeError(
|
664
|
+
f"Failed to create the Vertex AI Pipelines job: {e}"
|
661
665
|
)
|
662
666
|
except RuntimeError as e:
|
663
667
|
logger.error(
|
664
668
|
"The Vertex AI Pipelines job execution has failed: %s", e
|
665
669
|
)
|
670
|
+
raise
|
666
671
|
|
667
672
|
def get_orchestrator_run_id(self) -> str:
|
668
673
|
"""Returns the active orchestrator run id.
|
@@ -496,6 +496,108 @@ class GCPAuthenticationMethods(StrEnum):
|
|
496
496
|
IMPERSONATION = "impersonation"
|
497
497
|
|
498
498
|
|
499
|
+
try:
|
500
|
+
from google.auth.aws import _DefaultAwsSecurityCredentialsSupplier
|
501
|
+
|
502
|
+
class ZenMLAwsSecurityCredentialsSupplier(
|
503
|
+
_DefaultAwsSecurityCredentialsSupplier # type: ignore[misc]
|
504
|
+
):
|
505
|
+
"""An improved version of the GCP external account credential supplier for AWS.
|
506
|
+
|
507
|
+
The original GCP external account credential supplier only provides
|
508
|
+
rudimentary support for extracting AWS credentials from environment
|
509
|
+
variables or the AWS metadata service. This version improves on that by
|
510
|
+
using the boto3 library itself (if available), which uses the entire range
|
511
|
+
of implicit authentication features packed into it.
|
512
|
+
|
513
|
+
Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
|
514
|
+
not supported for EKS pods and the EC2 attached role credentials are
|
515
|
+
used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).
|
516
|
+
"""
|
517
|
+
|
518
|
+
def get_aws_security_credentials(
|
519
|
+
self, context: Any, request: Any
|
520
|
+
) -> gcp_aws.AwsSecurityCredentials:
|
521
|
+
"""Get the security credentials from the local environment.
|
522
|
+
|
523
|
+
This method is a copy of the original method from the
|
524
|
+
`google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
|
525
|
+
been modified to use the boto3 library to extract the AWS credentials
|
526
|
+
from the local environment.
|
527
|
+
|
528
|
+
Args:
|
529
|
+
context: The context to use to get the security credentials.
|
530
|
+
request: The request to use to get the security credentials.
|
531
|
+
|
532
|
+
Returns:
|
533
|
+
The AWS temporary security credentials.
|
534
|
+
"""
|
535
|
+
try:
|
536
|
+
import boto3
|
537
|
+
|
538
|
+
session = boto3.Session()
|
539
|
+
credentials = session.get_credentials()
|
540
|
+
if credentials is not None:
|
541
|
+
creds = credentials.get_frozen_credentials()
|
542
|
+
return gcp_aws.AwsSecurityCredentials(
|
543
|
+
creds.access_key,
|
544
|
+
creds.secret_key,
|
545
|
+
creds.token,
|
546
|
+
)
|
547
|
+
except ImportError:
|
548
|
+
pass
|
549
|
+
|
550
|
+
logger.debug(
|
551
|
+
"Failed to extract AWS credentials from the local environment "
|
552
|
+
"using the boto3 library. Falling back to the original "
|
553
|
+
"implementation."
|
554
|
+
)
|
555
|
+
|
556
|
+
return super().get_aws_security_credentials(context, request)
|
557
|
+
|
558
|
+
def get_aws_region(self, context: Any, request: Any) -> str:
|
559
|
+
"""Get the AWS region from the local environment.
|
560
|
+
|
561
|
+
This method is a copy of the original method from the
|
562
|
+
`google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
|
563
|
+
been modified to use the boto3 library to extract the AWS
|
564
|
+
region from the local environment.
|
565
|
+
|
566
|
+
Args:
|
567
|
+
context: The context to use to get the security credentials.
|
568
|
+
request: The request to use to get the security credentials.
|
569
|
+
|
570
|
+
Returns:
|
571
|
+
The AWS region.
|
572
|
+
"""
|
573
|
+
try:
|
574
|
+
import boto3
|
575
|
+
|
576
|
+
session = boto3.Session()
|
577
|
+
if session.region_name:
|
578
|
+
return session.region_name # type: ignore[no-any-return]
|
579
|
+
except ImportError:
|
580
|
+
pass
|
581
|
+
|
582
|
+
logger.debug(
|
583
|
+
"Failed to extract AWS region from the local environment "
|
584
|
+
"using the boto3 library. Falling back to the original "
|
585
|
+
"implementation."
|
586
|
+
)
|
587
|
+
|
588
|
+
return super().get_aws_region( # type: ignore[no-any-return]
|
589
|
+
context, request
|
590
|
+
)
|
591
|
+
|
592
|
+
except ImportError:
|
593
|
+
# The `google.auth.aws._DefaultAwsSecurityCredentialsSupplier`
|
594
|
+
# class has been introduced in the `google-auth` library version 2.29.0.
|
595
|
+
# Before that, the AWS logic was part of the `google.auth.awsCredentials`
|
596
|
+
# class itself.
|
597
|
+
ZenMLAwsSecurityCredentialsSupplier = None # type: ignore[assignment,misc]
|
598
|
+
pass
|
599
|
+
|
600
|
+
|
499
601
|
class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignore[misc]
|
500
602
|
"""An improved version of the GCP external account credential for AWS.
|
501
603
|
|
@@ -508,6 +610,13 @@ class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignor
|
|
508
610
|
Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
|
509
611
|
not supported for EKS pods and the EC2 attached role credentials are
|
510
612
|
used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).
|
613
|
+
|
614
|
+
IMPORTANT: subclassing this class only works with the `google-auth` library
|
615
|
+
version lower than 2.29.0. Starting from version 2.29.0, the AWS logic
|
616
|
+
has been moved to a separate `google.auth.aws._DefaultAwsSecurityCredentialsSupplier`
|
617
|
+
class that can be subclassed instead and supplied as the
|
618
|
+
`aws_security_credentials_supplier` parameter to the
|
619
|
+
`google.auth.aws.Credentials` class.
|
511
620
|
"""
|
512
621
|
|
513
622
|
def _get_security_credentials(
|
@@ -539,12 +648,14 @@ class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignor
|
|
539
648
|
"secret_access_key": creds.secret_key,
|
540
649
|
"security_token": creds.token,
|
541
650
|
}
|
542
|
-
except
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
651
|
+
except ImportError:
|
652
|
+
pass
|
653
|
+
|
654
|
+
logger.debug(
|
655
|
+
"Failed to extract AWS credentials from the local environment "
|
656
|
+
"using the boto3 library. Falling back to the original "
|
657
|
+
"implementation."
|
658
|
+
)
|
548
659
|
|
549
660
|
return super()._get_security_credentials( # type: ignore[no-any-return]
|
550
661
|
request, imdsv2_session_token
|
@@ -1126,6 +1237,12 @@ class GCPServiceConnector(ServiceConnector):
|
|
1126
1237
|
account_info.get("subject_token_type")
|
1127
1238
|
== _AWS_SUBJECT_TOKEN_TYPE
|
1128
1239
|
):
|
1240
|
+
if ZenMLAwsSecurityCredentialsSupplier is not None:
|
1241
|
+
account_info["aws_security_credentials_supplier"] = (
|
1242
|
+
ZenMLAwsSecurityCredentialsSupplier(
|
1243
|
+
account_info.pop("credential_source"),
|
1244
|
+
)
|
1245
|
+
)
|
1129
1246
|
credentials = (
|
1130
1247
|
ZenMLGCPAWSExternalAccountCredentials.from_info(
|
1131
1248
|
account_info,
|
@@ -295,7 +295,7 @@ class KanikoImageBuilder(BaseImageBuilder):
|
|
295
295
|
logger.debug("Writing build context to process stdin.")
|
296
296
|
assert process.stdin
|
297
297
|
with process.stdin as _, tempfile.TemporaryFile(mode="w+b") as f:
|
298
|
-
build_context.write_archive(f,
|
298
|
+
build_context.write_archive(f, use_gzip=True)
|
299
299
|
while True:
|
300
300
|
data = f.read(1024)
|
301
301
|
if not data:
|
@@ -57,6 +57,7 @@ DATABRICKS_HOST = "DATABRICKS_HOST"
|
|
57
57
|
DATABRICKS_USERNAME = "DATABRICKS_USERNAME"
|
58
58
|
DATABRICKS_PASSWORD = "DATABRICKS_PASSWORD"
|
59
59
|
DATABRICKS_TOKEN = "DATABRICKS_TOKEN"
|
60
|
+
DATABRICKS_UNITY_CATALOG = "databricks-uc"
|
60
61
|
|
61
62
|
|
62
63
|
class MLFlowExperimentTracker(BaseExperimentTracker):
|
@@ -285,7 +286,6 @@ class MLFlowExperimentTracker(BaseExperimentTracker):
|
|
285
286
|
"""Configures the MLflow tracking URI and any additional credentials."""
|
286
287
|
tracking_uri = self.get_tracking_uri()
|
287
288
|
mlflow.set_tracking_uri(tracking_uri)
|
288
|
-
mlflow.set_registry_uri(tracking_uri)
|
289
289
|
|
290
290
|
if is_databricks_tracking_uri(tracking_uri):
|
291
291
|
if self.config.databricks_host:
|
@@ -296,6 +296,8 @@ class MLFlowExperimentTracker(BaseExperimentTracker):
|
|
296
296
|
os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
|
297
297
|
if self.config.tracking_token:
|
298
298
|
os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
|
299
|
+
if self.config.enable_unity_catalog:
|
300
|
+
mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG)
|
299
301
|
else:
|
300
302
|
os.environ[MLFLOW_TRACKING_URI] = tracking_uri
|
301
303
|
if self.config.tracking_username:
|
@@ -98,6 +98,8 @@ class MLFlowExperimentTrackerConfig(
|
|
98
98
|
databricks_host: The host of the Databricks workspace with the MLflow
|
99
99
|
managed server to connect to. This is only required if
|
100
100
|
`tracking_uri` value is set to `"databricks"`.
|
101
|
+
enable_unity_catalog: If `True`, will enable the Databricks Unity Catalog for
|
102
|
+
logging and registering models.
|
101
103
|
"""
|
102
104
|
|
103
105
|
tracking_uri: Optional[str] = None
|
@@ -106,6 +108,7 @@ class MLFlowExperimentTrackerConfig(
|
|
106
108
|
tracking_token: Optional[str] = SecretField(default=None)
|
107
109
|
tracking_insecure_tls: bool = False
|
108
110
|
databricks_host: Optional[str] = None
|
111
|
+
enable_unity_catalog: bool = False
|
109
112
|
|
110
113
|
@model_validator(mode="after")
|
111
114
|
def _ensure_authentication_if_necessary(
|
zenml/logger.py
CHANGED
@@ -46,6 +46,7 @@ class CustomFormatter(logging.Formatter):
|
|
46
46
|
cyan: str = "\x1b[1;36m"
|
47
47
|
bold_red: str = "\x1b[31;1m"
|
48
48
|
purple: str = "\x1b[1;35m"
|
49
|
+
blue: str = "\x1b[34m"
|
49
50
|
reset: str = "\x1b[0m"
|
50
51
|
|
51
52
|
format_template: str = (
|
@@ -94,6 +95,18 @@ class CustomFormatter(logging.Formatter):
|
|
94
95
|
+ quoted
|
95
96
|
+ self.COLORS.get(LoggingLevels(record.levelno)),
|
96
97
|
)
|
98
|
+
|
99
|
+
# Format URLs
|
100
|
+
url_pattern = r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
|
101
|
+
urls = re.findall(url_pattern, formatted_message)
|
102
|
+
for url in urls:
|
103
|
+
formatted_message = formatted_message.replace(
|
104
|
+
url,
|
105
|
+
self.reset
|
106
|
+
+ self.blue
|
107
|
+
+ url
|
108
|
+
+ self.COLORS.get(LoggingLevels(record.levelno)),
|
109
|
+
)
|
97
110
|
return formatted_message
|
98
111
|
|
99
112
|
|
zenml/models/__init__.py
CHANGED
@@ -199,12 +199,7 @@ from zenml.models.v2.core.pipeline import (
|
|
199
199
|
PipelineResponse,
|
200
200
|
PipelineResponseBody,
|
201
201
|
PipelineResponseMetadata,
|
202
|
-
|
203
|
-
from zenml.models.v2.core.pipeline_namespace import (
|
204
|
-
PipelineNamespaceResponseBody,
|
205
|
-
PipelineNamespaceResponseMetadata,
|
206
|
-
PipelineNamespaceResponse,
|
207
|
-
PipelineNamespaceFilter,
|
202
|
+
PipelineResponseResources
|
208
203
|
)
|
209
204
|
from zenml.models.v2.core.pipeline_build import (
|
210
205
|
PipelineBuildBase,
|
@@ -230,6 +225,16 @@ from zenml.models.v2.core.pipeline_run import (
|
|
230
225
|
PipelineRunResponse,
|
231
226
|
PipelineRunResponseBody,
|
232
227
|
PipelineRunResponseMetadata,
|
228
|
+
PipelineRunResponseResources
|
229
|
+
)
|
230
|
+
from zenml.models.v2.core.run_template import (
|
231
|
+
RunTemplateRequest,
|
232
|
+
RunTemplateUpdate,
|
233
|
+
RunTemplateResponse,
|
234
|
+
RunTemplateResponseBody,
|
235
|
+
RunTemplateResponseMetadata,
|
236
|
+
RunTemplateResponseResources,
|
237
|
+
RunTemplateFilter,
|
233
238
|
)
|
234
239
|
from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse
|
235
240
|
from zenml.models.v2.core.run_metadata import (
|
@@ -357,13 +362,6 @@ from zenml.models.v2.core.event_source import (
|
|
357
362
|
from zenml.models.v2.misc.user_auth import UserAuthModel
|
358
363
|
from zenml.models.v2.misc.build_item import BuildItem
|
359
364
|
from zenml.models.v2.misc.loaded_visualization import LoadedVisualization
|
360
|
-
from zenml.models.v2.misc.hub_plugin_models import (
|
361
|
-
HubPluginRequestModel,
|
362
|
-
HubPluginResponseModel,
|
363
|
-
HubUserResponseModel,
|
364
|
-
HubPluginBaseModel,
|
365
|
-
PluginStatus,
|
366
|
-
)
|
367
365
|
from zenml.models.v2.misc.external_user import ExternalUserModel
|
368
366
|
from zenml.models.v2.misc.auth_models import (
|
369
367
|
OAuthDeviceAuthorizationRequest,
|
@@ -424,6 +422,7 @@ ModelVersionPipelineRunResponseBody.model_rebuild()
|
|
424
422
|
OAuthDeviceResponseBody.model_rebuild()
|
425
423
|
PipelineResponseBody.model_rebuild()
|
426
424
|
PipelineResponseMetadata.model_rebuild()
|
425
|
+
PipelineResponseResources.model_rebuild()
|
427
426
|
PipelineBuildBase.model_rebuild()
|
428
427
|
PipelineBuildResponseBody.model_rebuild()
|
429
428
|
PipelineBuildResponseMetadata.model_rebuild()
|
@@ -433,6 +432,11 @@ PipelineDeploymentResponseMetadata.model_rebuild()
|
|
433
432
|
PipelineDeploymentResponseResources.model_rebuild()
|
434
433
|
PipelineRunResponseBody.model_rebuild()
|
435
434
|
PipelineRunResponseMetadata.model_rebuild()
|
435
|
+
PipelineRunResponseResources.model_rebuild()
|
436
|
+
RunTemplateResponseBody.model_rebuild()
|
437
|
+
RunTemplateResponseMetadata.model_rebuild()
|
438
|
+
RunTemplateResponseResources.model_rebuild()
|
439
|
+
RunTemplateResponseBody.model_rebuild()
|
436
440
|
RunMetadataResponseBody.model_rebuild()
|
437
441
|
RunMetadataResponseMetadata.model_rebuild()
|
438
442
|
ScheduleResponseBody.model_rebuild()
|
@@ -590,10 +594,7 @@ __all__ = [
|
|
590
594
|
"PipelineResponse",
|
591
595
|
"PipelineResponseBody",
|
592
596
|
"PipelineResponseMetadata",
|
593
|
-
"
|
594
|
-
"PipelineNamespaceResponse",
|
595
|
-
"PipelineNamespaceResponseBody",
|
596
|
-
"PipelineNamespaceResponseMetadata",
|
597
|
+
"PipelineResponseResources",
|
597
598
|
"PipelineBuildBase",
|
598
599
|
"PipelineBuildRequest",
|
599
600
|
"PipelineBuildFilter",
|
@@ -612,6 +613,14 @@ __all__ = [
|
|
612
613
|
"PipelineRunResponse",
|
613
614
|
"PipelineRunResponseBody",
|
614
615
|
"PipelineRunResponseMetadata",
|
616
|
+
"PipelineRunResponseResources",
|
617
|
+
"RunTemplateRequest",
|
618
|
+
"RunTemplateUpdate",
|
619
|
+
"RunTemplateResponse",
|
620
|
+
"RunTemplateResponseBody",
|
621
|
+
"RunTemplateResponseMetadata",
|
622
|
+
"RunTemplateResponseResources",
|
623
|
+
"RunTemplateFilter",
|
615
624
|
"RunMetadataRequest",
|
616
625
|
"RunMetadataFilter",
|
617
626
|
"RunMetadataResponse",
|
@@ -719,11 +728,6 @@ __all__ = [
|
|
719
728
|
"ExternalUserModel",
|
720
729
|
"BuildItem",
|
721
730
|
"LoadedVisualization",
|
722
|
-
"HubPluginRequestModel",
|
723
|
-
"HubPluginResponseModel",
|
724
|
-
"HubUserResponseModel",
|
725
|
-
"HubPluginBaseModel",
|
726
|
-
"PluginStatus",
|
727
731
|
"ServerModel",
|
728
732
|
"ServerDatabaseType",
|
729
733
|
"ServerDeploymentType",
|
zenml/models/v2/base/filter.py
CHANGED
@@ -36,6 +36,7 @@ from pydantic import (
|
|
36
36
|
field_validator,
|
37
37
|
model_validator,
|
38
38
|
)
|
39
|
+
from sqlalchemy import asc, desc
|
39
40
|
from sqlmodel import SQLModel
|
40
41
|
|
41
42
|
from zenml.constants import (
|
@@ -267,6 +268,7 @@ class BaseFilter(BaseModel):
|
|
267
268
|
"size",
|
268
269
|
"logical_operator",
|
269
270
|
]
|
271
|
+
CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = []
|
270
272
|
|
271
273
|
# List of fields that are not even mentioned as options in the CLI.
|
272
274
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = []
|
@@ -352,6 +354,8 @@ class BaseFilter(BaseModel):
|
|
352
354
|
)
|
353
355
|
elif column in cls.model_fields:
|
354
356
|
return value
|
357
|
+
elif column in cls.CUSTOM_SORTING_OPTIONS:
|
358
|
+
return value
|
355
359
|
else:
|
356
360
|
raise ValueError(
|
357
361
|
"You can only sort by valid fields of this resource"
|
@@ -861,3 +865,31 @@ class BaseFilter(BaseModel):
|
|
861
865
|
query = query.where(filters)
|
862
866
|
|
863
867
|
return query
|
868
|
+
|
869
|
+
def apply_sorting(
|
870
|
+
self,
|
871
|
+
query: AnyQuery,
|
872
|
+
table: Type["AnySchema"],
|
873
|
+
) -> AnyQuery:
|
874
|
+
"""Apply sorting to the query.
|
875
|
+
|
876
|
+
Args:
|
877
|
+
query: The query to which to apply the sorting.
|
878
|
+
table: The query table.
|
879
|
+
|
880
|
+
Returns:
|
881
|
+
The query with sorting applied.
|
882
|
+
"""
|
883
|
+
column, operand = self.sorting_params
|
884
|
+
|
885
|
+
if operand == SorterOps.DESCENDING:
|
886
|
+
sort_clause = desc(getattr(table, column)) # type: ignore[var-annotated]
|
887
|
+
else:
|
888
|
+
sort_clause = asc(getattr(table, column))
|
889
|
+
|
890
|
+
# We always add the `id` column as a tiebreaker to ensure a stable,
|
891
|
+
# repeatable order of items, otherwise subsequent pages might contain
|
892
|
+
# the same items.
|
893
|
+
query = query.order_by(sort_clause, asc(table.id)) # type: ignore[arg-type]
|
894
|
+
|
895
|
+
return query
|