genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,729 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import posixpath
|
6
|
+
import uuid
|
7
|
+
from typing import Any, Optional
|
8
|
+
|
9
|
+
import requests
|
10
|
+
|
11
|
+
import mlflow.tracking
|
12
|
+
from mlflow.azure.client import (
|
13
|
+
patch_adls_file_upload,
|
14
|
+
patch_adls_flush,
|
15
|
+
put_adls_file_creation,
|
16
|
+
put_block,
|
17
|
+
put_block_list,
|
18
|
+
)
|
19
|
+
from mlflow.entities import FileInfo
|
20
|
+
from mlflow.environment_variables import (
|
21
|
+
MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT,
|
22
|
+
MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE,
|
23
|
+
MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE,
|
24
|
+
MLFLOW_MULTIPART_UPLOAD_MINIMUM_FILE_SIZE,
|
25
|
+
)
|
26
|
+
from mlflow.exceptions import (
|
27
|
+
MlflowException,
|
28
|
+
MlflowTraceDataCorrupted,
|
29
|
+
MlflowTraceDataNotFound,
|
30
|
+
)
|
31
|
+
from mlflow.protos.databricks_artifacts_pb2 import (
|
32
|
+
ArtifactCredentialType,
|
33
|
+
CompleteMultipartUpload,
|
34
|
+
CreateMultipartUpload,
|
35
|
+
DatabricksMlflowArtifactsService,
|
36
|
+
GetPresignedUploadPartUrl,
|
37
|
+
PartEtag,
|
38
|
+
)
|
39
|
+
from mlflow.protos.databricks_pb2 import (
|
40
|
+
INTERNAL_ERROR,
|
41
|
+
INVALID_PARAMETER_VALUE,
|
42
|
+
)
|
43
|
+
from mlflow.protos.service_pb2 import MlflowService
|
44
|
+
from mlflow.store.artifact.artifact_repo import write_local_temp_trace_data_file
|
45
|
+
from mlflow.store.artifact.cloud_artifact_repo import (
|
46
|
+
CloudArtifactRepository,
|
47
|
+
_complete_futures,
|
48
|
+
_compute_num_chunks,
|
49
|
+
_validate_chunk_size_aws,
|
50
|
+
)
|
51
|
+
from mlflow.store.artifact.databricks_artifact_repo_resources import (
|
52
|
+
_CredentialType,
|
53
|
+
_LoggedModel,
|
54
|
+
_Resource,
|
55
|
+
_Run,
|
56
|
+
_Trace,
|
57
|
+
)
|
58
|
+
from mlflow.tracing.constant import TRACE_REQUEST_ID_PREFIX
|
59
|
+
from mlflow.utils import chunk_list
|
60
|
+
from mlflow.utils.databricks_utils import get_databricks_host_creds
|
61
|
+
from mlflow.utils.file_utils import (
|
62
|
+
download_file_using_http_uri,
|
63
|
+
read_chunk,
|
64
|
+
)
|
65
|
+
from mlflow.utils.proto_json_utils import message_to_json
|
66
|
+
from mlflow.utils.request_utils import cloud_storage_http_request
|
67
|
+
from mlflow.utils.rest_utils import (
|
68
|
+
_REST_API_PATH_PREFIX,
|
69
|
+
augmented_raise_for_status,
|
70
|
+
call_endpoint,
|
71
|
+
extract_api_info_for_service,
|
72
|
+
)
|
73
|
+
from mlflow.utils.uri import (
|
74
|
+
extract_and_normalize_path,
|
75
|
+
get_databricks_profile_uri_from_artifact_uri,
|
76
|
+
is_databricks_acled_artifacts_uri,
|
77
|
+
is_valid_dbfs_uri,
|
78
|
+
remove_databricks_profile_info_from_artifact_uri,
|
79
|
+
)
|
80
|
+
|
81
|
+
_logger = logging.getLogger(__name__)
|
82
|
+
_MAX_CREDENTIALS_REQUEST_SIZE = 2000 # Max number of artifact paths in a single credentials request
|
83
|
+
_SERVICE_AND_METHOD_TO_INFO = {
|
84
|
+
service: extract_api_info_for_service(service, _REST_API_PATH_PREFIX)
|
85
|
+
for service in [MlflowService, DatabricksMlflowArtifactsService]
|
86
|
+
}
|
87
|
+
|
88
|
+
|
89
|
+
class DatabricksArtifactRepository(CloudArtifactRepository):
|
90
|
+
"""
|
91
|
+
Performs storage operations on artifacts in the access-controlled
|
92
|
+
`dbfs:/databricks/mlflow-tracking` location.
|
93
|
+
|
94
|
+
Signed access URIs for S3 / Azure Blob Storage are fetched from the MLflow service and used to
|
95
|
+
read and write files from/to this location.
|
96
|
+
|
97
|
+
The artifact_uri is expected to be in one of the following forms:
|
98
|
+
- dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/
|
99
|
+
- databricks/mlflow-tracking/<EXP_ID>/logged_models/<MODEL_ID>/artifacts/<path>
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(self, artifact_uri: str, tracking_uri: Optional[str] = None) -> None:
|
103
|
+
if not is_valid_dbfs_uri(artifact_uri):
|
104
|
+
raise MlflowException(
|
105
|
+
message="DBFS URI must be of the form dbfs:/<path> or "
|
106
|
+
+ "dbfs://profile@databricks/<path>",
|
107
|
+
error_code=INVALID_PARAMETER_VALUE,
|
108
|
+
)
|
109
|
+
if not is_databricks_acled_artifacts_uri(artifact_uri):
|
110
|
+
raise MlflowException(
|
111
|
+
message=(
|
112
|
+
"Artifact URI incorrect. Expected path prefix to be"
|
113
|
+
" databricks/mlflow-tracking/path/to/artifact/.."
|
114
|
+
),
|
115
|
+
error_code=INVALID_PARAMETER_VALUE,
|
116
|
+
)
|
117
|
+
# The dbfs:/ path ultimately used for artifact operations should not contain the
|
118
|
+
# Databricks profile info, so strip it before setting ``artifact_uri``.
|
119
|
+
super().__init__(
|
120
|
+
remove_databricks_profile_info_from_artifact_uri(artifact_uri), tracking_uri
|
121
|
+
)
|
122
|
+
|
123
|
+
self.databricks_profile_uri = (
|
124
|
+
get_databricks_profile_uri_from_artifact_uri(artifact_uri)
|
125
|
+
or mlflow.tracking.get_tracking_uri()
|
126
|
+
)
|
127
|
+
self.resource = self._extract_resource(self.artifact_uri)
|
128
|
+
|
129
|
+
def _extract_resource(self, artifact_uri) -> _Resource:
|
130
|
+
"""
|
131
|
+
The artifact_uri is expected to be in one of the following formats:
|
132
|
+
- dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
|
133
|
+
- dbfs:/databricks/mlflow-tracking/<EXP_ID>/logged_models/<MODEL_ID>/artifacts/<path>
|
134
|
+
- databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
|
135
|
+
- databricks/mlflow-tracking/<EXP_ID>/logged_models/<MODEL_ID>/artifacts/<path>
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
A `_Resource` object representing the MLflow resource associated with the specified
|
139
|
+
artifact URI.
|
140
|
+
"""
|
141
|
+
artifact_path = extract_and_normalize_path(artifact_uri)
|
142
|
+
parts = artifact_path.split("/")
|
143
|
+
|
144
|
+
if parts[3] == "logged_models":
|
145
|
+
return _LoggedModel(
|
146
|
+
id_=parts[4], artifact_uri=artifact_uri, call_endpoint=self._call_endpoint
|
147
|
+
)
|
148
|
+
|
149
|
+
if parts[3].startswith(TRACE_REQUEST_ID_PREFIX):
|
150
|
+
return _Trace(
|
151
|
+
id_=parts[3], artifact_uri=artifact_uri, call_endpoint=self._call_endpoint
|
152
|
+
)
|
153
|
+
|
154
|
+
return _Run(id_=parts[3], artifact_uri=artifact_uri, call_endpoint=self._call_endpoint)
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def _extract_run_id(artifact_uri: str) -> Optional[str]:
|
158
|
+
"""
|
159
|
+
Extracts the run ID from the run artifact URI.
|
160
|
+
"""
|
161
|
+
artifact_path = extract_and_normalize_path(artifact_uri)
|
162
|
+
parts = artifact_path.split("/")
|
163
|
+
if len(parts) < 4:
|
164
|
+
return None
|
165
|
+
|
166
|
+
if parts[3] == "logged_models" or parts[3].startswith(TRACE_REQUEST_ID_PREFIX):
|
167
|
+
return None
|
168
|
+
|
169
|
+
return parts[3]
|
170
|
+
|
171
|
+
def _call_endpoint(
|
172
|
+
self, service, api, json_body=None, path_params=None, retry_timeout_seconds=None
|
173
|
+
):
|
174
|
+
"""
|
175
|
+
Calls the specified REST endpoint with the specified JSON body and path parameters.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
service: The service to call.
|
179
|
+
api: The API to call.
|
180
|
+
json_body: The JSON body of the request.
|
181
|
+
path_params: The path parameters to substitute into the endpoint URI.
|
182
|
+
retry_timeout_seconds: The timeout in seconds for retrying failed requests.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
The response from the REST endpoint.
|
186
|
+
"""
|
187
|
+
db_creds = get_databricks_host_creds(self.databricks_profile_uri)
|
188
|
+
endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
|
189
|
+
if path_params:
|
190
|
+
endpoint = endpoint.format(**path_params)
|
191
|
+
response_proto = api.Response()
|
192
|
+
|
193
|
+
return call_endpoint(
|
194
|
+
host_creds=db_creds,
|
195
|
+
endpoint=endpoint,
|
196
|
+
method=method,
|
197
|
+
json_body=json_body,
|
198
|
+
response_proto=response_proto,
|
199
|
+
retry_timeout_seconds=retry_timeout_seconds,
|
200
|
+
)
|
201
|
+
|
202
|
+
def _get_credential_infos(self, cred_type: _CredentialType, paths: list[str]):
|
203
|
+
"""
|
204
|
+
Issue one or more requests for artifact credentials, providing read or write
|
205
|
+
access to the specified resource relative artifact `paths` within the MLflow
|
206
|
+
resource specified by `self.resource.id`. The type of access credentials, read or write,
|
207
|
+
is specified by `request_message_class`.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
cred_type: Specifies the type of access credentials, read or write.
|
211
|
+
paths: The specified relative artifact paths within the MLflow resource.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
A list of `ArtifactCredentialInfo` objects providing read access to the specified
|
215
|
+
relative artifact `paths` within the MLflow resource specified by `resource`.
|
216
|
+
"""
|
217
|
+
credential_infos = []
|
218
|
+
|
219
|
+
for paths_chunk in chunk_list(paths, _MAX_CREDENTIALS_REQUEST_SIZE):
|
220
|
+
page_token = None
|
221
|
+
while True:
|
222
|
+
cred_infos, next_page_token = self.resource.get_credentials(
|
223
|
+
cred_type=cred_type,
|
224
|
+
paths=paths_chunk,
|
225
|
+
page_token=page_token,
|
226
|
+
)
|
227
|
+
credential_infos += cred_infos
|
228
|
+
page_token = next_page_token
|
229
|
+
if not page_token or len(cred_infos) == 0:
|
230
|
+
break
|
231
|
+
|
232
|
+
return credential_infos
|
233
|
+
|
234
|
+
def _get_write_credential_infos(self, remote_file_paths):
|
235
|
+
"""
|
236
|
+
A list of `ArtifactCredentialInfo` objects providing write access to the specified
|
237
|
+
relative artifact `paths` within the MLflow resource specified by `self.resource.id`.
|
238
|
+
"""
|
239
|
+
relative_remote_paths = [
|
240
|
+
posixpath.join(self.resource.relative_path, p or "") for p in remote_file_paths
|
241
|
+
]
|
242
|
+
return self._get_credential_infos(_CredentialType.WRITE, relative_remote_paths)
|
243
|
+
|
244
|
+
def download_trace_data(self) -> dict[str, Any]:
|
245
|
+
[cred], _ = self.resource.get_credentials(cred_type=_CredentialType.READ)
|
246
|
+
signed_uri = cred.signed_uri
|
247
|
+
headers = self._extract_headers_from_credentials(cred.headers)
|
248
|
+
with cloud_storage_http_request("get", signed_uri, headers=headers) as resp:
|
249
|
+
try:
|
250
|
+
augmented_raise_for_status(resp)
|
251
|
+
except requests.HTTPError as e:
|
252
|
+
if e.response.status_code == 404:
|
253
|
+
raise MlflowTraceDataNotFound(request_id=self.resource.id) from e
|
254
|
+
raise
|
255
|
+
|
256
|
+
try:
|
257
|
+
return json.loads(resp.content)
|
258
|
+
except json.JSONDecodeError as e:
|
259
|
+
raise MlflowTraceDataCorrupted(request_id=self.resource.id) from e
|
260
|
+
|
261
|
+
def upload_trace_data(self, trace_data: str) -> None:
|
262
|
+
[cred], _ = self.resource.get_credentials(
|
263
|
+
cred_type=_CredentialType.WRITE,
|
264
|
+
timeout=MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT.get(),
|
265
|
+
)
|
266
|
+
with write_local_temp_trace_data_file(trace_data) as temp_file:
|
267
|
+
if cred.type == ArtifactCredentialType.AZURE_ADLS_GEN2_SAS_URI:
|
268
|
+
self._azure_adls_gen2_upload_file(
|
269
|
+
credentials=cred,
|
270
|
+
local_file=temp_file,
|
271
|
+
artifact_file_path=None,
|
272
|
+
get_credentials=lambda artifact_paths: [
|
273
|
+
self._get_upload_trace_data_cred_info()
|
274
|
+
],
|
275
|
+
)
|
276
|
+
elif cred.type == ArtifactCredentialType.AZURE_SAS_URI:
|
277
|
+
self._azure_upload_file(
|
278
|
+
credentials=cred,
|
279
|
+
local_file=temp_file,
|
280
|
+
artifact_file_path=None,
|
281
|
+
get_credentials=lambda artifact_paths: [
|
282
|
+
self._get_upload_trace_data_cred_info()
|
283
|
+
],
|
284
|
+
)
|
285
|
+
elif (
|
286
|
+
cred.type == ArtifactCredentialType.AWS_PRESIGNED_URL
|
287
|
+
or cred.type == ArtifactCredentialType.GCP_SIGNED_URL
|
288
|
+
):
|
289
|
+
self._signed_url_upload_file(cred, temp_file)
|
290
|
+
|
291
|
+
def _get_read_credential_infos(self, remote_file_paths):
|
292
|
+
"""
|
293
|
+
Returns:
|
294
|
+
A list of `ArtifactCredentialInfo` objects providing read access to the specified
|
295
|
+
relative artifact `paths` within the MLflow resource specified.
|
296
|
+
"""
|
297
|
+
if type(remote_file_paths) == str:
|
298
|
+
remote_file_paths = [remote_file_paths]
|
299
|
+
if type(remote_file_paths) != list:
|
300
|
+
raise MlflowException(
|
301
|
+
f"Expected `paths` to be a list of strings. Got {type(remote_file_paths)}"
|
302
|
+
)
|
303
|
+
relative_remote_paths = [
|
304
|
+
posixpath.join(self.resource.relative_path, p) for p in remote_file_paths
|
305
|
+
]
|
306
|
+
return self._get_credential_infos(_CredentialType.READ, relative_remote_paths)
|
307
|
+
|
308
|
+
def _extract_headers_from_credentials(self, headers):
|
309
|
+
"""
|
310
|
+
Returns:
|
311
|
+
A python dictionary of http headers converted from the protobuf credentials.
|
312
|
+
"""
|
313
|
+
return {header.name: header.value for header in headers}
|
314
|
+
|
315
|
+
def _azure_upload_chunk(
|
316
|
+
self,
|
317
|
+
credentials,
|
318
|
+
headers,
|
319
|
+
local_file,
|
320
|
+
artifact_file_path,
|
321
|
+
start_byte,
|
322
|
+
size,
|
323
|
+
get_credentials,
|
324
|
+
):
|
325
|
+
"""
|
326
|
+
Uploads a chunk of a file to a given Azure storage location.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
credentials: The credentials for the upload.
|
330
|
+
headers: The headers for the upload.
|
331
|
+
local_file: The local file to upload.
|
332
|
+
artifact_file_path: The path to the artifact file.
|
333
|
+
start_byte: The starting byte of the chunk.
|
334
|
+
size: The size of the chunk.
|
335
|
+
get_credentials: The function to call to get new credentials.
|
336
|
+
"""
|
337
|
+
# Base64-encode a UUID, producing a UTF8-encoded bytestring. Then, decode
|
338
|
+
# the bytestring for compliance with Azure Blob Storage API requests
|
339
|
+
block_id = base64.b64encode(uuid.uuid4().hex.encode()).decode("utf-8")
|
340
|
+
chunk = read_chunk(local_file, size, start_byte)
|
341
|
+
try:
|
342
|
+
put_block(credentials.signed_uri, block_id, chunk, headers=headers)
|
343
|
+
except requests.HTTPError as e:
|
344
|
+
if e.response.status_code in [401, 403]:
|
345
|
+
_logger.info(
|
346
|
+
"Failed to authorize request, possibly due to credential expiration."
|
347
|
+
" Refreshing credentials and trying again..."
|
348
|
+
)
|
349
|
+
credential_info = get_credentials([artifact_file_path])[0]
|
350
|
+
put_block(credential_info.signed_uri, block_id, chunk, headers=headers)
|
351
|
+
else:
|
352
|
+
raise e
|
353
|
+
return block_id
|
354
|
+
|
355
|
+
def _azure_upload_file(self, credentials, local_file, artifact_file_path, get_credentials):
|
356
|
+
"""
|
357
|
+
Uploads a file to a given Azure storage location.
|
358
|
+
The function uses a file chunking generator with 100 MB being the size limit for each chunk.
|
359
|
+
This limit is imposed by the stage_block API in azure-storage-blob.
|
360
|
+
In the case the file size is large and the upload takes longer than the validity of the
|
361
|
+
given credentials, a new set of credentials are generated and the operation continues. This
|
362
|
+
is the reason for the first nested try-except block
|
363
|
+
Finally, since the prevailing credentials could expire in the time between the last
|
364
|
+
stage_block and the commit, a second try-except block refreshes credentials if needed.
|
365
|
+
|
366
|
+
Args:
|
367
|
+
credentials: The credentials for the upload.
|
368
|
+
local_file: The local file to upload.
|
369
|
+
artifact_file_path: The path to the artifact file.
|
370
|
+
get_credentials: The function to call to get new credentials.
|
371
|
+
"""
|
372
|
+
try:
|
373
|
+
headers = self._extract_headers_from_credentials(credentials.headers)
|
374
|
+
futures = {}
|
375
|
+
num_chunks = _compute_num_chunks(local_file, MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get())
|
376
|
+
for index in range(num_chunks):
|
377
|
+
start_byte = index * MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get()
|
378
|
+
future = self.chunk_thread_pool.submit(
|
379
|
+
self._azure_upload_chunk,
|
380
|
+
credentials=credentials,
|
381
|
+
headers=headers,
|
382
|
+
local_file=local_file,
|
383
|
+
artifact_file_path=artifact_file_path,
|
384
|
+
start_byte=start_byte,
|
385
|
+
size=MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get(),
|
386
|
+
get_credentials=get_credentials,
|
387
|
+
)
|
388
|
+
futures[future] = index
|
389
|
+
|
390
|
+
results, errors = _complete_futures(futures, local_file)
|
391
|
+
if errors:
|
392
|
+
raise MlflowException(
|
393
|
+
f"Failed to upload at least one part of {local_file}. Errors: {errors}"
|
394
|
+
)
|
395
|
+
# Sort results by the chunk index
|
396
|
+
uploading_block_list = [results[index] for index in sorted(results)]
|
397
|
+
|
398
|
+
try:
|
399
|
+
put_block_list(credentials.signed_uri, uploading_block_list, headers=headers)
|
400
|
+
except requests.HTTPError as e:
|
401
|
+
if e.response.status_code in [401, 403]:
|
402
|
+
_logger.info(
|
403
|
+
"Failed to authorize request, possibly due to credential expiration."
|
404
|
+
" Refreshing credentials and trying again..."
|
405
|
+
)
|
406
|
+
credential_info = get_credentials([artifact_file_path])[0]
|
407
|
+
put_block_list(
|
408
|
+
credential_info.signed_uri, uploading_block_list, headers=headers
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise e
|
412
|
+
except Exception as err:
|
413
|
+
raise MlflowException(err)
|
414
|
+
|
415
|
+
def _retryable_adls_function(self, func, artifact_file_path, get_credentials, **kwargs):
|
416
|
+
"""
|
417
|
+
Calls the passed function, retrying if the credentials have expired.
|
418
|
+
|
419
|
+
Args:
|
420
|
+
func: The function to call.
|
421
|
+
artifact_file_path: The artifact file path.
|
422
|
+
get_credentials: The function to call to get new credentials.
|
423
|
+
**kwargs: The keyword arguments to pass to the function.
|
424
|
+
"""
|
425
|
+
# Attempt to call the passed function. Retry if the credentials have expired
|
426
|
+
try:
|
427
|
+
func(**kwargs)
|
428
|
+
except requests.HTTPError as e:
|
429
|
+
if e.response.status_code in [403]:
|
430
|
+
_logger.info(
|
431
|
+
"Failed to authorize ADLS operation, possibly due "
|
432
|
+
"to credential expiration. Refreshing credentials and trying again..."
|
433
|
+
)
|
434
|
+
new_credentials = get_credentials([artifact_file_path])[0]
|
435
|
+
kwargs["sas_url"] = new_credentials.signed_uri
|
436
|
+
func(**kwargs)
|
437
|
+
else:
|
438
|
+
raise e
|
439
|
+
|
440
|
+
def _azure_adls_gen2_upload_file(
|
441
|
+
self, credentials, local_file, artifact_file_path, get_credentials
|
442
|
+
):
|
443
|
+
"""
|
444
|
+
Uploads a file to a given Azure storage location using the ADLS gen2 API.
|
445
|
+
|
446
|
+
Args:
|
447
|
+
credentials: The credentials for the upload.
|
448
|
+
local_file: The local file to upload.
|
449
|
+
artifact_file_path: The path to the artifact file.
|
450
|
+
get_credentials: The function to call to get new credentials.
|
451
|
+
"""
|
452
|
+
try:
|
453
|
+
headers = self._extract_headers_from_credentials(credentials.headers)
|
454
|
+
|
455
|
+
# try to create the file
|
456
|
+
self._retryable_adls_function(
|
457
|
+
func=put_adls_file_creation,
|
458
|
+
artifact_file_path=artifact_file_path,
|
459
|
+
get_credentials=get_credentials,
|
460
|
+
sas_url=credentials.signed_uri,
|
461
|
+
headers=headers,
|
462
|
+
)
|
463
|
+
|
464
|
+
# next try to append the file
|
465
|
+
futures = {}
|
466
|
+
file_size = os.path.getsize(local_file)
|
467
|
+
num_chunks = _compute_num_chunks(local_file, MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get())
|
468
|
+
use_single_part_upload = num_chunks == 1
|
469
|
+
for index in range(num_chunks):
|
470
|
+
start_byte = index * MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get()
|
471
|
+
future = self.chunk_thread_pool.submit(
|
472
|
+
self._retryable_adls_function,
|
473
|
+
func=patch_adls_file_upload,
|
474
|
+
artifact_file_path=artifact_file_path,
|
475
|
+
get_credentials=get_credentials,
|
476
|
+
sas_url=credentials.signed_uri,
|
477
|
+
local_file=local_file,
|
478
|
+
start_byte=start_byte,
|
479
|
+
size=MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get(),
|
480
|
+
position=start_byte,
|
481
|
+
headers=headers,
|
482
|
+
is_single=use_single_part_upload,
|
483
|
+
)
|
484
|
+
futures[future] = index
|
485
|
+
|
486
|
+
_, errors = _complete_futures(futures, local_file)
|
487
|
+
if errors:
|
488
|
+
raise MlflowException(
|
489
|
+
f"Failed to upload at least one part of {artifact_file_path}. Errors: {errors}"
|
490
|
+
)
|
491
|
+
|
492
|
+
# finally try to flush the file
|
493
|
+
if not use_single_part_upload:
|
494
|
+
self._retryable_adls_function(
|
495
|
+
func=patch_adls_flush,
|
496
|
+
artifact_file_path=artifact_file_path,
|
497
|
+
get_credentials=get_credentials,
|
498
|
+
sas_url=credentials.signed_uri,
|
499
|
+
position=file_size,
|
500
|
+
headers=headers,
|
501
|
+
)
|
502
|
+
except Exception as err:
|
503
|
+
raise MlflowException(err)
|
504
|
+
|
505
|
+
def _signed_url_upload_file(self, credentials, local_file):
|
506
|
+
try:
|
507
|
+
headers = self._extract_headers_from_credentials(credentials.headers)
|
508
|
+
signed_write_uri = credentials.signed_uri
|
509
|
+
# Putting an empty file in a request by reading file bytes gives 501 error.
|
510
|
+
if os.stat(local_file).st_size == 0:
|
511
|
+
with cloud_storage_http_request(
|
512
|
+
"put", signed_write_uri, data="", headers=headers
|
513
|
+
) as response:
|
514
|
+
augmented_raise_for_status(response)
|
515
|
+
else:
|
516
|
+
with open(local_file, "rb") as file:
|
517
|
+
with cloud_storage_http_request(
|
518
|
+
"put", signed_write_uri, data=file, headers=headers
|
519
|
+
) as response:
|
520
|
+
augmented_raise_for_status(response)
|
521
|
+
except Exception as err:
|
522
|
+
raise MlflowException(err)
|
523
|
+
|
524
|
+
def _upload_to_cloud(self, cloud_credential_info, src_file_path, artifact_file_path):
|
525
|
+
"""
|
526
|
+
Upload a local file to the cloud. Note that in this artifact repository, files are uploaded
|
527
|
+
to resource relative artifact file paths in the artifact repository.
|
528
|
+
|
529
|
+
Args:
|
530
|
+
cloud_credential_info: ArtifactCredentialInfo object with presigned URL for the file.
|
531
|
+
src_file_path: Local source file path for the upload.
|
532
|
+
artifact_file_path: Path in the artifact repository, relative to the resource root path,
|
533
|
+
where the artifact will be logged.
|
534
|
+
|
535
|
+
"""
|
536
|
+
if cloud_credential_info.type == ArtifactCredentialType.AZURE_SAS_URI:
|
537
|
+
self._azure_upload_file(
|
538
|
+
cloud_credential_info,
|
539
|
+
src_file_path,
|
540
|
+
artifact_file_path,
|
541
|
+
get_credentials=self._get_write_credential_infos,
|
542
|
+
)
|
543
|
+
elif cloud_credential_info.type == ArtifactCredentialType.AZURE_ADLS_GEN2_SAS_URI:
|
544
|
+
self._azure_adls_gen2_upload_file(
|
545
|
+
cloud_credential_info,
|
546
|
+
src_file_path,
|
547
|
+
artifact_file_path,
|
548
|
+
self._get_write_credential_infos,
|
549
|
+
)
|
550
|
+
elif cloud_credential_info.type == ArtifactCredentialType.AWS_PRESIGNED_URL:
|
551
|
+
if os.path.getsize(src_file_path) > MLFLOW_MULTIPART_UPLOAD_MINIMUM_FILE_SIZE.get():
|
552
|
+
_validate_chunk_size_aws(MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get())
|
553
|
+
self._multipart_upload(src_file_path, artifact_file_path)
|
554
|
+
else:
|
555
|
+
self._signed_url_upload_file(cloud_credential_info, src_file_path)
|
556
|
+
elif cloud_credential_info.type == ArtifactCredentialType.GCP_SIGNED_URL:
|
557
|
+
self._signed_url_upload_file(cloud_credential_info, src_file_path)
|
558
|
+
else:
|
559
|
+
raise MlflowException(
|
560
|
+
message="Cloud provider not supported.", error_code=INTERNAL_ERROR
|
561
|
+
)
|
562
|
+
|
563
|
+
def _download_from_cloud(self, remote_file_path, local_path):
|
564
|
+
"""
|
565
|
+
Download a file from the input `remote_file_path` and save it to `local_path`.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
remote_file_path: Path relative to the resource root path to file in remote artifact
|
569
|
+
repository.
|
570
|
+
local_path: Local path to download file to.
|
571
|
+
|
572
|
+
"""
|
573
|
+
read_credentials = self._get_read_credential_infos(remote_file_path)
|
574
|
+
# Read credentials for only one file were requested. So we expected only one value in
|
575
|
+
# the response.
|
576
|
+
assert len(read_credentials) == 1
|
577
|
+
cloud_credential_info = read_credentials[0]
|
578
|
+
|
579
|
+
if cloud_credential_info.type not in [
|
580
|
+
ArtifactCredentialType.AZURE_SAS_URI,
|
581
|
+
ArtifactCredentialType.AZURE_ADLS_GEN2_SAS_URI,
|
582
|
+
ArtifactCredentialType.AWS_PRESIGNED_URL,
|
583
|
+
ArtifactCredentialType.GCP_SIGNED_URL,
|
584
|
+
]:
|
585
|
+
raise MlflowException(
|
586
|
+
message="Cloud provider not supported.", error_code=INTERNAL_ERROR
|
587
|
+
)
|
588
|
+
try:
|
589
|
+
download_file_using_http_uri(
|
590
|
+
cloud_credential_info.signed_uri,
|
591
|
+
local_path,
|
592
|
+
MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE.get(),
|
593
|
+
self._extract_headers_from_credentials(cloud_credential_info.headers),
|
594
|
+
)
|
595
|
+
except Exception as err:
|
596
|
+
raise MlflowException(err)
|
597
|
+
|
598
|
+
def _create_multipart_upload(self, run_id, path, num_parts):
|
599
|
+
return self._call_endpoint(
|
600
|
+
DatabricksMlflowArtifactsService,
|
601
|
+
CreateMultipartUpload,
|
602
|
+
message_to_json(CreateMultipartUpload(run_id=run_id, path=path, num_parts=num_parts)),
|
603
|
+
)
|
604
|
+
|
605
|
+
def _get_presigned_upload_part_url(self, run_id, path, upload_id, part_number):
|
606
|
+
return self._call_endpoint(
|
607
|
+
DatabricksMlflowArtifactsService,
|
608
|
+
GetPresignedUploadPartUrl,
|
609
|
+
message_to_json(
|
610
|
+
GetPresignedUploadPartUrl(
|
611
|
+
run_id=run_id, path=path, upload_id=upload_id, part_number=part_number
|
612
|
+
)
|
613
|
+
),
|
614
|
+
)
|
615
|
+
|
616
|
+
def _upload_part(self, cred_info, data):
|
617
|
+
headers = self._extract_headers_from_credentials(cred_info.headers)
|
618
|
+
with cloud_storage_http_request(
|
619
|
+
"put",
|
620
|
+
cred_info.signed_uri,
|
621
|
+
data=data,
|
622
|
+
headers=headers,
|
623
|
+
) as response:
|
624
|
+
augmented_raise_for_status(response)
|
625
|
+
return response.headers["ETag"]
|
626
|
+
|
627
|
+
def _upload_part_retry(self, cred_info, upload_id, part_number, local_file, start_byte, size):
|
628
|
+
data = read_chunk(local_file, size, start_byte)
|
629
|
+
try:
|
630
|
+
return self._upload_part(cred_info, data)
|
631
|
+
except requests.HTTPError as e:
|
632
|
+
if e.response.status_code not in (401, 403):
|
633
|
+
raise e
|
634
|
+
_logger.info(
|
635
|
+
"Failed to authorize request, possibly due to credential expiration."
|
636
|
+
" Refreshing credentials and trying again..."
|
637
|
+
)
|
638
|
+
resp = self._get_presigned_upload_part_url(
|
639
|
+
cred_info.run_id, cred_info.path, upload_id, part_number
|
640
|
+
)
|
641
|
+
return self._upload_part(resp.upload_credential_info, data)
|
642
|
+
|
643
|
+
def _upload_parts(self, local_file, create_mpu_resp):
|
644
|
+
futures = {}
|
645
|
+
for index, cred_info in enumerate(create_mpu_resp.upload_credential_infos):
|
646
|
+
part_number = index + 1
|
647
|
+
start_byte = index * MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get()
|
648
|
+
future = self.chunk_thread_pool.submit(
|
649
|
+
self._upload_part_retry,
|
650
|
+
cred_info=cred_info,
|
651
|
+
upload_id=create_mpu_resp.upload_id,
|
652
|
+
part_number=part_number,
|
653
|
+
local_file=local_file,
|
654
|
+
start_byte=start_byte,
|
655
|
+
size=MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get(),
|
656
|
+
)
|
657
|
+
futures[future] = part_number
|
658
|
+
|
659
|
+
results, errors = _complete_futures(futures, local_file)
|
660
|
+
if errors:
|
661
|
+
raise MlflowException(
|
662
|
+
f"Failed to upload at least one part of {local_file}. Errors: {errors}"
|
663
|
+
)
|
664
|
+
|
665
|
+
return [
|
666
|
+
PartEtag(part_number=part_number, etag=results[part_number])
|
667
|
+
for part_number in sorted(results)
|
668
|
+
]
|
669
|
+
|
670
|
+
def _complete_multipart_upload(self, run_id, path, upload_id, part_etags):
|
671
|
+
return self._call_endpoint(
|
672
|
+
DatabricksMlflowArtifactsService,
|
673
|
+
CompleteMultipartUpload,
|
674
|
+
message_to_json(
|
675
|
+
CompleteMultipartUpload(
|
676
|
+
run_id=run_id,
|
677
|
+
path=path,
|
678
|
+
upload_id=upload_id,
|
679
|
+
part_etags=part_etags,
|
680
|
+
)
|
681
|
+
),
|
682
|
+
)
|
683
|
+
|
684
|
+
def _abort_multipart_upload(self, cred_info):
|
685
|
+
headers = self._extract_headers_from_credentials(cred_info.headers)
|
686
|
+
with cloud_storage_http_request(
|
687
|
+
"delete", cred_info.signed_uri, headers=headers
|
688
|
+
) as response:
|
689
|
+
augmented_raise_for_status(response)
|
690
|
+
return response
|
691
|
+
|
692
|
+
def _multipart_upload(self, local_file, artifact_file_path):
|
693
|
+
run_relative_artifact_path = posixpath.join(
|
694
|
+
self.resource.relative_path, artifact_file_path or ""
|
695
|
+
)
|
696
|
+
num_parts = _compute_num_chunks(local_file, MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.get())
|
697
|
+
create_mpu_resp = self._create_multipart_upload(
|
698
|
+
self.resource.id, run_relative_artifact_path, num_parts
|
699
|
+
)
|
700
|
+
try:
|
701
|
+
part_etags = self._upload_parts(local_file, create_mpu_resp)
|
702
|
+
self._complete_multipart_upload(
|
703
|
+
self.resource.id,
|
704
|
+
run_relative_artifact_path,
|
705
|
+
create_mpu_resp.upload_id,
|
706
|
+
part_etags,
|
707
|
+
)
|
708
|
+
except Exception as e:
|
709
|
+
_logger.warning(
|
710
|
+
"Encountered an unexpected error during multipart upload: %s, aborting", e
|
711
|
+
)
|
712
|
+
self._abort_multipart_upload(create_mpu_resp.abort_credential_info)
|
713
|
+
raise e
|
714
|
+
|
715
|
+
def log_artifact(self, local_file, artifact_path=None):
|
716
|
+
src_file_name = os.path.basename(local_file)
|
717
|
+
artifact_file_path = posixpath.join(artifact_path or "", src_file_name)
|
718
|
+
write_credential_info = self._get_write_credential_infos([artifact_file_path])[0]
|
719
|
+
self._upload_to_cloud(
|
720
|
+
cloud_credential_info=write_credential_info,
|
721
|
+
src_file_path=local_file,
|
722
|
+
artifact_file_path=artifact_file_path,
|
723
|
+
)
|
724
|
+
|
725
|
+
def list_artifacts(self, path: Optional[str] = None) -> list[FileInfo]:
|
726
|
+
return self.resource.list_artifacts(path)
|
727
|
+
|
728
|
+
def delete_artifacts(self, artifact_path=None):
|
729
|
+
raise MlflowException("Not implemented yet")
|