genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,3217 @@
|
|
1
|
+
# Define all the service endpoint handlers here.
|
2
|
+
import bisect
|
3
|
+
import io
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import pathlib
|
8
|
+
import posixpath
|
9
|
+
import re
|
10
|
+
import tempfile
|
11
|
+
import time
|
12
|
+
import urllib
|
13
|
+
from functools import wraps
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import requests
|
17
|
+
from flask import Response, current_app, jsonify, request, send_file
|
18
|
+
from google.protobuf import descriptor
|
19
|
+
from google.protobuf.json_format import ParseError
|
20
|
+
|
21
|
+
from mlflow.entities import (
|
22
|
+
DatasetInput,
|
23
|
+
ExperimentTag,
|
24
|
+
FileInfo,
|
25
|
+
Metric,
|
26
|
+
Param,
|
27
|
+
RunTag,
|
28
|
+
ViewType,
|
29
|
+
)
|
30
|
+
from mlflow.entities.logged_model import LoggedModel
|
31
|
+
from mlflow.entities.logged_model_input import LoggedModelInput
|
32
|
+
from mlflow.entities.logged_model_output import LoggedModelOutput
|
33
|
+
from mlflow.entities.logged_model_parameter import LoggedModelParameter
|
34
|
+
from mlflow.entities.logged_model_status import LoggedModelStatus
|
35
|
+
from mlflow.entities.logged_model_tag import LoggedModelTag
|
36
|
+
from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
|
37
|
+
from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
|
38
|
+
from mlflow.entities.multipart_upload import MultipartUploadPart
|
39
|
+
from mlflow.entities.trace_info import TraceInfo
|
40
|
+
from mlflow.entities.trace_info_v2 import TraceInfoV2
|
41
|
+
from mlflow.entities.trace_status import TraceStatus
|
42
|
+
from mlflow.environment_variables import (
|
43
|
+
MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX,
|
44
|
+
MLFLOW_DEPLOYMENTS_TARGET,
|
45
|
+
)
|
46
|
+
from mlflow.exceptions import MlflowException, _UnsupportedMultipartUploadException
|
47
|
+
from mlflow.models import Model
|
48
|
+
from mlflow.protos import databricks_pb2
|
49
|
+
from mlflow.protos.databricks_pb2 import (
|
50
|
+
BAD_REQUEST,
|
51
|
+
INVALID_PARAMETER_VALUE,
|
52
|
+
RESOURCE_DOES_NOT_EXIST,
|
53
|
+
)
|
54
|
+
from mlflow.protos.mlflow_artifacts_pb2 import (
|
55
|
+
AbortMultipartUpload,
|
56
|
+
CompleteMultipartUpload,
|
57
|
+
CreateMultipartUpload,
|
58
|
+
DeleteArtifact,
|
59
|
+
DownloadArtifact,
|
60
|
+
MlflowArtifactsService,
|
61
|
+
UploadArtifact,
|
62
|
+
)
|
63
|
+
from mlflow.protos.mlflow_artifacts_pb2 import (
|
64
|
+
ListArtifacts as ListArtifactsMlflowArtifacts,
|
65
|
+
)
|
66
|
+
from mlflow.protos.model_registry_pb2 import (
|
67
|
+
CreateModelVersion,
|
68
|
+
CreateRegisteredModel,
|
69
|
+
DeleteModelVersion,
|
70
|
+
DeleteModelVersionTag,
|
71
|
+
DeleteRegisteredModel,
|
72
|
+
DeleteRegisteredModelAlias,
|
73
|
+
DeleteRegisteredModelTag,
|
74
|
+
GetLatestVersions,
|
75
|
+
GetModelVersion,
|
76
|
+
GetModelVersionByAlias,
|
77
|
+
GetModelVersionDownloadUri,
|
78
|
+
GetRegisteredModel,
|
79
|
+
ModelRegistryService,
|
80
|
+
RenameRegisteredModel,
|
81
|
+
SearchModelVersions,
|
82
|
+
SearchRegisteredModels,
|
83
|
+
SetModelVersionTag,
|
84
|
+
SetRegisteredModelAlias,
|
85
|
+
SetRegisteredModelTag,
|
86
|
+
TransitionModelVersionStage,
|
87
|
+
UpdateModelVersion,
|
88
|
+
UpdateRegisteredModel,
|
89
|
+
)
|
90
|
+
from mlflow.protos.service_pb2 import (
|
91
|
+
CreateExperiment,
|
92
|
+
CreateLoggedModel,
|
93
|
+
CreateRun,
|
94
|
+
DeleteExperiment,
|
95
|
+
DeleteLoggedModel,
|
96
|
+
DeleteLoggedModelTag,
|
97
|
+
DeleteRun,
|
98
|
+
DeleteTag,
|
99
|
+
DeleteTraces,
|
100
|
+
DeleteTraceTag,
|
101
|
+
EndTrace,
|
102
|
+
FinalizeLoggedModel,
|
103
|
+
GetExperiment,
|
104
|
+
GetExperimentByName,
|
105
|
+
GetLoggedModel,
|
106
|
+
GetMetricHistory,
|
107
|
+
GetMetricHistoryBulkInterval,
|
108
|
+
GetRun,
|
109
|
+
GetTraceInfo,
|
110
|
+
GetTraceInfoV3,
|
111
|
+
ListArtifacts,
|
112
|
+
ListLoggedModelArtifacts,
|
113
|
+
LogBatch,
|
114
|
+
LogInputs,
|
115
|
+
LogLoggedModelParamsRequest,
|
116
|
+
LogMetric,
|
117
|
+
LogModel,
|
118
|
+
LogOutputs,
|
119
|
+
LogParam,
|
120
|
+
MlflowService,
|
121
|
+
RestoreExperiment,
|
122
|
+
RestoreRun,
|
123
|
+
SearchDatasets,
|
124
|
+
SearchExperiments,
|
125
|
+
SearchLoggedModels,
|
126
|
+
SearchRuns,
|
127
|
+
SearchTraces,
|
128
|
+
SearchTracesV3,
|
129
|
+
SetExperimentTag,
|
130
|
+
SetLoggedModelTags,
|
131
|
+
SetTag,
|
132
|
+
SetTraceTag,
|
133
|
+
StartTrace,
|
134
|
+
StartTraceV3,
|
135
|
+
UpdateExperiment,
|
136
|
+
UpdateRun,
|
137
|
+
)
|
138
|
+
from mlflow.protos.service_pb2 import Trace as ProtoTrace
|
139
|
+
from mlflow.server.validation import _validate_content_type
|
140
|
+
from mlflow.store.artifact.artifact_repo import MultipartUploadMixin
|
141
|
+
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
|
142
|
+
from mlflow.store.db.db_types import DATABASE_ENGINES
|
143
|
+
from mlflow.tracing.utils.artifact_utils import (
|
144
|
+
TRACE_DATA_FILE_NAME,
|
145
|
+
get_artifact_uri_for_trace,
|
146
|
+
)
|
147
|
+
from mlflow.tracking._model_registry import utils as registry_utils
|
148
|
+
from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry
|
149
|
+
from mlflow.tracking._tracking_service import utils
|
150
|
+
from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry
|
151
|
+
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
|
152
|
+
from mlflow.utils.file_utils import local_file_uri_to_path
|
153
|
+
from mlflow.utils.mime_type_utils import _guess_mime_type
|
154
|
+
from mlflow.utils.promptlab_utils import _create_promptlab_run_impl
|
155
|
+
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
|
156
|
+
from mlflow.utils.security_validation import InputValidator, SecurityValidationError
|
157
|
+
from mlflow.utils.string_utils import is_string_type
|
158
|
+
from mlflow.utils.uri import is_local_uri, validate_path_is_safe, validate_query_string
|
159
|
+
from mlflow.utils.validation import (
|
160
|
+
_validate_batch_log_api_req,
|
161
|
+
invalid_value,
|
162
|
+
missing_value,
|
163
|
+
)
|
164
|
+
|
165
|
+
_logger = logging.getLogger(__name__)
|
166
|
+
_tracking_store = None
|
167
|
+
_model_registry_store = None
|
168
|
+
_artifact_repo = None
|
169
|
+
STATIC_PREFIX_ENV_VAR = "_MLFLOW_STATIC_PREFIX"
|
170
|
+
MAX_RUNS_GET_METRIC_HISTORY_BULK = 100
|
171
|
+
MAX_RESULTS_PER_RUN = 2500
|
172
|
+
MAX_RESULTS_GET_METRIC_HISTORY = 25000
|
173
|
+
|
174
|
+
|
175
|
+
class TrackingStoreRegistryWrapper(TrackingStoreRegistry):
|
176
|
+
def __init__(self):
|
177
|
+
super().__init__()
|
178
|
+
self.register("", self._get_file_store)
|
179
|
+
self.register("file", self._get_file_store)
|
180
|
+
for scheme in DATABASE_ENGINES:
|
181
|
+
self.register(scheme, self._get_sqlalchemy_store)
|
182
|
+
self.register_entrypoints()
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def _get_file_store(cls, store_uri, artifact_uri):
|
186
|
+
from mlflow.store.tracking.file_store import FileStore
|
187
|
+
|
188
|
+
return FileStore(store_uri, artifact_uri)
|
189
|
+
|
190
|
+
@classmethod
|
191
|
+
def _get_sqlalchemy_store(cls, store_uri, artifact_uri):
|
192
|
+
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
|
193
|
+
|
194
|
+
return SqlAlchemyStore(store_uri, artifact_uri)
|
195
|
+
|
196
|
+
|
197
|
+
|
198
|
+
class ModelRegistryStoreRegistryWrapper(ModelRegistryStoreRegistry):
|
199
|
+
def __init__(self):
|
200
|
+
super().__init__()
|
201
|
+
self.register("", self._get_file_store)
|
202
|
+
self.register("file", self._get_file_store)
|
203
|
+
for scheme in DATABASE_ENGINES:
|
204
|
+
self.register(scheme, self._get_sqlalchemy_store)
|
205
|
+
self.register_entrypoints()
|
206
|
+
|
207
|
+
@classmethod
|
208
|
+
def _get_file_store(cls, store_uri):
|
209
|
+
from mlflow.store.model_registry.file_store import FileStore
|
210
|
+
|
211
|
+
return FileStore(store_uri)
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def _get_sqlalchemy_store(cls, store_uri):
|
215
|
+
from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
|
216
|
+
|
217
|
+
return SqlAlchemyStore(store_uri)
|
218
|
+
|
219
|
+
|
220
|
+
_tracking_store_registry = TrackingStoreRegistryWrapper()
|
221
|
+
_model_registry_store_registry = ModelRegistryStoreRegistryWrapper()
|
222
|
+
|
223
|
+
|
224
|
+
def _get_artifact_repo_mlflow_artifacts():
|
225
|
+
"""
|
226
|
+
Get an artifact repository specified by ``--artifacts-destination`` option for ``mlflow server``
|
227
|
+
command.
|
228
|
+
"""
|
229
|
+
from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
|
230
|
+
|
231
|
+
global _artifact_repo
|
232
|
+
if _artifact_repo is None:
|
233
|
+
_artifact_repo = get_artifact_repository(os.environ[ARTIFACTS_DESTINATION_ENV_VAR])
|
234
|
+
return _artifact_repo
|
235
|
+
|
236
|
+
|
237
|
+
def _get_trace_artifact_repo(trace_info: TraceInfo):
|
238
|
+
"""
|
239
|
+
Resolve the artifact repository for fetching data for the given trace.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
trace_info: The trace info object containing metadata about the trace.
|
243
|
+
"""
|
244
|
+
artifact_uri = get_artifact_uri_for_trace(trace_info)
|
245
|
+
|
246
|
+
if _is_servable_proxied_run_artifact_root(artifact_uri):
|
247
|
+
# If the artifact location is a proxied run artifact root (e.g. mlflow-artifacts://...),
|
248
|
+
# we need to resolve it to the actual artifact location.
|
249
|
+
from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
|
250
|
+
|
251
|
+
path = _get_proxied_run_artifact_destination_path(artifact_uri)
|
252
|
+
if not path:
|
253
|
+
raise MlflowException(
|
254
|
+
f"Failed to resolve the proxied run artifact URI: {artifact_uri}. ",
|
255
|
+
"Trace artifact URI must contain subpath to the trace data directory.",
|
256
|
+
error_code=BAD_REQUEST,
|
257
|
+
)
|
258
|
+
root = os.environ[ARTIFACTS_DESTINATION_ENV_VAR]
|
259
|
+
artifact_uri = posixpath.join(root, path)
|
260
|
+
|
261
|
+
# We don't set it to global var unlike run artifact, because the artifact repo has
|
262
|
+
# to be created with full trace artifact URI including request_id.
|
263
|
+
# e.g. s3://<experiment_id>/traces/<request_id>
|
264
|
+
artifact_repo = get_artifact_repository(artifact_uri)
|
265
|
+
else:
|
266
|
+
artifact_repo = get_artifact_repository(artifact_uri)
|
267
|
+
return artifact_repo
|
268
|
+
|
269
|
+
|
270
|
+
def _is_serving_proxied_artifacts():
|
271
|
+
"""
|
272
|
+
Returns:
|
273
|
+
True if the MLflow server is serving proxied artifacts (i.e. acting as a proxy for
|
274
|
+
artifact upload / download / list operations), as would be enabled by specifying the
|
275
|
+
--serve-artifacts configuration option. False otherwise.
|
276
|
+
"""
|
277
|
+
from mlflow.server import SERVE_ARTIFACTS_ENV_VAR
|
278
|
+
|
279
|
+
return os.environ.get(SERVE_ARTIFACTS_ENV_VAR, "false") == "true"
|
280
|
+
|
281
|
+
|
282
|
+
def _is_servable_proxied_run_artifact_root(run_artifact_root):
|
283
|
+
"""
|
284
|
+
Determines whether or not the following are true:
|
285
|
+
|
286
|
+
- The specified Run artifact root is a proxied artifact root (i.e. an artifact root with scheme
|
287
|
+
``http``, ``https``, or ``mlflow-artifacts``).
|
288
|
+
|
289
|
+
- The MLflow server is capable of resolving and accessing the underlying storage location
|
290
|
+
corresponding to the proxied artifact root, allowing it to fulfill artifact list and
|
291
|
+
download requests by using this storage location directly.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
run_artifact_root: The Run artifact root location (URI).
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
True if the specified Run artifact root refers to proxied artifacts that can be
|
298
|
+
served by this MLflow server (i.e. the server has access to the destination and
|
299
|
+
can respond to list and download requests for the artifact). False otherwise.
|
300
|
+
"""
|
301
|
+
parsed_run_artifact_root = urllib.parse.urlparse(run_artifact_root)
|
302
|
+
# NB: If the run artifact root is a proxied artifact root (has scheme `http`, `https`, or
|
303
|
+
# `mlflow-artifacts`) *and* the MLflow server is configured to serve artifacts, the MLflow
|
304
|
+
# server always assumes that it has access to the underlying storage location for the proxied
|
305
|
+
# artifacts. This may not always be accurate. For example:
|
306
|
+
#
|
307
|
+
# An organization may initially use the MLflow server to serve Tracking API requests and proxy
|
308
|
+
# access to artifacts stored in Location A (via `mlflow server --serve-artifacts`). Then, for
|
309
|
+
# scalability and / or security purposes, the organization may decide to store artifacts in a
|
310
|
+
# new location B and set up a separate server (e.g. `mlflow server --artifacts-only`) to proxy
|
311
|
+
# access to artifacts stored in Location B.
|
312
|
+
#
|
313
|
+
# In this scenario, requests for artifacts stored in Location B that are sent to the original
|
314
|
+
# MLflow server will fail if the original MLflow server does not have access to Location B
|
315
|
+
# because it will assume that it can serve all proxied artifacts regardless of the underlying
|
316
|
+
# location. Such failures can be remediated by granting the original MLflow server access to
|
317
|
+
# Location B.
|
318
|
+
return (
|
319
|
+
parsed_run_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
|
320
|
+
and _is_serving_proxied_artifacts()
|
321
|
+
)
|
322
|
+
|
323
|
+
|
324
|
+
def _get_proxied_run_artifact_destination_path(proxied_artifact_root, relative_path=None):
|
325
|
+
"""
|
326
|
+
Resolves the specified proxied artifact location within a Run to a concrete storage location.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
|
330
|
+
``https``, or `mlflow-artifacts` that can be resolved by the MLflow server to a
|
331
|
+
concrete storage location.
|
332
|
+
relative_path: The relative path of the destination within the specified
|
333
|
+
``proxied_artifact_root``. If ``None``, the destination is assumed to be
|
334
|
+
the resolved ``proxied_artifact_root``.
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
The storage location of the specified artifact.
|
338
|
+
"""
|
339
|
+
parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
|
340
|
+
assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
|
341
|
+
|
342
|
+
if parsed_proxied_artifact_root.scheme == "mlflow-artifacts":
|
343
|
+
# If the proxied artifact root is an `mlflow-artifacts` URI, the run artifact root path is
|
344
|
+
# simply the path component of the URI, since the fully-qualified format of an
|
345
|
+
# `mlflow-artifacts` URI is `mlflow-artifacts://<netloc>/path/to/artifact`
|
346
|
+
proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.lstrip("/")
|
347
|
+
else:
|
348
|
+
# In this case, the proxied artifact root is an HTTP(S) URL referring to an mlflow-artifacts
|
349
|
+
# API route that can be used to download the artifact. These routes are always anchored at
|
350
|
+
# `/api/2.0/mlflow-artifacts/artifacts`. Accordingly, we split the path on this route anchor
|
351
|
+
# and interpret the rest of the path (everything after the route anchor) as the run artifact
|
352
|
+
# root path
|
353
|
+
mlflow_artifacts_http_route_anchor = "/api/2.0/mlflow-artifacts/artifacts/"
|
354
|
+
assert mlflow_artifacts_http_route_anchor in parsed_proxied_artifact_root.path
|
355
|
+
|
356
|
+
proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.split(
|
357
|
+
mlflow_artifacts_http_route_anchor
|
358
|
+
)[1].lstrip("/")
|
359
|
+
|
360
|
+
return (
|
361
|
+
posixpath.join(proxied_run_artifact_root_path, relative_path)
|
362
|
+
if relative_path is not None
|
363
|
+
else proxied_run_artifact_root_path
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
def _get_tracking_store(backend_store_uri=None, default_artifact_root=None):
|
368
|
+
from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR
|
369
|
+
|
370
|
+
global _tracking_store
|
371
|
+
if _tracking_store is None:
|
372
|
+
store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
|
373
|
+
artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
|
374
|
+
_tracking_store = _tracking_store_registry.get_store(store_uri, artifact_root)
|
375
|
+
utils.set_tracking_uri(store_uri)
|
376
|
+
return _tracking_store
|
377
|
+
|
378
|
+
|
379
|
+
def _get_model_registry_store(registry_store_uri=None):
|
380
|
+
from mlflow.server import BACKEND_STORE_URI_ENV_VAR, REGISTRY_STORE_URI_ENV_VAR
|
381
|
+
|
382
|
+
global _model_registry_store
|
383
|
+
if _model_registry_store is None:
|
384
|
+
store_uri = (
|
385
|
+
registry_store_uri
|
386
|
+
or os.environ.get(REGISTRY_STORE_URI_ENV_VAR, None)
|
387
|
+
or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
|
388
|
+
)
|
389
|
+
_model_registry_store = _model_registry_store_registry.get_store(store_uri)
|
390
|
+
registry_utils.set_registry_uri(store_uri)
|
391
|
+
return _model_registry_store
|
392
|
+
|
393
|
+
|
394
|
+
def initialize_backend_stores(
|
395
|
+
backend_store_uri=None, registry_store_uri=None, default_artifact_root=None
|
396
|
+
):
|
397
|
+
_get_tracking_store(backend_store_uri, default_artifact_root)
|
398
|
+
try:
|
399
|
+
_get_model_registry_store(registry_store_uri)
|
400
|
+
except UnsupportedModelRegistryStoreURIException:
|
401
|
+
pass
|
402
|
+
|
403
|
+
|
404
|
+
def _assert_string(x):
|
405
|
+
assert isinstance(x, str)
|
406
|
+
|
407
|
+
|
408
|
+
def _assert_intlike(x):
|
409
|
+
try:
|
410
|
+
x = int(x)
|
411
|
+
except ValueError:
|
412
|
+
pass
|
413
|
+
|
414
|
+
assert isinstance(x, int)
|
415
|
+
|
416
|
+
|
417
|
+
def _assert_bool(x):
|
418
|
+
assert isinstance(x, bool)
|
419
|
+
|
420
|
+
|
421
|
+
def _assert_floatlike(x):
|
422
|
+
try:
|
423
|
+
x = float(x)
|
424
|
+
except ValueError:
|
425
|
+
pass
|
426
|
+
|
427
|
+
assert isinstance(x, float)
|
428
|
+
|
429
|
+
|
430
|
+
def _assert_array(x):
|
431
|
+
assert isinstance(x, list)
|
432
|
+
|
433
|
+
|
434
|
+
def _assert_map_key_present(x):
|
435
|
+
_assert_array(x)
|
436
|
+
for entry in x:
|
437
|
+
_assert_required(entry.get("key"))
|
438
|
+
|
439
|
+
|
440
|
+
def _assert_required(x, path=None):
|
441
|
+
if path is None:
|
442
|
+
assert x is not None
|
443
|
+
# When parsing JSON payloads via proto, absent string fields
|
444
|
+
# are expressed as empty strings
|
445
|
+
assert x != ""
|
446
|
+
else:
|
447
|
+
assert x is not None, missing_value(path)
|
448
|
+
assert x != "", missing_value(path)
|
449
|
+
|
450
|
+
|
451
|
+
def _assert_less_than_or_equal(x, max_value, message=None):
|
452
|
+
if x > max_value:
|
453
|
+
raise AssertionError(message) if message else AssertionError()
|
454
|
+
|
455
|
+
|
456
|
+
def _assert_intlike_within_range(x, min_value, max_value, message=None):
|
457
|
+
if not min_value <= x <= max_value:
|
458
|
+
raise AssertionError(message) if message else AssertionError()
|
459
|
+
|
460
|
+
|
461
|
+
def _assert_item_type_string(x):
|
462
|
+
assert all(isinstance(item, str) for item in x)
|
463
|
+
|
464
|
+
|
465
|
+
_TYPE_VALIDATORS = {
|
466
|
+
_assert_intlike,
|
467
|
+
_assert_string,
|
468
|
+
_assert_bool,
|
469
|
+
_assert_floatlike,
|
470
|
+
_assert_array,
|
471
|
+
_assert_item_type_string,
|
472
|
+
}
|
473
|
+
|
474
|
+
|
475
|
+
def _validate_param_against_schema(schema, param, value, proto_parsing_succeeded=False):
|
476
|
+
"""
|
477
|
+
Attempts to validate a single parameter against a specified schema. Examples of the elements of
|
478
|
+
the schema are type assertions and checks for required parameters. Returns None on validation
|
479
|
+
success. Otherwise, raises an MLFlowException if an assertion fails. This method is intended
|
480
|
+
to be called for side effects.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
schema: A list of functions to validate the parameter against.
|
484
|
+
param: The string name of the parameter being validated.
|
485
|
+
value: The corresponding value of the `param` being validated.
|
486
|
+
proto_parsing_succeeded: A boolean value indicating whether proto parsing succeeded.
|
487
|
+
If the proto was successfully parsed, we assume all of the types of the parameters in
|
488
|
+
the request body were correctly specified, and thus we skip validating types. If proto
|
489
|
+
parsing failed, then we validate types in addition to the rest of the schema. For
|
490
|
+
details, see https://github.com/mlflow/mlflow/pull/5458#issuecomment-1080880870.
|
491
|
+
"""
|
492
|
+
|
493
|
+
for f in schema:
|
494
|
+
if f in _TYPE_VALIDATORS and proto_parsing_succeeded:
|
495
|
+
continue
|
496
|
+
|
497
|
+
try:
|
498
|
+
f(value)
|
499
|
+
except AssertionError as e:
|
500
|
+
if e.args:
|
501
|
+
message = e.args[0]
|
502
|
+
elif f == _assert_required:
|
503
|
+
message = f"Missing value for required parameter '{param}'."
|
504
|
+
else:
|
505
|
+
message = invalid_value(
|
506
|
+
param, value, f" Hint: Value was of type '{type(value).__name__}'."
|
507
|
+
)
|
508
|
+
raise MlflowException(
|
509
|
+
message=(
|
510
|
+
message + " See the API docs for more information about request parameters."
|
511
|
+
),
|
512
|
+
error_code=INVALID_PARAMETER_VALUE,
|
513
|
+
)
|
514
|
+
|
515
|
+
return None
|
516
|
+
|
517
|
+
|
518
|
+
def _get_request_json(flask_request=request):
|
519
|
+
_validate_content_type(flask_request, ["application/json"])
|
520
|
+
return flask_request.get_json(force=True, silent=True)
|
521
|
+
|
522
|
+
|
523
|
+
def _get_request_message(request_message, flask_request=request, schema=None):
|
524
|
+
if flask_request.method == "GET" and flask_request.args:
|
525
|
+
# Convert atomic values of repeated fields to lists before calling protobuf deserialization.
|
526
|
+
# Context: We parse the parameter string into a dictionary outside of protobuf since
|
527
|
+
# protobuf does not know how to read the query parameters directly. The query parser above
|
528
|
+
# has no type information and hence any parameter that occurs exactly once is parsed as an
|
529
|
+
# atomic value. Since protobuf requires that the values of repeated fields are lists,
|
530
|
+
# deserialization will fail unless we do the fix below.
|
531
|
+
request_json = {}
|
532
|
+
for field in request_message.DESCRIPTOR.fields:
|
533
|
+
if field.name not in flask_request.args:
|
534
|
+
continue
|
535
|
+
|
536
|
+
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
|
537
|
+
request_json[field.name] = flask_request.args.getlist(field.name)
|
538
|
+
else:
|
539
|
+
request_json[field.name] = flask_request.args.get(field.name)
|
540
|
+
else:
|
541
|
+
request_json = _get_request_json(flask_request)
|
542
|
+
|
543
|
+
# Older clients may post their JSON double-encoded as strings, so the get_json
|
544
|
+
# above actually converts it to a string. Therefore, we check this condition
|
545
|
+
# (which we can tell for sure because any proper request should be a dictionary),
|
546
|
+
# and decode it a second time.
|
547
|
+
if is_string_type(request_json):
|
548
|
+
request_json = json.loads(request_json)
|
549
|
+
|
550
|
+
# If request doesn't have json body then assume it's empty.
|
551
|
+
if request_json is None:
|
552
|
+
request_json = {}
|
553
|
+
|
554
|
+
proto_parsing_succeeded = True
|
555
|
+
try:
|
556
|
+
parse_dict(request_json, request_message)
|
557
|
+
except ParseError:
|
558
|
+
proto_parsing_succeeded = False
|
559
|
+
|
560
|
+
schema = schema or {}
|
561
|
+
for schema_key, schema_validation_fns in schema.items():
|
562
|
+
if schema_key in request_json or _assert_required in schema_validation_fns:
|
563
|
+
value = request_json.get(schema_key)
|
564
|
+
if schema_key == "run_id" and value is None and "run_uuid" in request_json:
|
565
|
+
value = request_json.get("run_uuid")
|
566
|
+
_validate_param_against_schema(
|
567
|
+
schema=schema_validation_fns,
|
568
|
+
param=schema_key,
|
569
|
+
value=value,
|
570
|
+
proto_parsing_succeeded=proto_parsing_succeeded,
|
571
|
+
)
|
572
|
+
|
573
|
+
return request_message
|
574
|
+
|
575
|
+
|
576
|
+
def _response_with_file_attachment_headers(file_path, response):
|
577
|
+
mime_type = _guess_mime_type(file_path)
|
578
|
+
filename = pathlib.Path(file_path).name
|
579
|
+
response.mimetype = mime_type
|
580
|
+
content_disposition_header_name = "Content-Disposition"
|
581
|
+
if content_disposition_header_name not in response.headers:
|
582
|
+
response.headers[content_disposition_header_name] = f"attachment; filename={filename}"
|
583
|
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
584
|
+
response.headers["Content-Type"] = mime_type
|
585
|
+
return response
|
586
|
+
|
587
|
+
|
588
|
+
def _send_artifact(artifact_repository, path):
|
589
|
+
file_path = os.path.abspath(artifact_repository.download_artifacts(path))
|
590
|
+
# Always send artifacts as attachments to prevent the browser from displaying them on our web
|
591
|
+
# server's domain, which might enable XSS.
|
592
|
+
mime_type = _guess_mime_type(file_path)
|
593
|
+
file_sender_response = send_file(file_path, mimetype=mime_type, as_attachment=True)
|
594
|
+
return _response_with_file_attachment_headers(file_path, file_sender_response)
|
595
|
+
|
596
|
+
|
597
|
+
def catch_mlflow_exception(func):
|
598
|
+
@wraps(func)
|
599
|
+
def wrapper(*args, **kwargs):
|
600
|
+
try:
|
601
|
+
return func(*args, **kwargs)
|
602
|
+
except MlflowException as e:
|
603
|
+
response = Response(mimetype="application/json")
|
604
|
+
response.set_data(e.serialize_as_json())
|
605
|
+
response.status_code = e.get_http_status_code()
|
606
|
+
return response
|
607
|
+
|
608
|
+
return wrapper
|
609
|
+
|
610
|
+
|
611
|
+
def _disable_unless_serve_artifacts(func):
|
612
|
+
@wraps(func)
|
613
|
+
def wrapper(*args, **kwargs):
|
614
|
+
if not _is_serving_proxied_artifacts():
|
615
|
+
return Response(
|
616
|
+
(
|
617
|
+
f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
|
618
|
+
"with `--no-serve-artifacts`. To enable artifacts server functionality, "
|
619
|
+
"run `mlflow server` with `--serve-artifacts`"
|
620
|
+
),
|
621
|
+
503,
|
622
|
+
)
|
623
|
+
return func(*args, **kwargs)
|
624
|
+
|
625
|
+
return wrapper
|
626
|
+
|
627
|
+
|
628
|
+
def _disable_if_artifacts_only(func):
|
629
|
+
@wraps(func)
|
630
|
+
def wrapper(*args, **kwargs):
|
631
|
+
from mlflow.server import ARTIFACTS_ONLY_ENV_VAR
|
632
|
+
|
633
|
+
if os.environ.get(ARTIFACTS_ONLY_ENV_VAR):
|
634
|
+
return Response(
|
635
|
+
(
|
636
|
+
f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
|
637
|
+
"in `--artifacts-only` mode. To enable tracking server functionality, run "
|
638
|
+
"`mlflow server` without `--artifacts-only`"
|
639
|
+
),
|
640
|
+
503,
|
641
|
+
)
|
642
|
+
return func(*args, **kwargs)
|
643
|
+
|
644
|
+
return wrapper
|
645
|
+
|
646
|
+
|
647
|
+
@catch_mlflow_exception
|
648
|
+
def get_artifact_handler():
|
649
|
+
run_id = request.args.get("run_id") or request.args.get("run_uuid")
|
650
|
+
path = request.args["path"]
|
651
|
+
path = validate_path_is_safe(path)
|
652
|
+
run = _get_tracking_store().get_run(run_id)
|
653
|
+
|
654
|
+
if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
|
655
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
656
|
+
artifact_path = _get_proxied_run_artifact_destination_path(
|
657
|
+
proxied_artifact_root=run.info.artifact_uri,
|
658
|
+
relative_path=path,
|
659
|
+
)
|
660
|
+
else:
|
661
|
+
artifact_repo = _get_artifact_repo(run)
|
662
|
+
artifact_path = path
|
663
|
+
|
664
|
+
return _send_artifact(artifact_repo, artifact_path)
|
665
|
+
|
666
|
+
|
667
|
+
def _not_implemented():
|
668
|
+
response = Response()
|
669
|
+
response.status_code = 404
|
670
|
+
return response
|
671
|
+
|
672
|
+
|
673
|
+
# Tracking Server APIs
|
674
|
+
|
675
|
+
|
676
|
+
@catch_mlflow_exception
|
677
|
+
@_disable_if_artifacts_only
|
678
|
+
def _create_experiment():
|
679
|
+
request_message = _get_request_message(
|
680
|
+
CreateExperiment(),
|
681
|
+
schema={
|
682
|
+
"name": [_assert_required, _assert_string],
|
683
|
+
"artifact_location": [_assert_string],
|
684
|
+
"tags": [_assert_array],
|
685
|
+
},
|
686
|
+
)
|
687
|
+
|
688
|
+
# Security validation for experiment name
|
689
|
+
try:
|
690
|
+
validated_name = InputValidator.validate_experiment_name(request_message.name)
|
691
|
+
except SecurityValidationError as e:
|
692
|
+
raise MlflowException(
|
693
|
+
f"Invalid experiment name: {e}",
|
694
|
+
error_code=INVALID_PARAMETER_VALUE,
|
695
|
+
)
|
696
|
+
|
697
|
+
# Security validation for tags
|
698
|
+
validated_tags = []
|
699
|
+
for tag in request_message.tags:
|
700
|
+
try:
|
701
|
+
validated_key = InputValidator.validate_tag_key(tag.key)
|
702
|
+
validated_value = InputValidator.validate_tag_value(tag.value)
|
703
|
+
validated_tags.append(ExperimentTag(validated_key, validated_value))
|
704
|
+
except SecurityValidationError as e:
|
705
|
+
raise MlflowException(
|
706
|
+
f"Invalid tag: {e}",
|
707
|
+
error_code=INVALID_PARAMETER_VALUE,
|
708
|
+
)
|
709
|
+
|
710
|
+
# Security validation for artifact location
|
711
|
+
if request_message.artifact_location:
|
712
|
+
try:
|
713
|
+
validated_artifact_location = InputValidator.validate_uri(request_message.artifact_location)
|
714
|
+
except SecurityValidationError as e:
|
715
|
+
raise MlflowException(
|
716
|
+
f"Invalid artifact location: {e}",
|
717
|
+
error_code=INVALID_PARAMETER_VALUE,
|
718
|
+
)
|
719
|
+
else:
|
720
|
+
validated_artifact_location = request_message.artifact_location
|
721
|
+
|
722
|
+
# Validate query string in artifact location to prevent attacks
|
723
|
+
parsed_artifact_location = urllib.parse.urlparse(validated_artifact_location)
|
724
|
+
if parsed_artifact_location.fragment or parsed_artifact_location.params:
|
725
|
+
raise MlflowException(
|
726
|
+
"'artifact_location' URL can't include fragments or params.",
|
727
|
+
error_code=INVALID_PARAMETER_VALUE,
|
728
|
+
)
|
729
|
+
validate_query_string(parsed_artifact_location.query)
|
730
|
+
experiment_id = _get_tracking_store().create_experiment(
|
731
|
+
validated_name, validated_artifact_location, validated_tags
|
732
|
+
)
|
733
|
+
response_message = CreateExperiment.Response()
|
734
|
+
response_message.experiment_id = experiment_id
|
735
|
+
response = Response(mimetype="application/json")
|
736
|
+
response.set_data(message_to_json(response_message))
|
737
|
+
return response
|
738
|
+
|
739
|
+
|
740
|
+
@catch_mlflow_exception
|
741
|
+
@_disable_if_artifacts_only
|
742
|
+
def _get_experiment():
|
743
|
+
request_message = _get_request_message(
|
744
|
+
GetExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
|
745
|
+
)
|
746
|
+
response_message = get_experiment_impl(request_message)
|
747
|
+
response = Response(mimetype="application/json")
|
748
|
+
response.set_data(message_to_json(response_message))
|
749
|
+
return response
|
750
|
+
|
751
|
+
|
752
|
+
def get_experiment_impl(request_message):
|
753
|
+
response_message = GetExperiment.Response()
|
754
|
+
experiment = _get_tracking_store().get_experiment(request_message.experiment_id).to_proto()
|
755
|
+
response_message.experiment.MergeFrom(experiment)
|
756
|
+
return response_message
|
757
|
+
|
758
|
+
|
759
|
+
@catch_mlflow_exception
|
760
|
+
@_disable_if_artifacts_only
|
761
|
+
def _get_experiment_by_name():
|
762
|
+
request_message = _get_request_message(
|
763
|
+
GetExperimentByName(),
|
764
|
+
schema={"experiment_name": [_assert_required, _assert_string]},
|
765
|
+
)
|
766
|
+
response_message = GetExperimentByName.Response()
|
767
|
+
store_exp = _get_tracking_store().get_experiment_by_name(request_message.experiment_name)
|
768
|
+
if store_exp is None:
|
769
|
+
raise MlflowException(
|
770
|
+
f"Could not find experiment with name '{request_message.experiment_name}'",
|
771
|
+
error_code=RESOURCE_DOES_NOT_EXIST,
|
772
|
+
)
|
773
|
+
experiment = store_exp.to_proto()
|
774
|
+
response_message.experiment.MergeFrom(experiment)
|
775
|
+
response = Response(mimetype="application/json")
|
776
|
+
response.set_data(message_to_json(response_message))
|
777
|
+
return response
|
778
|
+
|
779
|
+
|
780
|
+
@catch_mlflow_exception
|
781
|
+
@_disable_if_artifacts_only
|
782
|
+
def _delete_experiment():
|
783
|
+
request_message = _get_request_message(
|
784
|
+
DeleteExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
|
785
|
+
)
|
786
|
+
_get_tracking_store().delete_experiment(request_message.experiment_id)
|
787
|
+
response_message = DeleteExperiment.Response()
|
788
|
+
response = Response(mimetype="application/json")
|
789
|
+
response.set_data(message_to_json(response_message))
|
790
|
+
return response
|
791
|
+
|
792
|
+
|
793
|
+
@catch_mlflow_exception
|
794
|
+
@_disable_if_artifacts_only
|
795
|
+
def _restore_experiment():
|
796
|
+
request_message = _get_request_message(
|
797
|
+
RestoreExperiment(),
|
798
|
+
schema={"experiment_id": [_assert_required, _assert_string]},
|
799
|
+
)
|
800
|
+
_get_tracking_store().restore_experiment(request_message.experiment_id)
|
801
|
+
response_message = RestoreExperiment.Response()
|
802
|
+
response = Response(mimetype="application/json")
|
803
|
+
response.set_data(message_to_json(response_message))
|
804
|
+
return response
|
805
|
+
|
806
|
+
|
807
|
+
@catch_mlflow_exception
|
808
|
+
@_disable_if_artifacts_only
|
809
|
+
def _update_experiment():
|
810
|
+
request_message = _get_request_message(
|
811
|
+
UpdateExperiment(),
|
812
|
+
schema={
|
813
|
+
"experiment_id": [_assert_required, _assert_string],
|
814
|
+
"new_name": [_assert_string, _assert_required],
|
815
|
+
},
|
816
|
+
)
|
817
|
+
if request_message.new_name:
|
818
|
+
_get_tracking_store().rename_experiment(
|
819
|
+
request_message.experiment_id, request_message.new_name
|
820
|
+
)
|
821
|
+
response_message = UpdateExperiment.Response()
|
822
|
+
response = Response(mimetype="application/json")
|
823
|
+
response.set_data(message_to_json(response_message))
|
824
|
+
return response
|
825
|
+
|
826
|
+
|
827
|
+
@catch_mlflow_exception
|
828
|
+
@_disable_if_artifacts_only
|
829
|
+
def _create_run():
|
830
|
+
request_message = _get_request_message(
|
831
|
+
CreateRun(),
|
832
|
+
schema={
|
833
|
+
"experiment_id": [_assert_string],
|
834
|
+
"start_time": [_assert_intlike],
|
835
|
+
"run_name": [_assert_string],
|
836
|
+
},
|
837
|
+
)
|
838
|
+
|
839
|
+
tags = [RunTag(tag.key, tag.value) for tag in request_message.tags]
|
840
|
+
run = _get_tracking_store().create_run(
|
841
|
+
experiment_id=request_message.experiment_id,
|
842
|
+
user_id=request_message.user_id,
|
843
|
+
start_time=request_message.start_time,
|
844
|
+
tags=tags,
|
845
|
+
run_name=request_message.run_name,
|
846
|
+
)
|
847
|
+
|
848
|
+
response_message = CreateRun.Response()
|
849
|
+
response_message.run.MergeFrom(run.to_proto())
|
850
|
+
response = Response(mimetype="application/json")
|
851
|
+
response.set_data(message_to_json(response_message))
|
852
|
+
return response
|
853
|
+
|
854
|
+
|
855
|
+
@catch_mlflow_exception
|
856
|
+
@_disable_if_artifacts_only
|
857
|
+
def _update_run():
|
858
|
+
request_message = _get_request_message(
|
859
|
+
UpdateRun(),
|
860
|
+
schema={
|
861
|
+
"run_id": [_assert_required, _assert_string],
|
862
|
+
"end_time": [_assert_intlike],
|
863
|
+
"status": [_assert_string],
|
864
|
+
"run_name": [_assert_string],
|
865
|
+
},
|
866
|
+
)
|
867
|
+
run_id = request_message.run_id or request_message.run_uuid
|
868
|
+
run_name = request_message.run_name if request_message.HasField("run_name") else None
|
869
|
+
end_time = request_message.end_time if request_message.HasField("end_time") else None
|
870
|
+
status = request_message.status if request_message.HasField("status") else None
|
871
|
+
updated_info = _get_tracking_store().update_run_info(run_id, status, end_time, run_name)
|
872
|
+
response_message = UpdateRun.Response(run_info=updated_info.to_proto())
|
873
|
+
response = Response(mimetype="application/json")
|
874
|
+
response.set_data(message_to_json(response_message))
|
875
|
+
return response
|
876
|
+
|
877
|
+
|
878
|
+
@catch_mlflow_exception
|
879
|
+
@_disable_if_artifacts_only
|
880
|
+
def _delete_run():
|
881
|
+
request_message = _get_request_message(
|
882
|
+
DeleteRun(), schema={"run_id": [_assert_required, _assert_string]}
|
883
|
+
)
|
884
|
+
_get_tracking_store().delete_run(request_message.run_id)
|
885
|
+
response_message = DeleteRun.Response()
|
886
|
+
response = Response(mimetype="application/json")
|
887
|
+
response.set_data(message_to_json(response_message))
|
888
|
+
return response
|
889
|
+
|
890
|
+
|
891
|
+
@catch_mlflow_exception
|
892
|
+
@_disable_if_artifacts_only
|
893
|
+
def _restore_run():
|
894
|
+
request_message = _get_request_message(
|
895
|
+
RestoreRun(), schema={"run_id": [_assert_required, _assert_string]}
|
896
|
+
)
|
897
|
+
_get_tracking_store().restore_run(request_message.run_id)
|
898
|
+
response_message = RestoreRun.Response()
|
899
|
+
response = Response(mimetype="application/json")
|
900
|
+
response.set_data(message_to_json(response_message))
|
901
|
+
return response
|
902
|
+
|
903
|
+
|
904
|
+
@catch_mlflow_exception
|
905
|
+
@_disable_if_artifacts_only
|
906
|
+
def _log_metric():
|
907
|
+
request_message = _get_request_message(
|
908
|
+
LogMetric(),
|
909
|
+
schema={
|
910
|
+
"run_id": [_assert_required, _assert_string],
|
911
|
+
"key": [_assert_required, _assert_string],
|
912
|
+
"value": [_assert_required, _assert_floatlike],
|
913
|
+
"timestamp": [_assert_intlike, _assert_required],
|
914
|
+
"step": [_assert_intlike],
|
915
|
+
"model_id": [_assert_string],
|
916
|
+
"dataset_name": [_assert_string],
|
917
|
+
"dataset_digest": [_assert_string],
|
918
|
+
},
|
919
|
+
)
|
920
|
+
|
921
|
+
# Security validation for metric key
|
922
|
+
try:
|
923
|
+
validated_key = InputValidator.validate_metric_key(request_message.key)
|
924
|
+
except SecurityValidationError as e:
|
925
|
+
raise MlflowException(
|
926
|
+
f"Invalid metric key: {e}",
|
927
|
+
error_code=INVALID_PARAMETER_VALUE,
|
928
|
+
)
|
929
|
+
|
930
|
+
metric = Metric(
|
931
|
+
validated_key,
|
932
|
+
request_message.value,
|
933
|
+
request_message.timestamp,
|
934
|
+
request_message.step,
|
935
|
+
request_message.model_id or None,
|
936
|
+
request_message.dataset_name or None,
|
937
|
+
request_message.dataset_digest or None,
|
938
|
+
request_message.run_id or None,
|
939
|
+
)
|
940
|
+
run_id = request_message.run_id or request_message.run_uuid
|
941
|
+
_get_tracking_store().log_metric(run_id, metric)
|
942
|
+
response_message = LogMetric.Response()
|
943
|
+
response = Response(mimetype="application/json")
|
944
|
+
response.set_data(message_to_json(response_message))
|
945
|
+
return response
|
946
|
+
|
947
|
+
|
948
|
+
@catch_mlflow_exception
|
949
|
+
@_disable_if_artifacts_only
|
950
|
+
def _log_param():
|
951
|
+
request_message = _get_request_message(
|
952
|
+
LogParam(),
|
953
|
+
schema={
|
954
|
+
"run_id": [_assert_required, _assert_string],
|
955
|
+
"key": [_assert_required, _assert_string],
|
956
|
+
"value": [_assert_string],
|
957
|
+
},
|
958
|
+
)
|
959
|
+
|
960
|
+
# Security validation for parameter key and value
|
961
|
+
try:
|
962
|
+
validated_key = InputValidator.validate_param_key(request_message.key)
|
963
|
+
validated_value = InputValidator.validate_param_value(request_message.value)
|
964
|
+
except SecurityValidationError as e:
|
965
|
+
raise MlflowException(
|
966
|
+
f"Invalid parameter: {e}",
|
967
|
+
error_code=INVALID_PARAMETER_VALUE,
|
968
|
+
)
|
969
|
+
|
970
|
+
param = Param(validated_key, validated_value)
|
971
|
+
run_id = request_message.run_id or request_message.run_uuid
|
972
|
+
_get_tracking_store().log_param(run_id, param)
|
973
|
+
response_message = LogParam.Response()
|
974
|
+
response = Response(mimetype="application/json")
|
975
|
+
response.set_data(message_to_json(response_message))
|
976
|
+
return response
|
977
|
+
|
978
|
+
|
979
|
+
@catch_mlflow_exception
|
980
|
+
@_disable_if_artifacts_only
|
981
|
+
def _log_inputs():
|
982
|
+
request_message = _get_request_message(
|
983
|
+
LogInputs(),
|
984
|
+
schema={
|
985
|
+
"run_id": [_assert_required, _assert_string],
|
986
|
+
"datasets": [_assert_array],
|
987
|
+
"models": [_assert_array],
|
988
|
+
},
|
989
|
+
)
|
990
|
+
run_id = request_message.run_id
|
991
|
+
datasets = [
|
992
|
+
DatasetInput.from_proto(proto_dataset_input)
|
993
|
+
for proto_dataset_input in request_message.datasets
|
994
|
+
]
|
995
|
+
models = (
|
996
|
+
[
|
997
|
+
LoggedModelInput.from_proto(proto_logged_model_input)
|
998
|
+
for proto_logged_model_input in request_message.models
|
999
|
+
]
|
1000
|
+
if request_message.models
|
1001
|
+
else None
|
1002
|
+
)
|
1003
|
+
|
1004
|
+
_get_tracking_store().log_inputs(run_id, datasets=datasets, models=models)
|
1005
|
+
response_message = LogInputs.Response()
|
1006
|
+
response = Response(mimetype="application/json")
|
1007
|
+
response.set_data(message_to_json(response_message))
|
1008
|
+
return response
|
1009
|
+
|
1010
|
+
|
1011
|
+
@catch_mlflow_exception
|
1012
|
+
@_disable_if_artifacts_only
|
1013
|
+
def _log_outputs():
|
1014
|
+
request_message = _get_request_message(
|
1015
|
+
LogOutputs(),
|
1016
|
+
schema={
|
1017
|
+
"run_id": [_assert_required, _assert_string],
|
1018
|
+
"models": [_assert_required, _assert_array],
|
1019
|
+
},
|
1020
|
+
)
|
1021
|
+
models = [LoggedModelOutput.from_proto(p) for p in request_message.models]
|
1022
|
+
_get_tracking_store().log_outputs(run_id=request_message.run_id, models=models)
|
1023
|
+
response_message = LogOutputs.Response()
|
1024
|
+
return _wrap_response(response_message)
|
1025
|
+
|
1026
|
+
|
1027
|
+
@catch_mlflow_exception
|
1028
|
+
@_disable_if_artifacts_only
|
1029
|
+
def _set_experiment_tag():
|
1030
|
+
request_message = _get_request_message(
|
1031
|
+
SetExperimentTag(),
|
1032
|
+
schema={
|
1033
|
+
"experiment_id": [_assert_required, _assert_string],
|
1034
|
+
"key": [_assert_required, _assert_string],
|
1035
|
+
"value": [_assert_string],
|
1036
|
+
},
|
1037
|
+
)
|
1038
|
+
tag = ExperimentTag(request_message.key, request_message.value)
|
1039
|
+
_get_tracking_store().set_experiment_tag(request_message.experiment_id, tag)
|
1040
|
+
response_message = SetExperimentTag.Response()
|
1041
|
+
response = Response(mimetype="application/json")
|
1042
|
+
response.set_data(message_to_json(response_message))
|
1043
|
+
return response
|
1044
|
+
|
1045
|
+
|
1046
|
+
@catch_mlflow_exception
|
1047
|
+
@_disable_if_artifacts_only
|
1048
|
+
def _set_tag():
|
1049
|
+
request_message = _get_request_message(
|
1050
|
+
SetTag(),
|
1051
|
+
schema={
|
1052
|
+
"run_id": [_assert_required, _assert_string],
|
1053
|
+
"key": [_assert_required, _assert_string],
|
1054
|
+
"value": [_assert_string],
|
1055
|
+
},
|
1056
|
+
)
|
1057
|
+
tag = RunTag(request_message.key, request_message.value)
|
1058
|
+
run_id = request_message.run_id or request_message.run_uuid
|
1059
|
+
_get_tracking_store().set_tag(run_id, tag)
|
1060
|
+
response_message = SetTag.Response()
|
1061
|
+
response = Response(mimetype="application/json")
|
1062
|
+
response.set_data(message_to_json(response_message))
|
1063
|
+
return response
|
1064
|
+
|
1065
|
+
|
1066
|
+
@catch_mlflow_exception
|
1067
|
+
@_disable_if_artifacts_only
|
1068
|
+
def _delete_tag():
|
1069
|
+
request_message = _get_request_message(
|
1070
|
+
DeleteTag(),
|
1071
|
+
schema={
|
1072
|
+
"run_id": [_assert_required, _assert_string],
|
1073
|
+
"key": [_assert_required, _assert_string],
|
1074
|
+
},
|
1075
|
+
)
|
1076
|
+
_get_tracking_store().delete_tag(request_message.run_id, request_message.key)
|
1077
|
+
response_message = DeleteTag.Response()
|
1078
|
+
response = Response(mimetype="application/json")
|
1079
|
+
response.set_data(message_to_json(response_message))
|
1080
|
+
return response
|
1081
|
+
|
1082
|
+
|
1083
|
+
@catch_mlflow_exception
|
1084
|
+
@_disable_if_artifacts_only
|
1085
|
+
def _get_run():
|
1086
|
+
request_message = _get_request_message(
|
1087
|
+
GetRun(), schema={"run_id": [_assert_required, _assert_string]}
|
1088
|
+
)
|
1089
|
+
response_message = get_run_impl(request_message)
|
1090
|
+
response = Response(mimetype="application/json")
|
1091
|
+
response.set_data(message_to_json(response_message))
|
1092
|
+
return response
|
1093
|
+
|
1094
|
+
|
1095
|
+
def get_run_impl(request_message):
|
1096
|
+
response_message = GetRun.Response()
|
1097
|
+
run_id = request_message.run_id or request_message.run_uuid
|
1098
|
+
response_message.run.MergeFrom(_get_tracking_store().get_run(run_id).to_proto())
|
1099
|
+
return response_message
|
1100
|
+
|
1101
|
+
|
1102
|
+
@catch_mlflow_exception
|
1103
|
+
@_disable_if_artifacts_only
|
1104
|
+
def _search_runs():
|
1105
|
+
request_message = _get_request_message(
|
1106
|
+
SearchRuns(),
|
1107
|
+
schema={
|
1108
|
+
"experiment_ids": [_assert_array],
|
1109
|
+
"filter": [_assert_string],
|
1110
|
+
"max_results": [
|
1111
|
+
_assert_intlike,
|
1112
|
+
lambda x: _assert_less_than_or_equal(int(x), 50000),
|
1113
|
+
],
|
1114
|
+
"order_by": [_assert_array, _assert_item_type_string],
|
1115
|
+
},
|
1116
|
+
)
|
1117
|
+
response_message = search_runs_impl(request_message)
|
1118
|
+
response = Response(mimetype="application/json")
|
1119
|
+
response.set_data(message_to_json(response_message))
|
1120
|
+
return response
|
1121
|
+
|
1122
|
+
|
1123
|
+
def search_runs_impl(request_message):
|
1124
|
+
response_message = SearchRuns.Response()
|
1125
|
+
run_view_type = ViewType.ACTIVE_ONLY
|
1126
|
+
if request_message.HasField("run_view_type"):
|
1127
|
+
run_view_type = ViewType.from_proto(request_message.run_view_type)
|
1128
|
+
filter_string = request_message.filter
|
1129
|
+
max_results = request_message.max_results
|
1130
|
+
experiment_ids = request_message.experiment_ids
|
1131
|
+
order_by = request_message.order_by
|
1132
|
+
page_token = request_message.page_token
|
1133
|
+
run_entities = _get_tracking_store().search_runs(
|
1134
|
+
experiment_ids, filter_string, run_view_type, max_results, order_by, page_token
|
1135
|
+
)
|
1136
|
+
response_message.runs.extend([r.to_proto() for r in run_entities])
|
1137
|
+
if run_entities.token:
|
1138
|
+
response_message.next_page_token = run_entities.token
|
1139
|
+
return response_message
|
1140
|
+
|
1141
|
+
|
1142
|
+
@catch_mlflow_exception
|
1143
|
+
@_disable_if_artifacts_only
|
1144
|
+
def _list_artifacts():
|
1145
|
+
request_message = _get_request_message(
|
1146
|
+
ListArtifacts(),
|
1147
|
+
schema={
|
1148
|
+
"run_id": [_assert_string, _assert_required],
|
1149
|
+
"path": [_assert_string],
|
1150
|
+
"page_token": [_assert_string],
|
1151
|
+
},
|
1152
|
+
)
|
1153
|
+
response_message = list_artifacts_impl(request_message)
|
1154
|
+
response = Response(mimetype="application/json")
|
1155
|
+
response.set_data(message_to_json(response_message))
|
1156
|
+
return response
|
1157
|
+
|
1158
|
+
|
1159
|
+
def list_artifacts_impl(request_message):
|
1160
|
+
response_message = ListArtifacts.Response()
|
1161
|
+
if request_message.HasField("path"):
|
1162
|
+
path = request_message.path
|
1163
|
+
path = validate_path_is_safe(path)
|
1164
|
+
else:
|
1165
|
+
path = None
|
1166
|
+
run_id = request_message.run_id or request_message.run_uuid
|
1167
|
+
run = _get_tracking_store().get_run(run_id)
|
1168
|
+
|
1169
|
+
if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
|
1170
|
+
artifact_entities = _list_artifacts_for_proxied_run_artifact_root(
|
1171
|
+
proxied_artifact_root=run.info.artifact_uri,
|
1172
|
+
relative_path=path,
|
1173
|
+
)
|
1174
|
+
else:
|
1175
|
+
artifact_entities = _get_artifact_repo(run).list_artifacts(path)
|
1176
|
+
|
1177
|
+
response_message.files.extend([a.to_proto() for a in artifact_entities])
|
1178
|
+
response_message.root_uri = run.info.artifact_uri
|
1179
|
+
return response_message
|
1180
|
+
|
1181
|
+
|
1182
|
+
@catch_mlflow_exception
|
1183
|
+
def _list_artifacts_for_proxied_run_artifact_root(proxied_artifact_root, relative_path=None):
|
1184
|
+
"""
|
1185
|
+
Lists artifacts from the specified ``relative_path`` within the specified proxied Run artifact
|
1186
|
+
root (i.e. a Run artifact root with scheme ``http``, ``https``, or ``mlflow-artifacts``).
|
1187
|
+
|
1188
|
+
Args:
|
1189
|
+
proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
|
1190
|
+
``https``, or ``mlflow-artifacts`` that can be resolved by the
|
1191
|
+
MLflow server to a concrete storage location.
|
1192
|
+
relative_path: The relative path within the specified ``proxied_artifact_root`` under
|
1193
|
+
which to list artifact contents. If ``None``, artifacts are listed from
|
1194
|
+
the ``proxied_artifact_root`` directory.
|
1195
|
+
"""
|
1196
|
+
parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
|
1197
|
+
assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
|
1198
|
+
|
1199
|
+
artifact_destination_repo = _get_artifact_repo_mlflow_artifacts()
|
1200
|
+
artifact_destination_path = _get_proxied_run_artifact_destination_path(
|
1201
|
+
proxied_artifact_root=proxied_artifact_root,
|
1202
|
+
relative_path=relative_path,
|
1203
|
+
)
|
1204
|
+
|
1205
|
+
artifact_entities = []
|
1206
|
+
for file_info in artifact_destination_repo.list_artifacts(artifact_destination_path):
|
1207
|
+
basename = posixpath.basename(file_info.path)
|
1208
|
+
run_relative_artifact_path = (
|
1209
|
+
posixpath.join(relative_path, basename) if relative_path else basename
|
1210
|
+
)
|
1211
|
+
artifact_entities.append(
|
1212
|
+
FileInfo(run_relative_artifact_path, file_info.is_dir, file_info.file_size)
|
1213
|
+
)
|
1214
|
+
|
1215
|
+
return artifact_entities
|
1216
|
+
|
1217
|
+
|
1218
|
+
@catch_mlflow_exception
|
1219
|
+
@_disable_if_artifacts_only
|
1220
|
+
def _get_metric_history():
|
1221
|
+
request_message = _get_request_message(
|
1222
|
+
GetMetricHistory(),
|
1223
|
+
schema={
|
1224
|
+
"run_id": [_assert_string, _assert_required],
|
1225
|
+
"metric_key": [_assert_string, _assert_required],
|
1226
|
+
"page_token": [_assert_string],
|
1227
|
+
},
|
1228
|
+
)
|
1229
|
+
response_message = GetMetricHistory.Response()
|
1230
|
+
run_id = request_message.run_id or request_message.run_uuid
|
1231
|
+
|
1232
|
+
max_results = request_message.max_results if request_message.max_results is not None else None
|
1233
|
+
page_token = request_message.page_token if request_message.page_token else None
|
1234
|
+
|
1235
|
+
metric_entities = _get_tracking_store().get_metric_history(
|
1236
|
+
run_id, request_message.metric_key, max_results=max_results, page_token=page_token
|
1237
|
+
)
|
1238
|
+
response_message.metrics.extend([m.to_proto() for m in metric_entities])
|
1239
|
+
|
1240
|
+
# Set next_page_token if available
|
1241
|
+
if next_page_token := metric_entities.token:
|
1242
|
+
response_message.next_page_token = next_page_token
|
1243
|
+
|
1244
|
+
response = Response(mimetype="application/json")
|
1245
|
+
response.set_data(message_to_json(response_message))
|
1246
|
+
return response
|
1247
|
+
|
1248
|
+
|
1249
|
+
@catch_mlflow_exception
|
1250
|
+
@_disable_if_artifacts_only
|
1251
|
+
def get_metric_history_bulk_handler():
|
1252
|
+
MAX_HISTORY_RESULTS = 25000
|
1253
|
+
MAX_RUN_IDS_PER_REQUEST = 100
|
1254
|
+
run_ids = request.args.to_dict(flat=False).get("run_id", [])
|
1255
|
+
if not run_ids:
|
1256
|
+
raise MlflowException(
|
1257
|
+
message="GetMetricHistoryBulk request must specify at least one run_id.",
|
1258
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1259
|
+
)
|
1260
|
+
if len(run_ids) > MAX_RUN_IDS_PER_REQUEST:
|
1261
|
+
raise MlflowException(
|
1262
|
+
message=(
|
1263
|
+
f"GetMetricHistoryBulk request cannot specify more than {MAX_RUN_IDS_PER_REQUEST}"
|
1264
|
+
f" run_ids. Received {len(run_ids)} run_ids."
|
1265
|
+
),
|
1266
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
metric_key = request.args.get("metric_key")
|
1270
|
+
if metric_key is None:
|
1271
|
+
raise MlflowException(
|
1272
|
+
message="GetMetricHistoryBulk request must specify a metric_key.",
|
1273
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1274
|
+
)
|
1275
|
+
|
1276
|
+
max_results = int(request.args.get("max_results", MAX_HISTORY_RESULTS))
|
1277
|
+
max_results = min(max_results, MAX_HISTORY_RESULTS)
|
1278
|
+
|
1279
|
+
store = _get_tracking_store()
|
1280
|
+
|
1281
|
+
def _default_history_bulk_impl():
|
1282
|
+
metrics_with_run_ids = []
|
1283
|
+
for run_id in sorted(run_ids):
|
1284
|
+
metrics_for_run = sorted(
|
1285
|
+
store.get_metric_history(
|
1286
|
+
run_id=run_id,
|
1287
|
+
metric_key=metric_key,
|
1288
|
+
max_results=max_results,
|
1289
|
+
),
|
1290
|
+
key=lambda metric: (metric.timestamp, metric.step, metric.value),
|
1291
|
+
)
|
1292
|
+
metrics_with_run_ids.extend(
|
1293
|
+
[
|
1294
|
+
{
|
1295
|
+
"key": metric.key,
|
1296
|
+
"value": metric.value,
|
1297
|
+
"timestamp": metric.timestamp,
|
1298
|
+
"step": metric.step,
|
1299
|
+
"run_id": run_id,
|
1300
|
+
}
|
1301
|
+
for metric in metrics_for_run
|
1302
|
+
]
|
1303
|
+
)
|
1304
|
+
return metrics_with_run_ids
|
1305
|
+
|
1306
|
+
if hasattr(store, "get_metric_history_bulk"):
|
1307
|
+
metrics_with_run_ids = [
|
1308
|
+
metric.to_dict()
|
1309
|
+
for metric in store.get_metric_history_bulk(
|
1310
|
+
run_ids=run_ids,
|
1311
|
+
metric_key=metric_key,
|
1312
|
+
max_results=max_results,
|
1313
|
+
)
|
1314
|
+
]
|
1315
|
+
else:
|
1316
|
+
metrics_with_run_ids = _default_history_bulk_impl()
|
1317
|
+
|
1318
|
+
return {
|
1319
|
+
"metrics": metrics_with_run_ids[:max_results],
|
1320
|
+
}
|
1321
|
+
|
1322
|
+
|
1323
|
+
def _get_sampled_steps_from_steps(
|
1324
|
+
start_step: int, end_step: int, max_results: int, all_steps: list[int]
|
1325
|
+
) -> set[int]:
|
1326
|
+
# NOTE: all_steps should be sorted before
|
1327
|
+
# being passed to this function
|
1328
|
+
start_idx = bisect.bisect_left(all_steps, start_step)
|
1329
|
+
end_idx = bisect.bisect_right(all_steps, end_step)
|
1330
|
+
if end_idx - start_idx <= max_results:
|
1331
|
+
return set(all_steps[start_idx:end_idx])
|
1332
|
+
|
1333
|
+
num_steps = end_idx - start_idx
|
1334
|
+
interval = num_steps / max_results
|
1335
|
+
sampled_steps = []
|
1336
|
+
|
1337
|
+
for i in range(0, max_results):
|
1338
|
+
idx = start_idx + int(i * interval)
|
1339
|
+
if idx < num_steps:
|
1340
|
+
sampled_steps.append(all_steps[idx])
|
1341
|
+
|
1342
|
+
sampled_steps.append(all_steps[end_idx - 1])
|
1343
|
+
return set(sampled_steps)
|
1344
|
+
|
1345
|
+
|
1346
|
+
@catch_mlflow_exception
|
1347
|
+
@_disable_if_artifacts_only
|
1348
|
+
def get_metric_history_bulk_interval_handler():
|
1349
|
+
request_message = _get_request_message(
|
1350
|
+
GetMetricHistoryBulkInterval(),
|
1351
|
+
schema={
|
1352
|
+
"run_ids": [
|
1353
|
+
_assert_required,
|
1354
|
+
_assert_array,
|
1355
|
+
_assert_item_type_string,
|
1356
|
+
lambda x: _assert_less_than_or_equal(
|
1357
|
+
len(x),
|
1358
|
+
MAX_RUNS_GET_METRIC_HISTORY_BULK,
|
1359
|
+
message=f"GetMetricHistoryBulkInterval request must specify at most "
|
1360
|
+
f"{MAX_RUNS_GET_METRIC_HISTORY_BULK} run_ids. Received {len(x)} run_ids.",
|
1361
|
+
),
|
1362
|
+
],
|
1363
|
+
"metric_key": [_assert_required, _assert_string],
|
1364
|
+
"start_step": [_assert_intlike],
|
1365
|
+
"end_step": [_assert_intlike],
|
1366
|
+
"max_results": [
|
1367
|
+
_assert_intlike,
|
1368
|
+
lambda x: _assert_intlike_within_range(
|
1369
|
+
int(x),
|
1370
|
+
1,
|
1371
|
+
MAX_RESULTS_PER_RUN,
|
1372
|
+
message=f"max_results must be between 1 and {MAX_RESULTS_PER_RUN}.",
|
1373
|
+
),
|
1374
|
+
],
|
1375
|
+
},
|
1376
|
+
)
|
1377
|
+
response_message = get_metric_history_bulk_interval_impl(request_message)
|
1378
|
+
response = Response(mimetype="application/json")
|
1379
|
+
response.set_data(message_to_json(response_message))
|
1380
|
+
return response
|
1381
|
+
|
1382
|
+
|
1383
|
+
def get_metric_history_bulk_interval_impl(request_message):
|
1384
|
+
args = request.args
|
1385
|
+
run_ids = request_message.run_ids
|
1386
|
+
metric_key = request_message.metric_key
|
1387
|
+
max_results = int(args.get("max_results", MAX_RESULTS_PER_RUN))
|
1388
|
+
|
1389
|
+
store = _get_tracking_store()
|
1390
|
+
|
1391
|
+
def _get_sampled_steps(run_ids, metric_key, max_results):
|
1392
|
+
# cannot fetch from request_message as the default value is 0
|
1393
|
+
start_step = args.get("start_step")
|
1394
|
+
end_step = args.get("end_step")
|
1395
|
+
|
1396
|
+
# perform validation before any data fetching occurs
|
1397
|
+
if start_step is not None and end_step is not None:
|
1398
|
+
start_step = int(start_step)
|
1399
|
+
end_step = int(end_step)
|
1400
|
+
if start_step > end_step:
|
1401
|
+
raise MlflowException.invalid_parameter_value(
|
1402
|
+
"end_step must be greater than start_step. "
|
1403
|
+
f"Found start_step={start_step} and end_step={end_step}."
|
1404
|
+
)
|
1405
|
+
elif start_step is not None or end_step is not None:
|
1406
|
+
raise MlflowException.invalid_parameter_value(
|
1407
|
+
"If either start step or end step are specified, both must be specified."
|
1408
|
+
)
|
1409
|
+
|
1410
|
+
# get a list of all steps for all runs. this is necessary
|
1411
|
+
# because we can't assume that every step was logged, so
|
1412
|
+
# sampling needs to be done on the steps that actually exist
|
1413
|
+
all_runs = [
|
1414
|
+
[m.step for m in store.get_metric_history(run_id, metric_key)] for run_id in run_ids
|
1415
|
+
]
|
1416
|
+
|
1417
|
+
# save mins and maxes to be added back later
|
1418
|
+
all_mins_and_maxes = {step for run in all_runs if run for step in [min(run), max(run)]}
|
1419
|
+
all_steps = sorted({step for sublist in all_runs for step in sublist})
|
1420
|
+
|
1421
|
+
# init start and end step if not provided in args
|
1422
|
+
if start_step is None and end_step is None:
|
1423
|
+
start_step = 0
|
1424
|
+
end_step = all_steps[-1] if all_steps else 0
|
1425
|
+
|
1426
|
+
# remove any steps outside of the range
|
1427
|
+
all_mins_and_maxes = {step for step in all_mins_and_maxes if start_step <= step <= end_step}
|
1428
|
+
|
1429
|
+
# doing extra iterations here shouldn't badly affect performance,
|
1430
|
+
# since the number of steps at this point should be relatively small
|
1431
|
+
# (MAX_RESULTS_PER_RUN + len(all_mins_and_maxes))
|
1432
|
+
sampled_steps = _get_sampled_steps_from_steps(start_step, end_step, max_results, all_steps)
|
1433
|
+
return sorted(sampled_steps.union(all_mins_and_maxes))
|
1434
|
+
|
1435
|
+
def _default_history_bulk_interval_impl():
|
1436
|
+
steps = _get_sampled_steps(run_ids, metric_key, max_results)
|
1437
|
+
metrics_with_run_ids = []
|
1438
|
+
for run_id in run_ids:
|
1439
|
+
metrics_with_run_ids.extend(
|
1440
|
+
store.get_metric_history_bulk_interval_from_steps(
|
1441
|
+
run_id=run_id,
|
1442
|
+
metric_key=metric_key,
|
1443
|
+
steps=steps,
|
1444
|
+
max_results=MAX_RESULTS_GET_METRIC_HISTORY,
|
1445
|
+
)
|
1446
|
+
)
|
1447
|
+
return metrics_with_run_ids
|
1448
|
+
|
1449
|
+
metrics_with_run_ids = _default_history_bulk_interval_impl()
|
1450
|
+
|
1451
|
+
response_message = GetMetricHistoryBulkInterval.Response()
|
1452
|
+
response_message.metrics.extend([m.to_proto() for m in metrics_with_run_ids])
|
1453
|
+
return response_message
|
1454
|
+
|
1455
|
+
|
1456
|
+
@catch_mlflow_exception
|
1457
|
+
@_disable_if_artifacts_only
|
1458
|
+
def search_datasets_handler():
|
1459
|
+
request_message = _get_request_message(
|
1460
|
+
SearchDatasets(),
|
1461
|
+
)
|
1462
|
+
response_message = search_datasets_impl(request_message)
|
1463
|
+
response = Response(mimetype="application/json")
|
1464
|
+
response.set_data(message_to_json(response_message))
|
1465
|
+
return response
|
1466
|
+
|
1467
|
+
|
1468
|
+
def search_datasets_impl(request_message):
|
1469
|
+
MAX_EXPERIMENT_IDS_PER_REQUEST = 20
|
1470
|
+
_validate_content_type(request, ["application/json"])
|
1471
|
+
experiment_ids = request_message.experiment_ids or []
|
1472
|
+
if not experiment_ids:
|
1473
|
+
raise MlflowException(
|
1474
|
+
message="SearchDatasets request must specify at least one experiment_id.",
|
1475
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1476
|
+
)
|
1477
|
+
if len(experiment_ids) > MAX_EXPERIMENT_IDS_PER_REQUEST:
|
1478
|
+
raise MlflowException(
|
1479
|
+
message=(
|
1480
|
+
f"SearchDatasets request cannot specify more than {MAX_EXPERIMENT_IDS_PER_REQUEST}"
|
1481
|
+
f" experiment_ids. Received {len(experiment_ids)} experiment_ids."
|
1482
|
+
),
|
1483
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1484
|
+
)
|
1485
|
+
|
1486
|
+
store = _get_tracking_store()
|
1487
|
+
|
1488
|
+
if hasattr(store, "_search_datasets"):
|
1489
|
+
response_message = SearchDatasets.Response()
|
1490
|
+
response_message.dataset_summaries.extend(
|
1491
|
+
[summary.to_proto() for summary in store._search_datasets(experiment_ids)]
|
1492
|
+
)
|
1493
|
+
return response_message
|
1494
|
+
else:
|
1495
|
+
return _not_implemented()
|
1496
|
+
|
1497
|
+
|
1498
|
+
def _validate_gateway_path(method: str, gateway_path: str) -> None:
|
1499
|
+
if not gateway_path:
|
1500
|
+
raise MlflowException(
|
1501
|
+
message="Deployments proxy request must specify a gateway_path.",
|
1502
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1503
|
+
)
|
1504
|
+
elif method == "GET":
|
1505
|
+
if gateway_path.strip("/") != "api/2.0/endpoints":
|
1506
|
+
raise MlflowException(
|
1507
|
+
message=f"Invalid gateway_path: {gateway_path} for method: {method}",
|
1508
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1509
|
+
)
|
1510
|
+
elif method == "POST":
|
1511
|
+
# For POST, gateway_path must be in the form of "gateway/{name}/invocations"
|
1512
|
+
if not re.fullmatch(r"gateway/[^/]+/invocations", gateway_path.strip("/")):
|
1513
|
+
raise MlflowException(
|
1514
|
+
message=f"Invalid gateway_path: {gateway_path} for method: {method}",
|
1515
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1516
|
+
)
|
1517
|
+
|
1518
|
+
|
1519
|
+
@catch_mlflow_exception
|
1520
|
+
def gateway_proxy_handler():
|
1521
|
+
target_uri = MLFLOW_DEPLOYMENTS_TARGET.get()
|
1522
|
+
if not target_uri:
|
1523
|
+
# Pretend an empty gateway service is running
|
1524
|
+
return {"endpoints": []}
|
1525
|
+
|
1526
|
+
args = request.args if request.method == "GET" else request.json
|
1527
|
+
gateway_path = args.get("gateway_path")
|
1528
|
+
_validate_gateway_path(request.method, gateway_path)
|
1529
|
+
json_data = args.get("json_data", None)
|
1530
|
+
response = requests.request(request.method, f"{target_uri}/{gateway_path}", json=json_data)
|
1531
|
+
if response.status_code == 200:
|
1532
|
+
return response.json()
|
1533
|
+
else:
|
1534
|
+
raise MlflowException(
|
1535
|
+
message=f"Deployments proxy request failed with error code {response.status_code}. "
|
1536
|
+
f"Error message: {response.text}",
|
1537
|
+
error_code=response.status_code,
|
1538
|
+
)
|
1539
|
+
|
1540
|
+
|
1541
|
+
@catch_mlflow_exception
|
1542
|
+
@_disable_if_artifacts_only
|
1543
|
+
def create_promptlab_run_handler():
|
1544
|
+
def assert_arg_exists(arg_name, arg):
|
1545
|
+
if not arg:
|
1546
|
+
raise MlflowException(
|
1547
|
+
message=f"CreatePromptlabRun request must specify {arg_name}.",
|
1548
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1549
|
+
)
|
1550
|
+
|
1551
|
+
_validate_content_type(request, ["application/json"])
|
1552
|
+
|
1553
|
+
args = request.json
|
1554
|
+
experiment_id = args.get("experiment_id")
|
1555
|
+
assert_arg_exists("experiment_id", experiment_id)
|
1556
|
+
run_name = args.get("run_name", None)
|
1557
|
+
tags = args.get("tags", [])
|
1558
|
+
prompt_template = args.get("prompt_template")
|
1559
|
+
assert_arg_exists("prompt_template", prompt_template)
|
1560
|
+
raw_prompt_parameters = args.get("prompt_parameters")
|
1561
|
+
assert_arg_exists("prompt_parameters", raw_prompt_parameters)
|
1562
|
+
prompt_parameters = [
|
1563
|
+
Param(param.get("key"), param.get("value")) for param in args.get("prompt_parameters")
|
1564
|
+
]
|
1565
|
+
model_route = args.get("model_route")
|
1566
|
+
assert_arg_exists("model_route", model_route)
|
1567
|
+
raw_model_parameters = args.get("model_parameters", [])
|
1568
|
+
model_parameters = [
|
1569
|
+
Param(param.get("key"), param.get("value")) for param in raw_model_parameters
|
1570
|
+
]
|
1571
|
+
model_input = args.get("model_input")
|
1572
|
+
assert_arg_exists("model_input", model_input)
|
1573
|
+
model_output = args.get("model_output", None)
|
1574
|
+
raw_model_output_parameters = args.get("model_output_parameters", [])
|
1575
|
+
model_output_parameters = [
|
1576
|
+
Param(param.get("key"), param.get("value")) for param in raw_model_output_parameters
|
1577
|
+
]
|
1578
|
+
mlflow_version = args.get("mlflow_version")
|
1579
|
+
assert_arg_exists("mlflow_version", mlflow_version)
|
1580
|
+
user_id = args.get("user_id", "unknown")
|
1581
|
+
|
1582
|
+
# use current time if not provided
|
1583
|
+
start_time = args.get("start_time", int(time.time() * 1000))
|
1584
|
+
|
1585
|
+
store = _get_tracking_store()
|
1586
|
+
|
1587
|
+
run = _create_promptlab_run_impl(
|
1588
|
+
store,
|
1589
|
+
experiment_id=experiment_id,
|
1590
|
+
run_name=run_name,
|
1591
|
+
tags=tags,
|
1592
|
+
prompt_template=prompt_template,
|
1593
|
+
prompt_parameters=prompt_parameters,
|
1594
|
+
model_route=model_route,
|
1595
|
+
model_parameters=model_parameters,
|
1596
|
+
model_input=model_input,
|
1597
|
+
model_output=model_output,
|
1598
|
+
model_output_parameters=model_output_parameters,
|
1599
|
+
mlflow_version=mlflow_version,
|
1600
|
+
user_id=user_id,
|
1601
|
+
start_time=start_time,
|
1602
|
+
)
|
1603
|
+
response_message = CreateRun.Response()
|
1604
|
+
response_message.run.MergeFrom(run.to_proto())
|
1605
|
+
response = Response(mimetype="application/json")
|
1606
|
+
response.set_data(message_to_json(response_message))
|
1607
|
+
return response
|
1608
|
+
|
1609
|
+
|
1610
|
+
@catch_mlflow_exception
|
1611
|
+
def upload_artifact_handler():
|
1612
|
+
args = request.args
|
1613
|
+
run_uuid = args.get("run_uuid")
|
1614
|
+
if not run_uuid:
|
1615
|
+
raise MlflowException(
|
1616
|
+
message="Request must specify run_uuid.",
|
1617
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1618
|
+
)
|
1619
|
+
path = args.get("path")
|
1620
|
+
if not path:
|
1621
|
+
raise MlflowException(
|
1622
|
+
message="Request must specify path.",
|
1623
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1624
|
+
)
|
1625
|
+
|
1626
|
+
# Security validation for artifact path
|
1627
|
+
try:
|
1628
|
+
validated_path = InputValidator.validate_artifact_path(path)
|
1629
|
+
except SecurityValidationError as e:
|
1630
|
+
raise MlflowException(
|
1631
|
+
f"Invalid artifact path: {e}",
|
1632
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1633
|
+
)
|
1634
|
+
|
1635
|
+
path = validate_path_is_safe(validated_path)
|
1636
|
+
|
1637
|
+
if request.content_length and request.content_length > 10 * 1024 * 1024:
|
1638
|
+
raise MlflowException(
|
1639
|
+
message="Artifact size is too large. Max size is 10MB.",
|
1640
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1641
|
+
)
|
1642
|
+
|
1643
|
+
data = request.data
|
1644
|
+
if not data:
|
1645
|
+
raise MlflowException(
|
1646
|
+
message="Request must specify data.",
|
1647
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1648
|
+
)
|
1649
|
+
|
1650
|
+
run = _get_tracking_store().get_run(run_uuid)
|
1651
|
+
artifact_dir = run.info.artifact_uri
|
1652
|
+
|
1653
|
+
basename = posixpath.basename(path)
|
1654
|
+
dirname = posixpath.dirname(path)
|
1655
|
+
|
1656
|
+
def _log_artifact_to_repo(file, run, dirname, artifact_dir):
|
1657
|
+
if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
|
1658
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
1659
|
+
path_to_log = (
|
1660
|
+
os.path.join(run.info.experiment_id, run.info.run_id, "artifacts", dirname)
|
1661
|
+
if dirname
|
1662
|
+
else os.path.join(run.info.experiment_id, run.info.run_id, "artifacts")
|
1663
|
+
)
|
1664
|
+
else:
|
1665
|
+
artifact_repo = get_artifact_repository(artifact_dir)
|
1666
|
+
path_to_log = dirname
|
1667
|
+
|
1668
|
+
artifact_repo.log_artifact(file, path_to_log)
|
1669
|
+
|
1670
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
1671
|
+
dir_path = os.path.join(tmpdir, dirname) if dirname else tmpdir
|
1672
|
+
file_path = os.path.join(dir_path, basename)
|
1673
|
+
|
1674
|
+
os.makedirs(dir_path, exist_ok=True)
|
1675
|
+
|
1676
|
+
with open(file_path, "wb") as f:
|
1677
|
+
f.write(data)
|
1678
|
+
|
1679
|
+
_log_artifact_to_repo(file_path, run, dirname, artifact_dir)
|
1680
|
+
|
1681
|
+
return Response(mimetype="application/json")
|
1682
|
+
|
1683
|
+
|
1684
|
+
@catch_mlflow_exception
|
1685
|
+
@_disable_if_artifacts_only
|
1686
|
+
def _search_experiments():
|
1687
|
+
request_message = _get_request_message(
|
1688
|
+
SearchExperiments(),
|
1689
|
+
schema={
|
1690
|
+
"view_type": [_assert_intlike],
|
1691
|
+
"max_results": [_assert_intlike],
|
1692
|
+
"order_by": [_assert_array],
|
1693
|
+
"filter": [_assert_string],
|
1694
|
+
"page_token": [_assert_string],
|
1695
|
+
},
|
1696
|
+
)
|
1697
|
+
experiment_entities = _get_tracking_store().search_experiments(
|
1698
|
+
view_type=request_message.view_type,
|
1699
|
+
max_results=request_message.max_results,
|
1700
|
+
order_by=request_message.order_by,
|
1701
|
+
filter_string=request_message.filter,
|
1702
|
+
page_token=request_message.page_token,
|
1703
|
+
)
|
1704
|
+
response_message = SearchExperiments.Response()
|
1705
|
+
response_message.experiments.extend([e.to_proto() for e in experiment_entities])
|
1706
|
+
if experiment_entities.token:
|
1707
|
+
response_message.next_page_token = experiment_entities.token
|
1708
|
+
response = Response(mimetype="application/json")
|
1709
|
+
response.set_data(message_to_json(response_message))
|
1710
|
+
return response
|
1711
|
+
|
1712
|
+
|
1713
|
+
@catch_mlflow_exception
|
1714
|
+
def _get_artifact_repo(run):
|
1715
|
+
return get_artifact_repository(run.info.artifact_uri)
|
1716
|
+
|
1717
|
+
|
1718
|
+
@catch_mlflow_exception
|
1719
|
+
@_disable_if_artifacts_only
|
1720
|
+
def _log_batch():
|
1721
|
+
def _assert_metrics_fields_present(metrics):
|
1722
|
+
for idx, m in enumerate(metrics):
|
1723
|
+
_assert_required(m.get("key"), path=f"metrics[{idx}].key")
|
1724
|
+
_assert_required(m.get("value"), path=f"metrics[{idx}].value")
|
1725
|
+
_assert_required(m.get("timestamp"), path=f"metrics[{idx}].timestamp")
|
1726
|
+
|
1727
|
+
def _assert_params_fields_present(params):
|
1728
|
+
for idx, param in enumerate(params):
|
1729
|
+
_assert_required(param.get("key"), path=f"params[{idx}].key")
|
1730
|
+
|
1731
|
+
def _assert_tags_fields_present(tags):
|
1732
|
+
for idx, tag in enumerate(tags):
|
1733
|
+
_assert_required(tag.get("key"), path=f"tags[{idx}].key")
|
1734
|
+
|
1735
|
+
_validate_batch_log_api_req(_get_request_json())
|
1736
|
+
request_message = _get_request_message(
|
1737
|
+
LogBatch(),
|
1738
|
+
schema={
|
1739
|
+
"run_id": [_assert_string, _assert_required],
|
1740
|
+
"metrics": [_assert_array, _assert_metrics_fields_present],
|
1741
|
+
"params": [_assert_array, _assert_params_fields_present],
|
1742
|
+
"tags": [_assert_array, _assert_tags_fields_present],
|
1743
|
+
},
|
1744
|
+
)
|
1745
|
+
metrics = [Metric.from_proto(proto_metric) for proto_metric in request_message.metrics]
|
1746
|
+
params = [Param.from_proto(proto_param) for proto_param in request_message.params]
|
1747
|
+
tags = [RunTag.from_proto(proto_tag) for proto_tag in request_message.tags]
|
1748
|
+
_get_tracking_store().log_batch(
|
1749
|
+
run_id=request_message.run_id, metrics=metrics, params=params, tags=tags
|
1750
|
+
)
|
1751
|
+
response_message = LogBatch.Response()
|
1752
|
+
response = Response(mimetype="application/json")
|
1753
|
+
response.set_data(message_to_json(response_message))
|
1754
|
+
return response
|
1755
|
+
|
1756
|
+
|
1757
|
+
@catch_mlflow_exception
|
1758
|
+
@_disable_if_artifacts_only
|
1759
|
+
def _log_model():
|
1760
|
+
request_message = _get_request_message(
|
1761
|
+
LogModel(),
|
1762
|
+
schema={
|
1763
|
+
"run_id": [_assert_string, _assert_required],
|
1764
|
+
"model_json": [_assert_string, _assert_required],
|
1765
|
+
},
|
1766
|
+
)
|
1767
|
+
try:
|
1768
|
+
model = json.loads(request_message.model_json)
|
1769
|
+
except Exception:
|
1770
|
+
raise MlflowException(
|
1771
|
+
f"Malformed model info. \n {request_message.model_json} \n is not a valid JSON.",
|
1772
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1773
|
+
)
|
1774
|
+
|
1775
|
+
missing_fields = {"artifact_path", "flavors", "utc_time_created", "run_id"} - set(model.keys())
|
1776
|
+
|
1777
|
+
if missing_fields:
|
1778
|
+
raise MlflowException(
|
1779
|
+
f"Model json is missing mandatory fields: {missing_fields}",
|
1780
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1781
|
+
)
|
1782
|
+
_get_tracking_store().record_logged_model(
|
1783
|
+
run_id=request_message.run_id, mlflow_model=Model.from_dict(model)
|
1784
|
+
)
|
1785
|
+
response_message = LogModel.Response()
|
1786
|
+
response = Response(mimetype="application/json")
|
1787
|
+
response.set_data(message_to_json(response_message))
|
1788
|
+
return response
|
1789
|
+
|
1790
|
+
|
1791
|
+
def _wrap_response(response_message):
|
1792
|
+
response = Response(mimetype="application/json")
|
1793
|
+
response.set_data(message_to_json(response_message))
|
1794
|
+
return response
|
1795
|
+
|
1796
|
+
|
1797
|
+
# Model Registry APIs
|
1798
|
+
|
1799
|
+
|
1800
|
+
@catch_mlflow_exception
|
1801
|
+
@_disable_if_artifacts_only
|
1802
|
+
def _create_registered_model():
|
1803
|
+
request_message = _get_request_message(
|
1804
|
+
CreateRegisteredModel(),
|
1805
|
+
schema={
|
1806
|
+
"name": [_assert_string, _assert_required],
|
1807
|
+
"tags": [_assert_array],
|
1808
|
+
"description": [_assert_string],
|
1809
|
+
},
|
1810
|
+
)
|
1811
|
+
registered_model = _get_model_registry_store().create_registered_model(
|
1812
|
+
name=request_message.name,
|
1813
|
+
tags=request_message.tags,
|
1814
|
+
description=request_message.description,
|
1815
|
+
)
|
1816
|
+
response_message = CreateRegisteredModel.Response(registered_model=registered_model.to_proto())
|
1817
|
+
return _wrap_response(response_message)
|
1818
|
+
|
1819
|
+
|
1820
|
+
@catch_mlflow_exception
|
1821
|
+
@_disable_if_artifacts_only
|
1822
|
+
def _get_registered_model():
|
1823
|
+
request_message = _get_request_message(
|
1824
|
+
GetRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
|
1825
|
+
)
|
1826
|
+
registered_model = _get_model_registry_store().get_registered_model(name=request_message.name)
|
1827
|
+
response_message = GetRegisteredModel.Response(registered_model=registered_model.to_proto())
|
1828
|
+
return _wrap_response(response_message)
|
1829
|
+
|
1830
|
+
|
1831
|
+
@catch_mlflow_exception
|
1832
|
+
@_disable_if_artifacts_only
|
1833
|
+
def _update_registered_model():
|
1834
|
+
request_message = _get_request_message(
|
1835
|
+
UpdateRegisteredModel(),
|
1836
|
+
schema={
|
1837
|
+
"name": [_assert_string, _assert_required],
|
1838
|
+
"description": [_assert_string],
|
1839
|
+
},
|
1840
|
+
)
|
1841
|
+
name = request_message.name
|
1842
|
+
new_description = request_message.description
|
1843
|
+
registered_model = _get_model_registry_store().update_registered_model(
|
1844
|
+
name=name, description=new_description
|
1845
|
+
)
|
1846
|
+
response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto())
|
1847
|
+
return _wrap_response(response_message)
|
1848
|
+
|
1849
|
+
|
1850
|
+
@catch_mlflow_exception
|
1851
|
+
@_disable_if_artifacts_only
|
1852
|
+
def _rename_registered_model():
|
1853
|
+
request_message = _get_request_message(
|
1854
|
+
RenameRegisteredModel(),
|
1855
|
+
schema={
|
1856
|
+
"name": [_assert_string, _assert_required],
|
1857
|
+
"new_name": [_assert_string, _assert_required],
|
1858
|
+
},
|
1859
|
+
)
|
1860
|
+
name = request_message.name
|
1861
|
+
new_name = request_message.new_name
|
1862
|
+
registered_model = _get_model_registry_store().rename_registered_model(
|
1863
|
+
name=name, new_name=new_name
|
1864
|
+
)
|
1865
|
+
response_message = RenameRegisteredModel.Response(registered_model=registered_model.to_proto())
|
1866
|
+
return _wrap_response(response_message)
|
1867
|
+
|
1868
|
+
|
1869
|
+
@catch_mlflow_exception
|
1870
|
+
@_disable_if_artifacts_only
|
1871
|
+
def _delete_registered_model():
|
1872
|
+
request_message = _get_request_message(
|
1873
|
+
DeleteRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
|
1874
|
+
)
|
1875
|
+
_get_model_registry_store().delete_registered_model(name=request_message.name)
|
1876
|
+
return _wrap_response(DeleteRegisteredModel.Response())
|
1877
|
+
|
1878
|
+
|
1879
|
+
@catch_mlflow_exception
|
1880
|
+
@_disable_if_artifacts_only
|
1881
|
+
def _search_registered_models():
|
1882
|
+
request_message = _get_request_message(
|
1883
|
+
SearchRegisteredModels(),
|
1884
|
+
schema={
|
1885
|
+
"filter": [_assert_string],
|
1886
|
+
"max_results": [
|
1887
|
+
_assert_intlike,
|
1888
|
+
lambda x: _assert_less_than_or_equal(int(x), 1000),
|
1889
|
+
],
|
1890
|
+
"order_by": [_assert_array, _assert_item_type_string],
|
1891
|
+
"page_token": [_assert_string],
|
1892
|
+
},
|
1893
|
+
)
|
1894
|
+
store = _get_model_registry_store()
|
1895
|
+
registered_models = store.search_registered_models(
|
1896
|
+
filter_string=request_message.filter,
|
1897
|
+
max_results=request_message.max_results,
|
1898
|
+
order_by=request_message.order_by,
|
1899
|
+
page_token=request_message.page_token,
|
1900
|
+
)
|
1901
|
+
response_message = SearchRegisteredModels.Response()
|
1902
|
+
response_message.registered_models.extend([e.to_proto() for e in registered_models])
|
1903
|
+
if registered_models.token:
|
1904
|
+
response_message.next_page_token = registered_models.token
|
1905
|
+
return _wrap_response(response_message)
|
1906
|
+
|
1907
|
+
|
1908
|
+
@catch_mlflow_exception
|
1909
|
+
@_disable_if_artifacts_only
|
1910
|
+
def _get_latest_versions():
|
1911
|
+
request_message = _get_request_message(
|
1912
|
+
GetLatestVersions(),
|
1913
|
+
schema={
|
1914
|
+
"name": [_assert_string, _assert_required],
|
1915
|
+
"stages": [_assert_array, _assert_item_type_string],
|
1916
|
+
},
|
1917
|
+
)
|
1918
|
+
latest_versions = _get_model_registry_store().get_latest_versions(
|
1919
|
+
name=request_message.name, stages=request_message.stages
|
1920
|
+
)
|
1921
|
+
response_message = GetLatestVersions.Response()
|
1922
|
+
response_message.model_versions.extend([e.to_proto() for e in latest_versions])
|
1923
|
+
return _wrap_response(response_message)
|
1924
|
+
|
1925
|
+
|
1926
|
+
@catch_mlflow_exception
|
1927
|
+
@_disable_if_artifacts_only
|
1928
|
+
def _set_registered_model_tag():
|
1929
|
+
request_message = _get_request_message(
|
1930
|
+
SetRegisteredModelTag(),
|
1931
|
+
schema={
|
1932
|
+
"name": [_assert_string, _assert_required],
|
1933
|
+
"key": [_assert_string, _assert_required],
|
1934
|
+
"value": [_assert_string],
|
1935
|
+
},
|
1936
|
+
)
|
1937
|
+
tag = RegisteredModelTag(key=request_message.key, value=request_message.value)
|
1938
|
+
_get_model_registry_store().set_registered_model_tag(name=request_message.name, tag=tag)
|
1939
|
+
return _wrap_response(SetRegisteredModelTag.Response())
|
1940
|
+
|
1941
|
+
|
1942
|
+
@catch_mlflow_exception
|
1943
|
+
@_disable_if_artifacts_only
|
1944
|
+
def _delete_registered_model_tag():
|
1945
|
+
request_message = _get_request_message(
|
1946
|
+
DeleteRegisteredModelTag(),
|
1947
|
+
schema={
|
1948
|
+
"name": [_assert_string, _assert_required],
|
1949
|
+
"key": [_assert_string, _assert_required],
|
1950
|
+
},
|
1951
|
+
)
|
1952
|
+
_get_model_registry_store().delete_registered_model_tag(
|
1953
|
+
name=request_message.name, key=request_message.key
|
1954
|
+
)
|
1955
|
+
return _wrap_response(DeleteRegisteredModelTag.Response())
|
1956
|
+
|
1957
|
+
|
1958
|
+
def _validate_non_local_source_contains_relative_paths(source: str):
|
1959
|
+
"""
|
1960
|
+
Validation check to ensure that sources that are provided that conform to the schemes:
|
1961
|
+
http, https, or mlflow-artifacts do not contain relative path designations that are intended
|
1962
|
+
to access local file system paths on the tracking server.
|
1963
|
+
|
1964
|
+
Example paths that this validation function is intended to find and raise an Exception if
|
1965
|
+
passed:
|
1966
|
+
"mlflow-artifacts://host:port/../../../../"
|
1967
|
+
"http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
|
1968
|
+
"https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
|
1969
|
+
"/models/artifacts/../../../"
|
1970
|
+
"s3:/my_bucket/models/path/../../other/path"
|
1971
|
+
"file://path/to/../../../../some/where/you/should/not/be"
|
1972
|
+
"mlflow-artifacts://host:port/..%2f..%2f..%2f..%2f"
|
1973
|
+
"http://host:port/api/2.0/mlflow-artifacts/artifacts%00"
|
1974
|
+
"""
|
1975
|
+
invalid_source_error_message = (
|
1976
|
+
f"Invalid model version source: '{source}'. If supplying a source as an http, https, "
|
1977
|
+
"local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be "
|
1978
|
+
"provided without relative path references present. "
|
1979
|
+
"Please provide an absolute path."
|
1980
|
+
)
|
1981
|
+
|
1982
|
+
while (unquoted := urllib.parse.unquote_plus(source)) != source:
|
1983
|
+
source = unquoted
|
1984
|
+
source_path = re.sub(r"/+", "/", urllib.parse.urlparse(source).path.rstrip("/"))
|
1985
|
+
if "\x00" in source_path or any(p == ".." for p in source.split("/")):
|
1986
|
+
raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
|
1987
|
+
resolved_source = pathlib.Path(source_path).resolve().as_posix()
|
1988
|
+
# NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
|
1989
|
+
# drive path of the pwd to a given path. We don't care about the drive here, though.
|
1990
|
+
_, resolved_path = os.path.splitdrive(resolved_source)
|
1991
|
+
|
1992
|
+
if resolved_path != source_path:
|
1993
|
+
raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
|
1994
|
+
|
1995
|
+
|
1996
|
+
def _validate_source_run(source: str, run_id: str) -> None:
|
1997
|
+
if is_local_uri(source):
|
1998
|
+
if run_id:
|
1999
|
+
store = _get_tracking_store()
|
2000
|
+
run = store.get_run(run_id)
|
2001
|
+
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
|
2002
|
+
if is_local_uri(run.info.artifact_uri):
|
2003
|
+
run_artifact_dir = pathlib.Path(
|
2004
|
+
local_file_uri_to_path(run.info.artifact_uri)
|
2005
|
+
).resolve()
|
2006
|
+
if run_artifact_dir in [source, *source.parents]:
|
2007
|
+
return
|
2008
|
+
|
2009
|
+
raise MlflowException(
|
2010
|
+
f"Invalid model version source: '{source}'. To use a local path as a model version "
|
2011
|
+
"source, the run_id request parameter has to be specified and the local path has to be "
|
2012
|
+
"contained within the artifact directory of the run specified by the run_id.",
|
2013
|
+
INVALID_PARAMETER_VALUE,
|
2014
|
+
)
|
2015
|
+
|
2016
|
+
# Checks if relative paths are present in the source (a security threat). If any are present,
|
2017
|
+
# raises an Exception.
|
2018
|
+
_validate_non_local_source_contains_relative_paths(source)
|
2019
|
+
|
2020
|
+
|
2021
|
+
def _validate_source_model(source: str, model_id: str) -> None:
|
2022
|
+
if is_local_uri(source):
|
2023
|
+
if model_id:
|
2024
|
+
store = _get_tracking_store()
|
2025
|
+
model = store.get_logged_model(model_id)
|
2026
|
+
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
|
2027
|
+
if is_local_uri(model.artifact_location):
|
2028
|
+
run_artifact_dir = pathlib.Path(
|
2029
|
+
local_file_uri_to_path(model.artifact_location)
|
2030
|
+
).resolve()
|
2031
|
+
if run_artifact_dir in [source, *source.parents]:
|
2032
|
+
return
|
2033
|
+
|
2034
|
+
raise MlflowException(
|
2035
|
+
f"Invalid model version source: '{source}'. To use a local path as a model version "
|
2036
|
+
"source, the model_id request parameter has to be specified and the local path has to "
|
2037
|
+
"be contained within the artifact directory of the run specified by the model_id.",
|
2038
|
+
INVALID_PARAMETER_VALUE,
|
2039
|
+
)
|
2040
|
+
|
2041
|
+
# Checks if relative paths are present in the source (a security threat). If any are present,
|
2042
|
+
# raises an Exception.
|
2043
|
+
_validate_non_local_source_contains_relative_paths(source)
|
2044
|
+
|
2045
|
+
|
2046
|
+
@catch_mlflow_exception
|
2047
|
+
@_disable_if_artifacts_only
|
2048
|
+
def _create_model_version():
|
2049
|
+
request_message = _get_request_message(
|
2050
|
+
CreateModelVersion(),
|
2051
|
+
schema={
|
2052
|
+
"name": [_assert_string, _assert_required],
|
2053
|
+
"source": [_assert_string, _assert_required],
|
2054
|
+
"run_id": [_assert_string],
|
2055
|
+
"tags": [_assert_array],
|
2056
|
+
"run_link": [_assert_string],
|
2057
|
+
"description": [_assert_string],
|
2058
|
+
"model_id": [_assert_string],
|
2059
|
+
},
|
2060
|
+
)
|
2061
|
+
|
2062
|
+
if request_message.source and (
|
2063
|
+
regex := MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX.get()
|
2064
|
+
):
|
2065
|
+
if not re.search(regex, request_message.source):
|
2066
|
+
raise MlflowException(
|
2067
|
+
f"Invalid model version source: '{request_message.source}'.",
|
2068
|
+
error_code=INVALID_PARAMETER_VALUE,
|
2069
|
+
)
|
2070
|
+
|
2071
|
+
# If the model version is a prompt, we don't validate the source
|
2072
|
+
if not _is_prompt_request(request_message):
|
2073
|
+
if request_message.model_id:
|
2074
|
+
_validate_source_model(request_message.source, request_message.model_id)
|
2075
|
+
else:
|
2076
|
+
_validate_source_run(request_message.source, request_message.run_id)
|
2077
|
+
|
2078
|
+
model_version = _get_model_registry_store().create_model_version(
|
2079
|
+
name=request_message.name,
|
2080
|
+
source=request_message.source,
|
2081
|
+
run_id=request_message.run_id,
|
2082
|
+
run_link=request_message.run_link,
|
2083
|
+
tags=request_message.tags,
|
2084
|
+
description=request_message.description,
|
2085
|
+
model_id=request_message.model_id,
|
2086
|
+
)
|
2087
|
+
if not _is_prompt_request(request_message) and request_message.model_id:
|
2088
|
+
tracking_store = _get_tracking_store()
|
2089
|
+
tracking_store.set_model_versions_tags(
|
2090
|
+
name=request_message.name,
|
2091
|
+
version=model_version.version,
|
2092
|
+
model_id=request_message.model_id,
|
2093
|
+
)
|
2094
|
+
response_message = CreateModelVersion.Response(model_version=model_version.to_proto())
|
2095
|
+
return _wrap_response(response_message)
|
2096
|
+
|
2097
|
+
|
2098
|
+
def _is_prompt_request(request_message):
|
2099
|
+
return any(tag.key == IS_PROMPT_TAG_KEY for tag in request_message.tags)
|
2100
|
+
|
2101
|
+
|
2102
|
+
@catch_mlflow_exception
|
2103
|
+
@_disable_if_artifacts_only
|
2104
|
+
def get_model_version_artifact_handler():
|
2105
|
+
name = request.args.get("name")
|
2106
|
+
version = request.args.get("version")
|
2107
|
+
path = request.args["path"]
|
2108
|
+
path = validate_path_is_safe(path)
|
2109
|
+
artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version)
|
2110
|
+
if _is_servable_proxied_run_artifact_root(artifact_uri):
|
2111
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2112
|
+
artifact_path = _get_proxied_run_artifact_destination_path(
|
2113
|
+
proxied_artifact_root=artifact_uri,
|
2114
|
+
relative_path=path,
|
2115
|
+
)
|
2116
|
+
else:
|
2117
|
+
artifact_repo = get_artifact_repository(artifact_uri)
|
2118
|
+
artifact_path = path
|
2119
|
+
|
2120
|
+
return _send_artifact(artifact_repo, artifact_path)
|
2121
|
+
|
2122
|
+
|
2123
|
+
@catch_mlflow_exception
|
2124
|
+
@_disable_if_artifacts_only
|
2125
|
+
def _get_model_version():
|
2126
|
+
request_message = _get_request_message(
|
2127
|
+
GetModelVersion(),
|
2128
|
+
schema={
|
2129
|
+
"name": [_assert_string, _assert_required],
|
2130
|
+
"version": [_assert_string, _assert_required],
|
2131
|
+
},
|
2132
|
+
)
|
2133
|
+
model_version = _get_model_registry_store().get_model_version(
|
2134
|
+
name=request_message.name, version=request_message.version
|
2135
|
+
)
|
2136
|
+
response_proto = model_version.to_proto()
|
2137
|
+
response_message = GetModelVersion.Response(model_version=response_proto)
|
2138
|
+
return _wrap_response(response_message)
|
2139
|
+
|
2140
|
+
|
2141
|
+
@catch_mlflow_exception
|
2142
|
+
@_disable_if_artifacts_only
|
2143
|
+
def _update_model_version():
|
2144
|
+
request_message = _get_request_message(
|
2145
|
+
UpdateModelVersion(),
|
2146
|
+
schema={
|
2147
|
+
"name": [_assert_string, _assert_required],
|
2148
|
+
"version": [_assert_string, _assert_required],
|
2149
|
+
"description": [_assert_string],
|
2150
|
+
},
|
2151
|
+
)
|
2152
|
+
new_description = None
|
2153
|
+
if request_message.HasField("description"):
|
2154
|
+
new_description = request_message.description
|
2155
|
+
model_version = _get_model_registry_store().update_model_version(
|
2156
|
+
name=request_message.name,
|
2157
|
+
version=request_message.version,
|
2158
|
+
description=new_description,
|
2159
|
+
)
|
2160
|
+
return _wrap_response(UpdateModelVersion.Response(model_version=model_version.to_proto()))
|
2161
|
+
|
2162
|
+
|
2163
|
+
@catch_mlflow_exception
|
2164
|
+
@_disable_if_artifacts_only
|
2165
|
+
def _transition_stage():
|
2166
|
+
request_message = _get_request_message(
|
2167
|
+
TransitionModelVersionStage(),
|
2168
|
+
schema={
|
2169
|
+
"name": [_assert_string, _assert_required],
|
2170
|
+
"version": [_assert_string, _assert_required],
|
2171
|
+
"stage": [_assert_string, _assert_required],
|
2172
|
+
"archive_existing_versions": [_assert_bool],
|
2173
|
+
},
|
2174
|
+
)
|
2175
|
+
model_version = _get_model_registry_store().transition_model_version_stage(
|
2176
|
+
name=request_message.name,
|
2177
|
+
version=request_message.version,
|
2178
|
+
stage=request_message.stage,
|
2179
|
+
archive_existing_versions=request_message.archive_existing_versions,
|
2180
|
+
)
|
2181
|
+
return _wrap_response(
|
2182
|
+
TransitionModelVersionStage.Response(model_version=model_version.to_proto())
|
2183
|
+
)
|
2184
|
+
|
2185
|
+
|
2186
|
+
@catch_mlflow_exception
|
2187
|
+
@_disable_if_artifacts_only
|
2188
|
+
def _delete_model_version():
|
2189
|
+
request_message = _get_request_message(
|
2190
|
+
DeleteModelVersion(),
|
2191
|
+
schema={
|
2192
|
+
"name": [_assert_string, _assert_required],
|
2193
|
+
"version": [_assert_string, _assert_required],
|
2194
|
+
},
|
2195
|
+
)
|
2196
|
+
_get_model_registry_store().delete_model_version(
|
2197
|
+
name=request_message.name, version=request_message.version
|
2198
|
+
)
|
2199
|
+
return _wrap_response(DeleteModelVersion.Response())
|
2200
|
+
|
2201
|
+
|
2202
|
+
@catch_mlflow_exception
|
2203
|
+
@_disable_if_artifacts_only
|
2204
|
+
def _get_model_version_download_uri():
|
2205
|
+
request_message = _get_request_message(GetModelVersionDownloadUri())
|
2206
|
+
download_uri = _get_model_registry_store().get_model_version_download_uri(
|
2207
|
+
name=request_message.name, version=request_message.version
|
2208
|
+
)
|
2209
|
+
response_message = GetModelVersionDownloadUri.Response(artifact_uri=download_uri)
|
2210
|
+
return _wrap_response(response_message)
|
2211
|
+
|
2212
|
+
|
2213
|
+
@catch_mlflow_exception
|
2214
|
+
@_disable_if_artifacts_only
|
2215
|
+
def _search_model_versions():
|
2216
|
+
request_message = _get_request_message(
|
2217
|
+
SearchModelVersions(),
|
2218
|
+
schema={
|
2219
|
+
"filter": [_assert_string],
|
2220
|
+
"max_results": [
|
2221
|
+
_assert_intlike,
|
2222
|
+
lambda x: _assert_less_than_or_equal(int(x), 200_000),
|
2223
|
+
],
|
2224
|
+
"order_by": [_assert_array, _assert_item_type_string],
|
2225
|
+
"page_token": [_assert_string],
|
2226
|
+
},
|
2227
|
+
)
|
2228
|
+
response_message = search_model_versions_impl(request_message)
|
2229
|
+
return _wrap_response(response_message)
|
2230
|
+
|
2231
|
+
|
2232
|
+
def search_model_versions_impl(request_message):
|
2233
|
+
store = _get_model_registry_store()
|
2234
|
+
model_versions = store.search_model_versions(
|
2235
|
+
filter_string=request_message.filter,
|
2236
|
+
max_results=request_message.max_results,
|
2237
|
+
order_by=request_message.order_by,
|
2238
|
+
page_token=request_message.page_token,
|
2239
|
+
)
|
2240
|
+
response_message = SearchModelVersions.Response()
|
2241
|
+
response_message.model_versions.extend([e.to_proto() for e in model_versions])
|
2242
|
+
if model_versions.token:
|
2243
|
+
response_message.next_page_token = model_versions.token
|
2244
|
+
return response_message
|
2245
|
+
|
2246
|
+
|
2247
|
+
@catch_mlflow_exception
|
2248
|
+
@_disable_if_artifacts_only
|
2249
|
+
def _set_model_version_tag():
|
2250
|
+
request_message = _get_request_message(
|
2251
|
+
SetModelVersionTag(),
|
2252
|
+
schema={
|
2253
|
+
"name": [_assert_string, _assert_required],
|
2254
|
+
"version": [_assert_string, _assert_required],
|
2255
|
+
"key": [_assert_string, _assert_required],
|
2256
|
+
"value": [_assert_string],
|
2257
|
+
},
|
2258
|
+
)
|
2259
|
+
tag = ModelVersionTag(key=request_message.key, value=request_message.value)
|
2260
|
+
_get_model_registry_store().set_model_version_tag(
|
2261
|
+
name=request_message.name, version=request_message.version, tag=tag
|
2262
|
+
)
|
2263
|
+
return _wrap_response(SetModelVersionTag.Response())
|
2264
|
+
|
2265
|
+
|
2266
|
+
@catch_mlflow_exception
|
2267
|
+
@_disable_if_artifacts_only
|
2268
|
+
def _delete_model_version_tag():
|
2269
|
+
request_message = _get_request_message(
|
2270
|
+
DeleteModelVersionTag(),
|
2271
|
+
schema={
|
2272
|
+
"name": [_assert_string, _assert_required],
|
2273
|
+
"version": [_assert_string, _assert_required],
|
2274
|
+
"key": [_assert_string, _assert_required],
|
2275
|
+
},
|
2276
|
+
)
|
2277
|
+
_get_model_registry_store().delete_model_version_tag(
|
2278
|
+
name=request_message.name,
|
2279
|
+
version=request_message.version,
|
2280
|
+
key=request_message.key,
|
2281
|
+
)
|
2282
|
+
return _wrap_response(DeleteModelVersionTag.Response())
|
2283
|
+
|
2284
|
+
|
2285
|
+
@catch_mlflow_exception
|
2286
|
+
@_disable_if_artifacts_only
|
2287
|
+
def _set_registered_model_alias():
|
2288
|
+
request_message = _get_request_message(
|
2289
|
+
SetRegisteredModelAlias(),
|
2290
|
+
schema={
|
2291
|
+
"name": [_assert_string, _assert_required],
|
2292
|
+
"alias": [_assert_string, _assert_required],
|
2293
|
+
"version": [_assert_string, _assert_required],
|
2294
|
+
},
|
2295
|
+
)
|
2296
|
+
_get_model_registry_store().set_registered_model_alias(
|
2297
|
+
name=request_message.name,
|
2298
|
+
alias=request_message.alias,
|
2299
|
+
version=request_message.version,
|
2300
|
+
)
|
2301
|
+
return _wrap_response(SetRegisteredModelAlias.Response())
|
2302
|
+
|
2303
|
+
|
2304
|
+
@catch_mlflow_exception
|
2305
|
+
@_disable_if_artifacts_only
|
2306
|
+
def _delete_registered_model_alias():
|
2307
|
+
request_message = _get_request_message(
|
2308
|
+
DeleteRegisteredModelAlias(),
|
2309
|
+
schema={
|
2310
|
+
"name": [_assert_string, _assert_required],
|
2311
|
+
"alias": [_assert_string, _assert_required],
|
2312
|
+
},
|
2313
|
+
)
|
2314
|
+
_get_model_registry_store().delete_registered_model_alias(
|
2315
|
+
name=request_message.name, alias=request_message.alias
|
2316
|
+
)
|
2317
|
+
return _wrap_response(DeleteRegisteredModelAlias.Response())
|
2318
|
+
|
2319
|
+
|
2320
|
+
@catch_mlflow_exception
|
2321
|
+
@_disable_if_artifacts_only
|
2322
|
+
def _get_model_version_by_alias():
|
2323
|
+
request_message = _get_request_message(
|
2324
|
+
GetModelVersionByAlias(),
|
2325
|
+
schema={
|
2326
|
+
"name": [_assert_string, _assert_required],
|
2327
|
+
"alias": [_assert_string, _assert_required],
|
2328
|
+
},
|
2329
|
+
)
|
2330
|
+
model_version = _get_model_registry_store().get_model_version_by_alias(
|
2331
|
+
name=request_message.name, alias=request_message.alias
|
2332
|
+
)
|
2333
|
+
response_proto = model_version.to_proto()
|
2334
|
+
response_message = GetModelVersionByAlias.Response(model_version=response_proto)
|
2335
|
+
return _wrap_response(response_message)
|
2336
|
+
|
2337
|
+
|
2338
|
+
# MLflow Artifacts APIs
|
2339
|
+
|
2340
|
+
|
2341
|
+
@catch_mlflow_exception
|
2342
|
+
@_disable_unless_serve_artifacts
|
2343
|
+
def _download_artifact(artifact_path):
|
2344
|
+
"""
|
2345
|
+
A request handler for `GET /mlflow-artifacts/artifacts/<artifact_path>` to download an artifact
|
2346
|
+
from `artifact_path` (a relative path from the root artifact directory).
|
2347
|
+
"""
|
2348
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2349
|
+
tmp_dir = tempfile.TemporaryDirectory()
|
2350
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2351
|
+
dst = artifact_repo.download_artifacts(artifact_path, tmp_dir.name)
|
2352
|
+
|
2353
|
+
# Ref: https://stackoverflow.com/a/24613980/6943581
|
2354
|
+
file_handle = open(dst, "rb") # noqa: SIM115
|
2355
|
+
|
2356
|
+
def stream_and_remove_file():
|
2357
|
+
yield from file_handle
|
2358
|
+
file_handle.close()
|
2359
|
+
tmp_dir.cleanup()
|
2360
|
+
|
2361
|
+
file_sender_response = current_app.response_class(stream_and_remove_file())
|
2362
|
+
|
2363
|
+
return _response_with_file_attachment_headers(artifact_path, file_sender_response)
|
2364
|
+
|
2365
|
+
|
2366
|
+
@catch_mlflow_exception
|
2367
|
+
@_disable_unless_serve_artifacts
|
2368
|
+
def _upload_artifact(artifact_path):
|
2369
|
+
"""
|
2370
|
+
A request handler for `PUT /mlflow-artifacts/artifacts/<artifact_path>` to upload an artifact
|
2371
|
+
to `artifact_path` (a relative path from the root artifact directory).
|
2372
|
+
"""
|
2373
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2374
|
+
head, tail = posixpath.split(artifact_path)
|
2375
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
2376
|
+
tmp_path = os.path.join(tmp_dir, tail)
|
2377
|
+
with open(tmp_path, "wb") as f:
|
2378
|
+
chunk_size = 1024 * 1024 # 1 MB
|
2379
|
+
while True:
|
2380
|
+
chunk = request.stream.read(chunk_size)
|
2381
|
+
if len(chunk) == 0:
|
2382
|
+
break
|
2383
|
+
f.write(chunk)
|
2384
|
+
|
2385
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2386
|
+
artifact_repo.log_artifact(tmp_path, artifact_path=head or None)
|
2387
|
+
|
2388
|
+
return _wrap_response(UploadArtifact.Response())
|
2389
|
+
|
2390
|
+
|
2391
|
+
@catch_mlflow_exception
|
2392
|
+
@_disable_unless_serve_artifacts
|
2393
|
+
def _list_artifacts_mlflow_artifacts():
|
2394
|
+
"""
|
2395
|
+
A request handler for `GET /mlflow-artifacts/artifacts?path=<value>` to list artifacts in `path`
|
2396
|
+
(a relative path from the root artifact directory).
|
2397
|
+
"""
|
2398
|
+
request_message = _get_request_message(ListArtifactsMlflowArtifacts())
|
2399
|
+
path = validate_path_is_safe(request_message.path) if request_message.HasField("path") else None
|
2400
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2401
|
+
files = []
|
2402
|
+
for file_info in artifact_repo.list_artifacts(path):
|
2403
|
+
basename = posixpath.basename(file_info.path)
|
2404
|
+
new_file_info = FileInfo(basename, file_info.is_dir, file_info.file_size)
|
2405
|
+
files.append(new_file_info.to_proto())
|
2406
|
+
response_message = ListArtifacts.Response()
|
2407
|
+
response_message.files.extend(files)
|
2408
|
+
response = Response(mimetype="application/json")
|
2409
|
+
response.set_data(message_to_json(response_message))
|
2410
|
+
return response
|
2411
|
+
|
2412
|
+
|
2413
|
+
@catch_mlflow_exception
|
2414
|
+
@_disable_unless_serve_artifacts
|
2415
|
+
def _delete_artifact_mlflow_artifacts(artifact_path):
|
2416
|
+
"""
|
2417
|
+
A request handler for `DELETE /mlflow-artifacts/artifacts?path=<value>` to delete artifacts in
|
2418
|
+
`path` (a relative path from the root artifact directory).
|
2419
|
+
"""
|
2420
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2421
|
+
_get_request_message(DeleteArtifact())
|
2422
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2423
|
+
artifact_repo.delete_artifacts(artifact_path)
|
2424
|
+
response_message = DeleteArtifact.Response()
|
2425
|
+
response = Response(mimetype="application/json")
|
2426
|
+
response.set_data(message_to_json(response_message))
|
2427
|
+
return response
|
2428
|
+
|
2429
|
+
|
2430
|
+
@catch_mlflow_exception
|
2431
|
+
def _graphql():
|
2432
|
+
from graphql import parse
|
2433
|
+
|
2434
|
+
from mlflow.server.graphql.graphql_no_batching import check_query_safety
|
2435
|
+
from mlflow.server.graphql.graphql_schema_extensions import schema
|
2436
|
+
|
2437
|
+
# Extracting the query, variables, and operationName from the request
|
2438
|
+
request_json = _get_request_json()
|
2439
|
+
query = request_json.get("query")
|
2440
|
+
variables = request_json.get("variables")
|
2441
|
+
operation_name = request_json.get("operationName")
|
2442
|
+
|
2443
|
+
node = parse(query)
|
2444
|
+
if check_result := check_query_safety(node):
|
2445
|
+
result = check_result
|
2446
|
+
else:
|
2447
|
+
# Executing the GraphQL query using the Graphene schema
|
2448
|
+
result = schema.execute(query, variables=variables, operation_name=operation_name)
|
2449
|
+
|
2450
|
+
# Convert execution result into json.
|
2451
|
+
result_data = {
|
2452
|
+
"data": result.data,
|
2453
|
+
"errors": [error.message for error in result.errors] if result.errors else None,
|
2454
|
+
}
|
2455
|
+
|
2456
|
+
# Return the response
|
2457
|
+
return jsonify(result_data)
|
2458
|
+
|
2459
|
+
|
2460
|
+
def _validate_support_multipart_upload(artifact_repo):
|
2461
|
+
if not isinstance(artifact_repo, MultipartUploadMixin):
|
2462
|
+
raise _UnsupportedMultipartUploadException()
|
2463
|
+
|
2464
|
+
|
2465
|
+
@catch_mlflow_exception
|
2466
|
+
@_disable_unless_serve_artifacts
|
2467
|
+
def _create_multipart_upload_artifact(artifact_path):
|
2468
|
+
"""
|
2469
|
+
A request handler for `POST /mlflow-artifacts/mpu/create` to create a multipart upload
|
2470
|
+
to `artifact_path` (a relative path from the root artifact directory).
|
2471
|
+
"""
|
2472
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2473
|
+
|
2474
|
+
request_message = _get_request_message(
|
2475
|
+
CreateMultipartUpload(),
|
2476
|
+
schema={
|
2477
|
+
"path": [_assert_required, _assert_string],
|
2478
|
+
"num_parts": [_assert_intlike],
|
2479
|
+
},
|
2480
|
+
)
|
2481
|
+
path = request_message.path
|
2482
|
+
num_parts = request_message.num_parts
|
2483
|
+
|
2484
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2485
|
+
_validate_support_multipart_upload(artifact_repo)
|
2486
|
+
|
2487
|
+
create_response = artifact_repo.create_multipart_upload(
|
2488
|
+
path,
|
2489
|
+
num_parts,
|
2490
|
+
artifact_path,
|
2491
|
+
)
|
2492
|
+
response_message = create_response.to_proto()
|
2493
|
+
response = Response(mimetype="application/json")
|
2494
|
+
response.set_data(message_to_json(response_message))
|
2495
|
+
return response
|
2496
|
+
|
2497
|
+
|
2498
|
+
@catch_mlflow_exception
|
2499
|
+
@_disable_unless_serve_artifacts
|
2500
|
+
def _complete_multipart_upload_artifact(artifact_path):
|
2501
|
+
"""
|
2502
|
+
A request handler for `POST /mlflow-artifacts/mpu/complete` to complete a multipart upload
|
2503
|
+
to `artifact_path` (a relative path from the root artifact directory).
|
2504
|
+
"""
|
2505
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2506
|
+
|
2507
|
+
request_message = _get_request_message(
|
2508
|
+
CompleteMultipartUpload(),
|
2509
|
+
schema={
|
2510
|
+
"path": [_assert_required, _assert_string],
|
2511
|
+
"upload_id": [_assert_string],
|
2512
|
+
"parts": [_assert_required],
|
2513
|
+
},
|
2514
|
+
)
|
2515
|
+
path = request_message.path
|
2516
|
+
upload_id = request_message.upload_id
|
2517
|
+
parts = [MultipartUploadPart.from_proto(part) for part in request_message.parts]
|
2518
|
+
|
2519
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2520
|
+
_validate_support_multipart_upload(artifact_repo)
|
2521
|
+
|
2522
|
+
artifact_repo.complete_multipart_upload(
|
2523
|
+
path,
|
2524
|
+
upload_id,
|
2525
|
+
parts,
|
2526
|
+
artifact_path,
|
2527
|
+
)
|
2528
|
+
return _wrap_response(CompleteMultipartUpload.Response())
|
2529
|
+
|
2530
|
+
|
2531
|
+
@catch_mlflow_exception
|
2532
|
+
@_disable_unless_serve_artifacts
|
2533
|
+
def _abort_multipart_upload_artifact(artifact_path):
|
2534
|
+
"""
|
2535
|
+
A request handler for `POST /mlflow-artifacts/mpu/abort` to abort a multipart upload
|
2536
|
+
to `artifact_path` (a relative path from the root artifact directory).
|
2537
|
+
"""
|
2538
|
+
artifact_path = validate_path_is_safe(artifact_path)
|
2539
|
+
|
2540
|
+
request_message = _get_request_message(
|
2541
|
+
AbortMultipartUpload(),
|
2542
|
+
schema={
|
2543
|
+
"path": [_assert_required, _assert_string],
|
2544
|
+
"upload_id": [_assert_string],
|
2545
|
+
},
|
2546
|
+
)
|
2547
|
+
path = request_message.path
|
2548
|
+
upload_id = request_message.upload_id
|
2549
|
+
|
2550
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2551
|
+
_validate_support_multipart_upload(artifact_repo)
|
2552
|
+
|
2553
|
+
artifact_repo.abort_multipart_upload(
|
2554
|
+
path,
|
2555
|
+
upload_id,
|
2556
|
+
artifact_path,
|
2557
|
+
)
|
2558
|
+
return _wrap_response(AbortMultipartUpload.Response())
|
2559
|
+
|
2560
|
+
|
2561
|
+
# MLflow Tracing APIs
|
2562
|
+
|
2563
|
+
|
2564
|
+
@catch_mlflow_exception
|
2565
|
+
@_disable_if_artifacts_only
|
2566
|
+
def _start_trace_v3():
|
2567
|
+
"""
|
2568
|
+
A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
|
2569
|
+
"""
|
2570
|
+
request_message = _get_request_message(
|
2571
|
+
StartTraceV3(),
|
2572
|
+
schema={"trace": [_assert_required]},
|
2573
|
+
)
|
2574
|
+
trace_info = TraceInfo.from_proto(request_message.trace.trace_info)
|
2575
|
+
trace_info = _get_tracking_store().start_trace(trace_info)
|
2576
|
+
response_message = StartTraceV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
|
2577
|
+
return _wrap_response(response_message)
|
2578
|
+
|
2579
|
+
|
2580
|
+
@catch_mlflow_exception
|
2581
|
+
@_disable_if_artifacts_only
|
2582
|
+
def _get_trace_info_v3(trace_id):
|
2583
|
+
"""
|
2584
|
+
A request handler for `GET /mlflow/traces/{trace_id}/info` to retrieve
|
2585
|
+
an existing TraceInfo record from tracking store.
|
2586
|
+
"""
|
2587
|
+
trace_info = _get_tracking_store().get_trace_info(trace_id)
|
2588
|
+
response_message = GetTraceInfoV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
|
2589
|
+
return _wrap_response(response_message)
|
2590
|
+
|
2591
|
+
|
2592
|
+
@catch_mlflow_exception
|
2593
|
+
@_disable_if_artifacts_only
|
2594
|
+
def _search_traces_v3():
|
2595
|
+
"""
|
2596
|
+
A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
|
2597
|
+
"""
|
2598
|
+
request_message = _get_request_message(
|
2599
|
+
SearchTracesV3(),
|
2600
|
+
schema={
|
2601
|
+
"locations": [_assert_array, _assert_required],
|
2602
|
+
"filter": [_assert_string],
|
2603
|
+
"max_results": [
|
2604
|
+
_assert_intlike,
|
2605
|
+
lambda x: _assert_less_than_or_equal(int(x), 500),
|
2606
|
+
],
|
2607
|
+
"order_by": [_assert_array, _assert_item_type_string],
|
2608
|
+
"page_token": [_assert_string],
|
2609
|
+
},
|
2610
|
+
)
|
2611
|
+
experiment_ids = []
|
2612
|
+
for location in request_message.locations:
|
2613
|
+
if location.HasField("mlflow_experiment"):
|
2614
|
+
experiment_ids.append(location.mlflow_experiment.experiment_id)
|
2615
|
+
|
2616
|
+
traces, token = _get_tracking_store().search_traces(
|
2617
|
+
experiment_ids=experiment_ids,
|
2618
|
+
filter_string=request_message.filter,
|
2619
|
+
max_results=request_message.max_results,
|
2620
|
+
order_by=request_message.order_by,
|
2621
|
+
page_token=request_message.page_token,
|
2622
|
+
)
|
2623
|
+
response_message = SearchTracesV3.Response()
|
2624
|
+
response_message.traces.extend([e.to_proto() for e in traces])
|
2625
|
+
if token:
|
2626
|
+
response_message.next_page_token = token
|
2627
|
+
return _wrap_response(response_message)
|
2628
|
+
|
2629
|
+
|
2630
|
+
@catch_mlflow_exception
|
2631
|
+
@_disable_if_artifacts_only
|
2632
|
+
def _delete_traces():
|
2633
|
+
"""
|
2634
|
+
A request handler for `POST /mlflow/traces/delete-traces` to delete TraceInfo records
|
2635
|
+
from tracking store.
|
2636
|
+
"""
|
2637
|
+
request_message = _get_request_message(
|
2638
|
+
DeleteTraces(),
|
2639
|
+
schema={
|
2640
|
+
"experiment_id": [_assert_string, _assert_required],
|
2641
|
+
"max_timestamp_millis": [_assert_intlike],
|
2642
|
+
"max_traces": [_assert_intlike],
|
2643
|
+
"request_ids": [_assert_array, _assert_item_type_string],
|
2644
|
+
},
|
2645
|
+
)
|
2646
|
+
|
2647
|
+
# NB: Interestingly, the field accessor for the message object returns the default
|
2648
|
+
# value for optional field if it's not set. For example, `request_message.max_traces`
|
2649
|
+
# returns 0 if max_traces is not specified in the request. This is not desirable,
|
2650
|
+
# because null and 0 means completely opposite i.e. the former is 'delete nothing'
|
2651
|
+
# while the latter is 'delete all'. To handle this, we need to explicitly check
|
2652
|
+
# if the field is set or not using `HasField` method and return None if not.
|
2653
|
+
def _get_nullable_field(field):
|
2654
|
+
if request_message.HasField(field):
|
2655
|
+
return getattr(request_message, field)
|
2656
|
+
return None
|
2657
|
+
|
2658
|
+
traces_deleted = _get_tracking_store().delete_traces(
|
2659
|
+
experiment_id=request_message.experiment_id,
|
2660
|
+
max_timestamp_millis=_get_nullable_field("max_timestamp_millis"),
|
2661
|
+
max_traces=_get_nullable_field("max_traces"),
|
2662
|
+
trace_ids=request_message.request_ids,
|
2663
|
+
)
|
2664
|
+
return _wrap_response(DeleteTraces.Response(traces_deleted=traces_deleted))
|
2665
|
+
|
2666
|
+
|
2667
|
+
@catch_mlflow_exception
|
2668
|
+
@_disable_if_artifacts_only
|
2669
|
+
def _set_trace_tag(request_id):
|
2670
|
+
"""
|
2671
|
+
A request handler for `PATCH /mlflow/traces/{request_id}/tags` to set tags on a TraceInfo record
|
2672
|
+
"""
|
2673
|
+
request_message = _get_request_message(
|
2674
|
+
SetTraceTag(),
|
2675
|
+
schema={
|
2676
|
+
"key": [_assert_string, _assert_required],
|
2677
|
+
"value": [_assert_string],
|
2678
|
+
},
|
2679
|
+
)
|
2680
|
+
_get_tracking_store().set_trace_tag(request_id, request_message.key, request_message.value)
|
2681
|
+
return _wrap_response(SetTraceTag.Response())
|
2682
|
+
|
2683
|
+
|
2684
|
+
@catch_mlflow_exception
|
2685
|
+
@_disable_if_artifacts_only
|
2686
|
+
def _delete_trace_tag(request_id):
|
2687
|
+
"""
|
2688
|
+
A request handler for `DELETE /mlflow/traces/{request_id}/tags` to delete tags from a TraceInfo
|
2689
|
+
record.
|
2690
|
+
"""
|
2691
|
+
request_message = _get_request_message(
|
2692
|
+
DeleteTraceTag(),
|
2693
|
+
schema={
|
2694
|
+
"key": [_assert_string, _assert_required],
|
2695
|
+
},
|
2696
|
+
)
|
2697
|
+
_get_tracking_store().delete_trace_tag(request_id, request_message.key)
|
2698
|
+
return _wrap_response(DeleteTraceTag.Response())
|
2699
|
+
|
2700
|
+
|
2701
|
+
@catch_mlflow_exception
|
2702
|
+
@_disable_if_artifacts_only
|
2703
|
+
def get_trace_artifact_handler():
|
2704
|
+
request_id = request.args.get("request_id")
|
2705
|
+
|
2706
|
+
if not request_id:
|
2707
|
+
raise MlflowException(
|
2708
|
+
'Request must include the "request_id" query parameter.',
|
2709
|
+
error_code=BAD_REQUEST,
|
2710
|
+
)
|
2711
|
+
|
2712
|
+
trace_info = _get_tracking_store().get_trace_info(request_id)
|
2713
|
+
trace_data = _get_trace_artifact_repo(trace_info).download_trace_data()
|
2714
|
+
|
2715
|
+
# Write data to a BytesIO buffer instead of needing to save a temp file
|
2716
|
+
buf = io.BytesIO()
|
2717
|
+
buf.write(json.dumps(trace_data).encode())
|
2718
|
+
buf.seek(0)
|
2719
|
+
|
2720
|
+
file_sender_response = send_file(
|
2721
|
+
buf,
|
2722
|
+
mimetype="application/octet-stream",
|
2723
|
+
as_attachment=True,
|
2724
|
+
download_name=TRACE_DATA_FILE_NAME,
|
2725
|
+
)
|
2726
|
+
return _response_with_file_attachment_headers(TRACE_DATA_FILE_NAME, file_sender_response)
|
2727
|
+
|
2728
|
+
|
2729
|
+
# Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use.
|
2730
|
+
|
2731
|
+
|
2732
|
+
@catch_mlflow_exception
|
2733
|
+
@_disable_if_artifacts_only
|
2734
|
+
def _deprecated_start_trace_v2():
|
2735
|
+
"""
|
2736
|
+
A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
|
2737
|
+
"""
|
2738
|
+
request_message = _get_request_message(
|
2739
|
+
StartTrace(),
|
2740
|
+
schema={
|
2741
|
+
"experiment_id": [_assert_string],
|
2742
|
+
"timestamp_ms": [_assert_intlike],
|
2743
|
+
"request_metadata": [_assert_map_key_present],
|
2744
|
+
"tags": [_assert_map_key_present],
|
2745
|
+
},
|
2746
|
+
)
|
2747
|
+
request_metadata = {e.key: e.value for e in request_message.request_metadata}
|
2748
|
+
tags = {e.key: e.value for e in request_message.tags}
|
2749
|
+
|
2750
|
+
trace_info = _get_tracking_store().deprecated_start_trace_v2(
|
2751
|
+
experiment_id=request_message.experiment_id,
|
2752
|
+
timestamp_ms=request_message.timestamp_ms,
|
2753
|
+
request_metadata=request_metadata,
|
2754
|
+
tags=tags,
|
2755
|
+
)
|
2756
|
+
response_message = StartTrace.Response(trace_info=trace_info.to_proto())
|
2757
|
+
return _wrap_response(response_message)
|
2758
|
+
|
2759
|
+
|
2760
|
+
@catch_mlflow_exception
|
2761
|
+
@_disable_if_artifacts_only
|
2762
|
+
def _deprecated_end_trace_v2(request_id):
|
2763
|
+
"""
|
2764
|
+
A request handler for `PATCH /mlflow/traces/{request_id}` to mark an existing TraceInfo
|
2765
|
+
record completed in tracking store.
|
2766
|
+
"""
|
2767
|
+
request_message = _get_request_message(
|
2768
|
+
EndTrace(),
|
2769
|
+
schema={
|
2770
|
+
"timestamp_ms": [_assert_intlike],
|
2771
|
+
"status": [_assert_string],
|
2772
|
+
"request_metadata": [_assert_map_key_present],
|
2773
|
+
"tags": [_assert_map_key_present],
|
2774
|
+
},
|
2775
|
+
)
|
2776
|
+
request_metadata = {e.key: e.value for e in request_message.request_metadata}
|
2777
|
+
tags = {e.key: e.value for e in request_message.tags}
|
2778
|
+
|
2779
|
+
trace_info = _get_tracking_store().deprecated_end_trace_v2(
|
2780
|
+
request_id=request_id,
|
2781
|
+
timestamp_ms=request_message.timestamp_ms,
|
2782
|
+
status=TraceStatus.from_proto(request_message.status),
|
2783
|
+
request_metadata=request_metadata,
|
2784
|
+
tags=tags,
|
2785
|
+
)
|
2786
|
+
|
2787
|
+
if isinstance(trace_info, TraceInfo):
|
2788
|
+
trace_info = TraceInfoV2.from_v3(trace_info)
|
2789
|
+
|
2790
|
+
response_message = EndTrace.Response(trace_info=trace_info.to_proto())
|
2791
|
+
return _wrap_response(response_message)
|
2792
|
+
|
2793
|
+
|
2794
|
+
@catch_mlflow_exception
|
2795
|
+
@_disable_if_artifacts_only
|
2796
|
+
def _deprecated_get_trace_info_v2(request_id):
|
2797
|
+
"""
|
2798
|
+
A request handler for `GET /mlflow/traces/{request_id}/info` to retrieve
|
2799
|
+
an existing TraceInfo record from tracking store.
|
2800
|
+
"""
|
2801
|
+
trace_info = _get_tracking_store().get_trace_info(request_id)
|
2802
|
+
trace_info = TraceInfoV2.from_v3(trace_info)
|
2803
|
+
response_message = GetTraceInfo.Response(trace_info=trace_info.to_proto())
|
2804
|
+
return _wrap_response(response_message)
|
2805
|
+
|
2806
|
+
|
2807
|
+
@catch_mlflow_exception
|
2808
|
+
@_disable_if_artifacts_only
|
2809
|
+
def _deprecated_search_traces_v2():
|
2810
|
+
"""
|
2811
|
+
A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
|
2812
|
+
"""
|
2813
|
+
request_message = _get_request_message(
|
2814
|
+
SearchTraces(),
|
2815
|
+
schema={
|
2816
|
+
"experiment_ids": [
|
2817
|
+
_assert_array,
|
2818
|
+
_assert_item_type_string,
|
2819
|
+
_assert_required,
|
2820
|
+
],
|
2821
|
+
"filter": [_assert_string],
|
2822
|
+
"max_results": [
|
2823
|
+
_assert_intlike,
|
2824
|
+
lambda x: _assert_less_than_or_equal(int(x), 500),
|
2825
|
+
],
|
2826
|
+
"order_by": [_assert_array, _assert_item_type_string],
|
2827
|
+
"page_token": [_assert_string],
|
2828
|
+
},
|
2829
|
+
)
|
2830
|
+
traces, token = _get_tracking_store().search_traces(
|
2831
|
+
experiment_ids=request_message.experiment_ids,
|
2832
|
+
filter_string=request_message.filter,
|
2833
|
+
max_results=request_message.max_results,
|
2834
|
+
order_by=request_message.order_by,
|
2835
|
+
page_token=request_message.page_token,
|
2836
|
+
)
|
2837
|
+
traces = [TraceInfoV2.from_v3(t) for t in traces]
|
2838
|
+
response_message = SearchTraces.Response()
|
2839
|
+
response_message.traces.extend([e.to_proto() for e in traces])
|
2840
|
+
if token:
|
2841
|
+
response_message.next_page_token = token
|
2842
|
+
return _wrap_response(response_message)
|
2843
|
+
|
2844
|
+
|
2845
|
+
# Logged Models APIs
|
2846
|
+
|
2847
|
+
|
2848
|
+
@catch_mlflow_exception
|
2849
|
+
@_disable_if_artifacts_only
|
2850
|
+
def get_logged_model_artifact_handler(model_id: str):
|
2851
|
+
artifact_file_path = request.args.get("artifact_file_path")
|
2852
|
+
if not artifact_file_path:
|
2853
|
+
raise MlflowException(
|
2854
|
+
'Request must include the "artifact_file_path" query parameter.',
|
2855
|
+
error_code=BAD_REQUEST,
|
2856
|
+
)
|
2857
|
+
validate_path_is_safe(artifact_file_path)
|
2858
|
+
|
2859
|
+
logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
|
2860
|
+
if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
|
2861
|
+
artifact_repo = _get_artifact_repo_mlflow_artifacts()
|
2862
|
+
artifact_path = _get_proxied_run_artifact_destination_path(
|
2863
|
+
proxied_artifact_root=logged_model.artifact_location,
|
2864
|
+
relative_path=artifact_file_path,
|
2865
|
+
)
|
2866
|
+
else:
|
2867
|
+
artifact_repo = get_artifact_repository(logged_model.artifact_location)
|
2868
|
+
artifact_path = artifact_file_path
|
2869
|
+
|
2870
|
+
return _send_artifact(artifact_repo, artifact_path)
|
2871
|
+
|
2872
|
+
|
2873
|
+
@catch_mlflow_exception
|
2874
|
+
@_disable_if_artifacts_only
|
2875
|
+
def _create_logged_model():
|
2876
|
+
request_message = _get_request_message(
|
2877
|
+
CreateLoggedModel(),
|
2878
|
+
schema={
|
2879
|
+
"experiment_id": [_assert_string, _assert_required],
|
2880
|
+
"name": [_assert_string],
|
2881
|
+
"model_type": [_assert_string],
|
2882
|
+
"source_run_id": [_assert_string],
|
2883
|
+
"params": [_assert_array],
|
2884
|
+
"tags": [_assert_array],
|
2885
|
+
},
|
2886
|
+
)
|
2887
|
+
|
2888
|
+
model = _get_tracking_store().create_logged_model(
|
2889
|
+
experiment_id=request_message.experiment_id,
|
2890
|
+
name=request_message.name or None,
|
2891
|
+
model_type=request_message.model_type,
|
2892
|
+
source_run_id=request_message.source_run_id,
|
2893
|
+
params=(
|
2894
|
+
[LoggedModelParameter.from_proto(param) for param in request_message.params]
|
2895
|
+
if request_message.params
|
2896
|
+
else None
|
2897
|
+
),
|
2898
|
+
tags=(
|
2899
|
+
[LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
|
2900
|
+
if request_message.tags
|
2901
|
+
else None
|
2902
|
+
),
|
2903
|
+
)
|
2904
|
+
response_message = CreateLoggedModel.Response(model=model.to_proto())
|
2905
|
+
return _wrap_response(response_message)
|
2906
|
+
|
2907
|
+
|
2908
|
+
@catch_mlflow_exception
|
2909
|
+
@_disable_if_artifacts_only
|
2910
|
+
def _log_logged_model_params(model_id: str):
|
2911
|
+
request_message = _get_request_message(
|
2912
|
+
LogLoggedModelParamsRequest(),
|
2913
|
+
schema={
|
2914
|
+
"model_id": [_assert_string, _assert_required],
|
2915
|
+
"params": [_assert_array],
|
2916
|
+
},
|
2917
|
+
)
|
2918
|
+
params = (
|
2919
|
+
[LoggedModelParameter.from_proto(param) for param in request_message.params]
|
2920
|
+
if request_message.params
|
2921
|
+
else []
|
2922
|
+
)
|
2923
|
+
_get_tracking_store().log_logged_model_params(model_id, params)
|
2924
|
+
return _wrap_response(LogLoggedModelParamsRequest.Response())
|
2925
|
+
|
2926
|
+
|
2927
|
+
@catch_mlflow_exception
|
2928
|
+
@_disable_if_artifacts_only
|
2929
|
+
def _get_logged_model(model_id: str):
|
2930
|
+
model = _get_tracking_store().get_logged_model(model_id)
|
2931
|
+
response_message = GetLoggedModel.Response(model=model.to_proto())
|
2932
|
+
return _wrap_response(response_message)
|
2933
|
+
|
2934
|
+
|
2935
|
+
@catch_mlflow_exception
|
2936
|
+
@_disable_if_artifacts_only
|
2937
|
+
def _finalize_logged_model(model_id: str):
|
2938
|
+
request_message = _get_request_message(
|
2939
|
+
FinalizeLoggedModel(),
|
2940
|
+
schema={
|
2941
|
+
"model_id": [_assert_string, _assert_required],
|
2942
|
+
"status": [_assert_intlike, _assert_required],
|
2943
|
+
},
|
2944
|
+
)
|
2945
|
+
model = _get_tracking_store().finalize_logged_model(
|
2946
|
+
request_message.model_id, LoggedModelStatus.from_int(request_message.status)
|
2947
|
+
)
|
2948
|
+
response_message = FinalizeLoggedModel.Response(model=model.to_proto())
|
2949
|
+
return _wrap_response(response_message)
|
2950
|
+
|
2951
|
+
|
2952
|
+
@catch_mlflow_exception
|
2953
|
+
@_disable_if_artifacts_only
|
2954
|
+
def _delete_logged_model(model_id: str):
|
2955
|
+
_get_tracking_store().delete_logged_model(model_id)
|
2956
|
+
return _wrap_response(DeleteLoggedModel.Response())
|
2957
|
+
|
2958
|
+
|
2959
|
+
@catch_mlflow_exception
|
2960
|
+
@_disable_if_artifacts_only
|
2961
|
+
def _set_logged_model_tags(model_id: str):
|
2962
|
+
request_message = _get_request_message(
|
2963
|
+
SetLoggedModelTags(),
|
2964
|
+
schema={"tags": [_assert_array]},
|
2965
|
+
)
|
2966
|
+
tags = [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
|
2967
|
+
_get_tracking_store().set_logged_model_tags(model_id, tags)
|
2968
|
+
return _wrap_response(SetLoggedModelTags.Response())
|
2969
|
+
|
2970
|
+
|
2971
|
+
@catch_mlflow_exception
|
2972
|
+
@_disable_if_artifacts_only
|
2973
|
+
def _delete_logged_model_tag(model_id: str, tag_key: str):
|
2974
|
+
_get_tracking_store().delete_logged_model_tag(model_id, tag_key)
|
2975
|
+
return _wrap_response(DeleteLoggedModelTag.Response())
|
2976
|
+
|
2977
|
+
|
2978
|
+
@catch_mlflow_exception
|
2979
|
+
@_disable_if_artifacts_only
|
2980
|
+
def _search_logged_models():
|
2981
|
+
request_message = _get_request_message(
|
2982
|
+
SearchLoggedModels(),
|
2983
|
+
schema={
|
2984
|
+
"experiment_ids": [
|
2985
|
+
_assert_array,
|
2986
|
+
_assert_item_type_string,
|
2987
|
+
_assert_required,
|
2988
|
+
],
|
2989
|
+
"filter": [_assert_string],
|
2990
|
+
"datasets": [_assert_array],
|
2991
|
+
"max_results": [_assert_intlike],
|
2992
|
+
"order_by": [_assert_array],
|
2993
|
+
"page_token": [_assert_string],
|
2994
|
+
},
|
2995
|
+
)
|
2996
|
+
models = _get_tracking_store().search_logged_models(
|
2997
|
+
# Convert `RepeatedScalarContainer` objects (experiment_ids and order_by) to `list`
|
2998
|
+
# to avoid serialization issues
|
2999
|
+
experiment_ids=list(request_message.experiment_ids),
|
3000
|
+
filter_string=request_message.filter or None,
|
3001
|
+
datasets=(
|
3002
|
+
[
|
3003
|
+
{
|
3004
|
+
"dataset_name": d.dataset_name,
|
3005
|
+
"dataset_digest": d.dataset_digest or None,
|
3006
|
+
}
|
3007
|
+
for d in request_message.datasets
|
3008
|
+
]
|
3009
|
+
if request_message.datasets
|
3010
|
+
else None
|
3011
|
+
),
|
3012
|
+
max_results=request_message.max_results or None,
|
3013
|
+
order_by=(
|
3014
|
+
[
|
3015
|
+
{
|
3016
|
+
"field_name": ob.field_name,
|
3017
|
+
"ascending": ob.ascending,
|
3018
|
+
"dataset_name": ob.dataset_name or None,
|
3019
|
+
"dataset_digest": ob.dataset_digest or None,
|
3020
|
+
}
|
3021
|
+
for ob in request_message.order_by
|
3022
|
+
]
|
3023
|
+
if request_message.order_by
|
3024
|
+
else None
|
3025
|
+
),
|
3026
|
+
page_token=request_message.page_token or None,
|
3027
|
+
)
|
3028
|
+
response_message = SearchLoggedModels.Response()
|
3029
|
+
response_message.models.extend([e.to_proto() for e in models])
|
3030
|
+
if models.token:
|
3031
|
+
response_message.next_page_token = models.token
|
3032
|
+
return _wrap_response(response_message)
|
3033
|
+
|
3034
|
+
|
3035
|
+
@catch_mlflow_exception
|
3036
|
+
@_disable_if_artifacts_only
|
3037
|
+
def _list_logged_model_artifacts(model_id: str):
|
3038
|
+
request_message = _get_request_message(
|
3039
|
+
ListLoggedModelArtifacts(),
|
3040
|
+
schema={"artifact_directory_path": [_assert_string]},
|
3041
|
+
)
|
3042
|
+
if request_message.HasField("artifact_directory_path"):
|
3043
|
+
artifact_path = validate_path_is_safe(request_message.artifact_directory_path)
|
3044
|
+
else:
|
3045
|
+
artifact_path = None
|
3046
|
+
|
3047
|
+
return _list_logged_model_artifacts_impl(model_id, artifact_path)
|
3048
|
+
|
3049
|
+
|
3050
|
+
def _list_logged_model_artifacts_impl(
|
3051
|
+
model_id: str, artifact_directory_path: Optional[str]
|
3052
|
+
) -> Response:
|
3053
|
+
response = ListLoggedModelArtifacts.Response()
|
3054
|
+
logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
|
3055
|
+
if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
|
3056
|
+
artifacts = _list_artifacts_for_proxied_run_artifact_root(
|
3057
|
+
proxied_artifact_root=logged_model.artifact_location,
|
3058
|
+
relative_path=artifact_directory_path,
|
3059
|
+
)
|
3060
|
+
else:
|
3061
|
+
artifacts = get_artifact_repository(logged_model.artifact_location).list_artifacts(
|
3062
|
+
artifact_directory_path
|
3063
|
+
)
|
3064
|
+
|
3065
|
+
response.files.extend([a.to_proto() for a in artifacts])
|
3066
|
+
response.root_uri = logged_model.artifact_location
|
3067
|
+
return _wrap_response(response)
|
3068
|
+
|
3069
|
+
|
3070
|
+
def _get_rest_path(base_path, version=2):
|
3071
|
+
return f"/api/{version}.0{base_path}"
|
3072
|
+
|
3073
|
+
|
3074
|
+
def _get_ajax_path(base_path, version=2):
|
3075
|
+
return _add_static_prefix(f"/ajax-api/{version}.0{base_path}")
|
3076
|
+
|
3077
|
+
|
3078
|
+
def _add_static_prefix(route: str) -> str:
|
3079
|
+
if prefix := os.environ.get(STATIC_PREFIX_ENV_VAR):
|
3080
|
+
return prefix.rstrip("/") + route
|
3081
|
+
return route
|
3082
|
+
|
3083
|
+
|
3084
|
+
def _get_paths(base_path, version=2):
|
3085
|
+
"""
|
3086
|
+
A service endpoints base path is typically something like /mlflow/experiment.
|
3087
|
+
We should register paths like /api/2.0/mlflow/experiment and
|
3088
|
+
/ajax-api/2.0/mlflow/experiment in the Flask router.
|
3089
|
+
"""
|
3090
|
+
base_path = _convert_path_parameter_to_flask_format(base_path)
|
3091
|
+
return [_get_rest_path(base_path, version), _get_ajax_path(base_path, version)]
|
3092
|
+
|
3093
|
+
|
3094
|
+
def _convert_path_parameter_to_flask_format(path):
|
3095
|
+
"""
|
3096
|
+
Converts path parameter format to Flask compatible format.
|
3097
|
+
|
3098
|
+
Some protobuf endpoint paths contain parameters like /mlflow/trace/{request_id}.
|
3099
|
+
This can be interpreted correctly by gRPC framework like Armeria, but Flask does
|
3100
|
+
not understand it. Instead, we need to specify it with a different format,
|
3101
|
+
like /mlflow/trace/<request_id>.
|
3102
|
+
"""
|
3103
|
+
return re.sub(r"{(\w+)}", r"<\1>", path)
|
3104
|
+
|
3105
|
+
|
3106
|
+
def get_handler(request_class):
|
3107
|
+
"""
|
3108
|
+
Args:
|
3109
|
+
request_class: The type of protobuf message
|
3110
|
+
"""
|
3111
|
+
return HANDLERS.get(request_class, _not_implemented)
|
3112
|
+
|
3113
|
+
|
3114
|
+
def get_service_endpoints(service, get_handler):
|
3115
|
+
ret = []
|
3116
|
+
for service_method in service.DESCRIPTOR.methods:
|
3117
|
+
endpoints = service_method.GetOptions().Extensions[databricks_pb2.rpc].endpoints
|
3118
|
+
for endpoint in endpoints:
|
3119
|
+
for http_path in _get_paths(endpoint.path, version=endpoint.since.major):
|
3120
|
+
handler = get_handler(service().GetRequestClass(service_method))
|
3121
|
+
ret.append((http_path, handler, [endpoint.method]))
|
3122
|
+
return ret
|
3123
|
+
|
3124
|
+
|
3125
|
+
def get_endpoints(get_handler=get_handler):
|
3126
|
+
"""
|
3127
|
+
Returns:
|
3128
|
+
List of tuples (path, handler, methods)
|
3129
|
+
"""
|
3130
|
+
return (
|
3131
|
+
get_service_endpoints(MlflowService, get_handler)
|
3132
|
+
+ get_service_endpoints(ModelRegistryService, get_handler)
|
3133
|
+
+ get_service_endpoints(MlflowArtifactsService, get_handler)
|
3134
|
+
+ [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])]
|
3135
|
+
)
|
3136
|
+
|
3137
|
+
|
3138
|
+
HANDLERS = {
|
3139
|
+
# Tracking Server APIs
|
3140
|
+
CreateExperiment: _create_experiment,
|
3141
|
+
GetExperiment: _get_experiment,
|
3142
|
+
GetExperimentByName: _get_experiment_by_name,
|
3143
|
+
DeleteExperiment: _delete_experiment,
|
3144
|
+
RestoreExperiment: _restore_experiment,
|
3145
|
+
UpdateExperiment: _update_experiment,
|
3146
|
+
CreateRun: _create_run,
|
3147
|
+
UpdateRun: _update_run,
|
3148
|
+
DeleteRun: _delete_run,
|
3149
|
+
RestoreRun: _restore_run,
|
3150
|
+
LogParam: _log_param,
|
3151
|
+
LogMetric: _log_metric,
|
3152
|
+
SetExperimentTag: _set_experiment_tag,
|
3153
|
+
SetTag: _set_tag,
|
3154
|
+
DeleteTag: _delete_tag,
|
3155
|
+
LogBatch: _log_batch,
|
3156
|
+
LogModel: _log_model,
|
3157
|
+
GetRun: _get_run,
|
3158
|
+
SearchRuns: _search_runs,
|
3159
|
+
ListArtifacts: _list_artifacts,
|
3160
|
+
GetMetricHistory: _get_metric_history,
|
3161
|
+
GetMetricHistoryBulkInterval: get_metric_history_bulk_interval_handler,
|
3162
|
+
SearchExperiments: _search_experiments,
|
3163
|
+
LogInputs: _log_inputs,
|
3164
|
+
LogOutputs: _log_outputs,
|
3165
|
+
# Model Registry APIs
|
3166
|
+
CreateRegisteredModel: _create_registered_model,
|
3167
|
+
GetRegisteredModel: _get_registered_model,
|
3168
|
+
DeleteRegisteredModel: _delete_registered_model,
|
3169
|
+
UpdateRegisteredModel: _update_registered_model,
|
3170
|
+
RenameRegisteredModel: _rename_registered_model,
|
3171
|
+
SearchRegisteredModels: _search_registered_models,
|
3172
|
+
GetLatestVersions: _get_latest_versions,
|
3173
|
+
CreateModelVersion: _create_model_version,
|
3174
|
+
GetModelVersion: _get_model_version,
|
3175
|
+
DeleteModelVersion: _delete_model_version,
|
3176
|
+
UpdateModelVersion: _update_model_version,
|
3177
|
+
TransitionModelVersionStage: _transition_stage,
|
3178
|
+
GetModelVersionDownloadUri: _get_model_version_download_uri,
|
3179
|
+
SearchModelVersions: _search_model_versions,
|
3180
|
+
SetRegisteredModelTag: _set_registered_model_tag,
|
3181
|
+
DeleteRegisteredModelTag: _delete_registered_model_tag,
|
3182
|
+
SetModelVersionTag: _set_model_version_tag,
|
3183
|
+
DeleteModelVersionTag: _delete_model_version_tag,
|
3184
|
+
SetRegisteredModelAlias: _set_registered_model_alias,
|
3185
|
+
DeleteRegisteredModelAlias: _delete_registered_model_alias,
|
3186
|
+
GetModelVersionByAlias: _get_model_version_by_alias,
|
3187
|
+
# MLflow Artifacts APIs
|
3188
|
+
DownloadArtifact: _download_artifact,
|
3189
|
+
UploadArtifact: _upload_artifact,
|
3190
|
+
ListArtifactsMlflowArtifacts: _list_artifacts_mlflow_artifacts,
|
3191
|
+
DeleteArtifact: _delete_artifact_mlflow_artifacts,
|
3192
|
+
CreateMultipartUpload: _create_multipart_upload_artifact,
|
3193
|
+
CompleteMultipartUpload: _complete_multipart_upload_artifact,
|
3194
|
+
AbortMultipartUpload: _abort_multipart_upload_artifact,
|
3195
|
+
# MLflow Tracing APIs (V3)
|
3196
|
+
StartTraceV3: _start_trace_v3,
|
3197
|
+
GetTraceInfoV3: _get_trace_info_v3,
|
3198
|
+
SearchTracesV3: _search_traces_v3,
|
3199
|
+
DeleteTraces: _delete_traces,
|
3200
|
+
SetTraceTag: _set_trace_tag,
|
3201
|
+
DeleteTraceTag: _delete_trace_tag,
|
3202
|
+
# Legacy MLflow Tracing V2 APIs. Kept for backward compatibility but do not use.
|
3203
|
+
StartTrace: _deprecated_start_trace_v2,
|
3204
|
+
EndTrace: _deprecated_end_trace_v2,
|
3205
|
+
GetTraceInfo: _deprecated_get_trace_info_v2,
|
3206
|
+
SearchTraces: _deprecated_search_traces_v2,
|
3207
|
+
# Logged Models APIs
|
3208
|
+
CreateLoggedModel: _create_logged_model,
|
3209
|
+
GetLoggedModel: _get_logged_model,
|
3210
|
+
FinalizeLoggedModel: _finalize_logged_model,
|
3211
|
+
DeleteLoggedModel: _delete_logged_model,
|
3212
|
+
SetLoggedModelTags: _set_logged_model_tags,
|
3213
|
+
DeleteLoggedModelTag: _delete_logged_model_tag,
|
3214
|
+
SearchLoggedModels: _search_logged_models,
|
3215
|
+
ListLoggedModelArtifacts: _list_logged_model_artifacts,
|
3216
|
+
LogLoggedModelParamsRequest: _log_logged_model_params,
|
3217
|
+
}
|