genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1081 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import threading
|
4
|
+
from abc import ABCMeta, abstractmethod
|
5
|
+
from time import sleep, time
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
|
8
|
+
from mlflow.entities.logged_model_tag import LoggedModelTag
|
9
|
+
from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
|
10
|
+
from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
|
11
|
+
from mlflow.entities.model_registry.model_version_tag import ModelVersionTag
|
12
|
+
from mlflow.entities.model_registry.prompt import Prompt
|
13
|
+
from mlflow.entities.model_registry.prompt_version import PromptVersion
|
14
|
+
from mlflow.exceptions import MlflowException
|
15
|
+
from mlflow.prompt.constants import (
|
16
|
+
IS_PROMPT_TAG_KEY,
|
17
|
+
LINKED_PROMPTS_TAG_KEY,
|
18
|
+
PROMPT_TEXT_TAG_KEY,
|
19
|
+
)
|
20
|
+
from mlflow.prompt.registry_utils import has_prompt_tag, model_version_to_prompt_version
|
21
|
+
from mlflow.protos.databricks_pb2 import (
|
22
|
+
INVALID_PARAMETER_VALUE,
|
23
|
+
RESOURCE_ALREADY_EXISTS,
|
24
|
+
RESOURCE_DOES_NOT_EXIST,
|
25
|
+
ErrorCode,
|
26
|
+
)
|
27
|
+
from mlflow.store.entities.paged_list import PagedList
|
28
|
+
from mlflow.utils.annotations import developer_stable
|
29
|
+
from mlflow.utils.logging_utils import eprint
|
30
|
+
|
31
|
+
_logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
AWAIT_MODEL_VERSION_CREATE_SLEEP_INTERVAL_SECONDS = 3
|
34
|
+
|
35
|
+
|
36
|
+
@developer_stable
|
37
|
+
class AbstractStore:
|
38
|
+
"""
|
39
|
+
Abstract class that defines API interfaces for storing Model Registry metadata.
|
40
|
+
"""
|
41
|
+
|
42
|
+
__metaclass__ = ABCMeta
|
43
|
+
|
44
|
+
def __init__(self, store_uri=None, tracking_uri=None):
|
45
|
+
"""
|
46
|
+
Empty constructor. This is deliberately not marked as abstract, else every derived class
|
47
|
+
would be forced to create one.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
store_uri: The model registry store URI.
|
51
|
+
tracking_uri: URI of the current MLflow tracking server, used to perform operations
|
52
|
+
like fetching source run metadata or downloading source run artifacts
|
53
|
+
to support subsequently uploading them to the model registry storage
|
54
|
+
location.
|
55
|
+
"""
|
56
|
+
# Create a thread lock to ensure thread safety when linking prompts to other entities,
|
57
|
+
# since the default linking implementation reads and appends entity tags, which
|
58
|
+
# is prone to concurrent modification issues
|
59
|
+
self._prompt_link_lock = threading.RLock()
|
60
|
+
|
61
|
+
def __getstate__(self):
|
62
|
+
"""Support for pickle serialization by excluding the non-picklable RLock."""
|
63
|
+
state = self.__dict__.copy()
|
64
|
+
# Remove the RLock as it cannot be pickled
|
65
|
+
del state["_prompt_link_lock"]
|
66
|
+
return state
|
67
|
+
|
68
|
+
def __setstate__(self, state):
|
69
|
+
"""Support for pickle deserialization by recreating the RLock."""
|
70
|
+
self.__dict__.update(state)
|
71
|
+
# Recreate the RLock
|
72
|
+
self._prompt_link_lock = threading.RLock()
|
73
|
+
|
74
|
+
# CRUD API for RegisteredModel objects
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
|
78
|
+
"""
|
79
|
+
Create a new registered model in backend store.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
name: Name of the new model. This is expected to be unique in the backend store.
|
83
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
|
84
|
+
instances associated with this registered model.
|
85
|
+
description: Description of the model.
|
86
|
+
deployment_job_id: Optional deployment job ID.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
|
90
|
+
created in the backend.
|
91
|
+
|
92
|
+
"""
|
93
|
+
|
94
|
+
@abstractmethod
|
95
|
+
def update_registered_model(self, name, description, deployment_job_id=None):
|
96
|
+
"""
|
97
|
+
Update description of the registered model.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
name: Registered model name.
|
101
|
+
description: New description.
|
102
|
+
deployment_job_id: Optional deployment job ID.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
106
|
+
"""
|
107
|
+
|
108
|
+
@abstractmethod
|
109
|
+
def rename_registered_model(self, name, new_name):
|
110
|
+
"""
|
111
|
+
Rename the registered model.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
name: Registered model name.
|
115
|
+
new_name: New proposed name.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
119
|
+
"""
|
120
|
+
|
121
|
+
@abstractmethod
|
122
|
+
def delete_registered_model(self, name):
|
123
|
+
"""
|
124
|
+
Delete the registered model.
|
125
|
+
Backend raises exception if a registered model with given name does not exist.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
name: Registered model name.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
None
|
132
|
+
"""
|
133
|
+
|
134
|
+
@abstractmethod
|
135
|
+
def search_registered_models(
|
136
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
137
|
+
):
|
138
|
+
"""
|
139
|
+
Search for registered models in backend that satisfy the filter criteria.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
filter_string: Filter query string, defaults to searching all registered models.
|
143
|
+
max_results: Maximum number of registered models desired.
|
144
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
145
|
+
matching search results.
|
146
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
147
|
+
a ``search_registered_models`` call.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
|
151
|
+
that satisfy the search expressions. The pagination token for the next page can be
|
152
|
+
obtained via the ``token`` attribute of the object.
|
153
|
+
"""
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def get_registered_model(self, name):
|
157
|
+
"""
|
158
|
+
Get registered model instance by name.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
name: Registered model name.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
|
165
|
+
"""
|
166
|
+
|
167
|
+
@abstractmethod
|
168
|
+
def get_latest_versions(self, name, stages=None):
|
169
|
+
"""
|
170
|
+
Latest version models for each requested stage. If no ``stages`` argument is provided,
|
171
|
+
returns the latest version for each stage.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
name: Registered model name.
|
175
|
+
stages: List of desired stages. If input list is None, return latest versions for
|
176
|
+
each stage.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
|
180
|
+
"""
|
181
|
+
|
182
|
+
@abstractmethod
|
183
|
+
def set_registered_model_tag(self, name, tag):
|
184
|
+
"""
|
185
|
+
Set a tag for the registered model.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
name: Registered model name.
|
189
|
+
tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
None
|
193
|
+
"""
|
194
|
+
|
195
|
+
@abstractmethod
|
196
|
+
def delete_registered_model_tag(self, name, key):
|
197
|
+
"""
|
198
|
+
Delete a tag associated with the registered model.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
name: Registered model name.
|
202
|
+
key: Registered model tag key.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
None
|
206
|
+
"""
|
207
|
+
|
208
|
+
# CRUD API for ModelVersion objects
|
209
|
+
|
210
|
+
@abstractmethod
|
211
|
+
def create_model_version(
|
212
|
+
self,
|
213
|
+
name,
|
214
|
+
source,
|
215
|
+
run_id=None,
|
216
|
+
tags=None,
|
217
|
+
run_link=None,
|
218
|
+
description=None,
|
219
|
+
local_model_path=None,
|
220
|
+
model_id: Optional[str] = None,
|
221
|
+
):
|
222
|
+
"""
|
223
|
+
Create a new model version from given source and run ID.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
name: Registered model name.
|
227
|
+
source: URI indicating the location of the model artifacts.
|
228
|
+
run_id: Run ID from MLflow tracking server that generated the model.
|
229
|
+
tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
|
230
|
+
instances associated with this model version.
|
231
|
+
run_link: Link to the run from an MLflow tracking server that generated this model.
|
232
|
+
description: Description of the version.
|
233
|
+
local_model_path: Local path to the MLflow model, if it's already accessible
|
234
|
+
on the local filesystem. Can be used by AbstractStores that
|
235
|
+
upload model version files to the model registry to avoid
|
236
|
+
a redundant download from the source location when logging
|
237
|
+
and registering a model via a single
|
238
|
+
mlflow.<flavor>.log_model(..., registered_model_name) call
|
239
|
+
model_id: The ID of the model (from an Experiment) that is being promoted to a
|
240
|
+
registered model version, if applicable.
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
244
|
+
created in the backend.
|
245
|
+
|
246
|
+
"""
|
247
|
+
|
248
|
+
@abstractmethod
|
249
|
+
def update_model_version(self, name, version, description):
|
250
|
+
"""
|
251
|
+
Update metadata associated with a model version in backend.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
name: Registered model name.
|
255
|
+
version: Registered model version.
|
256
|
+
description: New model description.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
260
|
+
"""
|
261
|
+
|
262
|
+
@abstractmethod
|
263
|
+
def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
|
264
|
+
"""
|
265
|
+
Update model version stage.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
name: Registered model name.
|
269
|
+
version: Registered model version.
|
270
|
+
stage: New desired stage for this model version.
|
271
|
+
archive_existing_versions: If this flag is set to ``True``, all existing model
|
272
|
+
versions in the stage will be automatically moved to the "archived" stage. Only
|
273
|
+
valid when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will
|
274
|
+
be raised.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
278
|
+
|
279
|
+
"""
|
280
|
+
|
281
|
+
@abstractmethod
|
282
|
+
def delete_model_version(self, name, version):
|
283
|
+
"""
|
284
|
+
Delete model model version in backend.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
name: Registered model name.
|
288
|
+
version: Registered model version.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
None
|
292
|
+
"""
|
293
|
+
|
294
|
+
@abstractmethod
|
295
|
+
def get_model_version(self, name, version):
|
296
|
+
"""
|
297
|
+
Get the model version instance by name and version.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
name: Registered model name.
|
301
|
+
version: Registered model version.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
305
|
+
"""
|
306
|
+
|
307
|
+
@abstractmethod
|
308
|
+
def get_model_version_download_uri(self, name, version):
|
309
|
+
"""
|
310
|
+
Get the download location in Model Registry for this model version.
|
311
|
+
NOTE: For first version of Model Registry, since the models are not copied over to another
|
312
|
+
location, download URI points to input source path.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
name: Registered model name.
|
316
|
+
version: Registered model version.
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
A single URI location that allows reads for downloading.
|
320
|
+
"""
|
321
|
+
|
322
|
+
@abstractmethod
|
323
|
+
def search_model_versions(
|
324
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
325
|
+
):
|
326
|
+
"""
|
327
|
+
Search for model versions in backend that satisfy the filter criteria.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
filter_string: A filter string expression. Currently supports a single filter
|
331
|
+
condition either name of model like ``name = 'model_name'`` or
|
332
|
+
``run_id = '...'``.
|
333
|
+
max_results: Maximum number of model versions desired.
|
334
|
+
order_by: List of column names with ASC|DESC annotation, to be used for ordering
|
335
|
+
matching search results.
|
336
|
+
page_token: Token specifying the next page of results. It should be obtained from
|
337
|
+
a ``search_model_versions`` call.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
A PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
|
341
|
+
objects that satisfy the search expressions. The pagination token for the next
|
342
|
+
page can be obtained via the ``token`` attribute of the object.
|
343
|
+
|
344
|
+
"""
|
345
|
+
|
346
|
+
@abstractmethod
|
347
|
+
def set_model_version_tag(self, name, version, tag):
|
348
|
+
"""
|
349
|
+
Set a tag for the model version.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
name: Registered model name.
|
353
|
+
version: Registered model version.
|
354
|
+
tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
None
|
358
|
+
"""
|
359
|
+
|
360
|
+
@abstractmethod
|
361
|
+
def delete_model_version_tag(self, name, version, key):
|
362
|
+
"""
|
363
|
+
Delete a tag associated with the model version.
|
364
|
+
|
365
|
+
Args:
|
366
|
+
name: Registered model name.
|
367
|
+
version: Registered model version.
|
368
|
+
key: Tag key.
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
None
|
372
|
+
"""
|
373
|
+
|
374
|
+
@abstractmethod
|
375
|
+
def set_registered_model_alias(self, name, alias, version):
|
376
|
+
"""
|
377
|
+
Set a registered model alias pointing to a model version.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
name: Registered model name.
|
381
|
+
alias: Name of the alias.
|
382
|
+
version: Registered model version number.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
None
|
386
|
+
"""
|
387
|
+
|
388
|
+
@abstractmethod
|
389
|
+
def delete_registered_model_alias(self, name, alias):
|
390
|
+
"""
|
391
|
+
Delete an alias associated with a registered model.
|
392
|
+
|
393
|
+
Args:
|
394
|
+
name: Registered model name.
|
395
|
+
alias: Name of the alias.
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
None
|
399
|
+
"""
|
400
|
+
|
401
|
+
@abstractmethod
|
402
|
+
def get_model_version_by_alias(self, name, alias):
|
403
|
+
"""
|
404
|
+
Get the model version instance by name and alias.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
name: Registered model name.
|
408
|
+
alias: Name of the alias.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
412
|
+
"""
|
413
|
+
|
414
|
+
def copy_model_version(self, src_mv, dst_name):
|
415
|
+
"""
|
416
|
+
Copy a model version from one registered model to another as a new model version.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
src_mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
|
420
|
+
the source model version.
|
421
|
+
dst_name: The name of the registered model to copy the model version to. If a
|
422
|
+
registered model with this name does not exist, it will be created.
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
|
426
|
+
the cloned model version.
|
427
|
+
"""
|
428
|
+
try:
|
429
|
+
create_model_response = self.create_registered_model(dst_name)
|
430
|
+
eprint(f"Successfully registered model '{create_model_response.name}'.")
|
431
|
+
except MlflowException as e:
|
432
|
+
if e.error_code != ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
|
433
|
+
raise
|
434
|
+
eprint(
|
435
|
+
f"Registered model '{dst_name}' already exists."
|
436
|
+
f" Creating a new version of this model..."
|
437
|
+
)
|
438
|
+
|
439
|
+
try:
|
440
|
+
mv_copy = self.create_model_version(
|
441
|
+
name=dst_name,
|
442
|
+
source=f"models:/{src_mv.name}/{src_mv.version}",
|
443
|
+
run_id=src_mv.run_id,
|
444
|
+
tags=[ModelVersionTag(k, v) for k, v in src_mv.tags.items()],
|
445
|
+
run_link=src_mv.run_link,
|
446
|
+
description=src_mv.description,
|
447
|
+
)
|
448
|
+
eprint(
|
449
|
+
f"Copied version '{src_mv.version}' of model '{src_mv.name}'"
|
450
|
+
f" to version '{mv_copy.version}' of model '{mv_copy.name}'."
|
451
|
+
)
|
452
|
+
except MlflowException as e:
|
453
|
+
raise MlflowException(
|
454
|
+
f"Failed to create model version copy. The current model registry backend "
|
455
|
+
f"may not yet support model version URI sources.\nError: {e}"
|
456
|
+
) from e
|
457
|
+
|
458
|
+
return mv_copy
|
459
|
+
|
460
|
+
def _await_model_version_creation(self, mv, await_creation_for):
|
461
|
+
"""
|
462
|
+
Await for model version to become ready after creation.
|
463
|
+
|
464
|
+
Args:
|
465
|
+
mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object.
|
466
|
+
await_creation_for: Number of seconds to wait for the model version to finish being
|
467
|
+
created and is in ``READY`` status.
|
468
|
+
"""
|
469
|
+
self._await_model_version_creation_impl(mv, await_creation_for)
|
470
|
+
|
471
|
+
def _await_model_version_creation_impl(self, mv, await_creation_for, hint=""):
|
472
|
+
entity_type = "Prompt" if has_prompt_tag(mv.tags) else "Model"
|
473
|
+
_logger.info(
|
474
|
+
f"Waiting up to {await_creation_for} seconds for {entity_type.lower()} version to "
|
475
|
+
f"finish creation. {entity_type} name: {mv.name}, version {mv.version}",
|
476
|
+
)
|
477
|
+
max_time = time() + await_creation_for
|
478
|
+
pending_status = ModelVersionStatus.to_string(ModelVersionStatus.PENDING_REGISTRATION)
|
479
|
+
while mv.status == pending_status:
|
480
|
+
if time() > max_time:
|
481
|
+
raise MlflowException(
|
482
|
+
f"Exceeded max wait time for model name: {mv.name} version: {mv.version} "
|
483
|
+
f"to become READY. Status: {mv.status} Wait Time: {await_creation_for}"
|
484
|
+
f".{hint}"
|
485
|
+
)
|
486
|
+
mv = self.get_model_version(mv.name, mv.version)
|
487
|
+
if mv.status != pending_status:
|
488
|
+
break
|
489
|
+
sleep(AWAIT_MODEL_VERSION_CREATE_SLEEP_INTERVAL_SECONDS)
|
490
|
+
if mv.status != ModelVersionStatus.to_string(ModelVersionStatus.READY):
|
491
|
+
raise MlflowException(
|
492
|
+
f"{entity_type} version creation failed for {entity_type.lower()} name: {mv.name} "
|
493
|
+
f"version: {mv.version} with status: {mv.status} and message: {mv.status_message}"
|
494
|
+
)
|
495
|
+
|
496
|
+
# Prompt-related methods with concrete implementations for OSS stores
|
497
|
+
|
498
|
+
def create_prompt(
|
499
|
+
self,
|
500
|
+
name: str,
|
501
|
+
description: Optional[str] = None,
|
502
|
+
tags: Optional[dict[str, str]] = None,
|
503
|
+
) -> Prompt:
|
504
|
+
"""
|
505
|
+
Create a new prompt in the registry.
|
506
|
+
|
507
|
+
Default implementation: creates a RegisteredModel with special prompt tags.
|
508
|
+
Other store implementations may override this method.
|
509
|
+
|
510
|
+
Args:
|
511
|
+
name: Name of the prompt.
|
512
|
+
description: Optional description of the prompt.
|
513
|
+
tags: Optional dictionary of prompt tags.
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
A Prompt object representing the created prompt.
|
517
|
+
"""
|
518
|
+
# Default implementation: use RegisteredModel with special tags
|
519
|
+
prompt_tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
|
520
|
+
if tags:
|
521
|
+
prompt_tags.extend([RegisteredModelTag(key=k, value=v) for k, v in tags.items()])
|
522
|
+
|
523
|
+
# Create registered model for the prompt
|
524
|
+
rm = self.create_registered_model(name, tags=prompt_tags, description=description)
|
525
|
+
|
526
|
+
# Return as Prompt
|
527
|
+
return Prompt(
|
528
|
+
name=rm.name,
|
529
|
+
description=rm.description,
|
530
|
+
creation_timestamp=rm.creation_timestamp,
|
531
|
+
tags=tags or {},
|
532
|
+
)
|
533
|
+
|
534
|
+
def search_prompts(
|
535
|
+
self,
|
536
|
+
filter_string: Optional[str] = None,
|
537
|
+
max_results: Optional[int] = None,
|
538
|
+
order_by: Optional[list[str]] = None,
|
539
|
+
page_token: Optional[str] = None,
|
540
|
+
) -> PagedList[Prompt]:
|
541
|
+
"""
|
542
|
+
Search for prompts in the registry.
|
543
|
+
|
544
|
+
Default implementation: searches RegisteredModels with prompt tags.
|
545
|
+
Other store implementations may override this method.
|
546
|
+
|
547
|
+
Args:
|
548
|
+
filter_string: Filter query string, defaults to searching for all prompts.
|
549
|
+
max_results: Maximum number of prompts desired.
|
550
|
+
order_by: List of order-by clauses.
|
551
|
+
page_token: Pagination token for requesting subsequent pages.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
A PagedList of Prompt objects.
|
555
|
+
"""
|
556
|
+
if max_results is None:
|
557
|
+
max_results = 100
|
558
|
+
|
559
|
+
# Build filter to only include prompts (use backticks for tag key with dots)
|
560
|
+
prompt_filter = f"tags.`{IS_PROMPT_TAG_KEY}` = 'true'"
|
561
|
+
if filter_string:
|
562
|
+
prompt_filter = f"{prompt_filter} AND {filter_string}"
|
563
|
+
|
564
|
+
# Search registered models with prompt filter
|
565
|
+
registered_models = self.search_registered_models(
|
566
|
+
filter_string=prompt_filter,
|
567
|
+
max_results=max_results,
|
568
|
+
order_by=order_by,
|
569
|
+
page_token=page_token,
|
570
|
+
)
|
571
|
+
|
572
|
+
# Convert RegisteredModel objects to Prompt objects
|
573
|
+
prompts = []
|
574
|
+
for rm in registered_models:
|
575
|
+
# Extract tags as dict
|
576
|
+
if isinstance(rm.tags, dict):
|
577
|
+
tags = rm.tags.copy()
|
578
|
+
else:
|
579
|
+
tags = {tag.key: tag.value for tag in rm.tags} if rm.tags else {}
|
580
|
+
|
581
|
+
# Remove the internal prompt tag from user-visible tags
|
582
|
+
tags.pop(IS_PROMPT_TAG_KEY, None)
|
583
|
+
|
584
|
+
# Create Prompt object
|
585
|
+
prompt_info = Prompt(
|
586
|
+
name=rm.name,
|
587
|
+
description=rm.description,
|
588
|
+
creation_timestamp=rm.creation_timestamp,
|
589
|
+
tags=tags,
|
590
|
+
)
|
591
|
+
prompts.append(prompt_info)
|
592
|
+
|
593
|
+
return PagedList(prompts, registered_models.token)
|
594
|
+
|
595
|
+
def delete_prompt(self, name: str) -> None:
|
596
|
+
"""
|
597
|
+
Delete a prompt from the registry.
|
598
|
+
|
599
|
+
Default implementation: deletes the underlying RegisteredModel.
|
600
|
+
Other store implementations may override this method.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
name: Name of the prompt to delete.
|
604
|
+
"""
|
605
|
+
# Default implementation: delete the registered model
|
606
|
+
return self.delete_registered_model(name)
|
607
|
+
|
608
|
+
def set_prompt_tag(self, name: str, key: str, value: str) -> None:
|
609
|
+
"""
|
610
|
+
Set a tag on a prompt.
|
611
|
+
|
612
|
+
Default implementation: sets a tag on the underlying RegisteredModel.
|
613
|
+
Other store implementations may override this method.
|
614
|
+
|
615
|
+
Args:
|
616
|
+
name: Name of the prompt.
|
617
|
+
key: Tag key.
|
618
|
+
value: Tag value.
|
619
|
+
"""
|
620
|
+
# Default implementation: set tag on registered model
|
621
|
+
tag = RegisteredModelTag(key=key, value=value)
|
622
|
+
return self.set_registered_model_tag(name, tag)
|
623
|
+
|
624
|
+
def delete_prompt_tag(self, name: str, key: str) -> None:
|
625
|
+
"""
|
626
|
+
Delete a tag from a prompt.
|
627
|
+
|
628
|
+
Default implementation: deletes a tag from the underlying RegisteredModel.
|
629
|
+
Other store implementations may override this method.
|
630
|
+
|
631
|
+
Args:
|
632
|
+
name: Name of the prompt.
|
633
|
+
key: Tag key to delete.
|
634
|
+
"""
|
635
|
+
# Default implementation: delete tag from registered model
|
636
|
+
return self.delete_registered_model_tag(name, key)
|
637
|
+
|
638
|
+
def get_prompt(self, name: str) -> Optional[Prompt]:
|
639
|
+
"""
|
640
|
+
Get prompt metadata by name.
|
641
|
+
|
642
|
+
Default implementation: gets RegisteredModel with prompt tags and converts to Prompt.
|
643
|
+
Other store implementations may override this method.
|
644
|
+
|
645
|
+
Args:
|
646
|
+
name: Registered prompt name.
|
647
|
+
|
648
|
+
Returns:
|
649
|
+
A single Prompt object with prompt metadata, or None if not found.
|
650
|
+
"""
|
651
|
+
try:
|
652
|
+
rm = self.get_registered_model(name)
|
653
|
+
|
654
|
+
# Check if this is actually a prompt using _tags (internal tags)
|
655
|
+
if isinstance(rm._tags, dict):
|
656
|
+
internal_tags = rm._tags.copy()
|
657
|
+
else:
|
658
|
+
internal_tags = {tag.key: tag.value for tag in rm._tags} if rm._tags else {}
|
659
|
+
|
660
|
+
if not internal_tags.get(IS_PROMPT_TAG_KEY) == "true":
|
661
|
+
return None
|
662
|
+
|
663
|
+
# Get user-visible tags (without internal prompt tag)
|
664
|
+
if isinstance(rm.tags, dict):
|
665
|
+
user_tags = rm.tags.copy()
|
666
|
+
else:
|
667
|
+
user_tags = {tag.key: tag.value for tag in rm.tags} if rm.tags else {}
|
668
|
+
|
669
|
+
return Prompt(
|
670
|
+
name=rm.name,
|
671
|
+
description=rm.description,
|
672
|
+
creation_timestamp=rm.creation_timestamp,
|
673
|
+
tags=user_tags,
|
674
|
+
)
|
675
|
+
|
676
|
+
except Exception:
|
677
|
+
return None
|
678
|
+
|
679
|
+
def create_prompt_version(
|
680
|
+
self,
|
681
|
+
name: str,
|
682
|
+
template: str,
|
683
|
+
description: Optional[str] = None,
|
684
|
+
tags: Optional[dict[str, str]] = None,
|
685
|
+
) -> PromptVersion:
|
686
|
+
"""
|
687
|
+
Create a new version of an existing prompt.
|
688
|
+
|
689
|
+
Default implementation: creates a ModelVersion with prompt tags.
|
690
|
+
Other store implementations may override this method.
|
691
|
+
|
692
|
+
Args:
|
693
|
+
name: Name of the prompt.
|
694
|
+
template: The prompt template text.
|
695
|
+
description: Optional description of the prompt version.
|
696
|
+
tags: Optional dictionary of version tags.
|
697
|
+
|
698
|
+
Returns:
|
699
|
+
A PromptVersion object representing the created version.
|
700
|
+
"""
|
701
|
+
# Create version tags including template
|
702
|
+
version_tags = [
|
703
|
+
ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
|
704
|
+
ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value=template),
|
705
|
+
]
|
706
|
+
if tags:
|
707
|
+
version_tags.extend([ModelVersionTag(key=k, value=v) for k, v in tags.items()])
|
708
|
+
|
709
|
+
# Create model version
|
710
|
+
mv = self.create_model_version(
|
711
|
+
name=name,
|
712
|
+
source="prompt-template", # Required field for ModelVersion
|
713
|
+
tags=version_tags,
|
714
|
+
description=description,
|
715
|
+
)
|
716
|
+
|
717
|
+
# Get prompt-level tags from registered model
|
718
|
+
rm = self.get_registered_model(name)
|
719
|
+
if isinstance(rm.tags, dict):
|
720
|
+
prompt_tags = rm.tags.copy()
|
721
|
+
else:
|
722
|
+
prompt_tags = {tag.key: tag.value for tag in rm.tags}
|
723
|
+
|
724
|
+
return model_version_to_prompt_version(mv, prompt_tags=prompt_tags)
|
725
|
+
|
726
|
+
def get_prompt_version(self, name: str, version: Union[str, int]) -> Optional[PromptVersion]:
|
727
|
+
"""
|
728
|
+
Get a specific prompt version.
|
729
|
+
|
730
|
+
Default implementation: gets ModelVersion and converts to PromptVersion.
|
731
|
+
Other store implementations may override this method.
|
732
|
+
|
733
|
+
Args:
|
734
|
+
name: Name of the prompt.
|
735
|
+
version: Version number or alias.
|
736
|
+
|
737
|
+
Returns:
|
738
|
+
A PromptVersion object, or None if not found.
|
739
|
+
"""
|
740
|
+
try:
|
741
|
+
# First check if this is actually a prompt by checking the registered model
|
742
|
+
rm = self.get_registered_model(name)
|
743
|
+
|
744
|
+
# Check if this is actually a prompt using _tags (internal tags)
|
745
|
+
if hasattr(rm, "_tags") and isinstance(rm._tags, dict):
|
746
|
+
internal_tags = rm._tags.copy()
|
747
|
+
elif hasattr(rm, "_tags") and rm._tags:
|
748
|
+
internal_tags = {tag.key: tag.value for tag in rm._tags}
|
749
|
+
else:
|
750
|
+
internal_tags = {}
|
751
|
+
|
752
|
+
if not internal_tags.get(IS_PROMPT_TAG_KEY) == "true":
|
753
|
+
raise MlflowException(
|
754
|
+
f"Name `{name}` is registered as a model, not a prompt. "
|
755
|
+
f"Use get_model_version() or load_model() instead.",
|
756
|
+
INVALID_PARAMETER_VALUE,
|
757
|
+
)
|
758
|
+
|
759
|
+
# Now get the specific version
|
760
|
+
try:
|
761
|
+
version_int = int(str(version))
|
762
|
+
mv = self.get_model_version(name, version_int)
|
763
|
+
except (ValueError, TypeError):
|
764
|
+
# Treat as alias
|
765
|
+
mv = self.get_model_version_by_alias(name, str(version))
|
766
|
+
|
767
|
+
if not has_prompt_tag(mv.tags):
|
768
|
+
return None
|
769
|
+
|
770
|
+
# Get user-visible tags from registered model
|
771
|
+
if isinstance(rm.tags, dict):
|
772
|
+
prompt_tags = rm.tags.copy()
|
773
|
+
else:
|
774
|
+
prompt_tags = {tag.key: tag.value for tag in rm.tags}
|
775
|
+
|
776
|
+
return model_version_to_prompt_version(mv, prompt_tags=prompt_tags)
|
777
|
+
|
778
|
+
except MlflowException:
|
779
|
+
raise # Re-raise MlflowExceptions (including our custom one above)
|
780
|
+
except Exception:
|
781
|
+
return None
|
782
|
+
|
783
|
+
def delete_prompt_version(self, name: str, version: Union[str, int]) -> None:
|
784
|
+
"""
|
785
|
+
Delete a specific prompt version.
|
786
|
+
|
787
|
+
Default implementation: deletes the underlying ModelVersion.
|
788
|
+
Other store implementations may override this method.
|
789
|
+
|
790
|
+
Args:
|
791
|
+
name: Name of the prompt.
|
792
|
+
version: Version number to delete.
|
793
|
+
"""
|
794
|
+
# Convert version to int if needed
|
795
|
+
try:
|
796
|
+
version_int = int(version)
|
797
|
+
except (ValueError, TypeError):
|
798
|
+
raise MlflowException(f"Invalid version number: {version}")
|
799
|
+
return self.delete_model_version(name, version_int)
|
800
|
+
|
801
|
+
def get_prompt_version_by_alias(self, name: str, alias: str) -> Optional[PromptVersion]:
|
802
|
+
"""
|
803
|
+
Get a prompt version by alias.
|
804
|
+
|
805
|
+
Default implementation: uses get_model_version_by_alias and converts to PromptVersion.
|
806
|
+
|
807
|
+
Args:
|
808
|
+
name: Name of the prompt.
|
809
|
+
alias: Alias name.
|
810
|
+
|
811
|
+
Returns:
|
812
|
+
A PromptVersion object, or None if not found.
|
813
|
+
"""
|
814
|
+
return self.get_prompt_version(name, alias)
|
815
|
+
|
816
|
+
def set_prompt_alias(self, name: str, alias: str, version: Union[str, int]) -> None:
|
817
|
+
"""
|
818
|
+
Set an alias for a prompt version.
|
819
|
+
|
820
|
+
Default implementation: uses set_registered_model_alias.
|
821
|
+
|
822
|
+
Args:
|
823
|
+
name: Name of the prompt.
|
824
|
+
alias: Alias to set.
|
825
|
+
version: Version to alias.
|
826
|
+
"""
|
827
|
+
self.set_registered_model_alias(name, alias, version)
|
828
|
+
|
829
|
+
def delete_prompt_alias(self, name: str, alias: str) -> None:
|
830
|
+
"""
|
831
|
+
Delete a prompt alias.
|
832
|
+
|
833
|
+
Default implementation: uses delete_registered_model_alias.
|
834
|
+
|
835
|
+
Args:
|
836
|
+
name: Name of the prompt.
|
837
|
+
alias: Alias to delete.
|
838
|
+
"""
|
839
|
+
self.delete_registered_model_alias(name, alias)
|
840
|
+
|
841
|
+
def search_prompt_versions(
|
842
|
+
self, name: str, max_results: Optional[int] = None, page_token: Optional[str] = None
|
843
|
+
):
|
844
|
+
"""
|
845
|
+
Search prompt versions for a given prompt name.
|
846
|
+
|
847
|
+
This method is only supported in Unity Catalog registries.
|
848
|
+
For OSS registries, this functionality is not available.
|
849
|
+
|
850
|
+
Args:
|
851
|
+
name: Name of the prompt to search versions for
|
852
|
+
max_results: Maximum number of versions to return
|
853
|
+
page_token: Token for pagination
|
854
|
+
|
855
|
+
Raises:
|
856
|
+
MlflowException: Always, as this is not supported in OSS registries
|
857
|
+
"""
|
858
|
+
raise MlflowException(
|
859
|
+
"search_prompt_versions() is not supported in this registry. "
|
860
|
+
"This method is only available in Unity Catalog registries.",
|
861
|
+
INVALID_PARAMETER_VALUE,
|
862
|
+
)
|
863
|
+
|
864
|
+
def link_prompts_to_trace(self, prompt_versions: list[PromptVersion], trace_id: str) -> None:
|
865
|
+
"""
|
866
|
+
Link multiple prompt versions to a trace.
|
867
|
+
|
868
|
+
Default implementation sets a tag on the trace. Stores can override with custom behavior.
|
869
|
+
|
870
|
+
Args:
|
871
|
+
prompt_versions: List of PromptVersion objects to link.
|
872
|
+
trace_id: Trace ID to link to each prompt version.
|
873
|
+
"""
|
874
|
+
from mlflow.tracing.client import TracingClient
|
875
|
+
|
876
|
+
client = TracingClient()
|
877
|
+
with self._prompt_link_lock:
|
878
|
+
trace_info = client.get_trace_info(trace_id)
|
879
|
+
if not trace_info:
|
880
|
+
raise MlflowException(
|
881
|
+
f"Could not find trace with ID '{trace_id}' to which to link prompts.",
|
882
|
+
error_code=ErrorCode.Name(RESOURCE_DOES_NOT_EXIST),
|
883
|
+
)
|
884
|
+
|
885
|
+
# Prepare new prompt entries to add
|
886
|
+
new_prompt_entries = [
|
887
|
+
{
|
888
|
+
"name": prompt_version.name,
|
889
|
+
"version": str(prompt_version.version),
|
890
|
+
}
|
891
|
+
for prompt_version in prompt_versions
|
892
|
+
]
|
893
|
+
|
894
|
+
# Use utility function to update linked prompts tag
|
895
|
+
current_tag_value = trace_info.tags.get(LINKED_PROMPTS_TAG_KEY)
|
896
|
+
updated_tag_value = self._update_linked_prompts_tag(
|
897
|
+
current_tag_value, new_prompt_entries
|
898
|
+
)
|
899
|
+
|
900
|
+
# Only update if the tag value actually changed (avoiding redundant updates)
|
901
|
+
if current_tag_value != updated_tag_value:
|
902
|
+
client.set_trace_tag(
|
903
|
+
trace_id,
|
904
|
+
LINKED_PROMPTS_TAG_KEY,
|
905
|
+
updated_tag_value,
|
906
|
+
)
|
907
|
+
|
908
|
+
def set_prompt_version_tag(
|
909
|
+
self, name: str, version: Union[str, int], key: str, value: str
|
910
|
+
) -> None:
|
911
|
+
"""
|
912
|
+
Set a tag on a prompt version.
|
913
|
+
|
914
|
+
Default implementation: uses set_model_version_tag on the underlying ModelVersion.
|
915
|
+
Unity Catalog store implementations may override this method.
|
916
|
+
|
917
|
+
Args:
|
918
|
+
name: Name of the prompt.
|
919
|
+
version: Version number of the prompt.
|
920
|
+
key: Tag key.
|
921
|
+
value: Tag value.
|
922
|
+
"""
|
923
|
+
# Convert version to int if needed
|
924
|
+
try:
|
925
|
+
version_int = int(version)
|
926
|
+
except (ValueError, TypeError):
|
927
|
+
raise MlflowException(f"Invalid version number: {version}")
|
928
|
+
|
929
|
+
# Create a ModelVersionTag and delegate to the underlying model version method
|
930
|
+
tag = ModelVersionTag(key=key, value=value)
|
931
|
+
return self.set_model_version_tag(name, version_int, tag)
|
932
|
+
|
933
|
+
def delete_prompt_version_tag(self, name: str, version: Union[str, int], key: str) -> None:
|
934
|
+
"""
|
935
|
+
Delete a tag from a prompt version.
|
936
|
+
|
937
|
+
Default implementation: uses delete_model_version_tag on the underlying ModelVersion.
|
938
|
+
Unity Catalog store implementations may override this method.
|
939
|
+
|
940
|
+
Args:
|
941
|
+
name: Name of the prompt.
|
942
|
+
version: Version number of the prompt.
|
943
|
+
key: Tag key to delete.
|
944
|
+
"""
|
945
|
+
# Convert version to int if needed
|
946
|
+
try:
|
947
|
+
version_int = int(version)
|
948
|
+
except (ValueError, TypeError):
|
949
|
+
raise MlflowException(f"Invalid version number: {version}")
|
950
|
+
|
951
|
+
# Delegate to the underlying model version method
|
952
|
+
return self.delete_model_version_tag(name, version_int, key)
|
953
|
+
|
954
|
+
def link_prompt_version_to_model(self, name: str, version: str, model_id: str) -> None:
|
955
|
+
"""
|
956
|
+
Link a prompt version to a model.
|
957
|
+
|
958
|
+
Default implementation sets a tag. Stores can override with custom behavior.
|
959
|
+
|
960
|
+
Args:
|
961
|
+
name: Name of the prompt.
|
962
|
+
version: Version of the prompt to link.
|
963
|
+
model_id: ID of the model to link to.
|
964
|
+
"""
|
965
|
+
from mlflow.tracking import _get_store as _get_tracking_store
|
966
|
+
|
967
|
+
prompt_version = self.get_prompt_version(name, version)
|
968
|
+
tracking_store = _get_tracking_store()
|
969
|
+
|
970
|
+
with self._prompt_link_lock:
|
971
|
+
logged_model = tracking_store.get_logged_model(model_id)
|
972
|
+
if not logged_model:
|
973
|
+
raise MlflowException(
|
974
|
+
f"Could not find model with ID '{model_id}' to which to link prompt '{name}'.",
|
975
|
+
error_code=ErrorCode.Name(RESOURCE_DOES_NOT_EXIST),
|
976
|
+
)
|
977
|
+
|
978
|
+
new_prompt_entry = {
|
979
|
+
"name": prompt_version.name,
|
980
|
+
"version": str(prompt_version.version),
|
981
|
+
}
|
982
|
+
|
983
|
+
current_tag_value = logged_model.tags.get(LINKED_PROMPTS_TAG_KEY)
|
984
|
+
updated_tag_value = self._update_linked_prompts_tag(
|
985
|
+
current_tag_value, [new_prompt_entry]
|
986
|
+
)
|
987
|
+
|
988
|
+
if current_tag_value != updated_tag_value:
|
989
|
+
tracking_store.set_logged_model_tags(
|
990
|
+
model_id,
|
991
|
+
[
|
992
|
+
LoggedModelTag(
|
993
|
+
key=LINKED_PROMPTS_TAG_KEY,
|
994
|
+
value=updated_tag_value,
|
995
|
+
)
|
996
|
+
],
|
997
|
+
)
|
998
|
+
|
999
|
+
def link_prompt_version_to_run(self, name: str, version: str, run_id: str) -> None:
|
1000
|
+
"""
|
1001
|
+
Link a prompt version to a run.
|
1002
|
+
|
1003
|
+
Default implementation sets a tag. Stores can override with custom behavior.
|
1004
|
+
|
1005
|
+
Args:
|
1006
|
+
name: Name of the prompt.
|
1007
|
+
version: Version of the prompt to link.
|
1008
|
+
run_id: ID of the run to link to.
|
1009
|
+
"""
|
1010
|
+
from mlflow.tracking import _get_store as _get_tracking_store
|
1011
|
+
|
1012
|
+
prompt_version = self.get_prompt_version(name, version)
|
1013
|
+
tracking_store = _get_tracking_store()
|
1014
|
+
|
1015
|
+
with self._prompt_link_lock:
|
1016
|
+
run = tracking_store.get_run(run_id)
|
1017
|
+
if not run:
|
1018
|
+
raise MlflowException(
|
1019
|
+
f"Could not find run with ID '{run_id}' to which to link prompt '{name}'.",
|
1020
|
+
error_code=ErrorCode.Name(RESOURCE_DOES_NOT_EXIST),
|
1021
|
+
)
|
1022
|
+
|
1023
|
+
new_prompt_entry = {
|
1024
|
+
"name": prompt_version.name,
|
1025
|
+
"version": str(prompt_version.version),
|
1026
|
+
}
|
1027
|
+
|
1028
|
+
current_tag_value = None
|
1029
|
+
if isinstance(run.data.tags, dict):
|
1030
|
+
current_tag_value = run.data.tags.get(LINKED_PROMPTS_TAG_KEY)
|
1031
|
+
else:
|
1032
|
+
for tag in run.data.tags:
|
1033
|
+
if tag.key == LINKED_PROMPTS_TAG_KEY:
|
1034
|
+
current_tag_value = tag.value
|
1035
|
+
break
|
1036
|
+
|
1037
|
+
updated_tag_value = self._update_linked_prompts_tag(
|
1038
|
+
current_tag_value, [new_prompt_entry]
|
1039
|
+
)
|
1040
|
+
|
1041
|
+
if current_tag_value != updated_tag_value:
|
1042
|
+
from mlflow.entities import RunTag
|
1043
|
+
|
1044
|
+
tracking_store.set_tag(run_id, RunTag(LINKED_PROMPTS_TAG_KEY, updated_tag_value))
|
1045
|
+
|
1046
|
+
def _update_linked_prompts_tag(
|
1047
|
+
self, current_tag_value: str, new_prompt_entries: list[dict[str, Any]]
|
1048
|
+
) -> str:
|
1049
|
+
"""
|
1050
|
+
Utility method to update linked prompts tag value with new entries.
|
1051
|
+
|
1052
|
+
Args:
|
1053
|
+
current_tag_value: Current JSON string value of the linked prompts tag
|
1054
|
+
new_prompt_entries: List of prompt entry dicts to add
|
1055
|
+
|
1056
|
+
Returns:
|
1057
|
+
Updated JSON string with new entries added (avoiding duplicates)
|
1058
|
+
|
1059
|
+
Raises:
|
1060
|
+
MlflowException: If current tag value has invalid JSON or format
|
1061
|
+
"""
|
1062
|
+
if current_tag_value is not None:
|
1063
|
+
try:
|
1064
|
+
parsed_prompts_tag_value = json.loads(current_tag_value)
|
1065
|
+
if not isinstance(parsed_prompts_tag_value, list):
|
1066
|
+
raise MlflowException(
|
1067
|
+
f"Invalid format for '{LINKED_PROMPTS_TAG_KEY}' tag: {current_tag_value}"
|
1068
|
+
)
|
1069
|
+
except json.JSONDecodeError:
|
1070
|
+
raise MlflowException(
|
1071
|
+
f"Invalid JSON format for '{LINKED_PROMPTS_TAG_KEY}' tag: {current_tag_value}"
|
1072
|
+
)
|
1073
|
+
else:
|
1074
|
+
parsed_prompts_tag_value = []
|
1075
|
+
|
1076
|
+
# Add new prompt entries that aren't already linked
|
1077
|
+
for new_prompt_entry in new_prompt_entries:
|
1078
|
+
if new_prompt_entry not in parsed_prompts_tag_value:
|
1079
|
+
parsed_prompts_tag_value.append(new_prompt_entry)
|
1080
|
+
|
1081
|
+
return json.dumps(parsed_prompts_tag_value)
|