genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,737 @@
|
|
1
|
+
import contextlib
|
2
|
+
import importlib
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import threading
|
6
|
+
import time
|
7
|
+
from typing import Any, Callable, Optional
|
8
|
+
|
9
|
+
import mlflow
|
10
|
+
from mlflow.entities import Metric
|
11
|
+
from mlflow.utils.validation import MAX_METRICS_PER_BATCH
|
12
|
+
|
13
|
+
# Define the module-level logger for autologging utilities before importing utilities defined in
|
14
|
+
# submodules (e.g., `safety`, `events`) that depend on the module-level logger. Add the `noqa: E402`
|
15
|
+
# comment after each subsequent import to ignore "import not at top of file" code style errors
|
16
|
+
_logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
# Import autologging utilities used by this module
|
19
|
+
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS, FLAVOR_TO_MODULE_NAME
|
20
|
+
from mlflow.utils.autologging_utils.client import MlflowAutologgingQueueingClient # noqa: F401
|
21
|
+
from mlflow.utils.autologging_utils.events import AutologgingEventLogger
|
22
|
+
from mlflow.utils.autologging_utils.logging_and_warnings import (
|
23
|
+
MlflowEventsAndWarningsBehaviorGlobally,
|
24
|
+
NonMlflowWarningsBehaviorForCurrentThread,
|
25
|
+
)
|
26
|
+
|
27
|
+
# Wildcard import other autologging utilities (e.g. safety utilities, event logging utilities) used
|
28
|
+
# in autologging integration implementations, which reference them via the
|
29
|
+
# `mlflow.utils.autologging_utils` module
|
30
|
+
from mlflow.utils.autologging_utils.safety import ( # noqa: F401
|
31
|
+
ExceptionSafeAbstractClass,
|
32
|
+
ExceptionSafeClass,
|
33
|
+
exception_safe_function_for_class,
|
34
|
+
is_testing,
|
35
|
+
picklable_exception_safe_function,
|
36
|
+
revert_patches,
|
37
|
+
safe_patch,
|
38
|
+
update_wrapper_extended,
|
39
|
+
with_managed_run,
|
40
|
+
)
|
41
|
+
from mlflow.utils.autologging_utils.versioning import (
|
42
|
+
get_min_max_version_and_pip_release,
|
43
|
+
is_flavor_supported_for_associated_package_versions,
|
44
|
+
)
|
45
|
+
|
46
|
+
INPUT_EXAMPLE_SAMPLE_ROWS = 5
|
47
|
+
ENSURE_AUTOLOGGING_ENABLED_TEXT = (
|
48
|
+
"please ensure that autologging is enabled before constructing the dataset."
|
49
|
+
)
|
50
|
+
|
51
|
+
# Flag indicating whether autologging is globally disabled for all integrations.
|
52
|
+
_AUTOLOGGING_GLOBALLY_DISABLED = False
|
53
|
+
|
54
|
+
# Autologging config key indicating whether or not a particular autologging integration
|
55
|
+
# was configured (i.e. its various `log_models`, `disable`, etc. configuration options
|
56
|
+
# were set) via a call to `mlflow.autolog()`, rather than via a call to the integration-specific
|
57
|
+
# autologging method (e.g., `mlflow.tensorflow.autolog()`, ...)
|
58
|
+
AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED = "globally_configured"
|
59
|
+
|
60
|
+
# Dict mapping integration name to its config.
|
61
|
+
AUTOLOGGING_INTEGRATIONS = {}
|
62
|
+
|
63
|
+
# When the library version installed in the user's environment is outside of the supported
|
64
|
+
# version range declared in `ml-package-versions.yml`, a warning message is issued to the user.
|
65
|
+
# However, some libraries releases versions very frequently, and our configuration (updated on
|
66
|
+
# MLflow release) cannot keep up with the pace, resulting in false alarms. Therefore, we
|
67
|
+
# suppress warnings for certain libraries that are known to have frequent releases.
|
68
|
+
_AUTOLOGGING_SUPPORTED_VERSION_WARNING_SUPPRESS_LIST = [
|
69
|
+
"langchain",
|
70
|
+
"llama_index",
|
71
|
+
"litellm",
|
72
|
+
"openai",
|
73
|
+
"dspy",
|
74
|
+
"autogen",
|
75
|
+
"ag2",
|
76
|
+
"gemini",
|
77
|
+
"anthropic",
|
78
|
+
"crewai",
|
79
|
+
"bedrock",
|
80
|
+
]
|
81
|
+
|
82
|
+
# Global lock for turning on / off autologging
|
83
|
+
# Note "RLock" is required instead of plain lock, for avoid dead-lock
|
84
|
+
_autolog_conf_global_lock = threading.RLock()
|
85
|
+
|
86
|
+
_logger = logging.getLogger(__name__)
|
87
|
+
|
88
|
+
|
89
|
+
def autologging_conf_lock(fn):
|
90
|
+
"""
|
91
|
+
Apply a global lock on functions that enable / disable autologging.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def wrapper(*args, **kwargs):
|
95
|
+
with _autolog_conf_global_lock:
|
96
|
+
return fn(*args, **kwargs)
|
97
|
+
|
98
|
+
return update_wrapper_extended(wrapper, fn)
|
99
|
+
|
100
|
+
|
101
|
+
def get_mlflow_run_params_for_fn_args(fn, args, kwargs, unlogged=None):
|
102
|
+
"""Given arguments explicitly passed to a function, generate a dictionary of MLflow Run
|
103
|
+
parameter key / value pairs.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
fn: function whose parameters are to be logged.
|
107
|
+
args: arguments explicitly passed into fn. If `fn` is defined on a class,
|
108
|
+
`self` should not be part of `args`; the caller is responsible for
|
109
|
+
filtering out `self` before calling this function.
|
110
|
+
kwargs: kwargs explicitly passed into fn.
|
111
|
+
unlogged: parameters not to be logged.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
A dictionary of MLflow Run parameter key / value pairs.
|
115
|
+
"""
|
116
|
+
unlogged = unlogged or []
|
117
|
+
param_spec = inspect.signature(fn).parameters
|
118
|
+
# Filter out `self` from the signature under the assumption that it is not contained
|
119
|
+
# within the specified `args`, as stipulated by the documentation
|
120
|
+
relevant_params = [param for param in param_spec.values() if param.name != "self"]
|
121
|
+
|
122
|
+
# Fetch the parameter names for specified positional arguments from the function
|
123
|
+
# signature & create a mapping from positional argument name to specified value
|
124
|
+
params_to_log = {
|
125
|
+
param_info.name: param_val
|
126
|
+
for param_info, param_val in zip(list(relevant_params)[: len(args)], args)
|
127
|
+
}
|
128
|
+
# Add all user-specified keyword arguments to the set of parameters to log
|
129
|
+
params_to_log.update(kwargs)
|
130
|
+
# Add parameters that were not explicitly specified by the caller to the mapping,
|
131
|
+
# using their default values
|
132
|
+
params_to_log.update(
|
133
|
+
{
|
134
|
+
param.name: param.default
|
135
|
+
for param in list(relevant_params)[len(args) :]
|
136
|
+
if param.name not in kwargs
|
137
|
+
}
|
138
|
+
)
|
139
|
+
# Filter out any parameters that should not be logged, as specified by the `unlogged` parameter
|
140
|
+
return {key: value for key, value in params_to_log.items() if key not in unlogged}
|
141
|
+
|
142
|
+
|
143
|
+
def log_fn_args_as_params(fn, args, kwargs, unlogged=None):
|
144
|
+
"""Log arguments explicitly passed to a function as MLflow Run parameters to the current active
|
145
|
+
MLflow Run.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
fn: function whose parameters are to be logged
|
149
|
+
args: arguments explicitly passed into fn. If `fn` is defined on a class,
|
150
|
+
`self` should not be part of `args`; the caller is responsible for
|
151
|
+
filtering out `self` before calling this function.
|
152
|
+
kwargs: kwargs explicitly passed into fn
|
153
|
+
unlogged: parameters not to be logged
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
None
|
157
|
+
|
158
|
+
"""
|
159
|
+
params_to_log = get_mlflow_run_params_for_fn_args(fn, args, kwargs, unlogged)
|
160
|
+
mlflow.log_params(params_to_log)
|
161
|
+
|
162
|
+
|
163
|
+
class InputExampleInfo:
|
164
|
+
"""
|
165
|
+
Stores info about the input example collection before it is needed.
|
166
|
+
|
167
|
+
For example, in xgboost and lightgbm, an InputExampleInfo object is attached to the dataset,
|
168
|
+
where its value is read later by the train method.
|
169
|
+
|
170
|
+
Exactly one of input_example or error_msg should be populated.
|
171
|
+
"""
|
172
|
+
|
173
|
+
def __init__(self, input_example=None, error_msg=None):
|
174
|
+
self.input_example = input_example
|
175
|
+
self.error_msg = error_msg
|
176
|
+
|
177
|
+
|
178
|
+
def resolve_input_example_and_signature(
|
179
|
+
get_input_example, infer_model_signature, log_input_example, log_model_signature, logger
|
180
|
+
):
|
181
|
+
"""Handles the logic of calling functions to gather the input example and infer the model
|
182
|
+
signature.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
get_input_example: Function which returns an input example, usually sliced from a
|
186
|
+
dataset. This function can raise an exception, its message will be
|
187
|
+
shown to the user in a warning in the logs.
|
188
|
+
infer_model_signature: Function which takes an input example and returns the signature
|
189
|
+
of the inputs and outputs of the model. This function can raise
|
190
|
+
an exception, its message will be shown to the user in a warning
|
191
|
+
in the logs.
|
192
|
+
log_input_example: Whether to log errors while collecting the input example, and if it
|
193
|
+
succeeds, whether to return the input example to the user. We collect
|
194
|
+
it even if this parameter is False because it is needed for inferring
|
195
|
+
the model signature.
|
196
|
+
log_model_signature: Whether to infer and return the model signature.
|
197
|
+
logger: The logger instance used to log warnings to the user during input example
|
198
|
+
collection and model signature inference.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
A tuple of input_example and signature. Either or both could be None based on the
|
202
|
+
values of log_input_example and log_model_signature.
|
203
|
+
|
204
|
+
"""
|
205
|
+
|
206
|
+
input_example = None
|
207
|
+
input_example_user_msg = None
|
208
|
+
input_example_failure_msg = None
|
209
|
+
if log_input_example or log_model_signature:
|
210
|
+
try:
|
211
|
+
input_example = get_input_example()
|
212
|
+
except Exception as e:
|
213
|
+
input_example_failure_msg = str(e)
|
214
|
+
input_example_user_msg = "Failed to gather input example: " + str(e)
|
215
|
+
|
216
|
+
model_signature = None
|
217
|
+
model_signature_user_msg = None
|
218
|
+
if log_model_signature:
|
219
|
+
try:
|
220
|
+
if input_example is None:
|
221
|
+
raise Exception(
|
222
|
+
"could not sample data to infer model signature: " + input_example_failure_msg
|
223
|
+
)
|
224
|
+
model_signature = infer_model_signature(input_example)
|
225
|
+
except Exception as e:
|
226
|
+
model_signature_user_msg = "Failed to infer model signature: " + str(e)
|
227
|
+
|
228
|
+
# disable input_example signature inference in model logging if `log_model_signature`
|
229
|
+
# is set to `False` or signature inference in autologging fails
|
230
|
+
if (
|
231
|
+
model_signature is None
|
232
|
+
and input_example is not None
|
233
|
+
and (not log_model_signature or model_signature_user_msg is not None)
|
234
|
+
):
|
235
|
+
model_signature = False
|
236
|
+
|
237
|
+
if log_input_example and input_example_user_msg is not None:
|
238
|
+
logger.warning(input_example_user_msg)
|
239
|
+
if log_model_signature and model_signature_user_msg is not None:
|
240
|
+
logger.warning(model_signature_user_msg)
|
241
|
+
|
242
|
+
return input_example if log_input_example else None, model_signature
|
243
|
+
|
244
|
+
|
245
|
+
class BatchMetricsLogger:
|
246
|
+
"""
|
247
|
+
The BatchMetricsLogger will log metrics in batch against an mlflow run.
|
248
|
+
If run_id is passed to to constructor then all recording and logging will
|
249
|
+
happen against that run_id.
|
250
|
+
If no run_id is passed into constructor, then the run ID will be fetched
|
251
|
+
from `mlflow.active_run()` each time `record_metrics()` or `flush()` is called; in this
|
252
|
+
case, callers must ensure that an active run is present before invoking
|
253
|
+
`record_metrics()` or `flush()`.
|
254
|
+
"""
|
255
|
+
|
256
|
+
def __init__(self, run_id=None, tracking_uri=None, model_id=None):
|
257
|
+
from mlflow.tracking.client import MlflowClient
|
258
|
+
|
259
|
+
self.run_id = run_id
|
260
|
+
self.model_id = model_id
|
261
|
+
self.client = MlflowClient(tracking_uri)
|
262
|
+
|
263
|
+
# data is an array of Metric objects
|
264
|
+
self.data = []
|
265
|
+
self.total_training_time = 0
|
266
|
+
self.total_log_batch_time = 0
|
267
|
+
self.previous_training_timestamp = None
|
268
|
+
|
269
|
+
def flush(self):
|
270
|
+
"""
|
271
|
+
The metrics accumulated by BatchMetricsLogger will be batch logged to an MLflow run.
|
272
|
+
"""
|
273
|
+
self._timed_log_batch()
|
274
|
+
self.data = []
|
275
|
+
|
276
|
+
def _timed_log_batch(self):
|
277
|
+
# Retrieving run_id from active mlflow run when run_id is empty.
|
278
|
+
current_run_id = mlflow.active_run().info.run_id if self.run_id is None else self.run_id
|
279
|
+
|
280
|
+
start = time.time()
|
281
|
+
metrics_slices = [
|
282
|
+
self.data[i : i + MAX_METRICS_PER_BATCH]
|
283
|
+
for i in range(0, len(self.data), MAX_METRICS_PER_BATCH)
|
284
|
+
]
|
285
|
+
for metrics_slice in metrics_slices:
|
286
|
+
self.client.log_batch(run_id=current_run_id, metrics=metrics_slice)
|
287
|
+
end = time.time()
|
288
|
+
self.total_log_batch_time += end - start
|
289
|
+
|
290
|
+
def _should_flush(self):
|
291
|
+
target_training_to_logging_time_ratio = 10
|
292
|
+
if (
|
293
|
+
self.total_training_time
|
294
|
+
>= self.total_log_batch_time * target_training_to_logging_time_ratio
|
295
|
+
):
|
296
|
+
return True
|
297
|
+
|
298
|
+
return False
|
299
|
+
|
300
|
+
def record_metrics(self, metrics, step=None):
|
301
|
+
"""
|
302
|
+
Submit a set of metrics to be logged. The metrics may not be immediately logged, as this
|
303
|
+
class will batch them in order to not increase execution time too much by logging
|
304
|
+
frequently.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
metrics: Dictionary containing key, value pairs of metrics to be logged.
|
308
|
+
step: The training step that the metrics correspond to.
|
309
|
+
"""
|
310
|
+
current_timestamp = time.time()
|
311
|
+
if self.previous_training_timestamp is None:
|
312
|
+
self.previous_training_timestamp = current_timestamp
|
313
|
+
|
314
|
+
training_time = current_timestamp - self.previous_training_timestamp
|
315
|
+
|
316
|
+
self.total_training_time += training_time
|
317
|
+
|
318
|
+
# log_batch() requires step to be defined. Therefore will set step to 0 if not defined.
|
319
|
+
if step is None:
|
320
|
+
step = 0
|
321
|
+
|
322
|
+
for key, value in metrics.items():
|
323
|
+
self.data.append(
|
324
|
+
Metric(key, value, int(current_timestamp * 1000), step, model_id=self.model_id)
|
325
|
+
)
|
326
|
+
|
327
|
+
if self._should_flush():
|
328
|
+
self.flush()
|
329
|
+
|
330
|
+
self.previous_training_timestamp = current_timestamp
|
331
|
+
|
332
|
+
|
333
|
+
@contextlib.contextmanager
|
334
|
+
def batch_metrics_logger(run_id: Optional[str] = None, model_id: Optional[str] = None):
|
335
|
+
"""
|
336
|
+
Context manager that yields a BatchMetricsLogger object, which metrics can be logged against.
|
337
|
+
The BatchMetricsLogger keeps metrics in a list until it decides they should be logged, at
|
338
|
+
which point the accumulated metrics will be batch logged. The BatchMetricsLogger ensures
|
339
|
+
that logging imposes no more than a 10% overhead on the training, where the training is
|
340
|
+
measured by adding up the time elapsed between consecutive calls to record_metrics.
|
341
|
+
|
342
|
+
If logging a batch fails, a warning will be emitted and subsequent metrics will continue to
|
343
|
+
be collected.
|
344
|
+
|
345
|
+
Once the context is closed, any metrics that have yet to be logged will be logged.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
run_id: ID of the run that the metrics will be logged to.
|
349
|
+
model_id: ID of the model that the metrics will be associated with.
|
350
|
+
"""
|
351
|
+
|
352
|
+
batch_metrics_logger = BatchMetricsLogger(run_id, model_id=model_id)
|
353
|
+
yield batch_metrics_logger
|
354
|
+
batch_metrics_logger.flush()
|
355
|
+
|
356
|
+
|
357
|
+
def gen_autologging_package_version_requirements_doc(integration_name):
|
358
|
+
"""
|
359
|
+
Returns:
|
360
|
+
A document note string saying the compatibility for the specified autologging
|
361
|
+
integration's associated package versions.
|
362
|
+
"""
|
363
|
+
min_ver, max_ver, pip_release = get_min_max_version_and_pip_release(integration_name)
|
364
|
+
required_pkg_versions = f"``{min_ver}`` <= ``{pip_release}`` <= ``{max_ver}``"
|
365
|
+
|
366
|
+
return (
|
367
|
+
" .. Note:: Autologging is known to be compatible with the following package versions: "
|
368
|
+
+ required_pkg_versions
|
369
|
+
+ ". Autologging may not succeed when used with package versions outside of this range."
|
370
|
+
+ "\n\n"
|
371
|
+
)
|
372
|
+
|
373
|
+
|
374
|
+
def _check_and_log_warning_for_unsupported_package_versions(integration_name):
|
375
|
+
"""
|
376
|
+
When autologging is enabled and `disable_for_unsupported_versions=False` for the specified
|
377
|
+
autologging integration, check whether the currently-installed versions of the integration's
|
378
|
+
associated package versions are supported by the specified integration. If the package versions
|
379
|
+
are not supported, log a warning message.
|
380
|
+
"""
|
381
|
+
if (
|
382
|
+
integration_name in FLAVOR_TO_MODULE_NAME
|
383
|
+
and integration_name not in _AUTOLOGGING_SUPPORTED_VERSION_WARNING_SUPPRESS_LIST
|
384
|
+
and not get_autologging_config(integration_name, "disable", True)
|
385
|
+
and not get_autologging_config(integration_name, "disable_for_unsupported_versions", False)
|
386
|
+
and not is_flavor_supported_for_associated_package_versions(integration_name)
|
387
|
+
):
|
388
|
+
min_var, max_var, pip_release = get_min_max_version_and_pip_release(integration_name)
|
389
|
+
module = importlib.import_module(FLAVOR_TO_MODULE_NAME[integration_name])
|
390
|
+
_logger.warning(
|
391
|
+
f"MLflow {integration_name} autologging is known to be compatible with "
|
392
|
+
f"{min_var} <= {pip_release} <= {max_var}, but the installed version is "
|
393
|
+
f"{module.__version__}. If you encounter errors during autologging, try upgrading "
|
394
|
+
f"/ downgrading {pip_release} to a compatible version, or try upgrading MLflow.",
|
395
|
+
)
|
396
|
+
|
397
|
+
|
398
|
+
def autologging_integration(name):
|
399
|
+
"""
|
400
|
+
**All autologging integrations should be decorated with this wrapper.**
|
401
|
+
|
402
|
+
Wraps an autologging function in order to store its configuration arguments. This enables
|
403
|
+
patch functions to broadly obey certain configurations (e.g., disable=True) without
|
404
|
+
requiring specific logic to be present in each autologging integration.
|
405
|
+
"""
|
406
|
+
|
407
|
+
def validate_param_spec(param_spec):
|
408
|
+
if "disable" not in param_spec or param_spec["disable"].default is not False:
|
409
|
+
raise Exception(
|
410
|
+
f"Invalid `autolog()` function for integration '{name}'. `autolog()` functions"
|
411
|
+
" must specify a 'disable' argument with default value 'False'"
|
412
|
+
)
|
413
|
+
elif "silent" not in param_spec or param_spec["silent"].default is not False:
|
414
|
+
raise Exception(
|
415
|
+
f"Invalid `autolog()` function for integration '{name}'. `autolog()` functions"
|
416
|
+
" must specify a 'silent' argument with default value 'False'"
|
417
|
+
)
|
418
|
+
|
419
|
+
def wrapper(_autolog):
|
420
|
+
param_spec = inspect.signature(_autolog).parameters
|
421
|
+
validate_param_spec(param_spec)
|
422
|
+
|
423
|
+
AUTOLOGGING_INTEGRATIONS[name] = {}
|
424
|
+
default_params = {param.name: param.default for param in param_spec.values()}
|
425
|
+
|
426
|
+
@autologging_conf_lock
|
427
|
+
def autolog(*args, **kwargs):
|
428
|
+
config_to_store = dict(default_params)
|
429
|
+
config_to_store.update(
|
430
|
+
{param.name: arg for arg, param in zip(args, param_spec.values())}
|
431
|
+
)
|
432
|
+
config_to_store.update(kwargs)
|
433
|
+
AUTOLOGGING_INTEGRATIONS[name] = config_to_store
|
434
|
+
|
435
|
+
try:
|
436
|
+
# Pass `autolog()` arguments to `log_autolog_called` in keyword format to enable
|
437
|
+
# event loggers to more easily identify important configuration parameters
|
438
|
+
# (e.g., `disable`) without examining positional arguments. Passing positional
|
439
|
+
# arguments to `log_autolog_called` is deprecated in MLflow > 1.13.1
|
440
|
+
AutologgingEventLogger.get_logger().log_autolog_called(name, (), config_to_store)
|
441
|
+
except Exception:
|
442
|
+
pass
|
443
|
+
|
444
|
+
revert_patches(name)
|
445
|
+
|
446
|
+
# If disabling autologging using fluent api, then every active integration's autolog
|
447
|
+
# needs to be called with disable=True. So do not short circuit and let
|
448
|
+
# `mlflow.autolog()` invoke all active integrations with disable=True.
|
449
|
+
if name != "mlflow" and get_autologging_config(name, "disable", True):
|
450
|
+
return
|
451
|
+
|
452
|
+
is_silent_mode = get_autologging_config(name, "silent", False)
|
453
|
+
# Reroute non-MLflow warnings encountered during autologging enablement to an
|
454
|
+
# MLflow event logger, and enforce silent mode if applicable (i.e. if the corresponding
|
455
|
+
# autologging integration was called with `silent=True`)
|
456
|
+
with (
|
457
|
+
MlflowEventsAndWarningsBehaviorGlobally(
|
458
|
+
# MLflow warnings emitted during autologging setup / enablement are likely
|
459
|
+
# actionable and relevant to the user, so they should be emitted as normal
|
460
|
+
# when `silent=False`. For reference, see recommended warning and event logging
|
461
|
+
# behaviors from https://docs.python.org/3/howto/logging.html#when-to-use-logging
|
462
|
+
reroute_warnings=False,
|
463
|
+
disable_event_logs=is_silent_mode,
|
464
|
+
disable_warnings=is_silent_mode,
|
465
|
+
),
|
466
|
+
NonMlflowWarningsBehaviorForCurrentThread(
|
467
|
+
# non-MLflow warnings emitted during autologging setup / enablement are not
|
468
|
+
# actionable for the user, as they are a byproduct of the autologging
|
469
|
+
# implementation. Accordingly, they should be rerouted to `logger.warning()`.
|
470
|
+
# For reference, see recommended warning and event logging
|
471
|
+
# behaviors from https://docs.python.org/3/howto/logging.html#when-to-use-logging
|
472
|
+
reroute_warnings=True,
|
473
|
+
disable_warnings=is_silent_mode,
|
474
|
+
),
|
475
|
+
):
|
476
|
+
_check_and_log_warning_for_unsupported_package_versions(name)
|
477
|
+
|
478
|
+
return _autolog(*args, **kwargs)
|
479
|
+
|
480
|
+
wrapped_autolog = update_wrapper_extended(autolog, _autolog)
|
481
|
+
# Set the autologging integration name as a function attribute on the wrapped autologging
|
482
|
+
# function, allowing the integration name to be extracted from the function. This is used
|
483
|
+
# during the execution of import hooks for `mlflow.autolog()`.
|
484
|
+
wrapped_autolog.integration_name = name
|
485
|
+
|
486
|
+
if name in FLAVOR_TO_MODULE_NAME:
|
487
|
+
wrapped_autolog.__doc__ = gen_autologging_package_version_requirements_doc(name) + (
|
488
|
+
wrapped_autolog.__doc__ or ""
|
489
|
+
)
|
490
|
+
return wrapped_autolog
|
491
|
+
|
492
|
+
return wrapper
|
493
|
+
|
494
|
+
|
495
|
+
def get_autologging_config(flavor_name, config_key, default_value=None):
|
496
|
+
"""
|
497
|
+
Returns a desired config value for a specified autologging integration.
|
498
|
+
|
499
|
+
Returns `None` if specified `flavor_name` has no recorded configs.
|
500
|
+
If `config_key` is not set on the config object, default value is returned.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
flavor_name: An autologging integration flavor name.
|
504
|
+
config_key: The key for the desired config value.
|
505
|
+
default_value: The default_value to return.
|
506
|
+
"""
|
507
|
+
config = AUTOLOGGING_INTEGRATIONS.get(flavor_name)
|
508
|
+
if config is not None:
|
509
|
+
return config.get(config_key, default_value)
|
510
|
+
else:
|
511
|
+
return default_value
|
512
|
+
|
513
|
+
|
514
|
+
def autologging_is_disabled(integration_name):
|
515
|
+
"""Returns a boolean flag of whether the autologging integration is disabled.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
integration_name: An autologging integration flavor name.
|
519
|
+
|
520
|
+
"""
|
521
|
+
explicit_disabled = get_autologging_config(integration_name, "disable", True)
|
522
|
+
if explicit_disabled:
|
523
|
+
return True
|
524
|
+
|
525
|
+
if (
|
526
|
+
integration_name in FLAVOR_TO_MODULE_NAME
|
527
|
+
and get_autologging_config(integration_name, "disable_for_unsupported_versions", False)
|
528
|
+
and not is_flavor_supported_for_associated_package_versions(integration_name)
|
529
|
+
):
|
530
|
+
return True
|
531
|
+
|
532
|
+
return False
|
533
|
+
|
534
|
+
|
535
|
+
def is_autolog_supported(integration_name: str) -> bool:
|
536
|
+
"""
|
537
|
+
Whether the specified autologging integration is supported by the current environment.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
integration_name: An autologging integration flavor name.
|
541
|
+
"""
|
542
|
+
# NB: We don't check for the presence of autolog() function as it requires importing
|
543
|
+
# the flavor module, which may cause import error or overhead.
|
544
|
+
return "autologging" in _ML_PACKAGE_VERSIONS.get(integration_name, {})
|
545
|
+
|
546
|
+
|
547
|
+
def get_autolog_function(integration_name: str) -> Optional[Callable[..., Any]]:
|
548
|
+
"""
|
549
|
+
Get the autolog() function for the specified integration.
|
550
|
+
Returns None if the flavor does not have an autolog() function.
|
551
|
+
"""
|
552
|
+
flavor_module = importlib.import_module(f"mlflow.{integration_name}")
|
553
|
+
return getattr(flavor_module, "autolog", None)
|
554
|
+
|
555
|
+
|
556
|
+
@contextlib.contextmanager
|
557
|
+
def disable_autologging():
|
558
|
+
"""
|
559
|
+
Context manager that temporarily disables autologging globally for all integrations upon
|
560
|
+
entry and restores the previous autologging configuration upon exit.
|
561
|
+
"""
|
562
|
+
global _AUTOLOGGING_GLOBALLY_DISABLED
|
563
|
+
_AUTOLOGGING_GLOBALLY_DISABLED = True
|
564
|
+
try:
|
565
|
+
yield
|
566
|
+
finally:
|
567
|
+
_AUTOLOGGING_GLOBALLY_DISABLED = False
|
568
|
+
|
569
|
+
|
570
|
+
@contextlib.contextmanager
|
571
|
+
def disable_discrete_autologging(flavors_to_disable: list[str]) -> None:
|
572
|
+
"""
|
573
|
+
Context manager for disabling specific autologging integrations temporarily while another
|
574
|
+
flavor's autologging is activated. This context wrapper is useful in the event that, for
|
575
|
+
example, a particular library calls upon another library within a training API that has a
|
576
|
+
current MLflow autologging integration.
|
577
|
+
For instance, the transformers library's Trainer class, when running metric scoring,
|
578
|
+
builds a sklearn model and runs evaluations as part of its accuracy scoring. Without this
|
579
|
+
temporary autologging disabling, a new run will be generated that contains a sklearn model
|
580
|
+
that holds no use for tracking purposes as it is only used during the metric evaluation phase
|
581
|
+
of training.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
flavors_to_disable: A list of flavors that need to be temporarily disabled while
|
585
|
+
executing another flavor's autologging to prevent spurious run
|
586
|
+
logging of unrelated models, metrics, and parameters.
|
587
|
+
"""
|
588
|
+
enabled_flavors = []
|
589
|
+
for flavor in flavors_to_disable:
|
590
|
+
if not autologging_is_disabled(flavor):
|
591
|
+
enabled_flavors.append(flavor)
|
592
|
+
autolog_func = getattr(mlflow, flavor)
|
593
|
+
autolog_func.autolog(disable=True)
|
594
|
+
yield
|
595
|
+
for flavor in enabled_flavors:
|
596
|
+
autolog_func = getattr(mlflow, flavor)
|
597
|
+
autolog_func.autolog(disable=False)
|
598
|
+
|
599
|
+
|
600
|
+
_training_sessions = []
|
601
|
+
|
602
|
+
|
603
|
+
def _get_new_training_session_class():
|
604
|
+
"""
|
605
|
+
Returns a session manager class for nested autologging runs.
|
606
|
+
|
607
|
+
Examples
|
608
|
+
--------
|
609
|
+
>>> class Parent:
|
610
|
+
... pass
|
611
|
+
>>> class Child:
|
612
|
+
... pass
|
613
|
+
>>> class Grandchild:
|
614
|
+
... pass
|
615
|
+
>>>
|
616
|
+
>>> _TrainingSession = _get_new_training_session_class()
|
617
|
+
>>> with _TrainingSession(Parent, False) as p:
|
618
|
+
... with _SklearnTrainingSession(Child, True) as c:
|
619
|
+
... with _SklearnTrainingSession(Grandchild, True) as g:
|
620
|
+
... print(p.should_log(), c.should_log(), g.should_log())
|
621
|
+
True False False
|
622
|
+
>>>
|
623
|
+
>>> with _TrainingSession(Parent, True) as p:
|
624
|
+
... with _TrainingSession(Child, False) as c:
|
625
|
+
... with _TrainingSession(Grandchild, True) as g:
|
626
|
+
... print(p.should_log(), c.should_log(), g.should_log())
|
627
|
+
True True False
|
628
|
+
>>>
|
629
|
+
>>> with _TrainingSession(Child, True) as c1:
|
630
|
+
... with _TrainingSession(Child, True) as c2:
|
631
|
+
... print(c1.should_log(), c2.should_log())
|
632
|
+
True False
|
633
|
+
"""
|
634
|
+
|
635
|
+
# NOTE: The current implementation doesn't guarantee thread-safety, but that's okay for now
|
636
|
+
# because:
|
637
|
+
# 1. We don't currently have any use cases for allow_children=True.
|
638
|
+
# 2. The list append & pop operations are thread-safe, so we will always clear the session stack
|
639
|
+
# once all _TrainingSessions exit.
|
640
|
+
class _TrainingSession:
|
641
|
+
_session_stack = []
|
642
|
+
|
643
|
+
def __init__(self, estimator, allow_children=True):
|
644
|
+
"""A session manager for nested autologging runs.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
estimator: An estimator that this session originates from.
|
648
|
+
allow_children: If True, allows autologging in child sessions.
|
649
|
+
If False, disallows autologging in all descendant sessions.
|
650
|
+
|
651
|
+
"""
|
652
|
+
self.allow_children = allow_children
|
653
|
+
self.estimator = estimator
|
654
|
+
self._parent = None
|
655
|
+
|
656
|
+
def __enter__(self):
|
657
|
+
if len(_TrainingSession._session_stack) > 0:
|
658
|
+
self._parent = _TrainingSession._session_stack[-1]
|
659
|
+
self.allow_children = (
|
660
|
+
_TrainingSession._session_stack[-1].allow_children and self.allow_children
|
661
|
+
)
|
662
|
+
_TrainingSession._session_stack.append(self)
|
663
|
+
return self
|
664
|
+
|
665
|
+
def __exit__(self, tp, val, traceback):
|
666
|
+
_TrainingSession._session_stack.pop()
|
667
|
+
|
668
|
+
def should_log(self):
|
669
|
+
"""
|
670
|
+
Returns True when at least one of the following conditions satisfies:
|
671
|
+
|
672
|
+
1. This session is the root session.
|
673
|
+
2. The parent session allows autologging and its estimator differs from this session's
|
674
|
+
estimator.
|
675
|
+
"""
|
676
|
+
for training_session in _TrainingSession._session_stack:
|
677
|
+
if training_session is self:
|
678
|
+
break
|
679
|
+
elif training_session.estimator is self.estimator:
|
680
|
+
return False
|
681
|
+
|
682
|
+
return self._parent is None or self._parent.allow_children
|
683
|
+
|
684
|
+
@staticmethod
|
685
|
+
def is_active():
|
686
|
+
return len(_TrainingSession._session_stack) != 0
|
687
|
+
|
688
|
+
@staticmethod
|
689
|
+
def get_current_session():
|
690
|
+
if _TrainingSession.is_active():
|
691
|
+
return _TrainingSession._session_stack[-1]
|
692
|
+
return None
|
693
|
+
|
694
|
+
_training_sessions.append(_TrainingSession)
|
695
|
+
return _TrainingSession
|
696
|
+
|
697
|
+
|
698
|
+
def _has_active_training_session():
|
699
|
+
return any(s.is_active() for s in _training_sessions)
|
700
|
+
|
701
|
+
|
702
|
+
def get_instance_method_first_arg_value(method, call_pos_args, call_kwargs):
|
703
|
+
"""Get instance method first argument value (exclude the `self` argument).
|
704
|
+
|
705
|
+
Args:
|
706
|
+
method: A `cls.method` object which includes the `self` argument.
|
707
|
+
call_pos_args: positional arguments excluding the first `self` argument.
|
708
|
+
call_kwargs: keywords arguments.
|
709
|
+
"""
|
710
|
+
if len(call_pos_args) >= 1:
|
711
|
+
return call_pos_args[0]
|
712
|
+
else:
|
713
|
+
param_sig = inspect.signature(method).parameters
|
714
|
+
first_arg_name = list(param_sig.keys())[1]
|
715
|
+
assert param_sig[first_arg_name].kind not in [
|
716
|
+
inspect.Parameter.VAR_KEYWORD,
|
717
|
+
inspect.Parameter.VAR_POSITIONAL,
|
718
|
+
]
|
719
|
+
return call_kwargs.get(first_arg_name)
|
720
|
+
|
721
|
+
|
722
|
+
def get_method_call_arg_value(arg_index, arg_name, default_value, call_pos_args, call_kwargs):
|
723
|
+
"""Get argument value for a method call.
|
724
|
+
|
725
|
+
Args:
|
726
|
+
arg_index: The argument index in the function signature. Start from 0.
|
727
|
+
arg_name: The argument name in the function signature.
|
728
|
+
default_value: Default argument value.
|
729
|
+
call_pos_args: The positional argument values in the method call.
|
730
|
+
call_kwargs: The keyword argument values in the method call.
|
731
|
+
"""
|
732
|
+
if arg_name in call_kwargs:
|
733
|
+
return call_kwargs[arg_name]
|
734
|
+
elif arg_index < len(call_pos_args):
|
735
|
+
return call_pos_args[arg_index]
|
736
|
+
else:
|
737
|
+
return default_value
|