genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1740 @@
|
|
1
|
+
import base64
|
2
|
+
import functools
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
import shutil
|
8
|
+
from contextlib import contextmanager
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import Optional, Union
|
11
|
+
|
12
|
+
import google.protobuf.empty_pb2
|
13
|
+
|
14
|
+
import mlflow
|
15
|
+
from mlflow.entities import Run
|
16
|
+
from mlflow.entities.logged_model import LoggedModel
|
17
|
+
from mlflow.entities.model_registry.prompt import Prompt
|
18
|
+
from mlflow.entities.model_registry.prompt_version import PromptVersion
|
19
|
+
from mlflow.exceptions import MlflowException, RestException
|
20
|
+
from mlflow.protos.databricks_pb2 import (
|
21
|
+
INTERNAL_ERROR,
|
22
|
+
INVALID_PARAMETER_VALUE,
|
23
|
+
RESOURCE_DOES_NOT_EXIST,
|
24
|
+
ErrorCode,
|
25
|
+
)
|
26
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
27
|
+
MODEL_VERSION_OPERATION_READ_WRITE,
|
28
|
+
CreateModelVersionRequest,
|
29
|
+
CreateModelVersionResponse,
|
30
|
+
CreateRegisteredModelRequest,
|
31
|
+
CreateRegisteredModelResponse,
|
32
|
+
DeleteModelVersionRequest,
|
33
|
+
DeleteModelVersionResponse,
|
34
|
+
DeleteModelVersionTagRequest,
|
35
|
+
DeleteModelVersionTagResponse,
|
36
|
+
DeleteRegisteredModelAliasRequest,
|
37
|
+
DeleteRegisteredModelAliasResponse,
|
38
|
+
DeleteRegisteredModelRequest,
|
39
|
+
DeleteRegisteredModelResponse,
|
40
|
+
DeleteRegisteredModelTagRequest,
|
41
|
+
DeleteRegisteredModelTagResponse,
|
42
|
+
Entity,
|
43
|
+
FinalizeModelVersionRequest,
|
44
|
+
FinalizeModelVersionResponse,
|
45
|
+
GenerateTemporaryModelVersionCredentialsRequest,
|
46
|
+
GenerateTemporaryModelVersionCredentialsResponse,
|
47
|
+
GetModelVersionByAliasRequest,
|
48
|
+
GetModelVersionByAliasResponse,
|
49
|
+
GetModelVersionDownloadUriRequest,
|
50
|
+
GetModelVersionDownloadUriResponse,
|
51
|
+
GetModelVersionRequest,
|
52
|
+
GetModelVersionResponse,
|
53
|
+
GetRegisteredModelRequest,
|
54
|
+
GetRegisteredModelResponse,
|
55
|
+
Job,
|
56
|
+
Lineage,
|
57
|
+
LineageHeaderInfo,
|
58
|
+
Notebook,
|
59
|
+
SearchModelVersionsRequest,
|
60
|
+
SearchModelVersionsResponse,
|
61
|
+
SearchRegisteredModelsRequest,
|
62
|
+
SearchRegisteredModelsResponse,
|
63
|
+
Securable,
|
64
|
+
SetModelVersionTagRequest,
|
65
|
+
SetModelVersionTagResponse,
|
66
|
+
SetRegisteredModelAliasRequest,
|
67
|
+
SetRegisteredModelAliasResponse,
|
68
|
+
SetRegisteredModelTagRequest,
|
69
|
+
SetRegisteredModelTagResponse,
|
70
|
+
StorageMode,
|
71
|
+
Table,
|
72
|
+
TemporaryCredentials,
|
73
|
+
UpdateModelVersionRequest,
|
74
|
+
UpdateModelVersionResponse,
|
75
|
+
UpdateRegisteredModelRequest,
|
76
|
+
UpdateRegisteredModelResponse,
|
77
|
+
)
|
78
|
+
from mlflow.protos.databricks_uc_registry_service_pb2 import UcModelRegistryService
|
79
|
+
from mlflow.protos.service_pb2 import GetRun, MlflowService
|
80
|
+
from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
|
81
|
+
CreatePromptRequest,
|
82
|
+
CreatePromptVersionRequest,
|
83
|
+
DeletePromptAliasRequest,
|
84
|
+
DeletePromptRequest,
|
85
|
+
DeletePromptTagRequest,
|
86
|
+
DeletePromptVersionRequest,
|
87
|
+
DeletePromptVersionTagRequest,
|
88
|
+
GetPromptRequest,
|
89
|
+
GetPromptVersionByAliasRequest,
|
90
|
+
GetPromptVersionRequest,
|
91
|
+
LinkPromptsToTracesRequest,
|
92
|
+
LinkPromptVersionsToModelsRequest,
|
93
|
+
LinkPromptVersionsToRunsRequest,
|
94
|
+
PromptVersionLinkEntry,
|
95
|
+
SearchPromptsRequest,
|
96
|
+
SearchPromptsResponse,
|
97
|
+
SearchPromptVersionsRequest,
|
98
|
+
SearchPromptVersionsResponse,
|
99
|
+
SetPromptAliasRequest,
|
100
|
+
SetPromptTagRequest,
|
101
|
+
SetPromptVersionTagRequest,
|
102
|
+
UnityCatalogSchema,
|
103
|
+
UpdatePromptRequest,
|
104
|
+
UpdatePromptVersionRequest,
|
105
|
+
)
|
106
|
+
from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
|
107
|
+
Prompt as ProtoPrompt,
|
108
|
+
)
|
109
|
+
from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
|
110
|
+
PromptVersion as ProtoPromptVersion,
|
111
|
+
)
|
112
|
+
from mlflow.protos.unity_catalog_prompt_service_pb2 import UnityCatalogPromptService
|
113
|
+
from mlflow.store._unity_catalog.lineage.constants import (
|
114
|
+
_DATABRICKS_LINEAGE_ID_HEADER,
|
115
|
+
_DATABRICKS_ORG_ID_HEADER,
|
116
|
+
)
|
117
|
+
from mlflow.store._unity_catalog.registry.utils import (
|
118
|
+
mlflow_tags_to_proto,
|
119
|
+
mlflow_tags_to_proto_version_tags,
|
120
|
+
proto_info_to_mlflow_prompt_info,
|
121
|
+
proto_to_mlflow_prompt,
|
122
|
+
)
|
123
|
+
from mlflow.store.artifact.databricks_sdk_models_artifact_repo import (
|
124
|
+
DatabricksSDKModelsArtifactRepository,
|
125
|
+
)
|
126
|
+
from mlflow.store.artifact.presigned_url_artifact_repo import (
|
127
|
+
PresignedUrlArtifactRepository,
|
128
|
+
)
|
129
|
+
from mlflow.store.entities.paged_list import PagedList
|
130
|
+
from mlflow.store.model_registry.rest_store import BaseRestStore
|
131
|
+
from mlflow.utils._spark_utils import _get_active_spark_session
|
132
|
+
from mlflow.utils._unity_catalog_utils import (
|
133
|
+
get_artifact_repo_from_storage_info,
|
134
|
+
get_full_name_from_sc,
|
135
|
+
is_databricks_sdk_models_artifact_repository_enabled,
|
136
|
+
model_version_from_uc_proto,
|
137
|
+
model_version_search_from_uc_proto,
|
138
|
+
registered_model_from_uc_proto,
|
139
|
+
registered_model_search_from_uc_proto,
|
140
|
+
uc_model_version_tag_from_mlflow_tags,
|
141
|
+
uc_registered_model_tag_from_mlflow_tags,
|
142
|
+
)
|
143
|
+
from mlflow.utils.databricks_utils import (
|
144
|
+
_print_databricks_deployment_job_url,
|
145
|
+
get_databricks_host_creds,
|
146
|
+
is_databricks_uri,
|
147
|
+
)
|
148
|
+
from mlflow.utils.mlflow_tags import (
|
149
|
+
MLFLOW_DATABRICKS_JOB_ID,
|
150
|
+
MLFLOW_DATABRICKS_JOB_RUN_ID,
|
151
|
+
MLFLOW_DATABRICKS_NOTEBOOK_ID,
|
152
|
+
)
|
153
|
+
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
|
154
|
+
from mlflow.utils.rest_utils import (
|
155
|
+
_REST_API_PATH_PREFIX,
|
156
|
+
call_endpoint,
|
157
|
+
extract_all_api_info_for_service,
|
158
|
+
extract_api_info_for_service,
|
159
|
+
http_request,
|
160
|
+
verify_rest_response,
|
161
|
+
)
|
162
|
+
from mlflow.utils.uri import is_fuse_or_uc_volumes_uri
|
163
|
+
|
164
|
+
_TRACKING_METHOD_TO_INFO = extract_api_info_for_service(MlflowService, _REST_API_PATH_PREFIX)
|
165
|
+
_METHOD_TO_INFO = {
|
166
|
+
**extract_api_info_for_service(UcModelRegistryService, _REST_API_PATH_PREFIX),
|
167
|
+
**extract_api_info_for_service(UnityCatalogPromptService, _REST_API_PATH_PREFIX),
|
168
|
+
}
|
169
|
+
_METHOD_TO_ALL_INFO = {
|
170
|
+
**extract_all_api_info_for_service(UcModelRegistryService, _REST_API_PATH_PREFIX),
|
171
|
+
**extract_all_api_info_for_service(UnityCatalogPromptService, _REST_API_PATH_PREFIX),
|
172
|
+
}
|
173
|
+
|
174
|
+
_logger = logging.getLogger(__name__)
|
175
|
+
_DELTA_TABLE = "delta_table"
|
176
|
+
_MAX_LINEAGE_DATA_SOURCES = 10
|
177
|
+
|
178
|
+
# Pre-compiled regex patterns for better performance in search operations
|
179
|
+
_CATALOG_PATTERN = re.compile(r"catalog\s*=\s*['\"]([^'\"]+)['\"]", re.IGNORECASE)
|
180
|
+
_SCHEMA_PATTERN = re.compile(r"schema\s*=\s*['\"]([^'\"]+)['\"]", re.IGNORECASE)
|
181
|
+
|
182
|
+
|
183
|
+
@dataclass
|
184
|
+
class _CatalogSchemaFilter:
|
185
|
+
"""Internal class to hold parsed catalog, schema, and remaining filter."""
|
186
|
+
|
187
|
+
catalog_name: str
|
188
|
+
schema_name: str
|
189
|
+
remaining_filter: Optional[str]
|
190
|
+
|
191
|
+
|
192
|
+
def _require_arg_unspecified(arg_name, arg_value, default_values=None, message=None):
|
193
|
+
default_values = [None] if default_values is None else default_values
|
194
|
+
if arg_value not in default_values:
|
195
|
+
_raise_unsupported_arg(arg_name, message)
|
196
|
+
|
197
|
+
|
198
|
+
def _raise_unsupported_arg(arg_name, message=None):
|
199
|
+
messages = [
|
200
|
+
f"Argument '{arg_name}' is unsupported for models in the Unity Catalog.",
|
201
|
+
]
|
202
|
+
if message is not None:
|
203
|
+
messages.append(message)
|
204
|
+
raise MlflowException(" ".join(messages))
|
205
|
+
|
206
|
+
|
207
|
+
def _raise_unsupported_method(method, message=None):
|
208
|
+
messages = [
|
209
|
+
f"Method '{method}' is unsupported for models in the Unity Catalog.",
|
210
|
+
]
|
211
|
+
if message is not None:
|
212
|
+
messages.append(message)
|
213
|
+
raise MlflowException(" ".join(messages))
|
214
|
+
|
215
|
+
|
216
|
+
def _load_model(local_model_dir):
|
217
|
+
# Import Model here instead of in the top level, to avoid circular import; the
|
218
|
+
# mlflow.models.model module imports from MLflow tracking, which triggers an import of
|
219
|
+
# this file during store registry initialization
|
220
|
+
from mlflow.models.model import Model
|
221
|
+
|
222
|
+
try:
|
223
|
+
return Model.load(local_model_dir)
|
224
|
+
except Exception as e:
|
225
|
+
raise MlflowException(
|
226
|
+
"Unable to load model metadata. Ensure the source path of the model "
|
227
|
+
"being registered points to a valid MLflow model directory "
|
228
|
+
"(see https://mlflow.org/docs/latest/models.html#storage-format) containing a "
|
229
|
+
"model signature (https://mlflow.org/docs/latest/models.html#model-signature) "
|
230
|
+
"specifying both input and output type specifications."
|
231
|
+
) from e
|
232
|
+
|
233
|
+
|
234
|
+
def get_feature_dependencies(model_dir):
|
235
|
+
"""
|
236
|
+
Gets the features which a model depends on. This functionality is only implemented on
|
237
|
+
Databricks. In OSS mlflow, the dependencies are always empty ("").
|
238
|
+
"""
|
239
|
+
model = _load_model(model_dir)
|
240
|
+
model_info = model.get_model_info()
|
241
|
+
if (
|
242
|
+
model_info.flavors.get("python_function", {}).get("loader_module")
|
243
|
+
== mlflow.models.model._DATABRICKS_FS_LOADER_MODULE
|
244
|
+
):
|
245
|
+
raise MlflowException(
|
246
|
+
"This model was packaged by Databricks Feature Store and can only be registered on a "
|
247
|
+
"Databricks cluster."
|
248
|
+
)
|
249
|
+
return ""
|
250
|
+
|
251
|
+
|
252
|
+
def get_model_version_dependencies(model_dir):
|
253
|
+
"""
|
254
|
+
Gets the specified dependencies for a particular model version and formats them
|
255
|
+
to be passed into CreateModelVersion.
|
256
|
+
"""
|
257
|
+
from mlflow.models.resources import ResourceType
|
258
|
+
|
259
|
+
model = _load_model(model_dir)
|
260
|
+
model_info = model.get_model_info()
|
261
|
+
dependencies = []
|
262
|
+
|
263
|
+
# Try to get model.auth_policy.system_auth_policy.resources. If that is not found or empty,
|
264
|
+
# then use model.resources.
|
265
|
+
if model.auth_policy:
|
266
|
+
databricks_resources = model.auth_policy.get("system_auth_policy", {}).get("resources", {})
|
267
|
+
else:
|
268
|
+
databricks_resources = model.resources
|
269
|
+
|
270
|
+
if databricks_resources:
|
271
|
+
databricks_dependencies = databricks_resources.get("databricks", {})
|
272
|
+
dependencies.extend(
|
273
|
+
_fetch_langchain_dependency_from_model_resources(
|
274
|
+
databricks_dependencies,
|
275
|
+
ResourceType.VECTOR_SEARCH_INDEX.value,
|
276
|
+
"DATABRICKS_VECTOR_INDEX",
|
277
|
+
)
|
278
|
+
)
|
279
|
+
dependencies.extend(
|
280
|
+
_fetch_langchain_dependency_from_model_resources(
|
281
|
+
databricks_dependencies,
|
282
|
+
ResourceType.SERVING_ENDPOINT.value,
|
283
|
+
"DATABRICKS_MODEL_ENDPOINT",
|
284
|
+
)
|
285
|
+
)
|
286
|
+
dependencies.extend(
|
287
|
+
_fetch_langchain_dependency_from_model_resources(
|
288
|
+
databricks_dependencies,
|
289
|
+
ResourceType.FUNCTION.value,
|
290
|
+
"DATABRICKS_UC_FUNCTION",
|
291
|
+
)
|
292
|
+
)
|
293
|
+
dependencies.extend(
|
294
|
+
_fetch_langchain_dependency_from_model_resources(
|
295
|
+
databricks_dependencies,
|
296
|
+
ResourceType.UC_CONNECTION.value,
|
297
|
+
"DATABRICKS_UC_CONNECTION",
|
298
|
+
)
|
299
|
+
)
|
300
|
+
dependencies.extend(
|
301
|
+
_fetch_langchain_dependency_from_model_resources(
|
302
|
+
databricks_dependencies,
|
303
|
+
ResourceType.TABLE.value,
|
304
|
+
"DATABRICKS_TABLE",
|
305
|
+
)
|
306
|
+
)
|
307
|
+
else:
|
308
|
+
# These types of dependencies are required for old models that didn't use
|
309
|
+
# resources so they can be registered correctly to UC
|
310
|
+
_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY = "databricks_vector_search_index_name"
|
311
|
+
_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY = "databricks_embeddings_endpoint_name"
|
312
|
+
_DATABRICKS_LLM_ENDPOINT_NAME_KEY = "databricks_llm_endpoint_name"
|
313
|
+
_DATABRICKS_CHAT_ENDPOINT_NAME_KEY = "databricks_chat_endpoint_name"
|
314
|
+
_DB_DEPENDENCY_KEY = "databricks_dependency"
|
315
|
+
|
316
|
+
databricks_dependencies = model_info.flavors.get("langchain", {}).get(
|
317
|
+
_DB_DEPENDENCY_KEY, {}
|
318
|
+
)
|
319
|
+
|
320
|
+
index_names = _fetch_langchain_dependency_from_model_info(
|
321
|
+
databricks_dependencies, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY
|
322
|
+
)
|
323
|
+
for index_name in index_names:
|
324
|
+
dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name})
|
325
|
+
for key in (
|
326
|
+
_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY,
|
327
|
+
_DATABRICKS_LLM_ENDPOINT_NAME_KEY,
|
328
|
+
_DATABRICKS_CHAT_ENDPOINT_NAME_KEY,
|
329
|
+
):
|
330
|
+
endpoint_names = _fetch_langchain_dependency_from_model_info(
|
331
|
+
databricks_dependencies, key
|
332
|
+
)
|
333
|
+
for endpoint_name in endpoint_names:
|
334
|
+
dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name})
|
335
|
+
return dependencies
|
336
|
+
|
337
|
+
|
338
|
+
def _fetch_langchain_dependency_from_model_resources(databricks_dependencies, key, resource_type):
|
339
|
+
dependencies = databricks_dependencies.get(key, [])
|
340
|
+
deps = []
|
341
|
+
for dependency in dependencies:
|
342
|
+
if dependency.get("on_behalf_of_user", False):
|
343
|
+
continue
|
344
|
+
deps.append({"type": resource_type, "name": dependency["name"]})
|
345
|
+
return deps
|
346
|
+
|
347
|
+
|
348
|
+
def _fetch_langchain_dependency_from_model_info(databricks_dependencies, key):
|
349
|
+
return databricks_dependencies.get(key, [])
|
350
|
+
|
351
|
+
|
352
|
+
class UcModelRegistryStore(BaseRestStore):
|
353
|
+
"""
|
354
|
+
Client for a remote model registry server accessed via REST API calls
|
355
|
+
|
356
|
+
Args:
|
357
|
+
store_uri: URI with scheme 'databricks-uc'
|
358
|
+
tracking_uri: URI of the Databricks MLflow tracking server from which to fetch
|
359
|
+
run info and download run artifacts, when creating new model
|
360
|
+
versions from source artifacts logged to an MLflow run.
|
361
|
+
"""
|
362
|
+
|
363
|
+
def __init__(self, store_uri, tracking_uri):
|
364
|
+
super().__init__(get_host_creds=functools.partial(get_databricks_host_creds, store_uri))
|
365
|
+
self.tracking_uri = tracking_uri
|
366
|
+
self.get_tracking_host_creds = functools.partial(get_databricks_host_creds, tracking_uri)
|
367
|
+
try:
|
368
|
+
self.spark = _get_active_spark_session()
|
369
|
+
except Exception:
|
370
|
+
pass
|
371
|
+
|
372
|
+
def _get_response_from_method(self, method):
|
373
|
+
method_to_response = {
|
374
|
+
CreateRegisteredModelRequest: CreateRegisteredModelResponse,
|
375
|
+
UpdateRegisteredModelRequest: UpdateRegisteredModelResponse,
|
376
|
+
DeleteRegisteredModelRequest: DeleteRegisteredModelResponse,
|
377
|
+
CreateModelVersionRequest: CreateModelVersionResponse,
|
378
|
+
FinalizeModelVersionRequest: FinalizeModelVersionResponse,
|
379
|
+
UpdateModelVersionRequest: UpdateModelVersionResponse,
|
380
|
+
DeleteModelVersionRequest: DeleteModelVersionResponse,
|
381
|
+
GetModelVersionDownloadUriRequest: GetModelVersionDownloadUriResponse,
|
382
|
+
SearchModelVersionsRequest: SearchModelVersionsResponse,
|
383
|
+
GetRegisteredModelRequest: GetRegisteredModelResponse,
|
384
|
+
GetModelVersionRequest: GetModelVersionResponse,
|
385
|
+
SearchRegisteredModelsRequest: SearchRegisteredModelsResponse,
|
386
|
+
GenerateTemporaryModelVersionCredentialsRequest: (
|
387
|
+
GenerateTemporaryModelVersionCredentialsResponse
|
388
|
+
),
|
389
|
+
GetRun: GetRun.Response,
|
390
|
+
SetRegisteredModelAliasRequest: SetRegisteredModelAliasResponse,
|
391
|
+
DeleteRegisteredModelAliasRequest: DeleteRegisteredModelAliasResponse,
|
392
|
+
SetRegisteredModelTagRequest: SetRegisteredModelTagResponse,
|
393
|
+
DeleteRegisteredModelTagRequest: DeleteRegisteredModelTagResponse,
|
394
|
+
SetModelVersionTagRequest: SetModelVersionTagResponse,
|
395
|
+
DeleteModelVersionTagRequest: DeleteModelVersionTagResponse,
|
396
|
+
GetModelVersionByAliasRequest: GetModelVersionByAliasResponse,
|
397
|
+
CreatePromptRequest: ProtoPrompt,
|
398
|
+
SearchPromptsRequest: SearchPromptsResponse,
|
399
|
+
DeletePromptRequest: google.protobuf.empty_pb2.Empty,
|
400
|
+
SetPromptTagRequest: google.protobuf.empty_pb2.Empty,
|
401
|
+
DeletePromptTagRequest: google.protobuf.empty_pb2.Empty,
|
402
|
+
CreatePromptVersionRequest: ProtoPromptVersion,
|
403
|
+
GetPromptVersionRequest: ProtoPromptVersion,
|
404
|
+
DeletePromptVersionRequest: google.protobuf.empty_pb2.Empty,
|
405
|
+
GetPromptVersionByAliasRequest: ProtoPromptVersion,
|
406
|
+
UpdatePromptRequest: ProtoPrompt,
|
407
|
+
GetPromptRequest: ProtoPrompt,
|
408
|
+
SearchPromptVersionsRequest: SearchPromptVersionsResponse,
|
409
|
+
SetPromptAliasRequest: google.protobuf.empty_pb2.Empty,
|
410
|
+
DeletePromptAliasRequest: google.protobuf.empty_pb2.Empty,
|
411
|
+
SetPromptVersionTagRequest: google.protobuf.empty_pb2.Empty,
|
412
|
+
DeletePromptVersionTagRequest: google.protobuf.empty_pb2.Empty,
|
413
|
+
UpdatePromptVersionRequest: ProtoPromptVersion,
|
414
|
+
LinkPromptVersionsToModelsRequest: google.protobuf.empty_pb2.Empty,
|
415
|
+
LinkPromptsToTracesRequest: google.protobuf.empty_pb2.Empty,
|
416
|
+
LinkPromptVersionsToRunsRequest: google.protobuf.empty_pb2.Empty,
|
417
|
+
}
|
418
|
+
return method_to_response[method]()
|
419
|
+
|
420
|
+
def _get_endpoint_from_method(self, method):
|
421
|
+
return _METHOD_TO_INFO[method]
|
422
|
+
|
423
|
+
def _get_all_endpoints_from_method(self, method):
|
424
|
+
return _METHOD_TO_ALL_INFO[method]
|
425
|
+
|
426
|
+
# CRUD API for RegisteredModel objects
|
427
|
+
|
428
|
+
def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
|
429
|
+
"""
|
430
|
+
Create a new registered model in backend store.
|
431
|
+
|
432
|
+
Args:
|
433
|
+
name: Name of the new model. This is expected to be unique in the backend store.
|
434
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
|
435
|
+
instances associated with this registered model.
|
436
|
+
description: Description of the model.
|
437
|
+
deployment_job_id: Optional deployment job id.
|
438
|
+
|
439
|
+
Returns:
|
440
|
+
A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
|
441
|
+
created in the backend.
|
442
|
+
|
443
|
+
"""
|
444
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
445
|
+
req_body = message_to_json(
|
446
|
+
CreateRegisteredModelRequest(
|
447
|
+
name=full_name,
|
448
|
+
description=description,
|
449
|
+
tags=uc_registered_model_tag_from_mlflow_tags(tags),
|
450
|
+
deployment_job_id=str(deployment_job_id) if deployment_job_id else None,
|
451
|
+
)
|
452
|
+
)
|
453
|
+
try:
|
454
|
+
response_proto = self._call_endpoint(CreateRegisteredModelRequest, req_body)
|
455
|
+
except RestException as e:
|
456
|
+
|
457
|
+
def reraise_with_legacy_hint(exception, legacy_hint):
|
458
|
+
new_message = exception.message.rstrip(".") + f". {legacy_hint}"
|
459
|
+
raise MlflowException(
|
460
|
+
message=new_message,
|
461
|
+
error_code=exception.error_code,
|
462
|
+
)
|
463
|
+
|
464
|
+
if "specify all three levels" in e.message:
|
465
|
+
# The exception is likely due to the user trying to create a registered model
|
466
|
+
# in Unity Catalog without specifying a 3-level name (catalog.schema.model).
|
467
|
+
# The user may not be intending to use the Unity Catalog Model Registry at all,
|
468
|
+
# but rather the legacy Workspace Model Registry. Accordingly, we re-raise with
|
469
|
+
# a hint
|
470
|
+
legacy_hint = (
|
471
|
+
"If you are trying to use the legacy Workspace Model Registry, instead of the"
|
472
|
+
" recommended Unity Catalog Model Registry, set the Model Registry URI to"
|
473
|
+
" 'databricks' (legacy) instead of 'databricks-uc' (recommended)."
|
474
|
+
)
|
475
|
+
reraise_with_legacy_hint(exception=e, legacy_hint=legacy_hint)
|
476
|
+
elif "METASTORE_DOES_NOT_EXIST" in e.message:
|
477
|
+
legacy_hint = (
|
478
|
+
"If you are trying to use the Model Registry in a Databricks workspace that"
|
479
|
+
" does not have Unity Catalog enabled, either enable Unity Catalog in the"
|
480
|
+
" workspace (recommended) or set the Model Registry URI to 'databricks' to"
|
481
|
+
" use the legacy Workspace Model Registry."
|
482
|
+
)
|
483
|
+
reraise_with_legacy_hint(exception=e, legacy_hint=legacy_hint)
|
484
|
+
else:
|
485
|
+
raise
|
486
|
+
|
487
|
+
if deployment_job_id:
|
488
|
+
_print_databricks_deployment_job_url(
|
489
|
+
model_name=full_name,
|
490
|
+
job_id=str(deployment_job_id),
|
491
|
+
)
|
492
|
+
return registered_model_from_uc_proto(response_proto.registered_model)
|
493
|
+
|
494
|
+
def update_registered_model(self, name, description=None, deployment_job_id=None):
|
495
|
+
"""
|
496
|
+
Update description of the registered model.
|
497
|
+
|
498
|
+
Args:
|
499
|
+
name: Registered model name.
|
500
|
+
description: New description.
|
501
|
+
deployment_job_id: Optional deployment job id.
|
502
|
+
|
503
|
+
Returns:
|
504
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
505
|
+
"""
|
506
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
507
|
+
req_body = message_to_json(
|
508
|
+
UpdateRegisteredModelRequest(
|
509
|
+
name=full_name,
|
510
|
+
description=description,
|
511
|
+
deployment_job_id=str(deployment_job_id) if deployment_job_id else None,
|
512
|
+
)
|
513
|
+
)
|
514
|
+
response_proto = self._call_endpoint(UpdateRegisteredModelRequest, req_body)
|
515
|
+
if deployment_job_id:
|
516
|
+
_print_databricks_deployment_job_url(
|
517
|
+
model_name=full_name,
|
518
|
+
job_id=str(deployment_job_id),
|
519
|
+
)
|
520
|
+
return registered_model_from_uc_proto(response_proto.registered_model)
|
521
|
+
|
522
|
+
def rename_registered_model(self, name, new_name):
|
523
|
+
"""
|
524
|
+
Rename the registered model.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
name: Registered model name.
|
528
|
+
new_name: New proposed name.
|
529
|
+
|
530
|
+
Returns:
|
531
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
532
|
+
"""
|
533
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
534
|
+
req_body = message_to_json(UpdateRegisteredModelRequest(name=full_name, new_name=new_name))
|
535
|
+
response_proto = self._call_endpoint(UpdateRegisteredModelRequest, req_body)
|
536
|
+
return registered_model_from_uc_proto(response_proto.registered_model)
|
537
|
+
|
538
|
+
def delete_registered_model(self, name):
|
539
|
+
"""
|
540
|
+
Delete the registered model.
|
541
|
+
Backend raises exception if a registered model with given name does not exist.
|
542
|
+
|
543
|
+
Args:
|
544
|
+
name: Registered model name.
|
545
|
+
|
546
|
+
Returns:
|
547
|
+
None
|
548
|
+
"""
|
549
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
550
|
+
req_body = message_to_json(DeleteRegisteredModelRequest(name=full_name))
|
551
|
+
self._call_endpoint(DeleteRegisteredModelRequest, req_body)
|
552
|
+
|
553
|
+
def search_registered_models(
|
554
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
555
|
+
):
|
556
|
+
"""
|
557
|
+
Search for registered models in backend that satisfy the filter criteria.
|
558
|
+
|
559
|
+
Args:
|
560
|
+
filter_string: Filter query string, defaults to searching all registered models.
|
561
|
+
max_results: Maximum number of registered models desired.
|
562
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
563
|
+
matching search results.
|
564
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
565
|
+
a ``search_registered_models`` call.
|
566
|
+
|
567
|
+
Returns:
|
568
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
|
569
|
+
that satisfy the search expressions. The pagination token for the next page can be
|
570
|
+
obtained via the ``token`` attribute of the object.
|
571
|
+
|
572
|
+
"""
|
573
|
+
_require_arg_unspecified("filter_string", filter_string)
|
574
|
+
_require_arg_unspecified("order_by", order_by)
|
575
|
+
req_body = message_to_json(
|
576
|
+
SearchRegisteredModelsRequest(
|
577
|
+
max_results=max_results,
|
578
|
+
page_token=page_token,
|
579
|
+
)
|
580
|
+
)
|
581
|
+
response_proto = self._call_endpoint(SearchRegisteredModelsRequest, req_body)
|
582
|
+
registered_models = [
|
583
|
+
registered_model_search_from_uc_proto(registered_model)
|
584
|
+
for registered_model in response_proto.registered_models
|
585
|
+
]
|
586
|
+
return PagedList(registered_models, response_proto.next_page_token)
|
587
|
+
|
588
|
+
def get_registered_model(self, name):
|
589
|
+
"""
|
590
|
+
Get registered model instance by name.
|
591
|
+
|
592
|
+
Args:
|
593
|
+
name: Registered model name.
|
594
|
+
|
595
|
+
Returns:
|
596
|
+
A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
597
|
+
"""
|
598
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
599
|
+
req_body = message_to_json(GetRegisteredModelRequest(name=full_name))
|
600
|
+
response_proto = self._call_endpoint(GetRegisteredModelRequest, req_body)
|
601
|
+
return registered_model_from_uc_proto(response_proto.registered_model)
|
602
|
+
|
603
|
+
def get_latest_versions(self, name, stages=None):
|
604
|
+
"""
|
605
|
+
Latest version models for each requested stage. If no ``stages`` argument is provided,
|
606
|
+
returns the latest version for each stage.
|
607
|
+
|
608
|
+
Args:
|
609
|
+
name: Registered model name.
|
610
|
+
stages: List of desired stages. If input list is None, return latest versions for
|
611
|
+
each stage.
|
612
|
+
|
613
|
+
Returns:
|
614
|
+
List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
|
615
|
+
"""
|
616
|
+
alias_doc_url = "https://mlflow.org/docs/latest/model-registry.html#deploy-and-organize-models-with-aliases-and-tags"
|
617
|
+
if stages is None:
|
618
|
+
message = (
|
619
|
+
"To load the latest version of a model in Unity Catalog, you can "
|
620
|
+
"set an alias on the model version and load it by alias. See "
|
621
|
+
f"{alias_doc_url} for details."
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
message = (
|
625
|
+
f"Detected attempt to load latest model version in stages {stages}. "
|
626
|
+
"You may see this error because:\n"
|
627
|
+
"1) You're attempting to load a model version by stage. Setting stages "
|
628
|
+
"and loading model versions by stage is unsupported in Unity Catalog. Instead, "
|
629
|
+
"use aliases for flexible model deployment. See "
|
630
|
+
f"{alias_doc_url} for details.\n"
|
631
|
+
"2) You're attempting to load a model version by alias. Use "
|
632
|
+
"syntax 'models:/your_model_name@your_alias_name'\n"
|
633
|
+
"3) You're attempting load a model version by version number. Verify "
|
634
|
+
"that the version number is a valid integer"
|
635
|
+
)
|
636
|
+
|
637
|
+
_raise_unsupported_method(
|
638
|
+
method="get_latest_versions",
|
639
|
+
message=message,
|
640
|
+
)
|
641
|
+
|
642
|
+
def set_registered_model_tag(self, name, tag):
|
643
|
+
"""
|
644
|
+
Set a tag for the registered model.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
name: Registered model name.
|
648
|
+
tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
|
649
|
+
|
650
|
+
Returns:
|
651
|
+
None
|
652
|
+
"""
|
653
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
654
|
+
req_body = message_to_json(
|
655
|
+
SetRegisteredModelTagRequest(name=full_name, key=tag.key, value=tag.value)
|
656
|
+
)
|
657
|
+
self._call_endpoint(SetRegisteredModelTagRequest, req_body)
|
658
|
+
|
659
|
+
def delete_registered_model_tag(self, name, key):
|
660
|
+
"""
|
661
|
+
Delete a tag associated with the registered model.
|
662
|
+
|
663
|
+
Args:
|
664
|
+
name: Registered model name.
|
665
|
+
key: Registered model tag key.
|
666
|
+
|
667
|
+
Returns:
|
668
|
+
None
|
669
|
+
"""
|
670
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
671
|
+
req_body = message_to_json(DeleteRegisteredModelTagRequest(name=full_name, key=key))
|
672
|
+
self._call_endpoint(DeleteRegisteredModelTagRequest, req_body)
|
673
|
+
|
674
|
+
# CRUD API for ModelVersion objects
|
675
|
+
def _finalize_model_version(self, name, version):
|
676
|
+
"""
|
677
|
+
Finalize a UC model version after its files have been written to managed storage,
|
678
|
+
updating its status from PENDING_REGISTRATION to READY
|
679
|
+
|
680
|
+
Args:
|
681
|
+
name: Registered model name
|
682
|
+
version: Model version number
|
683
|
+
|
684
|
+
Returns:
|
685
|
+
Protobuf ModelVersion describing the finalized model version
|
686
|
+
"""
|
687
|
+
req_body = message_to_json(FinalizeModelVersionRequest(name=name, version=version))
|
688
|
+
return self._call_endpoint(FinalizeModelVersionRequest, req_body).model_version
|
689
|
+
|
690
|
+
def _get_temporary_model_version_write_credentials(self, name, version) -> TemporaryCredentials:
|
691
|
+
"""
|
692
|
+
Get temporary credentials for uploading model version files
|
693
|
+
|
694
|
+
Args:
|
695
|
+
name: Registered model name.
|
696
|
+
version: Model version number.
|
697
|
+
|
698
|
+
Returns:
|
699
|
+
mlflow.protos.databricks_uc_registry_messages_pb2.TemporaryCredentials containing
|
700
|
+
temporary model version credentials.
|
701
|
+
"""
|
702
|
+
req_body = message_to_json(
|
703
|
+
GenerateTemporaryModelVersionCredentialsRequest(
|
704
|
+
name=name, version=version, operation=MODEL_VERSION_OPERATION_READ_WRITE
|
705
|
+
)
|
706
|
+
)
|
707
|
+
return self._call_endpoint(
|
708
|
+
GenerateTemporaryModelVersionCredentialsRequest, req_body
|
709
|
+
).credentials
|
710
|
+
|
711
|
+
def _get_run_and_headers(self, run_id):
|
712
|
+
if run_id is None or not is_databricks_uri(self.tracking_uri):
|
713
|
+
return None, None
|
714
|
+
host_creds = self.get_tracking_host_creds()
|
715
|
+
endpoint, method = _TRACKING_METHOD_TO_INFO[GetRun]
|
716
|
+
response = http_request(
|
717
|
+
host_creds=host_creds,
|
718
|
+
endpoint=endpoint,
|
719
|
+
method=method,
|
720
|
+
params={"run_id": run_id},
|
721
|
+
)
|
722
|
+
try:
|
723
|
+
verify_rest_response(response, endpoint)
|
724
|
+
except MlflowException:
|
725
|
+
_logger.warning(
|
726
|
+
f"Unable to fetch model version's source run (with ID {run_id}) "
|
727
|
+
"from tracking server. The source run may be deleted or inaccessible to the "
|
728
|
+
"current user. No run link will be recorded for the model version."
|
729
|
+
)
|
730
|
+
return None, None
|
731
|
+
headers = response.headers
|
732
|
+
js_dict = response.json()
|
733
|
+
parsed_response = GetRun.Response()
|
734
|
+
parse_dict(js_dict=js_dict, message=parsed_response)
|
735
|
+
run = Run.from_proto(parsed_response.run)
|
736
|
+
return headers, run
|
737
|
+
|
738
|
+
def _get_workspace_id(self, headers):
|
739
|
+
if headers is None or _DATABRICKS_ORG_ID_HEADER not in headers:
|
740
|
+
_logger.warning(
|
741
|
+
"Unable to get model version source run's workspace ID from request headers. "
|
742
|
+
"No run link will be recorded for the model version"
|
743
|
+
)
|
744
|
+
return None
|
745
|
+
return headers[_DATABRICKS_ORG_ID_HEADER]
|
746
|
+
|
747
|
+
def _get_notebook_id(self, run):
|
748
|
+
if run is None:
|
749
|
+
return None
|
750
|
+
return run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_ID, None)
|
751
|
+
|
752
|
+
def _get_job_id(self, run):
|
753
|
+
if run is None:
|
754
|
+
return None
|
755
|
+
return run.data.tags.get(MLFLOW_DATABRICKS_JOB_ID, None)
|
756
|
+
|
757
|
+
def _get_job_run_id(self, run):
|
758
|
+
if run is None:
|
759
|
+
return None
|
760
|
+
return run.data.tags.get(MLFLOW_DATABRICKS_JOB_RUN_ID, None)
|
761
|
+
|
762
|
+
def _get_lineage_input_sources(self, run):
|
763
|
+
from mlflow.data.delta_dataset_source import DeltaDatasetSource
|
764
|
+
|
765
|
+
if run is None:
|
766
|
+
return None
|
767
|
+
securable_list = []
|
768
|
+
if run.inputs is not None:
|
769
|
+
for dataset in run.inputs.dataset_inputs:
|
770
|
+
dataset_source = mlflow.data.get_source(dataset)
|
771
|
+
if (
|
772
|
+
isinstance(dataset_source, DeltaDatasetSource)
|
773
|
+
and dataset_source._get_source_type() == _DELTA_TABLE
|
774
|
+
):
|
775
|
+
# check if dataset is a uc table and then append
|
776
|
+
if dataset_source.delta_table_name and dataset_source.delta_table_id:
|
777
|
+
table_entity = Table(
|
778
|
+
name=dataset_source.delta_table_name,
|
779
|
+
table_id=dataset_source.delta_table_id,
|
780
|
+
)
|
781
|
+
securable_list.append(Securable(table=table_entity))
|
782
|
+
if len(securable_list) > _MAX_LINEAGE_DATA_SOURCES:
|
783
|
+
_logger.warning(
|
784
|
+
f"Model version has {len(securable_list)!s} upstream datasets, which "
|
785
|
+
f"exceeds the max of 10 upstream datasets for lineage tracking. Only "
|
786
|
+
f"the first 10 datasets will be propagated to Unity Catalog lineage"
|
787
|
+
)
|
788
|
+
return securable_list[0:_MAX_LINEAGE_DATA_SOURCES]
|
789
|
+
else:
|
790
|
+
return None
|
791
|
+
|
792
|
+
def _validate_model_signature(self, local_model_path):
|
793
|
+
# Import Model here instead of in the top level, to avoid circular import; the
|
794
|
+
# mlflow.models.model module imports from MLflow tracking, which triggers an import of
|
795
|
+
# this file during store registry initialization
|
796
|
+
model = _load_model(local_model_path)
|
797
|
+
signature_required_explanation = (
|
798
|
+
"All models in the Unity Catalog must be logged with a "
|
799
|
+
"model signature containing both input and output "
|
800
|
+
"type specifications. See "
|
801
|
+
"https://mlflow.org/docs/latest/model/signatures.html#how-to-log-models-with-signatures"
|
802
|
+
" for details on how to log a model with a signature"
|
803
|
+
)
|
804
|
+
if model.signature is None:
|
805
|
+
raise MlflowException(
|
806
|
+
"Model passed for registration did not contain any signature metadata. "
|
807
|
+
f"{signature_required_explanation}"
|
808
|
+
)
|
809
|
+
if model.signature.outputs is None:
|
810
|
+
raise MlflowException(
|
811
|
+
"Model passed for registration contained a signature that includes only inputs. "
|
812
|
+
f"{signature_required_explanation}"
|
813
|
+
)
|
814
|
+
|
815
|
+
def _download_model_weights_if_not_saved(self, local_model_path):
|
816
|
+
"""
|
817
|
+
Transformers models can be saved without the base model weights by setting
|
818
|
+
`save_pretrained=False` when saving or logging the model. Such 'weight-less'
|
819
|
+
model cannot be directly deployed to model serving, so here we download the
|
820
|
+
weights proactively from the HuggingFace hub and save them to the model directory.
|
821
|
+
"""
|
822
|
+
model = _load_model(local_model_path)
|
823
|
+
flavor_conf = model.flavors.get("transformers")
|
824
|
+
|
825
|
+
if not flavor_conf:
|
826
|
+
return
|
827
|
+
|
828
|
+
from mlflow.transformers.flavor_config import FlavorKey
|
829
|
+
from mlflow.transformers.model_io import _MODEL_BINARY_FILE_NAME
|
830
|
+
|
831
|
+
if (
|
832
|
+
FlavorKey.MODEL_BINARY in flavor_conf
|
833
|
+
and os.path.exists(os.path.join(local_model_path, _MODEL_BINARY_FILE_NAME))
|
834
|
+
and FlavorKey.MODEL_REVISION not in flavor_conf
|
835
|
+
):
|
836
|
+
# Model weights are already saved
|
837
|
+
return
|
838
|
+
|
839
|
+
_logger.info(
|
840
|
+
"You are attempting to register a transformers model that does not have persisted "
|
841
|
+
"model weights. Attempting to fetch the weights so that the model can be registered "
|
842
|
+
"within Unity Catalog."
|
843
|
+
)
|
844
|
+
try:
|
845
|
+
mlflow.transformers.persist_pretrained_model(local_model_path)
|
846
|
+
except Exception as e:
|
847
|
+
raise MlflowException(
|
848
|
+
"Failed to download the model weights from the HuggingFace hub and cannot register "
|
849
|
+
"the model in the Unity Catalog. Please ensure that the model was saved with the "
|
850
|
+
"correct reference to the HuggingFace hub repository and that you have access to "
|
851
|
+
"fetch model weights from the defined repository.",
|
852
|
+
error_code=INTERNAL_ERROR,
|
853
|
+
) from e
|
854
|
+
|
855
|
+
@contextmanager
|
856
|
+
def _local_model_dir(self, source, local_model_path):
|
857
|
+
if local_model_path is not None:
|
858
|
+
yield local_model_path
|
859
|
+
else:
|
860
|
+
try:
|
861
|
+
local_model_dir = mlflow.artifacts.download_artifacts(
|
862
|
+
artifact_uri=source, tracking_uri=self.tracking_uri
|
863
|
+
)
|
864
|
+
except Exception as e:
|
865
|
+
raise MlflowException(
|
866
|
+
f"Unable to download model artifacts from source artifact location "
|
867
|
+
f"'{source}' in order to upload them to Unity Catalog. Please ensure "
|
868
|
+
f"the source artifact location exists and that you can download from "
|
869
|
+
f"it via mlflow.artifacts.download_artifacts()"
|
870
|
+
) from e
|
871
|
+
try:
|
872
|
+
yield local_model_dir
|
873
|
+
finally:
|
874
|
+
# Clean up temporary model directory at end of block. We assume a temporary
|
875
|
+
# model directory was created if the `source` is not a local path
|
876
|
+
# (must be downloaded from remote to a temporary directory) and
|
877
|
+
# `local_model_dir` is not a FUSE-mounted path. The check for FUSE-mounted
|
878
|
+
# paths is important as mlflow.artifacts.download_artifacts() can return
|
879
|
+
# a FUSE mounted path equivalent to the (remote) source path in some cases,
|
880
|
+
# e.g. return /dbfs/some/path for source dbfs:/some/path.
|
881
|
+
if not os.path.exists(source) and not is_fuse_or_uc_volumes_uri(local_model_dir):
|
882
|
+
shutil.rmtree(local_model_dir)
|
883
|
+
|
884
|
+
def _get_logged_model_from_model_id(self, model_id) -> Optional[LoggedModel]:
|
885
|
+
# load the MLflow LoggedModel by model_id and
|
886
|
+
if model_id is None:
|
887
|
+
return None
|
888
|
+
return mlflow.get_logged_model(model_id)
|
889
|
+
|
890
|
+
def create_model_version(
|
891
|
+
self,
|
892
|
+
name,
|
893
|
+
source,
|
894
|
+
run_id=None,
|
895
|
+
tags=None,
|
896
|
+
run_link=None,
|
897
|
+
description=None,
|
898
|
+
local_model_path=None,
|
899
|
+
model_id: Optional[str] = None,
|
900
|
+
):
|
901
|
+
"""
|
902
|
+
Create a new model version from given source and run ID.
|
903
|
+
|
904
|
+
Args:
|
905
|
+
name: Registered model name.
|
906
|
+
source: URI indicating the location of the model artifacts.
|
907
|
+
run_id: Run ID from MLflow tracking server that generated the model.
|
908
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
|
909
|
+
instances associated with this model version.
|
910
|
+
run_link: Link to the run from an MLflow tracking server that generated this model.
|
911
|
+
description: Description of the version.
|
912
|
+
local_model_path: Local path to the MLflow model, if it's already accessible on the
|
913
|
+
local filesystem. Can be used by AbstractStores that upload model version files
|
914
|
+
to the model registry to avoid a redundant download from the source location when
|
915
|
+
logging and registering a model via a single
|
916
|
+
mlflow.<flavor>.log_model(..., registered_model_name) call.
|
917
|
+
model_id: The ID of the model (from an Experiment) that is being promoted to a
|
918
|
+
registered model version, if applicable.
|
919
|
+
|
920
|
+
Returns:
|
921
|
+
A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
922
|
+
created in the backend.
|
923
|
+
"""
|
924
|
+
_require_arg_unspecified(arg_name="run_link", arg_value=run_link)
|
925
|
+
logged_model = self._get_logged_model_from_model_id(model_id)
|
926
|
+
if logged_model:
|
927
|
+
run_id = logged_model.source_run_id
|
928
|
+
headers, run = self._get_run_and_headers(run_id)
|
929
|
+
source_workspace_id = self._get_workspace_id(headers)
|
930
|
+
notebook_id = self._get_notebook_id(run)
|
931
|
+
lineage_securable_list = self._get_lineage_input_sources(run)
|
932
|
+
job_id = self._get_job_id(run)
|
933
|
+
job_run_id = self._get_job_run_id(run)
|
934
|
+
extra_headers = None
|
935
|
+
if notebook_id is not None or job_id is not None:
|
936
|
+
entity_list = []
|
937
|
+
lineage_list = None
|
938
|
+
if notebook_id is not None:
|
939
|
+
notebook_entity = Notebook(id=str(notebook_id))
|
940
|
+
entity_list.append(Entity(notebook=notebook_entity))
|
941
|
+
if job_id is not None:
|
942
|
+
job_entity = Job(id=job_id, job_run_id=job_run_id)
|
943
|
+
entity_list.append(Entity(job=job_entity))
|
944
|
+
if lineage_securable_list is not None:
|
945
|
+
lineage_list = [Lineage(source_securables=lineage_securable_list)]
|
946
|
+
lineage_header_info = LineageHeaderInfo(entities=entity_list, lineages=lineage_list)
|
947
|
+
# Base64-encode the header value to ensure it's valid ASCII,
|
948
|
+
# similar to JWT (see https://stackoverflow.com/a/40347926)
|
949
|
+
header_json = message_to_json(lineage_header_info)
|
950
|
+
header_base64 = base64.b64encode(header_json.encode())
|
951
|
+
extra_headers = {_DATABRICKS_LINEAGE_ID_HEADER: header_base64}
|
952
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
953
|
+
with self._local_model_dir(source, local_model_path) as local_model_dir:
|
954
|
+
self._validate_model_signature(local_model_dir)
|
955
|
+
self._download_model_weights_if_not_saved(local_model_dir)
|
956
|
+
feature_deps = get_feature_dependencies(local_model_dir)
|
957
|
+
other_model_deps = get_model_version_dependencies(local_model_dir)
|
958
|
+
req_body = message_to_json(
|
959
|
+
CreateModelVersionRequest(
|
960
|
+
name=full_name,
|
961
|
+
source=source,
|
962
|
+
run_id=run_id,
|
963
|
+
description=description,
|
964
|
+
tags=uc_model_version_tag_from_mlflow_tags(tags),
|
965
|
+
run_tracking_server_id=source_workspace_id,
|
966
|
+
feature_deps=feature_deps,
|
967
|
+
model_version_dependencies=other_model_deps,
|
968
|
+
model_id=model_id,
|
969
|
+
)
|
970
|
+
)
|
971
|
+
model_version = self._call_endpoint(
|
972
|
+
CreateModelVersionRequest, req_body, extra_headers=extra_headers
|
973
|
+
).model_version
|
974
|
+
|
975
|
+
store = self._get_artifact_repo(model_version, full_name)
|
976
|
+
store.log_artifacts(local_dir=local_model_dir, artifact_path="")
|
977
|
+
finalized_mv = self._finalize_model_version(
|
978
|
+
name=full_name, version=model_version.version
|
979
|
+
)
|
980
|
+
return model_version_from_uc_proto(finalized_mv)
|
981
|
+
|
982
|
+
def _get_artifact_repo(self, model_version, model_name=None):
|
983
|
+
def base_credential_refresh_def():
|
984
|
+
return self._get_temporary_model_version_write_credentials(
|
985
|
+
name=model_version.name, version=model_version.version
|
986
|
+
)
|
987
|
+
|
988
|
+
if is_databricks_sdk_models_artifact_repository_enabled(self.get_host_creds()):
|
989
|
+
return DatabricksSDKModelsArtifactRepository(model_name, model_version.version)
|
990
|
+
|
991
|
+
scoped_token = base_credential_refresh_def()
|
992
|
+
if scoped_token.storage_mode == StorageMode.DEFAULT_STORAGE:
|
993
|
+
return PresignedUrlArtifactRepository(
|
994
|
+
self.get_host_creds(), model_version.name, model_version.version
|
995
|
+
)
|
996
|
+
|
997
|
+
return get_artifact_repo_from_storage_info(
|
998
|
+
storage_location=model_version.storage_location,
|
999
|
+
scoped_token=scoped_token,
|
1000
|
+
base_credential_refresh_def=base_credential_refresh_def,
|
1001
|
+
)
|
1002
|
+
|
1003
|
+
def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
|
1004
|
+
"""
|
1005
|
+
Update model version stage.
|
1006
|
+
|
1007
|
+
Args:
|
1008
|
+
name: Registered model name.
|
1009
|
+
version: Registered model version.
|
1010
|
+
stage: New desired stage for this model version.
|
1011
|
+
archive_existing_versions: If this flag is set to ``True``, all existing model
|
1012
|
+
versions in the stage will be automatically moved to the "archived" stage. Only
|
1013
|
+
valid when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will be
|
1014
|
+
raised.
|
1015
|
+
"""
|
1016
|
+
_raise_unsupported_method(
|
1017
|
+
method="transition_model_version_stage",
|
1018
|
+
message="We recommend using aliases instead of stages for more flexible model "
|
1019
|
+
"deployment management. You can set an alias on a registered model using "
|
1020
|
+
"`MlflowClient().set_registered_model_alias(name, alias, version)` and load a model "
|
1021
|
+
"version by alias using the URI 'models:/your_model_name@your_alias', e.g. "
|
1022
|
+
"`mlflow.pyfunc.load_model('models:/your_model_name@your_alias')`.",
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
def update_model_version(self, name, version, description):
|
1026
|
+
"""
|
1027
|
+
Update metadata associated with a model version in backend.
|
1028
|
+
|
1029
|
+
Args:
|
1030
|
+
name: Registered model name.
|
1031
|
+
version: Registered model version.
|
1032
|
+
description: New model description.
|
1033
|
+
|
1034
|
+
Returns:
|
1035
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
1036
|
+
|
1037
|
+
"""
|
1038
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1039
|
+
req_body = message_to_json(
|
1040
|
+
UpdateModelVersionRequest(name=full_name, version=str(version), description=description)
|
1041
|
+
)
|
1042
|
+
response_proto = self._call_endpoint(UpdateModelVersionRequest, req_body)
|
1043
|
+
return model_version_from_uc_proto(response_proto.model_version)
|
1044
|
+
|
1045
|
+
def delete_model_version(self, name, version):
|
1046
|
+
"""
|
1047
|
+
Delete model version in backend.
|
1048
|
+
|
1049
|
+
Args:
|
1050
|
+
name: Registered model name.
|
1051
|
+
version: Registered model version.
|
1052
|
+
|
1053
|
+
Returns:
|
1054
|
+
None
|
1055
|
+
"""
|
1056
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1057
|
+
req_body = message_to_json(DeleteModelVersionRequest(name=full_name, version=str(version)))
|
1058
|
+
self._call_endpoint(DeleteModelVersionRequest, req_body)
|
1059
|
+
|
1060
|
+
def get_model_version(self, name, version):
|
1061
|
+
"""
|
1062
|
+
Get the model version instance by name and version.
|
1063
|
+
|
1064
|
+
Args:
|
1065
|
+
name: Registered model name.
|
1066
|
+
version: Registered model version.
|
1067
|
+
|
1068
|
+
Returns:
|
1069
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
1070
|
+
"""
|
1071
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1072
|
+
req_body = message_to_json(GetModelVersionRequest(name=full_name, version=str(version)))
|
1073
|
+
response_proto = self._call_endpoint(GetModelVersionRequest, req_body)
|
1074
|
+
return model_version_from_uc_proto(response_proto.model_version)
|
1075
|
+
|
1076
|
+
def get_model_version_download_uri(self, name, version):
|
1077
|
+
"""
|
1078
|
+
Get the download location in Model Registry for this model version.
|
1079
|
+
NOTE: For first version of Model Registry, since the models are not copied over to another
|
1080
|
+
location, download URI points to input source path.
|
1081
|
+
|
1082
|
+
Args:
|
1083
|
+
name: Registered model name.
|
1084
|
+
version: Registered model version.
|
1085
|
+
|
1086
|
+
Returns:
|
1087
|
+
A single URI location that allows reads for downloading.
|
1088
|
+
"""
|
1089
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1090
|
+
req_body = message_to_json(
|
1091
|
+
GetModelVersionDownloadUriRequest(name=full_name, version=str(version))
|
1092
|
+
)
|
1093
|
+
response_proto = self._call_endpoint(GetModelVersionDownloadUriRequest, req_body)
|
1094
|
+
return response_proto.artifact_uri
|
1095
|
+
|
1096
|
+
def search_model_versions(
|
1097
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
1098
|
+
):
|
1099
|
+
"""
|
1100
|
+
Search for model versions in backend that satisfy the filter criteria.
|
1101
|
+
|
1102
|
+
Args:
|
1103
|
+
filter_string: A filter string expression. Currently supports a single filter
|
1104
|
+
condition either name of model like ``name = 'model_name'`` or
|
1105
|
+
``run_id = '...'``.
|
1106
|
+
max_results: Maximum number of model versions desired.
|
1107
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
1108
|
+
matching search results.
|
1109
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
1110
|
+
a ``search_model_versions`` call.
|
1111
|
+
|
1112
|
+
Returns:
|
1113
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
1114
|
+
objects that satisfy the search expressions. The pagination token for the next
|
1115
|
+
page can be obtained via the ``token`` attribute of the object.
|
1116
|
+
|
1117
|
+
"""
|
1118
|
+
_require_arg_unspecified(arg_name="order_by", arg_value=order_by)
|
1119
|
+
req_body = message_to_json(
|
1120
|
+
SearchModelVersionsRequest(
|
1121
|
+
filter=filter_string, page_token=page_token, max_results=max_results
|
1122
|
+
)
|
1123
|
+
)
|
1124
|
+
response_proto = self._call_endpoint(SearchModelVersionsRequest, req_body)
|
1125
|
+
model_versions = [
|
1126
|
+
model_version_search_from_uc_proto(mvd) for mvd in response_proto.model_versions
|
1127
|
+
]
|
1128
|
+
return PagedList(model_versions, response_proto.next_page_token)
|
1129
|
+
|
1130
|
+
def set_model_version_tag(self, name, version, tag):
|
1131
|
+
"""
|
1132
|
+
Set a tag for the model version.
|
1133
|
+
|
1134
|
+
Args:
|
1135
|
+
name: Registered model name.
|
1136
|
+
version: Registered model version.
|
1137
|
+
tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
|
1138
|
+
"""
|
1139
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1140
|
+
req_body = message_to_json(
|
1141
|
+
SetModelVersionTagRequest(
|
1142
|
+
name=full_name, version=str(version), key=tag.key, value=tag.value
|
1143
|
+
)
|
1144
|
+
)
|
1145
|
+
self._call_endpoint(SetModelVersionTagRequest, req_body)
|
1146
|
+
|
1147
|
+
def delete_model_version_tag(self, name, version, key):
|
1148
|
+
"""
|
1149
|
+
Delete a tag associated with the model version.
|
1150
|
+
|
1151
|
+
Args:
|
1152
|
+
name: Registered model name.
|
1153
|
+
version: Registered model version.
|
1154
|
+
key: Tag key.
|
1155
|
+
"""
|
1156
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1157
|
+
req_body = message_to_json(
|
1158
|
+
DeleteModelVersionTagRequest(name=full_name, version=version, key=key)
|
1159
|
+
)
|
1160
|
+
self._call_endpoint(DeleteModelVersionTagRequest, req_body)
|
1161
|
+
|
1162
|
+
def set_registered_model_alias(self, name, alias, version):
|
1163
|
+
"""
|
1164
|
+
Set a registered model alias pointing to a model version.
|
1165
|
+
|
1166
|
+
Args:
|
1167
|
+
name: Registered model name.
|
1168
|
+
alias: Name of the alias.
|
1169
|
+
version: Registered model version number.
|
1170
|
+
|
1171
|
+
Returns:
|
1172
|
+
None
|
1173
|
+
"""
|
1174
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1175
|
+
req_body = message_to_json(
|
1176
|
+
SetRegisteredModelAliasRequest(name=full_name, alias=alias, version=str(version))
|
1177
|
+
)
|
1178
|
+
self._call_endpoint(SetRegisteredModelAliasRequest, req_body)
|
1179
|
+
|
1180
|
+
def delete_registered_model_alias(self, name, alias):
|
1181
|
+
"""
|
1182
|
+
Delete an alias associated with a registered model.
|
1183
|
+
|
1184
|
+
Args:
|
1185
|
+
name: Registered model name.
|
1186
|
+
alias: Name of the alias.
|
1187
|
+
|
1188
|
+
Returns:
|
1189
|
+
None
|
1190
|
+
"""
|
1191
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1192
|
+
req_body = message_to_json(DeleteRegisteredModelAliasRequest(name=full_name, alias=alias))
|
1193
|
+
self._call_endpoint(DeleteRegisteredModelAliasRequest, req_body)
|
1194
|
+
|
1195
|
+
def get_model_version_by_alias(self, name, alias):
|
1196
|
+
"""
|
1197
|
+
Get the model version instance by name and alias.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
name: Registered model name.
|
1201
|
+
alias: Name of the alias.
|
1202
|
+
|
1203
|
+
Returns:
|
1204
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
1205
|
+
"""
|
1206
|
+
full_name = get_full_name_from_sc(name, self.spark)
|
1207
|
+
req_body = message_to_json(GetModelVersionByAliasRequest(name=full_name, alias=alias))
|
1208
|
+
response_proto = self._call_endpoint(GetModelVersionByAliasRequest, req_body)
|
1209
|
+
return model_version_from_uc_proto(response_proto.model_version)
|
1210
|
+
|
1211
|
+
def _await_model_version_creation(self, mv, await_creation_for):
|
1212
|
+
"""
|
1213
|
+
Does not wait for the model version to become READY as a successful creation will
|
1214
|
+
immediately place the model version in a READY state.
|
1215
|
+
"""
|
1216
|
+
|
1217
|
+
# Prompt-related method overrides for UC
|
1218
|
+
|
1219
|
+
def create_prompt(
|
1220
|
+
self,
|
1221
|
+
name: str,
|
1222
|
+
description: Optional[str] = None,
|
1223
|
+
tags: Optional[dict[str, str]] = None,
|
1224
|
+
) -> Prompt:
|
1225
|
+
"""
|
1226
|
+
Create a new prompt in Unity Catalog (metadata only, no initial version).
|
1227
|
+
"""
|
1228
|
+
# Create a Prompt object with the provided fields
|
1229
|
+
prompt_proto = ProtoPrompt()
|
1230
|
+
prompt_proto.name = name
|
1231
|
+
if description:
|
1232
|
+
prompt_proto.description = description
|
1233
|
+
if tags:
|
1234
|
+
prompt_proto.tags.extend(mlflow_tags_to_proto(tags))
|
1235
|
+
|
1236
|
+
req_body = message_to_json(
|
1237
|
+
CreatePromptRequest(
|
1238
|
+
name=name,
|
1239
|
+
prompt=prompt_proto,
|
1240
|
+
)
|
1241
|
+
)
|
1242
|
+
response_proto = self._call_endpoint(CreatePromptRequest, req_body)
|
1243
|
+
return proto_info_to_mlflow_prompt_info(response_proto, tags or {})
|
1244
|
+
|
1245
|
+
def search_prompts(
|
1246
|
+
self,
|
1247
|
+
filter_string: Optional[str] = None,
|
1248
|
+
max_results: Optional[int] = None,
|
1249
|
+
order_by: Optional[list[str]] = None,
|
1250
|
+
page_token: Optional[str] = None,
|
1251
|
+
) -> PagedList[Prompt]:
|
1252
|
+
"""
|
1253
|
+
Search for prompts in Unity Catalog.
|
1254
|
+
|
1255
|
+
Args:
|
1256
|
+
filter_string: Filter string that must include catalog and schema in the format:
|
1257
|
+
"catalog = 'catalog_name' AND schema = 'schema_name'"
|
1258
|
+
max_results: Maximum number of results to return
|
1259
|
+
order_by: List of fields to order by (not used in current implementation)
|
1260
|
+
page_token: Token for pagination
|
1261
|
+
"""
|
1262
|
+
# Parse catalog and schema from filter string
|
1263
|
+
if filter_string:
|
1264
|
+
parsed_filter = self._parse_catalog_schema_from_filter(filter_string)
|
1265
|
+
else:
|
1266
|
+
raise MlflowException(
|
1267
|
+
"For Unity Catalog prompt registries, you must specify catalog and schema "
|
1268
|
+
"in the filter string: \"catalog = 'catalog_name' AND schema = 'schema_name'\"",
|
1269
|
+
INVALID_PARAMETER_VALUE,
|
1270
|
+
)
|
1271
|
+
|
1272
|
+
# Build the request with Unity Catalog schema
|
1273
|
+
unity_catalog_schema = UnityCatalogSchema(
|
1274
|
+
catalog_name=parsed_filter.catalog_name, schema_name=parsed_filter.schema_name
|
1275
|
+
)
|
1276
|
+
req_body = message_to_json(
|
1277
|
+
SearchPromptsRequest(
|
1278
|
+
catalog_schema=unity_catalog_schema,
|
1279
|
+
filter=parsed_filter.remaining_filter,
|
1280
|
+
max_results=max_results,
|
1281
|
+
page_token=page_token,
|
1282
|
+
)
|
1283
|
+
)
|
1284
|
+
|
1285
|
+
response_proto = self._call_endpoint(SearchPromptsRequest, req_body)
|
1286
|
+
prompts = []
|
1287
|
+
for prompt_info in response_proto.prompts:
|
1288
|
+
# For UC, only use the basic prompt info without extra tag fetching
|
1289
|
+
prompts.append(proto_info_to_mlflow_prompt_info(prompt_info, {}))
|
1290
|
+
|
1291
|
+
return PagedList(prompts, response_proto.next_page_token)
|
1292
|
+
|
1293
|
+
def _parse_catalog_schema_from_filter(
|
1294
|
+
self, filter_string: Optional[str]
|
1295
|
+
) -> _CatalogSchemaFilter:
|
1296
|
+
"""
|
1297
|
+
Parse catalog and schema from filter string for Unity Catalog using regex.
|
1298
|
+
|
1299
|
+
Expects filter format: "catalog = 'catalog_name' AND schema = 'schema_name'"
|
1300
|
+
|
1301
|
+
Args:
|
1302
|
+
filter_string: Filter string containing catalog and schema
|
1303
|
+
|
1304
|
+
Returns:
|
1305
|
+
_CatalogSchemaFilter object with catalog_name, schema_name, and remaining_filter
|
1306
|
+
|
1307
|
+
Raises:
|
1308
|
+
MlflowException: If filter format is invalid for Unity Catalog
|
1309
|
+
"""
|
1310
|
+
if not filter_string:
|
1311
|
+
raise MlflowException(
|
1312
|
+
"For Unity Catalog prompt registries, you must specify catalog and schema "
|
1313
|
+
"in the filter string: \"catalog = 'catalog_name' AND schema = 'schema_name'\"",
|
1314
|
+
INVALID_PARAMETER_VALUE,
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
# Use pre-compiled regex patterns for better performance
|
1318
|
+
catalog_match = _CATALOG_PATTERN.search(filter_string)
|
1319
|
+
schema_match = _SCHEMA_PATTERN.search(filter_string)
|
1320
|
+
|
1321
|
+
if not catalog_match or not schema_match:
|
1322
|
+
raise MlflowException(
|
1323
|
+
"For Unity Catalog prompt registries, filter string must include both "
|
1324
|
+
"catalog and schema in the format: "
|
1325
|
+
"\"catalog = 'catalog_name' AND schema = 'schema_name'\". "
|
1326
|
+
f"Got: {filter_string}",
|
1327
|
+
INVALID_PARAMETER_VALUE,
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
catalog_name = catalog_match.group(1)
|
1331
|
+
schema_name = schema_match.group(1)
|
1332
|
+
|
1333
|
+
# Remove catalog and schema from filter string to get remaining filters
|
1334
|
+
# First, normalize the filter by splitting on AND and rebuilding
|
1335
|
+
# without catalog/schema parts
|
1336
|
+
parts = re.split(r"\s+AND\s+", filter_string, flags=re.IGNORECASE)
|
1337
|
+
remaining_parts = []
|
1338
|
+
|
1339
|
+
for part in parts:
|
1340
|
+
part = part.strip()
|
1341
|
+
# Skip parts that match catalog or schema patterns
|
1342
|
+
if not (_CATALOG_PATTERN.match(part) or _SCHEMA_PATTERN.match(part)):
|
1343
|
+
remaining_parts.append(part)
|
1344
|
+
|
1345
|
+
# Rejoin the remaining parts
|
1346
|
+
remaining_filter = " AND ".join(remaining_parts) if remaining_parts else None
|
1347
|
+
|
1348
|
+
return _CatalogSchemaFilter(catalog_name, schema_name, remaining_filter)
|
1349
|
+
|
1350
|
+
def delete_prompt(self, name: str) -> None:
|
1351
|
+
"""
|
1352
|
+
Delete a prompt from Unity Catalog.
|
1353
|
+
"""
|
1354
|
+
req_body = message_to_json(DeletePromptRequest(name=name))
|
1355
|
+
endpoint, method = self._get_endpoint_from_method(DeletePromptRequest)
|
1356
|
+
self._edit_endpoint_and_call(
|
1357
|
+
endpoint=endpoint,
|
1358
|
+
method=method,
|
1359
|
+
req_body=req_body,
|
1360
|
+
name=name,
|
1361
|
+
proto_name=DeletePromptRequest,
|
1362
|
+
)
|
1363
|
+
|
1364
|
+
def set_prompt_tag(self, name: str, key: str, value: str) -> None:
|
1365
|
+
"""
|
1366
|
+
Set a tag on a prompt in Unity Catalog.
|
1367
|
+
"""
|
1368
|
+
req_body = message_to_json(SetPromptTagRequest(name=name, key=key, value=value))
|
1369
|
+
endpoint, method = self._get_endpoint_from_method(SetPromptTagRequest)
|
1370
|
+
self._edit_endpoint_and_call(
|
1371
|
+
endpoint=endpoint,
|
1372
|
+
method=method,
|
1373
|
+
req_body=req_body,
|
1374
|
+
name=name,
|
1375
|
+
key=key,
|
1376
|
+
proto_name=SetPromptTagRequest,
|
1377
|
+
)
|
1378
|
+
|
1379
|
+
def delete_prompt_tag(self, name: str, key: str) -> None:
|
1380
|
+
"""
|
1381
|
+
Delete a tag from a prompt in Unity Catalog.
|
1382
|
+
"""
|
1383
|
+
req_body = message_to_json(DeletePromptTagRequest(name=name, key=key))
|
1384
|
+
endpoint, method = self._get_endpoint_from_method(DeletePromptTagRequest)
|
1385
|
+
self._edit_endpoint_and_call(
|
1386
|
+
endpoint=endpoint,
|
1387
|
+
method=method,
|
1388
|
+
req_body=req_body,
|
1389
|
+
name=name,
|
1390
|
+
key=key,
|
1391
|
+
proto_name=DeletePromptTagRequest,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
def get_prompt(self, name: str) -> Optional[Prompt]:
|
1395
|
+
"""
|
1396
|
+
Get prompt by name from Unity Catalog.
|
1397
|
+
"""
|
1398
|
+
try:
|
1399
|
+
req_body = message_to_json(GetPromptRequest(name=name))
|
1400
|
+
endpoint, method = self._get_endpoint_from_method(GetPromptRequest)
|
1401
|
+
response_proto = self._edit_endpoint_and_call(
|
1402
|
+
endpoint=endpoint,
|
1403
|
+
method=method,
|
1404
|
+
req_body=req_body,
|
1405
|
+
name=name,
|
1406
|
+
proto_name=GetPromptRequest,
|
1407
|
+
)
|
1408
|
+
return proto_info_to_mlflow_prompt_info(response_proto, {})
|
1409
|
+
except Exception as e:
|
1410
|
+
if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
|
1411
|
+
RESOURCE_DOES_NOT_EXIST
|
1412
|
+
):
|
1413
|
+
return None
|
1414
|
+
raise
|
1415
|
+
|
1416
|
+
def create_prompt_version(
|
1417
|
+
self,
|
1418
|
+
name: str,
|
1419
|
+
template: str,
|
1420
|
+
description: Optional[str] = None,
|
1421
|
+
tags: Optional[dict[str, str]] = None,
|
1422
|
+
) -> PromptVersion:
|
1423
|
+
"""
|
1424
|
+
Create a new prompt version in Unity Catalog.
|
1425
|
+
"""
|
1426
|
+
# Create a PromptVersion object with the provided fields
|
1427
|
+
prompt_version_proto = ProtoPromptVersion()
|
1428
|
+
prompt_version_proto.name = name
|
1429
|
+
# JSON-encode the template for Unity Catalog server
|
1430
|
+
prompt_version_proto.template = json.dumps(template)
|
1431
|
+
|
1432
|
+
# Note: version will be set by the backend when creating a new version
|
1433
|
+
# We don't set it here as it's generated server-side
|
1434
|
+
if description:
|
1435
|
+
prompt_version_proto.description = description
|
1436
|
+
if tags:
|
1437
|
+
prompt_version_proto.tags.extend(mlflow_tags_to_proto_version_tags(tags))
|
1438
|
+
|
1439
|
+
req_body = message_to_json(
|
1440
|
+
CreatePromptVersionRequest(
|
1441
|
+
name=name,
|
1442
|
+
prompt_version=prompt_version_proto,
|
1443
|
+
)
|
1444
|
+
)
|
1445
|
+
endpoint, method = self._get_endpoint_from_method(CreatePromptVersionRequest)
|
1446
|
+
response_proto = self._edit_endpoint_and_call(
|
1447
|
+
endpoint=endpoint,
|
1448
|
+
method=method,
|
1449
|
+
req_body=req_body,
|
1450
|
+
name=name,
|
1451
|
+
proto_name=CreatePromptVersionRequest,
|
1452
|
+
)
|
1453
|
+
return proto_to_mlflow_prompt(response_proto)
|
1454
|
+
|
1455
|
+
def get_prompt_version(self, name: str, version: Union[str, int]) -> Optional[PromptVersion]:
|
1456
|
+
"""
|
1457
|
+
Get a specific prompt version from Unity Catalog.
|
1458
|
+
"""
|
1459
|
+
try:
|
1460
|
+
req_body = message_to_json(GetPromptVersionRequest(name=name, version=str(version)))
|
1461
|
+
endpoint, method = self._get_endpoint_from_method(GetPromptVersionRequest)
|
1462
|
+
response_proto = self._edit_endpoint_and_call(
|
1463
|
+
endpoint=endpoint,
|
1464
|
+
method=method,
|
1465
|
+
req_body=req_body,
|
1466
|
+
name=name,
|
1467
|
+
version=version,
|
1468
|
+
proto_name=GetPromptVersionRequest,
|
1469
|
+
)
|
1470
|
+
|
1471
|
+
# No longer fetch prompt-level tags - keep them completely separate
|
1472
|
+
return proto_to_mlflow_prompt(response_proto)
|
1473
|
+
except Exception as e:
|
1474
|
+
if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
|
1475
|
+
RESOURCE_DOES_NOT_EXIST
|
1476
|
+
):
|
1477
|
+
return None
|
1478
|
+
raise
|
1479
|
+
|
1480
|
+
def delete_prompt_version(self, name: str, version: Union[str, int]) -> None:
|
1481
|
+
"""
|
1482
|
+
Delete a prompt version from Unity Catalog.
|
1483
|
+
"""
|
1484
|
+
# Delete the specific version only
|
1485
|
+
req_body = message_to_json(DeletePromptVersionRequest(name=name, version=str(version)))
|
1486
|
+
endpoint, method = self._get_endpoint_from_method(DeletePromptVersionRequest)
|
1487
|
+
self._edit_endpoint_and_call(
|
1488
|
+
endpoint=endpoint,
|
1489
|
+
method=method,
|
1490
|
+
req_body=req_body,
|
1491
|
+
name=name,
|
1492
|
+
version=version,
|
1493
|
+
proto_name=DeletePromptVersionRequest,
|
1494
|
+
)
|
1495
|
+
|
1496
|
+
def search_prompt_versions(
|
1497
|
+
self, name: str, max_results: Optional[int] = None, page_token: Optional[str] = None
|
1498
|
+
) -> SearchPromptVersionsResponse:
|
1499
|
+
"""
|
1500
|
+
Search prompt versions for a given prompt name in Unity Catalog.
|
1501
|
+
|
1502
|
+
Note: Unity Catalog server uses a non-standard endpoint pattern for this operation.
|
1503
|
+
|
1504
|
+
Args:
|
1505
|
+
name: Name of the prompt to search versions for
|
1506
|
+
max_results: Maximum number of versions to return
|
1507
|
+
page_token: Token for pagination
|
1508
|
+
|
1509
|
+
Returns:
|
1510
|
+
SearchPromptVersionsResponse containing the list of versions
|
1511
|
+
"""
|
1512
|
+
req_body = message_to_json(
|
1513
|
+
SearchPromptVersionsRequest(name=name, max_results=max_results, page_token=page_token)
|
1514
|
+
)
|
1515
|
+
endpoint, method = self._get_endpoint_from_method(SearchPromptVersionsRequest)
|
1516
|
+
return self._edit_endpoint_and_call(
|
1517
|
+
endpoint=endpoint,
|
1518
|
+
method=method,
|
1519
|
+
req_body=req_body,
|
1520
|
+
name=name,
|
1521
|
+
proto_name=SearchPromptVersionsRequest,
|
1522
|
+
)
|
1523
|
+
|
1524
|
+
def set_prompt_version_tag(
|
1525
|
+
self, name: str, version: Union[str, int], key: str, value: str
|
1526
|
+
) -> None:
|
1527
|
+
"""
|
1528
|
+
Set a tag on a prompt version in Unity Catalog.
|
1529
|
+
"""
|
1530
|
+
req_body = message_to_json(
|
1531
|
+
SetPromptVersionTagRequest(name=name, version=str(version), key=key, value=value)
|
1532
|
+
)
|
1533
|
+
endpoint, method = self._get_endpoint_from_method(SetPromptVersionTagRequest)
|
1534
|
+
self._edit_endpoint_and_call(
|
1535
|
+
endpoint=endpoint,
|
1536
|
+
method=method,
|
1537
|
+
req_body=req_body,
|
1538
|
+
name=name,
|
1539
|
+
version=version,
|
1540
|
+
key=key,
|
1541
|
+
proto_name=SetPromptVersionTagRequest,
|
1542
|
+
)
|
1543
|
+
|
1544
|
+
def delete_prompt_version_tag(self, name: str, version: Union[str, int], key: str) -> None:
|
1545
|
+
"""
|
1546
|
+
Delete a tag from a prompt version in Unity Catalog.
|
1547
|
+
"""
|
1548
|
+
req_body = message_to_json(
|
1549
|
+
DeletePromptVersionTagRequest(name=name, version=str(version), key=key)
|
1550
|
+
)
|
1551
|
+
endpoint, method = self._get_endpoint_from_method(DeletePromptVersionTagRequest)
|
1552
|
+
self._edit_endpoint_and_call(
|
1553
|
+
endpoint=endpoint,
|
1554
|
+
method=method,
|
1555
|
+
req_body=req_body,
|
1556
|
+
name=name,
|
1557
|
+
version=version,
|
1558
|
+
key=key,
|
1559
|
+
proto_name=DeletePromptVersionTagRequest,
|
1560
|
+
)
|
1561
|
+
|
1562
|
+
def get_prompt_version_by_alias(self, name: str, alias: str) -> Optional[PromptVersion]:
|
1563
|
+
"""
|
1564
|
+
Get a prompt version by alias from Unity Catalog.
|
1565
|
+
"""
|
1566
|
+
try:
|
1567
|
+
req_body = message_to_json(GetPromptVersionByAliasRequest(name=name, alias=alias))
|
1568
|
+
endpoint, method = self._get_endpoint_from_method(GetPromptVersionByAliasRequest)
|
1569
|
+
response_proto = self._edit_endpoint_and_call(
|
1570
|
+
endpoint=endpoint,
|
1571
|
+
method=method,
|
1572
|
+
req_body=req_body,
|
1573
|
+
name=name,
|
1574
|
+
alias=alias,
|
1575
|
+
proto_name=GetPromptVersionByAliasRequest,
|
1576
|
+
)
|
1577
|
+
|
1578
|
+
# No longer fetch prompt-level tags - keep them completely separate
|
1579
|
+
return proto_to_mlflow_prompt(response_proto)
|
1580
|
+
except Exception as e:
|
1581
|
+
if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
|
1582
|
+
RESOURCE_DOES_NOT_EXIST
|
1583
|
+
):
|
1584
|
+
return None
|
1585
|
+
raise
|
1586
|
+
|
1587
|
+
def set_prompt_alias(self, name: str, alias: str, version: Union[str, int]) -> None:
|
1588
|
+
"""
|
1589
|
+
Set an alias for a prompt version in Unity Catalog.
|
1590
|
+
"""
|
1591
|
+
req_body = message_to_json(
|
1592
|
+
SetPromptAliasRequest(name=name, alias=alias, version=str(version))
|
1593
|
+
)
|
1594
|
+
endpoint, method = self._get_endpoint_from_method(SetPromptAliasRequest)
|
1595
|
+
self._edit_endpoint_and_call(
|
1596
|
+
endpoint=endpoint,
|
1597
|
+
method=method,
|
1598
|
+
req_body=req_body,
|
1599
|
+
name=name,
|
1600
|
+
alias=alias,
|
1601
|
+
version=version,
|
1602
|
+
proto_name=SetPromptAliasRequest,
|
1603
|
+
)
|
1604
|
+
|
1605
|
+
def delete_prompt_alias(self, name: str, alias: str) -> None:
|
1606
|
+
"""
|
1607
|
+
Delete an alias from a prompt in Unity Catalog.
|
1608
|
+
"""
|
1609
|
+
req_body = message_to_json(DeletePromptAliasRequest(name=name, alias=alias))
|
1610
|
+
endpoint, method = self._get_endpoint_from_method(DeletePromptAliasRequest)
|
1611
|
+
self._edit_endpoint_and_call(
|
1612
|
+
endpoint=endpoint,
|
1613
|
+
method=method,
|
1614
|
+
req_body=req_body,
|
1615
|
+
name=name,
|
1616
|
+
alias=alias,
|
1617
|
+
proto_name=DeletePromptAliasRequest,
|
1618
|
+
)
|
1619
|
+
|
1620
|
+
def link_prompt_version_to_model(self, name: str, version: str, model_id: str) -> None:
|
1621
|
+
"""
|
1622
|
+
Link a prompt version to a model in Unity Catalog.
|
1623
|
+
|
1624
|
+
Args:
|
1625
|
+
name: Name of the prompt.
|
1626
|
+
version: Version of the prompt to link.
|
1627
|
+
model_id: ID of the model to link to.
|
1628
|
+
"""
|
1629
|
+
# Call the default implementation, since the LinkPromptVersionsToModels API
|
1630
|
+
# will initially be a no-op until the Databricks backend supports it
|
1631
|
+
super().link_prompt_version_to_model(name=name, version=version, model_id=model_id)
|
1632
|
+
|
1633
|
+
prompt_version_entry = PromptVersionLinkEntry(name=name, version=version)
|
1634
|
+
req_body = message_to_json(
|
1635
|
+
LinkPromptVersionsToModelsRequest(
|
1636
|
+
prompt_versions=[prompt_version_entry], model_ids=[model_id]
|
1637
|
+
)
|
1638
|
+
)
|
1639
|
+
endpoint, method = self._get_endpoint_from_method(LinkPromptVersionsToModelsRequest)
|
1640
|
+
try:
|
1641
|
+
# NB: This will not raise an exception if the backend does not support linking.
|
1642
|
+
# We do this to prioritize reduction in errors and log spam while the prompt
|
1643
|
+
# registry remains experimental
|
1644
|
+
self._edit_endpoint_and_call(
|
1645
|
+
endpoint=endpoint,
|
1646
|
+
method=method,
|
1647
|
+
req_body=req_body,
|
1648
|
+
name=name,
|
1649
|
+
version=version,
|
1650
|
+
model_id=model_id,
|
1651
|
+
proto_name=LinkPromptVersionsToModelsRequest,
|
1652
|
+
)
|
1653
|
+
except Exception:
|
1654
|
+
_logger.debug("Failed to link prompt version to model in unity catalog", exc_info=True)
|
1655
|
+
|
1656
|
+
def link_prompts_to_trace(self, prompt_versions: list[PromptVersion], trace_id: str) -> None:
|
1657
|
+
"""
|
1658
|
+
Link multiple prompt versions to a trace in Unity Catalog.
|
1659
|
+
|
1660
|
+
Args:
|
1661
|
+
prompt_versions: List of PromptVersion objects to link.
|
1662
|
+
trace_id: Trace ID to link to each prompt version.
|
1663
|
+
"""
|
1664
|
+
super().link_prompts_to_trace(prompt_versions=prompt_versions, trace_id=trace_id)
|
1665
|
+
|
1666
|
+
prompt_version_entries = [
|
1667
|
+
PromptVersionLinkEntry(name=pv.name, version=str(pv.version)) for pv in prompt_versions
|
1668
|
+
]
|
1669
|
+
|
1670
|
+
batch_size = 25
|
1671
|
+
endpoint, method = self._get_endpoint_from_method(LinkPromptsToTracesRequest)
|
1672
|
+
|
1673
|
+
for i in range(0, len(prompt_version_entries), batch_size):
|
1674
|
+
batch = prompt_version_entries[i : i + batch_size]
|
1675
|
+
req_body = message_to_json(
|
1676
|
+
LinkPromptsToTracesRequest(prompt_versions=batch, trace_ids=[trace_id])
|
1677
|
+
)
|
1678
|
+
try:
|
1679
|
+
self._edit_endpoint_and_call(
|
1680
|
+
endpoint=endpoint,
|
1681
|
+
method=method,
|
1682
|
+
req_body=req_body,
|
1683
|
+
proto_name=LinkPromptsToTracesRequest,
|
1684
|
+
)
|
1685
|
+
except Exception:
|
1686
|
+
_logger.debug("Failed to link prompts to traces in unity catalog", exc_info=True)
|
1687
|
+
|
1688
|
+
def link_prompt_version_to_run(self, name: str, version: str, run_id: str) -> None:
|
1689
|
+
"""
|
1690
|
+
Link a prompt version to a run in Unity Catalog.
|
1691
|
+
|
1692
|
+
Args:
|
1693
|
+
name: Name of the prompt.
|
1694
|
+
version: Version of the prompt to link.
|
1695
|
+
run_id: ID of the run to link to.
|
1696
|
+
"""
|
1697
|
+
super().link_prompt_version_to_run(name=name, version=version, run_id=run_id)
|
1698
|
+
|
1699
|
+
prompt_version_entry = PromptVersionLinkEntry(name=name, version=version)
|
1700
|
+
endpoint, method = self._get_endpoint_from_method(LinkPromptVersionsToRunsRequest)
|
1701
|
+
|
1702
|
+
req_body = message_to_json(
|
1703
|
+
LinkPromptVersionsToRunsRequest(
|
1704
|
+
prompt_versions=[prompt_version_entry], run_ids=[run_id]
|
1705
|
+
)
|
1706
|
+
)
|
1707
|
+
try:
|
1708
|
+
self._edit_endpoint_and_call(
|
1709
|
+
endpoint=endpoint,
|
1710
|
+
method=method,
|
1711
|
+
req_body=req_body,
|
1712
|
+
proto_name=LinkPromptVersionsToRunsRequest,
|
1713
|
+
)
|
1714
|
+
except Exception:
|
1715
|
+
_logger.debug("Failed to link prompt version to run in unity catalog", exc_info=True)
|
1716
|
+
|
1717
|
+
def _edit_endpoint_and_call(self, endpoint, method, req_body, proto_name, **kwargs):
|
1718
|
+
"""
|
1719
|
+
Edit endpoint URL with parameters and make the call.
|
1720
|
+
|
1721
|
+
Args:
|
1722
|
+
endpoint: URL template with placeholders like {name}, {key}
|
1723
|
+
method: HTTP method
|
1724
|
+
req_body: Request body
|
1725
|
+
proto_name: Protobuf message class for response
|
1726
|
+
**kwargs: Parameters to substitute in the endpoint template
|
1727
|
+
"""
|
1728
|
+
# Replace placeholders in endpoint with actual values
|
1729
|
+
for key, value in kwargs.items():
|
1730
|
+
if value is not None:
|
1731
|
+
endpoint = endpoint.replace(f"{{{key}}}", str(value))
|
1732
|
+
|
1733
|
+
# Make the API call
|
1734
|
+
return call_endpoint(
|
1735
|
+
self.get_host_creds(),
|
1736
|
+
endpoint=endpoint,
|
1737
|
+
method=method,
|
1738
|
+
json_body=req_body,
|
1739
|
+
response_proto=self._get_response_from_method(proto_name),
|
1740
|
+
)
|