genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1286 @@
|
|
1
|
+
import logging
|
2
|
+
import urllib
|
3
|
+
from typing import Any, Optional, Union
|
4
|
+
|
5
|
+
import sqlalchemy
|
6
|
+
from sqlalchemy.future import select
|
7
|
+
|
8
|
+
import mlflow.store.db.utils
|
9
|
+
from mlflow.entities.model_registry.model_version_stages import (
|
10
|
+
ALL_STAGES,
|
11
|
+
DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
|
12
|
+
STAGE_ARCHIVED,
|
13
|
+
STAGE_DELETED_INTERNAL,
|
14
|
+
get_canonical_stage,
|
15
|
+
)
|
16
|
+
from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
|
17
|
+
from mlflow.exceptions import MlflowException
|
18
|
+
from mlflow.prompt.registry_utils import handle_resource_already_exist_error, has_prompt_tag
|
19
|
+
from mlflow.protos.databricks_pb2 import (
|
20
|
+
INVALID_PARAMETER_VALUE,
|
21
|
+
INVALID_STATE,
|
22
|
+
RESOURCE_ALREADY_EXISTS,
|
23
|
+
RESOURCE_DOES_NOT_EXIST,
|
24
|
+
)
|
25
|
+
from mlflow.store.artifact.utils.models import _parse_model_uri
|
26
|
+
from mlflow.store.entities.paged_list import PagedList
|
27
|
+
from mlflow.store.model_registry import (
|
28
|
+
SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
|
29
|
+
SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
|
30
|
+
SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
|
31
|
+
SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
|
32
|
+
)
|
33
|
+
from mlflow.store.model_registry.abstract_store import AbstractStore
|
34
|
+
from mlflow.store.model_registry.dbmodels.models import (
|
35
|
+
SqlModelVersion,
|
36
|
+
SqlModelVersionTag,
|
37
|
+
SqlRegisteredModel,
|
38
|
+
SqlRegisteredModelAlias,
|
39
|
+
SqlRegisteredModelTag,
|
40
|
+
)
|
41
|
+
from mlflow.tracking.client import MlflowClient
|
42
|
+
from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils, SearchUtils
|
43
|
+
from mlflow.utils.time import get_current_time_millis
|
44
|
+
from mlflow.utils.uri import extract_db_type_from_uri
|
45
|
+
from mlflow.utils.validation import (
|
46
|
+
_validate_model_alias_name,
|
47
|
+
_validate_model_name,
|
48
|
+
_validate_model_renaming,
|
49
|
+
_validate_model_version,
|
50
|
+
_validate_model_version_tag,
|
51
|
+
_validate_registered_model_tag,
|
52
|
+
_validate_tag_name,
|
53
|
+
)
|
54
|
+
|
55
|
+
_logger = logging.getLogger(__name__)
|
56
|
+
|
57
|
+
# For each database table, fetch its columns and define an appropriate attribute for each column
|
58
|
+
# on the table's associated object representation (Mapper). This is necessary to ensure that
|
59
|
+
# columns defined via backreference are available as Mapper instance attributes (e.g.,
|
60
|
+
# ``SqlRegisteredModel.model_versions``). For more information, see
|
61
|
+
# https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.configure_mappers
|
62
|
+
# and https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper
|
63
|
+
sqlalchemy.orm.configure_mappers()
|
64
|
+
|
65
|
+
|
66
|
+
class SqlAlchemyStore(AbstractStore):
|
67
|
+
"""
|
68
|
+
This entity may change or be removed in a future release without warning.
|
69
|
+
SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
|
70
|
+
supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
|
71
|
+
As specified in the
|
72
|
+
`SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ ,
|
73
|
+
the database URI is expected in the format
|
74
|
+
``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. If you do not
|
75
|
+
specify a driver, SQLAlchemy uses a dialect's default driver.
|
76
|
+
|
77
|
+
This store interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
|
78
|
+
:py:class:`mlflow.store.model_registry.models.RegisteredModel` and
|
79
|
+
:py:class:`mlflow.store.model_registry.models.ModelVersion`
|
80
|
+
"""
|
81
|
+
|
82
|
+
CREATE_MODEL_VERSION_RETRIES = 3
|
83
|
+
|
84
|
+
def __init__(self, db_uri):
|
85
|
+
"""
|
86
|
+
Create a database backed store.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
db_uri: The SQLAlchemy database URI string to connect to the database. See
|
90
|
+
the `SQLAlchemy docs
|
91
|
+
<https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
|
92
|
+
for format specifications. MLflow supports the dialects ``mysql``,
|
93
|
+
``mssql``, ``sqlite``, and ``postgresql``.
|
94
|
+
default_artifact_root: Path/URI to location suitable for large data (such as a blob
|
95
|
+
store object, DBFS path, or shared NFS file system).
|
96
|
+
"""
|
97
|
+
super().__init__()
|
98
|
+
self.db_uri = db_uri
|
99
|
+
self.db_type = extract_db_type_from_uri(db_uri)
|
100
|
+
self.engine = mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri)
|
101
|
+
if not mlflow.store.db.utils._all_tables_exist(self.engine):
|
102
|
+
mlflow.store.db.utils._initialize_tables(self.engine)
|
103
|
+
# Verify that all model registry tables exist.
|
104
|
+
SqlAlchemyStore._verify_registry_tables_exist(self.engine)
|
105
|
+
SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
|
106
|
+
self.ManagedSessionMaker = mlflow.store.db.utils._get_managed_session_maker(
|
107
|
+
SessionMaker, self.db_type
|
108
|
+
)
|
109
|
+
# TODO: verify schema here once we add logic to initialize the registry tables if they
|
110
|
+
# don't exist (schema verification will fail in tests otherwise)
|
111
|
+
# mlflow.store.db.utils._verify_schema(self.engine)
|
112
|
+
|
113
|
+
def _get_dialect(self):
|
114
|
+
return self.engine.dialect.name
|
115
|
+
|
116
|
+
def _dispose_engine(self):
|
117
|
+
self.engine.dispose()
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def _verify_registry_tables_exist(engine):
|
121
|
+
# Verify that all tables have been created.
|
122
|
+
inspected_tables = set(sqlalchemy.inspect(engine).get_table_names())
|
123
|
+
expected_tables = [
|
124
|
+
SqlRegisteredModel.__tablename__,
|
125
|
+
SqlModelVersion.__tablename__,
|
126
|
+
]
|
127
|
+
if any(table not in inspected_tables for table in expected_tables):
|
128
|
+
# TODO: Replace the MlflowException with the following line once it's possible to run
|
129
|
+
# the registry against a different DB than the tracking server:
|
130
|
+
# mlflow.store.db.utils._initialize_tables(self.engine)
|
131
|
+
raise MlflowException("Database migration in unexpected state. Run manual upgrade.")
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def _get_eager_registered_model_query_options():
|
135
|
+
"""
|
136
|
+
A list of SQLAlchemy query options that can be used to eagerly
|
137
|
+
load the following registered model attributes
|
138
|
+
when fetching a registered model: ``registered_model_tags``.
|
139
|
+
"""
|
140
|
+
# Use a subquery load rather than a joined load in order to minimize the memory overhead
|
141
|
+
# of the eager loading procedure. For more information about relationship loading
|
142
|
+
# techniques, see https://docs.sqlalchemy.org/en/13/orm/
|
143
|
+
# loading_relationships.html#relationship-loading-techniques
|
144
|
+
return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def _get_eager_model_version_query_options():
|
148
|
+
"""
|
149
|
+
A list of SQLAlchemy query options that can be used to eagerly
|
150
|
+
load the following model version attributes
|
151
|
+
when fetching a model version: ``model_version_tags``.
|
152
|
+
"""
|
153
|
+
# Use a subquery load rather than a joined load in order to minimize the memory overhead
|
154
|
+
# of the eager loading procedure. For more information about relationship loading
|
155
|
+
# techniques, see https://docs.sqlalchemy.org/en/13/orm/
|
156
|
+
# loading_relationships.html#relationship-loading-techniques
|
157
|
+
return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_version_tags)]
|
158
|
+
|
159
|
+
def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
|
160
|
+
"""
|
161
|
+
Create a new registered model in backend store.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
name: Name of the new model. This is expected to be unique in the backend store.
|
165
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
|
166
|
+
instances associated with this registered model.
|
167
|
+
description: Description of the version.
|
168
|
+
deployment_job_id: Optional deployment job ID.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
|
172
|
+
created in the backend.
|
173
|
+
"""
|
174
|
+
_validate_model_name(name)
|
175
|
+
for tag in tags or []:
|
176
|
+
_validate_registered_model_tag(tag.key, tag.value)
|
177
|
+
with self.ManagedSessionMaker() as session:
|
178
|
+
try:
|
179
|
+
creation_time = get_current_time_millis()
|
180
|
+
registered_model = SqlRegisteredModel(
|
181
|
+
name=name,
|
182
|
+
creation_time=creation_time,
|
183
|
+
last_updated_time=creation_time,
|
184
|
+
description=description,
|
185
|
+
)
|
186
|
+
tags_dict = {}
|
187
|
+
for tag in tags or []:
|
188
|
+
tags_dict[tag.key] = tag.value
|
189
|
+
registered_model.registered_model_tags = [
|
190
|
+
SqlRegisteredModelTag(key=key, value=value) for key, value in tags_dict.items()
|
191
|
+
]
|
192
|
+
session.add(registered_model)
|
193
|
+
session.flush()
|
194
|
+
return registered_model.to_mlflow_entity()
|
195
|
+
except sqlalchemy.exc.IntegrityError:
|
196
|
+
existing_model = self.get_registered_model(name)
|
197
|
+
handle_resource_already_exist_error(
|
198
|
+
name, has_prompt_tag(existing_model._tags), has_prompt_tag(tags)
|
199
|
+
)
|
200
|
+
|
201
|
+
@classmethod
|
202
|
+
def _get_registered_model(cls, session, name, eager=False): # noqa: D417
|
203
|
+
"""
|
204
|
+
Args:
|
205
|
+
eager: If ``True``, eagerly loads the registered model's tags. If ``False``, these
|
206
|
+
attributes are not eagerly loaded and will be loaded when their corresponding object
|
207
|
+
properties are accessed from the resulting ``SqlRegisteredModel`` object.
|
208
|
+
"""
|
209
|
+
_validate_model_name(name)
|
210
|
+
query_options = cls._get_eager_registered_model_query_options() if eager else []
|
211
|
+
rms = (
|
212
|
+
session.query(SqlRegisteredModel)
|
213
|
+
.options(*query_options)
|
214
|
+
.filter(SqlRegisteredModel.name == name)
|
215
|
+
.all()
|
216
|
+
)
|
217
|
+
|
218
|
+
if len(rms) == 0:
|
219
|
+
raise MlflowException(
|
220
|
+
f"Registered Model with name={name} not found", RESOURCE_DOES_NOT_EXIST
|
221
|
+
)
|
222
|
+
if len(rms) > 1:
|
223
|
+
raise MlflowException(
|
224
|
+
f"Expected only 1 registered model with name={name}. Found {len(rms)}.",
|
225
|
+
INVALID_STATE,
|
226
|
+
)
|
227
|
+
return rms[0]
|
228
|
+
|
229
|
+
def update_registered_model(self, name, description, deployment_job_id=None):
|
230
|
+
"""
|
231
|
+
Update description of the registered model.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
name: Registered model name.
|
235
|
+
description: New description.
|
236
|
+
deployment_job_id: Optional deployment job ID.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
240
|
+
|
241
|
+
"""
|
242
|
+
with self.ManagedSessionMaker() as session:
|
243
|
+
sql_registered_model = self._get_registered_model(session, name)
|
244
|
+
updated_time = get_current_time_millis()
|
245
|
+
sql_registered_model.description = description
|
246
|
+
sql_registered_model.last_updated_time = updated_time
|
247
|
+
session.add(sql_registered_model)
|
248
|
+
session.flush()
|
249
|
+
return sql_registered_model.to_mlflow_entity()
|
250
|
+
|
251
|
+
def rename_registered_model(self, name, new_name):
|
252
|
+
"""
|
253
|
+
Rename the registered model.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
name: Registered model name.
|
257
|
+
new_name: New proposed name.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
261
|
+
|
262
|
+
"""
|
263
|
+
_validate_model_renaming(new_name)
|
264
|
+
with self.ManagedSessionMaker() as session:
|
265
|
+
sql_registered_model = self._get_registered_model(session, name)
|
266
|
+
try:
|
267
|
+
updated_time = get_current_time_millis()
|
268
|
+
sql_registered_model.name = new_name
|
269
|
+
for sql_model_version in sql_registered_model.model_versions:
|
270
|
+
sql_model_version.name = new_name
|
271
|
+
sql_model_version.last_updated_time = updated_time
|
272
|
+
sql_registered_model.last_updated_time = updated_time
|
273
|
+
session.add_all([sql_registered_model] + sql_registered_model.model_versions)
|
274
|
+
session.flush()
|
275
|
+
return sql_registered_model.to_mlflow_entity()
|
276
|
+
except sqlalchemy.exc.IntegrityError as e:
|
277
|
+
raise MlflowException(
|
278
|
+
f"Registered Model (name={new_name}) already exists. Error: {e}",
|
279
|
+
RESOURCE_ALREADY_EXISTS,
|
280
|
+
)
|
281
|
+
|
282
|
+
def delete_registered_model(self, name):
|
283
|
+
"""
|
284
|
+
Delete the registered model.
|
285
|
+
Backend raises exception if a registered model with given name does not exist.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
name: Registered model name.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
None
|
292
|
+
"""
|
293
|
+
with self.ManagedSessionMaker() as session:
|
294
|
+
sql_registered_model = self._get_registered_model(session, name)
|
295
|
+
session.delete(sql_registered_model)
|
296
|
+
|
297
|
+
def _compute_next_token(self, max_results_for_query, current_size, offset, max_results):
|
298
|
+
next_token = None
|
299
|
+
if max_results_for_query == current_size:
|
300
|
+
final_offset = offset + max_results
|
301
|
+
next_token = SearchUtils.create_page_token(final_offset)
|
302
|
+
return next_token
|
303
|
+
|
304
|
+
def search_registered_models(
|
305
|
+
self,
|
306
|
+
filter_string=None,
|
307
|
+
max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
|
308
|
+
order_by=None,
|
309
|
+
page_token=None,
|
310
|
+
):
|
311
|
+
"""
|
312
|
+
Search for registered models in backend that satisfy the filter criteria.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
filter_string: Filter query string, defaults to searching all registered models.
|
316
|
+
max_results: Maximum number of registered models desired.
|
317
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
318
|
+
matching search results.
|
319
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
320
|
+
a ``search_registered_models`` call.
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
|
324
|
+
that satisfy the search expressions. The pagination token for the next page can be
|
325
|
+
obtained via the ``token`` attribute of the object.
|
326
|
+
"""
|
327
|
+
if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
|
328
|
+
raise MlflowException(
|
329
|
+
"Invalid value for request parameter max_results. It must be at most "
|
330
|
+
f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
|
331
|
+
INVALID_PARAMETER_VALUE,
|
332
|
+
)
|
333
|
+
|
334
|
+
parsed_filters = SearchModelUtils.parse_search_filter(filter_string)
|
335
|
+
|
336
|
+
filter_query = self._get_search_registered_model_filter_query(
|
337
|
+
parsed_filters, self.engine.dialect.name
|
338
|
+
)
|
339
|
+
|
340
|
+
parsed_orderby = self._parse_search_registered_models_order_by(order_by)
|
341
|
+
offset = SearchUtils.parse_start_offset_from_page_token(page_token)
|
342
|
+
# we query for max_results + 1 items to check whether there is another page to return.
|
343
|
+
# this remediates having to make another query which returns no items.
|
344
|
+
max_results_for_query = max_results + 1
|
345
|
+
|
346
|
+
with self.ManagedSessionMaker() as session:
|
347
|
+
query = (
|
348
|
+
filter_query.options(*self._get_eager_registered_model_query_options())
|
349
|
+
.order_by(*parsed_orderby)
|
350
|
+
.limit(max_results_for_query)
|
351
|
+
)
|
352
|
+
if page_token:
|
353
|
+
query = query.offset(offset)
|
354
|
+
sql_registered_models = session.execute(query).scalars(SqlRegisteredModel).all()
|
355
|
+
next_page_token = self._compute_next_token(
|
356
|
+
max_results_for_query, len(sql_registered_models), offset, max_results
|
357
|
+
)
|
358
|
+
rm_entities = [rm.to_mlflow_entity() for rm in sql_registered_models][:max_results]
|
359
|
+
return PagedList(rm_entities, next_page_token)
|
360
|
+
|
361
|
+
@classmethod
|
362
|
+
def _get_search_registered_model_filter_query(cls, parsed_filters, dialect):
|
363
|
+
attribute_filters = []
|
364
|
+
tag_filters = {}
|
365
|
+
for f in parsed_filters:
|
366
|
+
type_ = f["type"]
|
367
|
+
key = f["key"]
|
368
|
+
comparator = f["comparator"]
|
369
|
+
value = f["value"]
|
370
|
+
if type_ == "attribute":
|
371
|
+
if key != "name":
|
372
|
+
raise MlflowException(
|
373
|
+
f"Invalid attribute name: {key}", error_code=INVALID_PARAMETER_VALUE
|
374
|
+
)
|
375
|
+
if comparator not in ("=", "!=", "LIKE", "ILIKE"):
|
376
|
+
raise MlflowException(
|
377
|
+
f"Invalid comparator for attribute: {comparator}",
|
378
|
+
error_code=INVALID_PARAMETER_VALUE,
|
379
|
+
)
|
380
|
+
attr = getattr(SqlRegisteredModel, key)
|
381
|
+
attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(attr, value)
|
382
|
+
attribute_filters.append(attr_filter)
|
383
|
+
elif type_ == "tag":
|
384
|
+
if comparator not in ("=", "!=", "LIKE", "ILIKE"):
|
385
|
+
raise MlflowException.invalid_parameter_value(
|
386
|
+
f"Invalid comparator for tag: {comparator}"
|
387
|
+
)
|
388
|
+
if key not in tag_filters:
|
389
|
+
key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
|
390
|
+
SqlRegisteredModelTag.key, key
|
391
|
+
)
|
392
|
+
tag_filters[key] = [key_filter]
|
393
|
+
|
394
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
395
|
+
SqlRegisteredModelTag.value, value
|
396
|
+
)
|
397
|
+
tag_filters[key].append(val_filter)
|
398
|
+
else:
|
399
|
+
raise MlflowException(
|
400
|
+
f"Invalid token type: {type_}", error_code=INVALID_PARAMETER_VALUE
|
401
|
+
)
|
402
|
+
|
403
|
+
rm_query = select(SqlRegisteredModel).filter(*attribute_filters)
|
404
|
+
|
405
|
+
if not cls._is_querying_prompt(parsed_filters):
|
406
|
+
rm_query = cls._update_query_to_exclude_prompts(
|
407
|
+
rm_query, tag_filters, dialect, SqlRegisteredModel, SqlRegisteredModelTag
|
408
|
+
)
|
409
|
+
|
410
|
+
if tag_filters:
|
411
|
+
sql_tag_filters = (sqlalchemy.and_(*x) for x in tag_filters.values())
|
412
|
+
tag_filter_query = (
|
413
|
+
select(SqlRegisteredModelTag.name)
|
414
|
+
.filter(sqlalchemy.or_(*sql_tag_filters))
|
415
|
+
.group_by(SqlRegisteredModelTag.name)
|
416
|
+
.having(sqlalchemy.func.count(sqlalchemy.literal(1)) == len(tag_filters))
|
417
|
+
.subquery()
|
418
|
+
)
|
419
|
+
|
420
|
+
return rm_query.join(
|
421
|
+
tag_filter_query, SqlRegisteredModel.name == tag_filter_query.c.name
|
422
|
+
)
|
423
|
+
else:
|
424
|
+
return rm_query
|
425
|
+
|
426
|
+
@classmethod
|
427
|
+
def _get_search_model_versions_filter_clauses(cls, parsed_filters, dialect):
|
428
|
+
attribute_filters = []
|
429
|
+
tag_filters = {}
|
430
|
+
for f in parsed_filters:
|
431
|
+
type_ = f["type"]
|
432
|
+
key = f["key"]
|
433
|
+
comparator = f["comparator"]
|
434
|
+
value = f["value"]
|
435
|
+
if type_ == "attribute":
|
436
|
+
if key not in SearchModelVersionUtils.VALID_SEARCH_ATTRIBUTE_KEYS:
|
437
|
+
raise MlflowException(
|
438
|
+
f"Invalid attribute name: {key}", error_code=INVALID_PARAMETER_VALUE
|
439
|
+
)
|
440
|
+
if key in SearchModelVersionUtils.NUMERIC_ATTRIBUTES:
|
441
|
+
if (
|
442
|
+
comparator
|
443
|
+
not in SearchModelVersionUtils.VALID_NUMERIC_ATTRIBUTE_COMPARATORS
|
444
|
+
):
|
445
|
+
raise MlflowException(
|
446
|
+
f"Invalid comparator for attribute {key}: {comparator}",
|
447
|
+
error_code=INVALID_PARAMETER_VALUE,
|
448
|
+
)
|
449
|
+
elif (
|
450
|
+
comparator not in SearchModelVersionUtils.VALID_STRING_ATTRIBUTE_COMPARATORS
|
451
|
+
or (comparator == "IN" and key != "run_id")
|
452
|
+
):
|
453
|
+
raise MlflowException(
|
454
|
+
f"Invalid comparator for attribute: {comparator}",
|
455
|
+
error_code=INVALID_PARAMETER_VALUE,
|
456
|
+
)
|
457
|
+
if key == "source_path":
|
458
|
+
key_name = "source"
|
459
|
+
elif key == "version_number":
|
460
|
+
key_name = "version"
|
461
|
+
else:
|
462
|
+
key_name = key
|
463
|
+
attr = getattr(SqlModelVersion, key_name)
|
464
|
+
if comparator == "IN":
|
465
|
+
# Note: Here the run_id values in databases contain only lower case letters,
|
466
|
+
# so we already filter out comparison values containing upper case letters
|
467
|
+
# in `SearchModelUtils._get_value`. This addresses MySQL IN clause case
|
468
|
+
# in-sensitive issue.
|
469
|
+
val_filter = attr.in_(value)
|
470
|
+
else:
|
471
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
472
|
+
attr, value
|
473
|
+
)
|
474
|
+
attribute_filters.append(val_filter)
|
475
|
+
elif type_ == "tag":
|
476
|
+
if comparator not in ("=", "!=", "LIKE", "ILIKE"):
|
477
|
+
raise MlflowException.invalid_parameter_value(
|
478
|
+
f"Invalid comparator for tag: {comparator}",
|
479
|
+
)
|
480
|
+
if key not in tag_filters:
|
481
|
+
key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
|
482
|
+
SqlModelVersionTag.key, key
|
483
|
+
)
|
484
|
+
tag_filters[key] = [key_filter]
|
485
|
+
|
486
|
+
val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
|
487
|
+
SqlModelVersionTag.value, value
|
488
|
+
)
|
489
|
+
tag_filters[key].append(val_filter)
|
490
|
+
else:
|
491
|
+
raise MlflowException(
|
492
|
+
f"Invalid token type: {type_}", error_code=INVALID_PARAMETER_VALUE
|
493
|
+
)
|
494
|
+
|
495
|
+
mv_query = select(SqlModelVersion).filter(*attribute_filters)
|
496
|
+
|
497
|
+
if not cls._is_querying_prompt(parsed_filters):
|
498
|
+
mv_query = cls._update_query_to_exclude_prompts(
|
499
|
+
mv_query, tag_filters, dialect, SqlModelVersion, SqlModelVersionTag
|
500
|
+
)
|
501
|
+
|
502
|
+
if tag_filters:
|
503
|
+
sql_tag_filters = (sqlalchemy.and_(*x) for x in tag_filters.values())
|
504
|
+
tag_filter_query = (
|
505
|
+
select(SqlModelVersionTag.name, SqlModelVersionTag.version)
|
506
|
+
.filter(sqlalchemy.or_(*sql_tag_filters))
|
507
|
+
.group_by(SqlModelVersionTag.name, SqlModelVersionTag.version)
|
508
|
+
.having(sqlalchemy.func.count(sqlalchemy.literal(1)) == len(tag_filters))
|
509
|
+
.subquery()
|
510
|
+
)
|
511
|
+
return mv_query.join(
|
512
|
+
tag_filter_query,
|
513
|
+
sqlalchemy.and_(
|
514
|
+
SqlModelVersion.name == tag_filter_query.c.name,
|
515
|
+
SqlModelVersion.version == tag_filter_query.c.version,
|
516
|
+
),
|
517
|
+
)
|
518
|
+
else:
|
519
|
+
return mv_query
|
520
|
+
|
521
|
+
@classmethod
|
522
|
+
def _update_query_to_exclude_prompts(
|
523
|
+
cls,
|
524
|
+
query: Any,
|
525
|
+
tag_filters: dict[str, list[Any]],
|
526
|
+
dialect: str,
|
527
|
+
main_db_model: Union[SqlModelVersion, SqlRegisteredModel],
|
528
|
+
tag_db_model: Union[SqlModelVersionTag, SqlRegisteredModelTag],
|
529
|
+
):
|
530
|
+
"""
|
531
|
+
Update query to exclude all prompt rows and return only normal model or model versions.
|
532
|
+
|
533
|
+
Prompts and normal models are distinguished by the `mlflow.prompt.is_prompt` tag.
|
534
|
+
The search API should only return normal models by default. However, simply filtering
|
535
|
+
rows using the tag like this does not work because models do not have the prompt tag.
|
536
|
+
|
537
|
+
tags.`mlflow.prompt.is_prompt` != 'true'
|
538
|
+
tags.`mlflow.prompt.is_prompt` = 'false'
|
539
|
+
|
540
|
+
To workaround this, we need to use a subquery to get all prompt rows and then use an
|
541
|
+
anti-join for excluding prompts.
|
542
|
+
"""
|
543
|
+
# If the tag filter contains the prompt tag, remove it
|
544
|
+
tag_filters.pop(IS_PROMPT_TAG_KEY, [])
|
545
|
+
|
546
|
+
# Filter to get all prompt rows
|
547
|
+
equal = SearchUtils.get_sql_comparison_func("=", dialect)
|
548
|
+
prompts_subquery = (
|
549
|
+
select(tag_db_model.name)
|
550
|
+
.filter(
|
551
|
+
equal(tag_db_model.key, IS_PROMPT_TAG_KEY),
|
552
|
+
equal(tag_db_model.value, "true"),
|
553
|
+
)
|
554
|
+
.group_by(tag_db_model.name)
|
555
|
+
.subquery()
|
556
|
+
)
|
557
|
+
return query.join(
|
558
|
+
prompts_subquery, main_db_model.name == prompts_subquery.c.name, isouter=True
|
559
|
+
).filter(prompts_subquery.c.name.is_(None))
|
560
|
+
|
561
|
+
@classmethod
|
562
|
+
def _is_querying_prompt(cls, parsed_filters: list[dict[str, Any]]) -> bool:
|
563
|
+
for f in parsed_filters:
|
564
|
+
if f["type"] != "tag" or f["key"] != IS_PROMPT_TAG_KEY:
|
565
|
+
continue
|
566
|
+
|
567
|
+
return (f["comparator"] == "=" and f["value"].lower() == "true") or (
|
568
|
+
f["comparator"] == "!=" and f["value"].lower() == "false"
|
569
|
+
)
|
570
|
+
|
571
|
+
# Query should return only normal models by default
|
572
|
+
return False
|
573
|
+
|
574
|
+
@classmethod
|
575
|
+
def _parse_search_registered_models_order_by(cls, order_by_list):
|
576
|
+
"""Sorts a set of registered models based on their natural ordering and an overriding set
|
577
|
+
of order_bys. Registered models are naturally ordered first by name ascending.
|
578
|
+
"""
|
579
|
+
clauses = []
|
580
|
+
observed_order_by_clauses = set()
|
581
|
+
if order_by_list:
|
582
|
+
for order_by_clause in order_by_list:
|
583
|
+
(
|
584
|
+
attribute_token,
|
585
|
+
ascending,
|
586
|
+
) = SearchUtils.parse_order_by_for_search_registered_models(order_by_clause)
|
587
|
+
if attribute_token == SqlRegisteredModel.name.key:
|
588
|
+
field = SqlRegisteredModel.name
|
589
|
+
elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS:
|
590
|
+
field = SqlRegisteredModel.last_updated_time
|
591
|
+
else:
|
592
|
+
raise MlflowException(
|
593
|
+
f"Invalid order by key '{attribute_token}' specified."
|
594
|
+
+ "Valid keys are "
|
595
|
+
+ f"'{SearchUtils.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
|
596
|
+
error_code=INVALID_PARAMETER_VALUE,
|
597
|
+
)
|
598
|
+
if field.key in observed_order_by_clauses:
|
599
|
+
raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
|
600
|
+
observed_order_by_clauses.add(field.key)
|
601
|
+
if ascending:
|
602
|
+
clauses.append(field.asc())
|
603
|
+
else:
|
604
|
+
clauses.append(field.desc())
|
605
|
+
|
606
|
+
if SqlRegisteredModel.name.key not in observed_order_by_clauses:
|
607
|
+
clauses.append(SqlRegisteredModel.name.asc())
|
608
|
+
return clauses
|
609
|
+
|
610
|
+
def get_registered_model(self, name):
|
611
|
+
"""
|
612
|
+
Get registered model instance by name.
|
613
|
+
|
614
|
+
Args:
|
615
|
+
name: Registered model name.
|
616
|
+
|
617
|
+
Returns:
|
618
|
+
A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
619
|
+
"""
|
620
|
+
with self.ManagedSessionMaker() as session:
|
621
|
+
return self._get_registered_model(session, name, eager=True).to_mlflow_entity()
|
622
|
+
|
623
|
+
def get_latest_versions(self, name, stages=None):
|
624
|
+
"""
|
625
|
+
Latest version models for each requested stage. If no ``stages`` argument is provided,
|
626
|
+
returns the latest version for each stage.
|
627
|
+
|
628
|
+
Args:
|
629
|
+
name: Registered model name.
|
630
|
+
stages: List of desired stages. If input list is None, return latest versions for
|
631
|
+
each stage.
|
632
|
+
|
633
|
+
Returns:
|
634
|
+
List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
|
635
|
+
|
636
|
+
"""
|
637
|
+
with self.ManagedSessionMaker() as session:
|
638
|
+
sql_registered_model = self._get_registered_model(session, name)
|
639
|
+
# Convert to RegisteredModel entity first and then extract latest_versions
|
640
|
+
latest_versions = sql_registered_model.to_mlflow_entity().latest_versions
|
641
|
+
if stages is None or len(stages) == 0:
|
642
|
+
expected_stages = {get_canonical_stage(stage) for stage in ALL_STAGES}
|
643
|
+
else:
|
644
|
+
expected_stages = {get_canonical_stage(stage) for stage in stages}
|
645
|
+
mvs = [mv for mv in latest_versions if mv.current_stage in expected_stages]
|
646
|
+
|
647
|
+
# Populate aliases for each model version
|
648
|
+
for mv in mvs:
|
649
|
+
model_aliases = sql_registered_model.registered_model_aliases
|
650
|
+
mv.aliases = [alias.alias for alias in model_aliases if alias.version == mv.version]
|
651
|
+
|
652
|
+
return mvs
|
653
|
+
|
654
|
+
@classmethod
|
655
|
+
def _get_registered_model_tag(cls, session, name, key):
|
656
|
+
tags = (
|
657
|
+
session.query(SqlRegisteredModelTag)
|
658
|
+
.filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.key == key)
|
659
|
+
.all()
|
660
|
+
)
|
661
|
+
if len(tags) == 0:
|
662
|
+
return None
|
663
|
+
if len(tags) > 1:
|
664
|
+
raise MlflowException(
|
665
|
+
f"Expected only 1 registered model tag with name={name}, key={key}. "
|
666
|
+
f"Found {len(tags)}.",
|
667
|
+
INVALID_STATE,
|
668
|
+
)
|
669
|
+
return tags[0]
|
670
|
+
|
671
|
+
def set_registered_model_tag(self, name, tag):
|
672
|
+
"""
|
673
|
+
Set a tag for the registered model.
|
674
|
+
|
675
|
+
Args:
|
676
|
+
name: Registered model name.
|
677
|
+
tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
|
678
|
+
|
679
|
+
Returns:
|
680
|
+
None
|
681
|
+
"""
|
682
|
+
_validate_model_name(name)
|
683
|
+
_validate_registered_model_tag(tag.key, tag.value)
|
684
|
+
with self.ManagedSessionMaker() as session:
|
685
|
+
# check if registered model exists
|
686
|
+
self._get_registered_model(session, name)
|
687
|
+
session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value))
|
688
|
+
|
689
|
+
def delete_registered_model_tag(self, name, key):
|
690
|
+
"""
|
691
|
+
Delete a tag associated with the registered model.
|
692
|
+
|
693
|
+
Args:
|
694
|
+
name: Registered model name.
|
695
|
+
key: Registered model tag key.
|
696
|
+
|
697
|
+
Returns:
|
698
|
+
None
|
699
|
+
"""
|
700
|
+
_validate_model_name(name)
|
701
|
+
_validate_tag_name(key)
|
702
|
+
with self.ManagedSessionMaker() as session:
|
703
|
+
# check if registered model exists
|
704
|
+
self._get_registered_model(session, name)
|
705
|
+
existing_tag = self._get_registered_model_tag(session, name, key)
|
706
|
+
if existing_tag is not None:
|
707
|
+
session.delete(existing_tag)
|
708
|
+
|
709
|
+
# CRUD API for ModelVersion objects
|
710
|
+
|
711
|
+
def create_model_version(
|
712
|
+
self,
|
713
|
+
name,
|
714
|
+
source,
|
715
|
+
run_id=None,
|
716
|
+
tags=None,
|
717
|
+
run_link=None,
|
718
|
+
description=None,
|
719
|
+
local_model_path=None,
|
720
|
+
model_id: Optional[str] = None,
|
721
|
+
):
|
722
|
+
"""
|
723
|
+
Create a new model version from given source and run ID.
|
724
|
+
|
725
|
+
Args:
|
726
|
+
name: Registered model name.
|
727
|
+
source: URI indicating the location of the model artifacts.
|
728
|
+
run_id: Run ID from MLflow tracking server that generated the model.
|
729
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
|
730
|
+
instances associated with this model version.
|
731
|
+
run_link: Link to the run from an MLflow tracking server that generated this model.
|
732
|
+
description: Description of the version.
|
733
|
+
local_model_path: Unused.
|
734
|
+
model_id: The ID of the model (from an Experiment) that is being promoted to a
|
735
|
+
registered model version, if applicable.
|
736
|
+
|
737
|
+
Returns:
|
738
|
+
A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
739
|
+
created in the backend.
|
740
|
+
|
741
|
+
"""
|
742
|
+
|
743
|
+
def next_version(sql_registered_model):
|
744
|
+
if sql_registered_model.model_versions:
|
745
|
+
return max([mv.version for mv in sql_registered_model.model_versions]) + 1
|
746
|
+
else:
|
747
|
+
return 1
|
748
|
+
|
749
|
+
_validate_model_name(name)
|
750
|
+
for tag in tags or []:
|
751
|
+
_validate_model_version_tag(tag.key, tag.value)
|
752
|
+
storage_location = source
|
753
|
+
if urllib.parse.urlparse(source).scheme == "models":
|
754
|
+
parsed_model_uri = _parse_model_uri(source)
|
755
|
+
try:
|
756
|
+
if parsed_model_uri.model_id is not None:
|
757
|
+
# TODO: Propagate tracking URI to file sqlalchemy directly, rather than relying
|
758
|
+
# on global URI (individual MlflowClient instances may have different tracking
|
759
|
+
# URIs)
|
760
|
+
model = MlflowClient().get_logged_model(parsed_model_uri.model_id)
|
761
|
+
storage_location = model.artifact_location
|
762
|
+
run_id = run_id or model.source_run_id
|
763
|
+
else:
|
764
|
+
storage_location = self.get_model_version_download_uri(
|
765
|
+
parsed_model_uri.name, parsed_model_uri.version
|
766
|
+
)
|
767
|
+
except Exception as e:
|
768
|
+
raise MlflowException(
|
769
|
+
f"Unable to fetch model from model URI source artifact location '{source}'."
|
770
|
+
f"Error: {e}"
|
771
|
+
) from e
|
772
|
+
with self.ManagedSessionMaker() as session:
|
773
|
+
creation_time = get_current_time_millis()
|
774
|
+
for attempt in range(self.CREATE_MODEL_VERSION_RETRIES):
|
775
|
+
try:
|
776
|
+
sql_registered_model = self._get_registered_model(session, name)
|
777
|
+
sql_registered_model.last_updated_time = creation_time
|
778
|
+
version = next_version(sql_registered_model)
|
779
|
+
model_version = SqlModelVersion(
|
780
|
+
name=name,
|
781
|
+
version=version,
|
782
|
+
creation_time=creation_time,
|
783
|
+
last_updated_time=creation_time,
|
784
|
+
source=source,
|
785
|
+
storage_location=storage_location,
|
786
|
+
run_id=run_id,
|
787
|
+
run_link=run_link,
|
788
|
+
description=description,
|
789
|
+
)
|
790
|
+
tags_dict = {}
|
791
|
+
for tag in tags or []:
|
792
|
+
tags_dict[tag.key] = tag.value
|
793
|
+
model_version.model_version_tags = [
|
794
|
+
SqlModelVersionTag(key=key, value=value) for key, value in tags_dict.items()
|
795
|
+
]
|
796
|
+
session.add_all([sql_registered_model, model_version])
|
797
|
+
session.flush()
|
798
|
+
return self._populate_model_version_aliases(
|
799
|
+
session, name, model_version.to_mlflow_entity()
|
800
|
+
)
|
801
|
+
except sqlalchemy.exc.IntegrityError:
|
802
|
+
more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1
|
803
|
+
_logger.info(
|
804
|
+
"Model Version creation error (name=%s) Retrying %s more time%s.",
|
805
|
+
name,
|
806
|
+
str(more_retries),
|
807
|
+
"s" if more_retries > 1 else "",
|
808
|
+
)
|
809
|
+
raise MlflowException(
|
810
|
+
f"Model Version creation error (name={name}). Giving up after "
|
811
|
+
f"{self.CREATE_MODEL_VERSION_RETRIES} attempts."
|
812
|
+
)
|
813
|
+
|
814
|
+
@classmethod
|
815
|
+
def _populate_model_version_aliases(cls, session, name, version):
|
816
|
+
model_aliases = cls._get_registered_model(session, name).registered_model_aliases
|
817
|
+
version.aliases = [
|
818
|
+
alias.alias for alias in model_aliases if alias.version == version.version
|
819
|
+
]
|
820
|
+
return version
|
821
|
+
|
822
|
+
@classmethod
|
823
|
+
def _get_model_version_from_db(cls, session, name, version, conditions, query_options=None):
|
824
|
+
if query_options is None:
|
825
|
+
query_options = []
|
826
|
+
versions = session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
|
827
|
+
|
828
|
+
if len(versions) == 0:
|
829
|
+
raise MlflowException(
|
830
|
+
f"Model Version (name={name}, version={version}) not found",
|
831
|
+
RESOURCE_DOES_NOT_EXIST,
|
832
|
+
)
|
833
|
+
if len(versions) > 1:
|
834
|
+
raise MlflowException(
|
835
|
+
f"Expected only 1 model version with (name={name}, version={version}). "
|
836
|
+
f"Found {len(versions)}.",
|
837
|
+
INVALID_STATE,
|
838
|
+
)
|
839
|
+
return versions[0]
|
840
|
+
|
841
|
+
@classmethod
|
842
|
+
def _get_sql_model_version(cls, session, name, version, eager=False): # noqa: D417
|
843
|
+
"""
|
844
|
+
Args:
|
845
|
+
eager: If ``True``, eagerly loads the model version's tags.
|
846
|
+
If ``False``, these attributes are not eagerly loaded and
|
847
|
+
will be loaded when their corresponding object properties
|
848
|
+
are accessed from the resulting ``SqlModelVersion`` object.
|
849
|
+
"""
|
850
|
+
_validate_model_name(name)
|
851
|
+
_validate_model_version(version)
|
852
|
+
query_options = cls._get_eager_model_version_query_options() if eager else []
|
853
|
+
conditions = [
|
854
|
+
SqlModelVersion.name == name,
|
855
|
+
SqlModelVersion.version == version,
|
856
|
+
SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
|
857
|
+
]
|
858
|
+
return cls._get_model_version_from_db(session, name, version, conditions, query_options)
|
859
|
+
|
860
|
+
def _get_sql_model_version_including_deleted(self, name, version):
|
861
|
+
"""
|
862
|
+
Private method to retrieve model versions including those that are internally deleted.
|
863
|
+
Used in tests to verify redaction behavior on deletion.
|
864
|
+
|
865
|
+
Args:
|
866
|
+
name: Registered model name.
|
867
|
+
version: Registered model version.
|
868
|
+
|
869
|
+
Returns:
|
870
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
871
|
+
"""
|
872
|
+
with self.ManagedSessionMaker() as session:
|
873
|
+
conditions = [
|
874
|
+
SqlModelVersion.name == name,
|
875
|
+
SqlModelVersion.version == version,
|
876
|
+
]
|
877
|
+
sql_model_version = self._get_model_version_from_db(session, name, version, conditions)
|
878
|
+
return self._populate_model_version_aliases(
|
879
|
+
session, name, sql_model_version.to_mlflow_entity()
|
880
|
+
)
|
881
|
+
|
882
|
+
def update_model_version(self, name, version, description=None):
|
883
|
+
"""
|
884
|
+
Update metadata associated with a model version in backend.
|
885
|
+
|
886
|
+
Args:
|
887
|
+
name: Registered model name.
|
888
|
+
version: Registered model version.
|
889
|
+
description: New model description.
|
890
|
+
|
891
|
+
Returns:
|
892
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
893
|
+
|
894
|
+
"""
|
895
|
+
with self.ManagedSessionMaker() as session:
|
896
|
+
updated_time = get_current_time_millis()
|
897
|
+
sql_model_version = self._get_sql_model_version(session, name=name, version=version)
|
898
|
+
sql_model_version.description = description
|
899
|
+
sql_model_version.last_updated_time = updated_time
|
900
|
+
session.add(sql_model_version)
|
901
|
+
return self._populate_model_version_aliases(
|
902
|
+
session, name, sql_model_version.to_mlflow_entity()
|
903
|
+
)
|
904
|
+
|
905
|
+
def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
|
906
|
+
"""
|
907
|
+
Update model version stage.
|
908
|
+
|
909
|
+
Args:
|
910
|
+
name: Registered model name.
|
911
|
+
version: Registered model version.
|
912
|
+
stage: New desired stage for this model version.
|
913
|
+
archive_existing_versions: If this flag is set to ``True``, all existing model
|
914
|
+
versions in the stage will be automatically moved to the "archived" stage. Only
|
915
|
+
valid when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will
|
916
|
+
be raised.
|
917
|
+
|
918
|
+
Returns:
|
919
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
920
|
+
|
921
|
+
"""
|
922
|
+
is_active_stage = get_canonical_stage(stage) in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS
|
923
|
+
if archive_existing_versions and not is_active_stage:
|
924
|
+
msg_tpl = (
|
925
|
+
"Model version transition cannot archive existing model versions "
|
926
|
+
"because '{}' is not an Active stage. Valid stages are {}"
|
927
|
+
)
|
928
|
+
raise MlflowException(msg_tpl.format(stage, DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS))
|
929
|
+
|
930
|
+
with self.ManagedSessionMaker() as session:
|
931
|
+
last_updated_time = get_current_time_millis()
|
932
|
+
|
933
|
+
model_versions = []
|
934
|
+
if archive_existing_versions:
|
935
|
+
conditions = [
|
936
|
+
SqlModelVersion.name == name,
|
937
|
+
SqlModelVersion.version != version,
|
938
|
+
SqlModelVersion.current_stage == get_canonical_stage(stage),
|
939
|
+
]
|
940
|
+
model_versions = session.query(SqlModelVersion).filter(*conditions).all()
|
941
|
+
for mv in model_versions:
|
942
|
+
mv.current_stage = STAGE_ARCHIVED
|
943
|
+
mv.last_updated_time = last_updated_time
|
944
|
+
|
945
|
+
sql_model_version = self._get_sql_model_version(
|
946
|
+
session=session, name=name, version=version
|
947
|
+
)
|
948
|
+
sql_model_version.current_stage = get_canonical_stage(stage)
|
949
|
+
sql_model_version.last_updated_time = last_updated_time
|
950
|
+
sql_registered_model = sql_model_version.registered_model
|
951
|
+
sql_registered_model.last_updated_time = last_updated_time
|
952
|
+
session.add_all([*model_versions, sql_model_version, sql_registered_model])
|
953
|
+
return self._populate_model_version_aliases(
|
954
|
+
session, name, sql_model_version.to_mlflow_entity()
|
955
|
+
)
|
956
|
+
|
957
|
+
def delete_model_version(self, name, version):
|
958
|
+
"""
|
959
|
+
Delete model version in backend.
|
960
|
+
|
961
|
+
Args:
|
962
|
+
name: Registered model name.
|
963
|
+
version: Registered model version.
|
964
|
+
|
965
|
+
Returns:
|
966
|
+
None
|
967
|
+
"""
|
968
|
+
# currently delete model version still keeps the tags associated with the version
|
969
|
+
with self.ManagedSessionMaker() as session:
|
970
|
+
updated_time = get_current_time_millis()
|
971
|
+
sql_model_version = self._get_sql_model_version(session, name, version)
|
972
|
+
sql_registered_model = sql_model_version.registered_model
|
973
|
+
sql_registered_model.last_updated_time = updated_time
|
974
|
+
aliases = sql_registered_model.registered_model_aliases
|
975
|
+
for alias in aliases:
|
976
|
+
if alias.version == version:
|
977
|
+
session.delete(alias)
|
978
|
+
sql_model_version.current_stage = STAGE_DELETED_INTERNAL
|
979
|
+
sql_model_version.last_updated_time = updated_time
|
980
|
+
sql_model_version.description = None
|
981
|
+
sql_model_version.user_id = None
|
982
|
+
sql_model_version.source = "REDACTED-SOURCE-PATH"
|
983
|
+
sql_model_version.run_id = "REDACTED-RUN-ID"
|
984
|
+
sql_model_version.run_link = "REDACTED-RUN-LINK"
|
985
|
+
sql_model_version.status_message = None
|
986
|
+
session.add_all([sql_registered_model, sql_model_version])
|
987
|
+
|
988
|
+
def get_model_version(self, name, version):
|
989
|
+
"""
|
990
|
+
Get the model version instance by name and version.
|
991
|
+
|
992
|
+
Args:
|
993
|
+
name: Registered model name.
|
994
|
+
version: Registered model version.
|
995
|
+
|
996
|
+
Returns:
|
997
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
998
|
+
"""
|
999
|
+
with self.ManagedSessionMaker() as session:
|
1000
|
+
sql_model_version = self._get_sql_model_version(session, name, version, eager=True)
|
1001
|
+
return self._populate_model_version_aliases(
|
1002
|
+
session, name, sql_model_version.to_mlflow_entity()
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
def get_model_version_download_uri(self, name, version):
|
1006
|
+
"""
|
1007
|
+
Get the download location in Model Registry for this model version.
|
1008
|
+
NOTE: For first version of Model Registry, since the models are not copied over to another
|
1009
|
+
location, download URI points to input source path.
|
1010
|
+
|
1011
|
+
Args:
|
1012
|
+
name: Registered model name.
|
1013
|
+
version: Registered model version.
|
1014
|
+
|
1015
|
+
Returns:
|
1016
|
+
A single URI location that allows reads for downloading.
|
1017
|
+
"""
|
1018
|
+
with self.ManagedSessionMaker() as session:
|
1019
|
+
sql_model_version = self._get_sql_model_version(session, name, version)
|
1020
|
+
return sql_model_version.storage_location or sql_model_version.source
|
1021
|
+
|
1022
|
+
def search_model_versions(
|
1023
|
+
self,
|
1024
|
+
filter_string=None,
|
1025
|
+
max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
|
1026
|
+
order_by=None,
|
1027
|
+
page_token=None,
|
1028
|
+
):
|
1029
|
+
"""
|
1030
|
+
Search for model versions in backend that satisfy the filter criteria.
|
1031
|
+
|
1032
|
+
Args:
|
1033
|
+
filter_string: A filter string expression. Currently supports a single filter
|
1034
|
+
condition either name of model like ``name = 'model_name'`` or
|
1035
|
+
``run_id = '...'``.
|
1036
|
+
max_results: Maximum number of model versions desired.
|
1037
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
1038
|
+
matching search results.
|
1039
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
1040
|
+
a ``search_model_versions`` call.
|
1041
|
+
|
1042
|
+
Returns:
|
1043
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
1044
|
+
objects that satisfy the search expressions. The pagination token for the next
|
1045
|
+
page can be obtained via the ``token`` attribute of the object.
|
1046
|
+
|
1047
|
+
"""
|
1048
|
+
if not isinstance(max_results, int) or max_results < 1:
|
1049
|
+
raise MlflowException(
|
1050
|
+
"Invalid value for max_results. It must be a positive integer,"
|
1051
|
+
f" but got {max_results}",
|
1052
|
+
INVALID_PARAMETER_VALUE,
|
1053
|
+
)
|
1054
|
+
|
1055
|
+
if max_results > SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD:
|
1056
|
+
raise MlflowException(
|
1057
|
+
"Invalid value for request parameter max_results. It must be at most "
|
1058
|
+
f"{SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
|
1059
|
+
INVALID_PARAMETER_VALUE,
|
1060
|
+
)
|
1061
|
+
|
1062
|
+
parsed_filters = SearchModelVersionUtils.parse_search_filter(filter_string)
|
1063
|
+
|
1064
|
+
filter_query = self._get_search_model_versions_filter_clauses(
|
1065
|
+
parsed_filters, self.engine.dialect.name
|
1066
|
+
)
|
1067
|
+
|
1068
|
+
parsed_orderby = self._parse_search_model_versions_order_by(
|
1069
|
+
order_by or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"]
|
1070
|
+
)
|
1071
|
+
offset = SearchUtils.parse_start_offset_from_page_token(page_token)
|
1072
|
+
# we query for max_results + 1 items to check whether there is another page to return.
|
1073
|
+
# this remediates having to make another query which returns no items.
|
1074
|
+
max_results_for_query = max_results + 1
|
1075
|
+
|
1076
|
+
with self.ManagedSessionMaker() as session:
|
1077
|
+
query = (
|
1078
|
+
filter_query.options(*self._get_eager_model_version_query_options())
|
1079
|
+
.filter(SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
|
1080
|
+
.order_by(*parsed_orderby)
|
1081
|
+
.limit(max_results_for_query)
|
1082
|
+
)
|
1083
|
+
if page_token:
|
1084
|
+
query = query.offset(offset)
|
1085
|
+
sql_model_versions = session.execute(query).scalars(SqlModelVersion).all()
|
1086
|
+
next_page_token = self._compute_next_token(
|
1087
|
+
max_results_for_query, len(sql_model_versions), offset, max_results
|
1088
|
+
)
|
1089
|
+
model_versions = [mv.to_mlflow_entity() for mv in sql_model_versions][:max_results]
|
1090
|
+
return PagedList(model_versions, next_page_token)
|
1091
|
+
|
1092
|
+
@classmethod
|
1093
|
+
def _parse_search_model_versions_order_by(cls, order_by_list):
|
1094
|
+
"""Sorts a set of model versions based on their natural ordering and an overriding set
|
1095
|
+
of order_bys. Model versions are naturally ordered first by name ascending, then by
|
1096
|
+
version ascending.
|
1097
|
+
"""
|
1098
|
+
clauses = []
|
1099
|
+
observed_order_by_clauses = set()
|
1100
|
+
if order_by_list:
|
1101
|
+
for order_by_clause in order_by_list:
|
1102
|
+
(
|
1103
|
+
_,
|
1104
|
+
key,
|
1105
|
+
ascending,
|
1106
|
+
) = SearchModelVersionUtils.parse_order_by_for_search_model_versions(
|
1107
|
+
order_by_clause
|
1108
|
+
)
|
1109
|
+
if key not in SearchModelVersionUtils.VALID_ORDER_BY_ATTRIBUTE_KEYS:
|
1110
|
+
raise MlflowException(
|
1111
|
+
f"Invalid order by key '{key}' specified. "
|
1112
|
+
"Valid keys are "
|
1113
|
+
f"{SearchModelVersionUtils.VALID_ORDER_BY_ATTRIBUTE_KEYS}",
|
1114
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1115
|
+
)
|
1116
|
+
else:
|
1117
|
+
if key == "version_number":
|
1118
|
+
field = SqlModelVersion.version
|
1119
|
+
elif key == "creation_timestamp":
|
1120
|
+
field = SqlModelVersion.creation_time
|
1121
|
+
elif key == "last_updated_timestamp":
|
1122
|
+
field = SqlModelVersion.last_updated_time
|
1123
|
+
else:
|
1124
|
+
field = getattr(SqlModelVersion, key)
|
1125
|
+
if field.key in observed_order_by_clauses:
|
1126
|
+
raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
|
1127
|
+
observed_order_by_clauses.add(field.key)
|
1128
|
+
if ascending:
|
1129
|
+
clauses.append(field.asc())
|
1130
|
+
else:
|
1131
|
+
clauses.append(field.desc())
|
1132
|
+
|
1133
|
+
if SqlModelVersion.name.key not in observed_order_by_clauses:
|
1134
|
+
clauses.append(SqlModelVersion.name.asc())
|
1135
|
+
if SqlModelVersion.version.key not in observed_order_by_clauses:
|
1136
|
+
clauses.append(SqlModelVersion.version.desc())
|
1137
|
+
return clauses
|
1138
|
+
|
1139
|
+
@classmethod
|
1140
|
+
def _get_model_version_tag(cls, session, name, version, key):
|
1141
|
+
tags = (
|
1142
|
+
session.query(SqlModelVersionTag)
|
1143
|
+
.filter(
|
1144
|
+
SqlModelVersionTag.name == name,
|
1145
|
+
SqlModelVersionTag.version == version,
|
1146
|
+
SqlModelVersionTag.key == key,
|
1147
|
+
)
|
1148
|
+
.all()
|
1149
|
+
)
|
1150
|
+
if len(tags) == 0:
|
1151
|
+
return None
|
1152
|
+
if len(tags) > 1:
|
1153
|
+
raise MlflowException(
|
1154
|
+
f"Expected only 1 model version tag with name={name}, version={version}, "
|
1155
|
+
f"key={key}. Found {len(tags)}.",
|
1156
|
+
INVALID_STATE,
|
1157
|
+
)
|
1158
|
+
return tags[0]
|
1159
|
+
|
1160
|
+
def set_model_version_tag(self, name, version, tag):
|
1161
|
+
"""
|
1162
|
+
Set a tag for the model version.
|
1163
|
+
|
1164
|
+
Args:
|
1165
|
+
name: Registered model name.
|
1166
|
+
version: Registered model version.
|
1167
|
+
tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
|
1168
|
+
|
1169
|
+
Returns:
|
1170
|
+
None
|
1171
|
+
"""
|
1172
|
+
_validate_model_name(name)
|
1173
|
+
_validate_model_version(version)
|
1174
|
+
_validate_model_version_tag(tag.key, tag.value)
|
1175
|
+
with self.ManagedSessionMaker() as session:
|
1176
|
+
# check if model version exists
|
1177
|
+
self._get_sql_model_version(session, name, version)
|
1178
|
+
session.merge(
|
1179
|
+
SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value)
|
1180
|
+
)
|
1181
|
+
|
1182
|
+
def delete_model_version_tag(self, name, version, key):
|
1183
|
+
"""
|
1184
|
+
Delete a tag associated with the model version.
|
1185
|
+
|
1186
|
+
Args:
|
1187
|
+
name: Registered model name.
|
1188
|
+
version: Registered model version.
|
1189
|
+
key: Tag key.
|
1190
|
+
|
1191
|
+
Returns:
|
1192
|
+
None
|
1193
|
+
"""
|
1194
|
+
_validate_model_name(name)
|
1195
|
+
_validate_model_version(version)
|
1196
|
+
_validate_tag_name(key)
|
1197
|
+
with self.ManagedSessionMaker() as session:
|
1198
|
+
# check if model version exists
|
1199
|
+
self._get_sql_model_version(session, name, version)
|
1200
|
+
existing_tag = self._get_model_version_tag(session, name, version, key)
|
1201
|
+
if existing_tag is not None:
|
1202
|
+
session.delete(existing_tag)
|
1203
|
+
|
1204
|
+
@classmethod
|
1205
|
+
def _get_registered_model_alias(cls, session, name, alias):
|
1206
|
+
return (
|
1207
|
+
session.query(SqlRegisteredModelAlias)
|
1208
|
+
.filter(
|
1209
|
+
SqlRegisteredModelAlias.name == name,
|
1210
|
+
SqlRegisteredModelAlias.alias == alias,
|
1211
|
+
)
|
1212
|
+
.first()
|
1213
|
+
)
|
1214
|
+
|
1215
|
+
def set_registered_model_alias(self, name, alias, version):
|
1216
|
+
"""
|
1217
|
+
Set a registered model alias pointing to a model version.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
name: Registered model name.
|
1221
|
+
alias: Name of the alias.
|
1222
|
+
version: Registered model version number.
|
1223
|
+
|
1224
|
+
Returns:
|
1225
|
+
None
|
1226
|
+
"""
|
1227
|
+
_validate_model_name(name)
|
1228
|
+
_validate_model_alias_name(alias)
|
1229
|
+
_validate_model_version(version)
|
1230
|
+
with self.ManagedSessionMaker() as session:
|
1231
|
+
# check if model version exists
|
1232
|
+
self._get_sql_model_version(session, name, version)
|
1233
|
+
session.merge(SqlRegisteredModelAlias(name=name, alias=alias, version=version))
|
1234
|
+
|
1235
|
+
def delete_registered_model_alias(self, name, alias):
|
1236
|
+
"""
|
1237
|
+
Delete an alias associated with a registered model.
|
1238
|
+
|
1239
|
+
Args:
|
1240
|
+
name: Registered model name.
|
1241
|
+
alias: Name of the alias.
|
1242
|
+
|
1243
|
+
Returns:
|
1244
|
+
None
|
1245
|
+
"""
|
1246
|
+
_validate_model_name(name)
|
1247
|
+
_validate_model_alias_name(alias)
|
1248
|
+
with self.ManagedSessionMaker() as session:
|
1249
|
+
# check if registered model exists
|
1250
|
+
self._get_registered_model(session, name)
|
1251
|
+
existing_alias = self._get_registered_model_alias(session, name, alias)
|
1252
|
+
if existing_alias is not None:
|
1253
|
+
session.delete(existing_alias)
|
1254
|
+
|
1255
|
+
def get_model_version_by_alias(self, name, alias):
|
1256
|
+
"""
|
1257
|
+
Get the model version instance by name and alias.
|
1258
|
+
|
1259
|
+
Args:
|
1260
|
+
name: Registered model name.
|
1261
|
+
alias: Name of the alias.
|
1262
|
+
|
1263
|
+
Returns:
|
1264
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
1265
|
+
"""
|
1266
|
+
_validate_model_name(name)
|
1267
|
+
_validate_model_alias_name(alias)
|
1268
|
+
with self.ManagedSessionMaker() as session:
|
1269
|
+
existing_alias = self._get_registered_model_alias(session, name, alias)
|
1270
|
+
if existing_alias is not None:
|
1271
|
+
sql_model_version = self._get_sql_model_version(
|
1272
|
+
session, existing_alias.name, existing_alias.version
|
1273
|
+
)
|
1274
|
+
return self._populate_model_version_aliases(
|
1275
|
+
session, name, sql_model_version.to_mlflow_entity()
|
1276
|
+
)
|
1277
|
+
else:
|
1278
|
+
raise MlflowException(
|
1279
|
+
f"Registered model alias {alias} not found.", INVALID_PARAMETER_VALUE
|
1280
|
+
)
|
1281
|
+
|
1282
|
+
def _await_model_version_creation(self, mv, await_creation_for):
|
1283
|
+
"""
|
1284
|
+
Does not wait for the model version to become READY as a successful creation will
|
1285
|
+
immediately place the model version in a READY state.
|
1286
|
+
"""
|