genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1131 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from typing import Any, Optional
|
4
|
+
|
5
|
+
from mlflow.entities import (
|
6
|
+
DatasetInput,
|
7
|
+
Experiment,
|
8
|
+
LoggedModel,
|
9
|
+
LoggedModelInput,
|
10
|
+
LoggedModelOutput,
|
11
|
+
LoggedModelParameter,
|
12
|
+
LoggedModelStatus,
|
13
|
+
LoggedModelTag,
|
14
|
+
Metric,
|
15
|
+
Run,
|
16
|
+
RunInfo,
|
17
|
+
ViewType,
|
18
|
+
)
|
19
|
+
from mlflow.entities.assessment import Assessment, Expectation, Feedback
|
20
|
+
from mlflow.entities.trace import Trace
|
21
|
+
from mlflow.entities.trace_data import TraceData
|
22
|
+
from mlflow.entities.trace_info import TraceInfo
|
23
|
+
from mlflow.entities.trace_info_v2 import TraceInfoV2
|
24
|
+
from mlflow.entities.trace_location import TraceLocation
|
25
|
+
from mlflow.entities.trace_status import TraceStatus
|
26
|
+
from mlflow.environment_variables import (
|
27
|
+
_MLFLOW_CREATE_LOGGED_MODEL_PARAMS_BATCH_SIZE,
|
28
|
+
_MLFLOW_LOG_LOGGED_MODEL_PARAMS_BATCH_SIZE,
|
29
|
+
MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT,
|
30
|
+
)
|
31
|
+
from mlflow.exceptions import MlflowException
|
32
|
+
from mlflow.protos import databricks_pb2
|
33
|
+
from mlflow.protos.service_pb2 import (
|
34
|
+
CreateAssessment,
|
35
|
+
CreateExperiment,
|
36
|
+
CreateLoggedModel,
|
37
|
+
CreateRun,
|
38
|
+
DeleteAssessment,
|
39
|
+
DeleteExperiment,
|
40
|
+
DeleteLoggedModel,
|
41
|
+
DeleteLoggedModelTag,
|
42
|
+
DeleteRun,
|
43
|
+
DeleteTag,
|
44
|
+
DeleteTraces,
|
45
|
+
DeleteTraceTag,
|
46
|
+
EndTrace,
|
47
|
+
FinalizeLoggedModel,
|
48
|
+
GetAssessmentRequest,
|
49
|
+
GetExperiment,
|
50
|
+
GetExperimentByName,
|
51
|
+
GetLoggedModel,
|
52
|
+
GetMetricHistory,
|
53
|
+
GetOnlineTraceDetails,
|
54
|
+
GetRun,
|
55
|
+
GetTraceInfo,
|
56
|
+
GetTraceInfoV3,
|
57
|
+
LogBatch,
|
58
|
+
LogInputs,
|
59
|
+
LogLoggedModelParamsRequest,
|
60
|
+
LogMetric,
|
61
|
+
LogModel,
|
62
|
+
LogOutputs,
|
63
|
+
LogParam,
|
64
|
+
MlflowService,
|
65
|
+
RestoreExperiment,
|
66
|
+
RestoreRun,
|
67
|
+
SearchExperiments,
|
68
|
+
SearchLoggedModels,
|
69
|
+
SearchRuns,
|
70
|
+
SearchTraces,
|
71
|
+
SearchTracesV3,
|
72
|
+
SearchUnifiedTraces,
|
73
|
+
SetExperimentTag,
|
74
|
+
SetLoggedModelTags,
|
75
|
+
SetTag,
|
76
|
+
SetTraceTag,
|
77
|
+
StartTrace,
|
78
|
+
StartTraceV3,
|
79
|
+
TraceRequestMetadata,
|
80
|
+
TraceTag,
|
81
|
+
UpdateAssessment,
|
82
|
+
UpdateExperiment,
|
83
|
+
UpdateRun,
|
84
|
+
)
|
85
|
+
from mlflow.store.entities.paged_list import PagedList
|
86
|
+
from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS
|
87
|
+
from mlflow.store.tracking.abstract_store import AbstractStore
|
88
|
+
from mlflow.utils.proto_json_utils import message_to_json
|
89
|
+
from mlflow.utils.rest_utils import (
|
90
|
+
_REST_API_PATH_PREFIX,
|
91
|
+
_V3_TRACE_REST_API_PATH_PREFIX,
|
92
|
+
call_endpoint,
|
93
|
+
extract_api_info_for_service,
|
94
|
+
get_logged_model_endpoint,
|
95
|
+
get_single_assessment_endpoint,
|
96
|
+
get_single_trace_endpoint,
|
97
|
+
get_trace_tag_endpoint,
|
98
|
+
)
|
99
|
+
|
100
|
+
_METHOD_TO_INFO = extract_api_info_for_service(MlflowService, _REST_API_PATH_PREFIX)
|
101
|
+
_logger = logging.getLogger(__name__)
|
102
|
+
|
103
|
+
|
104
|
+
class RestStore(AbstractStore):
|
105
|
+
"""
|
106
|
+
Client for a remote tracking server accessed via REST API calls
|
107
|
+
|
108
|
+
Args
|
109
|
+
get_host_creds: Method to be invoked prior to every REST request to get the
|
110
|
+
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
|
111
|
+
is a function so that we can obtain fresh credentials in the case of expiry.
|
112
|
+
"""
|
113
|
+
|
114
|
+
def __init__(self, get_host_creds):
|
115
|
+
super().__init__()
|
116
|
+
self.get_host_creds = get_host_creds
|
117
|
+
|
118
|
+
def _call_endpoint(
|
119
|
+
self,
|
120
|
+
api,
|
121
|
+
json_body=None,
|
122
|
+
endpoint=None,
|
123
|
+
retry_timeout_seconds=None,
|
124
|
+
):
|
125
|
+
if endpoint:
|
126
|
+
# Allow customizing the endpoint for compatibility with dynamic endpoints, such as
|
127
|
+
# /mlflow/traces/{trace_id}/info.
|
128
|
+
_, method = _METHOD_TO_INFO[api]
|
129
|
+
else:
|
130
|
+
endpoint, method = _METHOD_TO_INFO[api]
|
131
|
+
response_proto = api.Response()
|
132
|
+
return call_endpoint(
|
133
|
+
self.get_host_creds(),
|
134
|
+
endpoint,
|
135
|
+
method,
|
136
|
+
json_body,
|
137
|
+
response_proto,
|
138
|
+
retry_timeout_seconds=retry_timeout_seconds,
|
139
|
+
)
|
140
|
+
|
141
|
+
def search_experiments(
|
142
|
+
self,
|
143
|
+
view_type=ViewType.ACTIVE_ONLY,
|
144
|
+
max_results=None,
|
145
|
+
filter_string=None,
|
146
|
+
order_by=None,
|
147
|
+
page_token=None,
|
148
|
+
):
|
149
|
+
req_body = message_to_json(
|
150
|
+
SearchExperiments(
|
151
|
+
view_type=view_type,
|
152
|
+
max_results=max_results,
|
153
|
+
page_token=page_token,
|
154
|
+
order_by=order_by,
|
155
|
+
filter=filter_string,
|
156
|
+
)
|
157
|
+
)
|
158
|
+
response_proto = self._call_endpoint(SearchExperiments, req_body)
|
159
|
+
experiments = [Experiment.from_proto(x) for x in response_proto.experiments]
|
160
|
+
token = (
|
161
|
+
response_proto.next_page_token if response_proto.HasField("next_page_token") else None
|
162
|
+
)
|
163
|
+
return PagedList(experiments, token)
|
164
|
+
|
165
|
+
def create_experiment(self, name, artifact_location=None, tags=None):
|
166
|
+
"""
|
167
|
+
Create a new experiment.
|
168
|
+
If an experiment with the given name already exists, throws exception.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
name: Desired name for an experiment.
|
172
|
+
artifact_location: Location to store run artifacts.
|
173
|
+
tags: A list of :py:class:`mlflow.entities.ExperimentTag` instances to set for the
|
174
|
+
experiment.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
experiment_id for the newly created experiment if successful, else None
|
178
|
+
"""
|
179
|
+
tag_protos = [tag.to_proto() for tag in tags] if tags else []
|
180
|
+
req_body = message_to_json(
|
181
|
+
CreateExperiment(name=name, artifact_location=artifact_location, tags=tag_protos)
|
182
|
+
)
|
183
|
+
response_proto = self._call_endpoint(CreateExperiment, req_body)
|
184
|
+
return response_proto.experiment_id
|
185
|
+
|
186
|
+
def get_experiment(self, experiment_id):
|
187
|
+
"""
|
188
|
+
Fetch the experiment from the backend store.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
experiment_id: String id for the experiment
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
A single :py:class:`mlflow.entities.Experiment` object if it exists,
|
195
|
+
otherwise raises an Exception.
|
196
|
+
"""
|
197
|
+
req_body = message_to_json(GetExperiment(experiment_id=str(experiment_id)))
|
198
|
+
response_proto = self._call_endpoint(GetExperiment, req_body)
|
199
|
+
return Experiment.from_proto(response_proto.experiment)
|
200
|
+
|
201
|
+
def delete_experiment(self, experiment_id):
|
202
|
+
req_body = message_to_json(DeleteExperiment(experiment_id=str(experiment_id)))
|
203
|
+
self._call_endpoint(DeleteExperiment, req_body)
|
204
|
+
|
205
|
+
def restore_experiment(self, experiment_id):
|
206
|
+
req_body = message_to_json(RestoreExperiment(experiment_id=str(experiment_id)))
|
207
|
+
self._call_endpoint(RestoreExperiment, req_body)
|
208
|
+
|
209
|
+
def rename_experiment(self, experiment_id, new_name):
|
210
|
+
req_body = message_to_json(
|
211
|
+
UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name)
|
212
|
+
)
|
213
|
+
self._call_endpoint(UpdateExperiment, req_body)
|
214
|
+
|
215
|
+
def get_run(self, run_id):
|
216
|
+
"""
|
217
|
+
Fetch the run from backend store
|
218
|
+
|
219
|
+
Args:
|
220
|
+
run_id: Unique identifier for the run
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
A single Run object if it exists, otherwise raises an Exception
|
224
|
+
"""
|
225
|
+
req_body = message_to_json(GetRun(run_uuid=run_id, run_id=run_id))
|
226
|
+
response_proto = self._call_endpoint(GetRun, req_body)
|
227
|
+
return Run.from_proto(response_proto.run)
|
228
|
+
|
229
|
+
def update_run_info(self, run_id, run_status, end_time, run_name):
|
230
|
+
"""Updates the metadata of the specified run."""
|
231
|
+
req_body = message_to_json(
|
232
|
+
UpdateRun(
|
233
|
+
run_uuid=run_id,
|
234
|
+
run_id=run_id,
|
235
|
+
status=run_status,
|
236
|
+
end_time=end_time,
|
237
|
+
run_name=run_name,
|
238
|
+
)
|
239
|
+
)
|
240
|
+
response_proto = self._call_endpoint(UpdateRun, req_body)
|
241
|
+
return RunInfo.from_proto(response_proto.run_info)
|
242
|
+
|
243
|
+
def create_run(self, experiment_id, user_id, start_time, tags, run_name):
|
244
|
+
"""
|
245
|
+
Create a run under the specified experiment ID, setting the run's status to "RUNNING"
|
246
|
+
and the start time to the current time.
|
247
|
+
|
248
|
+
Args:
|
249
|
+
experiment_id: ID of the experiment for this run.
|
250
|
+
user_id: ID of the user launching this run.
|
251
|
+
start_time: timestamp of the initialization of the run.
|
252
|
+
tags: tags to apply to this run at initialization.
|
253
|
+
run_name: Name of this run.
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
The created Run object.
|
257
|
+
"""
|
258
|
+
|
259
|
+
tag_protos = [tag.to_proto() for tag in tags]
|
260
|
+
req_body = message_to_json(
|
261
|
+
CreateRun(
|
262
|
+
experiment_id=str(experiment_id),
|
263
|
+
user_id=user_id,
|
264
|
+
start_time=start_time,
|
265
|
+
tags=tag_protos,
|
266
|
+
run_name=run_name,
|
267
|
+
)
|
268
|
+
)
|
269
|
+
response_proto = self._call_endpoint(CreateRun, req_body)
|
270
|
+
return Run.from_proto(response_proto.run)
|
271
|
+
|
272
|
+
def start_trace(self, trace_info: TraceInfo) -> TraceInfo:
|
273
|
+
"""
|
274
|
+
Create a new trace using the V3 API format.
|
275
|
+
|
276
|
+
NB: The backend API is named "StartTraceV3" for some internal reason, but actually
|
277
|
+
it is supposed to be called at the end of the trace.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
trace_info: The TraceInfo object to create in the backend.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
The returned TraceInfo object from the backend.
|
284
|
+
"""
|
285
|
+
# NB: The Databricks backend expects a Trace object, not a TraceInfo object, although
|
286
|
+
# it doesn't use the data field at all. Trace data increases the payload size significantly,
|
287
|
+
# so we create a Trace object with an empty data field here.
|
288
|
+
trace = Trace(info=trace_info, data=TraceData(spans=[]))
|
289
|
+
req_body = message_to_json(StartTraceV3(trace=trace.to_proto()))
|
290
|
+
|
291
|
+
try:
|
292
|
+
response_proto = self._call_endpoint(
|
293
|
+
# NB: _call_endpoint doesn't handle versioning between v2 and v3 endpoint
|
294
|
+
# yet, so manually passing the v3 endpoint here.
|
295
|
+
StartTraceV3,
|
296
|
+
req_body,
|
297
|
+
endpoint=_V3_TRACE_REST_API_PATH_PREFIX,
|
298
|
+
retry_timeout_seconds=MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT.get(),
|
299
|
+
)
|
300
|
+
return TraceInfo.from_proto(response_proto.trace.trace_info)
|
301
|
+
except MlflowException as e:
|
302
|
+
if e.error_code == databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
|
303
|
+
_logger.debug(
|
304
|
+
"Server does not support StartTraceV3 API yet. Falling back to V2 API."
|
305
|
+
)
|
306
|
+
return self._create_trace_v2_fallback(trace_info)
|
307
|
+
raise
|
308
|
+
|
309
|
+
def _create_trace_v2_fallback(self, trace_info: TraceInfo) -> TraceInfo:
|
310
|
+
"""
|
311
|
+
Create a new trace using the V2 API format. This is a fallback for the case where the
|
312
|
+
client is v3 but the tracking server does not support v3 yet(<= 3.2.0).
|
313
|
+
"""
|
314
|
+
trace_info_v2 = self.deprecated_start_trace_v2(
|
315
|
+
experiment_id=trace_info.experiment_id,
|
316
|
+
timestamp_ms=trace_info.request_time,
|
317
|
+
request_metadata=trace_info.trace_metadata,
|
318
|
+
tags=trace_info.tags,
|
319
|
+
)
|
320
|
+
self.deprecated_end_trace_v2(
|
321
|
+
request_id=trace_info_v2.request_id,
|
322
|
+
timestamp_ms=trace_info.request_time + trace_info.execution_duration,
|
323
|
+
status=trace_info.status,
|
324
|
+
request_metadata=trace_info.trace_metadata,
|
325
|
+
tags=trace_info.tags,
|
326
|
+
)
|
327
|
+
return trace_info_v2.to_v3()
|
328
|
+
|
329
|
+
def _delete_traces(
|
330
|
+
self,
|
331
|
+
experiment_id: str,
|
332
|
+
max_timestamp_millis: Optional[int] = None,
|
333
|
+
max_traces: Optional[int] = None,
|
334
|
+
trace_ids: Optional[list[str]] = None,
|
335
|
+
) -> int:
|
336
|
+
req_body = message_to_json(
|
337
|
+
DeleteTraces(
|
338
|
+
experiment_id=experiment_id,
|
339
|
+
max_timestamp_millis=max_timestamp_millis,
|
340
|
+
max_traces=max_traces,
|
341
|
+
request_ids=trace_ids,
|
342
|
+
)
|
343
|
+
)
|
344
|
+
res = self._call_endpoint(DeleteTraces, req_body)
|
345
|
+
return res.traces_deleted
|
346
|
+
|
347
|
+
def get_trace_info(self, trace_id: str) -> TraceInfo:
|
348
|
+
"""
|
349
|
+
Get the trace matching the `trace_id`.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
trace_id: String id of the trace to fetch.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
The fetched Trace object, of type ``mlflow.entities.TraceInfo``.
|
356
|
+
"""
|
357
|
+
trace_v3_req_body = message_to_json(GetTraceInfoV3(trace_id=trace_id))
|
358
|
+
trace_v3_endpoint = get_single_trace_endpoint(trace_id)
|
359
|
+
try:
|
360
|
+
trace_v3_response_proto = self._call_endpoint(
|
361
|
+
GetTraceInfoV3, trace_v3_req_body, endpoint=trace_v3_endpoint
|
362
|
+
)
|
363
|
+
return TraceInfo.from_proto(trace_v3_response_proto.trace.trace_info)
|
364
|
+
except MlflowException as e:
|
365
|
+
# If the tracking server does not support V3 trace API yet, fallback to V2 API.
|
366
|
+
if e.error_code != databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
|
367
|
+
raise
|
368
|
+
_logger.debug("Server does not support GetTraceInfoV3 API yet. Falling back to V2 API.")
|
369
|
+
|
370
|
+
req_body = message_to_json(GetTraceInfo(request_id=trace_id))
|
371
|
+
endpoint = get_single_trace_endpoint(trace_id, use_v3=False)
|
372
|
+
response_proto = self._call_endpoint(GetTraceInfo, req_body, endpoint=endpoint)
|
373
|
+
return TraceInfoV2.from_proto(response_proto.trace_info).to_v3()
|
374
|
+
|
375
|
+
def get_online_trace_details(
|
376
|
+
self,
|
377
|
+
trace_id: str,
|
378
|
+
sql_warehouse_id: str,
|
379
|
+
source_inference_table: str,
|
380
|
+
source_databricks_request_id: str,
|
381
|
+
):
|
382
|
+
req = GetOnlineTraceDetails(
|
383
|
+
trace_id=trace_id,
|
384
|
+
sql_warehouse_id=sql_warehouse_id,
|
385
|
+
source_inference_table=source_inference_table,
|
386
|
+
source_databricks_request_id=source_databricks_request_id,
|
387
|
+
)
|
388
|
+
req_body = message_to_json(req)
|
389
|
+
response_proto = self._call_endpoint(GetOnlineTraceDetails, req_body)
|
390
|
+
return response_proto.trace_data
|
391
|
+
|
392
|
+
def search_traces(
|
393
|
+
self,
|
394
|
+
experiment_ids: list[str],
|
395
|
+
filter_string: Optional[str] = None,
|
396
|
+
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
|
397
|
+
order_by: Optional[list[str]] = None,
|
398
|
+
page_token: Optional[str] = None,
|
399
|
+
model_id: Optional[str] = None,
|
400
|
+
sql_warehouse_id: Optional[str] = None,
|
401
|
+
):
|
402
|
+
if sql_warehouse_id is None:
|
403
|
+
# Create trace_locations from experiment_ids for the V3 API
|
404
|
+
trace_locations = []
|
405
|
+
for exp_id in experiment_ids:
|
406
|
+
try:
|
407
|
+
location = TraceLocation.from_experiment_id(exp_id)
|
408
|
+
proto_location = location.to_proto()
|
409
|
+
trace_locations.append(proto_location)
|
410
|
+
except Exception as e:
|
411
|
+
raise MlflowException(
|
412
|
+
f"Invalid experiment ID format: {exp_id}. Error: {e!s}"
|
413
|
+
) from e
|
414
|
+
|
415
|
+
# Create V3 request message using protobuf
|
416
|
+
request = SearchTracesV3(
|
417
|
+
locations=trace_locations,
|
418
|
+
filter=filter_string,
|
419
|
+
max_results=max_results,
|
420
|
+
order_by=order_by,
|
421
|
+
page_token=page_token,
|
422
|
+
)
|
423
|
+
|
424
|
+
req_body = message_to_json(request)
|
425
|
+
v3_endpoint = f"{_V3_TRACE_REST_API_PATH_PREFIX}/search"
|
426
|
+
|
427
|
+
try:
|
428
|
+
response_proto = self._call_endpoint(SearchTracesV3, req_body, v3_endpoint)
|
429
|
+
except MlflowException as e:
|
430
|
+
if e.error_code == databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
|
431
|
+
_logger.debug(
|
432
|
+
"Server does not support SearchTracesV3 API yet. Falling back to V2 API."
|
433
|
+
)
|
434
|
+
response_proto = self._call_endpoint(SearchTraces, req_body)
|
435
|
+
else:
|
436
|
+
raise
|
437
|
+
|
438
|
+
trace_infos = [TraceInfo.from_proto(t) for t in response_proto.traces]
|
439
|
+
else:
|
440
|
+
response_proto = self._search_unified_traces(
|
441
|
+
model_id=model_id,
|
442
|
+
sql_warehouse_id=sql_warehouse_id,
|
443
|
+
experiment_ids=experiment_ids,
|
444
|
+
filter_string=filter_string,
|
445
|
+
max_results=max_results,
|
446
|
+
order_by=order_by,
|
447
|
+
page_token=page_token,
|
448
|
+
)
|
449
|
+
# Convert TraceInfo (v2) objects to TraceInfoV3 objects for consistency
|
450
|
+
trace_infos = [TraceInfo.from_proto(t) for t in response_proto.traces]
|
451
|
+
return trace_infos, response_proto.next_page_token or None
|
452
|
+
|
453
|
+
def _search_unified_traces(
|
454
|
+
self,
|
455
|
+
model_id: str,
|
456
|
+
sql_warehouse_id: str,
|
457
|
+
experiment_ids: list[str],
|
458
|
+
filter_string: Optional[str] = None,
|
459
|
+
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
|
460
|
+
order_by: Optional[list[str]] = None,
|
461
|
+
page_token: Optional[str] = None,
|
462
|
+
):
|
463
|
+
request = SearchUnifiedTraces(
|
464
|
+
model_id=model_id,
|
465
|
+
sql_warehouse_id=sql_warehouse_id,
|
466
|
+
experiment_ids=experiment_ids,
|
467
|
+
filter=filter_string,
|
468
|
+
max_results=max_results,
|
469
|
+
order_by=order_by,
|
470
|
+
page_token=page_token,
|
471
|
+
)
|
472
|
+
req_body = message_to_json(request)
|
473
|
+
return self._call_endpoint(SearchUnifiedTraces, req_body)
|
474
|
+
|
475
|
+
def set_trace_tag(self, trace_id: str, key: str, value: str):
|
476
|
+
"""
|
477
|
+
Set a tag on the trace with the given trace_id.
|
478
|
+
|
479
|
+
Args:
|
480
|
+
trace_id: The ID of the trace.
|
481
|
+
key: The string key of the tag.
|
482
|
+
value: The string value of the tag.
|
483
|
+
"""
|
484
|
+
# Always use v2 endpoint
|
485
|
+
req_body = message_to_json(SetTraceTag(key=key, value=value))
|
486
|
+
self._call_endpoint(SetTraceTag, req_body, endpoint=get_trace_tag_endpoint(trace_id))
|
487
|
+
|
488
|
+
def delete_trace_tag(self, trace_id: str, key: str):
|
489
|
+
"""
|
490
|
+
Delete a tag on the trace with the given trace_id.
|
491
|
+
|
492
|
+
Args:
|
493
|
+
trace_id: The ID of the trace.
|
494
|
+
key: The string key of the tag.
|
495
|
+
"""
|
496
|
+
# Always use v2 endpoint
|
497
|
+
req_body = message_to_json(DeleteTraceTag(key=key))
|
498
|
+
self._call_endpoint(DeleteTraceTag, req_body, endpoint=get_trace_tag_endpoint(trace_id))
|
499
|
+
|
500
|
+
def get_assessment(self, trace_id: str, assessment_id: str) -> Assessment:
|
501
|
+
"""
|
502
|
+
Get an assessment entity from the backend store.
|
503
|
+
"""
|
504
|
+
req_body = message_to_json(
|
505
|
+
GetAssessmentRequest(trace_id=trace_id, assessment_id=assessment_id)
|
506
|
+
)
|
507
|
+
response_proto = self._call_endpoint(
|
508
|
+
GetAssessmentRequest,
|
509
|
+
req_body,
|
510
|
+
endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
|
511
|
+
)
|
512
|
+
return Assessment.from_proto(response_proto.assessment)
|
513
|
+
|
514
|
+
def create_assessment(self, assessment: Assessment) -> Assessment:
|
515
|
+
"""
|
516
|
+
Create an assessment entity in the backend store.
|
517
|
+
|
518
|
+
Args:
|
519
|
+
assessment: The assessment to log (without an assessment_id).
|
520
|
+
|
521
|
+
Returns:
|
522
|
+
The created Assessment object.
|
523
|
+
"""
|
524
|
+
req_body = message_to_json(CreateAssessment(assessment=assessment.to_proto()))
|
525
|
+
response_proto = self._call_endpoint(
|
526
|
+
CreateAssessment,
|
527
|
+
req_body,
|
528
|
+
endpoint=f"{_V3_TRACE_REST_API_PATH_PREFIX}/{assessment.trace_id}/assessments",
|
529
|
+
)
|
530
|
+
return Assessment.from_proto(response_proto.assessment)
|
531
|
+
|
532
|
+
def update_assessment(
|
533
|
+
self,
|
534
|
+
trace_id: str,
|
535
|
+
assessment_id: str,
|
536
|
+
name: Optional[str] = None,
|
537
|
+
expectation: Optional[Expectation] = None,
|
538
|
+
feedback: Optional[Feedback] = None,
|
539
|
+
rationale: Optional[str] = None,
|
540
|
+
metadata: Optional[dict[str, str]] = None,
|
541
|
+
) -> Assessment:
|
542
|
+
"""
|
543
|
+
Update an existing assessment entity in the backend store.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
trace_id: The ID of the trace.
|
547
|
+
assessment_id: The ID of the assessment to update.
|
548
|
+
name: The updated name of the assessment.
|
549
|
+
expectation: The updated expectation value of the assessment.
|
550
|
+
feedback: The updated feedback value of the assessment.
|
551
|
+
rationale: The updated rationale of the feedback. Not applicable for expectations.
|
552
|
+
metadata: Additional metadata for the assessment.
|
553
|
+
"""
|
554
|
+
if expectation is not None and feedback is not None:
|
555
|
+
raise MlflowException.invalid_parameter_value(
|
556
|
+
"Exactly one of `expectation` or `feedback` should be specified."
|
557
|
+
)
|
558
|
+
|
559
|
+
update = UpdateAssessment()
|
560
|
+
|
561
|
+
# The assessment object to be sent to the backend (only contains fields to update and IDs)
|
562
|
+
assessment = update.assessment
|
563
|
+
# Field mask specifies which fields to update.
|
564
|
+
mask = update.update_mask
|
565
|
+
|
566
|
+
assessment.assessment_id = assessment_id
|
567
|
+
assessment.trace_id = trace_id
|
568
|
+
|
569
|
+
if name is not None:
|
570
|
+
assessment.assessment_name = name
|
571
|
+
mask.paths.append("assessment_name")
|
572
|
+
if expectation is not None:
|
573
|
+
assessment.expectation.CopyFrom(expectation.to_proto())
|
574
|
+
mask.paths.append("expectation")
|
575
|
+
if feedback is not None:
|
576
|
+
assessment.feedback.CopyFrom(feedback.to_proto())
|
577
|
+
mask.paths.append("feedback")
|
578
|
+
if rationale is not None:
|
579
|
+
assessment.rationale = rationale
|
580
|
+
mask.paths.append("rationale")
|
581
|
+
if metadata is not None:
|
582
|
+
assessment.metadata.update(metadata)
|
583
|
+
mask.paths.append("metadata")
|
584
|
+
|
585
|
+
req_body = message_to_json(update)
|
586
|
+
response_proto = self._call_endpoint(
|
587
|
+
UpdateAssessment,
|
588
|
+
req_body,
|
589
|
+
endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
|
590
|
+
)
|
591
|
+
return Assessment.from_proto(response_proto.assessment)
|
592
|
+
|
593
|
+
def delete_assessment(self, trace_id: str, assessment_id: str):
|
594
|
+
"""
|
595
|
+
Delete an assessment associated with a trace.
|
596
|
+
|
597
|
+
Args:
|
598
|
+
trace_id: String ID of the trace.
|
599
|
+
assessment_id: String ID of the assessment to delete.
|
600
|
+
"""
|
601
|
+
req_body = message_to_json(DeleteAssessment(trace_id=trace_id, assessment_id=assessment_id))
|
602
|
+
self._call_endpoint(
|
603
|
+
DeleteAssessment,
|
604
|
+
req_body,
|
605
|
+
endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
|
606
|
+
)
|
607
|
+
|
608
|
+
def log_metric(self, run_id: str, metric: Metric):
|
609
|
+
"""
|
610
|
+
Log a metric for the specified run
|
611
|
+
|
612
|
+
Args:
|
613
|
+
run_id: String id for the run
|
614
|
+
metric: Metric instance to log
|
615
|
+
"""
|
616
|
+
req_body = message_to_json(
|
617
|
+
LogMetric(
|
618
|
+
run_uuid=run_id,
|
619
|
+
run_id=run_id,
|
620
|
+
key=metric.key,
|
621
|
+
value=metric.value,
|
622
|
+
timestamp=metric.timestamp,
|
623
|
+
step=metric.step,
|
624
|
+
model_id=metric.model_id,
|
625
|
+
dataset_name=metric.dataset_name,
|
626
|
+
dataset_digest=metric.dataset_digest,
|
627
|
+
)
|
628
|
+
)
|
629
|
+
self._call_endpoint(LogMetric, req_body)
|
630
|
+
|
631
|
+
def log_param(self, run_id, param):
|
632
|
+
"""
|
633
|
+
Log a param for the specified run
|
634
|
+
|
635
|
+
Args:
|
636
|
+
run_id: String id for the run
|
637
|
+
param: Param instance to log
|
638
|
+
"""
|
639
|
+
req_body = message_to_json(
|
640
|
+
LogParam(run_uuid=run_id, run_id=run_id, key=param.key, value=param.value)
|
641
|
+
)
|
642
|
+
self._call_endpoint(LogParam, req_body)
|
643
|
+
|
644
|
+
def set_experiment_tag(self, experiment_id, tag):
|
645
|
+
"""
|
646
|
+
Set a tag for the specified experiment
|
647
|
+
|
648
|
+
Args:
|
649
|
+
experiment_id: String ID of the experiment
|
650
|
+
tag: ExperimentRunTag instance to log
|
651
|
+
"""
|
652
|
+
req_body = message_to_json(
|
653
|
+
SetExperimentTag(experiment_id=experiment_id, key=tag.key, value=tag.value)
|
654
|
+
)
|
655
|
+
self._call_endpoint(SetExperimentTag, req_body)
|
656
|
+
|
657
|
+
def set_tag(self, run_id, tag):
|
658
|
+
"""
|
659
|
+
Set a tag for the specified run
|
660
|
+
|
661
|
+
Args:
|
662
|
+
run_id: String ID of the run
|
663
|
+
tag: RunTag instance to log
|
664
|
+
"""
|
665
|
+
req_body = message_to_json(
|
666
|
+
SetTag(run_uuid=run_id, run_id=run_id, key=tag.key, value=tag.value)
|
667
|
+
)
|
668
|
+
self._call_endpoint(SetTag, req_body)
|
669
|
+
|
670
|
+
def delete_tag(self, run_id, key):
|
671
|
+
"""
|
672
|
+
Delete a tag from a run. This is irreversible.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
run_id: String ID of the run.
|
676
|
+
key: Name of the tag.
|
677
|
+
"""
|
678
|
+
req_body = message_to_json(DeleteTag(run_id=run_id, key=key))
|
679
|
+
self._call_endpoint(DeleteTag, req_body)
|
680
|
+
|
681
|
+
def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
|
682
|
+
"""
|
683
|
+
Return all logged values for a given metric.
|
684
|
+
|
685
|
+
Args:
|
686
|
+
run_id: Unique identifier for run.
|
687
|
+
metric_key: Metric name within the run.
|
688
|
+
max_results: Maximum number of metric history events (steps) to return per paged
|
689
|
+
query. Only supported in 'databricks' backend.
|
690
|
+
page_token: A Token specifying the next paginated set of results of metric history.
|
691
|
+
|
692
|
+
Returns:
|
693
|
+
A PagedList of :py:class:`mlflow.entities.Metric` entities if a paginated request
|
694
|
+
is made by setting ``max_results`` to a value other than ``None``, a List of
|
695
|
+
:py:class:`mlflow.entities.Metric` entities if ``max_results`` is None, else, if no
|
696
|
+
metrics of the ``metric_key`` have been logged to the ``run_id``, an empty list.
|
697
|
+
"""
|
698
|
+
req_body = message_to_json(
|
699
|
+
GetMetricHistory(
|
700
|
+
run_uuid=run_id,
|
701
|
+
run_id=run_id,
|
702
|
+
metric_key=metric_key,
|
703
|
+
max_results=max_results,
|
704
|
+
page_token=page_token,
|
705
|
+
)
|
706
|
+
)
|
707
|
+
response_proto = self._call_endpoint(GetMetricHistory, req_body)
|
708
|
+
|
709
|
+
metric_history = [Metric.from_proto(metric) for metric in response_proto.metrics]
|
710
|
+
return PagedList(metric_history, response_proto.next_page_token or None)
|
711
|
+
|
712
|
+
def _search_runs(
|
713
|
+
self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token
|
714
|
+
):
|
715
|
+
experiment_ids = [str(experiment_id) for experiment_id in experiment_ids]
|
716
|
+
sr = SearchRuns(
|
717
|
+
experiment_ids=experiment_ids,
|
718
|
+
filter=filter_string,
|
719
|
+
run_view_type=ViewType.to_proto(run_view_type),
|
720
|
+
max_results=max_results,
|
721
|
+
order_by=order_by,
|
722
|
+
page_token=page_token,
|
723
|
+
)
|
724
|
+
req_body = message_to_json(sr)
|
725
|
+
response_proto = self._call_endpoint(SearchRuns, req_body)
|
726
|
+
runs = [Run.from_proto(proto_run) for proto_run in response_proto.runs]
|
727
|
+
# If next_page_token is not set, we will see it as "". We need to convert this to None.
|
728
|
+
next_page_token = None
|
729
|
+
if response_proto.next_page_token:
|
730
|
+
next_page_token = response_proto.next_page_token
|
731
|
+
return runs, next_page_token
|
732
|
+
|
733
|
+
def delete_run(self, run_id):
|
734
|
+
req_body = message_to_json(DeleteRun(run_id=run_id))
|
735
|
+
self._call_endpoint(DeleteRun, req_body)
|
736
|
+
|
737
|
+
def restore_run(self, run_id):
|
738
|
+
req_body = message_to_json(RestoreRun(run_id=run_id))
|
739
|
+
self._call_endpoint(RestoreRun, req_body)
|
740
|
+
|
741
|
+
def get_experiment_by_name(self, experiment_name):
|
742
|
+
try:
|
743
|
+
req_body = message_to_json(GetExperimentByName(experiment_name=experiment_name))
|
744
|
+
response_proto = self._call_endpoint(GetExperimentByName, req_body)
|
745
|
+
return Experiment.from_proto(response_proto.experiment)
|
746
|
+
except MlflowException as e:
|
747
|
+
if e.error_code == databricks_pb2.ErrorCode.Name(
|
748
|
+
databricks_pb2.RESOURCE_DOES_NOT_EXIST
|
749
|
+
):
|
750
|
+
return None
|
751
|
+
else:
|
752
|
+
raise
|
753
|
+
|
754
|
+
def log_batch(self, run_id, metrics, params, tags):
|
755
|
+
metric_protos = [metric.to_proto() for metric in metrics]
|
756
|
+
param_protos = [param.to_proto() for param in params]
|
757
|
+
tag_protos = [tag.to_proto() for tag in tags]
|
758
|
+
req_body = message_to_json(
|
759
|
+
LogBatch(metrics=metric_protos, params=param_protos, tags=tag_protos, run_id=run_id)
|
760
|
+
)
|
761
|
+
self._call_endpoint(LogBatch, req_body)
|
762
|
+
|
763
|
+
def record_logged_model(self, run_id, mlflow_model):
|
764
|
+
req_body = message_to_json(
|
765
|
+
LogModel(run_id=run_id, model_json=json.dumps(mlflow_model.get_tags_dict()))
|
766
|
+
)
|
767
|
+
self._call_endpoint(LogModel, req_body)
|
768
|
+
|
769
|
+
def create_logged_model(
|
770
|
+
self,
|
771
|
+
experiment_id: str,
|
772
|
+
name: Optional[str] = None,
|
773
|
+
source_run_id: Optional[str] = None,
|
774
|
+
tags: Optional[list[LoggedModelTag]] = None,
|
775
|
+
params: Optional[list[LoggedModelParameter]] = None,
|
776
|
+
model_type: Optional[str] = None,
|
777
|
+
) -> LoggedModel:
|
778
|
+
"""
|
779
|
+
Create a new logged model.
|
780
|
+
|
781
|
+
Args:
|
782
|
+
experiment_id: ID of the experiment to which the model belongs.
|
783
|
+
name: Name of the model. If not specified, a random name will be generated.
|
784
|
+
source_run_id: ID of the run that produced the model.
|
785
|
+
tags: Tags to set on the model.
|
786
|
+
params: Parameters to set on the model.
|
787
|
+
model_type: Type of the model.
|
788
|
+
|
789
|
+
Returns:
|
790
|
+
The created model.
|
791
|
+
"""
|
792
|
+
# Include the first 100 params in the initial request
|
793
|
+
initial_params = []
|
794
|
+
remaining_params = []
|
795
|
+
if params:
|
796
|
+
initial_batch_size = _MLFLOW_CREATE_LOGGED_MODEL_PARAMS_BATCH_SIZE.get()
|
797
|
+
initial_params = params[:initial_batch_size]
|
798
|
+
remaining_params = params[initial_batch_size:]
|
799
|
+
|
800
|
+
req_body = message_to_json(
|
801
|
+
CreateLoggedModel(
|
802
|
+
experiment_id=experiment_id,
|
803
|
+
name=name,
|
804
|
+
model_type=model_type,
|
805
|
+
source_run_id=source_run_id,
|
806
|
+
params=[p.to_proto() for p in initial_params],
|
807
|
+
tags=[t.to_proto() for t in tags or []],
|
808
|
+
)
|
809
|
+
)
|
810
|
+
|
811
|
+
response_proto = self._call_endpoint(CreateLoggedModel, req_body)
|
812
|
+
model = LoggedModel.from_proto(response_proto.model)
|
813
|
+
|
814
|
+
# Log remaining params if there are any
|
815
|
+
if remaining_params:
|
816
|
+
self.log_logged_model_params(model_id=model.model_id, params=remaining_params)
|
817
|
+
model = self.get_logged_model(model_id=model.model_id)
|
818
|
+
|
819
|
+
return model
|
820
|
+
|
821
|
+
def log_logged_model_params(self, model_id: str, params: list[LoggedModelParameter]) -> None:
|
822
|
+
"""
|
823
|
+
Log parameters for a logged model in batches of 100.
|
824
|
+
|
825
|
+
Args:
|
826
|
+
model_id: ID of the model to log parameters for.
|
827
|
+
params: List of parameters to log.
|
828
|
+
|
829
|
+
Returns:
|
830
|
+
None
|
831
|
+
"""
|
832
|
+
# Process params in batches to avoid exceeding per-request backend limits
|
833
|
+
batch_size = _MLFLOW_LOG_LOGGED_MODEL_PARAMS_BATCH_SIZE.get()
|
834
|
+
endpoint = get_logged_model_endpoint(model_id)
|
835
|
+
for i in range(0, len(params), batch_size):
|
836
|
+
batch = params[i : i + batch_size]
|
837
|
+
req_body = message_to_json(
|
838
|
+
LogLoggedModelParamsRequest(
|
839
|
+
model_id=model_id,
|
840
|
+
params=[p.to_proto() for p in batch],
|
841
|
+
)
|
842
|
+
)
|
843
|
+
self._call_endpoint(
|
844
|
+
LogLoggedModelParamsRequest, json_body=req_body, endpoint=f"{endpoint}/params"
|
845
|
+
)
|
846
|
+
|
847
|
+
def get_logged_model(self, model_id: str) -> LoggedModel:
|
848
|
+
"""
|
849
|
+
Fetch the logged model with the specified ID.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
model_id: ID of the model to fetch.
|
853
|
+
|
854
|
+
Returns:
|
855
|
+
The fetched model.
|
856
|
+
"""
|
857
|
+
endpoint = get_logged_model_endpoint(model_id)
|
858
|
+
response_proto = self._call_endpoint(GetLoggedModel, endpoint=endpoint)
|
859
|
+
return LoggedModel.from_proto(response_proto.model)
|
860
|
+
|
861
|
+
def delete_logged_model(self, model_id) -> None:
|
862
|
+
request = DeleteLoggedModel(model_id=model_id)
|
863
|
+
endpoint = get_logged_model_endpoint(model_id)
|
864
|
+
self._call_endpoint(
|
865
|
+
DeleteLoggedModel, endpoint=endpoint, json_body=message_to_json(request)
|
866
|
+
)
|
867
|
+
|
868
|
+
def search_logged_models(
|
869
|
+
self,
|
870
|
+
experiment_ids: list[str],
|
871
|
+
filter_string: Optional[str] = None,
|
872
|
+
datasets: Optional[list[dict[str, Any]]] = None,
|
873
|
+
max_results: Optional[int] = None,
|
874
|
+
order_by: Optional[list[dict[str, Any]]] = None,
|
875
|
+
page_token: Optional[str] = None,
|
876
|
+
) -> PagedList[LoggedModel]:
|
877
|
+
"""
|
878
|
+
Search for logged models that match the specified search criteria.
|
879
|
+
|
880
|
+
Args:
|
881
|
+
experiment_ids: List of experiment ids to scope the search.
|
882
|
+
filter_string: A search filter string.
|
883
|
+
datasets: List of dictionaries to specify datasets on which to apply metrics filters.
|
884
|
+
The following fields are supported:
|
885
|
+
|
886
|
+
dataset_name (str): Required. Name of the dataset.
|
887
|
+
dataset_digest (str): Optional. Digest of the dataset.
|
888
|
+
max_results: Maximum number of logged models desired.
|
889
|
+
order_by: List of dictionaries to specify the ordering of the search results.
|
890
|
+
The following fields are supported:
|
891
|
+
|
892
|
+
field_name (str): Required. Name of the field to order by, e.g. "metrics.accuracy".
|
893
|
+
ascending: (bool): Optional. Whether the order is ascending or not.
|
894
|
+
dataset_name: (str): Optional. If ``field_name`` refers to a metric, this field
|
895
|
+
specifies the name of the dataset associated with the metric. Only metrics
|
896
|
+
associated with the specified dataset name will be considered for ordering.
|
897
|
+
This field may only be set if ``field_name`` refers to a metric.
|
898
|
+
dataset_digest (str): Optional. If ``field_name`` refers to a metric, this field
|
899
|
+
specifies the digest of the dataset associated with the metric. Only metrics
|
900
|
+
associated with the specified dataset name and digest will be considered for
|
901
|
+
ordering. This field may only be set if ``dataset_name`` is also set.
|
902
|
+
page_token: Token specifying the next page of results.
|
903
|
+
|
904
|
+
Returns:
|
905
|
+
A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
|
906
|
+
:py:class:`LoggedModel <mlflow.entities.LoggedModel>` objects.
|
907
|
+
"""
|
908
|
+
req_body = message_to_json(
|
909
|
+
SearchLoggedModels(
|
910
|
+
experiment_ids=experiment_ids,
|
911
|
+
filter=filter_string,
|
912
|
+
datasets=[
|
913
|
+
SearchLoggedModels.Dataset(
|
914
|
+
dataset_name=d["dataset_name"],
|
915
|
+
dataset_digest=d.get("dataset_digest"),
|
916
|
+
)
|
917
|
+
for d in datasets or []
|
918
|
+
],
|
919
|
+
max_results=max_results,
|
920
|
+
order_by=[
|
921
|
+
SearchLoggedModels.OrderBy(
|
922
|
+
field_name=d["field_name"],
|
923
|
+
ascending=d.get("ascending", True),
|
924
|
+
dataset_name=d.get("dataset_name"),
|
925
|
+
dataset_digest=d.get("dataset_digest"),
|
926
|
+
)
|
927
|
+
for d in order_by or []
|
928
|
+
],
|
929
|
+
page_token=page_token,
|
930
|
+
)
|
931
|
+
)
|
932
|
+
response_proto = self._call_endpoint(SearchLoggedModels, req_body)
|
933
|
+
models = [LoggedModel.from_proto(x) for x in response_proto.models]
|
934
|
+
return PagedList(models, response_proto.next_page_token or None)
|
935
|
+
|
936
|
+
def finalize_logged_model(self, model_id: str, status: LoggedModelStatus) -> LoggedModel:
|
937
|
+
"""
|
938
|
+
Finalize a model by updating its status.
|
939
|
+
|
940
|
+
Args:
|
941
|
+
model_id: ID of the model to finalize.
|
942
|
+
status: Final status to set on the model.
|
943
|
+
|
944
|
+
Returns:
|
945
|
+
The updated model.
|
946
|
+
"""
|
947
|
+
endpoint = get_logged_model_endpoint(model_id)
|
948
|
+
json_body = message_to_json(
|
949
|
+
FinalizeLoggedModel(model_id=model_id, status=status.to_proto())
|
950
|
+
)
|
951
|
+
response_proto = self._call_endpoint(
|
952
|
+
FinalizeLoggedModel, json_body=json_body, endpoint=endpoint
|
953
|
+
)
|
954
|
+
return LoggedModel.from_proto(response_proto.model)
|
955
|
+
|
956
|
+
def set_logged_model_tags(self, model_id: str, tags: list[LoggedModelTag]) -> None:
|
957
|
+
"""
|
958
|
+
Set tags on the specified logged model.
|
959
|
+
|
960
|
+
Args:
|
961
|
+
model_id: ID of the model.
|
962
|
+
tags: Tags to set on the model.
|
963
|
+
|
964
|
+
Returns:
|
965
|
+
None
|
966
|
+
"""
|
967
|
+
endpoint = get_logged_model_endpoint(model_id)
|
968
|
+
json_body = message_to_json(SetLoggedModelTags(tags=[tag.to_proto() for tag in tags]))
|
969
|
+
self._call_endpoint(SetLoggedModelTags, json_body=json_body, endpoint=f"{endpoint}/tags")
|
970
|
+
|
971
|
+
def delete_logged_model_tag(self, model_id: str, key: str) -> None:
|
972
|
+
"""
|
973
|
+
Delete a tag from the specified logged model.
|
974
|
+
|
975
|
+
Args:
|
976
|
+
model_id: ID of the model.
|
977
|
+
key: Key of the tag to delete.
|
978
|
+
|
979
|
+
Returns:
|
980
|
+
The model with the specified tag removed.
|
981
|
+
"""
|
982
|
+
endpoint = get_logged_model_endpoint(model_id)
|
983
|
+
self._call_endpoint(DeleteLoggedModelTag, endpoint=f"{endpoint}/tags/{key}")
|
984
|
+
|
985
|
+
def log_inputs(
|
986
|
+
self,
|
987
|
+
run_id: str,
|
988
|
+
datasets: Optional[list[DatasetInput]] = None,
|
989
|
+
models: Optional[list[LoggedModelInput]] = None,
|
990
|
+
):
|
991
|
+
"""
|
992
|
+
Log inputs, such as datasets, to the specified run.
|
993
|
+
|
994
|
+
Args:
|
995
|
+
run_id: String id for the run
|
996
|
+
datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log
|
997
|
+
as inputs to the run.
|
998
|
+
models: List of :py:class:`mlflow.entities.LoggedModelInput` instances to log.
|
999
|
+
|
1000
|
+
Returns:
|
1001
|
+
None.
|
1002
|
+
"""
|
1003
|
+
datasets_protos = [dataset.to_proto() for dataset in datasets or []]
|
1004
|
+
models_protos = [model.to_proto() for model in models or []]
|
1005
|
+
req_body = message_to_json(
|
1006
|
+
LogInputs(
|
1007
|
+
run_id=run_id,
|
1008
|
+
datasets=datasets_protos,
|
1009
|
+
models=models_protos,
|
1010
|
+
)
|
1011
|
+
)
|
1012
|
+
self._call_endpoint(LogInputs, req_body)
|
1013
|
+
|
1014
|
+
def log_outputs(self, run_id: str, models: list[LoggedModelOutput]):
|
1015
|
+
"""
|
1016
|
+
Log outputs, such as models, to the specified run.
|
1017
|
+
|
1018
|
+
Args:
|
1019
|
+
run_id: String id for the run
|
1020
|
+
models: List of :py:class:`mlflow.entities.LoggedModelOutput` instances to log
|
1021
|
+
as outputs of the run.
|
1022
|
+
|
1023
|
+
Returns:
|
1024
|
+
None.
|
1025
|
+
"""
|
1026
|
+
req_body = message_to_json(LogOutputs(run_id=run_id, models=[m.to_proto() for m in models]))
|
1027
|
+
self._call_endpoint(LogOutputs, req_body)
|
1028
|
+
|
1029
|
+
############################################################################################
|
1030
|
+
# Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use.
|
1031
|
+
############################################################################################
|
1032
|
+
def deprecated_start_trace_v2(
|
1033
|
+
self,
|
1034
|
+
experiment_id: str,
|
1035
|
+
timestamp_ms: int,
|
1036
|
+
request_metadata: dict[str, str],
|
1037
|
+
tags: dict[str, str],
|
1038
|
+
) -> TraceInfoV2:
|
1039
|
+
"""
|
1040
|
+
DEPRECATED. DO NOT USE.
|
1041
|
+
|
1042
|
+
Start an initial TraceInfo object in the backend store.
|
1043
|
+
|
1044
|
+
Args:
|
1045
|
+
experiment_id: String id of the experiment for this run.
|
1046
|
+
timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
|
1047
|
+
request_metadata: Metadata of the trace.
|
1048
|
+
tags: Tags of the trace.
|
1049
|
+
|
1050
|
+
Returns:
|
1051
|
+
The created TraceInfo object.
|
1052
|
+
"""
|
1053
|
+
request_metadata_proto = []
|
1054
|
+
for key, value in request_metadata.items():
|
1055
|
+
attr = TraceRequestMetadata()
|
1056
|
+
attr.key = key
|
1057
|
+
attr.value = str(value)
|
1058
|
+
request_metadata_proto.append(attr)
|
1059
|
+
|
1060
|
+
tags_proto = []
|
1061
|
+
for key, value in tags.items():
|
1062
|
+
tag = TraceTag()
|
1063
|
+
tag.key = key
|
1064
|
+
tag.value = str(value)
|
1065
|
+
tags_proto.append(tag)
|
1066
|
+
|
1067
|
+
req_body = message_to_json(
|
1068
|
+
StartTrace(
|
1069
|
+
experiment_id=str(experiment_id),
|
1070
|
+
timestamp_ms=timestamp_ms,
|
1071
|
+
request_metadata=request_metadata_proto,
|
1072
|
+
tags=tags_proto,
|
1073
|
+
)
|
1074
|
+
)
|
1075
|
+
response_proto = self._call_endpoint(StartTrace, req_body)
|
1076
|
+
return TraceInfoV2.from_proto(response_proto.trace_info)
|
1077
|
+
|
1078
|
+
def deprecated_end_trace_v2(
|
1079
|
+
self,
|
1080
|
+
request_id: str,
|
1081
|
+
timestamp_ms: int,
|
1082
|
+
status: TraceStatus,
|
1083
|
+
request_metadata: dict[str, str],
|
1084
|
+
tags: dict[str, str],
|
1085
|
+
) -> TraceInfoV2:
|
1086
|
+
"""
|
1087
|
+
DEPRECATED. DO NOT USE.
|
1088
|
+
|
1089
|
+
Update the TraceInfo object in the backend store with the completed trace info.
|
1090
|
+
|
1091
|
+
Args:
|
1092
|
+
request_id: Unique string identifier of the trace.
|
1093
|
+
timestamp_ms: End time of the trace, in milliseconds. The execution time field
|
1094
|
+
in the TraceInfo will be calculated by subtracting the start time from this.
|
1095
|
+
status: Status of the trace.
|
1096
|
+
request_metadata: Metadata of the trace. This will be merged with the existing
|
1097
|
+
metadata logged during the start_trace call.
|
1098
|
+
tags: Tags of the trace. This will be merged with the existing tags logged
|
1099
|
+
during the start_trace or set_trace_tag calls.
|
1100
|
+
|
1101
|
+
Returns:
|
1102
|
+
The updated TraceInfo object.
|
1103
|
+
"""
|
1104
|
+
request_metadata_proto = []
|
1105
|
+
for key, value in request_metadata.items():
|
1106
|
+
attr = TraceRequestMetadata()
|
1107
|
+
attr.key = key
|
1108
|
+
attr.value = str(value)
|
1109
|
+
request_metadata_proto.append(attr)
|
1110
|
+
|
1111
|
+
tags_proto = []
|
1112
|
+
for key, value in tags.items():
|
1113
|
+
tag = TraceTag()
|
1114
|
+
tag.key = key
|
1115
|
+
tag.value = str(value)
|
1116
|
+
tags_proto.append(tag)
|
1117
|
+
|
1118
|
+
req_body = message_to_json(
|
1119
|
+
EndTrace(
|
1120
|
+
request_id=request_id,
|
1121
|
+
timestamp_ms=timestamp_ms,
|
1122
|
+
status=status.to_proto(),
|
1123
|
+
request_metadata=request_metadata_proto,
|
1124
|
+
tags=tags_proto,
|
1125
|
+
)
|
1126
|
+
)
|
1127
|
+
# EndTrace endpoint is a dynamic path built with the request_id
|
1128
|
+
# Always use v2 endpoint (not v3) for this endpoint to maintain compatibility
|
1129
|
+
endpoint = f"{_REST_API_PATH_PREFIX}/mlflow/traces/{request_id}"
|
1130
|
+
response_proto = self._call_endpoint(EndTrace, req_body, endpoint=endpoint)
|
1131
|
+
return TraceInfoV2.from_proto(response_proto.trace_info)
|