genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/optuna/storage.py
ADDED
@@ -0,0 +1,646 @@
|
|
1
|
+
import copy
|
2
|
+
import datetime
|
3
|
+
import json
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
import uuid
|
7
|
+
from collections.abc import Container, Sequence
|
8
|
+
from typing import Any, Optional
|
9
|
+
|
10
|
+
from mlflow import MlflowClient
|
11
|
+
from mlflow.entities import Metric, Param, RunTag
|
12
|
+
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
|
13
|
+
|
14
|
+
try:
|
15
|
+
from optuna._typing import JSONSerializable
|
16
|
+
from optuna.distributions import (
|
17
|
+
BaseDistribution,
|
18
|
+
check_distribution_compatibility,
|
19
|
+
distribution_to_json,
|
20
|
+
json_to_distribution,
|
21
|
+
)
|
22
|
+
from optuna.storages import BaseStorage
|
23
|
+
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
|
24
|
+
from optuna.study import StudyDirection
|
25
|
+
from optuna.study._frozen import FrozenStudy
|
26
|
+
from optuna.trial import FrozenTrial, TrialState
|
27
|
+
except ImportError as e:
|
28
|
+
raise ImportError("Install optuna to use `mlflow.optuna` module") from e
|
29
|
+
|
30
|
+
optuna_mlflow_status_map = {
|
31
|
+
TrialState.RUNNING: "RUNNING",
|
32
|
+
TrialState.COMPLETE: "FINISHED",
|
33
|
+
TrialState.PRUNED: "KILLED",
|
34
|
+
TrialState.FAIL: "FAILED",
|
35
|
+
TrialState.WAITING: "SCHEDULED",
|
36
|
+
}
|
37
|
+
|
38
|
+
mlflow_optuna_status_map = {
|
39
|
+
"RUNNING": TrialState.RUNNING,
|
40
|
+
"FINISHED": TrialState.COMPLETE,
|
41
|
+
"KILLED": TrialState.PRUNED,
|
42
|
+
"FAILED": TrialState.FAIL,
|
43
|
+
"SCHEDULED": TrialState.WAITING,
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
class MlflowStorage(BaseStorage):
|
48
|
+
"""
|
49
|
+
MLflow based storage class with batch processing to avoid REST API throttling.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
experiment_id: str,
|
55
|
+
name: Optional[str] = None,
|
56
|
+
batch_flush_interval: float = 1.0,
|
57
|
+
batch_size_threshold: int = 100,
|
58
|
+
):
|
59
|
+
"""
|
60
|
+
Initialize MLFlowStorage with batching capabilities.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
experiment_id : str
|
65
|
+
MLflow experiment ID
|
66
|
+
name : Optional[str]
|
67
|
+
Optional name for the storage
|
68
|
+
batch_flush_interval : float
|
69
|
+
Time in seconds between automatic batch flushes (default: 1.0)
|
70
|
+
batch_size_threshold : int
|
71
|
+
Maximum number of items in batch before triggering a flush (default: 100)
|
72
|
+
"""
|
73
|
+
if not experiment_id:
|
74
|
+
raise Exception("No experiment_id provided. MLFlowStorage cannot create experiments.")
|
75
|
+
|
76
|
+
self._experiment_id = experiment_id
|
77
|
+
self._mlflow_client = MlflowClient()
|
78
|
+
self._name = name
|
79
|
+
|
80
|
+
# Batching configuration
|
81
|
+
self._batch_flush_interval = batch_flush_interval
|
82
|
+
self._batch_size_threshold = batch_size_threshold
|
83
|
+
|
84
|
+
# Batching queues for metrics, parameters, and tags
|
85
|
+
self._batch_queue = {} # Dictionary of run_id -> {'metrics': [], 'params': [], 'tags': []}
|
86
|
+
self._batch_lock = threading.RLock()
|
87
|
+
self._last_flush_time = time.time()
|
88
|
+
|
89
|
+
# Flag to indicate if the worker should stop - must be defined BEFORE starting the thread
|
90
|
+
self._stop_worker = False
|
91
|
+
|
92
|
+
# Start a background thread for periodic flushing
|
93
|
+
self._flush_thread = threading.Thread(
|
94
|
+
target=self._periodic_flush_worker,
|
95
|
+
daemon=True,
|
96
|
+
name=f"mlflow_optuna_batch_flush_worker_{uuid.uuid4().hex[:8]}",
|
97
|
+
)
|
98
|
+
self._flush_thread.start()
|
99
|
+
|
100
|
+
def __getstate__(self):
|
101
|
+
"""
|
102
|
+
Prepare the object for serialization by removing non-picklable components.
|
103
|
+
This is called when the object is being pickled.
|
104
|
+
"""
|
105
|
+
state = self.__dict__.copy()
|
106
|
+
|
107
|
+
# Remove thread-related attributes that can't be pickled
|
108
|
+
state.pop("_batch_lock", None)
|
109
|
+
state.pop("_flush_thread", None)
|
110
|
+
|
111
|
+
# Store the configuration but not the actual lock/thread
|
112
|
+
state["_thread_running"] = hasattr(self, "_flush_thread") and self._flush_thread.is_alive()
|
113
|
+
|
114
|
+
return state
|
115
|
+
|
116
|
+
def __setstate__(self, state):
|
117
|
+
"""
|
118
|
+
Restore the object after deserialization by recreating non-picklable components.
|
119
|
+
This is called when the object is being unpickled.
|
120
|
+
"""
|
121
|
+
# First, update the instance with the pickled state
|
122
|
+
self.__dict__.update(state)
|
123
|
+
|
124
|
+
# Recreate the lock
|
125
|
+
self._batch_lock = threading.RLock()
|
126
|
+
|
127
|
+
# Don't automatically restart the thread on workers - this would create too many threads
|
128
|
+
# Instead, we'll use a manual flush approach in distributed contexts
|
129
|
+
self._flush_thread = None
|
130
|
+
|
131
|
+
# If we're on a worker node, we should disable automatic background flushing
|
132
|
+
# because it could cause issues with multiple threads trying to write to MLflow
|
133
|
+
self._stop_worker = True
|
134
|
+
|
135
|
+
def __del__(self):
|
136
|
+
"""Ensure all queued data is flushed before destroying the object."""
|
137
|
+
# Set the stop flag
|
138
|
+
if hasattr(self, "_stop_worker"):
|
139
|
+
self._stop_worker = True
|
140
|
+
|
141
|
+
# Join the thread if it exists and is alive
|
142
|
+
if hasattr(self, "_flush_thread") and self._flush_thread.is_alive():
|
143
|
+
try:
|
144
|
+
self._flush_thread.join(timeout=5.0)
|
145
|
+
except Exception:
|
146
|
+
pass # Ignore errors during cleanup
|
147
|
+
|
148
|
+
# Flush any remaining data
|
149
|
+
if hasattr(self, "_batch_queue"):
|
150
|
+
try:
|
151
|
+
self.flush_all_batches()
|
152
|
+
except Exception:
|
153
|
+
pass # Ignore errors during cleanup
|
154
|
+
|
155
|
+
def _periodic_flush_worker(self):
|
156
|
+
"""Background worker that periodically flushes batched data."""
|
157
|
+
while not self._stop_worker:
|
158
|
+
try:
|
159
|
+
time.sleep(min(0.1, self._batch_flush_interval / 10)) # Sleep in small increments
|
160
|
+
|
161
|
+
# Check if it's time to flush
|
162
|
+
current_time = time.time()
|
163
|
+
if current_time - self._last_flush_time >= self._batch_flush_interval:
|
164
|
+
self.flush_all_batches()
|
165
|
+
self._last_flush_time = current_time
|
166
|
+
except Exception:
|
167
|
+
# Catch any exceptions to prevent thread crashes
|
168
|
+
time.sleep(1.0) # Sleep a bit longer if there was an error
|
169
|
+
|
170
|
+
def _queue_batch_operation(
|
171
|
+
self,
|
172
|
+
run_id: str,
|
173
|
+
metrics: Optional[list[Metric]] = None,
|
174
|
+
params: Optional[list[Param]] = None,
|
175
|
+
tags: Optional[list[RunTag]] = None,
|
176
|
+
):
|
177
|
+
"""Queue metrics, parameters, or tags for batched processing."""
|
178
|
+
with self._batch_lock:
|
179
|
+
if run_id not in self._batch_queue:
|
180
|
+
self._batch_queue[run_id] = {"metrics": [], "params": [], "tags": []}
|
181
|
+
|
182
|
+
batch = self._batch_queue[run_id]
|
183
|
+
|
184
|
+
if metrics:
|
185
|
+
batch["metrics"].extend(metrics)
|
186
|
+
if params:
|
187
|
+
batch["params"].extend(params)
|
188
|
+
if tags:
|
189
|
+
batch["tags"].extend(tags)
|
190
|
+
|
191
|
+
# Check if we've reached the batch size threshold for this run
|
192
|
+
batch_size = len(batch["metrics"]) + len(batch["params"]) + len(batch["tags"])
|
193
|
+
if batch_size >= self._batch_size_threshold:
|
194
|
+
self._flush_batch(run_id)
|
195
|
+
|
196
|
+
def _flush_batch(self, run_id: str):
|
197
|
+
"""Flush the batch for a specific run_id to MLflow."""
|
198
|
+
with self._batch_lock:
|
199
|
+
if run_id not in self._batch_queue:
|
200
|
+
return
|
201
|
+
|
202
|
+
batch = self._batch_queue[run_id]
|
203
|
+
|
204
|
+
# Only make the API call if there's something to flush
|
205
|
+
if batch["metrics"] or batch["params"] or batch["tags"]:
|
206
|
+
try:
|
207
|
+
self._mlflow_client.log_batch(
|
208
|
+
run_id, metrics=batch["metrics"], params=batch["params"], tags=batch["tags"]
|
209
|
+
)
|
210
|
+
except Exception as e:
|
211
|
+
# If the run doesn't exist, propagate the error
|
212
|
+
if "Run with id=" in str(e) and "not found" in str(e):
|
213
|
+
raise
|
214
|
+
# Otherwise, handle or log the error as needed
|
215
|
+
|
216
|
+
# Clear the batch
|
217
|
+
batch["metrics"] = []
|
218
|
+
batch["params"] = []
|
219
|
+
batch["tags"] = []
|
220
|
+
|
221
|
+
def flush_all_batches(self):
|
222
|
+
"""Flush all pending batches to MLflow."""
|
223
|
+
with self._batch_lock:
|
224
|
+
run_ids = list(self._batch_queue.keys())
|
225
|
+
|
226
|
+
# Flush each run's batch
|
227
|
+
for run_id in run_ids:
|
228
|
+
self._flush_batch(run_id)
|
229
|
+
|
230
|
+
def _search_run_by_name(self, run_name: str):
|
231
|
+
filter_string = f"tags.mlflow.runName = '{run_name}'"
|
232
|
+
return self._mlflow_client.search_runs(
|
233
|
+
experiment_ids=[self._experiment_id], filter_string=filter_string
|
234
|
+
)
|
235
|
+
|
236
|
+
def create_new_study(
|
237
|
+
self, directions: Sequence[StudyDirection], study_name: Optional[str] = None
|
238
|
+
) -> int:
|
239
|
+
"""Create a new study as a mlflow run."""
|
240
|
+
study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str(uuid.uuid4())
|
241
|
+
tags = {
|
242
|
+
"mlflow.runName": study_name,
|
243
|
+
"optuna.study_direction": ",".join(direction.name for direction in directions),
|
244
|
+
}
|
245
|
+
study_run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=tags)
|
246
|
+
return study_run.info.run_id
|
247
|
+
|
248
|
+
def delete_study(self, study_id) -> None:
|
249
|
+
"""Delete a study."""
|
250
|
+
# Ensure any pending changes are saved before deletion
|
251
|
+
self._flush_batch(study_id)
|
252
|
+
self._mlflow_client.delete_run(study_id)
|
253
|
+
|
254
|
+
def set_study_user_attr(self, study_id, key: str, value: JSONSerializable) -> None:
|
255
|
+
"""Register a user-defined attribute as mlflow run tags to a study run."""
|
256
|
+
# Verify the run exists first to fail fast if it doesn't
|
257
|
+
self._mlflow_client.get_run(study_id)
|
258
|
+
|
259
|
+
# Queue the tag if the run exists
|
260
|
+
self._queue_batch_operation(study_id, tags=[RunTag(f"user_{key}", json.dumps(value))])
|
261
|
+
|
262
|
+
def set_study_system_attr(self, study_id, key: str, value: JSONSerializable) -> None:
|
263
|
+
"""Register a optuna-internal attribute as mlflow run tags to a study run."""
|
264
|
+
# Verify the run exists first to fail fast if it doesn't
|
265
|
+
self._mlflow_client.get_run(study_id)
|
266
|
+
|
267
|
+
# Queue the tag if the run exists
|
268
|
+
self._queue_batch_operation(study_id, tags=[RunTag(f"sys_{key}", json.dumps(value))])
|
269
|
+
|
270
|
+
def get_study_id_from_name(self, study_name: str) -> int:
|
271
|
+
# Flush all batches to ensure we have the latest data
|
272
|
+
self.flush_all_batches()
|
273
|
+
|
274
|
+
runs = self._search_run_by_name(study_name)
|
275
|
+
if len(runs):
|
276
|
+
return runs[0].info.run_id
|
277
|
+
else:
|
278
|
+
raise Exception(f"Study {study_name} not found")
|
279
|
+
|
280
|
+
def get_study_name_from_id(self, study_id) -> str:
|
281
|
+
# Flush the batch for this study to ensure we have the latest data
|
282
|
+
self._flush_batch(study_id)
|
283
|
+
|
284
|
+
run = self._mlflow_client.get_run(study_id)
|
285
|
+
return run.data.tags["mlflow.runName"]
|
286
|
+
|
287
|
+
def get_study_directions(self, study_id) -> list[StudyDirection]:
|
288
|
+
# Flush the batch for this study to ensure we have the latest data
|
289
|
+
self._flush_batch(study_id)
|
290
|
+
|
291
|
+
run = self._mlflow_client.get_run(study_id)
|
292
|
+
directions_str = run.data.tags["optuna.study_direction"]
|
293
|
+
return [StudyDirection[name] for name in directions_str.split(",")]
|
294
|
+
|
295
|
+
def get_study_user_attrs(self, study_id) -> dict[str, Any]:
|
296
|
+
# Flush the batch for this study to ensure we have the latest data
|
297
|
+
self._flush_batch(study_id)
|
298
|
+
|
299
|
+
run = self._mlflow_client.get_run(study_id)
|
300
|
+
user_attrs = {}
|
301
|
+
for key, value in run.data.tags.items():
|
302
|
+
if key.startswith("user_"):
|
303
|
+
user_attrs[key[5:]] = json.loads(value)
|
304
|
+
return user_attrs
|
305
|
+
|
306
|
+
def get_study_system_attrs(self, study_id) -> dict[str, Any]:
|
307
|
+
# Flush the batch for this study to ensure we have the latest data
|
308
|
+
self._flush_batch(study_id)
|
309
|
+
|
310
|
+
run = self._mlflow_client.get_run(study_id)
|
311
|
+
system_attrs = {}
|
312
|
+
for key, value in run.data.tags.items():
|
313
|
+
if key.startswith("sys_"):
|
314
|
+
system_attrs[key[4:]] = json.loads(value)
|
315
|
+
return system_attrs
|
316
|
+
|
317
|
+
def get_all_studies(self) -> list[FrozenStudy]:
|
318
|
+
# Flush all batches to ensure we have the latest data
|
319
|
+
self.flush_all_batches()
|
320
|
+
|
321
|
+
runs = self._mlflow_client.search_runs(experiment_ids=[self._experiment_id])
|
322
|
+
studies = []
|
323
|
+
for run in runs:
|
324
|
+
study_id = run.info.run_id
|
325
|
+
study_name = run.data.tags["mlflow.runName"]
|
326
|
+
directions_str = run.data.tags["optuna.study_direction"]
|
327
|
+
directions = [StudyDirection[name] for name in directions_str.split(",")]
|
328
|
+
studies.append(
|
329
|
+
FrozenStudy(
|
330
|
+
study_name=study_name,
|
331
|
+
direction=None,
|
332
|
+
directions=directions,
|
333
|
+
user_attrs=self.get_study_user_attrs(study_id),
|
334
|
+
system_attrs=self.get_study_system_attrs(study_id),
|
335
|
+
study_id=study_id,
|
336
|
+
)
|
337
|
+
)
|
338
|
+
return studies
|
339
|
+
|
340
|
+
def create_new_trial(self, study_id, template_trial: Optional[FrozenTrial] = None) -> int:
|
341
|
+
# Ensure study batch is flushed before creating a new trial
|
342
|
+
self._flush_batch(study_id)
|
343
|
+
|
344
|
+
if template_trial:
|
345
|
+
frozen = copy.deepcopy(template_trial)
|
346
|
+
else:
|
347
|
+
frozen = FrozenTrial(
|
348
|
+
trial_id=-1, # dummy value.
|
349
|
+
number=-1, # dummy value.
|
350
|
+
state=TrialState.RUNNING,
|
351
|
+
params={},
|
352
|
+
distributions={},
|
353
|
+
user_attrs={},
|
354
|
+
system_attrs={},
|
355
|
+
value=None,
|
356
|
+
intermediate_values={},
|
357
|
+
datetime_start=datetime.datetime.now(),
|
358
|
+
datetime_complete=None,
|
359
|
+
)
|
360
|
+
|
361
|
+
distribution_json = {
|
362
|
+
k: distribution_to_json(dist) for k, dist in frozen.distributions.items()
|
363
|
+
}
|
364
|
+
distribution_str = json.dumps(distribution_json)
|
365
|
+
tags = {"param_directions": distribution_str}
|
366
|
+
|
367
|
+
trial_run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=tags)
|
368
|
+
trial_id = trial_run.info.run_id
|
369
|
+
|
370
|
+
# Add parent run ID tag
|
371
|
+
self._queue_batch_operation(trial_id, tags=[RunTag(MLFLOW_PARENT_RUN_ID, study_id)])
|
372
|
+
|
373
|
+
# Log trial_id metric to study
|
374
|
+
hash_id = float(hash(trial_id))
|
375
|
+
self._queue_batch_operation(
|
376
|
+
study_id, metrics=[Metric("trial_id", hash_id, int(time.time() * 1000), 1)]
|
377
|
+
)
|
378
|
+
|
379
|
+
# Ensure study batch is flushed to get accurate metric history
|
380
|
+
self._flush_batch(study_id)
|
381
|
+
|
382
|
+
trial_ids = self._mlflow_client.get_metric_history(study_id, "trial_id")
|
383
|
+
index = next((i for i, obj in enumerate(trial_ids) if obj.value == hash_id), -1)
|
384
|
+
|
385
|
+
self._queue_batch_operation(trial_id, tags=[RunTag("numbers", str(index))])
|
386
|
+
|
387
|
+
# Set trial state
|
388
|
+
state = frozen.state
|
389
|
+
if state.is_finished():
|
390
|
+
self._mlflow_client.set_terminated(trial_id, status=optuna_mlflow_status_map[state])
|
391
|
+
else:
|
392
|
+
self._mlflow_client.update_run(trial_id, status=optuna_mlflow_status_map[state])
|
393
|
+
|
394
|
+
timestamp = int(time.time() * 1000)
|
395
|
+
metrics = []
|
396
|
+
params = []
|
397
|
+
tags = []
|
398
|
+
|
399
|
+
# Add metrics
|
400
|
+
if frozen.values is not None:
|
401
|
+
if len(frozen.values) > 1:
|
402
|
+
metrics.extend(
|
403
|
+
[
|
404
|
+
Metric(f"value_{idx}", val, timestamp, 1)
|
405
|
+
for idx, val in enumerate(frozen.values)
|
406
|
+
]
|
407
|
+
)
|
408
|
+
else:
|
409
|
+
metrics.append(Metric("value", frozen.values[0], timestamp, 1))
|
410
|
+
elif frozen.value is not None:
|
411
|
+
metrics.append(Metric("value", frozen.value, timestamp, 1))
|
412
|
+
|
413
|
+
# Add intermediate values
|
414
|
+
metrics.extend(
|
415
|
+
[
|
416
|
+
Metric("intermediate_value", val, timestamp, int(k))
|
417
|
+
for k, val in frozen.intermediate_values.items()
|
418
|
+
]
|
419
|
+
)
|
420
|
+
|
421
|
+
# Add params
|
422
|
+
params.extend([Param(k, param) for k, param in frozen.params.items()])
|
423
|
+
|
424
|
+
# Add tags
|
425
|
+
tags.extend(
|
426
|
+
[RunTag(f"user_{key}", json.dumps(value)) for key, value in frozen.user_attrs.items()]
|
427
|
+
)
|
428
|
+
tags.extend(
|
429
|
+
[RunTag(f"sys_{key}", json.dumps(value)) for key, value in frozen.system_attrs.items()]
|
430
|
+
)
|
431
|
+
tags.extend(
|
432
|
+
[
|
433
|
+
RunTag(
|
434
|
+
f"param_internal_val_{k}",
|
435
|
+
json.dumps(frozen.distributions[k].to_internal_repr(param)),
|
436
|
+
)
|
437
|
+
for k, param in frozen.params.items()
|
438
|
+
]
|
439
|
+
)
|
440
|
+
|
441
|
+
# Queue all the data to be sent in batches
|
442
|
+
self._queue_batch_operation(trial_id, metrics=metrics, params=params, tags=tags)
|
443
|
+
|
444
|
+
return trial_id
|
445
|
+
|
446
|
+
def set_trial_param(
|
447
|
+
self,
|
448
|
+
trial_id,
|
449
|
+
param_name: str,
|
450
|
+
param_value_internal: float,
|
451
|
+
distribution: BaseDistribution,
|
452
|
+
) -> None:
|
453
|
+
# Flush the batch for this trial to ensure we have the latest data
|
454
|
+
self._flush_batch(trial_id)
|
455
|
+
|
456
|
+
trial_run = self._mlflow_client.get_run(trial_id)
|
457
|
+
distributions_dict = json.loads(trial_run.data.tags["param_directions"])
|
458
|
+
self.check_trial_is_updatable(trial_id, mlflow_optuna_status_map[trial_run.info.status])
|
459
|
+
|
460
|
+
if param_name in trial_run.data.params:
|
461
|
+
param_distribution = json_to_distribution(distributions_dict[param_name])
|
462
|
+
check_distribution_compatibility(param_distribution, distribution)
|
463
|
+
|
464
|
+
# Queue parameter update
|
465
|
+
self._queue_batch_operation(
|
466
|
+
trial_id,
|
467
|
+
params=[Param(param_name, distribution.to_external_repr(param_value_internal))],
|
468
|
+
tags=[RunTag(f"param_internal_val_{param_name}", json.dumps(param_value_internal))],
|
469
|
+
)
|
470
|
+
|
471
|
+
distributions_dict[param_name] = distribution_to_json(distribution)
|
472
|
+
self._queue_batch_operation(
|
473
|
+
trial_id, tags=[RunTag("param_directions", json.dumps(distributions_dict))]
|
474
|
+
)
|
475
|
+
|
476
|
+
def get_trial_id_from_study_id_trial_number(self, study_id, trial_number: int) -> int:
|
477
|
+
raise NotImplementedError("This method is not supported in MLflow backend.")
|
478
|
+
|
479
|
+
def get_trial_number_from_id(self, trial_id) -> int:
|
480
|
+
# Flush the batch for this trial to ensure we have the latest data
|
481
|
+
self._flush_batch(trial_id)
|
482
|
+
|
483
|
+
trial_run = self._mlflow_client.get_run(trial_id)
|
484
|
+
return int(trial_run.data.tags.get("numbers", 0))
|
485
|
+
|
486
|
+
def get_trial_param(self, trial_id, param_name: str) -> float:
|
487
|
+
# Flush the batch for this trial to ensure we have the latest data
|
488
|
+
self._flush_batch(trial_id)
|
489
|
+
|
490
|
+
trial_run = self._mlflow_client.get_run(trial_id)
|
491
|
+
param_value = trial_run.data.tags[f"param_internal_val_{param_name}"]
|
492
|
+
|
493
|
+
return float(json.loads(param_value))
|
494
|
+
|
495
|
+
def set_trial_state_values(
|
496
|
+
self, trial_id, state: TrialState, values: Optional[Sequence[float]] = None
|
497
|
+
) -> bool:
|
498
|
+
# Update trial state
|
499
|
+
if state.is_finished():
|
500
|
+
self._mlflow_client.set_terminated(trial_id, status=optuna_mlflow_status_map[state])
|
501
|
+
else:
|
502
|
+
self._mlflow_client.update_run(trial_id, status=optuna_mlflow_status_map[state])
|
503
|
+
|
504
|
+
# Queue value metrics if provided
|
505
|
+
timestamp = int(time.time() * 1000)
|
506
|
+
if values is not None:
|
507
|
+
metrics = []
|
508
|
+
if len(values) > 1:
|
509
|
+
metrics = [
|
510
|
+
Metric(f"value_{idx}", val, timestamp, 1) for idx, val in enumerate(values)
|
511
|
+
]
|
512
|
+
else:
|
513
|
+
metrics = [Metric("value", values[0], timestamp, 1)]
|
514
|
+
|
515
|
+
self._queue_batch_operation(trial_id, metrics=metrics)
|
516
|
+
|
517
|
+
if state == TrialState.RUNNING and state != TrialState.WAITING:
|
518
|
+
return False
|
519
|
+
return True
|
520
|
+
|
521
|
+
def set_trial_intermediate_value(self, trial_id, step: int, intermediate_value: float) -> None:
|
522
|
+
# Queue intermediate value metric
|
523
|
+
self._queue_batch_operation(
|
524
|
+
trial_id,
|
525
|
+
metrics=[
|
526
|
+
Metric("intermediate_value", intermediate_value, int(time.time() * 1000), step)
|
527
|
+
],
|
528
|
+
)
|
529
|
+
|
530
|
+
def set_trial_user_attr(self, trial_id, key: str, value: Any) -> None:
|
531
|
+
# Queue user attribute tag
|
532
|
+
self._queue_batch_operation(trial_id, tags=[RunTag(f"user_{key}", json.dumps(value))])
|
533
|
+
|
534
|
+
def set_trial_system_attr(self, trial_id, key: str, value: Any) -> None:
|
535
|
+
# Queue system attribute tag
|
536
|
+
self._queue_batch_operation(trial_id, tags=[RunTag(f"sys_{key}", json.dumps(value))])
|
537
|
+
|
538
|
+
def get_trial(self, trial_id) -> FrozenTrial:
|
539
|
+
# Flush the batch for this trial to ensure we have the latest data
|
540
|
+
self._flush_batch(trial_id)
|
541
|
+
|
542
|
+
trial_run = self._mlflow_client.get_run(trial_id)
|
543
|
+
param_directions = trial_run.data.tags["param_directions"]
|
544
|
+
try:
|
545
|
+
distributions_dict = json.loads(param_directions)
|
546
|
+
except json.decoder.JSONDecodeError as e:
|
547
|
+
raise ValueError(f"error with param_directions = {param_directions!r}") from e
|
548
|
+
|
549
|
+
distributions = {
|
550
|
+
k: json_to_distribution(distribution) for k, distribution in distributions_dict.items()
|
551
|
+
}
|
552
|
+
params = {}
|
553
|
+
for key, value in trial_run.data.tags.items():
|
554
|
+
if key.startswith("param_internal_val_"):
|
555
|
+
param_name = key[19:]
|
556
|
+
param_value = json.loads(value)
|
557
|
+
params[param_name] = distributions[param_name].to_external_repr(float(param_value))
|
558
|
+
|
559
|
+
metrics = trial_run.data.metrics
|
560
|
+
values = None
|
561
|
+
if "value" in metrics:
|
562
|
+
values = [metrics["value"]]
|
563
|
+
if "value_0" in metrics:
|
564
|
+
values = [metrics[f"value_{idx}"] for idx in range(len(metrics))]
|
565
|
+
|
566
|
+
run_number = int(trial_run.data.tags.get("numbers", 0))
|
567
|
+
|
568
|
+
start_time = datetime.datetime.fromtimestamp(trial_run.info.start_time / 1000)
|
569
|
+
if trial_run.info.end_time:
|
570
|
+
end_time = datetime.datetime.fromtimestamp(trial_run.info.end_time / 1000)
|
571
|
+
else:
|
572
|
+
end_time = None
|
573
|
+
return FrozenTrial(
|
574
|
+
trial_id=trial_id,
|
575
|
+
number=run_number,
|
576
|
+
state=mlflow_optuna_status_map[trial_run.info.status],
|
577
|
+
value=None,
|
578
|
+
values=values,
|
579
|
+
datetime_start=start_time,
|
580
|
+
datetime_complete=end_time,
|
581
|
+
params=params,
|
582
|
+
distributions=distributions,
|
583
|
+
user_attrs=self.get_trial_user_attrs(trial_id),
|
584
|
+
system_attrs=self.get_trial_system_attrs(trial_id),
|
585
|
+
intermediate_values={
|
586
|
+
v.step: v.value
|
587
|
+
for idx, v in enumerate(
|
588
|
+
self._mlflow_client.get_metric_history(trial_id, "intermediate_value")
|
589
|
+
)
|
590
|
+
},
|
591
|
+
)
|
592
|
+
|
593
|
+
def get_trial_user_attrs(self, trial_id) -> dict[str, Any]:
|
594
|
+
# Flush the batch for this trial to ensure we have the latest data
|
595
|
+
self._flush_batch(trial_id)
|
596
|
+
|
597
|
+
run = self._mlflow_client.get_run(trial_id)
|
598
|
+
user_attrs = {}
|
599
|
+
for key, value in run.data.tags.items():
|
600
|
+
if key.startswith("user_"):
|
601
|
+
user_attrs[key[5:]] = json.loads(value)
|
602
|
+
return user_attrs
|
603
|
+
|
604
|
+
def get_trial_system_attrs(self, trial_id) -> dict[str, Any]:
|
605
|
+
# Flush the batch for this trial to ensure we have the latest data
|
606
|
+
self._flush_batch(trial_id)
|
607
|
+
|
608
|
+
run = self._mlflow_client.get_run(trial_id)
|
609
|
+
system_attrs = {}
|
610
|
+
for key, value in run.data.tags.items():
|
611
|
+
if key.startswith("sys_"):
|
612
|
+
system_attrs[key[4:]] = json.loads(value)
|
613
|
+
return system_attrs
|
614
|
+
|
615
|
+
def get_all_trials(
|
616
|
+
self,
|
617
|
+
study_id,
|
618
|
+
deepcopy: bool = True,
|
619
|
+
states: Optional[Container[TrialState]] = None,
|
620
|
+
) -> list[FrozenTrial]:
|
621
|
+
# Flush all batches to ensure we have the latest data
|
622
|
+
self.flush_all_batches()
|
623
|
+
|
624
|
+
runs = self._mlflow_client.search_runs(
|
625
|
+
experiment_ids=[self._experiment_id],
|
626
|
+
filter_string=f"tags.mlflow.parentRunId='{study_id}'",
|
627
|
+
)
|
628
|
+
trials = []
|
629
|
+
for run in runs:
|
630
|
+
trials.append(self.get_trial(run.info.run_id))
|
631
|
+
|
632
|
+
frozen_trials: list[FrozenTrial] = []
|
633
|
+
for trial in trials:
|
634
|
+
if states is None or trial.state in states:
|
635
|
+
frozen_trials.append(trial)
|
636
|
+
return frozen_trials
|
637
|
+
|
638
|
+
def get_n_trials(self, study_id, states=None) -> int:
|
639
|
+
# Flush all batches to ensure we have the latest data
|
640
|
+
self.flush_all_batches()
|
641
|
+
|
642
|
+
runs = self._mlflow_client.search_runs(
|
643
|
+
experiment_ids=[self._experiment_id],
|
644
|
+
filter_string=f"tags.mlflow.parentRunId='{study_id}'",
|
645
|
+
)
|
646
|
+
return len(runs)
|