genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,2111 @@
|
|
1
|
+
import ast
|
2
|
+
import base64
|
3
|
+
import json
|
4
|
+
import math
|
5
|
+
import operator
|
6
|
+
import re
|
7
|
+
import shlex
|
8
|
+
from dataclasses import asdict, dataclass
|
9
|
+
from typing import Any, Optional
|
10
|
+
|
11
|
+
import sqlparse
|
12
|
+
from packaging.version import Version
|
13
|
+
from sqlparse.sql import (
|
14
|
+
Comparison,
|
15
|
+
Identifier,
|
16
|
+
Parenthesis,
|
17
|
+
Statement,
|
18
|
+
Token,
|
19
|
+
TokenList,
|
20
|
+
)
|
21
|
+
from sqlparse.tokens import Token as TokenType
|
22
|
+
|
23
|
+
from mlflow.entities import LoggedModel, Metric, RunInfo
|
24
|
+
from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL
|
25
|
+
from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
|
26
|
+
from mlflow.exceptions import MlflowException
|
27
|
+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
|
28
|
+
from mlflow.store.db.db_types import MSSQL, MYSQL, POSTGRES, SQLITE
|
29
|
+
from mlflow.tracing.constant import TraceMetadataKey, TraceTagKey
|
30
|
+
from mlflow.utils.mlflow_tags import (
|
31
|
+
MLFLOW_DATASET_CONTEXT,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def _convert_like_pattern_to_regex(pattern, flags=0):
|
36
|
+
if not pattern.startswith("%"):
|
37
|
+
pattern = "^" + pattern
|
38
|
+
if not pattern.endswith("%"):
|
39
|
+
pattern = pattern + "$"
|
40
|
+
return re.compile(pattern.replace("_", ".").replace("%", ".*"), flags)
|
41
|
+
|
42
|
+
|
43
|
+
def _like(string, pattern):
|
44
|
+
return _convert_like_pattern_to_regex(pattern).match(string) is not None
|
45
|
+
|
46
|
+
|
47
|
+
def _ilike(string, pattern):
|
48
|
+
return _convert_like_pattern_to_regex(pattern, flags=re.IGNORECASE).match(string) is not None
|
49
|
+
|
50
|
+
|
51
|
+
def _join_in_comparison_tokens(tokens, search_traces=False):
|
52
|
+
"""
|
53
|
+
Find a sequence of tokens that matches the pattern of an IN comparison or a NOT IN comparison,
|
54
|
+
join the tokens into a single Comparison token. Otherwise, return the original list of tokens.
|
55
|
+
"""
|
56
|
+
if Version(sqlparse.__version__) < Version("0.4.4"):
|
57
|
+
# In sqlparse < 0.4.4, IN is treated as a comparison, we don't need to join tokens
|
58
|
+
return tokens
|
59
|
+
|
60
|
+
non_whitespace_tokens = [t for t in tokens if not t.is_whitespace]
|
61
|
+
joined_tokens = []
|
62
|
+
num_tokens = len(non_whitespace_tokens)
|
63
|
+
iterator = enumerate(non_whitespace_tokens)
|
64
|
+
while elem := next(iterator, None):
|
65
|
+
index, first = elem
|
66
|
+
# We need at least 3 tokens to form an IN comparison or a NOT IN comparison
|
67
|
+
if num_tokens - index < 3:
|
68
|
+
joined_tokens.extend(non_whitespace_tokens[index:])
|
69
|
+
break
|
70
|
+
|
71
|
+
if search_traces:
|
72
|
+
# timestamp
|
73
|
+
if first.match(ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]):
|
74
|
+
(_, second) = next(iterator, (None, None))
|
75
|
+
(_, third) = next(iterator, (None, None))
|
76
|
+
if any(x is None for x in [second, third]):
|
77
|
+
raise MlflowException(
|
78
|
+
f"Invalid comparison clause with token `{first}, {second}, {third}`, "
|
79
|
+
"expected 3 tokens",
|
80
|
+
error_code=INVALID_PARAMETER_VALUE,
|
81
|
+
)
|
82
|
+
if (
|
83
|
+
second.match(
|
84
|
+
ttype=TokenType.Operator.Comparison,
|
85
|
+
values=SearchTraceUtils.VALID_NUMERIC_ATTRIBUTE_COMPARATORS,
|
86
|
+
)
|
87
|
+
and third.ttype == TokenType.Literal.Number.Integer
|
88
|
+
):
|
89
|
+
joined_tokens.append(Comparison(TokenList([first, second, third])))
|
90
|
+
continue
|
91
|
+
else:
|
92
|
+
joined_tokens.extend([first, second, third])
|
93
|
+
|
94
|
+
# Wait until we encounter an identifier token
|
95
|
+
if not isinstance(first, Identifier):
|
96
|
+
joined_tokens.append(first)
|
97
|
+
continue
|
98
|
+
|
99
|
+
(_, second) = next(iterator)
|
100
|
+
(_, third) = next(iterator)
|
101
|
+
|
102
|
+
# IN
|
103
|
+
if (
|
104
|
+
isinstance(first, Identifier)
|
105
|
+
and second.match(ttype=TokenType.Keyword, values=["IN"])
|
106
|
+
and isinstance(third, Parenthesis)
|
107
|
+
):
|
108
|
+
joined_tokens.append(Comparison(TokenList([first, second, third])))
|
109
|
+
continue
|
110
|
+
|
111
|
+
(_, fourth) = next(iterator, (None, None))
|
112
|
+
if fourth is None:
|
113
|
+
joined_tokens.extend([first, second, third])
|
114
|
+
break
|
115
|
+
|
116
|
+
# NOT IN
|
117
|
+
if (
|
118
|
+
isinstance(first, Identifier)
|
119
|
+
and second.match(ttype=TokenType.Keyword, values=["NOT"])
|
120
|
+
and third.match(ttype=TokenType.Keyword, values=["IN"])
|
121
|
+
and isinstance(fourth, Parenthesis)
|
122
|
+
):
|
123
|
+
joined_tokens.append(
|
124
|
+
Comparison(TokenList([first, Token(TokenType.Keyword, "NOT IN"), fourth]))
|
125
|
+
)
|
126
|
+
continue
|
127
|
+
|
128
|
+
joined_tokens.extend([first, second, third, fourth])
|
129
|
+
|
130
|
+
return joined_tokens
|
131
|
+
|
132
|
+
|
133
|
+
class SearchUtils:
|
134
|
+
LIKE_OPERATOR = "LIKE"
|
135
|
+
ILIKE_OPERATOR = "ILIKE"
|
136
|
+
ASC_OPERATOR = "asc"
|
137
|
+
DESC_OPERATOR = "desc"
|
138
|
+
VALID_ORDER_BY_TAGS = [ASC_OPERATOR, DESC_OPERATOR]
|
139
|
+
VALID_METRIC_COMPARATORS = {">", ">=", "!=", "=", "<", "<="}
|
140
|
+
VALID_PARAM_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR}
|
141
|
+
VALID_TAG_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR}
|
142
|
+
VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"}
|
143
|
+
VALID_NUMERIC_ATTRIBUTE_COMPARATORS = VALID_METRIC_COMPARATORS
|
144
|
+
VALID_DATASET_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"}
|
145
|
+
_BUILTIN_NUMERIC_ATTRIBUTES = {"start_time", "end_time"}
|
146
|
+
_ALTERNATE_NUMERIC_ATTRIBUTES = {"created", "Created"}
|
147
|
+
_ALTERNATE_STRING_ATTRIBUTES = {"run name", "Run name", "Run Name"}
|
148
|
+
NUMERIC_ATTRIBUTES = set(
|
149
|
+
list(_BUILTIN_NUMERIC_ATTRIBUTES) + list(_ALTERNATE_NUMERIC_ATTRIBUTES)
|
150
|
+
)
|
151
|
+
DATASET_ATTRIBUTES = {"name", "digest", "context"}
|
152
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = set(
|
153
|
+
RunInfo.get_searchable_attributes()
|
154
|
+
+ list(_ALTERNATE_NUMERIC_ATTRIBUTES)
|
155
|
+
+ list(_ALTERNATE_STRING_ATTRIBUTES)
|
156
|
+
)
|
157
|
+
VALID_ORDER_BY_ATTRIBUTE_KEYS = set(
|
158
|
+
RunInfo.get_orderable_attributes() + list(_ALTERNATE_NUMERIC_ATTRIBUTES)
|
159
|
+
)
|
160
|
+
_METRIC_IDENTIFIER = "metric"
|
161
|
+
_ALTERNATE_METRIC_IDENTIFIERS = {"metrics"}
|
162
|
+
_PARAM_IDENTIFIER = "parameter"
|
163
|
+
_ALTERNATE_PARAM_IDENTIFIERS = {"parameters", "param", "params"}
|
164
|
+
_TAG_IDENTIFIER = "tag"
|
165
|
+
_ALTERNATE_TAG_IDENTIFIERS = {"tags"}
|
166
|
+
_ATTRIBUTE_IDENTIFIER = "attribute"
|
167
|
+
_ALTERNATE_ATTRIBUTE_IDENTIFIERS = {"attr", "attributes", "run"}
|
168
|
+
_DATASET_IDENTIFIER = "dataset"
|
169
|
+
_ALTERNATE_DATASET_IDENTIFIERS = {"datasets"}
|
170
|
+
_IDENTIFIERS = [
|
171
|
+
_METRIC_IDENTIFIER,
|
172
|
+
_PARAM_IDENTIFIER,
|
173
|
+
_TAG_IDENTIFIER,
|
174
|
+
_ATTRIBUTE_IDENTIFIER,
|
175
|
+
_DATASET_IDENTIFIER,
|
176
|
+
]
|
177
|
+
_VALID_IDENTIFIERS = set(
|
178
|
+
_IDENTIFIERS
|
179
|
+
+ list(_ALTERNATE_METRIC_IDENTIFIERS)
|
180
|
+
+ list(_ALTERNATE_PARAM_IDENTIFIERS)
|
181
|
+
+ list(_ALTERNATE_TAG_IDENTIFIERS)
|
182
|
+
+ list(_ALTERNATE_ATTRIBUTE_IDENTIFIERS)
|
183
|
+
+ list(_ALTERNATE_DATASET_IDENTIFIERS)
|
184
|
+
)
|
185
|
+
STRING_VALUE_TYPES = {TokenType.Literal.String.Single}
|
186
|
+
DELIMITER_VALUE_TYPES = {TokenType.Punctuation}
|
187
|
+
WHITESPACE_VALUE_TYPE = TokenType.Text.Whitespace
|
188
|
+
NUMERIC_VALUE_TYPES = {TokenType.Literal.Number.Integer, TokenType.Literal.Number.Float}
|
189
|
+
# Registered Models Constants
|
190
|
+
ORDER_BY_KEY_TIMESTAMP = "timestamp"
|
191
|
+
ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP = "last_updated_timestamp"
|
192
|
+
ORDER_BY_KEY_MODEL_NAME = "name"
|
193
|
+
VALID_ORDER_BY_KEYS_REGISTERED_MODELS = {
|
194
|
+
ORDER_BY_KEY_TIMESTAMP,
|
195
|
+
ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP,
|
196
|
+
ORDER_BY_KEY_MODEL_NAME,
|
197
|
+
}
|
198
|
+
VALID_TIMESTAMP_ORDER_BY_KEYS = {ORDER_BY_KEY_TIMESTAMP, ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP}
|
199
|
+
# We encourage users to use timestamp for order-by
|
200
|
+
RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS = {ORDER_BY_KEY_MODEL_NAME, ORDER_BY_KEY_TIMESTAMP}
|
201
|
+
|
202
|
+
@staticmethod
|
203
|
+
def get_comparison_func(comparator):
|
204
|
+
return {
|
205
|
+
">": operator.gt,
|
206
|
+
">=": operator.ge,
|
207
|
+
"=": operator.eq,
|
208
|
+
"!=": operator.ne,
|
209
|
+
"<=": operator.le,
|
210
|
+
"<": operator.lt,
|
211
|
+
"LIKE": _like,
|
212
|
+
"ILIKE": _ilike,
|
213
|
+
"IN": lambda x, y: x in y,
|
214
|
+
"NOT IN": lambda x, y: x not in y,
|
215
|
+
}[comparator]
|
216
|
+
|
217
|
+
@staticmethod
|
218
|
+
def get_sql_comparison_func(comparator, dialect):
|
219
|
+
import sqlalchemy as sa
|
220
|
+
|
221
|
+
def comparison_func(column, value):
|
222
|
+
if comparator == "LIKE":
|
223
|
+
return column.like(value)
|
224
|
+
elif comparator == "ILIKE":
|
225
|
+
return column.ilike(value)
|
226
|
+
elif comparator == "IN":
|
227
|
+
return column.in_(value)
|
228
|
+
elif comparator == "NOT IN":
|
229
|
+
return ~column.in_(value)
|
230
|
+
return SearchUtils.get_comparison_func(comparator)(column, value)
|
231
|
+
|
232
|
+
def mssql_comparison_func(column, value):
|
233
|
+
if not isinstance(column.type, sa.types.String):
|
234
|
+
return comparison_func(column, value)
|
235
|
+
|
236
|
+
collated = column.collate("Japanese_Bushu_Kakusu_100_CS_AS_KS_WS")
|
237
|
+
return comparison_func(collated, value)
|
238
|
+
|
239
|
+
def mysql_comparison_func(column, value):
|
240
|
+
if not isinstance(column.type, sa.types.String):
|
241
|
+
return comparison_func(column, value)
|
242
|
+
|
243
|
+
# MySQL is case insensitive by default, so we need to use the binary operator to
|
244
|
+
# perform case sensitive comparisons.
|
245
|
+
templates = {
|
246
|
+
# Use non-binary ahead of binary comparison for runtime performance
|
247
|
+
"=": "({column} = :value AND BINARY {column} = :value)",
|
248
|
+
"!=": "({column} != :value OR BINARY {column} != :value)",
|
249
|
+
"LIKE": "({column} LIKE :value AND BINARY {column} LIKE :value)",
|
250
|
+
}
|
251
|
+
if comparator in templates:
|
252
|
+
column = f"{column.class_.__tablename__}.{column.key}"
|
253
|
+
return sa.text(templates[comparator].format(column=column)).bindparams(
|
254
|
+
sa.bindparam("value", value=value, unique=True)
|
255
|
+
)
|
256
|
+
|
257
|
+
return comparison_func(column, value)
|
258
|
+
|
259
|
+
return {
|
260
|
+
POSTGRES: comparison_func,
|
261
|
+
SQLITE: comparison_func,
|
262
|
+
MSSQL: mssql_comparison_func,
|
263
|
+
MYSQL: mysql_comparison_func,
|
264
|
+
}[dialect]
|
265
|
+
|
266
|
+
@staticmethod
|
267
|
+
def translate_key_alias(key):
|
268
|
+
if key in ["created", "Created"]:
|
269
|
+
return "start_time"
|
270
|
+
if key in ["run name", "Run name", "Run Name"]:
|
271
|
+
return "run_name"
|
272
|
+
return key
|
273
|
+
|
274
|
+
@classmethod
|
275
|
+
def _trim_ends(cls, string_value):
|
276
|
+
return string_value[1:-1]
|
277
|
+
|
278
|
+
@classmethod
|
279
|
+
def _is_quoted(cls, value, pattern):
|
280
|
+
return len(value) >= 2 and value.startswith(pattern) and value.endswith(pattern)
|
281
|
+
|
282
|
+
@classmethod
|
283
|
+
def _trim_backticks(cls, entity_type):
|
284
|
+
"""Remove backticks from identifier like `param`, if they exist."""
|
285
|
+
if cls._is_quoted(entity_type, "`"):
|
286
|
+
return cls._trim_ends(entity_type)
|
287
|
+
return entity_type
|
288
|
+
|
289
|
+
@classmethod
|
290
|
+
def _strip_quotes(cls, value, expect_quoted_value=False):
|
291
|
+
"""
|
292
|
+
Remove quotes for input string.
|
293
|
+
Values of type strings are expected to have quotes.
|
294
|
+
Keys containing special characters are also expected to be enclose in quotes.
|
295
|
+
"""
|
296
|
+
if cls._is_quoted(value, "'") or cls._is_quoted(value, '"'):
|
297
|
+
return cls._trim_ends(value)
|
298
|
+
elif expect_quoted_value:
|
299
|
+
raise MlflowException(
|
300
|
+
"Parameter value is either not quoted or unidentified quote "
|
301
|
+
f"types used for string value {value}. Use either single or double "
|
302
|
+
"quotes.",
|
303
|
+
error_code=INVALID_PARAMETER_VALUE,
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
return value
|
307
|
+
|
308
|
+
@classmethod
|
309
|
+
def _valid_entity_type(cls, entity_type):
|
310
|
+
entity_type = cls._trim_backticks(entity_type)
|
311
|
+
if entity_type not in cls._VALID_IDENTIFIERS:
|
312
|
+
raise MlflowException(
|
313
|
+
f"Invalid entity type '{entity_type}'. Valid values are {cls._IDENTIFIERS}",
|
314
|
+
error_code=INVALID_PARAMETER_VALUE,
|
315
|
+
)
|
316
|
+
|
317
|
+
if entity_type in cls._ALTERNATE_PARAM_IDENTIFIERS:
|
318
|
+
return cls._PARAM_IDENTIFIER
|
319
|
+
elif entity_type in cls._ALTERNATE_METRIC_IDENTIFIERS:
|
320
|
+
return cls._METRIC_IDENTIFIER
|
321
|
+
elif entity_type in cls._ALTERNATE_TAG_IDENTIFIERS:
|
322
|
+
return cls._TAG_IDENTIFIER
|
323
|
+
elif entity_type in cls._ALTERNATE_ATTRIBUTE_IDENTIFIERS:
|
324
|
+
return cls._ATTRIBUTE_IDENTIFIER
|
325
|
+
elif entity_type in cls._ALTERNATE_DATASET_IDENTIFIERS:
|
326
|
+
return cls._DATASET_IDENTIFIER
|
327
|
+
else:
|
328
|
+
# one of ("metric", "parameter", "tag", or "attribute") since it a valid type
|
329
|
+
return entity_type
|
330
|
+
|
331
|
+
@classmethod
|
332
|
+
def _get_identifier(cls, identifier, valid_attributes):
|
333
|
+
try:
|
334
|
+
tokens = identifier.split(".", 1)
|
335
|
+
if len(tokens) == 1:
|
336
|
+
key = tokens[0]
|
337
|
+
entity_type = cls._ATTRIBUTE_IDENTIFIER
|
338
|
+
else:
|
339
|
+
entity_type, key = tokens
|
340
|
+
except ValueError:
|
341
|
+
raise MlflowException(
|
342
|
+
f"Invalid identifier {identifier!r}. Columns should be specified as "
|
343
|
+
"'attribute.<key>', 'metric.<key>', 'tag.<key>', 'dataset.<key>', or "
|
344
|
+
"'param.'.",
|
345
|
+
error_code=INVALID_PARAMETER_VALUE,
|
346
|
+
)
|
347
|
+
identifier = cls._valid_entity_type(entity_type)
|
348
|
+
key = cls._trim_backticks(cls._strip_quotes(key))
|
349
|
+
if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
|
350
|
+
raise MlflowException.invalid_parameter_value(
|
351
|
+
f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
|
352
|
+
)
|
353
|
+
elif identifier == cls._DATASET_IDENTIFIER and key not in cls.DATASET_ATTRIBUTES:
|
354
|
+
raise MlflowException.invalid_parameter_value(
|
355
|
+
f"Invalid dataset key '{key}' specified. Valid keys are '{cls.DATASET_ATTRIBUTES}'"
|
356
|
+
)
|
357
|
+
return {"type": identifier, "key": key}
|
358
|
+
|
359
|
+
@classmethod
|
360
|
+
def validate_list_supported(cls, key: str) -> None:
|
361
|
+
if key != "run_id":
|
362
|
+
raise MlflowException(
|
363
|
+
"Only the 'run_id' attribute supports comparison with a list of quoted "
|
364
|
+
"string values.",
|
365
|
+
error_code=INVALID_PARAMETER_VALUE,
|
366
|
+
)
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def _get_value(cls, identifier_type, key, token):
|
370
|
+
if identifier_type == cls._METRIC_IDENTIFIER:
|
371
|
+
if token.ttype not in cls.NUMERIC_VALUE_TYPES:
|
372
|
+
raise MlflowException(
|
373
|
+
f"Expected numeric value type for metric. Found {token.value}",
|
374
|
+
error_code=INVALID_PARAMETER_VALUE,
|
375
|
+
)
|
376
|
+
return token.value
|
377
|
+
elif identifier_type == cls._PARAM_IDENTIFIER or identifier_type == cls._TAG_IDENTIFIER:
|
378
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
379
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
380
|
+
raise MlflowException(
|
381
|
+
"Expected a quoted string value for "
|
382
|
+
f"{identifier_type} (e.g. 'my-value'). Got value "
|
383
|
+
f"{token.value}",
|
384
|
+
error_code=INVALID_PARAMETER_VALUE,
|
385
|
+
)
|
386
|
+
elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
|
387
|
+
if key in cls.NUMERIC_ATTRIBUTES:
|
388
|
+
if token.ttype not in cls.NUMERIC_VALUE_TYPES:
|
389
|
+
raise MlflowException(
|
390
|
+
f"Expected numeric value type for numeric attribute: {key}. "
|
391
|
+
f"Found {token.value}",
|
392
|
+
error_code=INVALID_PARAMETER_VALUE,
|
393
|
+
)
|
394
|
+
return token.value
|
395
|
+
elif token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
396
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
397
|
+
elif isinstance(token, Parenthesis):
|
398
|
+
cls.validate_list_supported(key)
|
399
|
+
return cls._parse_run_ids(token)
|
400
|
+
else:
|
401
|
+
raise MlflowException(
|
402
|
+
f"Expected a quoted string value for attributes. Got value {token.value}",
|
403
|
+
error_code=INVALID_PARAMETER_VALUE,
|
404
|
+
)
|
405
|
+
elif identifier_type == cls._DATASET_IDENTIFIER:
|
406
|
+
if key in cls.DATASET_ATTRIBUTES and (
|
407
|
+
token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier)
|
408
|
+
):
|
409
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
410
|
+
elif isinstance(token, Parenthesis):
|
411
|
+
if key not in ("name", "digest", "context"):
|
412
|
+
raise MlflowException(
|
413
|
+
"Only the dataset 'name' and 'digest' supports comparison with a list of "
|
414
|
+
"quoted string values.",
|
415
|
+
error_code=INVALID_PARAMETER_VALUE,
|
416
|
+
)
|
417
|
+
return cls._parse_run_ids(token)
|
418
|
+
else:
|
419
|
+
raise MlflowException(
|
420
|
+
"Expected a quoted string value for dataset attributes. "
|
421
|
+
f"Got value {token.value}",
|
422
|
+
error_code=INVALID_PARAMETER_VALUE,
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
# Expected to be either "param" or "metric".
|
426
|
+
raise MlflowException(
|
427
|
+
"Invalid identifier type. Expected one of "
|
428
|
+
f"{[cls._METRIC_IDENTIFIER, cls._PARAM_IDENTIFIER]}."
|
429
|
+
)
|
430
|
+
|
431
|
+
@classmethod
|
432
|
+
def _validate_comparison(cls, tokens, search_traces=False):
|
433
|
+
base_error_string = "Invalid comparison clause"
|
434
|
+
if len(tokens) != 3:
|
435
|
+
raise MlflowException(
|
436
|
+
f"{base_error_string}. Expected 3 tokens found {len(tokens)}",
|
437
|
+
error_code=INVALID_PARAMETER_VALUE,
|
438
|
+
)
|
439
|
+
if not isinstance(tokens[0], Identifier):
|
440
|
+
if not search_traces:
|
441
|
+
raise MlflowException(
|
442
|
+
f"{base_error_string}. Expected 'Identifier' found '{tokens[0]}'",
|
443
|
+
error_code=INVALID_PARAMETER_VALUE,
|
444
|
+
)
|
445
|
+
if search_traces and not tokens[0].match(
|
446
|
+
ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]
|
447
|
+
):
|
448
|
+
raise MlflowException(
|
449
|
+
f"{base_error_string}. Expected 'TokenType.Name.Builtin' found '{tokens[0]}'",
|
450
|
+
error_code=INVALID_PARAMETER_VALUE,
|
451
|
+
)
|
452
|
+
if not isinstance(tokens[1], Token) and tokens[1].ttype != TokenType.Operator.Comparison:
|
453
|
+
raise MlflowException(
|
454
|
+
f"{base_error_string}. Expected comparison found '{tokens[1]}'",
|
455
|
+
error_code=INVALID_PARAMETER_VALUE,
|
456
|
+
)
|
457
|
+
if not isinstance(tokens[2], Token) and (
|
458
|
+
tokens[2].ttype not in cls.STRING_VALUE_TYPES.union(cls.NUMERIC_VALUE_TYPES)
|
459
|
+
or isinstance(tokens[2], Identifier)
|
460
|
+
):
|
461
|
+
raise MlflowException(
|
462
|
+
f"{base_error_string}. Expected value token found '{tokens[2]}'",
|
463
|
+
error_code=INVALID_PARAMETER_VALUE,
|
464
|
+
)
|
465
|
+
|
466
|
+
@classmethod
|
467
|
+
def _get_comparison(cls, comparison):
|
468
|
+
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
|
469
|
+
cls._validate_comparison(stripped_comparison)
|
470
|
+
comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
|
471
|
+
comp["comparator"] = stripped_comparison[1].value
|
472
|
+
comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2])
|
473
|
+
return comp
|
474
|
+
|
475
|
+
@classmethod
|
476
|
+
def _invalid_statement_token_search_runs(cls, token):
|
477
|
+
if (
|
478
|
+
isinstance(token, Comparison)
|
479
|
+
or token.is_whitespace
|
480
|
+
or token.match(ttype=TokenType.Keyword, values=["AND"])
|
481
|
+
):
|
482
|
+
return False
|
483
|
+
return True
|
484
|
+
|
485
|
+
@classmethod
|
486
|
+
def _process_statement(cls, statement):
|
487
|
+
# check validity
|
488
|
+
tokens = _join_in_comparison_tokens(statement.tokens)
|
489
|
+
invalids = list(filter(cls._invalid_statement_token_search_runs, tokens))
|
490
|
+
if len(invalids) > 0:
|
491
|
+
invalid_clauses = ", ".join(f"'{token}'" for token in invalids)
|
492
|
+
raise MlflowException(
|
493
|
+
f"Invalid clause(s) in filter string: {invalid_clauses}",
|
494
|
+
error_code=INVALID_PARAMETER_VALUE,
|
495
|
+
)
|
496
|
+
return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)]
|
497
|
+
|
498
|
+
@classmethod
|
499
|
+
def parse_search_filter(cls, filter_string):
|
500
|
+
if not filter_string:
|
501
|
+
return []
|
502
|
+
try:
|
503
|
+
parsed = sqlparse.parse(filter_string)
|
504
|
+
except Exception:
|
505
|
+
raise MlflowException(
|
506
|
+
f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE
|
507
|
+
)
|
508
|
+
if len(parsed) == 0 or not isinstance(parsed[0], Statement):
|
509
|
+
raise MlflowException(
|
510
|
+
f"Invalid filter '{filter_string}'. Could not be parsed.",
|
511
|
+
error_code=INVALID_PARAMETER_VALUE,
|
512
|
+
)
|
513
|
+
elif len(parsed) > 1:
|
514
|
+
raise MlflowException(
|
515
|
+
f"Search filter contained multiple expression {filter_string!r}. "
|
516
|
+
"Provide AND-ed expression list.",
|
517
|
+
error_code=INVALID_PARAMETER_VALUE,
|
518
|
+
)
|
519
|
+
return cls._process_statement(parsed[0])
|
520
|
+
|
521
|
+
@classmethod
|
522
|
+
def is_metric(cls, key_type, comparator):
|
523
|
+
if key_type == cls._METRIC_IDENTIFIER:
|
524
|
+
if comparator not in cls.VALID_METRIC_COMPARATORS:
|
525
|
+
raise MlflowException(
|
526
|
+
f"Invalid comparator '{comparator}' not one of '{cls.VALID_METRIC_COMPARATORS}",
|
527
|
+
error_code=INVALID_PARAMETER_VALUE,
|
528
|
+
)
|
529
|
+
return True
|
530
|
+
return False
|
531
|
+
|
532
|
+
@classmethod
|
533
|
+
def is_param(cls, key_type, comparator):
|
534
|
+
if key_type == cls._PARAM_IDENTIFIER:
|
535
|
+
if comparator not in cls.VALID_PARAM_COMPARATORS:
|
536
|
+
raise MlflowException(
|
537
|
+
f"Invalid comparator '{comparator}' not one of '{cls.VALID_PARAM_COMPARATORS}'",
|
538
|
+
error_code=INVALID_PARAMETER_VALUE,
|
539
|
+
)
|
540
|
+
return True
|
541
|
+
return False
|
542
|
+
|
543
|
+
@classmethod
|
544
|
+
def is_tag(cls, key_type, comparator):
|
545
|
+
if key_type == cls._TAG_IDENTIFIER:
|
546
|
+
if comparator not in cls.VALID_TAG_COMPARATORS:
|
547
|
+
raise MlflowException(
|
548
|
+
f"Invalid comparator '{comparator}' not one of '{cls.VALID_TAG_COMPARATORS}",
|
549
|
+
error_code=INVALID_PARAMETER_VALUE,
|
550
|
+
)
|
551
|
+
return True
|
552
|
+
return False
|
553
|
+
|
554
|
+
@classmethod
|
555
|
+
def is_attribute(cls, key_type, key_name, comparator):
|
556
|
+
return cls.is_string_attribute(key_type, key_name, comparator) or cls.is_numeric_attribute(
|
557
|
+
key_type, key_name, comparator
|
558
|
+
)
|
559
|
+
|
560
|
+
@classmethod
|
561
|
+
def is_string_attribute(cls, key_type, key_name, comparator):
|
562
|
+
if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name not in cls.NUMERIC_ATTRIBUTES:
|
563
|
+
if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS:
|
564
|
+
raise MlflowException(
|
565
|
+
f"Invalid comparator '{comparator}' not one of "
|
566
|
+
f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'",
|
567
|
+
error_code=INVALID_PARAMETER_VALUE,
|
568
|
+
)
|
569
|
+
return True
|
570
|
+
return False
|
571
|
+
|
572
|
+
@classmethod
|
573
|
+
def is_numeric_attribute(cls, key_type, key_name, comparator):
|
574
|
+
if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name in cls.NUMERIC_ATTRIBUTES:
|
575
|
+
if comparator not in cls.VALID_NUMERIC_ATTRIBUTE_COMPARATORS:
|
576
|
+
raise MlflowException(
|
577
|
+
f"Invalid comparator '{comparator}' not one of "
|
578
|
+
f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}",
|
579
|
+
error_code=INVALID_PARAMETER_VALUE,
|
580
|
+
)
|
581
|
+
return True
|
582
|
+
return False
|
583
|
+
|
584
|
+
@classmethod
|
585
|
+
def is_dataset(cls, key_type, comparator):
|
586
|
+
if key_type == cls._DATASET_IDENTIFIER:
|
587
|
+
if comparator not in cls.VALID_DATASET_COMPARATORS:
|
588
|
+
raise MlflowException(
|
589
|
+
f"Invalid comparator '{comparator}' "
|
590
|
+
f"not one of '{cls.VALID_DATASET_COMPARATORS}",
|
591
|
+
error_code=INVALID_PARAMETER_VALUE,
|
592
|
+
)
|
593
|
+
return True
|
594
|
+
return False
|
595
|
+
|
596
|
+
@classmethod
|
597
|
+
def _is_metric_on_dataset(cls, metric: Metric, dataset: dict[str, Any]) -> bool:
|
598
|
+
return metric.dataset_name == dataset.get("dataset_name") and (
|
599
|
+
dataset.get("dataset_digest") is None
|
600
|
+
or dataset.get("dataset_digest") == metric.dataset_digest
|
601
|
+
)
|
602
|
+
|
603
|
+
@classmethod
|
604
|
+
def _does_run_match_clause(cls, run, sed):
|
605
|
+
key_type = sed.get("type")
|
606
|
+
key = sed.get("key")
|
607
|
+
value = sed.get("value")
|
608
|
+
comparator = sed.get("comparator").upper()
|
609
|
+
|
610
|
+
key = SearchUtils.translate_key_alias(key)
|
611
|
+
|
612
|
+
if cls.is_metric(key_type, comparator):
|
613
|
+
lhs = run.data.metrics.get(key, None)
|
614
|
+
value = float(value)
|
615
|
+
elif cls.is_param(key_type, comparator):
|
616
|
+
lhs = run.data.params.get(key, None)
|
617
|
+
elif cls.is_tag(key_type, comparator):
|
618
|
+
lhs = run.data.tags.get(key, None)
|
619
|
+
elif cls.is_string_attribute(key_type, key, comparator):
|
620
|
+
lhs = getattr(run.info, key)
|
621
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
622
|
+
lhs = getattr(run.info, key)
|
623
|
+
value = int(value)
|
624
|
+
elif cls.is_dataset(key_type, comparator):
|
625
|
+
if key == "context":
|
626
|
+
return any(
|
627
|
+
SearchUtils.get_comparison_func(comparator)(tag.value if tag else None, value)
|
628
|
+
for dataset_input in run.inputs.dataset_inputs
|
629
|
+
for tag in dataset_input.tags
|
630
|
+
if tag.key == MLFLOW_DATASET_CONTEXT
|
631
|
+
)
|
632
|
+
else:
|
633
|
+
return any(
|
634
|
+
SearchUtils.get_comparison_func(comparator)(
|
635
|
+
getattr(dataset_input.dataset, key), value
|
636
|
+
)
|
637
|
+
for dataset_input in run.inputs.dataset_inputs
|
638
|
+
)
|
639
|
+
else:
|
640
|
+
raise MlflowException(
|
641
|
+
f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
|
642
|
+
)
|
643
|
+
if lhs is None:
|
644
|
+
return False
|
645
|
+
|
646
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
647
|
+
|
648
|
+
@classmethod
|
649
|
+
def _does_model_match_clause(cls, model, sed):
|
650
|
+
key_type = sed.get("type")
|
651
|
+
key = sed.get("key")
|
652
|
+
value = sed.get("value")
|
653
|
+
comparator = sed.get("comparator").upper()
|
654
|
+
|
655
|
+
key = SearchUtils.translate_key_alias(key)
|
656
|
+
|
657
|
+
if cls.is_metric(key_type, comparator):
|
658
|
+
matching_metrics = [metric for metric in model.metrics if metric.key == key]
|
659
|
+
lhs = matching_metrics[0].value if matching_metrics else None
|
660
|
+
value = float(value)
|
661
|
+
elif cls.is_param(key_type, comparator):
|
662
|
+
lhs = model.params.get(key, None)
|
663
|
+
elif cls.is_tag(key_type, comparator):
|
664
|
+
lhs = model.tags.get(key, None)
|
665
|
+
elif cls.is_string_attribute(key_type, key, comparator):
|
666
|
+
lhs = getattr(model.info, key)
|
667
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
668
|
+
lhs = getattr(model.info, key)
|
669
|
+
value = int(value)
|
670
|
+
else:
|
671
|
+
raise MlflowException(
|
672
|
+
f"Invalid model search expression type '{key_type}'",
|
673
|
+
error_code=INVALID_PARAMETER_VALUE,
|
674
|
+
)
|
675
|
+
if lhs is None:
|
676
|
+
return False
|
677
|
+
|
678
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
679
|
+
|
680
|
+
@classmethod
|
681
|
+
def filter(cls, runs, filter_string):
|
682
|
+
"""Filters a set of runs based on a search filter string."""
|
683
|
+
if not filter_string:
|
684
|
+
return runs
|
685
|
+
parsed = cls.parse_search_filter(filter_string)
|
686
|
+
|
687
|
+
def run_matches(run):
|
688
|
+
return all(cls._does_run_match_clause(run, s) for s in parsed)
|
689
|
+
|
690
|
+
return [run for run in runs if run_matches(run)]
|
691
|
+
|
692
|
+
@classmethod
|
693
|
+
def _validate_order_by_and_generate_token(cls, order_by):
|
694
|
+
try:
|
695
|
+
parsed = sqlparse.parse(order_by)
|
696
|
+
except Exception:
|
697
|
+
raise MlflowException(
|
698
|
+
f"Error on parsing order_by clause '{order_by}'",
|
699
|
+
error_code=INVALID_PARAMETER_VALUE,
|
700
|
+
)
|
701
|
+
if len(parsed) != 1 or not isinstance(parsed[0], Statement):
|
702
|
+
raise MlflowException(
|
703
|
+
f"Invalid order_by clause '{order_by}'. Could not be parsed.",
|
704
|
+
error_code=INVALID_PARAMETER_VALUE,
|
705
|
+
)
|
706
|
+
statement = parsed[0]
|
707
|
+
ttype_for_timestamp = (
|
708
|
+
TokenType.Name.Builtin
|
709
|
+
if Version(sqlparse.__version__) >= Version("0.4.3")
|
710
|
+
else TokenType.Keyword
|
711
|
+
)
|
712
|
+
|
713
|
+
if len(statement.tokens) == 1 and isinstance(statement[0], Identifier):
|
714
|
+
token_value = statement.tokens[0].value
|
715
|
+
elif len(statement.tokens) == 1 and statement.tokens[0].match(
|
716
|
+
ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP]
|
717
|
+
):
|
718
|
+
token_value = cls.ORDER_BY_KEY_TIMESTAMP
|
719
|
+
elif (
|
720
|
+
statement.tokens[0].match(
|
721
|
+
ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP]
|
722
|
+
)
|
723
|
+
and all(token.is_whitespace for token in statement.tokens[1:-1])
|
724
|
+
and statement.tokens[-1].ttype == TokenType.Keyword.Order
|
725
|
+
):
|
726
|
+
token_value = cls.ORDER_BY_KEY_TIMESTAMP + " " + statement.tokens[-1].value
|
727
|
+
else:
|
728
|
+
raise MlflowException(
|
729
|
+
f"Invalid order_by clause '{order_by}'. Could not be parsed.",
|
730
|
+
error_code=INVALID_PARAMETER_VALUE,
|
731
|
+
)
|
732
|
+
return token_value
|
733
|
+
|
734
|
+
@classmethod
|
735
|
+
def _parse_order_by_string(cls, order_by):
|
736
|
+
token_value = cls._validate_order_by_and_generate_token(order_by)
|
737
|
+
is_ascending = True
|
738
|
+
tokens = shlex.split(token_value.replace("`", '"'))
|
739
|
+
if len(tokens) > 2:
|
740
|
+
raise MlflowException(
|
741
|
+
f"Invalid order_by clause '{order_by}'. Could not be parsed.",
|
742
|
+
error_code=INVALID_PARAMETER_VALUE,
|
743
|
+
)
|
744
|
+
elif len(tokens) == 2:
|
745
|
+
order_token = tokens[1].lower()
|
746
|
+
if order_token not in cls.VALID_ORDER_BY_TAGS:
|
747
|
+
raise MlflowException(
|
748
|
+
f"Invalid ordering key in order_by clause '{order_by}'.",
|
749
|
+
error_code=INVALID_PARAMETER_VALUE,
|
750
|
+
)
|
751
|
+
is_ascending = order_token == cls.ASC_OPERATOR
|
752
|
+
token_value = tokens[0]
|
753
|
+
return token_value, is_ascending
|
754
|
+
|
755
|
+
@classmethod
|
756
|
+
def parse_order_by_for_search_runs(cls, order_by):
|
757
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
758
|
+
identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
|
759
|
+
return identifier["type"], identifier["key"], is_ascending
|
760
|
+
|
761
|
+
@classmethod
|
762
|
+
def parse_order_by_for_search_registered_models(cls, order_by):
|
763
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
764
|
+
token_value = token_value.strip()
|
765
|
+
if token_value not in cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS:
|
766
|
+
raise MlflowException(
|
767
|
+
f"Invalid order by key '{token_value}' specified. Valid keys "
|
768
|
+
f"are '{cls.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
|
769
|
+
error_code=INVALID_PARAMETER_VALUE,
|
770
|
+
)
|
771
|
+
return token_value, is_ascending
|
772
|
+
|
773
|
+
@classmethod
|
774
|
+
def _get_value_for_sort(cls, run, key_type, key, ascending):
|
775
|
+
"""Returns a tuple suitable to be used as a sort key for runs."""
|
776
|
+
sort_value = None
|
777
|
+
key = SearchUtils.translate_key_alias(key)
|
778
|
+
if key_type == cls._METRIC_IDENTIFIER:
|
779
|
+
sort_value = run.data.metrics.get(key)
|
780
|
+
elif key_type == cls._PARAM_IDENTIFIER:
|
781
|
+
sort_value = run.data.params.get(key)
|
782
|
+
elif key_type == cls._TAG_IDENTIFIER:
|
783
|
+
sort_value = run.data.tags.get(key)
|
784
|
+
elif key_type == cls._ATTRIBUTE_IDENTIFIER:
|
785
|
+
sort_value = getattr(run.info, key)
|
786
|
+
else:
|
787
|
+
raise MlflowException(
|
788
|
+
f"Invalid order_by entity type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
|
789
|
+
)
|
790
|
+
|
791
|
+
# Return a key such that None values are always at the end.
|
792
|
+
is_none = sort_value is None
|
793
|
+
is_nan = isinstance(sort_value, float) and math.isnan(sort_value)
|
794
|
+
fill_value = (1 if ascending else -1) * math.inf
|
795
|
+
|
796
|
+
if is_none:
|
797
|
+
sort_value = fill_value
|
798
|
+
elif is_nan:
|
799
|
+
sort_value = -fill_value
|
800
|
+
|
801
|
+
is_none_or_nan = is_none or is_nan
|
802
|
+
|
803
|
+
return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
|
804
|
+
|
805
|
+
@classmethod
|
806
|
+
def _get_model_value_for_sort(cls, model, key_type, key, ascending):
|
807
|
+
"""Returns a tuple suitable to be used as a sort key for models."""
|
808
|
+
sort_value = None
|
809
|
+
key = SearchUtils.translate_key_alias(key)
|
810
|
+
if key_type == cls._METRIC_IDENTIFIER:
|
811
|
+
matching_metrics = [metric for metric in model.metrics if metric.key == key]
|
812
|
+
sort_value = float(matching_metrics[0].value) if matching_metrics else None
|
813
|
+
elif key_type == cls._PARAM_IDENTIFIER:
|
814
|
+
sort_value = model.params.get(key)
|
815
|
+
elif key_type == cls._TAG_IDENTIFIER:
|
816
|
+
sort_value = model.tags.get(key)
|
817
|
+
elif key_type == cls._ATTRIBUTE_IDENTIFIER:
|
818
|
+
sort_value = getattr(model, key)
|
819
|
+
else:
|
820
|
+
raise MlflowException(
|
821
|
+
f"Invalid models order_by entity type '{key_type}'",
|
822
|
+
error_code=INVALID_PARAMETER_VALUE,
|
823
|
+
)
|
824
|
+
|
825
|
+
# Return a key such that None values are always at the end.
|
826
|
+
is_none = sort_value is None
|
827
|
+
is_nan = isinstance(sort_value, float) and math.isnan(sort_value)
|
828
|
+
fill_value = (1 if ascending else -1) * math.inf
|
829
|
+
|
830
|
+
if is_none:
|
831
|
+
sort_value = fill_value
|
832
|
+
elif is_nan:
|
833
|
+
sort_value = -fill_value
|
834
|
+
|
835
|
+
is_none_or_nan = is_none or is_nan
|
836
|
+
|
837
|
+
return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
|
838
|
+
|
839
|
+
@classmethod
|
840
|
+
def sort(cls, runs, order_by_list):
|
841
|
+
"""Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
|
842
|
+
Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
|
843
|
+
"""
|
844
|
+
runs = sorted(runs, key=lambda run: (-run.info.start_time, run.info.run_id))
|
845
|
+
if not order_by_list:
|
846
|
+
return runs
|
847
|
+
# NB: We rely on the stability of Python's sort function, so that we can apply
|
848
|
+
# the ordering conditions in reverse order.
|
849
|
+
for order_by_clause in reversed(order_by_list):
|
850
|
+
(key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause)
|
851
|
+
|
852
|
+
runs = sorted(
|
853
|
+
runs,
|
854
|
+
key=lambda run: cls._get_value_for_sort(run, key_type, key, ascending),
|
855
|
+
reverse=not ascending,
|
856
|
+
)
|
857
|
+
return runs
|
858
|
+
|
859
|
+
@classmethod
|
860
|
+
def parse_start_offset_from_page_token(cls, page_token):
|
861
|
+
# Note: the page_token is expected to be a base64-encoded JSON that looks like
|
862
|
+
# { "offset": xxx }. However, this format is not stable, so it should not be
|
863
|
+
# relied upon outside of this method.
|
864
|
+
if not page_token:
|
865
|
+
return 0
|
866
|
+
|
867
|
+
try:
|
868
|
+
decoded_token = base64.b64decode(page_token)
|
869
|
+
except TypeError:
|
870
|
+
raise MlflowException(
|
871
|
+
"Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE
|
872
|
+
)
|
873
|
+
except base64.binascii.Error:
|
874
|
+
raise MlflowException(
|
875
|
+
"Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE
|
876
|
+
)
|
877
|
+
|
878
|
+
try:
|
879
|
+
parsed_token = json.loads(decoded_token)
|
880
|
+
except ValueError:
|
881
|
+
raise MlflowException(
|
882
|
+
f"Invalid page token, decoded value={decoded_token}",
|
883
|
+
error_code=INVALID_PARAMETER_VALUE,
|
884
|
+
)
|
885
|
+
|
886
|
+
offset_str = parsed_token.get("offset")
|
887
|
+
if not offset_str:
|
888
|
+
raise MlflowException(
|
889
|
+
f"Invalid page token, parsed value={parsed_token}",
|
890
|
+
error_code=INVALID_PARAMETER_VALUE,
|
891
|
+
)
|
892
|
+
|
893
|
+
try:
|
894
|
+
offset = int(offset_str)
|
895
|
+
except ValueError:
|
896
|
+
raise MlflowException(
|
897
|
+
f"Invalid page token, not stringable {offset_str}",
|
898
|
+
error_code=INVALID_PARAMETER_VALUE,
|
899
|
+
)
|
900
|
+
|
901
|
+
return offset
|
902
|
+
|
903
|
+
@classmethod
|
904
|
+
def create_page_token(cls, offset):
|
905
|
+
return base64.b64encode(json.dumps({"offset": offset}).encode("utf-8"))
|
906
|
+
|
907
|
+
@classmethod
|
908
|
+
def paginate(cls, runs, page_token, max_results):
|
909
|
+
"""Paginates a set of runs based on an offset encoded into the page_token and a max
|
910
|
+
results limit. Returns a pair containing the set of paginated runs, followed by
|
911
|
+
an optional next_page_token if there are further results that need to be returned.
|
912
|
+
"""
|
913
|
+
start_offset = cls.parse_start_offset_from_page_token(page_token)
|
914
|
+
final_offset = start_offset + max_results
|
915
|
+
|
916
|
+
paginated_runs = runs[start_offset:final_offset]
|
917
|
+
next_page_token = None
|
918
|
+
if final_offset < len(runs):
|
919
|
+
next_page_token = cls.create_page_token(final_offset)
|
920
|
+
return (paginated_runs, next_page_token)
|
921
|
+
|
922
|
+
# Model Registry specific parser
|
923
|
+
# TODO: Tech debt. Refactor search code into common utils, tracking server, and model
|
924
|
+
# registry specific code.
|
925
|
+
|
926
|
+
VALID_SEARCH_KEYS_FOR_MODEL_VERSIONS = {"name", "run_id", "source_path"}
|
927
|
+
VALID_SEARCH_KEYS_FOR_REGISTERED_MODELS = {"name"}
|
928
|
+
|
929
|
+
@classmethod
|
930
|
+
def _check_valid_identifier_list(cls, tup: tuple[Any, ...]) -> None:
|
931
|
+
"""
|
932
|
+
Validate that `tup` is a non-empty tuple of strings.
|
933
|
+
"""
|
934
|
+
if len(tup) == 0:
|
935
|
+
raise MlflowException(
|
936
|
+
"While parsing a list in the query,"
|
937
|
+
" expected a non-empty list of string values, but got empty list",
|
938
|
+
error_code=INVALID_PARAMETER_VALUE,
|
939
|
+
)
|
940
|
+
|
941
|
+
if not all(isinstance(x, str) for x in tup):
|
942
|
+
raise MlflowException(
|
943
|
+
"While parsing a list in the query, expected string value, punctuation, "
|
944
|
+
f"or whitespace, but got different type in list: {tup}",
|
945
|
+
error_code=INVALID_PARAMETER_VALUE,
|
946
|
+
)
|
947
|
+
|
948
|
+
@classmethod
|
949
|
+
def _parse_list_from_sql_token(cls, token):
|
950
|
+
try:
|
951
|
+
parsed = ast.literal_eval(token.value)
|
952
|
+
except SyntaxError as e:
|
953
|
+
raise MlflowException(
|
954
|
+
"While parsing a list in the query,"
|
955
|
+
" expected a non-empty list of string values, but got ill-formed list.",
|
956
|
+
error_code=INVALID_PARAMETER_VALUE,
|
957
|
+
) from e
|
958
|
+
|
959
|
+
parsed = parsed if isinstance(parsed, tuple) else (parsed,)
|
960
|
+
cls._check_valid_identifier_list(parsed)
|
961
|
+
return parsed
|
962
|
+
|
963
|
+
@classmethod
|
964
|
+
def _parse_run_ids(cls, token):
|
965
|
+
run_id_list = cls._parse_list_from_sql_token(token)
|
966
|
+
# Because MySQL IN clause is case-insensitive, but all run_ids only contain lower
|
967
|
+
# case letters, so that we filter out run_ids containing upper case letters here.
|
968
|
+
return [run_id for run_id in run_id_list if run_id.islower()]
|
969
|
+
|
970
|
+
|
971
|
+
class SearchExperimentsUtils(SearchUtils):
|
972
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = {"name", "creation_time", "last_update_time"}
|
973
|
+
VALID_ORDER_BY_ATTRIBUTE_KEYS = {"name", "experiment_id", "creation_time", "last_update_time"}
|
974
|
+
NUMERIC_ATTRIBUTES = {"creation_time", "last_update_time"}
|
975
|
+
|
976
|
+
@classmethod
|
977
|
+
def _invalid_statement_token_search_experiments(cls, token):
|
978
|
+
if (
|
979
|
+
isinstance(token, Comparison)
|
980
|
+
or token.is_whitespace
|
981
|
+
or token.match(ttype=TokenType.Keyword, values=["AND"])
|
982
|
+
):
|
983
|
+
return False
|
984
|
+
return True
|
985
|
+
|
986
|
+
@classmethod
|
987
|
+
def _process_statement(cls, statement):
|
988
|
+
tokens = _join_in_comparison_tokens(statement.tokens)
|
989
|
+
invalids = list(filter(cls._invalid_statement_token_search_experiments, tokens))
|
990
|
+
if len(invalids) > 0:
|
991
|
+
invalid_clauses = ", ".join(map(str, invalids))
|
992
|
+
raise MlflowException.invalid_parameter_value(
|
993
|
+
f"Invalid clause(s) in filter string: {invalid_clauses}"
|
994
|
+
)
|
995
|
+
return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
|
996
|
+
|
997
|
+
@classmethod
|
998
|
+
def _get_identifier(cls, identifier, valid_attributes):
|
999
|
+
tokens = identifier.split(".", maxsplit=1)
|
1000
|
+
if len(tokens) == 1:
|
1001
|
+
key = tokens[0]
|
1002
|
+
identifier = cls._ATTRIBUTE_IDENTIFIER
|
1003
|
+
else:
|
1004
|
+
entity_type, key = tokens
|
1005
|
+
valid_entity_types = ("attribute", "tag", "tags")
|
1006
|
+
if entity_type not in valid_entity_types:
|
1007
|
+
raise MlflowException.invalid_parameter_value(
|
1008
|
+
f"Invalid entity type '{entity_type}'. "
|
1009
|
+
f"Valid entity types are {valid_entity_types}"
|
1010
|
+
)
|
1011
|
+
identifier = cls._valid_entity_type(entity_type)
|
1012
|
+
|
1013
|
+
key = cls._trim_backticks(cls._strip_quotes(key))
|
1014
|
+
if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
|
1015
|
+
raise MlflowException.invalid_parameter_value(
|
1016
|
+
f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
|
1017
|
+
)
|
1018
|
+
return {"type": identifier, "key": key}
|
1019
|
+
|
1020
|
+
@classmethod
|
1021
|
+
def _get_comparison(cls, comparison):
|
1022
|
+
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
|
1023
|
+
cls._validate_comparison(stripped_comparison)
|
1024
|
+
left, comparator, right = stripped_comparison
|
1025
|
+
comp = cls._get_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
|
1026
|
+
comp["comparator"] = comparator.value
|
1027
|
+
comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
|
1028
|
+
return comp
|
1029
|
+
|
1030
|
+
@classmethod
|
1031
|
+
def parse_order_by_for_search_experiments(cls, order_by):
|
1032
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
1033
|
+
identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
|
1034
|
+
return identifier["type"], identifier["key"], is_ascending
|
1035
|
+
|
1036
|
+
@classmethod
|
1037
|
+
def is_attribute(cls, key_type, comparator):
|
1038
|
+
if key_type == cls._ATTRIBUTE_IDENTIFIER:
|
1039
|
+
if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS:
|
1040
|
+
raise MlflowException(
|
1041
|
+
f"Invalid comparator '{comparator}' not one of "
|
1042
|
+
f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'"
|
1043
|
+
)
|
1044
|
+
return True
|
1045
|
+
return False
|
1046
|
+
|
1047
|
+
@classmethod
|
1048
|
+
def _does_experiment_match_clause(cls, experiment, sed):
|
1049
|
+
key_type = sed.get("type")
|
1050
|
+
key = sed.get("key")
|
1051
|
+
value = sed.get("value")
|
1052
|
+
comparator = sed.get("comparator").upper()
|
1053
|
+
|
1054
|
+
if cls.is_string_attribute(key_type, key, comparator):
|
1055
|
+
lhs = getattr(experiment, key)
|
1056
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
1057
|
+
lhs = getattr(experiment, key)
|
1058
|
+
value = float(value)
|
1059
|
+
elif cls.is_tag(key_type, comparator):
|
1060
|
+
if key not in experiment.tags:
|
1061
|
+
return False
|
1062
|
+
lhs = experiment.tags.get(key, None)
|
1063
|
+
if lhs is None:
|
1064
|
+
return experiment
|
1065
|
+
else:
|
1066
|
+
raise MlflowException(
|
1067
|
+
f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
|
1068
|
+
)
|
1069
|
+
|
1070
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
1071
|
+
|
1072
|
+
@classmethod
|
1073
|
+
def filter(cls, experiments, filter_string):
|
1074
|
+
if not filter_string:
|
1075
|
+
return experiments
|
1076
|
+
parsed = cls.parse_search_filter(filter_string)
|
1077
|
+
|
1078
|
+
def experiment_matches(experiment):
|
1079
|
+
return all(cls._does_experiment_match_clause(experiment, s) for s in parsed)
|
1080
|
+
|
1081
|
+
return list(filter(experiment_matches, experiments))
|
1082
|
+
|
1083
|
+
@classmethod
|
1084
|
+
def _get_sort_key(cls, order_by_list):
|
1085
|
+
order_by = []
|
1086
|
+
parsed_order_by = map(cls.parse_order_by_for_search_experiments, order_by_list)
|
1087
|
+
for type_, key, ascending in parsed_order_by:
|
1088
|
+
if type_ == "attribute":
|
1089
|
+
order_by.append((key, ascending))
|
1090
|
+
else:
|
1091
|
+
raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
|
1092
|
+
|
1093
|
+
# Add a tie-breaker
|
1094
|
+
if not any(key == "experiment_id" for key, _ in order_by):
|
1095
|
+
order_by.append(("experiment_id", False))
|
1096
|
+
|
1097
|
+
# https://stackoverflow.com/a/56842689
|
1098
|
+
class _Sorter:
|
1099
|
+
def __init__(self, obj, ascending):
|
1100
|
+
self.obj = obj
|
1101
|
+
self.ascending = ascending
|
1102
|
+
|
1103
|
+
# Only need < and == are needed for use as a key parameter in the sorted function
|
1104
|
+
def __eq__(self, other):
|
1105
|
+
return other.obj == self.obj
|
1106
|
+
|
1107
|
+
def __lt__(self, other):
|
1108
|
+
if self.obj is None:
|
1109
|
+
return False
|
1110
|
+
elif other.obj is None:
|
1111
|
+
return True
|
1112
|
+
elif self.ascending:
|
1113
|
+
return self.obj < other.obj
|
1114
|
+
else:
|
1115
|
+
return other.obj < self.obj
|
1116
|
+
|
1117
|
+
def _apply_sorter(experiment, key, ascending):
|
1118
|
+
attr = getattr(experiment, key)
|
1119
|
+
return _Sorter(attr, ascending)
|
1120
|
+
|
1121
|
+
return lambda experiment: tuple(_apply_sorter(experiment, k, asc) for (k, asc) in order_by)
|
1122
|
+
|
1123
|
+
@classmethod
|
1124
|
+
def sort(cls, experiments, order_by_list):
|
1125
|
+
return sorted(experiments, key=cls._get_sort_key(order_by_list))
|
1126
|
+
|
1127
|
+
|
1128
|
+
# https://stackoverflow.com/a/56842689
|
1129
|
+
class _Reversor:
|
1130
|
+
def __init__(self, obj):
|
1131
|
+
self.obj = obj
|
1132
|
+
|
1133
|
+
# Only need < and == are needed for use as a key parameter in the sorted function
|
1134
|
+
def __eq__(self, other):
|
1135
|
+
return other.obj == self.obj
|
1136
|
+
|
1137
|
+
def __lt__(self, other):
|
1138
|
+
if self.obj is None:
|
1139
|
+
return False
|
1140
|
+
if other.obj is None:
|
1141
|
+
return True
|
1142
|
+
return other.obj < self.obj
|
1143
|
+
|
1144
|
+
|
1145
|
+
def _apply_reversor(model, key, ascending):
|
1146
|
+
attr = getattr(model, key)
|
1147
|
+
return attr if ascending else _Reversor(attr)
|
1148
|
+
|
1149
|
+
|
1150
|
+
class SearchModelUtils(SearchUtils):
|
1151
|
+
NUMERIC_ATTRIBUTES = {"creation_timestamp", "last_updated_timestamp"}
|
1152
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = {"name"}
|
1153
|
+
VALID_ORDER_BY_KEYS_REGISTERED_MODELS = {"name", "creation_timestamp", "last_updated_timestamp"}
|
1154
|
+
|
1155
|
+
@classmethod
|
1156
|
+
def _does_registered_model_match_clauses(cls, model, sed):
|
1157
|
+
key_type = sed.get("type")
|
1158
|
+
key = sed.get("key")
|
1159
|
+
value = sed.get("value")
|
1160
|
+
comparator = sed.get("comparator").upper()
|
1161
|
+
|
1162
|
+
# what comparators do we support here?
|
1163
|
+
if cls.is_string_attribute(key_type, key, comparator):
|
1164
|
+
lhs = getattr(model, key)
|
1165
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
1166
|
+
lhs = getattr(model, key)
|
1167
|
+
value = int(value)
|
1168
|
+
elif cls.is_tag(key_type, comparator):
|
1169
|
+
# NB: We should use the private attribute `_tags` instead of the `tags` property
|
1170
|
+
# to consider all tags including reserved ones.
|
1171
|
+
lhs = model._tags.get(key, None)
|
1172
|
+
else:
|
1173
|
+
raise MlflowException(
|
1174
|
+
f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
|
1175
|
+
)
|
1176
|
+
|
1177
|
+
# NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for
|
1178
|
+
# distinguishing between prompt models and normal models. For example, we want to
|
1179
|
+
# search for models only by the following filter string:
|
1180
|
+
#
|
1181
|
+
# tags.`mlflow.prompt.is_prompt` != 'true'
|
1182
|
+
# tags.`mlflow.prompt.is_prompt` = 'false'
|
1183
|
+
#
|
1184
|
+
# However, models do not have this tag, so lhs is None in this case. Instead of returning
|
1185
|
+
# False like normal tag filter, we need to return True here.
|
1186
|
+
if key == IS_PROMPT_TAG_KEY and lhs is None:
|
1187
|
+
return (comparator == "=" and value == "false") or (
|
1188
|
+
comparator == "!=" and value == "true"
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
if lhs is None:
|
1192
|
+
return False
|
1193
|
+
|
1194
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
1195
|
+
|
1196
|
+
@classmethod
|
1197
|
+
def filter(cls, registered_models, filter_string):
|
1198
|
+
"""Filters a set of registered models based on a search filter string."""
|
1199
|
+
if not filter_string:
|
1200
|
+
return registered_models
|
1201
|
+
parsed = cls.parse_search_filter(filter_string)
|
1202
|
+
|
1203
|
+
def registered_model_matches(model):
|
1204
|
+
return all(cls._does_registered_model_match_clauses(model, s) for s in parsed)
|
1205
|
+
|
1206
|
+
return [
|
1207
|
+
registered_model
|
1208
|
+
for registered_model in registered_models
|
1209
|
+
if registered_model_matches(registered_model)
|
1210
|
+
]
|
1211
|
+
|
1212
|
+
@classmethod
|
1213
|
+
def parse_order_by_for_search_registered_models(cls, order_by):
|
1214
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
1215
|
+
identifier = SearchExperimentsUtils._get_identifier(
|
1216
|
+
token_value.strip(), cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS
|
1217
|
+
)
|
1218
|
+
return identifier["type"], identifier["key"], is_ascending
|
1219
|
+
|
1220
|
+
@classmethod
|
1221
|
+
def _get_sort_key(cls, order_by_list):
|
1222
|
+
order_by = []
|
1223
|
+
parsed_order_by = map(cls.parse_order_by_for_search_registered_models, order_by_list or [])
|
1224
|
+
for type_, key, ascending in parsed_order_by:
|
1225
|
+
if type_ == "attribute":
|
1226
|
+
order_by.append((key, ascending))
|
1227
|
+
else:
|
1228
|
+
raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
|
1229
|
+
|
1230
|
+
# Add a tie-breaker
|
1231
|
+
if not any(key == "name" for key, _ in order_by):
|
1232
|
+
order_by.append(("name", True))
|
1233
|
+
|
1234
|
+
return lambda model: tuple(_apply_reversor(model, k, asc) for (k, asc) in order_by)
|
1235
|
+
|
1236
|
+
@classmethod
|
1237
|
+
def sort(cls, models, order_by_list):
|
1238
|
+
return sorted(models, key=cls._get_sort_key(order_by_list))
|
1239
|
+
|
1240
|
+
@classmethod
|
1241
|
+
def _process_statement(cls, statement):
|
1242
|
+
tokens = _join_in_comparison_tokens(statement.tokens)
|
1243
|
+
invalids = list(filter(cls._invalid_statement_token_search_model_registry, tokens))
|
1244
|
+
if len(invalids) > 0:
|
1245
|
+
invalid_clauses = ", ".join(map(str, invalids))
|
1246
|
+
raise MlflowException.invalid_parameter_value(
|
1247
|
+
f"Invalid clause(s) in filter string: {invalid_clauses}"
|
1248
|
+
)
|
1249
|
+
return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
|
1250
|
+
|
1251
|
+
@classmethod
|
1252
|
+
def _get_model_search_identifier(cls, identifier, valid_attributes):
|
1253
|
+
tokens = identifier.split(".", maxsplit=1)
|
1254
|
+
if len(tokens) == 1:
|
1255
|
+
key = tokens[0]
|
1256
|
+
identifier = cls._ATTRIBUTE_IDENTIFIER
|
1257
|
+
else:
|
1258
|
+
entity_type, key = tokens
|
1259
|
+
valid_entity_types = ("attribute", "tag", "tags")
|
1260
|
+
if entity_type not in valid_entity_types:
|
1261
|
+
raise MlflowException.invalid_parameter_value(
|
1262
|
+
f"Invalid entity type '{entity_type}'. "
|
1263
|
+
f"Valid entity types are {valid_entity_types}"
|
1264
|
+
)
|
1265
|
+
identifier = (
|
1266
|
+
cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
|
1270
|
+
raise MlflowException.invalid_parameter_value(
|
1271
|
+
f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
|
1272
|
+
)
|
1273
|
+
|
1274
|
+
key = cls._trim_backticks(cls._strip_quotes(key))
|
1275
|
+
return {"type": identifier, "key": key}
|
1276
|
+
|
1277
|
+
@classmethod
|
1278
|
+
def _get_comparison(cls, comparison):
|
1279
|
+
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
|
1280
|
+
cls._validate_comparison(stripped_comparison)
|
1281
|
+
left, comparator, right = stripped_comparison
|
1282
|
+
comp = cls._get_model_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
|
1283
|
+
comp["comparator"] = comparator.value.upper()
|
1284
|
+
comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
|
1285
|
+
return comp
|
1286
|
+
|
1287
|
+
@classmethod
|
1288
|
+
def _get_value(cls, identifier_type, key, token):
|
1289
|
+
if identifier_type == cls._TAG_IDENTIFIER:
|
1290
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1291
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1292
|
+
raise MlflowException(
|
1293
|
+
"Expected a quoted string value for "
|
1294
|
+
f"{identifier_type} (e.g. 'my-value'). Got value "
|
1295
|
+
f"{token.value}",
|
1296
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1297
|
+
)
|
1298
|
+
elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
|
1299
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1300
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1301
|
+
elif isinstance(token, Parenthesis):
|
1302
|
+
if key != "run_id":
|
1303
|
+
raise MlflowException(
|
1304
|
+
"Only the 'run_id' attribute supports comparison with a list of quoted "
|
1305
|
+
"string values.",
|
1306
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1307
|
+
)
|
1308
|
+
return cls._parse_run_ids(token)
|
1309
|
+
else:
|
1310
|
+
raise MlflowException(
|
1311
|
+
"Expected a quoted string value or a list of quoted string values for "
|
1312
|
+
f"attributes. Got value {token.value}",
|
1313
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1314
|
+
)
|
1315
|
+
else:
|
1316
|
+
# Expected to be either "param" or "metric".
|
1317
|
+
raise MlflowException(
|
1318
|
+
"Invalid identifier type. Expected one of "
|
1319
|
+
f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.",
|
1320
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1321
|
+
)
|
1322
|
+
|
1323
|
+
@classmethod
|
1324
|
+
def _invalid_statement_token_search_model_registry(cls, token):
|
1325
|
+
if (
|
1326
|
+
isinstance(token, Comparison)
|
1327
|
+
or token.is_whitespace
|
1328
|
+
or token.match(ttype=TokenType.Keyword, values=["AND"])
|
1329
|
+
):
|
1330
|
+
return False
|
1331
|
+
return True
|
1332
|
+
|
1333
|
+
|
1334
|
+
class SearchModelVersionUtils(SearchUtils):
|
1335
|
+
NUMERIC_ATTRIBUTES = {"version_number", "creation_timestamp", "last_updated_timestamp"}
|
1336
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = {
|
1337
|
+
"name",
|
1338
|
+
"version_number",
|
1339
|
+
"run_id",
|
1340
|
+
"source_path",
|
1341
|
+
}
|
1342
|
+
VALID_ORDER_BY_ATTRIBUTE_KEYS = {
|
1343
|
+
"name",
|
1344
|
+
"version_number",
|
1345
|
+
"creation_timestamp",
|
1346
|
+
"last_updated_timestamp",
|
1347
|
+
}
|
1348
|
+
VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "IN"}
|
1349
|
+
|
1350
|
+
@classmethod
|
1351
|
+
def _does_model_version_match_clauses(cls, mv, sed):
|
1352
|
+
key_type = sed.get("type")
|
1353
|
+
key = sed.get("key")
|
1354
|
+
value = sed.get("value")
|
1355
|
+
comparator = sed.get("comparator").upper()
|
1356
|
+
|
1357
|
+
if cls.is_string_attribute(key_type, key, comparator):
|
1358
|
+
lhs = getattr(mv, "source" if key == "source_path" else key)
|
1359
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
1360
|
+
if key == "version_number":
|
1361
|
+
key = "version"
|
1362
|
+
lhs = getattr(mv, key)
|
1363
|
+
value = int(value)
|
1364
|
+
elif cls.is_tag(key_type, comparator):
|
1365
|
+
lhs = mv.tags.get(key, None)
|
1366
|
+
else:
|
1367
|
+
raise MlflowException(
|
1368
|
+
f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
|
1369
|
+
)
|
1370
|
+
|
1371
|
+
# NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for
|
1372
|
+
# distinguishing between prompt models and normal models. For example, we want to
|
1373
|
+
# search for models only by the following filter string:
|
1374
|
+
#
|
1375
|
+
# tags.`mlflow.prompt.is_prompt` != 'true'
|
1376
|
+
# tags.`mlflow.prompt.is_prompt` = 'false'
|
1377
|
+
#
|
1378
|
+
# However, models do not have this tag, so lhs is None in this case. Instead of returning
|
1379
|
+
# False like normal tag filter, we need to return True here.
|
1380
|
+
if key == IS_PROMPT_TAG_KEY and lhs is None:
|
1381
|
+
return (comparator == "=" and value == "false") or (
|
1382
|
+
comparator == "!=" and value == "true"
|
1383
|
+
)
|
1384
|
+
|
1385
|
+
if lhs is None:
|
1386
|
+
return False
|
1387
|
+
|
1388
|
+
if comparator == "IN" and isinstance(value, (set, list)):
|
1389
|
+
return lhs in set(value)
|
1390
|
+
|
1391
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
1392
|
+
|
1393
|
+
@classmethod
|
1394
|
+
def filter(cls, model_versions, filter_string):
|
1395
|
+
"""Filters a set of model versions based on a search filter string."""
|
1396
|
+
model_versions = [mv for mv in model_versions if mv.current_stage != STAGE_DELETED_INTERNAL]
|
1397
|
+
if not filter_string:
|
1398
|
+
return model_versions
|
1399
|
+
parsed = cls.parse_search_filter(filter_string)
|
1400
|
+
|
1401
|
+
def model_version_matches(mv):
|
1402
|
+
return all(cls._does_model_version_match_clauses(mv, s) for s in parsed)
|
1403
|
+
|
1404
|
+
return [mv for mv in model_versions if model_version_matches(mv)]
|
1405
|
+
|
1406
|
+
@classmethod
|
1407
|
+
def parse_order_by_for_search_model_versions(cls, order_by):
|
1408
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
1409
|
+
identifier = SearchExperimentsUtils._get_identifier(
|
1410
|
+
token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS
|
1411
|
+
)
|
1412
|
+
return identifier["type"], identifier["key"], is_ascending
|
1413
|
+
|
1414
|
+
@classmethod
|
1415
|
+
def _get_sort_key(cls, order_by_list):
|
1416
|
+
order_by = []
|
1417
|
+
parsed_order_by = map(cls.parse_order_by_for_search_model_versions, order_by_list or [])
|
1418
|
+
for type_, key, ascending in parsed_order_by:
|
1419
|
+
if type_ == "attribute":
|
1420
|
+
# Need to add this mapping because version is a keyword in sql
|
1421
|
+
if key == "version_number":
|
1422
|
+
key = "version"
|
1423
|
+
order_by.append((key, ascending))
|
1424
|
+
else:
|
1425
|
+
raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
|
1426
|
+
|
1427
|
+
# Add a tie-breaker
|
1428
|
+
if not any(key == "name" for key, _ in order_by):
|
1429
|
+
order_by.append(("name", True))
|
1430
|
+
if not any(key == "version_number" for key, _ in order_by):
|
1431
|
+
order_by.append(("version", False))
|
1432
|
+
|
1433
|
+
return lambda model_version: tuple(
|
1434
|
+
_apply_reversor(model_version, k, asc) for (k, asc) in order_by
|
1435
|
+
)
|
1436
|
+
|
1437
|
+
@classmethod
|
1438
|
+
def sort(cls, model_versions, order_by_list):
|
1439
|
+
return sorted(model_versions, key=cls._get_sort_key(order_by_list))
|
1440
|
+
|
1441
|
+
@classmethod
|
1442
|
+
def _get_model_version_search_identifier(cls, identifier, valid_attributes):
|
1443
|
+
tokens = identifier.split(".", maxsplit=1)
|
1444
|
+
if len(tokens) == 1:
|
1445
|
+
key = tokens[0]
|
1446
|
+
identifier = cls._ATTRIBUTE_IDENTIFIER
|
1447
|
+
else:
|
1448
|
+
entity_type, key = tokens
|
1449
|
+
valid_entity_types = ("attribute", "tag", "tags")
|
1450
|
+
if entity_type not in valid_entity_types:
|
1451
|
+
raise MlflowException.invalid_parameter_value(
|
1452
|
+
f"Invalid entity type '{entity_type}'. "
|
1453
|
+
f"Valid entity types are {valid_entity_types}"
|
1454
|
+
)
|
1455
|
+
identifier = (
|
1456
|
+
cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER
|
1457
|
+
)
|
1458
|
+
|
1459
|
+
if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
|
1460
|
+
raise MlflowException.invalid_parameter_value(
|
1461
|
+
f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
|
1462
|
+
)
|
1463
|
+
|
1464
|
+
key = cls._trim_backticks(cls._strip_quotes(key))
|
1465
|
+
return {"type": identifier, "key": key}
|
1466
|
+
|
1467
|
+
@classmethod
|
1468
|
+
def _get_comparison(cls, comparison):
|
1469
|
+
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
|
1470
|
+
cls._validate_comparison(stripped_comparison)
|
1471
|
+
left, comparator, right = stripped_comparison
|
1472
|
+
comp = cls._get_model_version_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
|
1473
|
+
comp["comparator"] = comparator.value.upper()
|
1474
|
+
comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
|
1475
|
+
return comp
|
1476
|
+
|
1477
|
+
@classmethod
|
1478
|
+
def _get_value(cls, identifier_type, key, token):
|
1479
|
+
if identifier_type == cls._TAG_IDENTIFIER:
|
1480
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1481
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1482
|
+
raise MlflowException(
|
1483
|
+
"Expected a quoted string value for "
|
1484
|
+
f"{identifier_type} (e.g. 'my-value'). Got value "
|
1485
|
+
f"{token.value}",
|
1486
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1487
|
+
)
|
1488
|
+
elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
|
1489
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1490
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1491
|
+
elif isinstance(token, Parenthesis):
|
1492
|
+
if key != "run_id":
|
1493
|
+
raise MlflowException(
|
1494
|
+
"Only the 'run_id' attribute supports comparison with a list of quoted "
|
1495
|
+
"string values.",
|
1496
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1497
|
+
)
|
1498
|
+
return cls._parse_run_ids(token)
|
1499
|
+
elif token.ttype in cls.NUMERIC_VALUE_TYPES:
|
1500
|
+
if key not in cls.NUMERIC_ATTRIBUTES:
|
1501
|
+
raise MlflowException(
|
1502
|
+
f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with "
|
1503
|
+
"numeric values.",
|
1504
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1505
|
+
)
|
1506
|
+
if token.ttype == TokenType.Literal.Number.Integer:
|
1507
|
+
return int(token.value)
|
1508
|
+
elif token.ttype == TokenType.Literal.Number.Float:
|
1509
|
+
return float(token.value)
|
1510
|
+
else:
|
1511
|
+
raise MlflowException(
|
1512
|
+
"Expected a quoted string value or a list of quoted string values for "
|
1513
|
+
f"attributes. Got value {token.value}",
|
1514
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1515
|
+
)
|
1516
|
+
else:
|
1517
|
+
# Expected to be either "param" or "metric".
|
1518
|
+
raise MlflowException(
|
1519
|
+
"Invalid identifier type. Expected one of "
|
1520
|
+
f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.",
|
1521
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1522
|
+
)
|
1523
|
+
|
1524
|
+
@classmethod
|
1525
|
+
def _process_statement(cls, statement):
|
1526
|
+
tokens = _join_in_comparison_tokens(statement.tokens)
|
1527
|
+
invalids = list(filter(cls._invalid_statement_token_search_model_version, tokens))
|
1528
|
+
if len(invalids) > 0:
|
1529
|
+
invalid_clauses = ", ".join(map(str, invalids))
|
1530
|
+
raise MlflowException.invalid_parameter_value(
|
1531
|
+
f"Invalid clause(s) in filter string: {invalid_clauses}"
|
1532
|
+
)
|
1533
|
+
return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
|
1534
|
+
|
1535
|
+
@classmethod
|
1536
|
+
def _invalid_statement_token_search_model_version(cls, token):
|
1537
|
+
if (
|
1538
|
+
isinstance(token, Comparison)
|
1539
|
+
or token.is_whitespace
|
1540
|
+
or token.match(ttype=TokenType.Keyword, values=["AND"])
|
1541
|
+
):
|
1542
|
+
return False
|
1543
|
+
return True
|
1544
|
+
|
1545
|
+
@classmethod
|
1546
|
+
def parse_search_filter(cls, filter_string):
|
1547
|
+
if not filter_string:
|
1548
|
+
return []
|
1549
|
+
try:
|
1550
|
+
parsed = sqlparse.parse(filter_string)
|
1551
|
+
except Exception:
|
1552
|
+
raise MlflowException(
|
1553
|
+
f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE
|
1554
|
+
)
|
1555
|
+
if len(parsed) == 0 or not isinstance(parsed[0], Statement):
|
1556
|
+
raise MlflowException(
|
1557
|
+
f"Invalid filter '{filter_string}'. Could not be parsed.",
|
1558
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1559
|
+
)
|
1560
|
+
elif len(parsed) > 1:
|
1561
|
+
raise MlflowException(
|
1562
|
+
f"Search filter contained multiple expression {filter_string!r}. "
|
1563
|
+
"Provide AND-ed expression list.",
|
1564
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1565
|
+
)
|
1566
|
+
return cls._process_statement(parsed[0])
|
1567
|
+
|
1568
|
+
|
1569
|
+
class SearchTraceUtils(SearchUtils):
|
1570
|
+
"""
|
1571
|
+
Utility class for searching traces.
|
1572
|
+
"""
|
1573
|
+
|
1574
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = {
|
1575
|
+
"request_id",
|
1576
|
+
"timestamp",
|
1577
|
+
"timestamp_ms",
|
1578
|
+
"execution_time",
|
1579
|
+
"execution_time_ms",
|
1580
|
+
"status",
|
1581
|
+
# The following keys are mapped to tags or metadata
|
1582
|
+
"name",
|
1583
|
+
"run_id",
|
1584
|
+
}
|
1585
|
+
VALID_ORDER_BY_ATTRIBUTE_KEYS = {
|
1586
|
+
"experiment_id",
|
1587
|
+
"timestamp",
|
1588
|
+
"timestamp_ms",
|
1589
|
+
"execution_time",
|
1590
|
+
"execution_time_ms",
|
1591
|
+
"status",
|
1592
|
+
"request_id",
|
1593
|
+
# The following keys are mapped to tags or metadata
|
1594
|
+
"name",
|
1595
|
+
"run_id",
|
1596
|
+
}
|
1597
|
+
|
1598
|
+
NUMERIC_ATTRIBUTES = {
|
1599
|
+
"timestamp_ms",
|
1600
|
+
"timestamp",
|
1601
|
+
"execution_time_ms",
|
1602
|
+
"execution_time",
|
1603
|
+
}
|
1604
|
+
|
1605
|
+
# For now, don't support LIKE/ILIKE operators for trace search because it may
|
1606
|
+
# cause performance issues with large attributes and tags. We can revisit this
|
1607
|
+
# decision if we find a way to support them efficiently.
|
1608
|
+
VALID_TAG_COMPARATORS = {"!=", "="}
|
1609
|
+
VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "IN", "NOT IN"}
|
1610
|
+
|
1611
|
+
_REQUEST_METADATA_IDENTIFIER = "request_metadata"
|
1612
|
+
_TAG_IDENTIFIER = "tag"
|
1613
|
+
_ATTRIBUTE_IDENTIFIER = "attribute"
|
1614
|
+
|
1615
|
+
# These are aliases for the base identifiers
|
1616
|
+
# e.g. trace.status is equivalent to attribute.status
|
1617
|
+
_ALTERNATE_IDENTIFIERS = {
|
1618
|
+
"tags": _TAG_IDENTIFIER,
|
1619
|
+
"attributes": _ATTRIBUTE_IDENTIFIER,
|
1620
|
+
"trace": _ATTRIBUTE_IDENTIFIER,
|
1621
|
+
"metadata": _REQUEST_METADATA_IDENTIFIER,
|
1622
|
+
}
|
1623
|
+
_IDENTIFIERS = {_TAG_IDENTIFIER, _REQUEST_METADATA_IDENTIFIER, _ATTRIBUTE_IDENTIFIER}
|
1624
|
+
_VALID_IDENTIFIERS = _IDENTIFIERS | set(_ALTERNATE_IDENTIFIERS.keys())
|
1625
|
+
|
1626
|
+
SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS = {"name", "status", "request_id", "run_id"}
|
1627
|
+
|
1628
|
+
# Some search keys are defined differently in the DB models.
|
1629
|
+
# E.g. "name" is mapped to TraceTagKey.TRACE_NAME
|
1630
|
+
SEARCH_KEY_TO_TAG = {
|
1631
|
+
"name": TraceTagKey.TRACE_NAME,
|
1632
|
+
}
|
1633
|
+
SEARCH_KEY_TO_METADATA = {
|
1634
|
+
"run_id": TraceMetadataKey.SOURCE_RUN,
|
1635
|
+
}
|
1636
|
+
# Alias for attribute keys
|
1637
|
+
SEARCH_KEY_TO_ATTRIBUTE = {
|
1638
|
+
"timestamp": "timestamp_ms",
|
1639
|
+
"execution_time": "execution_time_ms",
|
1640
|
+
}
|
1641
|
+
|
1642
|
+
@classmethod
|
1643
|
+
def filter(cls, traces, filter_string):
|
1644
|
+
"""Filters a set of traces based on a search filter string."""
|
1645
|
+
if not filter_string:
|
1646
|
+
return traces
|
1647
|
+
parsed = cls.parse_search_filter_for_search_traces(filter_string)
|
1648
|
+
|
1649
|
+
def trace_matches(trace):
|
1650
|
+
return all(cls._does_trace_match_clause(trace, s) for s in parsed)
|
1651
|
+
|
1652
|
+
return list(filter(trace_matches, traces))
|
1653
|
+
|
1654
|
+
@classmethod
|
1655
|
+
def _does_trace_match_clause(cls, trace, sed):
|
1656
|
+
type_ = sed.get("type")
|
1657
|
+
key = sed.get("key")
|
1658
|
+
value = sed.get("value")
|
1659
|
+
comparator = sed.get("comparator").upper()
|
1660
|
+
|
1661
|
+
if cls.is_tag(type_, comparator):
|
1662
|
+
lhs = trace.tags.get(key)
|
1663
|
+
elif cls.is_request_metadata(type_, comparator):
|
1664
|
+
lhs = trace.request_metadata.get(key)
|
1665
|
+
elif cls.is_attribute(type_, key, comparator):
|
1666
|
+
lhs = getattr(trace, key)
|
1667
|
+
elif sed.get("type") == cls._TAG_IDENTIFIER:
|
1668
|
+
lhs = trace.tags.get(key)
|
1669
|
+
else:
|
1670
|
+
raise MlflowException(
|
1671
|
+
f"Invalid search key '{key}', supported are {cls.VALID_SEARCH_ATTRIBUTE_KEYS}",
|
1672
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1673
|
+
)
|
1674
|
+
if lhs is None:
|
1675
|
+
return False
|
1676
|
+
|
1677
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
1678
|
+
|
1679
|
+
@classmethod
|
1680
|
+
def sort(cls, traces, order_by_list):
|
1681
|
+
return sorted(traces, key=cls._get_sort_key(order_by_list))
|
1682
|
+
|
1683
|
+
@classmethod
|
1684
|
+
def parse_order_by_for_search_traces(cls, order_by):
|
1685
|
+
token_value, is_ascending = cls._parse_order_by_string(order_by)
|
1686
|
+
identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
|
1687
|
+
identifier = cls._replace_key_to_tag_or_metadata(identifier)
|
1688
|
+
return identifier["type"], identifier["key"], is_ascending
|
1689
|
+
|
1690
|
+
@classmethod
|
1691
|
+
def parse_search_filter_for_search_traces(cls, filter_string):
|
1692
|
+
parsed = cls.parse_search_filter(filter_string)
|
1693
|
+
return [cls._replace_key_to_tag_or_metadata(p) for p in parsed]
|
1694
|
+
|
1695
|
+
@classmethod
|
1696
|
+
def _replace_key_to_tag_or_metadata(cls, parsed: dict[str, Any]):
|
1697
|
+
"""
|
1698
|
+
Replace search key to tag or metadata key if it is in the mapping.
|
1699
|
+
"""
|
1700
|
+
key = parsed.get("key").lower()
|
1701
|
+
if key in cls.SEARCH_KEY_TO_TAG:
|
1702
|
+
parsed["type"] = cls._TAG_IDENTIFIER
|
1703
|
+
parsed["key"] = cls.SEARCH_KEY_TO_TAG[key]
|
1704
|
+
elif key in cls.SEARCH_KEY_TO_METADATA:
|
1705
|
+
parsed["type"] = cls._REQUEST_METADATA_IDENTIFIER
|
1706
|
+
parsed["key"] = cls.SEARCH_KEY_TO_METADATA[key]
|
1707
|
+
elif key in cls.SEARCH_KEY_TO_ATTRIBUTE:
|
1708
|
+
parsed["key"] = cls.SEARCH_KEY_TO_ATTRIBUTE[key]
|
1709
|
+
return parsed
|
1710
|
+
|
1711
|
+
@classmethod
|
1712
|
+
def is_request_metadata(cls, key_type, comparator):
|
1713
|
+
if key_type == cls._REQUEST_METADATA_IDENTIFIER:
|
1714
|
+
# Request metadata accepts the same set of comparators as tags
|
1715
|
+
if comparator not in cls.VALID_TAG_COMPARATORS:
|
1716
|
+
raise MlflowException(
|
1717
|
+
f"Invalid comparator '{comparator}' not one of '{cls.VALID_TAG_COMPARATORS}'",
|
1718
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1719
|
+
)
|
1720
|
+
return True
|
1721
|
+
return False
|
1722
|
+
|
1723
|
+
@classmethod
|
1724
|
+
def _valid_entity_type(cls, entity_type):
|
1725
|
+
entity_type = cls._trim_backticks(entity_type)
|
1726
|
+
if entity_type not in cls._VALID_IDENTIFIERS:
|
1727
|
+
raise MlflowException(
|
1728
|
+
f"Invalid entity type '{entity_type}'. Valid values are {cls._VALID_IDENTIFIERS}",
|
1729
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1730
|
+
)
|
1731
|
+
elif entity_type in cls._ALTERNATE_IDENTIFIERS:
|
1732
|
+
return cls._ALTERNATE_IDENTIFIERS[entity_type]
|
1733
|
+
else:
|
1734
|
+
return entity_type
|
1735
|
+
|
1736
|
+
@classmethod
|
1737
|
+
def _get_sort_key(cls, order_by_list):
|
1738
|
+
order_by = []
|
1739
|
+
parsed_order_by = map(cls.parse_order_by_for_search_traces, order_by_list or [])
|
1740
|
+
for type_, key, ascending in parsed_order_by:
|
1741
|
+
if type_ == "attribute":
|
1742
|
+
order_by.append((key, ascending))
|
1743
|
+
else:
|
1744
|
+
raise MlflowException.invalid_parameter_value(
|
1745
|
+
f"Invalid order_by entity `{type_}` with key `{key}`"
|
1746
|
+
)
|
1747
|
+
|
1748
|
+
# Add a tie-breaker
|
1749
|
+
if not any(key == "timestamp_ms" for key, _ in order_by):
|
1750
|
+
order_by.append(("timestamp_ms", False))
|
1751
|
+
if not any(key == "request_id" for key, _ in order_by):
|
1752
|
+
order_by.append(("request_id", True))
|
1753
|
+
|
1754
|
+
return lambda trace: tuple(_apply_reversor(trace, k, asc) for (k, asc) in order_by)
|
1755
|
+
|
1756
|
+
@classmethod
|
1757
|
+
def _get_value(cls, identifier_type, key, token):
|
1758
|
+
if identifier_type == cls._TAG_IDENTIFIER:
|
1759
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1760
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1761
|
+
elif isinstance(token, Parenthesis):
|
1762
|
+
return cls._parse_attribute_lists(token)
|
1763
|
+
raise MlflowException(
|
1764
|
+
"Expected a quoted string value for "
|
1765
|
+
f"{identifier_type} (e.g. 'my-value'). Got value "
|
1766
|
+
f"{token.value}",
|
1767
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1768
|
+
)
|
1769
|
+
elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
|
1770
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1771
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1772
|
+
elif isinstance(token, Parenthesis):
|
1773
|
+
if key not in cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS:
|
1774
|
+
raise MlflowException(
|
1775
|
+
f"Only attributes in {cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS} "
|
1776
|
+
"supports comparison with a list of quoted string values.",
|
1777
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1778
|
+
)
|
1779
|
+
return cls._parse_attribute_lists(token)
|
1780
|
+
elif token.ttype in cls.NUMERIC_VALUE_TYPES:
|
1781
|
+
if key not in cls.NUMERIC_ATTRIBUTES:
|
1782
|
+
raise MlflowException(
|
1783
|
+
f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with "
|
1784
|
+
"numeric values.",
|
1785
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1786
|
+
)
|
1787
|
+
if token.ttype == TokenType.Literal.Number.Integer:
|
1788
|
+
return int(token.value)
|
1789
|
+
elif token.ttype == TokenType.Literal.Number.Float:
|
1790
|
+
return float(token.value)
|
1791
|
+
else:
|
1792
|
+
raise MlflowException(
|
1793
|
+
"Expected a quoted string value or a list of quoted string values for "
|
1794
|
+
f"attributes. Got value {token.value}",
|
1795
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1796
|
+
)
|
1797
|
+
elif identifier_type == cls._REQUEST_METADATA_IDENTIFIER:
|
1798
|
+
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
|
1799
|
+
return cls._strip_quotes(token.value, expect_quoted_value=True)
|
1800
|
+
else:
|
1801
|
+
raise MlflowException(
|
1802
|
+
"Expected a quoted string value for "
|
1803
|
+
f"{identifier_type} (e.g. 'my-value'). Got value "
|
1804
|
+
f"{token.value}",
|
1805
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1806
|
+
)
|
1807
|
+
else:
|
1808
|
+
# Expected to be either "param" or "metric".
|
1809
|
+
raise MlflowException(
|
1810
|
+
f"Invalid identifier type: {identifier_type}. "
|
1811
|
+
f"Expected one of {cls._VALID_IDENTIFIERS}.",
|
1812
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1813
|
+
)
|
1814
|
+
|
1815
|
+
@classmethod
|
1816
|
+
def _parse_attribute_lists(cls, token):
|
1817
|
+
return cls._parse_list_from_sql_token(token)
|
1818
|
+
|
1819
|
+
@classmethod
|
1820
|
+
def _process_statement(cls, statement):
|
1821
|
+
# check validity
|
1822
|
+
tokens = _join_in_comparison_tokens(statement.tokens, search_traces=True)
|
1823
|
+
invalids = list(filter(cls._invalid_statement_token_search_traces, tokens))
|
1824
|
+
if len(invalids) > 0:
|
1825
|
+
invalid_clauses = ", ".join(f"'{token}'" for token in invalids)
|
1826
|
+
raise MlflowException(
|
1827
|
+
f"Invalid clause(s) in filter string: {invalid_clauses}",
|
1828
|
+
error_code=INVALID_PARAMETER_VALUE,
|
1829
|
+
)
|
1830
|
+
return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)]
|
1831
|
+
|
1832
|
+
@classmethod
|
1833
|
+
def _invalid_statement_token_search_traces(cls, token):
|
1834
|
+
if (
|
1835
|
+
isinstance(token, Comparison)
|
1836
|
+
or token.is_whitespace
|
1837
|
+
or token.match(ttype=TokenType.Keyword, values=["AND"])
|
1838
|
+
):
|
1839
|
+
return False
|
1840
|
+
return True
|
1841
|
+
|
1842
|
+
@classmethod
|
1843
|
+
def _get_comparison(cls, comparison):
|
1844
|
+
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
|
1845
|
+
cls._validate_comparison(stripped_comparison, search_traces=True)
|
1846
|
+
comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
|
1847
|
+
comp["comparator"] = stripped_comparison[1].value
|
1848
|
+
comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2])
|
1849
|
+
return comp
|
1850
|
+
|
1851
|
+
|
1852
|
+
class SearchLoggedModelsUtils(SearchUtils):
|
1853
|
+
NUMERIC_ATTRIBUTES = {
|
1854
|
+
"creation_timestamp",
|
1855
|
+
"creation_time",
|
1856
|
+
"last_updated_timestamp",
|
1857
|
+
"last_updated_time",
|
1858
|
+
}
|
1859
|
+
VALID_SEARCH_ATTRIBUTE_KEYS = {
|
1860
|
+
"name",
|
1861
|
+
"model_id",
|
1862
|
+
"model_type",
|
1863
|
+
"status",
|
1864
|
+
"source_run_id",
|
1865
|
+
} | NUMERIC_ATTRIBUTES
|
1866
|
+
VALID_ORDER_BY_ATTRIBUTE_KEYS = VALID_SEARCH_ATTRIBUTE_KEYS
|
1867
|
+
|
1868
|
+
@classmethod
|
1869
|
+
def _does_logged_model_match_clause(
|
1870
|
+
cls,
|
1871
|
+
model: LoggedModel,
|
1872
|
+
condition: dict[str, Any],
|
1873
|
+
datasets: Optional[list[dict[str, Any]]] = None,
|
1874
|
+
):
|
1875
|
+
key_type = condition.get("type")
|
1876
|
+
key = condition.get("key")
|
1877
|
+
value = condition.get("value")
|
1878
|
+
comparator = condition.get("comparator").upper()
|
1879
|
+
|
1880
|
+
key = SearchUtils.translate_key_alias(key)
|
1881
|
+
|
1882
|
+
if cls.is_metric(key_type, comparator):
|
1883
|
+
matching_metrics = [metric for metric in model.metrics if metric.key == key]
|
1884
|
+
if datasets:
|
1885
|
+
matching_metrics = [
|
1886
|
+
metric
|
1887
|
+
for metric in matching_metrics
|
1888
|
+
if any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
|
1889
|
+
]
|
1890
|
+
lhs = matching_metrics[0].value if matching_metrics else None
|
1891
|
+
value = float(value)
|
1892
|
+
elif cls.is_param(key_type, comparator):
|
1893
|
+
lhs = model.params.get(key, None)
|
1894
|
+
elif cls.is_tag(key_type, comparator):
|
1895
|
+
lhs = model.tags.get(key, None)
|
1896
|
+
elif cls.is_numeric_attribute(key_type, key, comparator):
|
1897
|
+
lhs = getattr(model, key)
|
1898
|
+
value = int(value)
|
1899
|
+
elif hasattr(model, key):
|
1900
|
+
lhs = getattr(model, key)
|
1901
|
+
else:
|
1902
|
+
raise MlflowException.invalid_parameter_value(
|
1903
|
+
f"Invalid logged model search key '{key}'",
|
1904
|
+
)
|
1905
|
+
if lhs is None:
|
1906
|
+
return False
|
1907
|
+
|
1908
|
+
return SearchUtils.get_comparison_func(comparator)(lhs, value)
|
1909
|
+
|
1910
|
+
@classmethod
|
1911
|
+
def validate_list_supported(cls, key: str) -> None:
|
1912
|
+
"""
|
1913
|
+
Override to allow logged model attributes to be used with IN/NOT IN.
|
1914
|
+
"""
|
1915
|
+
|
1916
|
+
@classmethod
|
1917
|
+
def filter_logged_models(
|
1918
|
+
cls,
|
1919
|
+
models: list[LoggedModel],
|
1920
|
+
filter_string: Optional[str] = None,
|
1921
|
+
datasets: Optional[list[dict[str, Any]]] = None,
|
1922
|
+
):
|
1923
|
+
"""Filters a set of runs based on a search filter string and list of dataset filters."""
|
1924
|
+
if not filter_string and not datasets:
|
1925
|
+
return models
|
1926
|
+
|
1927
|
+
parsed = cls.parse_search_filter(filter_string)
|
1928
|
+
|
1929
|
+
# If there are dataset filters but no metric filters in the filter string,
|
1930
|
+
# filter for models that have any metrics on the datasets
|
1931
|
+
if datasets and not any(
|
1932
|
+
cls.is_metric(s.get("type"), s.get("comparator").upper()) for s in parsed
|
1933
|
+
):
|
1934
|
+
|
1935
|
+
def model_has_metrics_on_datasets(model):
|
1936
|
+
return any(
|
1937
|
+
any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
|
1938
|
+
for metric in model.metrics
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
models = [model for model in models if model_has_metrics_on_datasets(model)]
|
1942
|
+
|
1943
|
+
def model_matches(model):
|
1944
|
+
return all(cls._does_logged_model_match_clause(model, s, datasets) for s in parsed)
|
1945
|
+
|
1946
|
+
return [model for model in models if model_matches(model)]
|
1947
|
+
|
1948
|
+
@dataclass
|
1949
|
+
class OrderBy:
|
1950
|
+
field_name: str
|
1951
|
+
ascending: bool = True
|
1952
|
+
dataset_name: Optional[str] = None
|
1953
|
+
dataset_digest: Optional[str] = None
|
1954
|
+
|
1955
|
+
@classmethod
|
1956
|
+
def parse_order_by_for_logged_models(cls, order_by: dict[str, Any]) -> OrderBy:
|
1957
|
+
if not isinstance(order_by, dict):
|
1958
|
+
raise MlflowException.invalid_parameter_value(
|
1959
|
+
"`order_by` must be a list of dictionaries."
|
1960
|
+
)
|
1961
|
+
field_name = order_by.get("field_name")
|
1962
|
+
if field_name is None:
|
1963
|
+
raise MlflowException.invalid_parameter_value(
|
1964
|
+
"`field_name` in the `order_by` clause must be specified."
|
1965
|
+
)
|
1966
|
+
if "." in field_name:
|
1967
|
+
entity = field_name.split(".", 1)[0]
|
1968
|
+
if entity != "metrics":
|
1969
|
+
raise MlflowException.invalid_parameter_value(
|
1970
|
+
f"Invalid order by field name: {entity}, only `metrics.<name>` is allowed."
|
1971
|
+
)
|
1972
|
+
else:
|
1973
|
+
field_name = field_name.strip()
|
1974
|
+
if field_name not in cls.VALID_ORDER_BY_ATTRIBUTE_KEYS:
|
1975
|
+
raise MlflowException.invalid_parameter_value(
|
1976
|
+
f"Invalid order by field name: {field_name}."
|
1977
|
+
)
|
1978
|
+
ascending = order_by.get("ascending", True)
|
1979
|
+
if ascending not in [True, False]:
|
1980
|
+
raise MlflowException.invalid_parameter_value(
|
1981
|
+
"Value of `ascending` in the `order_by` clause must be a boolean, got "
|
1982
|
+
f"{type(ascending)} for field {field_name}."
|
1983
|
+
)
|
1984
|
+
dataset_name = order_by.get("dataset_name")
|
1985
|
+
dataset_digest = order_by.get("dataset_digest")
|
1986
|
+
if dataset_digest and not dataset_name:
|
1987
|
+
raise MlflowException.invalid_parameter_value(
|
1988
|
+
"`dataset_digest` can only be specified if `dataset_name` is also specified."
|
1989
|
+
)
|
1990
|
+
|
1991
|
+
aliases = {
|
1992
|
+
"creation_time": "creation_timestamp",
|
1993
|
+
}
|
1994
|
+
return cls.OrderBy(
|
1995
|
+
aliases.get(field_name, field_name), ascending, dataset_name, dataset_digest
|
1996
|
+
)
|
1997
|
+
|
1998
|
+
@classmethod
|
1999
|
+
def _apply_reversor_for_logged_model(
|
2000
|
+
cls,
|
2001
|
+
model: LoggedModel,
|
2002
|
+
order_by: OrderBy,
|
2003
|
+
):
|
2004
|
+
if "." in order_by.field_name:
|
2005
|
+
metric_key = order_by.field_name.split(".", 1)[1]
|
2006
|
+
filtered_metrics = sorted(
|
2007
|
+
[
|
2008
|
+
m
|
2009
|
+
for m in model.metrics
|
2010
|
+
if m.key == metric_key
|
2011
|
+
and (not order_by.dataset_name or m.dataset_name == order_by.dataset_name)
|
2012
|
+
and (not order_by.dataset_digest or m.dataset_digest == order_by.dataset_digest)
|
2013
|
+
],
|
2014
|
+
key=lambda metric: metric.timestamp,
|
2015
|
+
reverse=True,
|
2016
|
+
)
|
2017
|
+
latest_metric_value = None if len(filtered_metrics) == 0 else filtered_metrics[0].value
|
2018
|
+
return (
|
2019
|
+
_LoggedModelMetricComp(latest_metric_value)
|
2020
|
+
if order_by.ascending
|
2021
|
+
else _Reversor(latest_metric_value)
|
2022
|
+
)
|
2023
|
+
else:
|
2024
|
+
value = getattr(model, order_by.field_name)
|
2025
|
+
return value if order_by.ascending else _Reversor(value)
|
2026
|
+
|
2027
|
+
@classmethod
|
2028
|
+
def _get_sort_key(cls, order_by_list: Optional[list[dict[str, Any]]]):
|
2029
|
+
parsed_order_by = list(map(cls.parse_order_by_for_logged_models, order_by_list or []))
|
2030
|
+
|
2031
|
+
# Add a tie-breaker
|
2032
|
+
if not any(order_by.field_name == "creation_timestamp" for order_by in parsed_order_by):
|
2033
|
+
parsed_order_by.append(cls.OrderBy("creation_timestamp", False))
|
2034
|
+
if not any(order_by.field_name == "model_id" for order_by in parsed_order_by):
|
2035
|
+
parsed_order_by.append(cls.OrderBy("model_id"))
|
2036
|
+
|
2037
|
+
return lambda logged_model: tuple(
|
2038
|
+
cls._apply_reversor_for_logged_model(logged_model, order_by)
|
2039
|
+
for order_by in parsed_order_by
|
2040
|
+
)
|
2041
|
+
|
2042
|
+
@classmethod
|
2043
|
+
def sort(cls, models, order_by_list):
|
2044
|
+
return sorted(models, key=cls._get_sort_key(order_by_list))
|
2045
|
+
|
2046
|
+
|
2047
|
+
class _LoggedModelMetricComp:
|
2048
|
+
def __init__(self, obj):
|
2049
|
+
self.obj = obj
|
2050
|
+
|
2051
|
+
def __eq__(self, other):
|
2052
|
+
return other.obj == self.obj
|
2053
|
+
|
2054
|
+
def __lt__(self, other):
|
2055
|
+
if self.obj is None:
|
2056
|
+
return False
|
2057
|
+
if other.obj is None:
|
2058
|
+
return True
|
2059
|
+
return self.obj < other.obj
|
2060
|
+
|
2061
|
+
|
2062
|
+
@dataclass
|
2063
|
+
class SearchLoggedModelsPaginationToken:
|
2064
|
+
experiment_ids: list[str]
|
2065
|
+
filter_string: Optional[str] = None
|
2066
|
+
order_by: Optional[list[dict[str, Any]]] = None
|
2067
|
+
offset: int = 0
|
2068
|
+
|
2069
|
+
def to_json(self) -> str:
|
2070
|
+
return json.dumps(asdict(self))
|
2071
|
+
|
2072
|
+
def encode(self) -> str:
|
2073
|
+
return base64.b64encode(self.to_json().encode("utf-8")).decode("utf-8")
|
2074
|
+
|
2075
|
+
@classmethod
|
2076
|
+
def decode(cls, token: str) -> "SearchLoggedModelsPaginationToken":
|
2077
|
+
try:
|
2078
|
+
token = json.loads(base64.b64decode(token.encode("utf-8")).decode("utf-8"))
|
2079
|
+
except json.JSONDecodeError as e:
|
2080
|
+
raise MlflowException.invalid_parameter_value(f"Invalid page token: {token}. {e}")
|
2081
|
+
|
2082
|
+
return cls(
|
2083
|
+
experiment_ids=token.get("experiment_ids"),
|
2084
|
+
filter_string=token.get("filter_string") or None,
|
2085
|
+
order_by=token.get("order_by") or None,
|
2086
|
+
offset=token.get("offset") or 0,
|
2087
|
+
)
|
2088
|
+
|
2089
|
+
def validate(
|
2090
|
+
self,
|
2091
|
+
experiment_ids: list[str],
|
2092
|
+
filter_string: Optional[str],
|
2093
|
+
order_by: Optional[list[dict[str, Any]]],
|
2094
|
+
) -> None:
|
2095
|
+
if self.experiment_ids != experiment_ids:
|
2096
|
+
raise MlflowException.invalid_parameter_value(
|
2097
|
+
f"Experiment IDs in the page token do not match the requested experiment IDs. "
|
2098
|
+
f"Expected: {experiment_ids}. Found: {self.experiment_ids}"
|
2099
|
+
)
|
2100
|
+
|
2101
|
+
if self.filter_string != filter_string:
|
2102
|
+
raise MlflowException.invalid_parameter_value(
|
2103
|
+
f"Filter string in the page token does not match the requested filter string. "
|
2104
|
+
f"Expected: {filter_string}. Found: {self.filter_string}"
|
2105
|
+
)
|
2106
|
+
|
2107
|
+
if self.order_by != order_by:
|
2108
|
+
raise MlflowException.invalid_parameter_value(
|
2109
|
+
f"Order by in the page token does not match the requested order by. "
|
2110
|
+
f"Expected: {order_by}. Found: {self.order_by}"
|
2111
|
+
)
|