genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,2785 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import math
|
4
|
+
import random
|
5
|
+
import threading
|
6
|
+
import time
|
7
|
+
import uuid
|
8
|
+
from collections import defaultdict
|
9
|
+
from functools import reduce
|
10
|
+
from typing import Any, Optional, TypedDict
|
11
|
+
|
12
|
+
import sqlalchemy
|
13
|
+
import sqlalchemy.orm
|
14
|
+
import sqlalchemy.sql.expression as sql
|
15
|
+
from sqlalchemy import and_, func, sql, text
|
16
|
+
from sqlalchemy.future import select
|
17
|
+
|
18
|
+
import mlflow.store.db.utils
|
19
|
+
from mlflow.entities import (
|
20
|
+
DatasetInput,
|
21
|
+
Experiment,
|
22
|
+
Run,
|
23
|
+
RunInputs,
|
24
|
+
RunOutputs,
|
25
|
+
RunStatus,
|
26
|
+
RunTag,
|
27
|
+
SourceType,
|
28
|
+
TraceInfo,
|
29
|
+
ViewType,
|
30
|
+
_DatasetSummary,
|
31
|
+
)
|
32
|
+
from mlflow.entities.lifecycle_stage import LifecycleStage
|
33
|
+
from mlflow.entities.logged_model import LoggedModel
|
34
|
+
from mlflow.entities.logged_model_input import LoggedModelInput
|
35
|
+
from mlflow.entities.logged_model_output import LoggedModelOutput
|
36
|
+
from mlflow.entities.logged_model_parameter import LoggedModelParameter
|
37
|
+
from mlflow.entities.logged_model_status import LoggedModelStatus
|
38
|
+
from mlflow.entities.logged_model_tag import LoggedModelTag
|
39
|
+
from mlflow.entities.metric import Metric, MetricWithRunId
|
40
|
+
from mlflow.entities.trace_info_v2 import TraceInfoV2
|
41
|
+
from mlflow.entities.trace_status import TraceStatus
|
42
|
+
from mlflow.exceptions import MlflowException
|
43
|
+
from mlflow.protos.databricks_pb2 import (
|
44
|
+
INTERNAL_ERROR,
|
45
|
+
INVALID_PARAMETER_VALUE,
|
46
|
+
INVALID_STATE,
|
47
|
+
RESOURCE_ALREADY_EXISTS,
|
48
|
+
RESOURCE_DOES_NOT_EXIST,
|
49
|
+
)
|
50
|
+
from mlflow.store.db.db_types import MSSQL, MYSQL
|
51
|
+
from mlflow.store.entities.paged_list import PagedList
|
52
|
+
from mlflow.store.tracking import (
|
53
|
+
SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT,
|
54
|
+
SEARCH_MAX_RESULTS_DEFAULT,
|
55
|
+
SEARCH_MAX_RESULTS_THRESHOLD,
|
56
|
+
SEARCH_TRACES_DEFAULT_MAX_RESULTS,
|
57
|
+
)
|
58
|
+
from mlflow.store.tracking.abstract_store import AbstractStore
|
59
|
+
from mlflow.store.tracking.dbmodels.models import (
|
60
|
+
SqlDataset,
|
61
|
+
SqlExperiment,
|
62
|
+
SqlExperimentTag,
|
63
|
+
SqlInput,
|
64
|
+
SqlInputTag,
|
65
|
+
SqlLatestMetric,
|
66
|
+
SqlLoggedModel,
|
67
|
+
SqlLoggedModelMetric,
|
68
|
+
SqlLoggedModelParam,
|
69
|
+
SqlLoggedModelTag,
|
70
|
+
SqlMetric,
|
71
|
+
SqlParam,
|
72
|
+
SqlRun,
|
73
|
+
SqlTag,
|
74
|
+
SqlTraceInfo,
|
75
|
+
SqlTraceMetadata,
|
76
|
+
SqlTraceTag,
|
77
|
+
)
|
78
|
+
from mlflow.tracing.utils import generate_request_id_v2
|
79
|
+
from mlflow.tracking.fluent import _get_experiment_id
|
80
|
+
from mlflow.utils.file_utils import local_file_uri_to_path, mkdir
|
81
|
+
from mlflow.utils.mlflow_tags import (
|
82
|
+
MLFLOW_ARTIFACT_LOCATION,
|
83
|
+
MLFLOW_DATASET_CONTEXT,
|
84
|
+
MLFLOW_LOGGED_MODELS,
|
85
|
+
MLFLOW_RUN_NAME,
|
86
|
+
_get_run_name_from_tags,
|
87
|
+
)
|
88
|
+
from mlflow.utils.name_utils import _generate_random_name
|
89
|
+
from mlflow.utils.search_utils import (
|
90
|
+
SearchExperimentsUtils,
|
91
|
+
SearchLoggedModelsPaginationToken,
|
92
|
+
SearchTraceUtils,
|
93
|
+
SearchUtils,
|
94
|
+
)
|
95
|
+
from mlflow.utils.string_utils import is_string_type
|
96
|
+
from mlflow.utils.time import get_current_time_millis
|
97
|
+
from mlflow.utils.uri import (
|
98
|
+
append_to_uri_path,
|
99
|
+
extract_db_type_from_uri,
|
100
|
+
is_local_uri,
|
101
|
+
resolve_uri_if_local,
|
102
|
+
)
|
103
|
+
from mlflow.utils.validation import (
|
104
|
+
_validate_batch_log_data,
|
105
|
+
_validate_batch_log_limits,
|
106
|
+
_validate_dataset_inputs,
|
107
|
+
_validate_experiment_artifact_location_length,
|
108
|
+
_validate_experiment_name,
|
109
|
+
_validate_experiment_tag,
|
110
|
+
_validate_logged_model_name,
|
111
|
+
_validate_metric,
|
112
|
+
_validate_param,
|
113
|
+
_validate_param_keys_unique,
|
114
|
+
_validate_run_id,
|
115
|
+
_validate_tag,
|
116
|
+
_validate_trace_tag,
|
117
|
+
)
|
118
|
+
|
119
|
+
_logger = logging.getLogger(__name__)
|
120
|
+
|
121
|
+
# For each database table, fetch its columns and define an appropriate attribute for each column
|
122
|
+
# on the table's associated object representation (Mapper). This is necessary to ensure that
|
123
|
+
# columns defined via backreference are available as Mapper instance attributes (e.g.,
|
124
|
+
# ``SqlExperiment.tags`` and ``SqlRun.params``). For more information, see
|
125
|
+
# https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.configure_mappers
|
126
|
+
# and https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper
|
127
|
+
sqlalchemy.orm.configure_mappers()
|
128
|
+
|
129
|
+
|
130
|
+
class DatasetFilter(TypedDict, total=False):
|
131
|
+
"""
|
132
|
+
Dataset filter used for search_logged_models.
|
133
|
+
"""
|
134
|
+
|
135
|
+
dataset_name: str
|
136
|
+
dataset_digest: str
|
137
|
+
|
138
|
+
|
139
|
+
class SqlAlchemyStore(AbstractStore):
|
140
|
+
"""
|
141
|
+
SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
|
142
|
+
supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
|
143
|
+
As specified in the
|
144
|
+
`SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ ,
|
145
|
+
the database URI is expected in the format
|
146
|
+
``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. If you do not
|
147
|
+
specify a driver, SQLAlchemy uses a dialect's default driver.
|
148
|
+
|
149
|
+
This store interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
|
150
|
+
:py:class:`mlflow.store.dbmodels.models.SqlExperiment`,
|
151
|
+
:py:class:`mlflow.store.dbmodels.models.SqlRun`,
|
152
|
+
:py:class:`mlflow.store.dbmodels.models.SqlTag`,
|
153
|
+
:py:class:`mlflow.store.dbmodels.models.SqlMetric`, and
|
154
|
+
:py:class:`mlflow.store.dbmodels.models.SqlParam`.
|
155
|
+
|
156
|
+
Run artifacts are stored in a separate location using artifact stores conforming to
|
157
|
+
:py:class:`mlflow.store.artifact_repo.ArtifactRepository`. Default artifact locations for
|
158
|
+
user experiments are stored in the database along with metadata. Each run artifact location
|
159
|
+
is recorded in :py:class:`mlflow.store.dbmodels.models.SqlRun` and stored in the backend DB.
|
160
|
+
"""
|
161
|
+
|
162
|
+
ARTIFACTS_FOLDER_NAME = "artifacts"
|
163
|
+
MODELS_FOLDER_NAME = "models"
|
164
|
+
TRACE_FOLDER_NAME = "traces"
|
165
|
+
DEFAULT_EXPERIMENT_ID = "0"
|
166
|
+
_db_uri_sql_alchemy_engine_map = {}
|
167
|
+
_db_uri_sql_alchemy_engine_map_lock = threading.Lock()
|
168
|
+
|
169
|
+
def __init__(self, db_uri, default_artifact_root):
|
170
|
+
"""
|
171
|
+
Create a database backed store.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
db_uri: The SQLAlchemy database URI string to connect to the database. See
|
175
|
+
the `SQLAlchemy docs
|
176
|
+
<https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
|
177
|
+
for format specifications. MLflow supports the dialects ``mysql``,
|
178
|
+
``mssql``, ``sqlite``, and ``postgresql``.
|
179
|
+
default_artifact_root: Path/URI to location suitable for large data (such as a blob
|
180
|
+
store object, DBFS path, or shared NFS file system).
|
181
|
+
"""
|
182
|
+
super().__init__()
|
183
|
+
self.db_uri = db_uri
|
184
|
+
self.db_type = extract_db_type_from_uri(db_uri)
|
185
|
+
self.artifact_root_uri = resolve_uri_if_local(default_artifact_root)
|
186
|
+
# Quick check to see if the respective SQLAlchemy database engine has already been created.
|
187
|
+
if db_uri not in SqlAlchemyStore._db_uri_sql_alchemy_engine_map:
|
188
|
+
with SqlAlchemyStore._db_uri_sql_alchemy_engine_map_lock:
|
189
|
+
# Repeat check to prevent race conditions where one thread checks for an existing
|
190
|
+
# engine while another is creating the respective one, resulting in multiple
|
191
|
+
# engines being created. It isn't combined with the above check to prevent
|
192
|
+
# inefficiency from multiple threads waiting for the lock to check for engine
|
193
|
+
# existence if it has already been created.
|
194
|
+
if db_uri not in SqlAlchemyStore._db_uri_sql_alchemy_engine_map:
|
195
|
+
SqlAlchemyStore._db_uri_sql_alchemy_engine_map[db_uri] = (
|
196
|
+
mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri)
|
197
|
+
)
|
198
|
+
self.engine = SqlAlchemyStore._db_uri_sql_alchemy_engine_map[db_uri]
|
199
|
+
# On a completely fresh MLflow installation against an empty database (verify database
|
200
|
+
# emptiness by checking that 'experiments' etc aren't in the list of table names), run all
|
201
|
+
# DB migrations
|
202
|
+
if not mlflow.store.db.utils._all_tables_exist(self.engine):
|
203
|
+
mlflow.store.db.utils._initialize_tables(self.engine)
|
204
|
+
SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
|
205
|
+
self.ManagedSessionMaker = mlflow.store.db.utils._get_managed_session_maker(
|
206
|
+
SessionMaker, self.db_type
|
207
|
+
)
|
208
|
+
mlflow.store.db.utils._verify_schema(self.engine)
|
209
|
+
|
210
|
+
if is_local_uri(default_artifact_root):
|
211
|
+
mkdir(local_file_uri_to_path(default_artifact_root))
|
212
|
+
|
213
|
+
if len(self.search_experiments(view_type=ViewType.ALL)) == 0:
|
214
|
+
with self.ManagedSessionMaker() as session:
|
215
|
+
self._create_default_experiment(session)
|
216
|
+
|
217
|
+
def _get_dialect(self):
|
218
|
+
return self.engine.dialect.name
|
219
|
+
|
220
|
+
def _dispose_engine(self):
|
221
|
+
self.engine.dispose()
|
222
|
+
|
223
|
+
def _set_zero_value_insertion_for_autoincrement_column(self, session):
|
224
|
+
if self.db_type == MYSQL:
|
225
|
+
# config letting MySQL override default
|
226
|
+
# to allow 0 value for experiment ID (auto increment column)
|
227
|
+
session.execute(sql.text("SET @@SESSION.sql_mode='NO_AUTO_VALUE_ON_ZERO';"))
|
228
|
+
if self.db_type == MSSQL:
|
229
|
+
# config letting MSSQL override default
|
230
|
+
# to allow any manual value inserted into IDENTITY column
|
231
|
+
session.execute(sql.text("SET IDENTITY_INSERT experiments ON;"))
|
232
|
+
|
233
|
+
# DB helper methods to allow zero values for columns with auto increments
|
234
|
+
def _unset_zero_value_insertion_for_autoincrement_column(self, session):
|
235
|
+
if self.db_type == MYSQL:
|
236
|
+
session.execute(sql.text("SET @@SESSION.sql_mode='';"))
|
237
|
+
if self.db_type == MSSQL:
|
238
|
+
session.execute(sql.text("SET IDENTITY_INSERT experiments OFF;"))
|
239
|
+
|
240
|
+
def _create_default_experiment(self, session):
|
241
|
+
"""
|
242
|
+
MLflow UI and client code expects a default experiment with ID 0.
|
243
|
+
This method uses SQL insert statement to create the default experiment as a hack, since
|
244
|
+
experiment table uses 'experiment_id' column is a PK and is also set to auto increment.
|
245
|
+
MySQL and other implementation do not allow value '0' for such cases.
|
246
|
+
|
247
|
+
ToDo: Identify a less hacky mechanism to create default experiment 0
|
248
|
+
"""
|
249
|
+
table = SqlExperiment.__tablename__
|
250
|
+
creation_time = get_current_time_millis()
|
251
|
+
default_experiment = {
|
252
|
+
SqlExperiment.experiment_id.name: int(SqlAlchemyStore.DEFAULT_EXPERIMENT_ID),
|
253
|
+
SqlExperiment.name.name: Experiment.DEFAULT_EXPERIMENT_NAME,
|
254
|
+
SqlExperiment.artifact_location.name: str(self._get_artifact_location(0)),
|
255
|
+
SqlExperiment.lifecycle_stage.name: LifecycleStage.ACTIVE,
|
256
|
+
SqlExperiment.creation_time.name: creation_time,
|
257
|
+
SqlExperiment.last_update_time.name: creation_time,
|
258
|
+
}
|
259
|
+
|
260
|
+
def decorate(s):
|
261
|
+
if is_string_type(s):
|
262
|
+
return repr(s)
|
263
|
+
else:
|
264
|
+
return str(s)
|
265
|
+
|
266
|
+
# Get a list of keys to ensure we have a deterministic ordering
|
267
|
+
columns = list(default_experiment.keys())
|
268
|
+
values = ", ".join([decorate(default_experiment.get(c)) for c in columns])
|
269
|
+
|
270
|
+
try:
|
271
|
+
self._set_zero_value_insertion_for_autoincrement_column(session)
|
272
|
+
session.execute(
|
273
|
+
sql.text(f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({values});")
|
274
|
+
)
|
275
|
+
finally:
|
276
|
+
self._unset_zero_value_insertion_for_autoincrement_column(session)
|
277
|
+
|
278
|
+
def _get_or_create(self, session, model, **kwargs):
|
279
|
+
instance = session.query(model).filter_by(**kwargs).first()
|
280
|
+
created = False
|
281
|
+
|
282
|
+
if instance:
|
283
|
+
return instance, created
|
284
|
+
else:
|
285
|
+
instance = model(**kwargs)
|
286
|
+
session.add(instance)
|
287
|
+
created = True
|
288
|
+
|
289
|
+
return instance, created
|
290
|
+
|
291
|
+
def _get_artifact_location(self, experiment_id):
|
292
|
+
return append_to_uri_path(self.artifact_root_uri, str(experiment_id))
|
293
|
+
|
294
|
+
def create_experiment(self, name, artifact_location=None, tags=None):
|
295
|
+
_validate_experiment_name(name)
|
296
|
+
|
297
|
+
# Genesis-Flow: Use MLFLOW_ARTIFACT_LOCATION if no artifact location is provided
|
298
|
+
if not artifact_location:
|
299
|
+
from mlflow.environment_variables import MLFLOW_ARTIFACT_LOCATION
|
300
|
+
if MLFLOW_ARTIFACT_LOCATION.defined:
|
301
|
+
artifact_location = MLFLOW_ARTIFACT_LOCATION.get()
|
302
|
+
|
303
|
+
if artifact_location:
|
304
|
+
artifact_location = resolve_uri_if_local(artifact_location)
|
305
|
+
_validate_experiment_artifact_location_length(artifact_location)
|
306
|
+
with self.ManagedSessionMaker() as session:
|
307
|
+
try:
|
308
|
+
creation_time = get_current_time_millis()
|
309
|
+
experiment = SqlExperiment(
|
310
|
+
name=name,
|
311
|
+
lifecycle_stage=LifecycleStage.ACTIVE,
|
312
|
+
artifact_location=artifact_location,
|
313
|
+
creation_time=creation_time,
|
314
|
+
last_update_time=creation_time,
|
315
|
+
)
|
316
|
+
experiment.tags = (
|
317
|
+
[SqlExperimentTag(key=tag.key, value=tag.value) for tag in tags] if tags else []
|
318
|
+
)
|
319
|
+
session.add(experiment)
|
320
|
+
if not artifact_location:
|
321
|
+
# this requires a double write. The first one to generate an autoincrement-ed ID
|
322
|
+
eid = session.query(SqlExperiment).filter_by(name=name).first().experiment_id
|
323
|
+
experiment.artifact_location = self._get_artifact_location(eid)
|
324
|
+
except sqlalchemy.exc.IntegrityError as e:
|
325
|
+
raise MlflowException(
|
326
|
+
f"Experiment(name={name}) already exists. Error: {e}",
|
327
|
+
RESOURCE_ALREADY_EXISTS,
|
328
|
+
)
|
329
|
+
|
330
|
+
session.flush()
|
331
|
+
return str(experiment.experiment_id)
|
332
|
+
|
333
|
+
def _search_experiments(
|
334
|
+
self,
|
335
|
+
view_type,
|
336
|
+
max_results,
|
337
|
+
filter_string,
|
338
|
+
order_by,
|
339
|
+
page_token,
|
340
|
+
):
|
341
|
+
def compute_next_token(current_size):
|
342
|
+
next_token = None
|
343
|
+
if max_results + 1 == current_size:
|
344
|
+
final_offset = offset + max_results
|
345
|
+
next_token = SearchExperimentsUtils.create_page_token(final_offset)
|
346
|
+
|
347
|
+
return next_token
|
348
|
+
|
349
|
+
self._validate_max_results_param(max_results)
|
350
|
+
with self.ManagedSessionMaker() as session:
|
351
|
+
parsed_filters = SearchExperimentsUtils.parse_search_filter(filter_string)
|
352
|
+
attribute_filters, non_attribute_filters = _get_search_experiments_filter_clauses(
|
353
|
+
parsed_filters, self._get_dialect()
|
354
|
+
)
|
355
|
+
|
356
|
+
order_by_clauses = _get_search_experiments_order_by_clauses(order_by)
|
357
|
+
offset = SearchUtils.parse_start_offset_from_page_token(page_token)
|
358
|
+
lifecycle_stags = set(LifecycleStage.view_type_to_stages(view_type))
|
359
|
+
|
360
|
+
stmt = (
|
361
|
+
reduce(lambda s, f: s.join(f), non_attribute_filters, select(SqlExperiment))
|
362
|
+
.options(*self._get_eager_experiment_query_options())
|
363
|
+
.filter(
|
364
|
+
*attribute_filters,
|
365
|
+
SqlExperiment.lifecycle_stage.in_(lifecycle_stags),
|
366
|
+
)
|
367
|
+
.order_by(*order_by_clauses)
|
368
|
+
.offset(offset)
|
369
|
+
.limit(max_results + 1)
|
370
|
+
)
|
371
|
+
queried_experiments = session.execute(stmt).scalars(SqlExperiment).all()
|
372
|
+
experiments = [e.to_mlflow_entity() for e in queried_experiments]
|
373
|
+
next_page_token = compute_next_token(len(experiments))
|
374
|
+
|
375
|
+
return experiments[:max_results], next_page_token
|
376
|
+
|
377
|
+
def search_experiments(
|
378
|
+
self,
|
379
|
+
view_type=ViewType.ACTIVE_ONLY,
|
380
|
+
max_results=SEARCH_MAX_RESULTS_DEFAULT,
|
381
|
+
filter_string=None,
|
382
|
+
order_by=None,
|
383
|
+
page_token=None,
|
384
|
+
):
|
385
|
+
experiments, next_page_token = self._search_experiments(
|
386
|
+
view_type, max_results, filter_string, order_by, page_token
|
387
|
+
)
|
388
|
+
return PagedList(experiments, next_page_token)
|
389
|
+
|
390
|
+
def _get_experiment(self, session, experiment_id, view_type, eager=False): # noqa: D417
|
391
|
+
"""
|
392
|
+
Args:
|
393
|
+
eager: If ``True``, eagerly loads the experiments's tags. If ``False``, these tags
|
394
|
+
are not eagerly loaded and will be loaded if/when their corresponding
|
395
|
+
object properties are accessed from the resulting ``SqlExperiment`` object.
|
396
|
+
"""
|
397
|
+
experiment_id = experiment_id or SqlAlchemyStore.DEFAULT_EXPERIMENT_ID
|
398
|
+
stages = LifecycleStage.view_type_to_stages(view_type)
|
399
|
+
query_options = self._get_eager_experiment_query_options() if eager else []
|
400
|
+
|
401
|
+
experiment = (
|
402
|
+
session.query(SqlExperiment)
|
403
|
+
.options(*query_options)
|
404
|
+
.filter(
|
405
|
+
SqlExperiment.experiment_id == experiment_id,
|
406
|
+
SqlExperiment.lifecycle_stage.in_(stages),
|
407
|
+
)
|
408
|
+
.one_or_none()
|
409
|
+
)
|
410
|
+
|
411
|
+
if experiment is None:
|
412
|
+
raise MlflowException(
|
413
|
+
f"No Experiment with id={experiment_id} exists", RESOURCE_DOES_NOT_EXIST
|
414
|
+
)
|
415
|
+
|
416
|
+
return experiment
|
417
|
+
|
418
|
+
@staticmethod
|
419
|
+
def _get_eager_experiment_query_options():
|
420
|
+
"""
|
421
|
+
A list of SQLAlchemy query options that can be used to eagerly load the following
|
422
|
+
experiment attributes when fetching an experiment: ``tags``.
|
423
|
+
"""
|
424
|
+
return [
|
425
|
+
# Use a subquery load rather than a joined load in order to minimize the memory overhead
|
426
|
+
# of the eager loading procedure. For more information about relationship loading
|
427
|
+
# techniques, see https://docs.sqlalchemy.org/en/13/orm/
|
428
|
+
# loading_relationships.html#relationship-loading-techniques
|
429
|
+
sqlalchemy.orm.subqueryload(SqlExperiment.tags),
|
430
|
+
]
|
431
|
+
|
432
|
+
def get_experiment(self, experiment_id):
|
433
|
+
with self.ManagedSessionMaker() as session:
|
434
|
+
return self._get_experiment(
|
435
|
+
session, experiment_id, ViewType.ALL, eager=True
|
436
|
+
).to_mlflow_entity()
|
437
|
+
|
438
|
+
def get_experiment_by_name(self, experiment_name):
|
439
|
+
"""
|
440
|
+
Specialized implementation for SQL backed store.
|
441
|
+
"""
|
442
|
+
with self.ManagedSessionMaker() as session:
|
443
|
+
stages = LifecycleStage.view_type_to_stages(ViewType.ALL)
|
444
|
+
experiment = (
|
445
|
+
session.query(SqlExperiment)
|
446
|
+
.options(*self._get_eager_experiment_query_options())
|
447
|
+
.filter(
|
448
|
+
SqlExperiment.name == experiment_name,
|
449
|
+
SqlExperiment.lifecycle_stage.in_(stages),
|
450
|
+
)
|
451
|
+
.one_or_none()
|
452
|
+
)
|
453
|
+
return experiment.to_mlflow_entity() if experiment is not None else None
|
454
|
+
|
455
|
+
def delete_experiment(self, experiment_id):
|
456
|
+
with self.ManagedSessionMaker() as session:
|
457
|
+
experiment = self._get_experiment(session, experiment_id, ViewType.ACTIVE_ONLY)
|
458
|
+
experiment.lifecycle_stage = LifecycleStage.DELETED
|
459
|
+
experiment.last_update_time = get_current_time_millis()
|
460
|
+
runs = self._list_run_infos(session, experiment_id)
|
461
|
+
for run in runs:
|
462
|
+
self._mark_run_deleted(session, run)
|
463
|
+
session.add(experiment)
|
464
|
+
|
465
|
+
def _hard_delete_experiment(self, experiment_id):
|
466
|
+
"""
|
467
|
+
Permanently delete a experiment (metadata and metrics, tags, parameters).
|
468
|
+
This is used by the ``mlflow gc`` command line and is not intended to be used elsewhere.
|
469
|
+
"""
|
470
|
+
with self.ManagedSessionMaker() as session:
|
471
|
+
experiment = self._get_experiment(
|
472
|
+
experiment_id=experiment_id,
|
473
|
+
session=session,
|
474
|
+
view_type=ViewType.DELETED_ONLY,
|
475
|
+
)
|
476
|
+
session.delete(experiment)
|
477
|
+
|
478
|
+
def _mark_run_deleted(self, session, run):
|
479
|
+
run.lifecycle_stage = LifecycleStage.DELETED
|
480
|
+
run.deleted_time = get_current_time_millis()
|
481
|
+
session.add(run)
|
482
|
+
|
483
|
+
def _mark_run_active(self, session, run):
|
484
|
+
run.lifecycle_stage = LifecycleStage.ACTIVE
|
485
|
+
run.deleted_time = None
|
486
|
+
session.add(run)
|
487
|
+
|
488
|
+
def _list_run_infos(self, session, experiment_id):
|
489
|
+
return session.query(SqlRun).filter(SqlRun.experiment_id == experiment_id).all()
|
490
|
+
|
491
|
+
def restore_experiment(self, experiment_id):
|
492
|
+
with self.ManagedSessionMaker() as session:
|
493
|
+
experiment = self._get_experiment(session, experiment_id, ViewType.DELETED_ONLY)
|
494
|
+
experiment.lifecycle_stage = LifecycleStage.ACTIVE
|
495
|
+
experiment.last_update_time = get_current_time_millis()
|
496
|
+
runs = self._list_run_infos(session, experiment_id)
|
497
|
+
for run in runs:
|
498
|
+
self._mark_run_active(session, run)
|
499
|
+
session.add(experiment)
|
500
|
+
|
501
|
+
def rename_experiment(self, experiment_id, new_name):
|
502
|
+
with self.ManagedSessionMaker() as session:
|
503
|
+
experiment = self._get_experiment(session, experiment_id, ViewType.ALL)
|
504
|
+
if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
|
505
|
+
raise MlflowException("Cannot rename a non-active experiment.", INVALID_STATE)
|
506
|
+
|
507
|
+
experiment.name = new_name
|
508
|
+
experiment.last_update_time = get_current_time_millis()
|
509
|
+
session.add(experiment)
|
510
|
+
|
511
|
+
def create_run(self, experiment_id, user_id, start_time, tags, run_name):
|
512
|
+
with self.ManagedSessionMaker() as session:
|
513
|
+
experiment = self.get_experiment(experiment_id)
|
514
|
+
self._check_experiment_is_active(experiment)
|
515
|
+
|
516
|
+
# Note: we need to ensure the generated "run_id" only contains digits and lower
|
517
|
+
# case letters, because some query filters contain "IN" clause, and in MYSQL the
|
518
|
+
# "IN" clause is case-insensitive, we use a trick that filters out comparison values
|
519
|
+
# containing upper case letters when parsing "IN" clause inside query filter.
|
520
|
+
run_id = uuid.uuid4().hex
|
521
|
+
artifact_location = append_to_uri_path(
|
522
|
+
experiment.artifact_location,
|
523
|
+
run_id,
|
524
|
+
SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
|
525
|
+
)
|
526
|
+
tags = tags.copy() if tags else []
|
527
|
+
run_name_tag = _get_run_name_from_tags(tags)
|
528
|
+
if run_name and run_name_tag and (run_name != run_name_tag):
|
529
|
+
raise MlflowException(
|
530
|
+
"Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "
|
531
|
+
f"different values (run_name='{run_name}', mlflow.runName='{run_name_tag}').",
|
532
|
+
INVALID_PARAMETER_VALUE,
|
533
|
+
)
|
534
|
+
run_name = run_name or run_name_tag or _generate_random_name()
|
535
|
+
if not run_name_tag:
|
536
|
+
tags.append(RunTag(key=MLFLOW_RUN_NAME, value=run_name))
|
537
|
+
run = SqlRun(
|
538
|
+
name=run_name,
|
539
|
+
artifact_uri=artifact_location,
|
540
|
+
run_uuid=run_id,
|
541
|
+
experiment_id=experiment_id,
|
542
|
+
source_type=SourceType.to_string(SourceType.UNKNOWN),
|
543
|
+
source_name="",
|
544
|
+
entry_point_name="",
|
545
|
+
user_id=user_id,
|
546
|
+
status=RunStatus.to_string(RunStatus.RUNNING),
|
547
|
+
start_time=start_time,
|
548
|
+
end_time=None,
|
549
|
+
deleted_time=None,
|
550
|
+
source_version="",
|
551
|
+
lifecycle_stage=LifecycleStage.ACTIVE,
|
552
|
+
)
|
553
|
+
|
554
|
+
run.tags = [SqlTag(key=tag.key, value=tag.value) for tag in tags]
|
555
|
+
session.add(run)
|
556
|
+
|
557
|
+
run = run.to_mlflow_entity()
|
558
|
+
inputs_list = self._get_run_inputs(session, [run_id])
|
559
|
+
dataset_inputs = inputs_list[0] if inputs_list else []
|
560
|
+
return Run(run.info, run.data, RunInputs(dataset_inputs=dataset_inputs))
|
561
|
+
|
562
|
+
def _get_run(self, session, run_uuid, eager=False): # noqa: D417
|
563
|
+
"""
|
564
|
+
Args:
|
565
|
+
eager: If ``True``, eagerly loads the run's summary metrics (``latest_metrics``),
|
566
|
+
params, and tags when fetching the run. If ``False``, these attributes
|
567
|
+
are not eagerly loaded and will be loaded when their corresponding
|
568
|
+
object properties are accessed from the resulting ``SqlRun`` object.
|
569
|
+
"""
|
570
|
+
query_options = self._get_eager_run_query_options() if eager else []
|
571
|
+
runs = (
|
572
|
+
session.query(SqlRun).options(*query_options).filter(SqlRun.run_uuid == run_uuid).all()
|
573
|
+
)
|
574
|
+
|
575
|
+
if len(runs) == 0:
|
576
|
+
raise MlflowException(f"Run with id={run_uuid} not found", RESOURCE_DOES_NOT_EXIST)
|
577
|
+
if len(runs) > 1:
|
578
|
+
raise MlflowException(
|
579
|
+
f"Expected only 1 run with id={run_uuid}. Found {len(runs)}.",
|
580
|
+
INVALID_STATE,
|
581
|
+
)
|
582
|
+
|
583
|
+
return runs[0]
|
584
|
+
|
585
|
+
def _get_run_inputs(self, session, run_uuids):
|
586
|
+
datasets_with_tags = (
|
587
|
+
session.query(
|
588
|
+
SqlInput.input_uuid,
|
589
|
+
SqlInput.destination_id.label("run_uuid"),
|
590
|
+
SqlDataset,
|
591
|
+
SqlInputTag,
|
592
|
+
)
|
593
|
+
.select_from(SqlInput)
|
594
|
+
.join(SqlDataset, SqlInput.source_id == SqlDataset.dataset_uuid)
|
595
|
+
.outerjoin(SqlInputTag, SqlInputTag.input_uuid == SqlInput.input_uuid)
|
596
|
+
.filter(SqlInput.destination_type == "RUN", SqlInput.destination_id.in_(run_uuids))
|
597
|
+
.order_by("run_uuid")
|
598
|
+
).all()
|
599
|
+
|
600
|
+
dataset_inputs_per_run = defaultdict(dict)
|
601
|
+
for input_uuid, run_uuid, dataset_sql, tag_sql in datasets_with_tags:
|
602
|
+
dataset_inputs = dataset_inputs_per_run[run_uuid]
|
603
|
+
dataset_uuid = dataset_sql.dataset_uuid
|
604
|
+
dataset_input = dataset_inputs.get(dataset_uuid)
|
605
|
+
if dataset_input is None:
|
606
|
+
dataset_entity = dataset_sql.to_mlflow_entity()
|
607
|
+
dataset_input = DatasetInput(dataset=dataset_entity, tags=[])
|
608
|
+
dataset_inputs[dataset_uuid] = dataset_input
|
609
|
+
if tag_sql is not None:
|
610
|
+
dataset_input.tags.append(tag_sql.to_mlflow_entity())
|
611
|
+
return [list(dataset_inputs_per_run[run_uuid].values()) for run_uuid in run_uuids]
|
612
|
+
|
613
|
+
@staticmethod
|
614
|
+
def _get_eager_run_query_options():
|
615
|
+
"""
|
616
|
+
A list of SQLAlchemy query options that can be used to eagerly load the following
|
617
|
+
run attributes when fetching a run: ``latest_metrics``, ``params``, and ``tags``.
|
618
|
+
"""
|
619
|
+
return [
|
620
|
+
# Use a select in load rather than a joined load in order to minimize the memory
|
621
|
+
# overhead of the eager loading procedure. For more information about relationship
|
622
|
+
# loading techniques, see https://docs.sqlalchemy.org/en/13/orm/
|
623
|
+
# loading_relationships.html#relationship-loading-techniques
|
624
|
+
sqlalchemy.orm.selectinload(SqlRun.latest_metrics),
|
625
|
+
sqlalchemy.orm.selectinload(SqlRun.params),
|
626
|
+
sqlalchemy.orm.selectinload(SqlRun.tags),
|
627
|
+
]
|
628
|
+
|
629
|
+
def _check_run_is_active(self, run):
|
630
|
+
if run.lifecycle_stage != LifecycleStage.ACTIVE:
|
631
|
+
raise MlflowException(
|
632
|
+
(
|
633
|
+
f"The run {run.run_uuid} must be in the 'active' state. "
|
634
|
+
f"Current state is {run.lifecycle_stage}."
|
635
|
+
),
|
636
|
+
INVALID_PARAMETER_VALUE,
|
637
|
+
)
|
638
|
+
|
639
|
+
def _check_experiment_is_active(self, experiment):
|
640
|
+
if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
|
641
|
+
raise MlflowException(
|
642
|
+
(
|
643
|
+
f"The experiment {experiment.experiment_id} must be in the 'active' state. "
|
644
|
+
f"Current state is {experiment.lifecycle_stage}."
|
645
|
+
),
|
646
|
+
INVALID_PARAMETER_VALUE,
|
647
|
+
)
|
648
|
+
|
649
|
+
def update_run_info(self, run_id, run_status, end_time, run_name):
|
650
|
+
with self.ManagedSessionMaker() as session:
|
651
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
652
|
+
self._check_run_is_active(run)
|
653
|
+
if run_status is not None:
|
654
|
+
run.status = RunStatus.to_string(run_status)
|
655
|
+
if end_time is not None:
|
656
|
+
run.end_time = end_time
|
657
|
+
if run_name:
|
658
|
+
run.name = run_name
|
659
|
+
run_name_tag = self._try_get_run_tag(session, run_id, MLFLOW_RUN_NAME)
|
660
|
+
if run_name_tag is None:
|
661
|
+
run.tags.append(SqlTag(key=MLFLOW_RUN_NAME, value=run_name))
|
662
|
+
else:
|
663
|
+
run_name_tag.value = run_name
|
664
|
+
|
665
|
+
session.add(run)
|
666
|
+
run = run.to_mlflow_entity()
|
667
|
+
|
668
|
+
return run.info
|
669
|
+
|
670
|
+
def _try_get_run_tag(self, session, run_id, tagKey, eager=False):
|
671
|
+
query_options = self._get_eager_run_query_options() if eager else []
|
672
|
+
return (
|
673
|
+
session.query(SqlTag)
|
674
|
+
.options(*query_options)
|
675
|
+
.filter(SqlTag.run_uuid == run_id, SqlTag.key == tagKey)
|
676
|
+
.one_or_none()
|
677
|
+
)
|
678
|
+
|
679
|
+
def get_run(self, run_id):
|
680
|
+
with self.ManagedSessionMaker() as session:
|
681
|
+
# Load the run with the specified id and eagerly load its summary metrics, params, and
|
682
|
+
# tags. These attributes are referenced during the invocation of
|
683
|
+
# ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
|
684
|
+
# that are otherwise executed at attribute access time under a lazy loading model.
|
685
|
+
run = self._get_run(run_uuid=run_id, session=session, eager=True)
|
686
|
+
mlflow_run = run.to_mlflow_entity()
|
687
|
+
# Get the run inputs and add to the run
|
688
|
+
inputs = self._get_run_inputs(run_uuids=[run_id], session=session)[0]
|
689
|
+
model_inputs = self._get_model_inputs(run_id, session)
|
690
|
+
model_outputs = self._get_model_outputs(run_id, session)
|
691
|
+
return Run(
|
692
|
+
mlflow_run.info,
|
693
|
+
mlflow_run.data,
|
694
|
+
RunInputs(dataset_inputs=inputs, model_inputs=model_inputs),
|
695
|
+
RunOutputs(model_outputs),
|
696
|
+
)
|
697
|
+
|
698
|
+
def restore_run(self, run_id):
|
699
|
+
with self.ManagedSessionMaker() as session:
|
700
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
701
|
+
run.lifecycle_stage = LifecycleStage.ACTIVE
|
702
|
+
run.deleted_time = None
|
703
|
+
session.add(run)
|
704
|
+
|
705
|
+
def delete_run(self, run_id):
|
706
|
+
with self.ManagedSessionMaker() as session:
|
707
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
708
|
+
run.lifecycle_stage = LifecycleStage.DELETED
|
709
|
+
run.deleted_time = get_current_time_millis()
|
710
|
+
session.add(run)
|
711
|
+
|
712
|
+
def _hard_delete_run(self, run_id):
|
713
|
+
"""
|
714
|
+
Permanently delete a run (metadata and metrics, tags, parameters).
|
715
|
+
This is used by the ``mlflow gc`` command line and is not intended to be used elsewhere.
|
716
|
+
"""
|
717
|
+
with self.ManagedSessionMaker() as session:
|
718
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
719
|
+
session.delete(run)
|
720
|
+
|
721
|
+
def _get_deleted_runs(self, older_than=0):
|
722
|
+
"""
|
723
|
+
Get all deleted run ids.
|
724
|
+
|
725
|
+
Args:
|
726
|
+
older_than: get runs that is older than this variable in number of milliseconds.
|
727
|
+
defaults to 0 ms to get all deleted runs.
|
728
|
+
"""
|
729
|
+
current_time = get_current_time_millis()
|
730
|
+
with self.ManagedSessionMaker() as session:
|
731
|
+
runs = (
|
732
|
+
session.query(SqlRun)
|
733
|
+
.filter(
|
734
|
+
SqlRun.lifecycle_stage == LifecycleStage.DELETED,
|
735
|
+
SqlRun.deleted_time <= (current_time - older_than),
|
736
|
+
)
|
737
|
+
.all()
|
738
|
+
)
|
739
|
+
return [run.run_uuid for run in runs]
|
740
|
+
|
741
|
+
def log_metric(self, run_id, metric):
|
742
|
+
# simply call _log_metrics and let it handle the rest
|
743
|
+
self._log_metrics(run_id, [metric])
|
744
|
+
self._log_model_metrics(run_id, [metric])
|
745
|
+
|
746
|
+
def sanitize_metric_value(self, metric_value: float) -> tuple[bool, float]:
|
747
|
+
"""
|
748
|
+
Returns a tuple of two values:
|
749
|
+
- A boolean indicating whether the metric is NaN.
|
750
|
+
- The metric value, which is set to 0 if the metric is NaN.
|
751
|
+
"""
|
752
|
+
is_nan = math.isnan(metric_value)
|
753
|
+
if is_nan:
|
754
|
+
value = 0
|
755
|
+
elif math.isinf(metric_value):
|
756
|
+
# NB: Sql can not represent Infs = > We replace +/- Inf with max/min 64b float
|
757
|
+
# value
|
758
|
+
value = 1.7976931348623157e308 if metric_value > 0 else -1.7976931348623157e308
|
759
|
+
else:
|
760
|
+
value = metric_value
|
761
|
+
return is_nan, value
|
762
|
+
|
763
|
+
def _log_metrics(self, run_id, metrics):
|
764
|
+
# Duplicate metric values are eliminated here to maintain
|
765
|
+
# the same behavior in log_metric
|
766
|
+
metric_instances = []
|
767
|
+
seen = set()
|
768
|
+
is_single_metric = len(metrics) == 1
|
769
|
+
for idx, metric in enumerate(metrics):
|
770
|
+
_validate_metric(
|
771
|
+
metric.key,
|
772
|
+
metric.value,
|
773
|
+
metric.timestamp,
|
774
|
+
metric.step,
|
775
|
+
path="" if is_single_metric else f"metrics[{idx}]",
|
776
|
+
)
|
777
|
+
if metric not in seen:
|
778
|
+
is_nan, value = self.sanitize_metric_value(metric.value)
|
779
|
+
metric_instances.append(
|
780
|
+
SqlMetric(
|
781
|
+
run_uuid=run_id,
|
782
|
+
key=metric.key,
|
783
|
+
value=value,
|
784
|
+
timestamp=metric.timestamp,
|
785
|
+
step=metric.step,
|
786
|
+
is_nan=is_nan,
|
787
|
+
)
|
788
|
+
)
|
789
|
+
seen.add(metric)
|
790
|
+
|
791
|
+
with self.ManagedSessionMaker() as session:
|
792
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
793
|
+
self._check_run_is_active(run)
|
794
|
+
|
795
|
+
def _insert_metrics(metric_instances):
|
796
|
+
session.add_all(metric_instances)
|
797
|
+
self._update_latest_metrics_if_necessary(metric_instances, session)
|
798
|
+
session.commit()
|
799
|
+
|
800
|
+
try:
|
801
|
+
_insert_metrics(metric_instances)
|
802
|
+
except sqlalchemy.exc.IntegrityError:
|
803
|
+
# Primary key can be violated if it is tried to log a metric with same value,
|
804
|
+
# timestamp, step, and key within the same run.
|
805
|
+
# Roll back the current session to make it usable for further transactions. In
|
806
|
+
# the event of an error during "commit", a rollback is required in order to
|
807
|
+
# continue using the session. In this case, we re-use the session to query
|
808
|
+
# SqlMetric
|
809
|
+
session.rollback()
|
810
|
+
# Divide metric keys into batches of 100 to avoid loading too much metric
|
811
|
+
# history data into memory at once
|
812
|
+
metric_keys = [m.key for m in metric_instances]
|
813
|
+
metric_key_batches = [
|
814
|
+
metric_keys[i : i + 100] for i in range(0, len(metric_keys), 100)
|
815
|
+
]
|
816
|
+
for metric_key_batch in metric_key_batches:
|
817
|
+
# obtain the metric history corresponding to the given metrics
|
818
|
+
metric_history = (
|
819
|
+
session.query(SqlMetric)
|
820
|
+
.filter(
|
821
|
+
SqlMetric.run_uuid == run_id,
|
822
|
+
SqlMetric.key.in_(metric_key_batch),
|
823
|
+
)
|
824
|
+
.all()
|
825
|
+
)
|
826
|
+
# convert to a set of Metric instance to take advantage of its hashable
|
827
|
+
# and then obtain the metrics that were not logged earlier within this
|
828
|
+
# run_id
|
829
|
+
metric_history = {m.to_mlflow_entity() for m in metric_history}
|
830
|
+
non_existing_metrics = [
|
831
|
+
m for m in metric_instances if m.to_mlflow_entity() not in metric_history
|
832
|
+
]
|
833
|
+
# if there exist metrics that were tried to be logged & rolled back even
|
834
|
+
# though they were not violating the PK, log them
|
835
|
+
_insert_metrics(non_existing_metrics)
|
836
|
+
|
837
|
+
def _log_model_metrics(
|
838
|
+
self,
|
839
|
+
run_id: str,
|
840
|
+
metrics: list[Metric],
|
841
|
+
dataset_uuid: Optional[str] = None,
|
842
|
+
experiment_id: Optional[str] = None,
|
843
|
+
) -> None:
|
844
|
+
if not metrics:
|
845
|
+
return
|
846
|
+
|
847
|
+
metric_instances: list[SqlLoggedModelMetric] = []
|
848
|
+
is_single_metric = len(metrics) == 1
|
849
|
+
seen: set[Metric] = set()
|
850
|
+
for idx, metric in enumerate(metrics):
|
851
|
+
if metric.model_id is None:
|
852
|
+
continue
|
853
|
+
|
854
|
+
if metric in seen:
|
855
|
+
continue
|
856
|
+
seen.add(metric)
|
857
|
+
|
858
|
+
_validate_metric(
|
859
|
+
metric.key,
|
860
|
+
metric.value,
|
861
|
+
metric.timestamp,
|
862
|
+
metric.step,
|
863
|
+
path="" if is_single_metric else f"metrics[{idx}]",
|
864
|
+
)
|
865
|
+
is_nan, value = self.sanitize_metric_value(metric.value)
|
866
|
+
metric_instances.append(
|
867
|
+
SqlLoggedModelMetric(
|
868
|
+
model_id=metric.model_id,
|
869
|
+
metric_name=metric.key,
|
870
|
+
metric_timestamp_ms=metric.timestamp,
|
871
|
+
metric_step=metric.step,
|
872
|
+
metric_value=value,
|
873
|
+
experiment_id=experiment_id or _get_experiment_id(),
|
874
|
+
run_id=run_id,
|
875
|
+
dataset_uuid=dataset_uuid,
|
876
|
+
dataset_name=metric.dataset_name,
|
877
|
+
dataset_digest=metric.dataset_digest,
|
878
|
+
)
|
879
|
+
)
|
880
|
+
|
881
|
+
with self.ManagedSessionMaker() as session:
|
882
|
+
try:
|
883
|
+
session.add_all(metric_instances)
|
884
|
+
session.commit()
|
885
|
+
except sqlalchemy.exc.IntegrityError:
|
886
|
+
# Primary key can be violated if it is tried to log a metric with same value,
|
887
|
+
# timestamp, step, and key within the same run.
|
888
|
+
session.rollback()
|
889
|
+
metric_keys = [m.metric_name for m in metric_instances]
|
890
|
+
metric_key_batches = (
|
891
|
+
metric_keys[i : i + 100] for i in range(0, len(metric_keys), 100)
|
892
|
+
)
|
893
|
+
for batch in metric_key_batches:
|
894
|
+
existing_metrics = (
|
895
|
+
session.query(SqlLoggedModelMetric)
|
896
|
+
.filter(
|
897
|
+
SqlLoggedModelMetric.run_id == run_id,
|
898
|
+
SqlLoggedModelMetric.metric_name.in_(batch),
|
899
|
+
)
|
900
|
+
.all()
|
901
|
+
)
|
902
|
+
existing_metrics = {m.to_mlflow_entity() for m in existing_metrics}
|
903
|
+
non_existing_metrics = [
|
904
|
+
m for m in metric_instances if m.to_mlflow_entity() not in existing_metrics
|
905
|
+
]
|
906
|
+
session.add_all(non_existing_metrics)
|
907
|
+
|
908
|
+
def _update_latest_metrics_if_necessary(self, logged_metrics, session):
|
909
|
+
def _compare_metrics(metric_a, metric_b):
|
910
|
+
"""
|
911
|
+
Returns:
|
912
|
+
True if ``metric_a`` is strictly more recent than ``metric_b``, as determined
|
913
|
+
by ``step``, ``timestamp``, and ``value``. False otherwise.
|
914
|
+
"""
|
915
|
+
return (metric_a.step, metric_a.timestamp, metric_a.value) > (
|
916
|
+
metric_b.step,
|
917
|
+
metric_b.timestamp,
|
918
|
+
metric_b.value,
|
919
|
+
)
|
920
|
+
|
921
|
+
def _overwrite_metric(new_metric, old_metric):
|
922
|
+
"""
|
923
|
+
Writes content of new_metric over old_metric. The content are `value`, `step`,
|
924
|
+
`timestamp`, and `is_nan`.
|
925
|
+
|
926
|
+
Returns:
|
927
|
+
old_metric with its content updated.
|
928
|
+
"""
|
929
|
+
old_metric.value = new_metric.value
|
930
|
+
old_metric.step = new_metric.step
|
931
|
+
old_metric.timestamp = new_metric.timestamp
|
932
|
+
old_metric.is_nan = new_metric.is_nan
|
933
|
+
return old_metric
|
934
|
+
|
935
|
+
if not logged_metrics:
|
936
|
+
return
|
937
|
+
|
938
|
+
# Fetch the latest metric value corresponding to the specified run_id and metric keys and
|
939
|
+
# lock their associated rows for the remainder of the transaction in order to ensure
|
940
|
+
# isolation
|
941
|
+
latest_metrics = {}
|
942
|
+
metric_keys = [m.key for m in logged_metrics]
|
943
|
+
# Divide metric keys into batches of 500 to avoid binding too many parameters to the SQL
|
944
|
+
# query, which may produce limit exceeded errors or poor performance on certain database
|
945
|
+
# platforms
|
946
|
+
metric_key_batches = [metric_keys[i : i + 500] for i in range(0, len(metric_keys), 500)]
|
947
|
+
for metric_key_batch in metric_key_batches:
|
948
|
+
# First, determine which metric keys are present in the database
|
949
|
+
latest_metrics_key_records_from_db = (
|
950
|
+
session.query(SqlLatestMetric.key)
|
951
|
+
.filter(
|
952
|
+
SqlLatestMetric.run_uuid == logged_metrics[0].run_uuid,
|
953
|
+
SqlLatestMetric.key.in_(metric_key_batch),
|
954
|
+
)
|
955
|
+
.all()
|
956
|
+
)
|
957
|
+
# Then, take a write lock on the rows corresponding to metric keys that are present,
|
958
|
+
# ensuring that they aren't modified by another transaction until they can be
|
959
|
+
# compared to the metric values logged by this transaction while avoiding gap locking
|
960
|
+
# and next-key locking which may otherwise occur when issuing a `SELECT FOR UPDATE`
|
961
|
+
# against nonexistent rows
|
962
|
+
if len(latest_metrics_key_records_from_db) > 0:
|
963
|
+
latest_metric_keys_from_db = [
|
964
|
+
record[0] for record in latest_metrics_key_records_from_db
|
965
|
+
]
|
966
|
+
latest_metrics_batch = (
|
967
|
+
session.query(SqlLatestMetric)
|
968
|
+
.filter(
|
969
|
+
SqlLatestMetric.run_uuid == logged_metrics[0].run_uuid,
|
970
|
+
SqlLatestMetric.key.in_(latest_metric_keys_from_db),
|
971
|
+
)
|
972
|
+
# Order by the metric run ID and key to ensure a consistent locking order
|
973
|
+
# across transactions, reducing deadlock likelihood
|
974
|
+
.order_by(SqlLatestMetric.run_uuid, SqlLatestMetric.key)
|
975
|
+
.with_for_update()
|
976
|
+
.all()
|
977
|
+
)
|
978
|
+
latest_metrics.update({m.key: m for m in latest_metrics_batch})
|
979
|
+
|
980
|
+
# iterate over all logged metrics and compare them with corresponding
|
981
|
+
# SqlLatestMetric entries
|
982
|
+
# if there's no SqlLatestMetric entry for the current metric key,
|
983
|
+
# create a new SqlLatestMetric instance and put it in
|
984
|
+
# new_latest_metric_dict so that they can be saved later.
|
985
|
+
new_latest_metric_dict = {}
|
986
|
+
for logged_metric in logged_metrics:
|
987
|
+
latest_metric = latest_metrics.get(logged_metric.key)
|
988
|
+
# a metric key can be passed more then once within logged metrics
|
989
|
+
# with different step/timestamp/value. However SqlLatestMetric
|
990
|
+
# entries are inserted after this loop is completed.
|
991
|
+
# so, retrieve the instances they were just created and use them
|
992
|
+
# for comparison.
|
993
|
+
new_latest_metric = new_latest_metric_dict.get(logged_metric.key)
|
994
|
+
|
995
|
+
# just create a new SqlLatestMetric instance since both
|
996
|
+
# latest_metric row or recently created instance does not exist
|
997
|
+
if not latest_metric and not new_latest_metric:
|
998
|
+
new_latest_metric = SqlLatestMetric(
|
999
|
+
run_uuid=logged_metric.run_uuid,
|
1000
|
+
key=logged_metric.key,
|
1001
|
+
value=logged_metric.value,
|
1002
|
+
timestamp=logged_metric.timestamp,
|
1003
|
+
step=logged_metric.step,
|
1004
|
+
is_nan=logged_metric.is_nan,
|
1005
|
+
)
|
1006
|
+
new_latest_metric_dict[logged_metric.key] = new_latest_metric
|
1007
|
+
|
1008
|
+
# there's no row but a new instance is recently created.
|
1009
|
+
# so, update the recent instance in new_latest_metric_dict if
|
1010
|
+
# metric comparison is successful.
|
1011
|
+
elif not latest_metric and new_latest_metric:
|
1012
|
+
if _compare_metrics(logged_metric, new_latest_metric):
|
1013
|
+
new_latest_metric = _overwrite_metric(logged_metric, new_latest_metric)
|
1014
|
+
new_latest_metric_dict[logged_metric.key] = new_latest_metric
|
1015
|
+
|
1016
|
+
# compare with the row
|
1017
|
+
elif _compare_metrics(logged_metric, latest_metric):
|
1018
|
+
# editing the attributes of latest_metric, which is a
|
1019
|
+
# SqlLatestMetric instance will result in UPDATE in DB side.
|
1020
|
+
latest_metric = _overwrite_metric(logged_metric, latest_metric)
|
1021
|
+
|
1022
|
+
if new_latest_metric_dict:
|
1023
|
+
session.add_all(new_latest_metric_dict.values())
|
1024
|
+
|
1025
|
+
def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
|
1026
|
+
"""
|
1027
|
+
Return all logged values for a given metric.
|
1028
|
+
|
1029
|
+
Args:
|
1030
|
+
run_id: Unique identifier for run.
|
1031
|
+
metric_key: Metric name within the run.
|
1032
|
+
max_results: An indicator for paginated results.
|
1033
|
+
page_token: Token indicating the page of metric history to fetch.
|
1034
|
+
|
1035
|
+
Returns:
|
1036
|
+
A :py:class:`mlflow.store.entities.paged_list.PagedList` of
|
1037
|
+
:py:class:`mlflow.entities.Metric` entities if ``metric_key`` values
|
1038
|
+
have been logged to the ``run_id``, else an empty list.
|
1039
|
+
|
1040
|
+
"""
|
1041
|
+
with self.ManagedSessionMaker() as session:
|
1042
|
+
query = session.query(SqlMetric).filter_by(run_uuid=run_id, key=metric_key)
|
1043
|
+
|
1044
|
+
# Parse offset from page_token for pagination
|
1045
|
+
offset = SearchUtils.parse_start_offset_from_page_token(page_token)
|
1046
|
+
|
1047
|
+
# Add ORDER BY clause to satisfy MSSQL requirement for OFFSET
|
1048
|
+
query = query.order_by(SqlMetric.timestamp, SqlMetric.step, SqlMetric.value)
|
1049
|
+
query = query.offset(offset)
|
1050
|
+
|
1051
|
+
if max_results is not None:
|
1052
|
+
query = query.limit(max_results + 1)
|
1053
|
+
|
1054
|
+
metrics = query.all()
|
1055
|
+
|
1056
|
+
# Compute next token if more results are available
|
1057
|
+
next_token = None
|
1058
|
+
if max_results is not None and len(metrics) == max_results + 1:
|
1059
|
+
final_offset = offset + max_results
|
1060
|
+
next_token = SearchUtils.create_page_token(final_offset)
|
1061
|
+
metrics = metrics[:max_results]
|
1062
|
+
|
1063
|
+
return PagedList([metric.to_mlflow_entity() for metric in metrics], next_token)
|
1064
|
+
|
1065
|
+
def get_metric_history_bulk(self, run_ids, metric_key, max_results):
|
1066
|
+
"""
|
1067
|
+
Return all logged values for a given metric.
|
1068
|
+
|
1069
|
+
Args:
|
1070
|
+
run_ids: Unique identifiers of the runs from which to fetch the metric histories for
|
1071
|
+
the specified key.
|
1072
|
+
metric_key: Metric name within the runs.
|
1073
|
+
max_results: The maximum number of results to return.
|
1074
|
+
|
1075
|
+
Returns:
|
1076
|
+
A List of SqlAlchemyStore.MetricWithRunId objects if metric_key values have been logged
|
1077
|
+
to one or more of the specified run_ids, else an empty list. Results are sorted by run
|
1078
|
+
ID in lexicographically ascending order, followed by timestamp, step, and value in
|
1079
|
+
numerically ascending order.
|
1080
|
+
"""
|
1081
|
+
# NB: The SQLAlchemyStore does not currently support pagination for this API.
|
1082
|
+
# Raise if `page_token` is specified, as the functionality to support paged queries
|
1083
|
+
# is not implemented.
|
1084
|
+
with self.ManagedSessionMaker() as session:
|
1085
|
+
metrics = (
|
1086
|
+
session.query(SqlMetric)
|
1087
|
+
.filter(
|
1088
|
+
SqlMetric.key == metric_key,
|
1089
|
+
SqlMetric.run_uuid.in_(run_ids),
|
1090
|
+
)
|
1091
|
+
.order_by(
|
1092
|
+
SqlMetric.run_uuid,
|
1093
|
+
SqlMetric.timestamp,
|
1094
|
+
SqlMetric.step,
|
1095
|
+
SqlMetric.value,
|
1096
|
+
)
|
1097
|
+
.limit(max_results)
|
1098
|
+
.all()
|
1099
|
+
)
|
1100
|
+
return [
|
1101
|
+
MetricWithRunId(
|
1102
|
+
run_id=metric.run_uuid,
|
1103
|
+
metric=metric.to_mlflow_entity(),
|
1104
|
+
)
|
1105
|
+
for metric in metrics
|
1106
|
+
]
|
1107
|
+
|
1108
|
+
def get_max_step_for_metric(self, run_id, metric_key):
|
1109
|
+
with self.ManagedSessionMaker() as session:
|
1110
|
+
max_step = (
|
1111
|
+
session.query(func.max(SqlMetric.step))
|
1112
|
+
.filter(SqlMetric.run_uuid == run_id, SqlMetric.key == metric_key)
|
1113
|
+
.scalar()
|
1114
|
+
)
|
1115
|
+
return max_step or 0
|
1116
|
+
|
1117
|
+
def get_metric_history_bulk_interval_from_steps(self, run_id, metric_key, steps, max_results):
|
1118
|
+
with self.ManagedSessionMaker() as session:
|
1119
|
+
metrics = (
|
1120
|
+
session.query(SqlMetric)
|
1121
|
+
.filter(
|
1122
|
+
SqlMetric.key == metric_key,
|
1123
|
+
SqlMetric.run_uuid == run_id,
|
1124
|
+
SqlMetric.step.in_(steps),
|
1125
|
+
)
|
1126
|
+
.order_by(
|
1127
|
+
SqlMetric.run_uuid,
|
1128
|
+
SqlMetric.step,
|
1129
|
+
SqlMetric.timestamp,
|
1130
|
+
SqlMetric.value,
|
1131
|
+
)
|
1132
|
+
.limit(max_results)
|
1133
|
+
.all()
|
1134
|
+
)
|
1135
|
+
return [
|
1136
|
+
MetricWithRunId(
|
1137
|
+
run_id=metric.run_uuid,
|
1138
|
+
metric=metric.to_mlflow_entity(),
|
1139
|
+
)
|
1140
|
+
for metric in metrics
|
1141
|
+
]
|
1142
|
+
|
1143
|
+
def _search_datasets(self, experiment_ids):
|
1144
|
+
"""
|
1145
|
+
Return all dataset summaries associated to the given experiments.
|
1146
|
+
|
1147
|
+
Args:
|
1148
|
+
experiment_ids: List of experiment ids to scope the search
|
1149
|
+
|
1150
|
+
Returns:
|
1151
|
+
A List of :py:class:`SqlAlchemyStore.DatasetSummary` entities.
|
1152
|
+
"""
|
1153
|
+
|
1154
|
+
MAX_DATASET_SUMMARIES_RESULTS = 1000
|
1155
|
+
with self.ManagedSessionMaker() as session:
|
1156
|
+
# Note that the join with the input tag table is a left join. This is required so if an
|
1157
|
+
# input does not have the MLFLOW_DATASET_CONTEXT tag, we still return that entry as part
|
1158
|
+
# of the final result with the context set to None.
|
1159
|
+
summaries = (
|
1160
|
+
session.query(
|
1161
|
+
SqlDataset.experiment_id,
|
1162
|
+
SqlDataset.name,
|
1163
|
+
SqlDataset.digest,
|
1164
|
+
SqlInputTag.value,
|
1165
|
+
)
|
1166
|
+
.select_from(SqlDataset)
|
1167
|
+
.distinct()
|
1168
|
+
.join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
|
1169
|
+
.join(
|
1170
|
+
SqlInputTag,
|
1171
|
+
and_(
|
1172
|
+
SqlInput.input_uuid == SqlInputTag.input_uuid,
|
1173
|
+
SqlInputTag.name == MLFLOW_DATASET_CONTEXT,
|
1174
|
+
),
|
1175
|
+
isouter=True,
|
1176
|
+
)
|
1177
|
+
.filter(SqlDataset.experiment_id.in_(experiment_ids))
|
1178
|
+
.limit(MAX_DATASET_SUMMARIES_RESULTS)
|
1179
|
+
.all()
|
1180
|
+
)
|
1181
|
+
|
1182
|
+
return [
|
1183
|
+
_DatasetSummary(
|
1184
|
+
experiment_id=str(summary.experiment_id),
|
1185
|
+
name=summary.name,
|
1186
|
+
digest=summary.digest,
|
1187
|
+
context=summary.value,
|
1188
|
+
)
|
1189
|
+
for summary in summaries
|
1190
|
+
]
|
1191
|
+
|
1192
|
+
def log_param(self, run_id, param):
|
1193
|
+
param = _validate_param(param.key, param.value)
|
1194
|
+
with self.ManagedSessionMaker() as session:
|
1195
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1196
|
+
self._check_run_is_active(run)
|
1197
|
+
# if we try to update the value of an existing param this will fail
|
1198
|
+
# because it will try to create it with same run_uuid, param key
|
1199
|
+
try:
|
1200
|
+
# This will check for various integrity checks for params table.
|
1201
|
+
# ToDo: Consider prior checks for null, type, param name validations, ... etc.
|
1202
|
+
self._get_or_create(
|
1203
|
+
model=SqlParam,
|
1204
|
+
session=session,
|
1205
|
+
run_uuid=run_id,
|
1206
|
+
key=param.key,
|
1207
|
+
value=param.value,
|
1208
|
+
)
|
1209
|
+
# Explicitly commit the session in order to catch potential integrity errors
|
1210
|
+
# while maintaining the current managed session scope ("commit" checks that
|
1211
|
+
# a transaction satisfies uniqueness constraints and throws integrity errors
|
1212
|
+
# when they are violated; "get_or_create()" does not perform these checks). It is
|
1213
|
+
# important that we maintain the same session scope because, in the case of
|
1214
|
+
# an integrity error, we want to examine the uniqueness of parameter values using
|
1215
|
+
# the same database state that the session uses during "commit". Creating a new
|
1216
|
+
# session synchronizes the state with the database. As a result, if the conflicting
|
1217
|
+
# parameter value were to be removed prior to the creation of a new session,
|
1218
|
+
# we would be unable to determine the cause of failure for the first session's
|
1219
|
+
# "commit" operation.
|
1220
|
+
session.commit()
|
1221
|
+
except sqlalchemy.exc.IntegrityError:
|
1222
|
+
# Roll back the current session to make it usable for further transactions. In the
|
1223
|
+
# event of an error during "commit", a rollback is required in order to continue
|
1224
|
+
# using the session. In this case, we re-use the session because the SqlRun, `run`,
|
1225
|
+
# is lazily evaluated during the invocation of `run.params`.
|
1226
|
+
session.rollback()
|
1227
|
+
existing_params = [p.value for p in run.params if p.key == param.key]
|
1228
|
+
if len(existing_params) > 0:
|
1229
|
+
old_value = existing_params[0]
|
1230
|
+
if old_value != param.value:
|
1231
|
+
raise MlflowException(
|
1232
|
+
"Changing param values is not allowed. Param with key='{}' was already"
|
1233
|
+
" logged with value='{}' for run ID='{}'. Attempted logging new value"
|
1234
|
+
" '{}'.".format(param.key, old_value, run_id, param.value),
|
1235
|
+
INVALID_PARAMETER_VALUE,
|
1236
|
+
)
|
1237
|
+
else:
|
1238
|
+
raise
|
1239
|
+
|
1240
|
+
def _log_params(self, run_id, params):
|
1241
|
+
if not params:
|
1242
|
+
return
|
1243
|
+
|
1244
|
+
with self.ManagedSessionMaker() as session:
|
1245
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1246
|
+
self._check_run_is_active(run)
|
1247
|
+
existing_params = {p.key: p.value for p in run.params}
|
1248
|
+
new_params = []
|
1249
|
+
non_matching_params = []
|
1250
|
+
for param in params:
|
1251
|
+
if param.key in existing_params:
|
1252
|
+
if param.value != existing_params[param.key]:
|
1253
|
+
non_matching_params.append(
|
1254
|
+
{
|
1255
|
+
"key": param.key,
|
1256
|
+
"old_value": existing_params[param.key],
|
1257
|
+
"new_value": param.value,
|
1258
|
+
}
|
1259
|
+
)
|
1260
|
+
continue
|
1261
|
+
new_params.append(SqlParam(run_uuid=run_id, key=param.key, value=param.value))
|
1262
|
+
|
1263
|
+
if non_matching_params:
|
1264
|
+
raise MlflowException(
|
1265
|
+
"Changing param values is not allowed. Params were already"
|
1266
|
+
f" logged='{non_matching_params}' for run ID='{run_id}'.",
|
1267
|
+
INVALID_PARAMETER_VALUE,
|
1268
|
+
)
|
1269
|
+
|
1270
|
+
if not new_params:
|
1271
|
+
return
|
1272
|
+
|
1273
|
+
session.add_all(new_params)
|
1274
|
+
|
1275
|
+
def set_experiment_tag(self, experiment_id, tag):
|
1276
|
+
"""
|
1277
|
+
Set a tag for the specified experiment
|
1278
|
+
|
1279
|
+
Args:
|
1280
|
+
experiment_id: String ID of the experiment
|
1281
|
+
tag: ExperimentRunTag instance to log
|
1282
|
+
"""
|
1283
|
+
_validate_experiment_tag(tag.key, tag.value)
|
1284
|
+
with self.ManagedSessionMaker() as session:
|
1285
|
+
tag = _validate_tag(tag.key, tag.value)
|
1286
|
+
experiment = self._get_experiment(
|
1287
|
+
session, experiment_id, ViewType.ALL
|
1288
|
+
).to_mlflow_entity()
|
1289
|
+
self._check_experiment_is_active(experiment)
|
1290
|
+
session.merge(
|
1291
|
+
SqlExperimentTag(experiment_id=experiment_id, key=tag.key, value=tag.value)
|
1292
|
+
)
|
1293
|
+
|
1294
|
+
def set_tag(self, run_id, tag):
|
1295
|
+
"""
|
1296
|
+
Set a tag on a run.
|
1297
|
+
|
1298
|
+
Args:
|
1299
|
+
run_id: String ID of the run.
|
1300
|
+
tag: RunTag instance to log.
|
1301
|
+
"""
|
1302
|
+
with self.ManagedSessionMaker() as session:
|
1303
|
+
tag = _validate_tag(tag.key, tag.value)
|
1304
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1305
|
+
self._check_run_is_active(run)
|
1306
|
+
if tag.key == MLFLOW_RUN_NAME:
|
1307
|
+
run_status = RunStatus.from_string(run.status)
|
1308
|
+
self.update_run_info(run_id, run_status, run.end_time, tag.value)
|
1309
|
+
else:
|
1310
|
+
# NB: Updating the run_info will set the tag. No need to do it twice.
|
1311
|
+
session.merge(SqlTag(run_uuid=run_id, key=tag.key, value=tag.value))
|
1312
|
+
|
1313
|
+
def _set_tags(self, run_id, tags):
|
1314
|
+
"""
|
1315
|
+
Set multiple tags on a run
|
1316
|
+
|
1317
|
+
Args:
|
1318
|
+
run_id: String ID of the run
|
1319
|
+
tags: List of RunTag instances to log
|
1320
|
+
path: current json path for error messages
|
1321
|
+
"""
|
1322
|
+
if not tags:
|
1323
|
+
return
|
1324
|
+
|
1325
|
+
tags = [_validate_tag(t.key, t.value, path=f"tags[{idx}]") for (idx, t) in enumerate(tags)]
|
1326
|
+
|
1327
|
+
with self.ManagedSessionMaker() as session:
|
1328
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1329
|
+
self._check_run_is_active(run)
|
1330
|
+
|
1331
|
+
def _try_insert_tags(attempt_number, max_retries):
|
1332
|
+
try:
|
1333
|
+
current_tags = (
|
1334
|
+
session.query(SqlTag)
|
1335
|
+
.filter(
|
1336
|
+
SqlTag.run_uuid == run_id,
|
1337
|
+
SqlTag.key.in_([t.key for t in tags]),
|
1338
|
+
)
|
1339
|
+
.all()
|
1340
|
+
)
|
1341
|
+
current_tags = {t.key: t for t in current_tags}
|
1342
|
+
|
1343
|
+
new_tag_dict = {}
|
1344
|
+
for tag in tags:
|
1345
|
+
# NB: If the run name tag is explicitly set, update the run info attribute
|
1346
|
+
# and do not resubmit the tag for overwrite as the tag will be set within
|
1347
|
+
# `set_tag()` with a call to `update_run_info()`
|
1348
|
+
if tag.key == MLFLOW_RUN_NAME:
|
1349
|
+
self.set_tag(run_id, tag)
|
1350
|
+
else:
|
1351
|
+
current_tag = current_tags.get(tag.key)
|
1352
|
+
new_tag = new_tag_dict.get(tag.key)
|
1353
|
+
|
1354
|
+
# update the SqlTag if it is already present in DB
|
1355
|
+
if current_tag:
|
1356
|
+
current_tag.value = tag.value
|
1357
|
+
continue
|
1358
|
+
|
1359
|
+
# if a SqlTag instance is already present in `new_tag_dict`,
|
1360
|
+
# this means that multiple tags with the same key were passed to
|
1361
|
+
# `set_tags`.
|
1362
|
+
# In this case, we resolve potential conflicts by updating the value
|
1363
|
+
# of the existing instance to the value of `tag`
|
1364
|
+
if new_tag:
|
1365
|
+
new_tag.value = tag.value
|
1366
|
+
# otherwise, put it into the dict
|
1367
|
+
else:
|
1368
|
+
new_tag = SqlTag(run_uuid=run_id, key=tag.key, value=tag.value)
|
1369
|
+
|
1370
|
+
new_tag_dict[tag.key] = new_tag
|
1371
|
+
|
1372
|
+
# finally, save new entries to DB.
|
1373
|
+
session.add_all(new_tag_dict.values())
|
1374
|
+
session.commit()
|
1375
|
+
except sqlalchemy.exc.IntegrityError:
|
1376
|
+
session.rollback()
|
1377
|
+
# two concurrent operations may try to attempt to insert tags.
|
1378
|
+
# apply retry here.
|
1379
|
+
if attempt_number > max_retries:
|
1380
|
+
raise MlflowException(
|
1381
|
+
"Failed to set tags with given within {} retries. Keys: {}".format(
|
1382
|
+
max_retries, [t.key for t in tags]
|
1383
|
+
)
|
1384
|
+
)
|
1385
|
+
sleep_duration = (2**attempt_number) - 1
|
1386
|
+
sleep_duration += random.uniform(0, 1)
|
1387
|
+
time.sleep(sleep_duration)
|
1388
|
+
_try_insert_tags(attempt_number + 1, max_retries=max_retries)
|
1389
|
+
|
1390
|
+
_try_insert_tags(attempt_number=0, max_retries=3)
|
1391
|
+
|
1392
|
+
def delete_tag(self, run_id, key):
|
1393
|
+
"""
|
1394
|
+
Delete a tag from a run. This is irreversible.
|
1395
|
+
|
1396
|
+
Args:
|
1397
|
+
run_id: String ID of the run
|
1398
|
+
key: Name of the tag
|
1399
|
+
"""
|
1400
|
+
with self.ManagedSessionMaker() as session:
|
1401
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1402
|
+
self._check_run_is_active(run)
|
1403
|
+
filtered_tags = session.query(SqlTag).filter_by(run_uuid=run_id, key=key).all()
|
1404
|
+
if len(filtered_tags) == 0:
|
1405
|
+
raise MlflowException(
|
1406
|
+
f"No tag with name: {key} in run with id {run_id}",
|
1407
|
+
error_code=RESOURCE_DOES_NOT_EXIST,
|
1408
|
+
)
|
1409
|
+
elif len(filtered_tags) > 1:
|
1410
|
+
raise MlflowException(
|
1411
|
+
"Bad data in database - tags for a specific run must have "
|
1412
|
+
"a single unique value. "
|
1413
|
+
"See https://mlflow.org/docs/latest/tracking.html#adding-tags-to-runs",
|
1414
|
+
error_code=INVALID_STATE,
|
1415
|
+
)
|
1416
|
+
session.delete(filtered_tags[0])
|
1417
|
+
|
1418
|
+
def _search_runs(
|
1419
|
+
self,
|
1420
|
+
experiment_ids,
|
1421
|
+
filter_string,
|
1422
|
+
run_view_type,
|
1423
|
+
max_results,
|
1424
|
+
order_by,
|
1425
|
+
page_token,
|
1426
|
+
):
|
1427
|
+
def compute_next_token(current_size):
|
1428
|
+
next_token = None
|
1429
|
+
if max_results == current_size:
|
1430
|
+
final_offset = offset + max_results
|
1431
|
+
next_token = SearchUtils.create_page_token(final_offset)
|
1432
|
+
|
1433
|
+
return next_token
|
1434
|
+
|
1435
|
+
self._validate_max_results_param(max_results, allow_null=True)
|
1436
|
+
|
1437
|
+
stages = set(LifecycleStage.view_type_to_stages(run_view_type))
|
1438
|
+
|
1439
|
+
with self.ManagedSessionMaker() as session:
|
1440
|
+
# Fetch the appropriate runs and eagerly load their summary metrics, params, and
|
1441
|
+
# tags. These run attributes are referenced during the invocation of
|
1442
|
+
# ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
|
1443
|
+
# that are otherwise executed at attribute access time under a lazy loading model.
|
1444
|
+
parsed_filters = SearchUtils.parse_search_filter(filter_string)
|
1445
|
+
cases_orderby, parsed_orderby, sorting_joins = _get_orderby_clauses(order_by, session)
|
1446
|
+
|
1447
|
+
stmt = select(SqlRun, *cases_orderby)
|
1448
|
+
(
|
1449
|
+
attribute_filters,
|
1450
|
+
non_attribute_filters,
|
1451
|
+
dataset_filters,
|
1452
|
+
) = _get_sqlalchemy_filter_clauses(parsed_filters, session, self._get_dialect())
|
1453
|
+
for non_attr_filter in non_attribute_filters:
|
1454
|
+
stmt = stmt.join(non_attr_filter)
|
1455
|
+
for idx, dataset_filter in enumerate(dataset_filters):
|
1456
|
+
# need to reference the anon table in the join condition
|
1457
|
+
anon_table_name = f"anon_{idx + 1}"
|
1458
|
+
stmt = stmt.join(
|
1459
|
+
dataset_filter,
|
1460
|
+
text(f"runs.run_uuid = {anon_table_name}.destination_id"),
|
1461
|
+
)
|
1462
|
+
# using an outer join is necessary here because we want to be able to sort
|
1463
|
+
# on a column (tag, metric or param) without removing the lines that
|
1464
|
+
# do not have a value for this column (which is what inner join would do)
|
1465
|
+
for j in sorting_joins:
|
1466
|
+
stmt = stmt.outerjoin(j)
|
1467
|
+
|
1468
|
+
offset = SearchUtils.parse_start_offset_from_page_token(page_token)
|
1469
|
+
stmt = (
|
1470
|
+
stmt.distinct()
|
1471
|
+
.options(*self._get_eager_run_query_options())
|
1472
|
+
.filter(
|
1473
|
+
SqlRun.experiment_id.in_(experiment_ids),
|
1474
|
+
SqlRun.lifecycle_stage.in_(stages),
|
1475
|
+
*attribute_filters,
|
1476
|
+
)
|
1477
|
+
.order_by(*parsed_orderby)
|
1478
|
+
.offset(offset)
|
1479
|
+
.limit(max_results)
|
1480
|
+
)
|
1481
|
+
queried_runs = session.execute(stmt).scalars(SqlRun).all()
|
1482
|
+
|
1483
|
+
runs = [run.to_mlflow_entity() for run in queried_runs]
|
1484
|
+
run_ids = [run.info.run_id for run in runs]
|
1485
|
+
|
1486
|
+
# add inputs to runs
|
1487
|
+
inputs = self._get_run_inputs(run_uuids=run_ids, session=session)
|
1488
|
+
runs_with_inputs = []
|
1489
|
+
for i, run in enumerate(runs):
|
1490
|
+
runs_with_inputs.append(
|
1491
|
+
Run(run.info, run.data, RunInputs(dataset_inputs=inputs[i]))
|
1492
|
+
)
|
1493
|
+
|
1494
|
+
next_page_token = compute_next_token(len(runs_with_inputs))
|
1495
|
+
|
1496
|
+
return runs_with_inputs, next_page_token
|
1497
|
+
|
1498
|
+
def log_batch(self, run_id, metrics, params, tags):
|
1499
|
+
_validate_run_id(run_id)
|
1500
|
+
metrics, params, tags = _validate_batch_log_data(metrics, params, tags)
|
1501
|
+
_validate_batch_log_limits(metrics, params, tags)
|
1502
|
+
_validate_param_keys_unique(params)
|
1503
|
+
|
1504
|
+
with self.ManagedSessionMaker() as session:
|
1505
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1506
|
+
self._check_run_is_active(run)
|
1507
|
+
try:
|
1508
|
+
self._log_params(run_id, params)
|
1509
|
+
self._log_metrics(run_id, metrics)
|
1510
|
+
self._log_model_metrics(run_id, metrics)
|
1511
|
+
self._set_tags(run_id, tags)
|
1512
|
+
except MlflowException as e:
|
1513
|
+
raise e
|
1514
|
+
except Exception as e:
|
1515
|
+
raise MlflowException(e, INTERNAL_ERROR)
|
1516
|
+
|
1517
|
+
def record_logged_model(self, run_id, mlflow_model):
|
1518
|
+
from mlflow.models import Model
|
1519
|
+
|
1520
|
+
if not isinstance(mlflow_model, Model):
|
1521
|
+
raise TypeError(
|
1522
|
+
f"Argument 'mlflow_model' should be mlflow.models.Model, got '{type(mlflow_model)}'"
|
1523
|
+
)
|
1524
|
+
model_dict = mlflow_model.get_tags_dict()
|
1525
|
+
with self.ManagedSessionMaker() as session:
|
1526
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1527
|
+
self._check_run_is_active(run)
|
1528
|
+
previous_tag = [t for t in run.tags if t.key == MLFLOW_LOGGED_MODELS]
|
1529
|
+
if previous_tag:
|
1530
|
+
value = json.dumps(json.loads(previous_tag[0].value) + [model_dict])
|
1531
|
+
else:
|
1532
|
+
value = json.dumps([model_dict])
|
1533
|
+
_validate_tag(MLFLOW_LOGGED_MODELS, value)
|
1534
|
+
session.merge(SqlTag(key=MLFLOW_LOGGED_MODELS, value=value, run_uuid=run_id))
|
1535
|
+
|
1536
|
+
def log_inputs(
|
1537
|
+
self,
|
1538
|
+
run_id: str,
|
1539
|
+
datasets: Optional[list[DatasetInput]] = None,
|
1540
|
+
models: Optional[list[LoggedModelInput]] = None,
|
1541
|
+
):
|
1542
|
+
"""
|
1543
|
+
Log inputs, such as datasets, to the specified run.
|
1544
|
+
|
1545
|
+
Args:
|
1546
|
+
run_id: String id for the run
|
1547
|
+
datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log
|
1548
|
+
as inputs to the run.
|
1549
|
+
models: List of :py:class:`mlflow.entities.LoggedModelInput` instances to log
|
1550
|
+
as inputs to the run.
|
1551
|
+
|
1552
|
+
Returns:
|
1553
|
+
None.
|
1554
|
+
"""
|
1555
|
+
_validate_run_id(run_id)
|
1556
|
+
if datasets is not None:
|
1557
|
+
if not isinstance(datasets, list):
|
1558
|
+
raise TypeError(f"Argument 'datasets' should be a list, got '{type(datasets)}'")
|
1559
|
+
_validate_dataset_inputs(datasets)
|
1560
|
+
|
1561
|
+
with self.ManagedSessionMaker() as session:
|
1562
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1563
|
+
experiment_id = run.experiment_id
|
1564
|
+
self._check_run_is_active(run)
|
1565
|
+
try:
|
1566
|
+
self._log_inputs_impl(experiment_id, run_id, datasets, models)
|
1567
|
+
except MlflowException as e:
|
1568
|
+
raise e
|
1569
|
+
except Exception as e:
|
1570
|
+
raise MlflowException(e, INTERNAL_ERROR)
|
1571
|
+
|
1572
|
+
def _log_inputs_impl(
|
1573
|
+
self,
|
1574
|
+
experiment_id,
|
1575
|
+
run_id,
|
1576
|
+
dataset_inputs: Optional[list[DatasetInput]] = None,
|
1577
|
+
models: Optional[list[LoggedModelInput]] = None,
|
1578
|
+
):
|
1579
|
+
dataset_inputs = dataset_inputs or []
|
1580
|
+
for dataset_input in dataset_inputs:
|
1581
|
+
if dataset_input.dataset is None:
|
1582
|
+
raise MlflowException(
|
1583
|
+
"Dataset input must have a dataset associated with it.",
|
1584
|
+
INTERNAL_ERROR,
|
1585
|
+
)
|
1586
|
+
|
1587
|
+
# dedup dataset_inputs list if two dataset inputs have the same name and digest
|
1588
|
+
# keeping the first occurrence
|
1589
|
+
name_digest_keys = {}
|
1590
|
+
for dataset_input in dataset_inputs:
|
1591
|
+
key = (dataset_input.dataset.name, dataset_input.dataset.digest)
|
1592
|
+
if key not in name_digest_keys:
|
1593
|
+
name_digest_keys[key] = dataset_input
|
1594
|
+
dataset_inputs = list(name_digest_keys.values())
|
1595
|
+
|
1596
|
+
with self.ManagedSessionMaker() as session:
|
1597
|
+
dataset_names_to_check = [
|
1598
|
+
dataset_input.dataset.name for dataset_input in dataset_inputs
|
1599
|
+
]
|
1600
|
+
dataset_digests_to_check = [
|
1601
|
+
dataset_input.dataset.digest for dataset_input in dataset_inputs
|
1602
|
+
]
|
1603
|
+
# find all datasets with the same name and digest
|
1604
|
+
# if the dataset already exists, use the existing dataset uuid
|
1605
|
+
existing_datasets = (
|
1606
|
+
session.query(SqlDataset)
|
1607
|
+
.filter(SqlDataset.name.in_(dataset_names_to_check))
|
1608
|
+
.filter(SqlDataset.digest.in_(dataset_digests_to_check))
|
1609
|
+
.all()
|
1610
|
+
)
|
1611
|
+
dataset_uuids = {}
|
1612
|
+
for existing_dataset in existing_datasets:
|
1613
|
+
dataset_uuids[(existing_dataset.name, existing_dataset.digest)] = (
|
1614
|
+
existing_dataset.dataset_uuid
|
1615
|
+
)
|
1616
|
+
|
1617
|
+
# collect all objects to write to DB in a single list
|
1618
|
+
objs_to_write = []
|
1619
|
+
|
1620
|
+
# add datasets to objs_to_write
|
1621
|
+
for dataset_input in dataset_inputs:
|
1622
|
+
if (
|
1623
|
+
dataset_input.dataset.name,
|
1624
|
+
dataset_input.dataset.digest,
|
1625
|
+
) not in dataset_uuids:
|
1626
|
+
new_dataset_uuid = uuid.uuid4().hex
|
1627
|
+
dataset_uuids[(dataset_input.dataset.name, dataset_input.dataset.digest)] = (
|
1628
|
+
new_dataset_uuid
|
1629
|
+
)
|
1630
|
+
objs_to_write.append(
|
1631
|
+
SqlDataset(
|
1632
|
+
dataset_uuid=new_dataset_uuid,
|
1633
|
+
experiment_id=experiment_id,
|
1634
|
+
name=dataset_input.dataset.name,
|
1635
|
+
digest=dataset_input.dataset.digest,
|
1636
|
+
dataset_source_type=dataset_input.dataset.source_type,
|
1637
|
+
dataset_source=dataset_input.dataset.source,
|
1638
|
+
dataset_schema=dataset_input.dataset.schema,
|
1639
|
+
dataset_profile=dataset_input.dataset.profile,
|
1640
|
+
)
|
1641
|
+
)
|
1642
|
+
|
1643
|
+
# find all inputs with the same source_id and destination_id
|
1644
|
+
# if the input already exists, use the existing input uuid
|
1645
|
+
existing_inputs = (
|
1646
|
+
session.query(SqlInput)
|
1647
|
+
.filter(SqlInput.source_type == "DATASET")
|
1648
|
+
.filter(SqlInput.source_id.in_(dataset_uuids.values()))
|
1649
|
+
.filter(SqlInput.destination_type == "RUN")
|
1650
|
+
.filter(SqlInput.destination_id == run_id)
|
1651
|
+
.all()
|
1652
|
+
)
|
1653
|
+
input_uuids = {}
|
1654
|
+
for existing_input in existing_inputs:
|
1655
|
+
input_uuids[(existing_input.source_id, existing_input.destination_id)] = (
|
1656
|
+
existing_input.input_uuid
|
1657
|
+
)
|
1658
|
+
|
1659
|
+
# add input edges to objs_to_write
|
1660
|
+
for dataset_input in dataset_inputs:
|
1661
|
+
dataset_uuid = dataset_uuids[
|
1662
|
+
(dataset_input.dataset.name, dataset_input.dataset.digest)
|
1663
|
+
]
|
1664
|
+
if (dataset_uuid, run_id) not in input_uuids:
|
1665
|
+
new_input_uuid = uuid.uuid4().hex
|
1666
|
+
input_uuids[(dataset_input.dataset.name, dataset_input.dataset.digest)] = (
|
1667
|
+
new_input_uuid
|
1668
|
+
)
|
1669
|
+
objs_to_write.append(
|
1670
|
+
SqlInput(
|
1671
|
+
input_uuid=new_input_uuid,
|
1672
|
+
source_type="DATASET",
|
1673
|
+
source_id=dataset_uuid,
|
1674
|
+
destination_type="RUN",
|
1675
|
+
destination_id=run_id,
|
1676
|
+
)
|
1677
|
+
)
|
1678
|
+
# add input tags to objs_to_write
|
1679
|
+
for input_tag in dataset_input.tags:
|
1680
|
+
objs_to_write.append(
|
1681
|
+
SqlInputTag(
|
1682
|
+
input_uuid=new_input_uuid,
|
1683
|
+
name=input_tag.key,
|
1684
|
+
value=input_tag.value,
|
1685
|
+
)
|
1686
|
+
)
|
1687
|
+
|
1688
|
+
if models:
|
1689
|
+
for model in models:
|
1690
|
+
session.merge(
|
1691
|
+
SqlInput(
|
1692
|
+
input_uuid=uuid.uuid4().hex,
|
1693
|
+
source_type="RUN_INPUT",
|
1694
|
+
source_id=run_id,
|
1695
|
+
destination_type="MODEL_INPUT",
|
1696
|
+
destination_id=model.model_id,
|
1697
|
+
)
|
1698
|
+
)
|
1699
|
+
|
1700
|
+
session.add_all(objs_to_write)
|
1701
|
+
|
1702
|
+
def log_outputs(self, run_id: str, models: list[LoggedModelOutput]):
|
1703
|
+
with self.ManagedSessionMaker() as session:
|
1704
|
+
run = self._get_run(run_uuid=run_id, session=session)
|
1705
|
+
self._check_run_is_active(run)
|
1706
|
+
session.add_all(
|
1707
|
+
SqlInput(
|
1708
|
+
input_uuid=uuid.uuid4().hex,
|
1709
|
+
source_type="RUN_OUTPUT",
|
1710
|
+
source_id=run_id,
|
1711
|
+
destination_type="MODEL_OUTPUT",
|
1712
|
+
destination_id=model.model_id,
|
1713
|
+
step=model.step,
|
1714
|
+
)
|
1715
|
+
for model in models
|
1716
|
+
)
|
1717
|
+
|
1718
|
+
def _get_model_inputs(
|
1719
|
+
self,
|
1720
|
+
run_id: str,
|
1721
|
+
session: Optional[sqlalchemy.orm.Session] = None,
|
1722
|
+
) -> list[LoggedModelInput]:
|
1723
|
+
return [
|
1724
|
+
LoggedModelInput(model_id=input.destination_id)
|
1725
|
+
for input in (
|
1726
|
+
session.query(SqlInput)
|
1727
|
+
.filter(
|
1728
|
+
SqlInput.source_type == "RUN_INPUT",
|
1729
|
+
SqlInput.source_id == run_id,
|
1730
|
+
SqlInput.destination_type == "MODEL_INPUT",
|
1731
|
+
)
|
1732
|
+
.all()
|
1733
|
+
)
|
1734
|
+
]
|
1735
|
+
|
1736
|
+
def _get_model_outputs(
|
1737
|
+
self,
|
1738
|
+
run_id: str,
|
1739
|
+
session: sqlalchemy.orm.Session,
|
1740
|
+
) -> list[LoggedModelOutput]:
|
1741
|
+
return [
|
1742
|
+
LoggedModelOutput(model_id=output.destination_id, step=output.step)
|
1743
|
+
for output in session.query(SqlInput)
|
1744
|
+
.filter(
|
1745
|
+
SqlInput.source_type == "RUN_OUTPUT",
|
1746
|
+
SqlInput.source_id == run_id,
|
1747
|
+
SqlInput.destination_type == "MODEL_OUTPUT",
|
1748
|
+
)
|
1749
|
+
.all()
|
1750
|
+
]
|
1751
|
+
|
1752
|
+
#######################################################################################
|
1753
|
+
# Logged models
|
1754
|
+
#######################################################################################
|
1755
|
+
def create_logged_model(
|
1756
|
+
self,
|
1757
|
+
experiment_id: str,
|
1758
|
+
name: Optional[str] = None,
|
1759
|
+
source_run_id: Optional[str] = None,
|
1760
|
+
tags: Optional[list[LoggedModelTag]] = None,
|
1761
|
+
params: Optional[list[LoggedModelParameter]] = None,
|
1762
|
+
model_type: Optional[str] = None,
|
1763
|
+
) -> LoggedModel:
|
1764
|
+
_validate_logged_model_name(name)
|
1765
|
+
with self.ManagedSessionMaker() as session:
|
1766
|
+
experiment = self.get_experiment(experiment_id)
|
1767
|
+
self._check_experiment_is_active(experiment)
|
1768
|
+
model_id = f"m-{str(uuid.uuid4()).replace('-', '')}"
|
1769
|
+
artifact_location = append_to_uri_path(
|
1770
|
+
experiment.artifact_location,
|
1771
|
+
SqlAlchemyStore.MODELS_FOLDER_NAME,
|
1772
|
+
model_id,
|
1773
|
+
SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
|
1774
|
+
)
|
1775
|
+
name = name or _generate_random_name()
|
1776
|
+
creation_timestamp = get_current_time_millis()
|
1777
|
+
logged_model = SqlLoggedModel(
|
1778
|
+
model_id=model_id,
|
1779
|
+
experiment_id=experiment_id,
|
1780
|
+
name=name,
|
1781
|
+
artifact_location=artifact_location,
|
1782
|
+
creation_timestamp_ms=creation_timestamp,
|
1783
|
+
last_updated_timestamp_ms=creation_timestamp,
|
1784
|
+
model_type=model_type,
|
1785
|
+
status=LoggedModelStatus.PENDING.to_int(),
|
1786
|
+
lifecycle_stage=LifecycleStage.ACTIVE,
|
1787
|
+
source_run_id=source_run_id,
|
1788
|
+
)
|
1789
|
+
session.add(logged_model)
|
1790
|
+
|
1791
|
+
if params:
|
1792
|
+
session.add_all(
|
1793
|
+
SqlLoggedModelParam(
|
1794
|
+
model_id=logged_model.model_id,
|
1795
|
+
experiment_id=experiment_id,
|
1796
|
+
param_key=param.key,
|
1797
|
+
param_value=param.value,
|
1798
|
+
)
|
1799
|
+
for param in params
|
1800
|
+
)
|
1801
|
+
|
1802
|
+
if tags:
|
1803
|
+
session.add_all(
|
1804
|
+
SqlLoggedModelTag(
|
1805
|
+
model_id=logged_model.model_id,
|
1806
|
+
experiment_id=experiment_id,
|
1807
|
+
tag_key=tag.key,
|
1808
|
+
tag_value=tag.value,
|
1809
|
+
)
|
1810
|
+
for tag in tags
|
1811
|
+
)
|
1812
|
+
|
1813
|
+
session.commit()
|
1814
|
+
return logged_model.to_mlflow_entity()
|
1815
|
+
|
1816
|
+
def log_logged_model_params(self, model_id: str, params: list[LoggedModelParameter]):
|
1817
|
+
with self.ManagedSessionMaker() as session:
|
1818
|
+
logged_model = session.query(SqlLoggedModel).get(model_id)
|
1819
|
+
if not logged_model:
|
1820
|
+
self._raise_model_not_found(model_id)
|
1821
|
+
|
1822
|
+
session.add_all(
|
1823
|
+
SqlLoggedModelParam(
|
1824
|
+
model_id=model_id,
|
1825
|
+
experiment_id=logged_model.experiment_id,
|
1826
|
+
param_key=param.key,
|
1827
|
+
param_value=param.value,
|
1828
|
+
)
|
1829
|
+
for param in params
|
1830
|
+
)
|
1831
|
+
|
1832
|
+
def _raise_model_not_found(self, model_id: str):
|
1833
|
+
raise MlflowException(
|
1834
|
+
f"Logged model with ID '{model_id}' not found.",
|
1835
|
+
RESOURCE_DOES_NOT_EXIST,
|
1836
|
+
)
|
1837
|
+
|
1838
|
+
def get_logged_model(self, model_id: str) -> LoggedModel:
|
1839
|
+
with self.ManagedSessionMaker() as session:
|
1840
|
+
logged_model = (
|
1841
|
+
session.query(SqlLoggedModel)
|
1842
|
+
.filter(
|
1843
|
+
SqlLoggedModel.model_id == model_id,
|
1844
|
+
SqlLoggedModel.lifecycle_stage != LifecycleStage.DELETED,
|
1845
|
+
)
|
1846
|
+
.first()
|
1847
|
+
)
|
1848
|
+
if not logged_model:
|
1849
|
+
self._raise_model_not_found(model_id)
|
1850
|
+
|
1851
|
+
return logged_model.to_mlflow_entity()
|
1852
|
+
|
1853
|
+
def delete_logged_model(self, model_id):
|
1854
|
+
with self.ManagedSessionMaker() as session:
|
1855
|
+
logged_model = session.query(SqlLoggedModel).get(model_id)
|
1856
|
+
if not logged_model:
|
1857
|
+
self._raise_model_not_found(model_id)
|
1858
|
+
|
1859
|
+
logged_model.lifecycle_stage = LifecycleStage.DELETED
|
1860
|
+
logged_model.last_updated_timestamp_ms = get_current_time_millis()
|
1861
|
+
session.commit()
|
1862
|
+
|
1863
|
+
def finalize_logged_model(self, model_id: str, status: LoggedModelStatus) -> LoggedModel:
|
1864
|
+
with self.ManagedSessionMaker() as session:
|
1865
|
+
logged_model = session.query(SqlLoggedModel).get(model_id)
|
1866
|
+
if not logged_model:
|
1867
|
+
self._raise_model_not_found(model_id)
|
1868
|
+
|
1869
|
+
logged_model.status = status.to_int()
|
1870
|
+
logged_model.last_updated_timestamp_ms = get_current_time_millis()
|
1871
|
+
session.commit()
|
1872
|
+
return logged_model.to_mlflow_entity()
|
1873
|
+
|
1874
|
+
def set_logged_model_tags(self, model_id: str, tags: list[LoggedModelTag]) -> None:
|
1875
|
+
with self.ManagedSessionMaker() as session:
|
1876
|
+
logged_model = session.query(SqlLoggedModel).get(model_id)
|
1877
|
+
if not logged_model:
|
1878
|
+
self._raise_model_not_found(model_id)
|
1879
|
+
|
1880
|
+
# TODO: Consider upserting tags in a single transaction for performance
|
1881
|
+
for tag in tags:
|
1882
|
+
session.merge(
|
1883
|
+
SqlLoggedModelTag(
|
1884
|
+
model_id=model_id,
|
1885
|
+
experiment_id=logged_model.experiment_id,
|
1886
|
+
tag_key=tag.key,
|
1887
|
+
tag_value=tag.value,
|
1888
|
+
)
|
1889
|
+
)
|
1890
|
+
|
1891
|
+
def delete_logged_model_tag(self, model_id: str, key: str) -> None:
|
1892
|
+
with self.ManagedSessionMaker() as session:
|
1893
|
+
logged_model = session.query(SqlLoggedModel).get(model_id)
|
1894
|
+
if not logged_model:
|
1895
|
+
self._raise_model_not_found(model_id)
|
1896
|
+
|
1897
|
+
count = (
|
1898
|
+
session.query(SqlLoggedModelTag)
|
1899
|
+
.filter(
|
1900
|
+
SqlLoggedModelTag.model_id == model_id,
|
1901
|
+
SqlLoggedModelTag.tag_key == key,
|
1902
|
+
)
|
1903
|
+
.delete()
|
1904
|
+
)
|
1905
|
+
if count == 0:
|
1906
|
+
raise MlflowException(
|
1907
|
+
f"No tag with key {key!r} found for model with ID {model_id!r}.",
|
1908
|
+
RESOURCE_DOES_NOT_EXIST,
|
1909
|
+
)
|
1910
|
+
|
1911
|
+
def _apply_order_by_search_logged_models(
|
1912
|
+
self,
|
1913
|
+
models: sqlalchemy.orm.Query,
|
1914
|
+
session: sqlalchemy.orm.Session,
|
1915
|
+
order_by: Optional[list[dict[str, Any]]] = None,
|
1916
|
+
) -> sqlalchemy.orm.Query:
|
1917
|
+
order_by_clauses = []
|
1918
|
+
has_creation_timestamp = False
|
1919
|
+
for ob in order_by or []:
|
1920
|
+
field_name = ob.get("field_name")
|
1921
|
+
ascending = ob.get("ascending", True)
|
1922
|
+
if "." not in field_name:
|
1923
|
+
name = SqlLoggedModel.ALIASES.get(field_name, field_name)
|
1924
|
+
if name == "creation_timestamp_ms":
|
1925
|
+
has_creation_timestamp = True
|
1926
|
+
try:
|
1927
|
+
col = getattr(SqlLoggedModel, name)
|
1928
|
+
except AttributeError:
|
1929
|
+
raise MlflowException.invalid_parameter_value(
|
1930
|
+
f"Invalid order by field name: {field_name}"
|
1931
|
+
)
|
1932
|
+
# Why not use `nulls_last`? Because it's not supported by all dialects (e.g., MySQL)
|
1933
|
+
order_by_clauses.extend(
|
1934
|
+
[
|
1935
|
+
# Sort nulls last
|
1936
|
+
sqlalchemy.case((col.is_(None), 1), else_=0).asc(),
|
1937
|
+
col.asc() if ascending else col.desc(),
|
1938
|
+
]
|
1939
|
+
)
|
1940
|
+
continue
|
1941
|
+
|
1942
|
+
entity, name = field_name.split(".", 1)
|
1943
|
+
# TODO: Support filtering by other entities such as params if needed
|
1944
|
+
if entity != "metrics":
|
1945
|
+
raise MlflowException.invalid_parameter_value(
|
1946
|
+
f"Invalid order by field name: {field_name}. Only metrics are supported."
|
1947
|
+
)
|
1948
|
+
|
1949
|
+
# Sub query to get the latest metrics value for each (model_id, metric_name) pair
|
1950
|
+
dataset_filter = []
|
1951
|
+
if dataset_name := ob.get("dataset_name"):
|
1952
|
+
dataset_filter.append(SqlLoggedModelMetric.dataset_name == dataset_name)
|
1953
|
+
if dataset_digest := ob.get("dataset_digest"):
|
1954
|
+
dataset_filter.append(SqlLoggedModelMetric.dataset_digest == dataset_digest)
|
1955
|
+
|
1956
|
+
subquery = (
|
1957
|
+
session.query(
|
1958
|
+
SqlLoggedModelMetric.model_id,
|
1959
|
+
SqlLoggedModelMetric.metric_value,
|
1960
|
+
func.rank()
|
1961
|
+
.over(
|
1962
|
+
partition_by=[
|
1963
|
+
SqlLoggedModelMetric.model_id,
|
1964
|
+
SqlLoggedModelMetric.metric_name,
|
1965
|
+
],
|
1966
|
+
order_by=[
|
1967
|
+
SqlLoggedModelMetric.metric_timestamp_ms.desc(),
|
1968
|
+
SqlLoggedModelMetric.metric_step.desc(),
|
1969
|
+
],
|
1970
|
+
)
|
1971
|
+
.label("rank"),
|
1972
|
+
)
|
1973
|
+
.filter(
|
1974
|
+
SqlLoggedModelMetric.metric_name == name,
|
1975
|
+
*dataset_filter,
|
1976
|
+
)
|
1977
|
+
.subquery()
|
1978
|
+
)
|
1979
|
+
subquery = select(subquery.c).where(subquery.c.rank == 1).subquery()
|
1980
|
+
|
1981
|
+
models = models.outerjoin(subquery)
|
1982
|
+
# Why not use `nulls_last`? Because it's not supported by all dialects (e.g., MySQL)
|
1983
|
+
order_by_clauses.extend(
|
1984
|
+
[
|
1985
|
+
# Sort nulls last
|
1986
|
+
sqlalchemy.case((subquery.c.metric_value.is_(None), 1), else_=0).asc(),
|
1987
|
+
subquery.c.metric_value.asc() if ascending else subquery.c.metric_value.desc(),
|
1988
|
+
]
|
1989
|
+
)
|
1990
|
+
|
1991
|
+
if not has_creation_timestamp:
|
1992
|
+
order_by_clauses.append(SqlLoggedModel.creation_timestamp_ms.desc())
|
1993
|
+
|
1994
|
+
return models.order_by(*order_by_clauses)
|
1995
|
+
|
1996
|
+
def _apply_filter_string_datasets_search_logged_models(
|
1997
|
+
self,
|
1998
|
+
models: sqlalchemy.orm.Query,
|
1999
|
+
session: sqlalchemy.orm.Session,
|
2000
|
+
experiment_ids: list[str],
|
2001
|
+
filter_string: Optional[str],
|
2002
|
+
datasets: Optional[list[dict[str, Any]]],
|
2003
|
+
):
|
2004
|
+
from mlflow.utils.search_logged_model_utils import EntityType, parse_filter_string
|
2005
|
+
|
2006
|
+
comparisons = parse_filter_string(filter_string)
|
2007
|
+
dialect = self._get_dialect()
|
2008
|
+
attr_filters: list[sqlalchemy.BinaryExpression] = []
|
2009
|
+
non_attr_filters: list[sqlalchemy.BinaryExpression] = []
|
2010
|
+
|
2011
|
+
dataset_filters = []
|
2012
|
+
if datasets:
|
2013
|
+
for dataset in datasets:
|
2014
|
+
dataset_filter = SqlLoggedModelMetric.dataset_name == dataset["dataset_name"]
|
2015
|
+
if "dataset_digest" in dataset:
|
2016
|
+
dataset_filter = dataset_filter & (
|
2017
|
+
SqlLoggedModelMetric.dataset_digest == dataset["dataset_digest"]
|
2018
|
+
)
|
2019
|
+
dataset_filters.append(dataset_filter)
|
2020
|
+
|
2021
|
+
has_metric_filters = False
|
2022
|
+
for comp in comparisons:
|
2023
|
+
comp_func = SearchUtils.get_sql_comparison_func(comp.op, dialect)
|
2024
|
+
if comp.entity.type == EntityType.ATTRIBUTE:
|
2025
|
+
attr_filters.append(comp_func(getattr(SqlLoggedModel, comp.entity.key), comp.value))
|
2026
|
+
elif comp.entity.type == EntityType.METRIC:
|
2027
|
+
has_metric_filters = True
|
2028
|
+
metric_filters = [
|
2029
|
+
SqlLoggedModelMetric.metric_name == comp.entity.key,
|
2030
|
+
comp_func(SqlLoggedModelMetric.metric_value, comp.value),
|
2031
|
+
]
|
2032
|
+
if dataset_filters:
|
2033
|
+
metric_filters.append(sqlalchemy.or_(*dataset_filters))
|
2034
|
+
non_attr_filters.append(
|
2035
|
+
session.query(SqlLoggedModelMetric).filter(*metric_filters).subquery()
|
2036
|
+
)
|
2037
|
+
elif comp.entity.type == EntityType.PARAM:
|
2038
|
+
non_attr_filters.append(
|
2039
|
+
session.query(SqlLoggedModelParam)
|
2040
|
+
.filter(
|
2041
|
+
SqlLoggedModelParam.param_key == comp.entity.key,
|
2042
|
+
comp_func(SqlLoggedModelParam.param_value, comp.value),
|
2043
|
+
)
|
2044
|
+
.subquery()
|
2045
|
+
)
|
2046
|
+
elif comp.entity.type == EntityType.TAG:
|
2047
|
+
non_attr_filters.append(
|
2048
|
+
session.query(SqlLoggedModelTag)
|
2049
|
+
.filter(
|
2050
|
+
SqlLoggedModelTag.tag_key == comp.entity.key,
|
2051
|
+
comp_func(SqlLoggedModelTag.tag_value, comp.value),
|
2052
|
+
)
|
2053
|
+
.subquery()
|
2054
|
+
)
|
2055
|
+
|
2056
|
+
for f in non_attr_filters:
|
2057
|
+
models = models.join(f)
|
2058
|
+
|
2059
|
+
# If there are dataset filters but no metric filters,
|
2060
|
+
# filter for models that have any metrics on the datasets
|
2061
|
+
if dataset_filters and not has_metric_filters:
|
2062
|
+
subquery = (
|
2063
|
+
session.query(SqlLoggedModelMetric.model_id)
|
2064
|
+
.filter(sqlalchemy.or_(*dataset_filters))
|
2065
|
+
.distinct()
|
2066
|
+
.subquery()
|
2067
|
+
)
|
2068
|
+
models = models.join(subquery)
|
2069
|
+
|
2070
|
+
return models.filter(
|
2071
|
+
SqlLoggedModel.lifecycle_stage != LifecycleStage.DELETED,
|
2072
|
+
SqlLoggedModel.experiment_id.in_(experiment_ids),
|
2073
|
+
*attr_filters,
|
2074
|
+
)
|
2075
|
+
|
2076
|
+
def search_logged_models(
|
2077
|
+
self,
|
2078
|
+
experiment_ids: list[str],
|
2079
|
+
filter_string: Optional[str] = None,
|
2080
|
+
datasets: Optional[list[DatasetFilter]] = None,
|
2081
|
+
max_results: Optional[int] = None,
|
2082
|
+
order_by: Optional[list[dict[str, Any]]] = None,
|
2083
|
+
page_token: Optional[str] = None,
|
2084
|
+
) -> PagedList[LoggedModel]:
|
2085
|
+
if datasets and not all(d.get("dataset_name") for d in datasets):
|
2086
|
+
raise MlflowException(
|
2087
|
+
"`dataset_name` in the `datasets` clause must be specified.",
|
2088
|
+
INVALID_PARAMETER_VALUE,
|
2089
|
+
)
|
2090
|
+
if page_token:
|
2091
|
+
token = SearchLoggedModelsPaginationToken.decode(page_token)
|
2092
|
+
token.validate(experiment_ids, filter_string, order_by)
|
2093
|
+
offset = token.offset
|
2094
|
+
else:
|
2095
|
+
offset = 0
|
2096
|
+
|
2097
|
+
max_results = max_results or SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT
|
2098
|
+
with self.ManagedSessionMaker() as session:
|
2099
|
+
models = session.query(SqlLoggedModel)
|
2100
|
+
models = self._apply_filter_string_datasets_search_logged_models(
|
2101
|
+
models, session, experiment_ids, filter_string, datasets
|
2102
|
+
)
|
2103
|
+
models = self._apply_order_by_search_logged_models(models, session, order_by)
|
2104
|
+
models = models.offset(offset).limit(max_results + 1).all()
|
2105
|
+
|
2106
|
+
if len(models) > max_results:
|
2107
|
+
token = SearchLoggedModelsPaginationToken(
|
2108
|
+
offset=offset + max_results,
|
2109
|
+
experiment_ids=experiment_ids,
|
2110
|
+
filter_string=filter_string,
|
2111
|
+
order_by=order_by,
|
2112
|
+
).encode()
|
2113
|
+
else:
|
2114
|
+
token = None
|
2115
|
+
|
2116
|
+
return PagedList([lm.to_mlflow_entity() for lm in models[:max_results]], token=token)
|
2117
|
+
|
2118
|
+
#######################################################################################
|
2119
|
+
# Below are Tracing APIs. We may refactor them to be in a separate class in the future.
|
2120
|
+
#######################################################################################
|
2121
|
+
def _get_trace_artifact_location_tag(self, experiment, trace_id: str) -> SqlTraceTag:
|
2122
|
+
# Trace data is stored as file artifacts regardless of the tracking backend choice.
|
2123
|
+
# We use subdirectory "/traces" under the experiment's artifact location to isolate
|
2124
|
+
# them from run artifacts.
|
2125
|
+
artifact_uri = append_to_uri_path(
|
2126
|
+
experiment.artifact_location,
|
2127
|
+
SqlAlchemyStore.TRACE_FOLDER_NAME,
|
2128
|
+
trace_id,
|
2129
|
+
SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
|
2130
|
+
)
|
2131
|
+
return SqlTraceTag(request_id=trace_id, key=MLFLOW_ARTIFACT_LOCATION, value=artifact_uri)
|
2132
|
+
|
2133
|
+
def start_trace(self, trace_info: "TraceInfo") -> TraceInfo:
|
2134
|
+
"""
|
2135
|
+
Create a trace using the V3 API format with a complete Trace object.
|
2136
|
+
|
2137
|
+
Args:
|
2138
|
+
trace_info: The TraceInfo object to create in the backend.
|
2139
|
+
|
2140
|
+
Returns:
|
2141
|
+
The created TraceInfo object from the backend.
|
2142
|
+
"""
|
2143
|
+
with self.ManagedSessionMaker() as session:
|
2144
|
+
experiment = self.get_experiment(trace_info.experiment_id)
|
2145
|
+
self._check_experiment_is_active(experiment)
|
2146
|
+
|
2147
|
+
# Use the provided trace_id
|
2148
|
+
trace_id = trace_info.trace_id
|
2149
|
+
|
2150
|
+
# Create SqlTraceInfo with V3 fields directly
|
2151
|
+
sql_trace_info = SqlTraceInfo(
|
2152
|
+
request_id=trace_id,
|
2153
|
+
experiment_id=trace_info.experiment_id,
|
2154
|
+
timestamp_ms=trace_info.request_time,
|
2155
|
+
execution_time_ms=trace_info.execution_duration,
|
2156
|
+
status=trace_info.state.value,
|
2157
|
+
client_request_id=trace_info.client_request_id,
|
2158
|
+
request_preview=trace_info.request_preview,
|
2159
|
+
response_preview=trace_info.response_preview,
|
2160
|
+
)
|
2161
|
+
|
2162
|
+
sql_trace_info.tags = [
|
2163
|
+
SqlTraceTag(request_id=trace_id, key=k, value=v) for k, v in trace_info.tags.items()
|
2164
|
+
]
|
2165
|
+
sql_trace_info.tags.append(self._get_trace_artifact_location_tag(experiment, trace_id))
|
2166
|
+
|
2167
|
+
sql_trace_info.request_metadata = [
|
2168
|
+
SqlTraceMetadata(request_id=trace_id, key=k, value=v)
|
2169
|
+
for k, v in trace_info.trace_metadata.items()
|
2170
|
+
]
|
2171
|
+
session.add(sql_trace_info)
|
2172
|
+
return sql_trace_info.to_mlflow_entity()
|
2173
|
+
|
2174
|
+
def get_trace_info(self, trace_id: str) -> TraceInfo:
|
2175
|
+
"""
|
2176
|
+
Fetch the trace info for the given trace id.
|
2177
|
+
|
2178
|
+
Args:
|
2179
|
+
trace_id: Unique string identifier of the trace.
|
2180
|
+
|
2181
|
+
Returns:
|
2182
|
+
The TraceInfo object.
|
2183
|
+
"""
|
2184
|
+
with self.ManagedSessionMaker() as session:
|
2185
|
+
sql_trace_info = self._get_sql_trace_info(session, trace_id)
|
2186
|
+
return sql_trace_info.to_mlflow_entity()
|
2187
|
+
|
2188
|
+
def _get_sql_trace_info(self, session, trace_id) -> SqlTraceInfo:
|
2189
|
+
sql_trace_info = (
|
2190
|
+
session.query(SqlTraceInfo).filter(SqlTraceInfo.request_id == trace_id).one_or_none()
|
2191
|
+
)
|
2192
|
+
if sql_trace_info is None:
|
2193
|
+
raise MlflowException(
|
2194
|
+
f"Trace with ID '{trace_id}' not found.",
|
2195
|
+
RESOURCE_DOES_NOT_EXIST,
|
2196
|
+
)
|
2197
|
+
return sql_trace_info
|
2198
|
+
|
2199
|
+
def search_traces(
|
2200
|
+
self,
|
2201
|
+
experiment_ids: list[str],
|
2202
|
+
filter_string: Optional[str] = None,
|
2203
|
+
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
|
2204
|
+
order_by: Optional[list[str]] = None,
|
2205
|
+
page_token: Optional[str] = None,
|
2206
|
+
model_id: Optional[str] = None,
|
2207
|
+
sql_warehouse_id: Optional[str] = None,
|
2208
|
+
) -> tuple[list[TraceInfo], Optional[str]]:
|
2209
|
+
"""
|
2210
|
+
Return traces that match the given list of search expressions within the experiments.
|
2211
|
+
|
2212
|
+
Args:
|
2213
|
+
experiment_ids: List of experiment ids to scope the search.
|
2214
|
+
filter_string: A search filter string.
|
2215
|
+
max_results: Maximum number of traces desired.
|
2216
|
+
order_by: List of order_by clauses.
|
2217
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
2218
|
+
a ``search_traces`` call.
|
2219
|
+
model_id: If specified, search traces associated with the given model ID.
|
2220
|
+
sql_warehouse_id: Only used in Databricks. The ID of the SQL warehouse to use for
|
2221
|
+
searching traces in inference tables.
|
2222
|
+
|
2223
|
+
Returns:
|
2224
|
+
A tuple of a list of :py:class:`TraceInfo <mlflow.entities.TraceInfo>` objects that
|
2225
|
+
satisfy the search expressions and a pagination token for the next page of results.
|
2226
|
+
"""
|
2227
|
+
self._validate_max_results_param(max_results)
|
2228
|
+
|
2229
|
+
with self.ManagedSessionMaker() as session:
|
2230
|
+
cases_orderby, parsed_orderby, sorting_joins = _get_orderby_clauses_for_search_traces(
|
2231
|
+
order_by or [], session
|
2232
|
+
)
|
2233
|
+
stmt = select(SqlTraceInfo, *cases_orderby)
|
2234
|
+
|
2235
|
+
attribute_filters, non_attribute_filters = _get_filter_clauses_for_search_traces(
|
2236
|
+
filter_string, session, self._get_dialect()
|
2237
|
+
)
|
2238
|
+
for non_attr_filter in non_attribute_filters:
|
2239
|
+
stmt = stmt.join(non_attr_filter)
|
2240
|
+
|
2241
|
+
# using an outer join is necessary here because we want to be able to sort
|
2242
|
+
# on a column (tag, metric or param) without removing the lines that
|
2243
|
+
# do not have a value for this column (which is what inner join would do)
|
2244
|
+
for j in sorting_joins:
|
2245
|
+
stmt = stmt.outerjoin(j)
|
2246
|
+
|
2247
|
+
offset = SearchTraceUtils.parse_start_offset_from_page_token(page_token)
|
2248
|
+
stmt = (
|
2249
|
+
# NB: We don't need to distinct the results of joins because of the fact that
|
2250
|
+
# the right tables of the joins are unique on the join key, trace_id.
|
2251
|
+
# This is because the subquery that is joined on the right side is conditioned
|
2252
|
+
# by a key and value pair of tags/metadata, and the combination of key and
|
2253
|
+
# trace_id is unique in those tables.
|
2254
|
+
# Be careful when changing the query building logic, as it may break this
|
2255
|
+
# uniqueness property and require deduplication, which can be expensive.
|
2256
|
+
stmt.filter(
|
2257
|
+
SqlTraceInfo.experiment_id.in_(experiment_ids),
|
2258
|
+
*attribute_filters,
|
2259
|
+
)
|
2260
|
+
.order_by(*parsed_orderby)
|
2261
|
+
.offset(offset)
|
2262
|
+
.limit(max_results)
|
2263
|
+
)
|
2264
|
+
queried_traces = session.execute(stmt).scalars(SqlTraceInfo).all()
|
2265
|
+
trace_infos = [t.to_mlflow_entity() for t in queried_traces]
|
2266
|
+
|
2267
|
+
# Compute next search token
|
2268
|
+
if max_results == len(trace_infos):
|
2269
|
+
final_offset = offset + max_results
|
2270
|
+
next_token = SearchTraceUtils.create_page_token(final_offset)
|
2271
|
+
else:
|
2272
|
+
next_token = None
|
2273
|
+
|
2274
|
+
return trace_infos, next_token
|
2275
|
+
|
2276
|
+
def _validate_max_results_param(self, max_results: int, allow_null=False):
|
2277
|
+
if (not allow_null and max_results is None) or max_results < 1:
|
2278
|
+
raise MlflowException(
|
2279
|
+
f"Invalid value {max_results} for parameter 'max_results' supplied. It must be "
|
2280
|
+
f"a positive integer",
|
2281
|
+
INVALID_PARAMETER_VALUE,
|
2282
|
+
)
|
2283
|
+
|
2284
|
+
if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
|
2285
|
+
raise MlflowException(
|
2286
|
+
f"Invalid value {max_results} for parameter 'max_results' supplied. It must be at "
|
2287
|
+
f"most {SEARCH_MAX_RESULTS_THRESHOLD}",
|
2288
|
+
INVALID_PARAMETER_VALUE,
|
2289
|
+
)
|
2290
|
+
|
2291
|
+
def set_trace_tag(self, trace_id: str, key: str, value: str):
|
2292
|
+
"""
|
2293
|
+
Set a tag on the trace with the given trace_id.
|
2294
|
+
|
2295
|
+
Args:
|
2296
|
+
trace_id: The ID of the trace.
|
2297
|
+
key: The string key of the tag.
|
2298
|
+
value: The string value of the tag.
|
2299
|
+
"""
|
2300
|
+
with self.ManagedSessionMaker() as session:
|
2301
|
+
key, value = _validate_trace_tag(key, value)
|
2302
|
+
session.merge(SqlTraceTag(request_id=trace_id, key=key, value=value))
|
2303
|
+
|
2304
|
+
def delete_trace_tag(self, trace_id: str, key: str):
|
2305
|
+
"""
|
2306
|
+
Delete a tag on the trace with the given trace_id.
|
2307
|
+
|
2308
|
+
Args:
|
2309
|
+
trace_id: The ID of the trace.
|
2310
|
+
key: The string key of the tag.
|
2311
|
+
"""
|
2312
|
+
with self.ManagedSessionMaker() as session:
|
2313
|
+
tags = session.query(SqlTraceTag).filter_by(request_id=trace_id, key=key)
|
2314
|
+
if tags.count() == 0:
|
2315
|
+
raise MlflowException(
|
2316
|
+
f"No trace tag with key '{key}' for trace with ID '{trace_id}'",
|
2317
|
+
RESOURCE_DOES_NOT_EXIST,
|
2318
|
+
)
|
2319
|
+
tags.delete()
|
2320
|
+
|
2321
|
+
def _delete_traces(
|
2322
|
+
self,
|
2323
|
+
experiment_id: str,
|
2324
|
+
max_timestamp_millis: Optional[int] = None,
|
2325
|
+
max_traces: Optional[int] = None,
|
2326
|
+
trace_ids: Optional[list[str]] = None,
|
2327
|
+
) -> int:
|
2328
|
+
"""
|
2329
|
+
Delete traces based on the specified criteria.
|
2330
|
+
|
2331
|
+
Args:
|
2332
|
+
experiment_id: ID of the associated experiment.
|
2333
|
+
max_timestamp_millis: The maximum timestamp in milliseconds since the UNIX epoch for
|
2334
|
+
deleting traces. Traces older than or equal to this timestamp will be deleted.
|
2335
|
+
max_traces: The maximum number of traces to delete.
|
2336
|
+
trace_ids: A set of request IDs to delete.
|
2337
|
+
|
2338
|
+
Returns:
|
2339
|
+
The number of traces deleted.
|
2340
|
+
"""
|
2341
|
+
with self.ManagedSessionMaker() as session:
|
2342
|
+
filters = [SqlTraceInfo.experiment_id == experiment_id]
|
2343
|
+
if max_timestamp_millis:
|
2344
|
+
filters.append(SqlTraceInfo.timestamp_ms <= max_timestamp_millis)
|
2345
|
+
if trace_ids:
|
2346
|
+
filters.append(SqlTraceInfo.request_id.in_(trace_ids))
|
2347
|
+
if max_traces:
|
2348
|
+
filters.append(
|
2349
|
+
SqlTraceInfo.request_id.in_(
|
2350
|
+
session.query(SqlTraceInfo.request_id)
|
2351
|
+
.filter(*filters)
|
2352
|
+
# Delete the oldest traces first
|
2353
|
+
.order_by(SqlTraceInfo.timestamp_ms)
|
2354
|
+
.limit(max_traces)
|
2355
|
+
.subquery()
|
2356
|
+
)
|
2357
|
+
)
|
2358
|
+
|
2359
|
+
return (
|
2360
|
+
session.query(SqlTraceInfo)
|
2361
|
+
.filter(and_(*filters))
|
2362
|
+
.delete(synchronize_session="fetch")
|
2363
|
+
)
|
2364
|
+
|
2365
|
+
#######################################################################################
|
2366
|
+
# Below are legacy V2 Tracing APIs. DO NOT USE. Use the V3 APIs instead.
|
2367
|
+
#######################################################################################
|
2368
|
+
def deprecated_start_trace_v2(
|
2369
|
+
self,
|
2370
|
+
experiment_id: str,
|
2371
|
+
timestamp_ms: int,
|
2372
|
+
request_metadata: dict[str, str],
|
2373
|
+
tags: dict[str, str],
|
2374
|
+
) -> TraceInfoV2:
|
2375
|
+
"""
|
2376
|
+
DEPRECATED. DO NOT USE.
|
2377
|
+
|
2378
|
+
Create an initial TraceInfo object in the database.
|
2379
|
+
|
2380
|
+
Args:
|
2381
|
+
experiment_id: String id of the experiment for this run.
|
2382
|
+
timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
|
2383
|
+
request_metadata: Metadata of the trace.
|
2384
|
+
tags: Tags of the trace.
|
2385
|
+
|
2386
|
+
Returns:
|
2387
|
+
The created TraceInfo object.
|
2388
|
+
"""
|
2389
|
+
with self.ManagedSessionMaker() as session:
|
2390
|
+
experiment = self.get_experiment(experiment_id)
|
2391
|
+
self._check_experiment_is_active(experiment)
|
2392
|
+
|
2393
|
+
request_id = generate_request_id_v2()
|
2394
|
+
trace_info = SqlTraceInfo(
|
2395
|
+
request_id=request_id,
|
2396
|
+
experiment_id=experiment_id,
|
2397
|
+
timestamp_ms=timestamp_ms,
|
2398
|
+
execution_time_ms=None,
|
2399
|
+
status=TraceStatus.IN_PROGRESS,
|
2400
|
+
)
|
2401
|
+
|
2402
|
+
trace_info.tags = [SqlTraceTag(key=k, value=v) for k, v in tags.items()]
|
2403
|
+
trace_info.tags.append(self._get_trace_artifact_location_tag(experiment, request_id))
|
2404
|
+
|
2405
|
+
trace_info.request_metadata = [
|
2406
|
+
SqlTraceMetadata(key=k, value=v) for k, v in request_metadata.items()
|
2407
|
+
]
|
2408
|
+
session.add(trace_info)
|
2409
|
+
|
2410
|
+
return TraceInfoV2.from_v3(trace_info.to_mlflow_entity())
|
2411
|
+
|
2412
|
+
def deprecated_end_trace_v2(
|
2413
|
+
self,
|
2414
|
+
request_id: str,
|
2415
|
+
timestamp_ms: int,
|
2416
|
+
status: TraceStatus,
|
2417
|
+
request_metadata: dict[str, str],
|
2418
|
+
tags: dict[str, str],
|
2419
|
+
) -> TraceInfoV2:
|
2420
|
+
"""
|
2421
|
+
DEPRECATED. DO NOT USE.
|
2422
|
+
|
2423
|
+
Update the TraceInfo object in the database with the completed trace info.
|
2424
|
+
|
2425
|
+
Args:
|
2426
|
+
request_id: Unique string identifier of the trace.
|
2427
|
+
timestamp_ms: End time of the trace, in milliseconds. The execution time field
|
2428
|
+
in the TraceInfo will be calculated by subtracting the start time from this.
|
2429
|
+
status: Status of the trace.
|
2430
|
+
request_metadata: Metadata of the trace. This will be merged with the existing
|
2431
|
+
metadata logged during the start_trace call.
|
2432
|
+
tags: Tags of the trace. This will be merged with the existing tags logged
|
2433
|
+
during the start_trace or set_trace_tag calls.
|
2434
|
+
|
2435
|
+
Returns:
|
2436
|
+
The updated TraceInfo object.
|
2437
|
+
"""
|
2438
|
+
with self.ManagedSessionMaker() as session:
|
2439
|
+
sql_trace_info = self._get_sql_trace_info(session, request_id)
|
2440
|
+
trace_start_time_ms = sql_trace_info.timestamp_ms
|
2441
|
+
execution_time_ms = timestamp_ms - trace_start_time_ms
|
2442
|
+
sql_trace_info.execution_time_ms = execution_time_ms
|
2443
|
+
sql_trace_info.status = status
|
2444
|
+
session.merge(sql_trace_info)
|
2445
|
+
for k, v in request_metadata.items():
|
2446
|
+
session.merge(SqlTraceMetadata(request_id=request_id, key=k, value=v))
|
2447
|
+
for k, v in tags.items():
|
2448
|
+
session.merge(SqlTraceTag(request_id=request_id, key=k, value=v))
|
2449
|
+
return TraceInfoV2.from_v3(sql_trace_info.to_mlflow_entity())
|
2450
|
+
|
2451
|
+
|
2452
|
+
def _get_sqlalchemy_filter_clauses(parsed, session, dialect):
|
2453
|
+
"""
|
2454
|
+
Creates run attribute filters and subqueries that will be inner-joined to SqlRun to act as
|
2455
|
+
multi-clause filters and return them as a tuple.
|
2456
|
+
"""
|
2457
|
+
attribute_filters = []
|
2458
|
+
non_attribute_filters = []
|
2459
|
+
dataset_filters = []
|
2460
|
+
|
2461
|
+
for sql_statement in parsed:
|
2462
|
+
key_type = sql_statement.get("type")
|
2463
|
+
key_name = sql_statement.get("key")
|
2464
|
+
value = sql_statement.get("value")
|
2465
|
+
comparator = sql_statement.get("comparator").upper()
|
2466
|
+
|
2467
|
+
key_name = SearchUtils.translate_key_alias(key_name)
|
2468
|
+
|
2469
|
+
if SearchUtils.is_string_attribute(
|
2470
|
+
key_type, key_name, comparator
|
2471
|
+
) or SearchUtils.is_numeric_attribute(key_type, key_name, comparator):
|
2472
|
+
if key_name == "run_name":
|
2473
|
+
# Treat "attributes.run_name == <value>" as "tags.`mlflow.runName` == <value>".
|
2474
|
+
# The name column in the runs table is empty for runs logged in MLflow <= 1.29.0.
|
2475
|
+
key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
|
2476
|
+
SqlTag.key, MLFLOW_RUN_NAME
|
2477
|
+
)
|
2478
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2479
|
+
SqlTag.value, value
|
2480
|
+
)
|
2481
|
+
non_attribute_filters.append(
|
2482
|
+
session.query(SqlTag).filter(key_filter, val_filter).subquery()
|
2483
|
+
)
|
2484
|
+
else:
|
2485
|
+
attribute = getattr(SqlRun, SqlRun.get_attribute_name(key_name))
|
2486
|
+
attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2487
|
+
attribute, value
|
2488
|
+
)
|
2489
|
+
attribute_filters.append(attr_filter)
|
2490
|
+
else:
|
2491
|
+
if SearchUtils.is_metric(key_type, comparator):
|
2492
|
+
entity = SqlLatestMetric
|
2493
|
+
value = float(value)
|
2494
|
+
elif SearchUtils.is_param(key_type, comparator):
|
2495
|
+
entity = SqlParam
|
2496
|
+
elif SearchUtils.is_tag(key_type, comparator):
|
2497
|
+
entity = SqlTag
|
2498
|
+
elif SearchUtils.is_dataset(key_type, comparator):
|
2499
|
+
entity = SqlDataset
|
2500
|
+
else:
|
2501
|
+
raise MlflowException(
|
2502
|
+
f"Invalid search expression type '{key_type}'",
|
2503
|
+
error_code=INVALID_PARAMETER_VALUE,
|
2504
|
+
)
|
2505
|
+
|
2506
|
+
if entity == SqlDataset:
|
2507
|
+
if key_name == "context":
|
2508
|
+
dataset_filters.append(
|
2509
|
+
session.query(entity, SqlInput, SqlInputTag)
|
2510
|
+
.join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
|
2511
|
+
.join(
|
2512
|
+
SqlInputTag,
|
2513
|
+
and_(
|
2514
|
+
SqlInputTag.input_uuid == SqlInput.input_uuid,
|
2515
|
+
SqlInputTag.name == MLFLOW_DATASET_CONTEXT,
|
2516
|
+
SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2517
|
+
getattr(SqlInputTag, "value"), value
|
2518
|
+
),
|
2519
|
+
),
|
2520
|
+
)
|
2521
|
+
.subquery()
|
2522
|
+
)
|
2523
|
+
else:
|
2524
|
+
dataset_attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2525
|
+
getattr(SqlDataset, key_name), value
|
2526
|
+
)
|
2527
|
+
dataset_filters.append(
|
2528
|
+
session.query(entity, SqlInput)
|
2529
|
+
.join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
|
2530
|
+
.filter(dataset_attr_filter)
|
2531
|
+
.subquery()
|
2532
|
+
)
|
2533
|
+
else:
|
2534
|
+
key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(entity.key, key_name)
|
2535
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2536
|
+
entity.value, value
|
2537
|
+
)
|
2538
|
+
non_attribute_filters.append(
|
2539
|
+
session.query(entity).filter(key_filter, val_filter).subquery()
|
2540
|
+
)
|
2541
|
+
|
2542
|
+
return attribute_filters, non_attribute_filters, dataset_filters
|
2543
|
+
|
2544
|
+
|
2545
|
+
def _get_orderby_clauses(order_by_list, session):
|
2546
|
+
"""Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
|
2547
|
+
Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
|
2548
|
+
"""
|
2549
|
+
|
2550
|
+
clauses = []
|
2551
|
+
ordering_joins = []
|
2552
|
+
clause_id = 0
|
2553
|
+
observed_order_by_clauses = set()
|
2554
|
+
select_clauses = []
|
2555
|
+
# contrary to filters, it is not easily feasible to separately handle sorting
|
2556
|
+
# on attributes and on joined tables as we must keep all clauses in the same order
|
2557
|
+
if order_by_list:
|
2558
|
+
for order_by_clause in order_by_list:
|
2559
|
+
clause_id += 1
|
2560
|
+
(key_type, key, ascending) = SearchUtils.parse_order_by_for_search_runs(order_by_clause)
|
2561
|
+
key = SearchUtils.translate_key_alias(key)
|
2562
|
+
if SearchUtils.is_string_attribute(
|
2563
|
+
key_type, key, "="
|
2564
|
+
) or SearchUtils.is_numeric_attribute(key_type, key, "="):
|
2565
|
+
order_value = getattr(SqlRun, SqlRun.get_attribute_name(key))
|
2566
|
+
else:
|
2567
|
+
if SearchUtils.is_metric(key_type, "="): # any valid comparator
|
2568
|
+
entity = SqlLatestMetric
|
2569
|
+
elif SearchUtils.is_tag(key_type, "="):
|
2570
|
+
entity = SqlTag
|
2571
|
+
elif SearchUtils.is_param(key_type, "="):
|
2572
|
+
entity = SqlParam
|
2573
|
+
else:
|
2574
|
+
raise MlflowException(
|
2575
|
+
f"Invalid identifier type '{key_type}'",
|
2576
|
+
error_code=INVALID_PARAMETER_VALUE,
|
2577
|
+
)
|
2578
|
+
|
2579
|
+
# build a subquery first because we will join it in the main request so that the
|
2580
|
+
# metric we want to sort on is available when we apply the sorting clause
|
2581
|
+
subquery = session.query(entity).filter(entity.key == key).subquery()
|
2582
|
+
|
2583
|
+
ordering_joins.append(subquery)
|
2584
|
+
order_value = subquery.c.value
|
2585
|
+
|
2586
|
+
# MySQL does not support NULLS LAST expression, so we sort first by
|
2587
|
+
# presence of the field (and is_nan for metrics), then by actual value
|
2588
|
+
# As the subqueries are created independently and used later in the
|
2589
|
+
# same main query, the CASE WHEN columns need to have unique names to
|
2590
|
+
# avoid ambiguity
|
2591
|
+
if SearchUtils.is_metric(key_type, "="):
|
2592
|
+
case = sql.case(
|
2593
|
+
# Ideally the use of "IS" is preferred here but owing to sqlalchemy
|
2594
|
+
# translation in MSSQL we are forced to use "=" instead.
|
2595
|
+
# These 2 options are functionally identical / unchanged because
|
2596
|
+
# the column (is_nan) is not nullable. However it could become an issue
|
2597
|
+
# if this precondition changes in the future.
|
2598
|
+
(subquery.c.is_nan == sqlalchemy.true(), 1),
|
2599
|
+
(order_value.is_(None), 2),
|
2600
|
+
else_=0,
|
2601
|
+
).label(f"clause_{clause_id}")
|
2602
|
+
|
2603
|
+
else: # other entities do not have an 'is_nan' field
|
2604
|
+
case = sql.case((order_value.is_(None), 1), else_=0).label(f"clause_{clause_id}")
|
2605
|
+
clauses.append(case.name)
|
2606
|
+
select_clauses.append(case)
|
2607
|
+
select_clauses.append(order_value)
|
2608
|
+
|
2609
|
+
if (key_type, key) in observed_order_by_clauses:
|
2610
|
+
raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
|
2611
|
+
observed_order_by_clauses.add((key_type, key))
|
2612
|
+
|
2613
|
+
if ascending:
|
2614
|
+
clauses.append(order_value)
|
2615
|
+
else:
|
2616
|
+
clauses.append(order_value.desc())
|
2617
|
+
|
2618
|
+
if (
|
2619
|
+
SearchUtils._ATTRIBUTE_IDENTIFIER,
|
2620
|
+
SqlRun.start_time.key,
|
2621
|
+
) not in observed_order_by_clauses:
|
2622
|
+
clauses.append(SqlRun.start_time.desc())
|
2623
|
+
clauses.append(SqlRun.run_uuid)
|
2624
|
+
return select_clauses, clauses, ordering_joins
|
2625
|
+
|
2626
|
+
|
2627
|
+
def _get_search_experiments_filter_clauses(parsed_filters, dialect):
|
2628
|
+
attribute_filters = []
|
2629
|
+
non_attribute_filters = []
|
2630
|
+
for f in parsed_filters:
|
2631
|
+
type_ = f["type"]
|
2632
|
+
key = f["key"]
|
2633
|
+
comparator = f["comparator"]
|
2634
|
+
value = f["value"]
|
2635
|
+
if type_ == "attribute":
|
2636
|
+
if SearchExperimentsUtils.is_string_attribute(
|
2637
|
+
type_, key, comparator
|
2638
|
+
) and comparator not in ("=", "!=", "LIKE", "ILIKE"):
|
2639
|
+
raise MlflowException.invalid_parameter_value(
|
2640
|
+
f"Invalid comparator for string attribute: {comparator}"
|
2641
|
+
)
|
2642
|
+
if SearchExperimentsUtils.is_numeric_attribute(
|
2643
|
+
type_, key, comparator
|
2644
|
+
) and comparator not in ("=", "!=", "<", "<=", ">", ">="):
|
2645
|
+
raise MlflowException.invalid_parameter_value(
|
2646
|
+
f"Invalid comparator for numeric attribute: {comparator}"
|
2647
|
+
)
|
2648
|
+
attr = getattr(SqlExperiment, key)
|
2649
|
+
attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(attr, value)
|
2650
|
+
attribute_filters.append(attr_filter)
|
2651
|
+
elif type_ == "tag":
|
2652
|
+
if comparator not in ("=", "!=", "LIKE", "ILIKE"):
|
2653
|
+
raise MlflowException.invalid_parameter_value(
|
2654
|
+
f"Invalid comparator for tag: {comparator}"
|
2655
|
+
)
|
2656
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
2657
|
+
SqlExperimentTag.value, value
|
2658
|
+
)
|
2659
|
+
key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
|
2660
|
+
SqlExperimentTag.key, key
|
2661
|
+
)
|
2662
|
+
non_attribute_filters.append(
|
2663
|
+
select(SqlExperimentTag).filter(key_filter, val_filter).subquery()
|
2664
|
+
)
|
2665
|
+
else:
|
2666
|
+
raise MlflowException.invalid_parameter_value(f"Invalid token type: {type_}")
|
2667
|
+
|
2668
|
+
return attribute_filters, non_attribute_filters
|
2669
|
+
|
2670
|
+
|
2671
|
+
def _get_search_experiments_order_by_clauses(order_by):
|
2672
|
+
order_by_clauses = []
|
2673
|
+
for type_, key, ascending in map(
|
2674
|
+
SearchExperimentsUtils.parse_order_by_for_search_experiments,
|
2675
|
+
order_by or ["creation_time DESC", "experiment_id ASC"],
|
2676
|
+
):
|
2677
|
+
if type_ == "attribute":
|
2678
|
+
order_by_clauses.append((getattr(SqlExperiment, key), ascending))
|
2679
|
+
else:
|
2680
|
+
raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
|
2681
|
+
|
2682
|
+
# Add a tie-breaker
|
2683
|
+
if not any(col == SqlExperiment.experiment_id for col, _ in order_by_clauses):
|
2684
|
+
order_by_clauses.append((SqlExperiment.experiment_id, False))
|
2685
|
+
|
2686
|
+
return [col.asc() if ascending else col.desc() for col, ascending in order_by_clauses]
|
2687
|
+
|
2688
|
+
|
2689
|
+
def _get_orderby_clauses_for_search_traces(order_by_list: list[str], session):
|
2690
|
+
"""Sorts a set of traces based on their natural ordering and an overriding set of order_bys.
|
2691
|
+
Traces are ordered first by timestamp_ms descending, then by trace_id for tie-breaking.
|
2692
|
+
"""
|
2693
|
+
clauses = []
|
2694
|
+
ordering_joins = []
|
2695
|
+
observed_order_by_clauses = set()
|
2696
|
+
select_clauses = []
|
2697
|
+
|
2698
|
+
for clause_id, order_by_clause in enumerate(order_by_list):
|
2699
|
+
(key_type, key, ascending) = SearchTraceUtils.parse_order_by_for_search_traces(
|
2700
|
+
order_by_clause
|
2701
|
+
)
|
2702
|
+
|
2703
|
+
if SearchTraceUtils.is_attribute(key_type, key, "="):
|
2704
|
+
order_value = getattr(SqlTraceInfo, key)
|
2705
|
+
else:
|
2706
|
+
if SearchTraceUtils.is_tag(key_type, "="):
|
2707
|
+
entity = SqlTraceTag
|
2708
|
+
elif SearchTraceUtils.is_request_metadata(key_type, "="):
|
2709
|
+
entity = SqlTraceMetadata
|
2710
|
+
else:
|
2711
|
+
raise MlflowException(
|
2712
|
+
f"Invalid identifier type '{key_type}'",
|
2713
|
+
error_code=INVALID_PARAMETER_VALUE,
|
2714
|
+
)
|
2715
|
+
# Tags and request metadata requires a join to the main table (trace_info)
|
2716
|
+
subquery = session.query(entity).filter(entity.key == key).subquery()
|
2717
|
+
ordering_joins.append(subquery)
|
2718
|
+
order_value = subquery.c.value
|
2719
|
+
|
2720
|
+
case = sql.case((order_value.is_(None), 1), else_=0).label(f"clause_{clause_id}")
|
2721
|
+
clauses.append(case.name)
|
2722
|
+
select_clauses.append(case)
|
2723
|
+
select_clauses.append(order_value)
|
2724
|
+
|
2725
|
+
if (key_type, key) in observed_order_by_clauses:
|
2726
|
+
raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
|
2727
|
+
observed_order_by_clauses.add((key_type, key))
|
2728
|
+
clauses.append(order_value if ascending else order_value.desc())
|
2729
|
+
|
2730
|
+
# Add descending trace start time as default ordering and a tie-breaker
|
2731
|
+
for attr, ascending in [
|
2732
|
+
(SqlTraceInfo.timestamp_ms, False),
|
2733
|
+
(SqlTraceInfo.request_id, True),
|
2734
|
+
]:
|
2735
|
+
if (
|
2736
|
+
SearchTraceUtils._ATTRIBUTE_IDENTIFIER,
|
2737
|
+
attr.key,
|
2738
|
+
) not in observed_order_by_clauses:
|
2739
|
+
clauses.append(attr if ascending else attr.desc())
|
2740
|
+
return select_clauses, clauses, ordering_joins
|
2741
|
+
|
2742
|
+
|
2743
|
+
def _get_filter_clauses_for_search_traces(filter_string, session, dialect):
|
2744
|
+
"""
|
2745
|
+
Creates trace attribute filters and subqueries that will be inner-joined
|
2746
|
+
to SqlTraceInfo to act as multi-clause filters and return them as a tuple.
|
2747
|
+
"""
|
2748
|
+
attribute_filters = []
|
2749
|
+
non_attribute_filters = []
|
2750
|
+
|
2751
|
+
parsed_filters = SearchTraceUtils.parse_search_filter_for_search_traces(filter_string)
|
2752
|
+
for sql_statement in parsed_filters:
|
2753
|
+
key_type = sql_statement.get("type")
|
2754
|
+
key_name = sql_statement.get("key")
|
2755
|
+
value = sql_statement.get("value")
|
2756
|
+
comparator = sql_statement.get("comparator").upper()
|
2757
|
+
|
2758
|
+
if SearchTraceUtils.is_attribute(key_type, key_name, comparator):
|
2759
|
+
attribute = getattr(SqlTraceInfo, key_name)
|
2760
|
+
attr_filter = SearchTraceUtils.get_sql_comparison_func(comparator, dialect)(
|
2761
|
+
attribute, value
|
2762
|
+
)
|
2763
|
+
attribute_filters.append(attr_filter)
|
2764
|
+
else:
|
2765
|
+
if SearchTraceUtils.is_tag(key_type, comparator):
|
2766
|
+
entity = SqlTraceTag
|
2767
|
+
elif SearchTraceUtils.is_request_metadata(key_type, comparator):
|
2768
|
+
entity = SqlTraceMetadata
|
2769
|
+
else:
|
2770
|
+
raise MlflowException(
|
2771
|
+
f"Invalid search expression type '{key_type}'",
|
2772
|
+
error_code=INVALID_PARAMETER_VALUE,
|
2773
|
+
)
|
2774
|
+
|
2775
|
+
key_filter = SearchTraceUtils.get_sql_comparison_func("=", dialect)(
|
2776
|
+
entity.key, key_name
|
2777
|
+
)
|
2778
|
+
val_filter = SearchTraceUtils.get_sql_comparison_func(comparator, dialect)(
|
2779
|
+
entity.value, value
|
2780
|
+
)
|
2781
|
+
non_attribute_filters.append(
|
2782
|
+
session.query(entity).filter(key_filter, val_filter).subquery()
|
2783
|
+
)
|
2784
|
+
|
2785
|
+
return attribute_filters, non_attribute_filters
|