genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,1104 @@
|
|
1
|
+
import abc
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
import itertools
|
5
|
+
import uuid
|
6
|
+
from contextlib import asynccontextmanager, contextmanager
|
7
|
+
from typing import Any, Callable, NamedTuple, Optional
|
8
|
+
|
9
|
+
import mlflow
|
10
|
+
import mlflow.utils.autologging_utils
|
11
|
+
from mlflow.entities.run_status import RunStatus
|
12
|
+
from mlflow.environment_variables import _MLFLOW_AUTOLOGGING_TESTING
|
13
|
+
from mlflow.exceptions import MlflowException
|
14
|
+
from mlflow.utils import gorilla, is_iterator
|
15
|
+
from mlflow.utils.autologging_utils import _logger
|
16
|
+
from mlflow.utils.autologging_utils.events import AutologgingEventLoggerWrapper
|
17
|
+
from mlflow.utils.autologging_utils.logging_and_warnings import (
|
18
|
+
MlflowEventsAndWarningsBehaviorGlobally,
|
19
|
+
NonMlflowWarningsBehaviorForCurrentThread,
|
20
|
+
)
|
21
|
+
from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING
|
22
|
+
|
23
|
+
_AUTOLOGGING_PATCHES = {}
|
24
|
+
|
25
|
+
|
26
|
+
# Function attribute used for testing purposes to verify that a given function
|
27
|
+
# has been wrapped with the `exception_safe_function_for_class` and
|
28
|
+
# `picklable_exception_safe_function` decorators
|
29
|
+
_ATTRIBUTE_EXCEPTION_SAFE = "exception_safe"
|
30
|
+
|
31
|
+
|
32
|
+
_ERROR_MSG = "Encountered unexpected error during {} autologging: {}"
|
33
|
+
|
34
|
+
|
35
|
+
def exception_safe_function_for_class(function):
|
36
|
+
"""
|
37
|
+
Wraps the specified function with broad exception handling to guard
|
38
|
+
against unexpected errors during autologging.
|
39
|
+
Note this function creates an unpicklable function as `safe_function` is locally defined,
|
40
|
+
but a class instance containing methods decorated by this function should be pickalable,
|
41
|
+
because pickle only saves instance attributes, not methods.
|
42
|
+
See https://docs.python.org/3/library/pickle.html#pickling-class-instances for more details.
|
43
|
+
"""
|
44
|
+
if is_testing():
|
45
|
+
setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
|
46
|
+
|
47
|
+
def safe_function(*args, **kwargs):
|
48
|
+
try:
|
49
|
+
return function(*args, **kwargs)
|
50
|
+
except Exception as e:
|
51
|
+
if is_testing():
|
52
|
+
raise
|
53
|
+
else:
|
54
|
+
_logger.warning("Encountered unexpected error during autologging: %s", e)
|
55
|
+
|
56
|
+
return update_wrapper_extended(safe_function, function)
|
57
|
+
|
58
|
+
|
59
|
+
def _safe_function(function, *args, **kwargs):
|
60
|
+
try:
|
61
|
+
return function(*args, **kwargs)
|
62
|
+
except Exception as e:
|
63
|
+
if is_testing():
|
64
|
+
raise
|
65
|
+
else:
|
66
|
+
_logger.warning("Encountered unexpected error during autologging: %s", e)
|
67
|
+
|
68
|
+
|
69
|
+
def picklable_exception_safe_function(function):
|
70
|
+
"""
|
71
|
+
Wraps the specified function with broad exception handling to guard
|
72
|
+
against unexpected errors during autologging while preserving picklability.
|
73
|
+
"""
|
74
|
+
if is_testing():
|
75
|
+
setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
|
76
|
+
|
77
|
+
return update_wrapper_extended(functools.partial(_safe_function, function), function)
|
78
|
+
|
79
|
+
|
80
|
+
def _exception_safe_class_factory(base_class):
|
81
|
+
"""
|
82
|
+
Creates an exception safe metaclass that inherits from `base_class`.
|
83
|
+
"""
|
84
|
+
|
85
|
+
class _ExceptionSafeClass(base_class):
|
86
|
+
"""
|
87
|
+
Metaclass that wraps all functions defined on the specified class with broad error handling
|
88
|
+
logic to guard against unexpected errors during autlogging.
|
89
|
+
|
90
|
+
Rationale: Patched autologging functions commonly pass additional class instances as
|
91
|
+
arguments to their underlying original training routines; for example, Keras autologging
|
92
|
+
constructs a subclass of `keras.callbacks.Callback` and forwards it to `Model.fit()`.
|
93
|
+
To prevent errors encountered during method execution within such classes from disrupting
|
94
|
+
model training, this metaclass wraps all class functions in a broad try / catch statement.
|
95
|
+
|
96
|
+
Note: `ExceptionSafeClass` does not handle exceptions in class methods or static methods,
|
97
|
+
as these are not always Python callables and are difficult to wrap
|
98
|
+
"""
|
99
|
+
|
100
|
+
def __new__(cls, name, bases, dct):
|
101
|
+
for m in dct:
|
102
|
+
# class methods or static methods are not callable.
|
103
|
+
if callable(dct[m]):
|
104
|
+
dct[m] = exception_safe_function_for_class(dct[m])
|
105
|
+
return base_class.__new__(cls, name, bases, dct)
|
106
|
+
|
107
|
+
return _ExceptionSafeClass
|
108
|
+
|
109
|
+
|
110
|
+
ExceptionSafeClass = _exception_safe_class_factory(type)
|
111
|
+
|
112
|
+
# `ExceptionSafeClass` causes an error when used with an abstract class.
|
113
|
+
#
|
114
|
+
# ```
|
115
|
+
# class AbstractClass(abc.ABC):
|
116
|
+
# ...
|
117
|
+
#
|
118
|
+
# class DerivedClass(AbstractClass, metaclass=ExceptionSafeClass):
|
119
|
+
# ...
|
120
|
+
# ```
|
121
|
+
#
|
122
|
+
# This raises:
|
123
|
+
#
|
124
|
+
# ```
|
125
|
+
# TypeError: metaclass conflict: the metaclass of a derived class must be
|
126
|
+
# a (non-strict) subclass of the metaclasses of all its bases.
|
127
|
+
# ```
|
128
|
+
#
|
129
|
+
# To avoid this error, create `ExceptionSafeAbstractClass` that is based on `abc.ABCMeta`.
|
130
|
+
ExceptionSafeAbstractClass = _exception_safe_class_factory(abc.ABCMeta)
|
131
|
+
|
132
|
+
|
133
|
+
def with_managed_run(autologging_integration, patch_function, tags=None):
|
134
|
+
"""Given a `patch_function`, returns an `augmented_patch_function` that wraps the execution of
|
135
|
+
`patch_function` with an active MLflow run. The following properties apply:
|
136
|
+
|
137
|
+
- An MLflow run is only created if there is no active run present when the
|
138
|
+
patch function is executed
|
139
|
+
|
140
|
+
- If an active run is created by the `augmented_patch_function`, it is terminated
|
141
|
+
with the `FINISHED` state at the end of function execution
|
142
|
+
|
143
|
+
- If an active run is created by the `augmented_patch_function`, it is terminated
|
144
|
+
with the `FAILED` if an unhandled exception is thrown during function execution
|
145
|
+
|
146
|
+
Note that, if nested runs or non-fluent runs are created by `patch_function`, `patch_function`
|
147
|
+
is responsible for terminating them by the time it terminates
|
148
|
+
(or in the event of an exception).
|
149
|
+
|
150
|
+
Args:
|
151
|
+
autologging_integration: The autologging integration associated
|
152
|
+
with the `patch_function`.
|
153
|
+
patch_function: A function object compatible with `safe_patch`.
|
154
|
+
tags: A dictionary of string tags to set on each managed run created during the
|
155
|
+
execution of `patch_function`.
|
156
|
+
"""
|
157
|
+
from mlflow.tracking.fluent import active_run
|
158
|
+
from mlflow.utils.autologging_utils import _has_active_training_session
|
159
|
+
|
160
|
+
def create_managed_run():
|
161
|
+
managed_run = mlflow.start_run(tags=tags)
|
162
|
+
_logger.info(
|
163
|
+
"Created MLflow autologging run with ID '%s', which will track hyperparameters,"
|
164
|
+
" performance metrics, model artifacts, and lineage information for the"
|
165
|
+
" current %s workflow",
|
166
|
+
managed_run.info.run_id,
|
167
|
+
autologging_integration,
|
168
|
+
)
|
169
|
+
return managed_run
|
170
|
+
|
171
|
+
def patch_with_managed_run(original, *args, **kwargs):
|
172
|
+
managed_run = None
|
173
|
+
# If there is an active training session but there is no active run
|
174
|
+
# in current thread, it means the thread is spawned by `estimator.fit`
|
175
|
+
# as a worker thread, we should disable autologging in
|
176
|
+
# these worker threads, so skip creating managed run.
|
177
|
+
if not active_run() and not _has_active_training_session():
|
178
|
+
managed_run = create_managed_run()
|
179
|
+
|
180
|
+
try:
|
181
|
+
result = patch_function(original, *args, **kwargs)
|
182
|
+
except (Exception, KeyboardInterrupt):
|
183
|
+
# In addition to standard Python exceptions, handle keyboard interrupts to ensure
|
184
|
+
# that runs are terminated if a user prematurely interrupts training execution
|
185
|
+
# (e.g. via sigint / ctrl-c)
|
186
|
+
if managed_run:
|
187
|
+
mlflow.end_run(RunStatus.to_string(RunStatus.FAILED))
|
188
|
+
raise
|
189
|
+
else:
|
190
|
+
if managed_run:
|
191
|
+
mlflow.end_run(RunStatus.to_string(RunStatus.FINISHED))
|
192
|
+
return result
|
193
|
+
|
194
|
+
return patch_with_managed_run
|
195
|
+
|
196
|
+
|
197
|
+
def is_testing():
|
198
|
+
"""
|
199
|
+
Indicates whether or not autologging functionality is running in test mode (as determined
|
200
|
+
by the `MLFLOW_AUTOLOGGING_TESTING` environment variable). Test mode performs additional
|
201
|
+
validation during autologging, including:
|
202
|
+
|
203
|
+
- Checks for the exception safety of arguments passed to model training functions
|
204
|
+
(i.e. all additional arguments should be "exception safe" functions or classes)
|
205
|
+
- Disables exception handling for patched function logic, ensuring that patch code
|
206
|
+
executes without errors during testing
|
207
|
+
"""
|
208
|
+
return _MLFLOW_AUTOLOGGING_TESTING.get()
|
209
|
+
|
210
|
+
|
211
|
+
def _resolve_extra_tags(autologging_integration, extra_tags):
|
212
|
+
tags = {MLFLOW_AUTOLOGGING: autologging_integration}
|
213
|
+
if extra_tags:
|
214
|
+
if isinstance(extra_tags, dict):
|
215
|
+
if MLFLOW_AUTOLOGGING in extra_tags:
|
216
|
+
extra_tags.pop(MLFLOW_AUTOLOGGING)
|
217
|
+
_logger.warning(
|
218
|
+
f"Tag `{MLFLOW_AUTOLOGGING}` is ignored as it is a reserved tag by MLflow "
|
219
|
+
f"autologging."
|
220
|
+
)
|
221
|
+
tags.update(extra_tags)
|
222
|
+
else:
|
223
|
+
raise mlflow.exceptions.MlflowException.invalid_parameter_value(
|
224
|
+
f"Invalid `extra_tags` type: expecting dictionary, "
|
225
|
+
f"received `{type(extra_tags).__name__}`"
|
226
|
+
)
|
227
|
+
return tags
|
228
|
+
|
229
|
+
|
230
|
+
def safe_patch(
|
231
|
+
autologging_integration,
|
232
|
+
destination,
|
233
|
+
function_name,
|
234
|
+
patch_function,
|
235
|
+
manage_run=False,
|
236
|
+
extra_tags=None,
|
237
|
+
):
|
238
|
+
"""Patches the specified `function_name` on the specified `destination` class for autologging
|
239
|
+
purposes, preceding its implementation with an error-safe copy of the specified patch
|
240
|
+
`patch_function` with the following error handling behavior:
|
241
|
+
- Exceptions thrown from the underlying / original function
|
242
|
+
(`<destination>.<function_name>`) are propagated to the caller.
|
243
|
+
- Exceptions thrown from other parts of the patched implementation (`patch_function`)
|
244
|
+
are caught and logged as warnings.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
autologging_integration: The name of the autologging integration associated with the
|
248
|
+
patch.
|
249
|
+
destination: The Python class on which the patch is being defined.
|
250
|
+
function_name: The name of the function to patch on the specified `destination` class.
|
251
|
+
patch_function: The patched function code to apply. The first argument should be reserved
|
252
|
+
for an `original` argument representing the underlying / original function. Subsequent
|
253
|
+
arguments should be identical to those of the original function being patched.
|
254
|
+
manage_run: If `True`, applies the `with_managed_run` wrapper to the specified
|
255
|
+
`patch_function`, which automatically creates & terminates an MLflow
|
256
|
+
active run during patch code execution if necessary. If `False`,
|
257
|
+
does not apply the `with_managed_run` wrapper to the specified
|
258
|
+
`patch_function`.
|
259
|
+
extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
|
260
|
+
"""
|
261
|
+
from mlflow.tracking.fluent import active_run
|
262
|
+
from mlflow.utils.autologging_utils import autologging_is_disabled, get_autologging_config
|
263
|
+
|
264
|
+
# NB: Checking the signature of the patch function rather than original, so that we don't
|
265
|
+
# accidentally change the behavior of existing patches that may use sync patch function
|
266
|
+
# for async original functions (e.g. LangChain).
|
267
|
+
is_async_function = inspect.iscoroutinefunction(patch_function)
|
268
|
+
|
269
|
+
if manage_run:
|
270
|
+
if is_async_function:
|
271
|
+
raise MlflowException("manage_run parameter is not supported for async functions.")
|
272
|
+
|
273
|
+
tags = _resolve_extra_tags(autologging_integration, extra_tags)
|
274
|
+
patch_function = with_managed_run(
|
275
|
+
autologging_integration,
|
276
|
+
patch_function,
|
277
|
+
tags=tags,
|
278
|
+
)
|
279
|
+
|
280
|
+
original_fn = gorilla.get_original_attribute(
|
281
|
+
destination, function_name, bypass_descriptor_protocol=False
|
282
|
+
)
|
283
|
+
# Retrieve raw attribute while bypassing the descriptor protocol
|
284
|
+
raw_original_obj = gorilla.get_original_attribute(
|
285
|
+
destination, function_name, bypass_descriptor_protocol=True
|
286
|
+
)
|
287
|
+
if original_fn != raw_original_obj:
|
288
|
+
raise RuntimeError(f"Unsupported patch on {destination}.{function_name}")
|
289
|
+
elif isinstance(original_fn, property):
|
290
|
+
if is_async_function:
|
291
|
+
raise MlflowException("Patching async property methods is not supported.")
|
292
|
+
|
293
|
+
is_property_method = True
|
294
|
+
|
295
|
+
# For property decorated methods (a kind of method delegation), e.g.
|
296
|
+
# class A:
|
297
|
+
# @property
|
298
|
+
# def f1(self):
|
299
|
+
# ...
|
300
|
+
# return delegated_f1
|
301
|
+
#
|
302
|
+
# suppose `a1` is an instance of class `A`,
|
303
|
+
# `A.f1.fget` will get the original `def f1(self)` method,
|
304
|
+
# and `A.f1.fget(a1)` will be equivalent to `a1.f1()` and
|
305
|
+
# its return value will be the `delegated_f1` function.
|
306
|
+
# So using the `property.fget` we can construct the (delegated) "original_fn"
|
307
|
+
def original(self, *args, **kwargs):
|
308
|
+
# the `original_fn.fget` will get the original method decorated by `property`
|
309
|
+
# the `original_fn.fget(self)` will get the delegated function returned by the
|
310
|
+
# property decorated method.
|
311
|
+
bound_delegate_method = original_fn.fget(self)
|
312
|
+
return bound_delegate_method(*args, **kwargs)
|
313
|
+
|
314
|
+
else:
|
315
|
+
original = original_fn
|
316
|
+
is_property_method = False
|
317
|
+
|
318
|
+
def safe_patch_function(*args, **kwargs):
|
319
|
+
"""
|
320
|
+
A safe wrapper around the specified `patch_function` implementation designed to
|
321
|
+
handle exceptions thrown during the execution of `patch_function`. This wrapper
|
322
|
+
distinguishes exceptions thrown from the underlying / original function
|
323
|
+
(`<destination>.<function_name>`) from exceptions thrown from other parts of
|
324
|
+
`patch_function`. This distinction is made by passing an augmented version of the
|
325
|
+
underlying / original function to `patch_function` that uses nonlocal state to track
|
326
|
+
whether or not it has been executed and whether or not it threw an exception.
|
327
|
+
Exceptions thrown from the underlying / original function are propagated to the caller,
|
328
|
+
while exceptions thrown from other parts of `patch_function` are caught and logged as
|
329
|
+
warnings.
|
330
|
+
|
331
|
+
NB: PLEASE BE SUPER CAREFUL WHEN MODIFYING THIS FUNCTION. IT IS USED IN A WIDE VARIETY
|
332
|
+
OF CONTEXTX AND CRITICAL PATH IN DBR/MLR BY DEFAULT. ANY BUG HERE CAN BREAK USERS'
|
333
|
+
WORKLOAD WITHOUT THEM TAKING ANY ACTION.
|
334
|
+
"""
|
335
|
+
# Reroute warnings encountered during the patch function implementation to an MLflow event
|
336
|
+
# logger, and enforce silent mode if applicable (i.e. if the corresponding autologging
|
337
|
+
# integration was called with `silent=True`), hiding MLflow event logging statements and
|
338
|
+
# hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding
|
339
|
+
# the user's original / underlying ML function). Non-MLflow warnings are enabled during the
|
340
|
+
# execution of the original / underlying ML function
|
341
|
+
#
|
342
|
+
# Note that we've opted *not* to apply this context manager as a decorator on
|
343
|
+
# `safe_patch_function` because the context-manager-as-decorator pattern uses
|
344
|
+
# `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled
|
345
|
+
# during model serialization by ML frameworks such as scikit-learn
|
346
|
+
is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
|
347
|
+
with (
|
348
|
+
MlflowEventsAndWarningsBehaviorGlobally(
|
349
|
+
# MLflow warnings emitted during autologging training sessions are likely not
|
350
|
+
# actionable and result from the autologging implementation invoking another MLflow
|
351
|
+
# API. Accordingly, we reroute these warnings to the MLflow event logger with level
|
352
|
+
# WARNING For reference, see recommended warning and event logging behaviors from
|
353
|
+
# https://docs.python.org/3/howto/logging.html#when-to-use-logging
|
354
|
+
reroute_warnings=True,
|
355
|
+
disable_event_logs=is_silent_mode,
|
356
|
+
disable_warnings=is_silent_mode,
|
357
|
+
),
|
358
|
+
NonMlflowWarningsBehaviorForCurrentThread(
|
359
|
+
# non-MLflow Warnings emitted during the autologging preamble (before the original /
|
360
|
+
# underlying ML function is called) and postamble (after the original / underlying
|
361
|
+
# ML function is called) are likely not actionable and result from the autologging
|
362
|
+
# implementation invoking an API from a dependent library. Accordingly, we reroute
|
363
|
+
# these warnings to the MLflow event logger with level WARNING. For reference, see
|
364
|
+
# recommended warning and event logging behaviors from
|
365
|
+
# https://docs.python.org/3/howto/logging.html#when-to-use-logging
|
366
|
+
reroute_warnings=True,
|
367
|
+
disable_warnings=is_silent_mode,
|
368
|
+
),
|
369
|
+
):
|
370
|
+
if is_testing():
|
371
|
+
preexisting_run_for_testing = active_run()
|
372
|
+
|
373
|
+
# Whether or not to exclude autologged content from user-created fluent runs
|
374
|
+
# (i.e. runs created manually via `mlflow.start_run()`)
|
375
|
+
exclusive = get_autologging_config(autologging_integration, "exclusive", False)
|
376
|
+
user_created_fluent_run_is_active = (
|
377
|
+
active_run() and not _AutologgingSessionManager.active_session()
|
378
|
+
)
|
379
|
+
active_session_failed = (
|
380
|
+
_AutologgingSessionManager.active_session() is not None
|
381
|
+
and _AutologgingSessionManager.active_session().state == "failed"
|
382
|
+
)
|
383
|
+
|
384
|
+
if (
|
385
|
+
active_session_failed
|
386
|
+
or autologging_is_disabled(autologging_integration)
|
387
|
+
or (user_created_fluent_run_is_active and exclusive)
|
388
|
+
or (
|
389
|
+
mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
|
390
|
+
and autologging_integration
|
391
|
+
)
|
392
|
+
):
|
393
|
+
# If the autologging integration associated with this patch is disabled,
|
394
|
+
# or if the current autologging integration is in exclusive mode and a user-created
|
395
|
+
# fluent run is active, call the original function and return. Restore the original
|
396
|
+
# warning behavior during original function execution, since autologging is being
|
397
|
+
# skipped
|
398
|
+
with NonMlflowWarningsBehaviorForCurrentThread(
|
399
|
+
disable_warnings=False,
|
400
|
+
reroute_warnings=False,
|
401
|
+
):
|
402
|
+
return original(*args, **kwargs)
|
403
|
+
|
404
|
+
# Whether or not the original / underlying function has been called during the
|
405
|
+
# execution of patched code
|
406
|
+
original_has_been_called = False
|
407
|
+
# The value returned by the call to the original / underlying function during
|
408
|
+
# the execution of patched code
|
409
|
+
original_result = None
|
410
|
+
# Whether or not an exception was raised from within the original / underlying function
|
411
|
+
# during the execution of patched code
|
412
|
+
failed_during_original = False
|
413
|
+
# The active MLflow run (if any) associated with patch code execution
|
414
|
+
patch_function_run_for_testing = None
|
415
|
+
# The exception raised during executing patching function
|
416
|
+
patch_error = None
|
417
|
+
|
418
|
+
with _AutologgingSessionManager.start_session(autologging_integration) as session:
|
419
|
+
event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
|
420
|
+
|
421
|
+
def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
|
422
|
+
try:
|
423
|
+
event_logger.log_original_function_start(og_args, og_kwargs)
|
424
|
+
|
425
|
+
original_fn_result = original_fn(*og_args, **og_kwargs)
|
426
|
+
|
427
|
+
event_logger.log_original_function_success(og_args, og_kwargs)
|
428
|
+
return original_fn_result
|
429
|
+
except Exception as e:
|
430
|
+
event_logger.log_original_function_error(og_args, og_kwargs, e)
|
431
|
+
|
432
|
+
nonlocal failed_during_original
|
433
|
+
failed_during_original = True
|
434
|
+
raise
|
435
|
+
|
436
|
+
try:
|
437
|
+
|
438
|
+
def call_original(*og_args, **og_kwargs):
|
439
|
+
def _original_fn(*_og_args, **_og_kwargs):
|
440
|
+
if is_testing():
|
441
|
+
_validate_args(
|
442
|
+
autologging_integration,
|
443
|
+
function_name,
|
444
|
+
args,
|
445
|
+
kwargs,
|
446
|
+
og_args,
|
447
|
+
og_kwargs,
|
448
|
+
)
|
449
|
+
# By the time `original` is called by the patch implementation, we
|
450
|
+
# assume that either: 1. the patch implementation has already
|
451
|
+
# created an MLflow run or 2. the patch code will not create an
|
452
|
+
# MLflow run during the current execution. Here, we capture a
|
453
|
+
# reference to the active run, which we will use later on to
|
454
|
+
# determine whether or not the patch implementation created
|
455
|
+
# a run and perform validation if necessary
|
456
|
+
nonlocal patch_function_run_for_testing
|
457
|
+
patch_function_run_for_testing = active_run()
|
458
|
+
|
459
|
+
nonlocal original_has_been_called
|
460
|
+
original_has_been_called = True
|
461
|
+
|
462
|
+
nonlocal original_result
|
463
|
+
# Show all non-MLflow warnings as normal (i.e. not as event logs)
|
464
|
+
# during original function execution, even if silent mode is enabled
|
465
|
+
# (`silent=True`), since these warnings originate from the ML framework
|
466
|
+
# or one of its dependencies and are likely relevant to the caller
|
467
|
+
with NonMlflowWarningsBehaviorForCurrentThread(
|
468
|
+
disable_warnings=False,
|
469
|
+
reroute_warnings=False,
|
470
|
+
):
|
471
|
+
original_result = original(*_og_args, **_og_kwargs)
|
472
|
+
return original_result
|
473
|
+
|
474
|
+
return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
|
475
|
+
|
476
|
+
# Apply the name, docstring, and signature of `original` to `call_original`.
|
477
|
+
# This is important because several autologging patch implementations inspect
|
478
|
+
# the signature of the `original` argument during execution
|
479
|
+
call_original = update_wrapper_extended(call_original, original)
|
480
|
+
|
481
|
+
event_logger.log_patch_function_start(args, kwargs)
|
482
|
+
|
483
|
+
patch_function(call_original, *args, **kwargs)
|
484
|
+
|
485
|
+
session.state = "succeeded"
|
486
|
+
event_logger.log_patch_function_success(args, kwargs)
|
487
|
+
|
488
|
+
except Exception as e:
|
489
|
+
session.state = "failed"
|
490
|
+
patch_error = e
|
491
|
+
# Exceptions thrown during execution of the original function should be
|
492
|
+
# propagated to the caller. Additionally, exceptions encountered during test
|
493
|
+
# mode should be reraised to detect bugs in autologging implementations
|
494
|
+
if failed_during_original or is_testing():
|
495
|
+
raise
|
496
|
+
|
497
|
+
if is_testing() and not preexisting_run_for_testing:
|
498
|
+
# If an MLflow run was created during the execution of patch code, verify that
|
499
|
+
# it is no longer active and that it contains expected autologging tags
|
500
|
+
assert not active_run(), (
|
501
|
+
f"Autologging integration {autologging_integration} leaked an active run"
|
502
|
+
)
|
503
|
+
if patch_function_run_for_testing:
|
504
|
+
_validate_autologging_run(
|
505
|
+
autologging_integration, patch_function_run_for_testing.info.run_id
|
506
|
+
)
|
507
|
+
try:
|
508
|
+
if original_has_been_called:
|
509
|
+
return original_result
|
510
|
+
else:
|
511
|
+
return call_original_fn_with_event_logging(original, args, kwargs)
|
512
|
+
finally:
|
513
|
+
# If original function succeeds, but `patch_function_exception` exists,
|
514
|
+
# it represent patching code unexpected failure, so we call
|
515
|
+
# `log_patch_function_error` in this case.
|
516
|
+
# If original function failed, we don't call `log_patch_function_error`
|
517
|
+
# even if `patch_function_exception` exists, because original function failure
|
518
|
+
# means there's some error in user code (e.g. user provide wrong arguments)
|
519
|
+
if patch_error is not None and not failed_during_original:
|
520
|
+
event_logger.log_patch_function_error(args, kwargs, patch_error)
|
521
|
+
_logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
|
522
|
+
|
523
|
+
async def async_safe_patch_function(*args, **kwargs):
|
524
|
+
"""
|
525
|
+
Async version of safe_patch_function.
|
526
|
+
|
527
|
+
This code brainlessly copies the synchronous version of the function, but with async
|
528
|
+
context managers and async functions. This is done to avoid the risk of introducing
|
529
|
+
any bugs or regressions in the async version of the function. Note that we need to
|
530
|
+
be really careful here, because autologging is enabled by-default in DBR/MLR, hence
|
531
|
+
any bug here can break users' workload without them taking any action.
|
532
|
+
|
533
|
+
That said, some long comments are omitted in this version to avoid redundancy. If
|
534
|
+
you want to understand the context of the code better, please refer to the
|
535
|
+
synchronous version as well.
|
536
|
+
"""
|
537
|
+
is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
|
538
|
+
async with (
|
539
|
+
MlflowEventsAndWarningsBehaviorGlobally(
|
540
|
+
reroute_warnings=True,
|
541
|
+
disable_event_logs=is_silent_mode,
|
542
|
+
disable_warnings=is_silent_mode,
|
543
|
+
),
|
544
|
+
NonMlflowWarningsBehaviorForCurrentThread(
|
545
|
+
disable_warnings=is_silent_mode,
|
546
|
+
reroute_warnings=True,
|
547
|
+
),
|
548
|
+
):
|
549
|
+
if is_testing():
|
550
|
+
preexisting_run_for_testing = active_run()
|
551
|
+
|
552
|
+
# Whether or not to exclude autologged content from user-created fluent runs
|
553
|
+
# (i.e. runs created manually via `mlflow.start_run()`)
|
554
|
+
exclusive = get_autologging_config(autologging_integration, "exclusive", False)
|
555
|
+
user_created_fluent_run_is_active = (
|
556
|
+
active_run() and not _AutologgingSessionManager.active_session()
|
557
|
+
)
|
558
|
+
active_session_failed = (
|
559
|
+
_AutologgingSessionManager.active_session() is not None
|
560
|
+
and _AutologgingSessionManager.active_session().state == "failed"
|
561
|
+
)
|
562
|
+
|
563
|
+
if (
|
564
|
+
active_session_failed
|
565
|
+
or autologging_is_disabled(autologging_integration)
|
566
|
+
or (user_created_fluent_run_is_active and exclusive)
|
567
|
+
or (
|
568
|
+
mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
|
569
|
+
and autologging_integration
|
570
|
+
)
|
571
|
+
):
|
572
|
+
async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
|
573
|
+
return await original(*args, **kwargs)
|
574
|
+
|
575
|
+
original_has_been_called = False
|
576
|
+
original_result = None
|
577
|
+
failed_during_original = False
|
578
|
+
patch_function_run_for_testing = None
|
579
|
+
patch_error = None
|
580
|
+
|
581
|
+
async with _AutologgingSessionManager.astart_session(
|
582
|
+
autologging_integration
|
583
|
+
) as session:
|
584
|
+
event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
|
585
|
+
|
586
|
+
async def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
|
587
|
+
try:
|
588
|
+
event_logger.log_original_function_start(og_args, og_kwargs)
|
589
|
+
original_fn_result = await original_fn(*og_args, **og_kwargs)
|
590
|
+
event_logger.log_original_function_success(og_args, og_kwargs)
|
591
|
+
return original_fn_result
|
592
|
+
except Exception as e:
|
593
|
+
event_logger.log_original_function_error(og_args, og_kwargs, e)
|
594
|
+
nonlocal failed_during_original
|
595
|
+
failed_during_original = True
|
596
|
+
raise
|
597
|
+
|
598
|
+
try:
|
599
|
+
|
600
|
+
async def call_original(*og_args, **og_kwargs):
|
601
|
+
async def _original_fn(*_og_args, **_og_kwargs):
|
602
|
+
if is_testing():
|
603
|
+
_validate_args(
|
604
|
+
autologging_integration,
|
605
|
+
function_name,
|
606
|
+
args,
|
607
|
+
kwargs,
|
608
|
+
og_args,
|
609
|
+
og_kwargs,
|
610
|
+
)
|
611
|
+
nonlocal patch_function_run_for_testing
|
612
|
+
patch_function_run_for_testing = active_run()
|
613
|
+
|
614
|
+
nonlocal original_has_been_called
|
615
|
+
original_has_been_called = True
|
616
|
+
|
617
|
+
nonlocal original_result
|
618
|
+
async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
|
619
|
+
original_result = await original(*_og_args, **_og_kwargs)
|
620
|
+
return original_result
|
621
|
+
|
622
|
+
return await call_original_fn_with_event_logging(
|
623
|
+
_original_fn, og_args, og_kwargs
|
624
|
+
)
|
625
|
+
|
626
|
+
# Apply the name, docstring, and signature of `original` to `call_original`.
|
627
|
+
# This is important because several autologging patch implementations inspect
|
628
|
+
# the signature of the `original` argument during execution
|
629
|
+
call_original = update_wrapper_extended(call_original, original)
|
630
|
+
|
631
|
+
event_logger.log_patch_function_start(args, kwargs)
|
632
|
+
|
633
|
+
await patch_function(call_original, *args, **kwargs)
|
634
|
+
|
635
|
+
session.state = "succeeded"
|
636
|
+
event_logger.log_patch_function_success(args, kwargs)
|
637
|
+
|
638
|
+
except Exception as e:
|
639
|
+
session.state = "failed"
|
640
|
+
patch_error = e
|
641
|
+
# Exceptions thrown during execution of the original function should be
|
642
|
+
# propagated to the caller. Additionally, exceptions encountered during test
|
643
|
+
# mode should be reraised to detect bugs in autologging implementations
|
644
|
+
if failed_during_original or is_testing():
|
645
|
+
raise
|
646
|
+
|
647
|
+
if is_testing() and not preexisting_run_for_testing:
|
648
|
+
# If an MLflow run was created during the execution of patch code, verify that
|
649
|
+
# it is no longer active and that it contains expected autologging tags
|
650
|
+
assert not active_run(), (
|
651
|
+
f"Autologging integration {autologging_integration} leaked an active run"
|
652
|
+
)
|
653
|
+
if patch_function_run_for_testing:
|
654
|
+
_validate_autologging_run(
|
655
|
+
autologging_integration, patch_function_run_for_testing.info.run_id
|
656
|
+
)
|
657
|
+
try:
|
658
|
+
if original_has_been_called:
|
659
|
+
return original_result
|
660
|
+
else:
|
661
|
+
return await call_original_fn_with_event_logging(original, args, kwargs)
|
662
|
+
finally:
|
663
|
+
if patch_error is not None and not failed_during_original:
|
664
|
+
event_logger.log_patch_function_error(args, kwargs, patch_error)
|
665
|
+
_logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
|
666
|
+
|
667
|
+
if is_property_method:
|
668
|
+
# Create a patched function (also property decorated)
|
669
|
+
# like:
|
670
|
+
#
|
671
|
+
# class A:
|
672
|
+
# @property
|
673
|
+
# def get_bound_safe_patch_fn(self):
|
674
|
+
# original_fn.fget(self) # do availability check
|
675
|
+
# return bound_safe_patch_fn
|
676
|
+
#
|
677
|
+
# Suppose `a1` is instance of class A,
|
678
|
+
# then `a1.get_bound_safe_patch_fn(*args, **kwargs)` will be equivalent to
|
679
|
+
# `bound_safe_patch_fn(*args, **kwargs)`
|
680
|
+
def get_bound_safe_patch_fn(self):
|
681
|
+
# This `original_fn.fget` call is for availability check, if it raise error
|
682
|
+
# then `hasattr(obj, {func_name})` will return False
|
683
|
+
# so it mimic the original property behavior.
|
684
|
+
original_fn.fget(self)
|
685
|
+
|
686
|
+
def bound_safe_patch_fn(*args, **kwargs):
|
687
|
+
return safe_patch_function(self, *args, **kwargs)
|
688
|
+
|
689
|
+
# Make bound method `instance.target_method` keep the same doc and signature.
|
690
|
+
# Here return the bound safe patch function because user call property decorated
|
691
|
+
# method will like `instance.property_decorated_method(...)`, and internally it will
|
692
|
+
# call the `bound_safe_patch_fn`, the argument list don't include the `self` argument,
|
693
|
+
# so return bound function here.
|
694
|
+
return update_wrapper_extended(bound_safe_patch_fn, original_fn.fget)
|
695
|
+
|
696
|
+
# Make unbound method `class.target_method` keep the same doc and signature
|
697
|
+
get_bound_safe_patch_fn = update_wrapper_extended(get_bound_safe_patch_fn, original_fn.fget)
|
698
|
+
safe_patch_obj = property(get_bound_safe_patch_fn)
|
699
|
+
elif is_async_function:
|
700
|
+
safe_patch_obj = update_wrapper_extended(async_safe_patch_function, original)
|
701
|
+
else:
|
702
|
+
safe_patch_obj = update_wrapper_extended(safe_patch_function, original)
|
703
|
+
|
704
|
+
new_patch = _wrap_patch(destination, function_name, safe_patch_obj)
|
705
|
+
_store_patch(autologging_integration, new_patch)
|
706
|
+
|
707
|
+
|
708
|
+
def revert_patches(autologging_integration):
|
709
|
+
"""Reverts all patches on the specified destination class for autologging disablement purposes.
|
710
|
+
|
711
|
+
Args:
|
712
|
+
autologging_integration: The name of the autologging integration associated with the
|
713
|
+
patch. Note: If called via fluent api (`autologging_integration="mlflow"`), then revert
|
714
|
+
all patches for all active autologging integrations.
|
715
|
+
|
716
|
+
"""
|
717
|
+
for patch in _AUTOLOGGING_PATCHES.get(autologging_integration, []):
|
718
|
+
gorilla.revert(patch)
|
719
|
+
|
720
|
+
_AUTOLOGGING_PATCHES.pop(autologging_integration, None)
|
721
|
+
|
722
|
+
|
723
|
+
# Represents an active autologging session using two fields:
|
724
|
+
# - integration: the name of the autologging integration corresponding to the session
|
725
|
+
# - id: a unique session identifier (e.g., a UUID)
|
726
|
+
# - state: the state of AutologgingSession, will be one of running/succeeded/failed
|
727
|
+
class AutologgingSession:
|
728
|
+
def __init__(self, integration, id_):
|
729
|
+
self.integration = integration
|
730
|
+
self.id = id_
|
731
|
+
self.state = "running"
|
732
|
+
|
733
|
+
|
734
|
+
class _AutologgingSessionManager:
|
735
|
+
_session = None
|
736
|
+
|
737
|
+
@classmethod
|
738
|
+
@contextmanager
|
739
|
+
def start_session(cls, integration):
|
740
|
+
try:
|
741
|
+
prev_session = cls._session
|
742
|
+
if prev_session is None:
|
743
|
+
session_id = uuid.uuid4().hex
|
744
|
+
cls._session = AutologgingSession(integration, session_id)
|
745
|
+
yield cls._session
|
746
|
+
finally:
|
747
|
+
# Only end the session upon termination of the context if we created
|
748
|
+
# the session; otherwise, leave the session open for later termination
|
749
|
+
# by its creator
|
750
|
+
if prev_session is None:
|
751
|
+
cls._end_session()
|
752
|
+
|
753
|
+
@classmethod
|
754
|
+
@asynccontextmanager
|
755
|
+
async def astart_session(cls, integration):
|
756
|
+
try:
|
757
|
+
prev_session = cls._session
|
758
|
+
if prev_session is None:
|
759
|
+
session_id = uuid.uuid4().hex
|
760
|
+
cls._session = AutologgingSession(integration, session_id)
|
761
|
+
yield cls._session
|
762
|
+
finally:
|
763
|
+
if prev_session is None:
|
764
|
+
cls._end_session()
|
765
|
+
|
766
|
+
@classmethod
|
767
|
+
def active_session(cls):
|
768
|
+
return cls._session
|
769
|
+
|
770
|
+
@classmethod
|
771
|
+
def _end_session(cls):
|
772
|
+
cls._session = None
|
773
|
+
|
774
|
+
|
775
|
+
def update_wrapper_extended(wrapper, wrapped):
|
776
|
+
"""Update a `wrapper` function to look like the `wrapped` function. This is an extension of
|
777
|
+
`functools.update_wrapper` that applies the docstring *and* signature of `wrapped` to
|
778
|
+
`wrapper`, producing a new function.
|
779
|
+
|
780
|
+
Returns:
|
781
|
+
A new function with the same implementation as `wrapper` and the same docstring
|
782
|
+
& signature as `wrapped`.
|
783
|
+
"""
|
784
|
+
updated_wrapper = functools.update_wrapper(wrapper, wrapped)
|
785
|
+
# Assign the signature of the `wrapped` function to the updated wrapper function.
|
786
|
+
# Certain frameworks may disallow signature inspection, causing `inspect.signature()` to throw.
|
787
|
+
# One such example is the `tensorflow.estimator.Estimator.export_savedmodel()` function
|
788
|
+
try:
|
789
|
+
updated_wrapper.__signature__ = inspect.signature(wrapped)
|
790
|
+
except Exception:
|
791
|
+
_logger.debug("Failed to restore original signature for wrapper around %s", wrapped)
|
792
|
+
return updated_wrapper
|
793
|
+
|
794
|
+
|
795
|
+
def _wrap_patch(destination, name, patch_obj, settings=None):
|
796
|
+
"""Apply a patch.
|
797
|
+
|
798
|
+
Args:
|
799
|
+
destination: Patch destination.
|
800
|
+
name: Name of the attribute at the destination.
|
801
|
+
patch_obj: Patch object, it should be a function or a property decorated function
|
802
|
+
to be assigned to the patch point {destination}.{name}.
|
803
|
+
settings: Settings for gorilla.Patch.
|
804
|
+
|
805
|
+
"""
|
806
|
+
if settings is None:
|
807
|
+
settings = gorilla.Settings(allow_hit=True, store_hit=True)
|
808
|
+
|
809
|
+
patch = gorilla.Patch(destination, name, patch_obj, settings=settings)
|
810
|
+
gorilla.apply(patch)
|
811
|
+
return patch
|
812
|
+
|
813
|
+
|
814
|
+
def _store_patch(autologging_integration, patch):
|
815
|
+
"""
|
816
|
+
Stores a patch for a specified autologging_integration class. Later to be used for being able
|
817
|
+
to revert the patch when disabling autologging.
|
818
|
+
|
819
|
+
Args:
|
820
|
+
autologging_integration: The name of the autologging integration associated with the
|
821
|
+
patch.
|
822
|
+
patch: The patch to be stored.
|
823
|
+
"""
|
824
|
+
if autologging_integration in _AUTOLOGGING_PATCHES:
|
825
|
+
_AUTOLOGGING_PATCHES[autologging_integration].add(patch)
|
826
|
+
else:
|
827
|
+
_AUTOLOGGING_PATCHES[autologging_integration] = {patch}
|
828
|
+
|
829
|
+
|
830
|
+
def _validate_autologging_run(autologging_integration, run_id):
|
831
|
+
"""
|
832
|
+
For testing purposes, verifies that an MLflow run produced by an `autologging_integration`
|
833
|
+
satisfies the following properties:
|
834
|
+
|
835
|
+
- The run has an autologging tag whose value is the name of the autologging integration
|
836
|
+
- The run has a terminal status (e.g., KILLED, FAILED, FINISHED)
|
837
|
+
"""
|
838
|
+
from mlflow.tracking.client import MlflowClient
|
839
|
+
|
840
|
+
client = MlflowClient()
|
841
|
+
run = client.get_run(run_id)
|
842
|
+
autologging_tag_value = run.data.tags.get(MLFLOW_AUTOLOGGING)
|
843
|
+
assert autologging_tag_value == autologging_integration, (
|
844
|
+
f"Autologging run with id {run_id} failed to set autologging tag with expected value. "
|
845
|
+
f"Expected: '{autologging_integration}', Actual: '{autologging_tag_value}'"
|
846
|
+
)
|
847
|
+
assert RunStatus.is_terminated(RunStatus.from_string(run.info.status)), (
|
848
|
+
f"Autologging run with id {run_id} has a non-terminal status '{run.info.status}'"
|
849
|
+
)
|
850
|
+
|
851
|
+
|
852
|
+
class ValidationExemptArgument(NamedTuple):
|
853
|
+
"""
|
854
|
+
A NamedTuple representing the properties of an argument that is exempt from validation
|
855
|
+
|
856
|
+
autologging_integration: The name of the autologging integration.
|
857
|
+
function_name: The name of the function that is being validated.
|
858
|
+
type_function: A Callable that accepts an object and returns True if the given object matches
|
859
|
+
the argument type. Returns False otherwise.
|
860
|
+
positional_argument_index: The index of the argument in the function signature.
|
861
|
+
keyword_argument_name: The name of the argument in the function signature.
|
862
|
+
"""
|
863
|
+
|
864
|
+
autologging_integration: str
|
865
|
+
function_name: str
|
866
|
+
type_function: Callable[..., Any]
|
867
|
+
positional_argument_index: Optional[int] = None
|
868
|
+
keyword_argument_name: Optional[str] = None
|
869
|
+
|
870
|
+
def matches(
|
871
|
+
self,
|
872
|
+
autologging_integration,
|
873
|
+
function_name,
|
874
|
+
value,
|
875
|
+
argument_index=None,
|
876
|
+
argument_name=None,
|
877
|
+
):
|
878
|
+
"""
|
879
|
+
This method checks if the properties provided through the function arguments matches the
|
880
|
+
properties defined in the NamedTuple.
|
881
|
+
|
882
|
+
Args:
|
883
|
+
autologging_integration: The name of an autologging integration.
|
884
|
+
function_name: The name of the function that is being matched.
|
885
|
+
value: The value of the argument.
|
886
|
+
argument_index: The index of the argument, if it is passed as a positional
|
887
|
+
argument. Otherwise it is None.
|
888
|
+
argument_name: The name of the argument, if it is passed as a keyword
|
889
|
+
argument. Otherwise it is None.
|
890
|
+
|
891
|
+
Returns:
|
892
|
+
Returns True if the given function properties matches the exempt argument's
|
893
|
+
properties. Returns False otherwise.
|
894
|
+
"""
|
895
|
+
return (
|
896
|
+
self.autologging_integration == autologging_integration
|
897
|
+
and self.function_name == function_name
|
898
|
+
and (
|
899
|
+
self.positional_argument_index == argument_index
|
900
|
+
or self.keyword_argument_name == argument_name
|
901
|
+
)
|
902
|
+
and self.type_function(value)
|
903
|
+
)
|
904
|
+
|
905
|
+
|
906
|
+
# WARNING: Exemptions should NOT be introduced unless absolutely necessary. If deemed necessary,
|
907
|
+
# clear reasons must be provided as comment in addition to thorough integration tests.
|
908
|
+
_VALIDATION_EXEMPT_ARGUMENTS = [
|
909
|
+
# When extracting implicitly defined `batch_size` in the case that `x` is a generator or a
|
910
|
+
# generator class, we need to consume and restore the first element back to the generator to
|
911
|
+
# calculate the `batch_size`. This means that:
|
912
|
+
# 1. The type of `x` will become 'generator' regardless if user provided `x` as a generator or a
|
913
|
+
# custom generator class.
|
914
|
+
# 2. The instance of `x` will be different, since we reconstructed the generator after consuming
|
915
|
+
# the first element.
|
916
|
+
ValidationExemptArgument("tensorflow", "fit", is_iterator, 1, "x"),
|
917
|
+
ValidationExemptArgument("keras", "fit", is_iterator, 1, "x"),
|
918
|
+
]
|
919
|
+
|
920
|
+
|
921
|
+
def _is_arg_exempt_from_validation(
|
922
|
+
autologging_integration,
|
923
|
+
function_name,
|
924
|
+
argument,
|
925
|
+
argument_index=None,
|
926
|
+
argument_name=None,
|
927
|
+
):
|
928
|
+
"""This function is responsible for determining whether or not an argument is exempt from
|
929
|
+
autolog safety validations. This includes both type checking and immutable checking.
|
930
|
+
|
931
|
+
Args:
|
932
|
+
autologging_integration: The name of the autologging integration.
|
933
|
+
function_name: The name of the function that is being validated.
|
934
|
+
argument: The actual argument.
|
935
|
+
argument_index: The index of the argument, if it is passed as a positional
|
936
|
+
argument. Otherwise it is None.
|
937
|
+
argument_name: The name of the argument, if it is passed as a keyword argument.
|
938
|
+
Otherwise it is None.
|
939
|
+
|
940
|
+
Returns:
|
941
|
+
True or False
|
942
|
+
"""
|
943
|
+
return any(
|
944
|
+
exemption.matches(
|
945
|
+
autologging_integration,
|
946
|
+
function_name,
|
947
|
+
argument,
|
948
|
+
argument_index,
|
949
|
+
argument_name,
|
950
|
+
)
|
951
|
+
for exemption in _VALIDATION_EXEMPT_ARGUMENTS
|
952
|
+
)
|
953
|
+
|
954
|
+
|
955
|
+
def _validate_args(
|
956
|
+
autologging_integration,
|
957
|
+
function_name,
|
958
|
+
user_call_args,
|
959
|
+
user_call_kwargs,
|
960
|
+
autologging_call_args,
|
961
|
+
autologging_call_kwargs,
|
962
|
+
):
|
963
|
+
"""
|
964
|
+
Used for testing purposes to verify that, when a patched ML function calls its underlying
|
965
|
+
/ original ML function, the following properties are satisfied:
|
966
|
+
|
967
|
+
- All arguments supplied to the patched ML function are forwarded to the
|
968
|
+
original ML function
|
969
|
+
- Any additional arguments supplied to the original function are exception safe (i.e.
|
970
|
+
they are either functions decorated with the `@exception_safe_function_for_class` or
|
971
|
+
`@pickalable_exception_safe_function` decorators, or classes / instances of classes with
|
972
|
+
type `ExceptionSafeClass`
|
973
|
+
"""
|
974
|
+
|
975
|
+
def _validate_new_input(inp):
|
976
|
+
"""
|
977
|
+
Validates a new input (arg or kwarg) introduced to the underlying / original ML function
|
978
|
+
call during the execution of a patched ML function. The new input is valid if:
|
979
|
+
|
980
|
+
- The new input is a function that has been decorated with
|
981
|
+
`exception_safe_function_for_class` or `pickalable_exception_safe_function`
|
982
|
+
- OR the new input is a class with the `ExceptionSafeClass` metaclass
|
983
|
+
- OR the new input is a list and each of its elements is valid according to the
|
984
|
+
these criteria
|
985
|
+
"""
|
986
|
+
if type(inp) == list:
|
987
|
+
for item in inp:
|
988
|
+
_validate_new_input(item)
|
989
|
+
elif isinstance(inp, dict) and "callbacks" in inp:
|
990
|
+
_validate_new_input(inp["callbacks"])
|
991
|
+
elif callable(inp):
|
992
|
+
assert getattr(inp, _ATTRIBUTE_EXCEPTION_SAFE, False), (
|
993
|
+
f"New function argument '{inp}' passed to original function is not exception-safe."
|
994
|
+
" Please decorate the function with `exception_safe_function` or "
|
995
|
+
"`pickalable_exception_safe_function`"
|
996
|
+
)
|
997
|
+
else:
|
998
|
+
assert hasattr(inp, "__class__") and type(inp.__class__) in [
|
999
|
+
ExceptionSafeClass,
|
1000
|
+
ExceptionSafeAbstractClass,
|
1001
|
+
], (
|
1002
|
+
f"Invalid new input '{inp}'. New args / kwargs introduced to `original` function "
|
1003
|
+
"calls by patched code must either be functions decorated with "
|
1004
|
+
"`exception_safe_function_for_class`, instances of classes with the "
|
1005
|
+
"`ExceptionSafeClass` or `ExceptionSafeAbstractClass` metaclass safe or lists of "
|
1006
|
+
"such exception safe functions / classes."
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
def _assert_autologging_input_positional_args_are_superset(
|
1010
|
+
autologging_call_input, user_call_input
|
1011
|
+
):
|
1012
|
+
length_diff = len(autologging_call_input) - len(user_call_input)
|
1013
|
+
assert length_diff >= 0, (
|
1014
|
+
f"{length_diff} expected inputs are missing from the call to the original function."
|
1015
|
+
)
|
1016
|
+
|
1017
|
+
def _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input):
|
1018
|
+
assert set(user_call_input.keys()).issubset(set(autologging_call_input.keys())), (
|
1019
|
+
"Keyword or dictionary arguments to original function omit"
|
1020
|
+
" one or more expected keys: '{}'".format(
|
1021
|
+
set(user_call_input.keys()) - set(autologging_call_input.keys())
|
1022
|
+
)
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
def _validate(autologging_call_input, user_call_input=None):
|
1026
|
+
"""
|
1027
|
+
Validates that the specified `autologging_call_input` and `user_call_input`
|
1028
|
+
are compatible. If `user_call_input` is `None`, then `autologging_call_input`
|
1029
|
+
is regarded as a new input added by autologging and is validated using
|
1030
|
+
`_validate_new_input`. Otherwise, the following properties must hold:
|
1031
|
+
|
1032
|
+
- `autologging_call_input` and `user_call_input` must have the same type
|
1033
|
+
(referred to as "input type")
|
1034
|
+
- if the input type is a tuple, list or dictionary, then `autologging_call_input` must
|
1035
|
+
be equivalent to `user_call_input` or be a superset of `user_call_input`
|
1036
|
+
- for all other input types, `autologging_call_input` and `user_call_input`
|
1037
|
+
must be equivalent by reference equality or by object equality
|
1038
|
+
|
1039
|
+
Args:
|
1040
|
+
autologging_call_input: call input from autologging.
|
1041
|
+
user_call_input: call input from user.
|
1042
|
+
"""
|
1043
|
+
|
1044
|
+
if user_call_input is None and autologging_call_input is not None:
|
1045
|
+
_validate_new_input(autologging_call_input)
|
1046
|
+
return
|
1047
|
+
|
1048
|
+
assert type(autologging_call_input) == type(user_call_input), (
|
1049
|
+
"Type of input to original function '{}' does not match expected type '{}'".format(
|
1050
|
+
type(autologging_call_input), type(user_call_input)
|
1051
|
+
)
|
1052
|
+
)
|
1053
|
+
|
1054
|
+
if type(autologging_call_input) in [list, tuple]:
|
1055
|
+
_assert_autologging_input_positional_args_are_superset(
|
1056
|
+
autologging_call_input, user_call_input
|
1057
|
+
)
|
1058
|
+
# If the autologging call input is longer than the user call input, we `zip_longest`
|
1059
|
+
# will pad the user call input with `None` values to ensure that the subsequent calls
|
1060
|
+
# to `_validate` identify new inputs added by the autologging call
|
1061
|
+
for a, u in itertools.zip_longest(autologging_call_input, user_call_input):
|
1062
|
+
_validate(a, u)
|
1063
|
+
elif type(autologging_call_input) == dict:
|
1064
|
+
_assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input)
|
1065
|
+
for key in autologging_call_input.keys():
|
1066
|
+
_validate(autologging_call_input[key], user_call_input.get(key, None))
|
1067
|
+
|
1068
|
+
else:
|
1069
|
+
assert (
|
1070
|
+
autologging_call_input is user_call_input
|
1071
|
+
or autologging_call_input == user_call_input
|
1072
|
+
), (
|
1073
|
+
"Input to original function does not match expected input."
|
1074
|
+
f" Original: '{autologging_call_input}'. Expected: '{user_call_input}'"
|
1075
|
+
)
|
1076
|
+
|
1077
|
+
# Similar validation logic found in _validate, unraveling the list of arguments to exclude
|
1078
|
+
# checks for any validation exempt positional arguments.
|
1079
|
+
_assert_autologging_input_positional_args_are_superset(autologging_call_args, user_call_args)
|
1080
|
+
for index, autologging_call_arg, user_call_arg in itertools.zip_longest(
|
1081
|
+
range(len(user_call_args)), autologging_call_args, user_call_args
|
1082
|
+
):
|
1083
|
+
if not _is_arg_exempt_from_validation(
|
1084
|
+
autologging_integration,
|
1085
|
+
function_name,
|
1086
|
+
user_call_arg,
|
1087
|
+
argument_index=index,
|
1088
|
+
):
|
1089
|
+
_validate(autologging_call_arg, user_call_arg)
|
1090
|
+
|
1091
|
+
# Similar validation logic found in _validate, unraveling the dictionary of arguments to exclude
|
1092
|
+
# checks for any validation exempt keyword arguments.
|
1093
|
+
_assert_autologging_input_kwargs_are_superset(autologging_call_kwargs, user_call_kwargs)
|
1094
|
+
for key in autologging_call_kwargs.keys():
|
1095
|
+
if not _is_arg_exempt_from_validation(
|
1096
|
+
autologging_integration,
|
1097
|
+
function_name,
|
1098
|
+
user_call_kwargs.get(key, None),
|
1099
|
+
argument_name=key,
|
1100
|
+
):
|
1101
|
+
_validate(
|
1102
|
+
autologging_call_kwargs[key],
|
1103
|
+
user_call_kwargs.get(key, None),
|
1104
|
+
)
|