genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1436 @@
|
|
1
|
+
import functools
|
2
|
+
import getpass
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import platform
|
7
|
+
import subprocess
|
8
|
+
import time
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import TYPE_CHECKING, NamedTuple, Optional, TypeVar
|
11
|
+
|
12
|
+
from mlflow.utils.logging_utils import eprint
|
13
|
+
from mlflow.utils.request_utils import augmented_raise_for_status
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from pyspark.sql.connect.session import SparkSession as SparkConnectSession
|
17
|
+
|
18
|
+
|
19
|
+
import mlflow.utils
|
20
|
+
from mlflow.environment_variables import (
|
21
|
+
MLFLOW_ENABLE_DB_SDK,
|
22
|
+
MLFLOW_TRACKING_URI,
|
23
|
+
)
|
24
|
+
from mlflow.exceptions import MlflowException
|
25
|
+
from mlflow.legacy_databricks_cli.configure.provider import (
|
26
|
+
DatabricksConfig,
|
27
|
+
DatabricksConfigProvider,
|
28
|
+
DatabricksModelServingConfigProvider,
|
29
|
+
EnvironmentVariableConfigProvider,
|
30
|
+
ProfileConfigProvider,
|
31
|
+
SparkTaskContextConfigProvider,
|
32
|
+
)
|
33
|
+
from mlflow.utils._spark_utils import _get_active_spark_session
|
34
|
+
from mlflow.utils.rest_utils import MlflowHostCreds, http_request
|
35
|
+
from mlflow.utils.uri import (
|
36
|
+
_DATABRICKS_UNITY_CATALOG_SCHEME,
|
37
|
+
get_db_info_from_uri,
|
38
|
+
is_databricks_uri,
|
39
|
+
)
|
40
|
+
|
41
|
+
_logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
44
|
+
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
|
45
|
+
|
46
|
+
|
47
|
+
def _use_repl_context_if_available(
|
48
|
+
name: str,
|
49
|
+
*,
|
50
|
+
ignore_none: bool = False,
|
51
|
+
):
|
52
|
+
"""Creates a decorator to insert a short circuit that returns the specified REPL context
|
53
|
+
attribute if it's available.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
name: Attribute name (e.g. "apiUrl").
|
57
|
+
ignore_none: If True, use the original function if the REPL context attribute exists but
|
58
|
+
is None.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Decorator to insert the short circuit.
|
62
|
+
"""
|
63
|
+
|
64
|
+
def decorator(f):
|
65
|
+
@functools.wraps(f)
|
66
|
+
def wrapper(*args, **kwargs):
|
67
|
+
try:
|
68
|
+
from dbruntime.databricks_repl_context import get_context
|
69
|
+
|
70
|
+
context = get_context()
|
71
|
+
if context is not None and hasattr(context, name):
|
72
|
+
attr = getattr(context, name)
|
73
|
+
if attr is None and ignore_none:
|
74
|
+
# do nothing and continue to the original function
|
75
|
+
pass
|
76
|
+
else:
|
77
|
+
return attr
|
78
|
+
except Exception:
|
79
|
+
pass
|
80
|
+
return f(*args, **kwargs)
|
81
|
+
|
82
|
+
return wrapper
|
83
|
+
|
84
|
+
return decorator
|
85
|
+
|
86
|
+
|
87
|
+
def get_mlflow_credential_context_by_run_id(run_id):
|
88
|
+
from mlflow.tracking.artifact_utils import get_artifact_uri
|
89
|
+
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
|
90
|
+
|
91
|
+
run_root_artifact_uri = get_artifact_uri(run_id=run_id)
|
92
|
+
profile = get_databricks_profile_uri_from_artifact_uri(run_root_artifact_uri)
|
93
|
+
return MlflowCredentialContext(profile)
|
94
|
+
|
95
|
+
|
96
|
+
class MlflowCredentialContext:
|
97
|
+
"""Sets and clears credentials on a context using the provided profile URL."""
|
98
|
+
|
99
|
+
def __init__(self, databricks_profile_url):
|
100
|
+
self.databricks_profile_url = databricks_profile_url or "databricks"
|
101
|
+
self.db_utils = _get_dbutils()
|
102
|
+
|
103
|
+
def __enter__(self):
|
104
|
+
db_creds = _get_databricks_creds_config(self.databricks_profile_url)
|
105
|
+
self.db_utils.notebook.entry_point.putMlflowProperties(
|
106
|
+
db_creds.host,
|
107
|
+
db_creds.insecure,
|
108
|
+
db_creds.token,
|
109
|
+
db_creds.username,
|
110
|
+
db_creds.password,
|
111
|
+
)
|
112
|
+
|
113
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
114
|
+
self.db_utils.notebook.entry_point.clearMlflowProperties()
|
115
|
+
|
116
|
+
|
117
|
+
def _get_dbutils():
|
118
|
+
try:
|
119
|
+
import IPython
|
120
|
+
|
121
|
+
ip_shell = IPython.get_ipython()
|
122
|
+
if ip_shell is None:
|
123
|
+
raise _NoDbutilsError
|
124
|
+
return ip_shell.ns_table["user_global"]["dbutils"]
|
125
|
+
except ImportError:
|
126
|
+
raise _NoDbutilsError
|
127
|
+
except KeyError:
|
128
|
+
raise _NoDbutilsError
|
129
|
+
|
130
|
+
|
131
|
+
class _NoDbutilsError(Exception):
|
132
|
+
pass
|
133
|
+
|
134
|
+
|
135
|
+
def _get_java_dbutils():
|
136
|
+
dbutils = _get_dbutils()
|
137
|
+
return dbutils.notebook.entry_point.getDbutils()
|
138
|
+
|
139
|
+
|
140
|
+
def _get_command_context():
|
141
|
+
return _get_java_dbutils().notebook().getContext()
|
142
|
+
|
143
|
+
|
144
|
+
def _get_extra_context(context_key):
|
145
|
+
opt = _get_command_context().extraContext().get(context_key)
|
146
|
+
return opt.get() if opt.isDefined() else None
|
147
|
+
|
148
|
+
|
149
|
+
def _get_context_tag(context_tag_key):
|
150
|
+
try:
|
151
|
+
tag_opt = _get_command_context().tags().get(context_tag_key)
|
152
|
+
if tag_opt.isDefined():
|
153
|
+
return tag_opt.get()
|
154
|
+
except Exception:
|
155
|
+
pass
|
156
|
+
|
157
|
+
return None
|
158
|
+
|
159
|
+
|
160
|
+
@_use_repl_context_if_available("aclPathOfAclRoot")
|
161
|
+
def acl_path_of_acl_root():
|
162
|
+
try:
|
163
|
+
return _get_command_context().aclPathOfAclRoot().get()
|
164
|
+
except Exception:
|
165
|
+
return _get_extra_context("aclPathOfAclRoot")
|
166
|
+
|
167
|
+
|
168
|
+
def _get_property_from_spark_context(key):
|
169
|
+
try:
|
170
|
+
from pyspark import TaskContext
|
171
|
+
|
172
|
+
task_context = TaskContext.get()
|
173
|
+
if task_context:
|
174
|
+
return task_context.getLocalProperty(key)
|
175
|
+
except Exception:
|
176
|
+
return None
|
177
|
+
|
178
|
+
|
179
|
+
def is_databricks_default_tracking_uri(tracking_uri):
|
180
|
+
return tracking_uri.lower().strip() == "databricks"
|
181
|
+
|
182
|
+
|
183
|
+
@_use_repl_context_if_available("isInNotebook")
|
184
|
+
def is_in_databricks_notebook():
|
185
|
+
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
|
186
|
+
return True
|
187
|
+
try:
|
188
|
+
return path.startswith("/workspace") if (path := acl_path_of_acl_root()) else False
|
189
|
+
except Exception:
|
190
|
+
return False
|
191
|
+
|
192
|
+
|
193
|
+
@_use_repl_context_if_available("isInJob")
|
194
|
+
def is_in_databricks_job():
|
195
|
+
try:
|
196
|
+
return get_job_id() is not None and get_job_run_id() is not None
|
197
|
+
except Exception:
|
198
|
+
return False
|
199
|
+
|
200
|
+
|
201
|
+
def is_in_databricks_model_serving_environment():
|
202
|
+
"""
|
203
|
+
Check if the code is running in Databricks Model Serving environment.
|
204
|
+
The environment variable set by Databricks when starting the serving container.
|
205
|
+
"""
|
206
|
+
val = (
|
207
|
+
os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
|
208
|
+
# Checking the old env var name for backward compatibility. The env var was renamed once
|
209
|
+
# to fix a model loading issue, but we still need to support it for a while.
|
210
|
+
# TODO: Remove this once the new env var is fully rolled out.
|
211
|
+
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
|
212
|
+
or "false"
|
213
|
+
)
|
214
|
+
return val.lower() == "true"
|
215
|
+
|
216
|
+
|
217
|
+
def is_mlflow_tracing_enabled_in_model_serving() -> bool:
|
218
|
+
"""
|
219
|
+
This environment variable guards tracing behaviors for models in databricks
|
220
|
+
model serving. Tracing in serving is only enabled when this env var is true.
|
221
|
+
"""
|
222
|
+
return os.environ.get("ENABLE_MLFLOW_TRACING", "false").lower() == "true"
|
223
|
+
|
224
|
+
|
225
|
+
# this should only be the case when we are in model serving environment
|
226
|
+
# and OAuth token file exists in specified path
|
227
|
+
def should_fetch_model_serving_environment_oauth():
|
228
|
+
return (
|
229
|
+
is_in_databricks_model_serving_environment()
|
230
|
+
and os.path.exists(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH)
|
231
|
+
and os.path.isfile(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH)
|
232
|
+
)
|
233
|
+
|
234
|
+
|
235
|
+
def is_in_databricks_repo():
|
236
|
+
try:
|
237
|
+
return get_git_repo_relative_path() is not None
|
238
|
+
except Exception:
|
239
|
+
return False
|
240
|
+
|
241
|
+
|
242
|
+
def is_in_databricks_repo_notebook():
|
243
|
+
try:
|
244
|
+
path = get_notebook_path()
|
245
|
+
return path is not None and path.startswith("/Repos")
|
246
|
+
except Exception:
|
247
|
+
return False
|
248
|
+
|
249
|
+
|
250
|
+
_DATABRICKS_VERSION_FILE_PATH = "/databricks/DBR_VERSION"
|
251
|
+
|
252
|
+
|
253
|
+
def get_databricks_runtime_version():
|
254
|
+
if ver := os.environ.get("DATABRICKS_RUNTIME_VERSION"):
|
255
|
+
return ver
|
256
|
+
if os.path.exists(_DATABRICKS_VERSION_FILE_PATH):
|
257
|
+
# In Databricks DCS cluster, it doesn't have DATABRICKS_RUNTIME_VERSION
|
258
|
+
# environment variable, we have to read version from the version file.
|
259
|
+
with open(_DATABRICKS_VERSION_FILE_PATH) as f:
|
260
|
+
return f.read().strip()
|
261
|
+
return None
|
262
|
+
|
263
|
+
|
264
|
+
def is_in_databricks_runtime():
|
265
|
+
return get_databricks_runtime_version() is not None
|
266
|
+
|
267
|
+
|
268
|
+
def is_in_databricks_serverless_runtime():
|
269
|
+
dbr_version = get_databricks_runtime_version()
|
270
|
+
return dbr_version and dbr_version.startswith("client.")
|
271
|
+
|
272
|
+
|
273
|
+
def is_in_databricks_shared_cluster_runtime():
|
274
|
+
from mlflow.utils.spark_utils import is_spark_connect_mode
|
275
|
+
|
276
|
+
return (
|
277
|
+
is_in_databricks_runtime()
|
278
|
+
and is_spark_connect_mode()
|
279
|
+
and not is_in_databricks_serverless_runtime()
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def is_databricks_connect(spark=None):
|
284
|
+
"""
|
285
|
+
Return True if current Spark-connect client connects to Databricks cluster.
|
286
|
+
"""
|
287
|
+
from mlflow.utils.spark_utils import is_spark_connect_mode
|
288
|
+
|
289
|
+
if is_in_databricks_serverless_runtime() or is_in_databricks_shared_cluster_runtime():
|
290
|
+
return True
|
291
|
+
|
292
|
+
spark = spark or _get_active_spark_session()
|
293
|
+
if spark is None:
|
294
|
+
return False
|
295
|
+
|
296
|
+
if not is_spark_connect_mode():
|
297
|
+
return False
|
298
|
+
|
299
|
+
if hasattr(spark.client, "metadata"):
|
300
|
+
metadata = spark.client.metadata()
|
301
|
+
else:
|
302
|
+
metadata = spark.client._builder.metadata()
|
303
|
+
|
304
|
+
return any(k in ["x-databricks-session-id", "x-databricks-cluster-id"] for k, v in metadata)
|
305
|
+
|
306
|
+
|
307
|
+
@dataclass
|
308
|
+
class DBConnectUDFSandboxInfo:
|
309
|
+
spark: "SparkConnectSession"
|
310
|
+
image_version: str
|
311
|
+
runtime_version: str
|
312
|
+
platform_machine: str
|
313
|
+
mlflow_version: str
|
314
|
+
|
315
|
+
|
316
|
+
_dbconnect_udf_sandbox_info_cache: Optional[DBConnectUDFSandboxInfo] = None
|
317
|
+
|
318
|
+
|
319
|
+
def get_dbconnect_udf_sandbox_info(spark):
|
320
|
+
"""
|
321
|
+
Get Databricks UDF sandbox info which includes the following fields:
|
322
|
+
- image_version like
|
323
|
+
'{major_version}.{minor_version}' or 'client.{major_version}.{minor_version}'
|
324
|
+
- runtime_version like '{major_version}.{minor_version}'
|
325
|
+
- platform_machine like 'x86_64' or 'aarch64'
|
326
|
+
- mlflow_version
|
327
|
+
"""
|
328
|
+
global _dbconnect_udf_sandbox_info_cache
|
329
|
+
from pyspark.sql.functions import pandas_udf
|
330
|
+
|
331
|
+
if (
|
332
|
+
_dbconnect_udf_sandbox_info_cache is not None
|
333
|
+
and spark is _dbconnect_udf_sandbox_info_cache.spark
|
334
|
+
):
|
335
|
+
return _dbconnect_udf_sandbox_info_cache
|
336
|
+
|
337
|
+
# version is like '15.4.x-scala2.12'
|
338
|
+
version = spark.sql("SELECT current_version().dbr_version").collect()[0][0]
|
339
|
+
major, minor, *_rest = version.split(".")
|
340
|
+
runtime_version = f"{major}.{minor}"
|
341
|
+
|
342
|
+
# For Databricks Serverless python REPL,
|
343
|
+
# the UDF sandbox runs on client image, which has version like 'client.1.1'
|
344
|
+
# in other cases, UDF sandbox runs on databricks runtime image with version like '15.4'
|
345
|
+
if is_in_databricks_runtime():
|
346
|
+
_dbconnect_udf_sandbox_info_cache = DBConnectUDFSandboxInfo(
|
347
|
+
spark=_get_active_spark_session(),
|
348
|
+
runtime_version=runtime_version,
|
349
|
+
image_version=get_databricks_runtime_version(),
|
350
|
+
platform_machine=platform.machine(),
|
351
|
+
# In databricks runtime, driver and executor should have the
|
352
|
+
# same version.
|
353
|
+
mlflow_version=mlflow.__version__,
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
image_version = runtime_version
|
357
|
+
|
358
|
+
@pandas_udf("string")
|
359
|
+
def f(_):
|
360
|
+
import pandas as pd
|
361
|
+
|
362
|
+
platform_machine = platform.machine()
|
363
|
+
|
364
|
+
try:
|
365
|
+
import mlflow
|
366
|
+
|
367
|
+
mlflow_version = mlflow.__version__
|
368
|
+
except ImportError:
|
369
|
+
mlflow_version = ""
|
370
|
+
|
371
|
+
return pd.Series([f"{platform_machine}\n{mlflow_version}"])
|
372
|
+
|
373
|
+
platform_machine, mlflow_version = (
|
374
|
+
spark.range(1).select(f("id")).collect()[0][0].split("\n")
|
375
|
+
)
|
376
|
+
if mlflow_version == "":
|
377
|
+
mlflow_version = None
|
378
|
+
_dbconnect_udf_sandbox_info_cache = DBConnectUDFSandboxInfo(
|
379
|
+
spark=spark,
|
380
|
+
image_version=image_version,
|
381
|
+
runtime_version=runtime_version,
|
382
|
+
platform_machine=platform_machine,
|
383
|
+
mlflow_version=mlflow_version,
|
384
|
+
)
|
385
|
+
|
386
|
+
return _dbconnect_udf_sandbox_info_cache
|
387
|
+
|
388
|
+
|
389
|
+
def is_databricks_serverless(spark):
|
390
|
+
"""
|
391
|
+
Return True if running on Databricks Serverless notebook or
|
392
|
+
on Databricks Connect client that connects to Databricks Serverless.
|
393
|
+
"""
|
394
|
+
from mlflow.utils.spark_utils import is_spark_connect_mode
|
395
|
+
|
396
|
+
if not is_spark_connect_mode():
|
397
|
+
return False
|
398
|
+
|
399
|
+
if hasattr(spark.client, "metadata"):
|
400
|
+
metadata = spark.client.metadata()
|
401
|
+
else:
|
402
|
+
metadata = spark.client._builder.metadata()
|
403
|
+
|
404
|
+
return any(k == "x-databricks-session-id" for k, v in metadata)
|
405
|
+
|
406
|
+
|
407
|
+
def is_dbfs_fuse_available():
|
408
|
+
if not is_in_databricks_runtime():
|
409
|
+
return False
|
410
|
+
|
411
|
+
try:
|
412
|
+
return (
|
413
|
+
subprocess.call(
|
414
|
+
["mountpoint", "/dbfs"],
|
415
|
+
stderr=subprocess.DEVNULL,
|
416
|
+
stdout=subprocess.DEVNULL,
|
417
|
+
)
|
418
|
+
== 0
|
419
|
+
)
|
420
|
+
except Exception:
|
421
|
+
return False
|
422
|
+
|
423
|
+
|
424
|
+
def is_uc_volume_fuse_available():
|
425
|
+
try:
|
426
|
+
return (
|
427
|
+
subprocess.call(
|
428
|
+
["mountpoint", "/Volumes"],
|
429
|
+
stderr=subprocess.DEVNULL,
|
430
|
+
stdout=subprocess.DEVNULL,
|
431
|
+
)
|
432
|
+
== 0
|
433
|
+
)
|
434
|
+
except Exception:
|
435
|
+
return False
|
436
|
+
|
437
|
+
|
438
|
+
@_use_repl_context_if_available("isInCluster")
|
439
|
+
def is_in_cluster():
|
440
|
+
try:
|
441
|
+
spark_session = _get_active_spark_session()
|
442
|
+
return (
|
443
|
+
spark_session is not None
|
444
|
+
and spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId", None)
|
445
|
+
is not None
|
446
|
+
)
|
447
|
+
except Exception:
|
448
|
+
return False
|
449
|
+
|
450
|
+
|
451
|
+
@_use_repl_context_if_available("notebookId")
|
452
|
+
def get_notebook_id():
|
453
|
+
"""Should only be called if is_in_databricks_notebook is true"""
|
454
|
+
if notebook_id := _get_property_from_spark_context("spark.databricks.notebook.id"):
|
455
|
+
return notebook_id
|
456
|
+
if (path := acl_path_of_acl_root()) and path.startswith("/workspace"):
|
457
|
+
return path.split("/")[-1]
|
458
|
+
return None
|
459
|
+
|
460
|
+
|
461
|
+
@_use_repl_context_if_available("notebookPath")
|
462
|
+
def get_notebook_path():
|
463
|
+
"""Should only be called if is_in_databricks_notebook is true"""
|
464
|
+
path = _get_property_from_spark_context("spark.databricks.notebook.path")
|
465
|
+
if path is not None:
|
466
|
+
return path
|
467
|
+
try:
|
468
|
+
return _get_command_context().notebookPath().get()
|
469
|
+
except Exception:
|
470
|
+
return _get_extra_context("notebook_path")
|
471
|
+
|
472
|
+
|
473
|
+
@_use_repl_context_if_available("clusterId")
|
474
|
+
def get_cluster_id():
|
475
|
+
spark_session = _get_active_spark_session()
|
476
|
+
if spark_session is None:
|
477
|
+
return None
|
478
|
+
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId", None)
|
479
|
+
|
480
|
+
|
481
|
+
@_use_repl_context_if_available("jobGroupId")
|
482
|
+
def get_job_group_id():
|
483
|
+
try:
|
484
|
+
dbutils = _get_dbutils()
|
485
|
+
job_group_id = dbutils.entry_point.getJobGroupId()
|
486
|
+
if job_group_id is not None:
|
487
|
+
return job_group_id
|
488
|
+
except Exception:
|
489
|
+
return None
|
490
|
+
|
491
|
+
|
492
|
+
@_use_repl_context_if_available("replId")
|
493
|
+
def get_repl_id():
|
494
|
+
"""
|
495
|
+
Returns:
|
496
|
+
The ID of the current Databricks Python REPL.
|
497
|
+
"""
|
498
|
+
# Attempt to fetch the REPL ID from the Python REPL's entrypoint object. This REPL ID
|
499
|
+
# is guaranteed to be set upon REPL startup in DBR / MLR 9.0
|
500
|
+
try:
|
501
|
+
dbutils = _get_dbutils()
|
502
|
+
repl_id = dbutils.entry_point.getReplId()
|
503
|
+
if repl_id is not None:
|
504
|
+
return repl_id
|
505
|
+
except Exception:
|
506
|
+
pass
|
507
|
+
|
508
|
+
# If the REPL ID entrypoint property is unavailable due to an older runtime version (< 9.0),
|
509
|
+
# attempt to fetch the REPL ID from the Spark Context. This property may not be available
|
510
|
+
# until several seconds after REPL startup
|
511
|
+
try:
|
512
|
+
from pyspark import SparkContext
|
513
|
+
|
514
|
+
repl_id = SparkContext.getOrCreate().getLocalProperty("spark.databricks.replId")
|
515
|
+
if repl_id is not None:
|
516
|
+
return repl_id
|
517
|
+
except Exception:
|
518
|
+
pass
|
519
|
+
|
520
|
+
|
521
|
+
@_use_repl_context_if_available("jobId")
|
522
|
+
def get_job_id():
|
523
|
+
try:
|
524
|
+
return _get_command_context().jobId().get()
|
525
|
+
except Exception:
|
526
|
+
return _get_context_tag("jobId")
|
527
|
+
|
528
|
+
|
529
|
+
@_use_repl_context_if_available("idInJob")
|
530
|
+
def get_job_run_id():
|
531
|
+
try:
|
532
|
+
return _get_command_context().idInJob().get()
|
533
|
+
except Exception:
|
534
|
+
return _get_context_tag("idInJob")
|
535
|
+
|
536
|
+
|
537
|
+
@_use_repl_context_if_available("jobTaskType")
|
538
|
+
def get_job_type():
|
539
|
+
"""Should only be called if is_in_databricks_job is true"""
|
540
|
+
try:
|
541
|
+
return _get_command_context().jobTaskType().get()
|
542
|
+
except Exception:
|
543
|
+
return _get_context_tag("jobTaskType")
|
544
|
+
|
545
|
+
|
546
|
+
@_use_repl_context_if_available("jobType")
|
547
|
+
def get_job_type_info():
|
548
|
+
try:
|
549
|
+
return _get_context_tag("jobType")
|
550
|
+
except Exception:
|
551
|
+
return None
|
552
|
+
|
553
|
+
|
554
|
+
@_use_repl_context_if_available("commandRunId")
|
555
|
+
def get_command_run_id():
|
556
|
+
try:
|
557
|
+
return _get_command_context().commandRunId().get()
|
558
|
+
except Exception:
|
559
|
+
# Older runtimes may not have the commandRunId available
|
560
|
+
return None
|
561
|
+
|
562
|
+
|
563
|
+
@_use_repl_context_if_available("workloadId")
|
564
|
+
def get_workload_id():
|
565
|
+
try:
|
566
|
+
return _get_command_context().workloadId().get()
|
567
|
+
except Exception:
|
568
|
+
return _get_context_tag("workloadId")
|
569
|
+
|
570
|
+
|
571
|
+
@_use_repl_context_if_available("workloadClass")
|
572
|
+
def get_workload_class():
|
573
|
+
try:
|
574
|
+
return _get_command_context().workloadClass().get()
|
575
|
+
except Exception:
|
576
|
+
return _get_context_tag("workloadClass")
|
577
|
+
|
578
|
+
|
579
|
+
@_use_repl_context_if_available("apiUrl")
|
580
|
+
def get_webapp_url():
|
581
|
+
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
|
582
|
+
url = _get_property_from_spark_context("spark.databricks.api.url")
|
583
|
+
if url is not None:
|
584
|
+
return url
|
585
|
+
try:
|
586
|
+
return _get_command_context().apiUrl().get()
|
587
|
+
except Exception:
|
588
|
+
return _get_extra_context("api_url")
|
589
|
+
|
590
|
+
|
591
|
+
@_use_repl_context_if_available("workspaceId")
|
592
|
+
def get_workspace_id():
|
593
|
+
try:
|
594
|
+
return _get_command_context().workspaceId().get()
|
595
|
+
except Exception:
|
596
|
+
return _get_context_tag("orgId")
|
597
|
+
|
598
|
+
|
599
|
+
@_use_repl_context_if_available("browserHostName")
|
600
|
+
def get_browser_hostname():
|
601
|
+
try:
|
602
|
+
return _get_command_context().browserHostName().get()
|
603
|
+
except Exception:
|
604
|
+
return _get_context_tag("browserHostName")
|
605
|
+
|
606
|
+
|
607
|
+
def get_workspace_info_from_dbutils():
|
608
|
+
try:
|
609
|
+
dbutils = _get_dbutils()
|
610
|
+
if dbutils:
|
611
|
+
browser_hostname = get_browser_hostname()
|
612
|
+
workspace_host = "https://" + browser_hostname if browser_hostname else get_webapp_url()
|
613
|
+
workspace_id = get_workspace_id()
|
614
|
+
return workspace_host, workspace_id
|
615
|
+
except Exception:
|
616
|
+
pass
|
617
|
+
return None, None
|
618
|
+
|
619
|
+
|
620
|
+
@_use_repl_context_if_available("workspaceUrl", ignore_none=True)
|
621
|
+
def _get_workspace_url():
|
622
|
+
try:
|
623
|
+
if spark_session := _get_active_spark_session():
|
624
|
+
if workspace_url := spark_session.conf.get("spark.databricks.workspaceUrl", None):
|
625
|
+
return workspace_url
|
626
|
+
except Exception:
|
627
|
+
return None
|
628
|
+
|
629
|
+
|
630
|
+
def get_workspace_url():
|
631
|
+
if url := _get_workspace_url():
|
632
|
+
return f"https://{url}" if not url.startswith("https://") else url
|
633
|
+
return None
|
634
|
+
|
635
|
+
|
636
|
+
def warn_on_deprecated_cross_workspace_registry_uri(registry_uri):
|
637
|
+
workspace_host, workspace_id = get_workspace_info_from_databricks_secrets(
|
638
|
+
tracking_uri=registry_uri
|
639
|
+
)
|
640
|
+
if workspace_host is not None or workspace_id is not None:
|
641
|
+
_logger.warning(
|
642
|
+
"Accessing remote workspace model registries using registry URIs of the form "
|
643
|
+
"'databricks://scope:prefix', or by loading models via URIs of the form "
|
644
|
+
"'models://scope:prefix@databricks/model-name/stage-or-version', is deprecated. "
|
645
|
+
"Use Models in Unity Catalog instead for easy cross-workspace model access, with "
|
646
|
+
"granular per-user audit logging and no extra setup required. See "
|
647
|
+
"https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html "
|
648
|
+
"for more details."
|
649
|
+
)
|
650
|
+
|
651
|
+
|
652
|
+
def get_workspace_info_from_databricks_secrets(tracking_uri):
|
653
|
+
profile, key_prefix = get_db_info_from_uri(tracking_uri)
|
654
|
+
if key_prefix:
|
655
|
+
dbutils = _get_dbutils()
|
656
|
+
if dbutils:
|
657
|
+
workspace_id = dbutils.secrets.get(scope=profile, key=key_prefix + "-workspace-id")
|
658
|
+
workspace_host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host")
|
659
|
+
return workspace_host, workspace_id
|
660
|
+
return None, None
|
661
|
+
|
662
|
+
|
663
|
+
def _fail_malformed_databricks_auth(uri):
|
664
|
+
if uri and uri.startswith(_DATABRICKS_UNITY_CATALOG_SCHEME):
|
665
|
+
uri_name = "registry URI"
|
666
|
+
uri_scheme = _DATABRICKS_UNITY_CATALOG_SCHEME
|
667
|
+
else:
|
668
|
+
uri_name = "tracking URI"
|
669
|
+
uri_scheme = "databricks"
|
670
|
+
if is_in_databricks_model_serving_environment():
|
671
|
+
raise MlflowException(
|
672
|
+
f"Reading Databricks credential configuration in model serving failed. "
|
673
|
+
f"Most commonly, this happens because the model currently "
|
674
|
+
f"being served was logged without Databricks resource dependencies "
|
675
|
+
f"properly specified. Re-log your model, specifying resource dependencies as "
|
676
|
+
f"described in "
|
677
|
+
f"https://docs.databricks.com/en/generative-ai/agent-framework/log-agent.html"
|
678
|
+
f"#specify-resources-for-pyfunc-or-langchain-agent "
|
679
|
+
f"and then register and attempt to serve it again. Alternatively, you can explicitly "
|
680
|
+
f"configure authentication by setting environment variables as described in "
|
681
|
+
f"https://docs.databricks.com/en/generative-ai/agent-framework/deploy-agent.html"
|
682
|
+
f"#manual-authentication. "
|
683
|
+
f"Additional debug info: the MLflow {uri_name} was set to '{uri}'"
|
684
|
+
)
|
685
|
+
raise MlflowException(
|
686
|
+
f"Reading Databricks credential configuration failed with MLflow {uri_name} '{uri}'. "
|
687
|
+
"Please ensure that the 'databricks-sdk' PyPI library is installed, the tracking "
|
688
|
+
"URI is set correctly, and Databricks authentication is properly configured. "
|
689
|
+
f"The {uri_name} can be either '{uri_scheme}' "
|
690
|
+
f"(using profile name specified by 'DATABRICKS_CONFIG_PROFILE' environment variable "
|
691
|
+
f"or using 'DEFAULT' authentication profile if 'DATABRICKS_CONFIG_PROFILE' environment "
|
692
|
+
f"variable does not exist) or '{uri_scheme}://{{profile}}'. "
|
693
|
+
"You can configure Databricks authentication in several ways, for example by "
|
694
|
+
"specifying environment variables (e.g. DATABRICKS_HOST + DATABRICKS_TOKEN) or "
|
695
|
+
"logging in using 'databricks auth login'. \n"
|
696
|
+
"For details on configuring Databricks authentication, please refer to "
|
697
|
+
"'https://docs.databricks.com/en/dev-tools/auth/index.html#unified-auth'."
|
698
|
+
)
|
699
|
+
|
700
|
+
|
701
|
+
# Helper function to attempt to read OAuth Token from
|
702
|
+
# mounted file in Databricks Model Serving environment
|
703
|
+
def get_model_dependency_oauth_token(should_retry=True):
|
704
|
+
try:
|
705
|
+
with open(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
|
706
|
+
oauth_dict = json.load(f)
|
707
|
+
return oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
|
708
|
+
except Exception as e:
|
709
|
+
# sleep and retry in case of any race conditions with OAuth refreshing
|
710
|
+
if should_retry:
|
711
|
+
time.sleep(0.5)
|
712
|
+
return get_model_dependency_oauth_token(should_retry=False)
|
713
|
+
else:
|
714
|
+
raise MlflowException(
|
715
|
+
"Unable to read Oauth credentials from file mount for Databricks "
|
716
|
+
"Model Serving dependency failed"
|
717
|
+
) from e
|
718
|
+
|
719
|
+
|
720
|
+
class TrackingURIConfigProvider(DatabricksConfigProvider):
|
721
|
+
"""
|
722
|
+
TrackingURIConfigProvider extracts `scope` and `key_prefix` from tracking URI
|
723
|
+
of format like `databricks://scope:key_prefix`,
|
724
|
+
then read host and token value from dbutils secrets by key
|
725
|
+
"{key_prefix}-host" and "{key_prefix}-token"
|
726
|
+
|
727
|
+
This provider only works in Databricks runtime and it is deprecated,
|
728
|
+
in Databricks runtime you can simply use 'databricks'
|
729
|
+
as the tracking URI and MLflow can automatically read dynamic token in
|
730
|
+
Databricks runtime.
|
731
|
+
"""
|
732
|
+
|
733
|
+
def __init__(self, tracking_uri):
|
734
|
+
self.tracking_uri = tracking_uri
|
735
|
+
|
736
|
+
def get_config(self):
|
737
|
+
scope, key_prefix = get_db_info_from_uri(self.tracking_uri)
|
738
|
+
|
739
|
+
if scope and key_prefix:
|
740
|
+
dbutils = _get_dbutils()
|
741
|
+
if dbutils:
|
742
|
+
# Prefix differentiates users and is provided as path information in the URI
|
743
|
+
host = dbutils.secrets.get(scope=scope, key=key_prefix + "-host")
|
744
|
+
token = dbutils.secrets.get(scope=scope, key=key_prefix + "-token")
|
745
|
+
return DatabricksConfig.from_token(host=host, token=token, insecure=False)
|
746
|
+
|
747
|
+
return None
|
748
|
+
|
749
|
+
|
750
|
+
def get_databricks_host_creds(server_uri=None):
|
751
|
+
"""
|
752
|
+
Reads in configuration necessary to make HTTP requests to a Databricks server. This
|
753
|
+
uses Databricks SDK workspace client API,
|
754
|
+
If no available credential configuration is found to the server URI, this function
|
755
|
+
will attempt to retrieve these credentials from the Databricks Secret Manager. For that to work,
|
756
|
+
the server URI will need to be of the following format: "databricks://scope:prefix". In the
|
757
|
+
Databricks Secret Manager, we will query for a secret in the scope "<scope>" for secrets with
|
758
|
+
keys of the form "<prefix>-host" and "<prefix>-token". Note that this prefix *cannot* be empty
|
759
|
+
if trying to authenticate with this method. If found, those host credentials will be used. This
|
760
|
+
method will throw an exception if sufficient auth cannot be found.
|
761
|
+
|
762
|
+
Args:
|
763
|
+
server_uri: A URI that specifies the Databricks profile you want to use for making
|
764
|
+
requests.
|
765
|
+
|
766
|
+
Returns:
|
767
|
+
MlflowHostCreds which includes the hostname if databricks sdk authentication is available,
|
768
|
+
otherwise includes the hostname and authentication information necessary to
|
769
|
+
talk to the Databricks server.
|
770
|
+
|
771
|
+
.. Warning:: This API is deprecated. In the future it might be removed.
|
772
|
+
"""
|
773
|
+
|
774
|
+
if MLFLOW_ENABLE_DB_SDK.get():
|
775
|
+
from databricks.sdk import WorkspaceClient
|
776
|
+
|
777
|
+
profile, key_prefix = get_db_info_from_uri(server_uri)
|
778
|
+
profile = profile or os.environ.get("DATABRICKS_CONFIG_PROFILE")
|
779
|
+
if key_prefix is not None:
|
780
|
+
try:
|
781
|
+
config = TrackingURIConfigProvider(server_uri).get_config()
|
782
|
+
WorkspaceClient(host=config.host, token=config.token)
|
783
|
+
return MlflowHostCreds(
|
784
|
+
config.host,
|
785
|
+
token=config.token,
|
786
|
+
use_databricks_sdk=True,
|
787
|
+
use_secret_scope_token=True,
|
788
|
+
)
|
789
|
+
except Exception as e:
|
790
|
+
raise MlflowException(
|
791
|
+
f"The hostname and credentials configured by {server_uri} is invalid. "
|
792
|
+
"Please create valid hostname secret by command "
|
793
|
+
f"'databricks secrets put-secret {profile} {key_prefix}-host' and "
|
794
|
+
"create valid token secret by command "
|
795
|
+
f"'databricks secrets put-secret {profile} {key_prefix}-token'."
|
796
|
+
) from e
|
797
|
+
try:
|
798
|
+
# Using databricks-sdk to create Databricks WorkspaceClient instance,
|
799
|
+
# If authentication is failed, MLflow falls back to legacy authentication methods,
|
800
|
+
# see `SparkTaskContextConfigProvider`, `DatabricksModelServingConfigProvider`,
|
801
|
+
# and `TrackingURIConfigProvider`.
|
802
|
+
# databricks-sdk supports many kinds of authentication ways,
|
803
|
+
# it will try to read authentication information by the following ways:
|
804
|
+
# 1. Read dynamic generated token via databricks `dbutils`.
|
805
|
+
# 2. parse relevant environment variables (such as DATABRICKS_HOST + DATABRICKS_TOKEN
|
806
|
+
# or DATABRICKS_HOST + DATABRICKS_CLIENT_ID + DATABRICKS_CLIENT_SECRET)
|
807
|
+
# to get authentication information
|
808
|
+
# 3. parse ~/.databrickscfg file (generated by databricks-CLI command-line tool)
|
809
|
+
# to get authentication information.
|
810
|
+
# databricks-sdk is designed to hide authentication details and
|
811
|
+
# support various authentication ways, so that it does not provide API
|
812
|
+
# to get credential values. Instead, we can use ``WorkspaceClient``
|
813
|
+
# API to invoke databricks shard restful APIs.
|
814
|
+
WorkspaceClient(profile=profile)
|
815
|
+
use_databricks_sdk = True
|
816
|
+
databricks_auth_profile = profile
|
817
|
+
except Exception as e:
|
818
|
+
_logger.debug(f"Failed to create databricks SDK workspace client, error: {e!r}")
|
819
|
+
use_databricks_sdk = False
|
820
|
+
databricks_auth_profile = None
|
821
|
+
else:
|
822
|
+
use_databricks_sdk = False
|
823
|
+
databricks_auth_profile = None
|
824
|
+
|
825
|
+
config = _get_databricks_creds_config(server_uri)
|
826
|
+
|
827
|
+
if not config:
|
828
|
+
_fail_malformed_databricks_auth(profile)
|
829
|
+
|
830
|
+
return MlflowHostCreds(
|
831
|
+
config.host,
|
832
|
+
username=config.username,
|
833
|
+
password=config.password,
|
834
|
+
ignore_tls_verification=config.insecure,
|
835
|
+
token=config.token,
|
836
|
+
client_id=config.client_id,
|
837
|
+
client_secret=config.client_secret,
|
838
|
+
use_databricks_sdk=use_databricks_sdk,
|
839
|
+
databricks_auth_profile=databricks_auth_profile,
|
840
|
+
)
|
841
|
+
|
842
|
+
|
843
|
+
@_use_repl_context_if_available("mlflowGitRepoUrl")
|
844
|
+
def get_git_repo_url():
|
845
|
+
try:
|
846
|
+
return _get_command_context().mlflowGitRepoUrl().get()
|
847
|
+
except Exception:
|
848
|
+
return _get_extra_context("mlflowGitUrl")
|
849
|
+
|
850
|
+
|
851
|
+
@_use_repl_context_if_available("mlflowGitRepoProvider")
|
852
|
+
def get_git_repo_provider():
|
853
|
+
try:
|
854
|
+
return _get_command_context().mlflowGitRepoProvider().get()
|
855
|
+
except Exception:
|
856
|
+
return _get_extra_context("mlflowGitProvider")
|
857
|
+
|
858
|
+
|
859
|
+
@_use_repl_context_if_available("mlflowGitRepoCommit")
|
860
|
+
def get_git_repo_commit():
|
861
|
+
try:
|
862
|
+
return _get_command_context().mlflowGitRepoCommit().get()
|
863
|
+
except Exception:
|
864
|
+
return _get_extra_context("mlflowGitCommit")
|
865
|
+
|
866
|
+
|
867
|
+
@_use_repl_context_if_available("mlflowGitRelativePath")
|
868
|
+
def get_git_repo_relative_path():
|
869
|
+
try:
|
870
|
+
return _get_command_context().mlflowGitRelativePath().get()
|
871
|
+
except Exception:
|
872
|
+
return _get_extra_context("mlflowGitRelativePath")
|
873
|
+
|
874
|
+
|
875
|
+
@_use_repl_context_if_available("mlflowGitRepoReference")
|
876
|
+
def get_git_repo_reference():
|
877
|
+
try:
|
878
|
+
return _get_command_context().mlflowGitRepoReference().get()
|
879
|
+
except Exception:
|
880
|
+
return _get_extra_context("mlflowGitReference")
|
881
|
+
|
882
|
+
|
883
|
+
@_use_repl_context_if_available("mlflowGitRepoReferenceType")
|
884
|
+
def get_git_repo_reference_type():
|
885
|
+
try:
|
886
|
+
return _get_command_context().mlflowGitRepoReferenceType().get()
|
887
|
+
except Exception:
|
888
|
+
return _get_extra_context("mlflowGitReferenceType")
|
889
|
+
|
890
|
+
|
891
|
+
@_use_repl_context_if_available("mlflowGitRepoStatus")
|
892
|
+
def get_git_repo_status():
|
893
|
+
try:
|
894
|
+
return _get_command_context().mlflowGitRepoStatus().get()
|
895
|
+
except Exception:
|
896
|
+
return _get_extra_context("mlflowGitStatus")
|
897
|
+
|
898
|
+
|
899
|
+
def is_running_in_ipython_environment():
|
900
|
+
try:
|
901
|
+
from IPython import get_ipython
|
902
|
+
|
903
|
+
return get_ipython() is not None
|
904
|
+
except (ImportError, ModuleNotFoundError):
|
905
|
+
return False
|
906
|
+
|
907
|
+
|
908
|
+
def get_databricks_run_url(tracking_uri: str, run_id: str, artifact_path=None) -> Optional[str]:
|
909
|
+
"""
|
910
|
+
Obtains a Databricks URL corresponding to the specified MLflow Run, optionally referring
|
911
|
+
to an artifact within the run.
|
912
|
+
|
913
|
+
Args:
|
914
|
+
tracking_uri: The URI of the MLflow Tracking server containing the Run.
|
915
|
+
run_id: The ID of the MLflow Run for which to obtain a Databricks URL.
|
916
|
+
artifact_path: An optional relative artifact path within the Run to which the URL
|
917
|
+
should refer.
|
918
|
+
|
919
|
+
Returns:
|
920
|
+
A Databricks URL corresponding to the specified MLflow Run
|
921
|
+
(and artifact path, if specified), or None if the MLflow Run does not belong to a
|
922
|
+
Databricks Workspace.
|
923
|
+
"""
|
924
|
+
from mlflow.tracking.client import MlflowClient
|
925
|
+
|
926
|
+
try:
|
927
|
+
workspace_info = (
|
928
|
+
DatabricksWorkspaceInfo.from_environment()
|
929
|
+
or get_databricks_workspace_info_from_uri(tracking_uri)
|
930
|
+
)
|
931
|
+
if workspace_info is not None:
|
932
|
+
experiment_id = MlflowClient(tracking_uri).get_run(run_id).info.experiment_id
|
933
|
+
return _construct_databricks_run_url(
|
934
|
+
host=workspace_info.host,
|
935
|
+
experiment_id=experiment_id,
|
936
|
+
run_id=run_id,
|
937
|
+
workspace_id=workspace_info.workspace_id,
|
938
|
+
artifact_path=artifact_path,
|
939
|
+
)
|
940
|
+
except Exception:
|
941
|
+
return None
|
942
|
+
|
943
|
+
|
944
|
+
def get_databricks_model_version_url(registry_uri: str, name: str, version: str) -> Optional[str]:
|
945
|
+
"""Obtains a Databricks URL corresponding to the specified Model Version.
|
946
|
+
|
947
|
+
Args:
|
948
|
+
registry_uri: The URI of the Model Registry server containing the Model Version.
|
949
|
+
name: The name of the registered model containing the Model Version.
|
950
|
+
version: Version number of the Model Version.
|
951
|
+
|
952
|
+
Returns:
|
953
|
+
A Databricks URL corresponding to the specified Model Version, or None if the
|
954
|
+
Model Version does not belong to a Databricks Workspace.
|
955
|
+
|
956
|
+
"""
|
957
|
+
try:
|
958
|
+
workspace_info = (
|
959
|
+
DatabricksWorkspaceInfo.from_environment()
|
960
|
+
or get_databricks_workspace_info_from_uri(registry_uri)
|
961
|
+
)
|
962
|
+
if workspace_info is not None:
|
963
|
+
return _construct_databricks_model_version_url(
|
964
|
+
host=workspace_info.host,
|
965
|
+
name=name,
|
966
|
+
version=version,
|
967
|
+
workspace_id=workspace_info.workspace_id,
|
968
|
+
)
|
969
|
+
except Exception:
|
970
|
+
return None
|
971
|
+
|
972
|
+
|
973
|
+
DatabricksWorkspaceInfoType = TypeVar("DatabricksWorkspaceInfo", bound="DatabricksWorkspaceInfo")
|
974
|
+
|
975
|
+
|
976
|
+
class DatabricksWorkspaceInfo:
|
977
|
+
WORKSPACE_HOST_ENV_VAR = "_DATABRICKS_WORKSPACE_HOST"
|
978
|
+
WORKSPACE_ID_ENV_VAR = "_DATABRICKS_WORKSPACE_ID"
|
979
|
+
|
980
|
+
def __init__(self, host: str, workspace_id: Optional[str] = None):
|
981
|
+
self.host = host
|
982
|
+
self.workspace_id = workspace_id
|
983
|
+
|
984
|
+
@classmethod
|
985
|
+
def from_environment(cls) -> Optional[DatabricksWorkspaceInfoType]:
|
986
|
+
if DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR in os.environ:
|
987
|
+
return DatabricksWorkspaceInfo(
|
988
|
+
host=os.environ[DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR],
|
989
|
+
workspace_id=os.environ.get(DatabricksWorkspaceInfo.WORKSPACE_ID_ENV_VAR),
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
return None
|
993
|
+
|
994
|
+
def to_environment(self):
|
995
|
+
env = {
|
996
|
+
DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR: self.host,
|
997
|
+
}
|
998
|
+
if self.workspace_id is not None:
|
999
|
+
env[DatabricksWorkspaceInfo.WORKSPACE_ID_ENV_VAR] = self.workspace_id
|
1000
|
+
|
1001
|
+
return env
|
1002
|
+
|
1003
|
+
|
1004
|
+
def get_databricks_workspace_info_from_uri(tracking_uri: str) -> Optional[DatabricksWorkspaceInfo]:
|
1005
|
+
if not is_databricks_uri(tracking_uri):
|
1006
|
+
return None
|
1007
|
+
|
1008
|
+
if is_databricks_default_tracking_uri(tracking_uri) and (
|
1009
|
+
is_in_databricks_notebook() or is_in_databricks_job()
|
1010
|
+
):
|
1011
|
+
workspace_host, workspace_id = get_workspace_info_from_dbutils()
|
1012
|
+
else:
|
1013
|
+
workspace_host, workspace_id = get_workspace_info_from_databricks_secrets(tracking_uri)
|
1014
|
+
if not workspace_id:
|
1015
|
+
_logger.info(
|
1016
|
+
"No workspace ID specified; if your Databricks workspaces share the same"
|
1017
|
+
" host URL, you may want to specify the workspace ID (along with the host"
|
1018
|
+
" information in the secret manager) for run lineage tracking. For more"
|
1019
|
+
" details on how to specify this information in the secret manager,"
|
1020
|
+
" please refer to the Databricks MLflow documentation."
|
1021
|
+
)
|
1022
|
+
|
1023
|
+
if workspace_host:
|
1024
|
+
return DatabricksWorkspaceInfo(host=workspace_host, workspace_id=workspace_id)
|
1025
|
+
else:
|
1026
|
+
return None
|
1027
|
+
|
1028
|
+
|
1029
|
+
def check_databricks_secret_scope_access(scope_name):
|
1030
|
+
dbutils = _get_dbutils()
|
1031
|
+
if dbutils:
|
1032
|
+
try:
|
1033
|
+
dbutils.secrets.list(scope_name)
|
1034
|
+
except Exception as e:
|
1035
|
+
_logger.warning(
|
1036
|
+
f"Unable to access Databricks secret scope '{scope_name}' for OpenAI credentials "
|
1037
|
+
"that will be used to deploy the model to Databricks Model Serving. "
|
1038
|
+
"Please verify that the current Databricks user has 'READ' permission for "
|
1039
|
+
"this scope. For more information, see "
|
1040
|
+
"https://mlflow.org/docs/latest/python_api/openai/index.html#credential-management-for-openai-on-databricks. " # noqa: E501
|
1041
|
+
f"Error: {e}"
|
1042
|
+
)
|
1043
|
+
|
1044
|
+
|
1045
|
+
def _construct_databricks_run_url(
|
1046
|
+
host: str,
|
1047
|
+
experiment_id: str,
|
1048
|
+
run_id: str,
|
1049
|
+
workspace_id: Optional[str] = None,
|
1050
|
+
artifact_path: Optional[str] = None,
|
1051
|
+
) -> str:
|
1052
|
+
run_url = host
|
1053
|
+
if workspace_id and workspace_id != "0":
|
1054
|
+
run_url += "?o=" + str(workspace_id)
|
1055
|
+
|
1056
|
+
run_url += f"#mlflow/experiments/{experiment_id}/runs/{run_id}"
|
1057
|
+
|
1058
|
+
if artifact_path is not None:
|
1059
|
+
run_url += f"/artifactPath/{artifact_path.lstrip('/')}"
|
1060
|
+
|
1061
|
+
return run_url
|
1062
|
+
|
1063
|
+
|
1064
|
+
def _construct_databricks_model_version_url(
|
1065
|
+
host: str, name: str, version: str, workspace_id: Optional[str] = None
|
1066
|
+
) -> str:
|
1067
|
+
model_version_url = host
|
1068
|
+
if workspace_id and workspace_id != "0":
|
1069
|
+
model_version_url += "?o=" + str(workspace_id)
|
1070
|
+
|
1071
|
+
model_version_url += f"#mlflow/models/{name}/versions/{version}"
|
1072
|
+
|
1073
|
+
return model_version_url
|
1074
|
+
|
1075
|
+
|
1076
|
+
def _construct_databricks_logged_model_url(
|
1077
|
+
workspace_url: str, experiment_id: str, model_id: str, workspace_id: Optional[str] = None
|
1078
|
+
) -> str:
|
1079
|
+
"""
|
1080
|
+
Get a Databricks URL for a given registered model version in Unity Catalog.
|
1081
|
+
|
1082
|
+
Args:
|
1083
|
+
workspace_url: The URL of the workspace the registered model is in.
|
1084
|
+
experiment_id: The ID of the experiment the model is logged to.
|
1085
|
+
model_id: The ID of the logged model to create the URL for.
|
1086
|
+
workspace_id: The ID of the workspace to include as a query parameter (if provided).
|
1087
|
+
|
1088
|
+
Returns:
|
1089
|
+
The Databricks URL for a registered model in Unity Catalog.
|
1090
|
+
"""
|
1091
|
+
query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
|
1092
|
+
return f"{workspace_url}/ml/experiments/{experiment_id}/models/{model_id}{query}"
|
1093
|
+
|
1094
|
+
|
1095
|
+
def _construct_databricks_uc_registered_model_url(
|
1096
|
+
workspace_url: str, registered_model_name: str, version: str, workspace_id: Optional[str] = None
|
1097
|
+
) -> str:
|
1098
|
+
"""
|
1099
|
+
Get a Databricks URL for a given registered model version in Unity Catalog.
|
1100
|
+
|
1101
|
+
Args:
|
1102
|
+
workspace_url: The URL of the workspace the registered model is in.
|
1103
|
+
registered_model_name: The full name of the registered model containing the version.
|
1104
|
+
version: The version of the registered model to create the URL for.
|
1105
|
+
workspace_id: The ID of the workspace to include as a query parameter (if provided).
|
1106
|
+
|
1107
|
+
Returns:
|
1108
|
+
The Databricks URL for a registered model in Unity Catalog.
|
1109
|
+
"""
|
1110
|
+
path = registered_model_name.replace(".", "/")
|
1111
|
+
query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
|
1112
|
+
return f"{workspace_url}/explore/data/models/{path}/version/{version}{query}"
|
1113
|
+
|
1114
|
+
|
1115
|
+
def _print_databricks_deployment_job_url(
|
1116
|
+
model_name: str,
|
1117
|
+
job_id: str,
|
1118
|
+
workspace_url: Optional[str] = None,
|
1119
|
+
workspace_id: Optional[str] = None,
|
1120
|
+
) -> str:
|
1121
|
+
if not workspace_url:
|
1122
|
+
workspace_url = get_workspace_url()
|
1123
|
+
if not workspace_id:
|
1124
|
+
workspace_id = get_workspace_id()
|
1125
|
+
# If there is no workspace_url, we cannot print the job URL
|
1126
|
+
if not workspace_url:
|
1127
|
+
return None
|
1128
|
+
|
1129
|
+
query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
|
1130
|
+
job_url = f"{workspace_url}/jobs/{job_id}{query}"
|
1131
|
+
eprint(f"🔗 Linked deployment job to '{model_name}': {job_url}")
|
1132
|
+
return job_url
|
1133
|
+
|
1134
|
+
|
1135
|
+
def _get_databricks_creds_config(tracking_uri):
|
1136
|
+
# Note:
|
1137
|
+
# `_get_databricks_creds_config` reads credential token values or password and
|
1138
|
+
# returns a `DatabricksConfig` object
|
1139
|
+
# Databricks-SDK API doesn't support reading credential token values,
|
1140
|
+
# so that in this function we still have to use
|
1141
|
+
# configuration providers defined in legacy Databricks CLI python library to
|
1142
|
+
# read token values.
|
1143
|
+
profile, key_prefix = get_db_info_from_uri(tracking_uri)
|
1144
|
+
|
1145
|
+
config = None
|
1146
|
+
|
1147
|
+
if profile and key_prefix:
|
1148
|
+
# legacy way to read credentials by setting `tracking_uri` to 'databricks://scope:prefix'
|
1149
|
+
providers = [TrackingURIConfigProvider(tracking_uri)]
|
1150
|
+
elif profile:
|
1151
|
+
# If `tracking_uri` is 'databricks://<profile>'
|
1152
|
+
# MLflow should only read credentials from this profile
|
1153
|
+
providers = [ProfileConfigProvider(profile)]
|
1154
|
+
else:
|
1155
|
+
providers = [
|
1156
|
+
# `EnvironmentVariableConfigProvider` should be prioritized at the highest level,
|
1157
|
+
# to align with Databricks-SDK behavior.
|
1158
|
+
EnvironmentVariableConfigProvider(),
|
1159
|
+
_dynamic_token_config_provider,
|
1160
|
+
ProfileConfigProvider(None),
|
1161
|
+
SparkTaskContextConfigProvider(),
|
1162
|
+
DatabricksModelServingConfigProvider(),
|
1163
|
+
]
|
1164
|
+
|
1165
|
+
for provider in providers:
|
1166
|
+
if provider:
|
1167
|
+
_config = provider.get_config()
|
1168
|
+
if _config is not None and _config.is_valid:
|
1169
|
+
config = _config
|
1170
|
+
break
|
1171
|
+
|
1172
|
+
if not config or not config.host:
|
1173
|
+
_fail_malformed_databricks_auth(tracking_uri)
|
1174
|
+
|
1175
|
+
return config
|
1176
|
+
|
1177
|
+
|
1178
|
+
def get_databricks_env_vars(tracking_uri):
|
1179
|
+
if not mlflow.utils.uri.is_databricks_uri(tracking_uri):
|
1180
|
+
return {}
|
1181
|
+
|
1182
|
+
config = _get_databricks_creds_config(tracking_uri)
|
1183
|
+
|
1184
|
+
if config.auth_type == "databricks-cli":
|
1185
|
+
raise MlflowException(
|
1186
|
+
"You configured authentication type to 'databricks-cli', in this case, MLflow cannot "
|
1187
|
+
"read credential values, so that MLflow cannot construct the databricks environment "
|
1188
|
+
"variables for child process authentication."
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
# We set these via environment variables so that only the current profile is exposed, rather
|
1192
|
+
# than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
|
1193
|
+
# part of ~/.databrickscfg into the container
|
1194
|
+
env_vars = {}
|
1195
|
+
env_vars[MLFLOW_TRACKING_URI.name] = "databricks"
|
1196
|
+
env_vars["DATABRICKS_HOST"] = config.host
|
1197
|
+
if config.username:
|
1198
|
+
env_vars["DATABRICKS_USERNAME"] = config.username
|
1199
|
+
if config.password:
|
1200
|
+
env_vars["DATABRICKS_PASSWORD"] = config.password
|
1201
|
+
if config.token:
|
1202
|
+
env_vars["DATABRICKS_TOKEN"] = config.token
|
1203
|
+
if config.insecure:
|
1204
|
+
env_vars["DATABRICKS_INSECURE"] = str(config.insecure)
|
1205
|
+
if config.client_id:
|
1206
|
+
env_vars["DATABRICKS_CLIENT_ID"] = config.client_id
|
1207
|
+
if config.client_secret:
|
1208
|
+
env_vars["DATABRICKS_CLIENT_SECRET"] = config.client_secret
|
1209
|
+
|
1210
|
+
workspace_info = get_databricks_workspace_info_from_uri(tracking_uri)
|
1211
|
+
if workspace_info is not None:
|
1212
|
+
env_vars.update(workspace_info.to_environment())
|
1213
|
+
|
1214
|
+
return env_vars
|
1215
|
+
|
1216
|
+
|
1217
|
+
def _get_databricks_serverless_env_vars() -> dict[str, str]:
|
1218
|
+
"""
|
1219
|
+
Returns the environment variables required to to initialize WorkspaceClient in a subprocess
|
1220
|
+
with serverless compute.
|
1221
|
+
|
1222
|
+
Note:
|
1223
|
+
Databricks authentication related environment variables such as DATABRICKS_HOST are
|
1224
|
+
set in the are set in the _capture_imported_modules function.
|
1225
|
+
"""
|
1226
|
+
envs = {}
|
1227
|
+
if "SPARK_REMOTE" in os.environ:
|
1228
|
+
envs["SPARK_LOCAL_REMOTE"] = os.environ["SPARK_REMOTE"]
|
1229
|
+
else:
|
1230
|
+
_logger.warning(
|
1231
|
+
"Missing required environment variable `SPARK_LOCAL_REMOTE` or `SPARK_REMOTE`. "
|
1232
|
+
"These are necessary to initialize the WorkspaceClient with serverless compute in "
|
1233
|
+
"a subprocess in Databricks for UC function execution. Setting the value to 'true'."
|
1234
|
+
)
|
1235
|
+
envs["SPARK_LOCAL_REMOTE"] = "true"
|
1236
|
+
return envs
|
1237
|
+
|
1238
|
+
|
1239
|
+
class DatabricksRuntimeVersion(NamedTuple):
|
1240
|
+
is_client_image: bool
|
1241
|
+
major: int
|
1242
|
+
minor: int
|
1243
|
+
|
1244
|
+
@classmethod
|
1245
|
+
def parse(cls, databricks_runtime: Optional[str] = None):
|
1246
|
+
dbr_version = databricks_runtime or get_databricks_runtime_version()
|
1247
|
+
try:
|
1248
|
+
dbr_version_splits = dbr_version.split(".", maxsplit=2)
|
1249
|
+
if dbr_version_splits[0] == "client":
|
1250
|
+
is_client_image = True
|
1251
|
+
major = int(dbr_version_splits[1])
|
1252
|
+
minor = int(dbr_version_splits[2]) if len(dbr_version_splits) > 2 else 0
|
1253
|
+
else:
|
1254
|
+
is_client_image = False
|
1255
|
+
major = int(dbr_version_splits[0])
|
1256
|
+
minor = int(dbr_version_splits[1])
|
1257
|
+
return cls(is_client_image, major, minor)
|
1258
|
+
except Exception:
|
1259
|
+
raise MlflowException(f"Failed to parse databricks runtime version '{dbr_version}'.")
|
1260
|
+
|
1261
|
+
|
1262
|
+
def get_databricks_runtime_major_minor_version():
|
1263
|
+
return DatabricksRuntimeVersion.parse()
|
1264
|
+
|
1265
|
+
|
1266
|
+
_dynamic_token_config_provider = None
|
1267
|
+
|
1268
|
+
|
1269
|
+
def _init_databricks_dynamic_token_config_provider(entry_point):
|
1270
|
+
"""
|
1271
|
+
set a custom DatabricksConfigProvider with the hostname and token of the
|
1272
|
+
user running the current command (achieved by looking at
|
1273
|
+
PythonAccessibleThreadLocals.commandContext, via the already-exposed
|
1274
|
+
NotebookUtils.getContext API)
|
1275
|
+
"""
|
1276
|
+
global _dynamic_token_config_provider
|
1277
|
+
|
1278
|
+
notebook_utils = entry_point.getDbutils().notebook()
|
1279
|
+
|
1280
|
+
dbr_version = get_databricks_runtime_major_minor_version()
|
1281
|
+
dbr_major_minor_version = (dbr_version.major, dbr_version.minor)
|
1282
|
+
|
1283
|
+
# the CLI code in client-branch-1.0 is the same as in the 15.0 runtime branch
|
1284
|
+
if dbr_version.is_client_image or dbr_major_minor_version >= (13, 2):
|
1285
|
+
|
1286
|
+
class DynamicConfigProvider(DatabricksConfigProvider):
|
1287
|
+
def get_config(self):
|
1288
|
+
logger = entry_point.getLogger()
|
1289
|
+
try:
|
1290
|
+
from dbruntime.databricks_repl_context import get_context
|
1291
|
+
|
1292
|
+
ctx = get_context()
|
1293
|
+
if ctx and ctx.apiUrl and ctx.apiToken:
|
1294
|
+
return DatabricksConfig.from_token(
|
1295
|
+
host=ctx.apiUrl, token=ctx.apiToken, insecure=ctx.sslTrustAll
|
1296
|
+
)
|
1297
|
+
except Exception as e:
|
1298
|
+
_logger.debug(
|
1299
|
+
"Unexpected internal error while constructing `DatabricksConfig` "
|
1300
|
+
f"from REPL context: {e}",
|
1301
|
+
)
|
1302
|
+
# Invoking getContext() will attempt to find the credentials related to the
|
1303
|
+
# current command execution, so it's critical that we execute it on every
|
1304
|
+
# get_config().
|
1305
|
+
api_url_option = notebook_utils.getContext().apiUrl()
|
1306
|
+
api_url = api_url_option.get() if api_url_option.isDefined() else None
|
1307
|
+
# Invoking getNonUcApiToken() will attempt to find the current credentials related
|
1308
|
+
# to the current command execution and refresh it if its expired automatically,
|
1309
|
+
# so it's critical that we execute it on every get_config().
|
1310
|
+
api_token = None
|
1311
|
+
try:
|
1312
|
+
api_token = entry_point.getNonUcApiToken()
|
1313
|
+
except Exception:
|
1314
|
+
# Using apiToken from command context would return back the token which is not
|
1315
|
+
# refreshed.
|
1316
|
+
fallback_api_token_option = notebook_utils.getContext().apiToken()
|
1317
|
+
logger.logUsage(
|
1318
|
+
"refreshableTokenNotFound",
|
1319
|
+
{"api_url": api_url},
|
1320
|
+
None,
|
1321
|
+
)
|
1322
|
+
if fallback_api_token_option.isDefined():
|
1323
|
+
api_token = fallback_api_token_option.get()
|
1324
|
+
|
1325
|
+
ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
|
1326
|
+
|
1327
|
+
if api_token is None or api_url is None:
|
1328
|
+
return None
|
1329
|
+
|
1330
|
+
return DatabricksConfig.from_token(
|
1331
|
+
host=api_url, token=api_token, insecure=ssl_trust_all
|
1332
|
+
)
|
1333
|
+
elif dbr_major_minor_version >= (10, 3):
|
1334
|
+
|
1335
|
+
class DynamicConfigProvider(DatabricksConfigProvider):
|
1336
|
+
def get_config(self):
|
1337
|
+
try:
|
1338
|
+
from dbruntime.databricks_repl_context import get_context
|
1339
|
+
|
1340
|
+
ctx = get_context()
|
1341
|
+
if ctx and ctx.apiUrl and ctx.apiToken:
|
1342
|
+
return DatabricksConfig.from_token(
|
1343
|
+
host=ctx.apiUrl, token=ctx.apiToken, insecure=ctx.sslTrustAll
|
1344
|
+
)
|
1345
|
+
except Exception as e:
|
1346
|
+
_logger.debug(
|
1347
|
+
"Unexpected internal error while constructing `DatabricksConfig` "
|
1348
|
+
f"from REPL context: {e}",
|
1349
|
+
)
|
1350
|
+
# Invoking getContext() will attempt to find the credentials related to the
|
1351
|
+
# current command execution, so it's critical that we execute it on every
|
1352
|
+
# get_config().
|
1353
|
+
api_token_option = notebook_utils.getContext().apiToken()
|
1354
|
+
api_url_option = notebook_utils.getContext().apiUrl()
|
1355
|
+
ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
|
1356
|
+
|
1357
|
+
if not api_token_option.isDefined() or not api_url_option.isDefined():
|
1358
|
+
return None
|
1359
|
+
|
1360
|
+
return DatabricksConfig.from_token(
|
1361
|
+
host=api_url_option.get(), token=api_token_option.get(), insecure=ssl_trust_all
|
1362
|
+
)
|
1363
|
+
else:
|
1364
|
+
|
1365
|
+
class DynamicConfigProvider(DatabricksConfigProvider):
|
1366
|
+
def get_config(self):
|
1367
|
+
# Invoking getContext() will attempt to find the credentials related to the
|
1368
|
+
# current command execution, so it's critical that we execute it on every
|
1369
|
+
# get_config().
|
1370
|
+
api_token_option = notebook_utils.getContext().apiToken()
|
1371
|
+
api_url_option = notebook_utils.getContext().apiUrl()
|
1372
|
+
ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
|
1373
|
+
|
1374
|
+
if not api_token_option.isDefined() or not api_url_option.isDefined():
|
1375
|
+
return None
|
1376
|
+
|
1377
|
+
return DatabricksConfig.from_token(
|
1378
|
+
host=api_url_option.get(), token=api_token_option.get(), insecure=ssl_trust_all
|
1379
|
+
)
|
1380
|
+
|
1381
|
+
_dynamic_token_config_provider = DynamicConfigProvider()
|
1382
|
+
|
1383
|
+
|
1384
|
+
if is_in_databricks_runtime():
|
1385
|
+
try:
|
1386
|
+
dbutils = _get_dbutils()
|
1387
|
+
_init_databricks_dynamic_token_config_provider(dbutils.entry_point)
|
1388
|
+
except _NoDbutilsError:
|
1389
|
+
# If there is no dbutils available, it means it is run in databricks driver local suite,
|
1390
|
+
# in this case, we don't need to initialize databricks token because
|
1391
|
+
# there is no backend mlflow service available.
|
1392
|
+
pass
|
1393
|
+
|
1394
|
+
|
1395
|
+
def get_databricks_nfs_temp_dir():
|
1396
|
+
entry_point = _get_dbutils().entry_point
|
1397
|
+
if getpass.getuser().lower() == "root":
|
1398
|
+
return entry_point.getReplNFSTempDir()
|
1399
|
+
else:
|
1400
|
+
try:
|
1401
|
+
# If it is not ROOT user, it means the code is running in Safe-spark.
|
1402
|
+
# In this case, we should get temporary directory of current user.
|
1403
|
+
# and `getReplNFSTempDir` will be deprecated for this case.
|
1404
|
+
return entry_point.getUserNFSTempDir()
|
1405
|
+
except Exception:
|
1406
|
+
# fallback
|
1407
|
+
return entry_point.getReplNFSTempDir()
|
1408
|
+
|
1409
|
+
|
1410
|
+
def get_databricks_local_temp_dir():
|
1411
|
+
entry_point = _get_dbutils().entry_point
|
1412
|
+
if getpass.getuser().lower() == "root":
|
1413
|
+
return entry_point.getReplLocalTempDir()
|
1414
|
+
else:
|
1415
|
+
try:
|
1416
|
+
# If it is not ROOT user, it means the code is running in Safe-spark.
|
1417
|
+
# In this case, we should get temporary directory of current user.
|
1418
|
+
# and `getReplLocalTempDir` will be deprecated for this case.
|
1419
|
+
return entry_point.getUserLocalTempDir()
|
1420
|
+
except Exception:
|
1421
|
+
# fallback
|
1422
|
+
return entry_point.getReplLocalTempDir()
|
1423
|
+
|
1424
|
+
|
1425
|
+
def stage_model_for_databricks_model_serving(model_name: str, model_version: str):
|
1426
|
+
response = http_request(
|
1427
|
+
host_creds=get_databricks_host_creds(),
|
1428
|
+
endpoint="/api/2.0/serving-endpoints:stageDeployment",
|
1429
|
+
method="POST",
|
1430
|
+
raise_on_status=False,
|
1431
|
+
json={
|
1432
|
+
"model_name": model_name,
|
1433
|
+
"model_version": model_version,
|
1434
|
+
},
|
1435
|
+
)
|
1436
|
+
augmented_raise_for_status(response)
|