genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/sklearn/utils.py
ADDED
@@ -0,0 +1,1041 @@
|
|
1
|
+
import collections
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
import pkgutil
|
5
|
+
import platform
|
6
|
+
import warnings
|
7
|
+
from copy import deepcopy
|
8
|
+
from importlib import import_module
|
9
|
+
from numbers import Number
|
10
|
+
from operator import itemgetter
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from packaging.version import Version
|
14
|
+
|
15
|
+
from mlflow import MlflowClient
|
16
|
+
from mlflow.entities.dataset_input import DatasetInput
|
17
|
+
from mlflow.entities.input_tag import InputTag
|
18
|
+
from mlflow.tracking.fluent import MLFLOW_DATASET_CONTEXT
|
19
|
+
from mlflow.utils.arguments_utils import _get_arg_names
|
20
|
+
from mlflow.utils.file_utils import TempDir
|
21
|
+
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
|
22
|
+
from mlflow.utils.time import get_current_time_millis
|
23
|
+
|
24
|
+
_logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
# The prefix to note that all calculated metrics and artifacts are solely based on training datasets
|
27
|
+
_TRAINING_PREFIX = "training_"
|
28
|
+
|
29
|
+
_SAMPLE_WEIGHT = "sample_weight"
|
30
|
+
|
31
|
+
# _SklearnArtifact represents a artifact (e.g confusion matrix) that will be computed and
|
32
|
+
# logged during the autologging routine for a particular model type (eg, classifier, regressor).
|
33
|
+
_SklearnArtifact = collections.namedtuple(
|
34
|
+
"_SklearnArtifact", ["name", "function", "arguments", "title"]
|
35
|
+
)
|
36
|
+
|
37
|
+
# _SklearnMetric represents a metric (e.g, precision_score) that will be computed and
|
38
|
+
# logged during the autologging routine for a particular model type (eg, classifier, regressor).
|
39
|
+
_SklearnMetric = collections.namedtuple("_SklearnMetric", ["name", "function", "arguments"])
|
40
|
+
|
41
|
+
|
42
|
+
def _gen_xgboost_sklearn_estimators_to_patch():
|
43
|
+
import xgboost as xgb
|
44
|
+
|
45
|
+
all_classes = inspect.getmembers(xgb.sklearn, inspect.isclass)
|
46
|
+
base_class = xgb.sklearn.XGBModel
|
47
|
+
sklearn_estimators = []
|
48
|
+
for _, class_object in all_classes:
|
49
|
+
if issubclass(class_object, base_class) and class_object != base_class:
|
50
|
+
sklearn_estimators.append(class_object)
|
51
|
+
|
52
|
+
return sklearn_estimators
|
53
|
+
|
54
|
+
|
55
|
+
def _gen_lightgbm_sklearn_estimators_to_patch():
|
56
|
+
import lightgbm as lgb
|
57
|
+
|
58
|
+
import mlflow.lightgbm
|
59
|
+
|
60
|
+
all_classes = inspect.getmembers(lgb.sklearn, inspect.isclass)
|
61
|
+
base_class = lgb.sklearn._LGBMModelBase
|
62
|
+
sklearn_estimators = []
|
63
|
+
for _, class_object in all_classes:
|
64
|
+
package_name = class_object.__module__.split(".")[0]
|
65
|
+
if (
|
66
|
+
package_name == mlflow.lightgbm.FLAVOR_NAME
|
67
|
+
and issubclass(class_object, base_class)
|
68
|
+
and class_object != base_class
|
69
|
+
):
|
70
|
+
sklearn_estimators.append(class_object)
|
71
|
+
|
72
|
+
return sklearn_estimators
|
73
|
+
|
74
|
+
|
75
|
+
def _get_estimator_info_tags(estimator):
|
76
|
+
"""
|
77
|
+
Returns:
|
78
|
+
A dictionary of MLflow run tag keys and values describing the specified estimator.
|
79
|
+
"""
|
80
|
+
return {
|
81
|
+
"estimator_name": estimator.__class__.__name__,
|
82
|
+
"estimator_class": (estimator.__class__.__module__ + "." + estimator.__class__.__name__),
|
83
|
+
}
|
84
|
+
|
85
|
+
|
86
|
+
def _get_X_y_and_sample_weight(fit_func, fit_args, fit_kwargs):
|
87
|
+
"""
|
88
|
+
Get a tuple of (X, y, sample_weight) in the following steps.
|
89
|
+
|
90
|
+
1. Extract X and y from fit_args and fit_kwargs.
|
91
|
+
2. If the sample_weight argument exists in fit_func,
|
92
|
+
extract it from fit_args or fit_kwargs and return (X, y, sample_weight),
|
93
|
+
otherwise return (X, y)
|
94
|
+
|
95
|
+
Args:
|
96
|
+
fit_func: A fit function object.
|
97
|
+
fit_args: Positional arguments given to fit_func.
|
98
|
+
fit_kwargs: Keyword arguments given to fit_func.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
A tuple of either (X, y, sample_weight), where `y` and `sample_weight` may be
|
102
|
+
`None` if the specified `fit_args` and `fit_kwargs` do not specify labels or
|
103
|
+
a sample weighting. Copies of `X` and `y` are made in order to avoid mutation
|
104
|
+
of the dataset during training.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def _get_Xy(args, kwargs, X_var_name, y_var_name):
|
108
|
+
# corresponds to: model.fit(X, y)
|
109
|
+
if len(args) >= 2:
|
110
|
+
return args[:2]
|
111
|
+
|
112
|
+
# corresponds to: model.fit(X, <y_var_name>=y)
|
113
|
+
if len(args) == 1:
|
114
|
+
return args[0], kwargs.get(y_var_name)
|
115
|
+
|
116
|
+
# corresponds to: model.fit(<X_var_name>=X, <y_var_name>=y)
|
117
|
+
return kwargs[X_var_name], kwargs.get(y_var_name)
|
118
|
+
|
119
|
+
def _get_sample_weight(arg_names, args, kwargs):
|
120
|
+
sample_weight_index = arg_names.index(_SAMPLE_WEIGHT)
|
121
|
+
|
122
|
+
# corresponds to: model.fit(X, y, ..., sample_weight)
|
123
|
+
if len(args) > sample_weight_index:
|
124
|
+
return args[sample_weight_index]
|
125
|
+
|
126
|
+
# corresponds to: model.fit(X, y, ..., sample_weight=sample_weight)
|
127
|
+
if _SAMPLE_WEIGHT in kwargs:
|
128
|
+
return kwargs[_SAMPLE_WEIGHT]
|
129
|
+
|
130
|
+
return None
|
131
|
+
|
132
|
+
fit_arg_names = _get_arg_names(fit_func)
|
133
|
+
# In most cases, X_var_name and y_var_name become "X" and "y", respectively.
|
134
|
+
# However, certain sklearn models use different variable names for X and y.
|
135
|
+
# E.g., see: https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html#sklearn.multioutput.MultiOutputClassifier.fit
|
136
|
+
X_var_name, y_var_name = fit_arg_names[:2]
|
137
|
+
X, y = _get_Xy(fit_args, fit_kwargs, X_var_name, y_var_name)
|
138
|
+
if X is not None:
|
139
|
+
X = deepcopy(X)
|
140
|
+
if y is not None:
|
141
|
+
y = deepcopy(y)
|
142
|
+
sample_weight = (
|
143
|
+
_get_sample_weight(fit_arg_names, fit_args, fit_kwargs)
|
144
|
+
if (_SAMPLE_WEIGHT in fit_arg_names)
|
145
|
+
else None
|
146
|
+
)
|
147
|
+
|
148
|
+
return (X, y, sample_weight)
|
149
|
+
|
150
|
+
|
151
|
+
def _get_metrics_value_dict(metrics_list):
|
152
|
+
metric_value_dict = {}
|
153
|
+
for metric in metrics_list:
|
154
|
+
try:
|
155
|
+
metric_value = metric.function(**metric.arguments)
|
156
|
+
except Exception as e:
|
157
|
+
_log_warning_for_metrics(metric.name, metric.function, e)
|
158
|
+
else:
|
159
|
+
metric_value_dict[metric.name] = metric_value
|
160
|
+
return metric_value_dict
|
161
|
+
|
162
|
+
|
163
|
+
def _get_classifier_metrics(fitted_estimator, prefix, X, y_true, sample_weight, pos_label): # noqa: D417
|
164
|
+
"""
|
165
|
+
Compute and record various common metrics for classifiers
|
166
|
+
|
167
|
+
For (1) precision score:
|
168
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html
|
169
|
+
(2) recall score:
|
170
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
|
171
|
+
(3) f1_score:
|
172
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
|
173
|
+
By default, when `pos_label` is not specified (passed in as `None`), we set `average`
|
174
|
+
to `weighted` to compute the weighted score of these metrics.
|
175
|
+
When the `pos_label` is specified (not `None`), we set `average` to `binary`.
|
176
|
+
|
177
|
+
For (4) accuracy score:
|
178
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
|
179
|
+
we choose the parameter `normalize` to be `True` to output the percentage of accuracy,
|
180
|
+
as opposed to `False` that outputs the absolute correct number of sample prediction
|
181
|
+
|
182
|
+
We log additional metrics if certain classifier has method `predict_proba`
|
183
|
+
(5) log loss:
|
184
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html
|
185
|
+
(6) roc_auc_score:
|
186
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
|
187
|
+
By default, for roc_auc_score, we pick `average` to be `weighted`, `multi_class` to be `ovo`,
|
188
|
+
to make the output more insensitive to dataset imbalance.
|
189
|
+
|
190
|
+
Steps:
|
191
|
+
1. Extract X and y_true from fit_args and fit_kwargs, and compute y_pred.
|
192
|
+
2. If the sample_weight argument exists in fit_func (accuracy_score by default
|
193
|
+
has sample_weight), extract it from fit_args or fit_kwargs as
|
194
|
+
(y_true, y_pred, ...... sample_weight), otherwise as (y_true, y_pred, ......)
|
195
|
+
3. return a dictionary of metric(name, value)
|
196
|
+
|
197
|
+
Args:
|
198
|
+
fitted_estimator: The already fitted classifier
|
199
|
+
fit_args: Positional arguments given to fit_func.
|
200
|
+
fit_kwargs: Keyword arguments given to fit_func.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
dictionary of (function name, computed value)
|
204
|
+
"""
|
205
|
+
import sklearn
|
206
|
+
|
207
|
+
average = "weighted" if pos_label is None else "binary"
|
208
|
+
y_pred = fitted_estimator.predict(X)
|
209
|
+
|
210
|
+
classifier_metrics = [
|
211
|
+
_SklearnMetric(
|
212
|
+
name=prefix + "precision_score",
|
213
|
+
function=sklearn.metrics.precision_score,
|
214
|
+
arguments={
|
215
|
+
"y_true": y_true,
|
216
|
+
"y_pred": y_pred,
|
217
|
+
"pos_label": pos_label,
|
218
|
+
"average": average,
|
219
|
+
"sample_weight": sample_weight,
|
220
|
+
},
|
221
|
+
),
|
222
|
+
_SklearnMetric(
|
223
|
+
name=prefix + "recall_score",
|
224
|
+
function=sklearn.metrics.recall_score,
|
225
|
+
arguments={
|
226
|
+
"y_true": y_true,
|
227
|
+
"y_pred": y_pred,
|
228
|
+
"pos_label": pos_label,
|
229
|
+
"average": average,
|
230
|
+
"sample_weight": sample_weight,
|
231
|
+
},
|
232
|
+
),
|
233
|
+
_SklearnMetric(
|
234
|
+
name=prefix + "f1_score",
|
235
|
+
function=sklearn.metrics.f1_score,
|
236
|
+
arguments={
|
237
|
+
"y_true": y_true,
|
238
|
+
"y_pred": y_pred,
|
239
|
+
"pos_label": pos_label,
|
240
|
+
"average": average,
|
241
|
+
"sample_weight": sample_weight,
|
242
|
+
},
|
243
|
+
),
|
244
|
+
_SklearnMetric(
|
245
|
+
name=prefix + "accuracy_score",
|
246
|
+
function=sklearn.metrics.accuracy_score,
|
247
|
+
arguments={
|
248
|
+
"y_true": y_true,
|
249
|
+
"y_pred": y_pred,
|
250
|
+
"normalize": True,
|
251
|
+
"sample_weight": sample_weight,
|
252
|
+
},
|
253
|
+
),
|
254
|
+
]
|
255
|
+
|
256
|
+
if hasattr(fitted_estimator, "predict_proba"):
|
257
|
+
y_pred_proba = fitted_estimator.predict_proba(X)
|
258
|
+
classifier_metrics.extend(
|
259
|
+
[
|
260
|
+
_SklearnMetric(
|
261
|
+
name=prefix + "log_loss",
|
262
|
+
function=sklearn.metrics.log_loss,
|
263
|
+
arguments={
|
264
|
+
"y_true": y_true,
|
265
|
+
"y_pred": y_pred_proba,
|
266
|
+
"sample_weight": sample_weight,
|
267
|
+
},
|
268
|
+
),
|
269
|
+
]
|
270
|
+
)
|
271
|
+
|
272
|
+
if _is_metric_supported("roc_auc_score"):
|
273
|
+
# For binary case, the parameter `y_score` expect scores must be
|
274
|
+
# the scores of the class with the greater label.
|
275
|
+
if len(y_pred_proba[0]) == 2:
|
276
|
+
y_pred_proba = y_pred_proba[:, 1]
|
277
|
+
|
278
|
+
classifier_metrics.extend(
|
279
|
+
[
|
280
|
+
_SklearnMetric(
|
281
|
+
name=prefix + "roc_auc",
|
282
|
+
function=sklearn.metrics.roc_auc_score,
|
283
|
+
arguments={
|
284
|
+
"y_true": y_true,
|
285
|
+
"y_score": y_pred_proba,
|
286
|
+
"average": "weighted",
|
287
|
+
"sample_weight": sample_weight,
|
288
|
+
"multi_class": "ovo",
|
289
|
+
},
|
290
|
+
),
|
291
|
+
]
|
292
|
+
)
|
293
|
+
|
294
|
+
return _get_metrics_value_dict(classifier_metrics)
|
295
|
+
|
296
|
+
|
297
|
+
def _get_class_labels_from_estimator(estimator):
|
298
|
+
"""
|
299
|
+
Extracts class labels from `estimator` if `estimator.classes` is available.
|
300
|
+
"""
|
301
|
+
return estimator.classes_ if hasattr(estimator, "classes_") else None
|
302
|
+
|
303
|
+
|
304
|
+
def _get_classifier_artifacts(fitted_estimator, prefix, X, y_true, sample_weight): # noqa: D417
|
305
|
+
"""
|
306
|
+
Draw and record various common artifacts for classifier
|
307
|
+
|
308
|
+
For all classifiers, we always log:
|
309
|
+
(1) confusion matrix:
|
310
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
|
311
|
+
|
312
|
+
For only binary classifiers, we will log:
|
313
|
+
(2) precision recall curve:
|
314
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_precision_recall_curve.html
|
315
|
+
(3) roc curve:
|
316
|
+
https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
317
|
+
|
318
|
+
Steps:
|
319
|
+
1. Extract X and y_true from fit_args and fit_kwargs, and split into train & test datasets.
|
320
|
+
2. If the sample_weight argument exists in fit_func (accuracy_score by default
|
321
|
+
has sample_weight), extract it from fit_args or fit_kwargs as
|
322
|
+
(y_true, y_pred, sample_weight, multioutput), otherwise as (y_true, y_pred, multioutput)
|
323
|
+
3. return a list of artifacts path to be logged
|
324
|
+
|
325
|
+
Args:
|
326
|
+
fitted_estimator: The already fitted regressor
|
327
|
+
fit_args: Positional arguments given to fit_func.
|
328
|
+
fit_kwargs: Keyword arguments given to fit_func.
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
List of artifacts to be logged
|
332
|
+
"""
|
333
|
+
import sklearn
|
334
|
+
|
335
|
+
if not _is_plotting_supported():
|
336
|
+
return []
|
337
|
+
|
338
|
+
is_plot_function_deprecated = Version(sklearn.__version__) >= Version("1.0")
|
339
|
+
|
340
|
+
def plot_confusion_matrix(*args, **kwargs):
|
341
|
+
import matplotlib
|
342
|
+
import matplotlib.pyplot as plt
|
343
|
+
|
344
|
+
class_labels = _get_class_labels_from_estimator(fitted_estimator)
|
345
|
+
if class_labels is None:
|
346
|
+
class_labels = set(y_true)
|
347
|
+
|
348
|
+
with matplotlib.rc_context(
|
349
|
+
{
|
350
|
+
"font.size": min(8.0, 50.0 / len(class_labels)),
|
351
|
+
"axes.labelsize": 8.0,
|
352
|
+
"figure.dpi": 175,
|
353
|
+
}
|
354
|
+
):
|
355
|
+
_, ax = plt.subplots(1, 1, figsize=(6.0, 4.0))
|
356
|
+
return (
|
357
|
+
sklearn.metrics.ConfusionMatrixDisplay.from_estimator(*args, **kwargs, ax=ax)
|
358
|
+
if is_plot_function_deprecated
|
359
|
+
else sklearn.metrics.plot_confusion_matrix(*args, **kwargs, ax=ax)
|
360
|
+
)
|
361
|
+
|
362
|
+
y_true_arg_name = "y" if is_plot_function_deprecated else "y_true"
|
363
|
+
classifier_artifacts = [
|
364
|
+
_SklearnArtifact(
|
365
|
+
name=prefix + "confusion_matrix",
|
366
|
+
function=plot_confusion_matrix,
|
367
|
+
arguments=dict(
|
368
|
+
estimator=fitted_estimator,
|
369
|
+
X=X,
|
370
|
+
sample_weight=sample_weight,
|
371
|
+
normalize="true",
|
372
|
+
cmap="Blues",
|
373
|
+
**{y_true_arg_name: y_true},
|
374
|
+
),
|
375
|
+
title="Normalized confusion matrix",
|
376
|
+
),
|
377
|
+
]
|
378
|
+
|
379
|
+
# The plot_roc_curve and plot_precision_recall_curve can only be
|
380
|
+
# supported for binary classifier
|
381
|
+
if len(set(y_true)) == 2:
|
382
|
+
classifier_artifacts.extend(
|
383
|
+
[
|
384
|
+
_SklearnArtifact(
|
385
|
+
name=prefix + "roc_curve",
|
386
|
+
function=sklearn.metrics.RocCurveDisplay.from_estimator
|
387
|
+
if is_plot_function_deprecated
|
388
|
+
else sklearn.metrics.plot_roc_curve,
|
389
|
+
arguments={
|
390
|
+
"estimator": fitted_estimator,
|
391
|
+
"X": X,
|
392
|
+
"y": y_true,
|
393
|
+
"sample_weight": sample_weight,
|
394
|
+
},
|
395
|
+
title="ROC curve",
|
396
|
+
),
|
397
|
+
_SklearnArtifact(
|
398
|
+
name=prefix + "precision_recall_curve",
|
399
|
+
function=sklearn.metrics.PrecisionRecallDisplay.from_estimator
|
400
|
+
if is_plot_function_deprecated
|
401
|
+
else sklearn.metrics.plot_precision_recall_curve,
|
402
|
+
arguments={
|
403
|
+
"estimator": fitted_estimator,
|
404
|
+
"X": X,
|
405
|
+
"y": y_true,
|
406
|
+
"sample_weight": sample_weight,
|
407
|
+
},
|
408
|
+
title="Precision recall curve",
|
409
|
+
),
|
410
|
+
]
|
411
|
+
)
|
412
|
+
|
413
|
+
return classifier_artifacts
|
414
|
+
|
415
|
+
|
416
|
+
def _get_regressor_metrics(fitted_estimator, prefix, X, y_true, sample_weight): # noqa: D417
|
417
|
+
"""
|
418
|
+
Compute and record various common metrics for regressors
|
419
|
+
|
420
|
+
For (1) (root) mean squared error:
|
421
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
|
422
|
+
(2) mean absolute error:
|
423
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html
|
424
|
+
(3) r2 score:
|
425
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
|
426
|
+
By default, we choose the parameter `multioutput` to be `uniform_average`
|
427
|
+
to average outputs with uniform weight.
|
428
|
+
|
429
|
+
Steps:
|
430
|
+
1. Extract X and y_true from fit_args and fit_kwargs, and compute y_pred.
|
431
|
+
2. If the sample_weight argument exists in fit_func (accuracy_score by default
|
432
|
+
has sample_weight), extract it from fit_args or fit_kwargs as
|
433
|
+
(y_true, y_pred, sample_weight, multioutput), otherwise as (y_true, y_pred, multioutput)
|
434
|
+
3. return a dictionary of metric(name, value)
|
435
|
+
|
436
|
+
Args:
|
437
|
+
fitted_estimator: The already fitted regressor
|
438
|
+
fit_args: Positional arguments given to fit_func.
|
439
|
+
fit_kwargs: Keyword arguments given to fit_func.
|
440
|
+
|
441
|
+
Returns:
|
442
|
+
dictionary of (function name, computed value)
|
443
|
+
"""
|
444
|
+
import sklearn
|
445
|
+
|
446
|
+
y_pred = fitted_estimator.predict(X)
|
447
|
+
|
448
|
+
regressor_metrics = [
|
449
|
+
_SklearnMetric(
|
450
|
+
name=prefix + "mean_squared_error",
|
451
|
+
function=sklearn.metrics.mean_squared_error,
|
452
|
+
arguments={
|
453
|
+
"y_true": y_true,
|
454
|
+
"y_pred": y_pred,
|
455
|
+
"sample_weight": sample_weight,
|
456
|
+
"multioutput": "uniform_average",
|
457
|
+
},
|
458
|
+
),
|
459
|
+
_SklearnMetric(
|
460
|
+
name=prefix + "mean_absolute_error",
|
461
|
+
function=sklearn.metrics.mean_absolute_error,
|
462
|
+
arguments={
|
463
|
+
"y_true": y_true,
|
464
|
+
"y_pred": y_pred,
|
465
|
+
"sample_weight": sample_weight,
|
466
|
+
"multioutput": "uniform_average",
|
467
|
+
},
|
468
|
+
),
|
469
|
+
_SklearnMetric(
|
470
|
+
name=prefix + "r2_score",
|
471
|
+
function=sklearn.metrics.r2_score,
|
472
|
+
arguments={
|
473
|
+
"y_true": y_true,
|
474
|
+
"y_pred": y_pred,
|
475
|
+
"sample_weight": sample_weight,
|
476
|
+
"multioutput": "uniform_average",
|
477
|
+
},
|
478
|
+
),
|
479
|
+
]
|
480
|
+
|
481
|
+
# To be compatible with older versions of scikit-learn (below 0.22.2), where
|
482
|
+
# `sklearn.metrics.mean_squared_error` does not have "squared" parameter to calculate `rmse`,
|
483
|
+
# we compute it through np.sqrt(<value of mse>)
|
484
|
+
metrics_value_dict = _get_metrics_value_dict(regressor_metrics)
|
485
|
+
metrics_value_dict[prefix + "root_mean_squared_error"] = np.sqrt(
|
486
|
+
metrics_value_dict[prefix + "mean_squared_error"]
|
487
|
+
)
|
488
|
+
|
489
|
+
return metrics_value_dict
|
490
|
+
|
491
|
+
|
492
|
+
def _log_warning_for_metrics(func_name, func_call, err):
|
493
|
+
msg = (
|
494
|
+
func_call.__qualname__
|
495
|
+
+ " failed. The metric "
|
496
|
+
+ func_name
|
497
|
+
+ " will not be recorded."
|
498
|
+
+ " Metric error: "
|
499
|
+
+ str(err)
|
500
|
+
)
|
501
|
+
_logger.warning(msg)
|
502
|
+
|
503
|
+
|
504
|
+
def _log_warning_for_artifacts(func_name, func_call, err):
|
505
|
+
msg = (
|
506
|
+
func_call.__qualname__
|
507
|
+
+ " failed. The artifact "
|
508
|
+
+ func_name
|
509
|
+
+ " will not be recorded."
|
510
|
+
+ " Artifact error: "
|
511
|
+
+ str(err)
|
512
|
+
)
|
513
|
+
_logger.warning(msg)
|
514
|
+
|
515
|
+
|
516
|
+
def _log_specialized_estimator_content(
|
517
|
+
autologging_client,
|
518
|
+
fitted_estimator,
|
519
|
+
run_id,
|
520
|
+
prefix,
|
521
|
+
X,
|
522
|
+
y_true,
|
523
|
+
sample_weight,
|
524
|
+
pos_label,
|
525
|
+
model_id,
|
526
|
+
dataset,
|
527
|
+
):
|
528
|
+
import sklearn
|
529
|
+
|
530
|
+
metrics = {}
|
531
|
+
|
532
|
+
if y_true is not None:
|
533
|
+
try:
|
534
|
+
if sklearn.base.is_classifier(fitted_estimator):
|
535
|
+
metrics = _get_classifier_metrics(
|
536
|
+
fitted_estimator, prefix, X, y_true, sample_weight, pos_label
|
537
|
+
)
|
538
|
+
elif sklearn.base.is_regressor(fitted_estimator):
|
539
|
+
metrics = _get_regressor_metrics(fitted_estimator, prefix, X, y_true, sample_weight)
|
540
|
+
except Exception as err:
|
541
|
+
msg = (
|
542
|
+
"Failed to autolog metrics for "
|
543
|
+
+ fitted_estimator.__class__.__name__
|
544
|
+
+ ". Logging error: "
|
545
|
+
+ str(err)
|
546
|
+
)
|
547
|
+
_logger.warning(msg)
|
548
|
+
else:
|
549
|
+
autologging_client.log_metrics(
|
550
|
+
run_id=run_id,
|
551
|
+
metrics=metrics,
|
552
|
+
model_id=model_id,
|
553
|
+
dataset=dataset,
|
554
|
+
)
|
555
|
+
|
556
|
+
if sklearn.base.is_classifier(fitted_estimator):
|
557
|
+
try:
|
558
|
+
artifacts = _get_classifier_artifacts(
|
559
|
+
fitted_estimator, prefix, X, y_true, sample_weight
|
560
|
+
)
|
561
|
+
except Exception as e:
|
562
|
+
msg = (
|
563
|
+
"Failed to autolog artifacts for "
|
564
|
+
+ fitted_estimator.__class__.__name__
|
565
|
+
+ ". Logging error: "
|
566
|
+
+ str(e)
|
567
|
+
)
|
568
|
+
_logger.warning(msg)
|
569
|
+
return metrics
|
570
|
+
|
571
|
+
try:
|
572
|
+
import matplotlib
|
573
|
+
import matplotlib.pyplot as plt
|
574
|
+
except ImportError as ie:
|
575
|
+
_logger.warning(
|
576
|
+
f"Failed to import matplotlib (error: {ie!r}). Skipping artifact logging."
|
577
|
+
)
|
578
|
+
return metrics
|
579
|
+
|
580
|
+
_matplotlib_config = {"savefig.dpi": 175, "figure.autolayout": True, "font.size": 8}
|
581
|
+
with TempDir() as tmp_dir:
|
582
|
+
for artifact in artifacts:
|
583
|
+
try:
|
584
|
+
with matplotlib.rc_context(_matplotlib_config):
|
585
|
+
display = artifact.function(**artifact.arguments)
|
586
|
+
display.ax_.set_title(artifact.title)
|
587
|
+
artifact_path = f"{artifact.name}.png"
|
588
|
+
filepath = tmp_dir.path(artifact_path)
|
589
|
+
display.figure_.savefig(fname=filepath, format="png")
|
590
|
+
plt.close(display.figure_)
|
591
|
+
except Exception as e:
|
592
|
+
_log_warning_for_artifacts(artifact.name, artifact.function, e)
|
593
|
+
|
594
|
+
MlflowClient().log_artifacts(run_id, tmp_dir.path())
|
595
|
+
|
596
|
+
return metrics
|
597
|
+
|
598
|
+
|
599
|
+
def _is_estimator_html_repr_supported():
|
600
|
+
import sklearn
|
601
|
+
|
602
|
+
# Only scikit-learn >= 0.23 supports `estimator_html_repr`
|
603
|
+
return Version(sklearn.__version__) >= Version("0.23.0")
|
604
|
+
|
605
|
+
|
606
|
+
def _log_estimator_html(run_id, estimator):
|
607
|
+
if not _is_estimator_html_repr_supported():
|
608
|
+
return
|
609
|
+
|
610
|
+
from sklearn.utils import estimator_html_repr
|
611
|
+
|
612
|
+
# Specifies charset so triangle toggle buttons are not garbled
|
613
|
+
estimator_html_string = f"""
|
614
|
+
<!DOCTYPE html>
|
615
|
+
<html lang="en">
|
616
|
+
<head>
|
617
|
+
<meta charset="UTF-8"/>
|
618
|
+
</head>
|
619
|
+
<body>
|
620
|
+
{estimator_html_repr(estimator)}
|
621
|
+
</body>
|
622
|
+
</html>
|
623
|
+
"""
|
624
|
+
MlflowClient().log_text(run_id, estimator_html_string, artifact_file="estimator.html")
|
625
|
+
|
626
|
+
|
627
|
+
def _log_estimator_content(
|
628
|
+
autologging_client,
|
629
|
+
estimator,
|
630
|
+
run_id,
|
631
|
+
prefix,
|
632
|
+
X,
|
633
|
+
y_true=None,
|
634
|
+
sample_weight=None,
|
635
|
+
pos_label=None,
|
636
|
+
model_id=None,
|
637
|
+
dataset=None,
|
638
|
+
):
|
639
|
+
"""
|
640
|
+
Logs content for the given estimator, which includes metrics and artifacts that might be
|
641
|
+
tailored to the estimator's type (e.g., regression vs classification). Training labels
|
642
|
+
are required for metric computation; metrics will be omitted if labels are not available.
|
643
|
+
|
644
|
+
Args:
|
645
|
+
autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
|
646
|
+
efficiently logging run data to MLflow Tracking.
|
647
|
+
estimator: The estimator used to compute metrics and artifacts.
|
648
|
+
run_id: The run under which the content is logged.
|
649
|
+
prefix: A prefix used to name the logged content. Typically it's 'training_' for
|
650
|
+
training-time content and user-controlled for evaluation-time content.
|
651
|
+
X: The data samples.
|
652
|
+
y_true: Labels.
|
653
|
+
sample_weight: Per-sample weights used in the computation of metrics and artifacts.
|
654
|
+
pos_label: The positive label used to compute binary classification metrics such as
|
655
|
+
precision, recall, f1, etc. This parameter is only used for classification metrics.
|
656
|
+
If set to `None`, the function will calculate metrics for each label and find their
|
657
|
+
average weighted by support (number of true instances for each label).
|
658
|
+
model_id: Model ID.
|
659
|
+
dataset: The dataset used to evaluate the model.
|
660
|
+
|
661
|
+
Returns:
|
662
|
+
A dict of the computed metrics.
|
663
|
+
"""
|
664
|
+
metrics = _log_specialized_estimator_content(
|
665
|
+
autologging_client=autologging_client,
|
666
|
+
fitted_estimator=estimator,
|
667
|
+
run_id=run_id,
|
668
|
+
prefix=prefix,
|
669
|
+
X=X,
|
670
|
+
y_true=y_true,
|
671
|
+
sample_weight=sample_weight,
|
672
|
+
pos_label=pos_label,
|
673
|
+
model_id=model_id,
|
674
|
+
dataset=dataset,
|
675
|
+
)
|
676
|
+
|
677
|
+
if hasattr(estimator, "score") and y_true is not None:
|
678
|
+
try:
|
679
|
+
# Use the sample weight only if it is present in the score args
|
680
|
+
score_arg_names = _get_arg_names(estimator.score)
|
681
|
+
score_args = (
|
682
|
+
(X, y_true, sample_weight) if _SAMPLE_WEIGHT in score_arg_names else (X, y_true)
|
683
|
+
)
|
684
|
+
score = estimator.score(*score_args)
|
685
|
+
except Exception as e:
|
686
|
+
msg = (
|
687
|
+
estimator.score.__qualname__
|
688
|
+
+ " failed. The 'training_score' metric will not be recorded. Scoring error: "
|
689
|
+
+ str(e)
|
690
|
+
)
|
691
|
+
_logger.warning(msg)
|
692
|
+
else:
|
693
|
+
score_key = prefix + "score"
|
694
|
+
autologging_client.log_metrics(
|
695
|
+
run_id=run_id,
|
696
|
+
metrics={score_key: score},
|
697
|
+
model_id=model_id,
|
698
|
+
dataset=dataset,
|
699
|
+
)
|
700
|
+
metrics[score_key] = score
|
701
|
+
_log_estimator_html(run_id, estimator)
|
702
|
+
return metrics
|
703
|
+
|
704
|
+
|
705
|
+
def _get_meta_estimators_for_autologging():
|
706
|
+
"""
|
707
|
+
Returns:
|
708
|
+
A list of meta estimator class definitions
|
709
|
+
(e.g., `sklearn.model_selection.GridSearchCV`) that should be included
|
710
|
+
when patching training functions for autologging
|
711
|
+
"""
|
712
|
+
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
713
|
+
from sklearn.pipeline import Pipeline
|
714
|
+
|
715
|
+
return [
|
716
|
+
GridSearchCV,
|
717
|
+
RandomizedSearchCV,
|
718
|
+
Pipeline,
|
719
|
+
]
|
720
|
+
|
721
|
+
|
722
|
+
def _is_parameter_search_estimator(estimator):
|
723
|
+
"""
|
724
|
+
Returns:
|
725
|
+
`True` if the specified scikit-learn estimator is a parameter search estimator,
|
726
|
+
such as `GridSearchCV`. `False` otherwise.
|
727
|
+
"""
|
728
|
+
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
729
|
+
|
730
|
+
parameter_search_estimators = [
|
731
|
+
GridSearchCV,
|
732
|
+
RandomizedSearchCV,
|
733
|
+
]
|
734
|
+
|
735
|
+
return any(
|
736
|
+
isinstance(estimator, param_search_estimator)
|
737
|
+
for param_search_estimator in parameter_search_estimators
|
738
|
+
)
|
739
|
+
|
740
|
+
|
741
|
+
def _log_parameter_search_results_as_artifact(cv_results_df, run_id):
|
742
|
+
"""
|
743
|
+
Records a collection of parameter search results as an MLflow artifact
|
744
|
+
for the specified run.
|
745
|
+
|
746
|
+
Args:
|
747
|
+
cv_results_df: A Pandas DataFrame containing the results of a parameter search
|
748
|
+
training session, which may be obtained by parsing the `cv_results_`
|
749
|
+
attribute of a trained parameter search estimator such as
|
750
|
+
`GridSearchCV`.
|
751
|
+
run_id: The ID of the MLflow Run to which the artifact should be recorded.
|
752
|
+
"""
|
753
|
+
with TempDir() as t:
|
754
|
+
results_path = t.path("cv_results.csv")
|
755
|
+
cv_results_df.to_csv(results_path, index=False)
|
756
|
+
MlflowClient().log_artifact(run_id, results_path)
|
757
|
+
|
758
|
+
|
759
|
+
# Log how many child runs will be created vs omitted based on `max_tuning_runs`.
|
760
|
+
def _log_child_runs_info(max_tuning_runs, total_runs):
|
761
|
+
rest = total_runs - max_tuning_runs
|
762
|
+
|
763
|
+
# Set logging statement for runs to be logged.
|
764
|
+
if max_tuning_runs == 0:
|
765
|
+
logging_phrase = "no runs"
|
766
|
+
elif max_tuning_runs == 1:
|
767
|
+
logging_phrase = "the best run"
|
768
|
+
else:
|
769
|
+
logging_phrase = f"the {max_tuning_runs} best runs"
|
770
|
+
|
771
|
+
# Set logging statement for runs to be omitted.
|
772
|
+
if rest <= 0:
|
773
|
+
omitting_phrase = "no runs"
|
774
|
+
elif rest == 1:
|
775
|
+
omitting_phrase = "one run"
|
776
|
+
else:
|
777
|
+
omitting_phrase = f"{rest} runs"
|
778
|
+
|
779
|
+
_logger.info("Logging %s, %s will be omitted.", logging_phrase, omitting_phrase)
|
780
|
+
|
781
|
+
|
782
|
+
def _create_child_runs_for_parameter_search( # noqa: D417
|
783
|
+
autologging_client,
|
784
|
+
cv_estimator,
|
785
|
+
parent_run,
|
786
|
+
max_tuning_runs,
|
787
|
+
child_tags=None,
|
788
|
+
dataset=None,
|
789
|
+
best_estimator_params=None,
|
790
|
+
best_estimator_model_id=None,
|
791
|
+
):
|
792
|
+
"""
|
793
|
+
Creates a collection of child runs for a parameter search training session.
|
794
|
+
Runs are reconstructed from the `cv_results_` attribute of the specified trained
|
795
|
+
parameter search estimator - `cv_estimator`, which provides relevant performance
|
796
|
+
metrics for each point in the parameter search space. One child run is created
|
797
|
+
for each point in the parameter search space. For additional information, see
|
798
|
+
`https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html`_.
|
799
|
+
|
800
|
+
Args:
|
801
|
+
autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
|
802
|
+
efficiently logging run data to MLflow Tracking.
|
803
|
+
cv_estimator: The trained parameter search estimator for which to create
|
804
|
+
child runs.
|
805
|
+
parent_run: A py:class:`mlflow.entities.Run` object referring to the parent
|
806
|
+
parameter search run for which child runs should be created.
|
807
|
+
child_tags: An optional dictionary of MLflow tag keys and values to log
|
808
|
+
for each child run.
|
809
|
+
dataset: The dataset used to evaluate the model.
|
810
|
+
best_estimator_params: The parameters of the best estimator.
|
811
|
+
best_estimator_model_id: The model ID of the logged best estimator.
|
812
|
+
"""
|
813
|
+
import pandas as pd
|
814
|
+
|
815
|
+
def first_custom_rank_column(df):
|
816
|
+
column_names = df.columns.values
|
817
|
+
for col_name in column_names:
|
818
|
+
if "rank_test_" in col_name:
|
819
|
+
return col_name
|
820
|
+
|
821
|
+
# Use the start time of the parent parameter search run as a rough estimate for the
|
822
|
+
# start time of child runs, since we cannot precisely determine when each point
|
823
|
+
# in the parameter search space was explored
|
824
|
+
child_run_start_time = parent_run.info.start_time
|
825
|
+
child_run_end_time = get_current_time_millis()
|
826
|
+
|
827
|
+
seed_estimator = cv_estimator.estimator
|
828
|
+
# In the unlikely case that a seed of a parameter search estimator is,
|
829
|
+
# itself, a parameter search estimator, we should avoid logging the untuned
|
830
|
+
# parameters of the seeds's seed estimator
|
831
|
+
should_log_params_deeply = not _is_parameter_search_estimator(seed_estimator)
|
832
|
+
# Each row of `cv_results_` only provides parameters that vary across
|
833
|
+
# the user-specified parameter grid. In order to log the complete set
|
834
|
+
# of parameters for each child run, we fetch the parameters defined by
|
835
|
+
# the seed estimator and update them with parameter subset specified
|
836
|
+
# in the result row
|
837
|
+
base_params = seed_estimator.get_params(deep=should_log_params_deeply)
|
838
|
+
cv_results_df = pd.DataFrame.from_dict(cv_estimator.cv_results_)
|
839
|
+
|
840
|
+
if max_tuning_runs is None:
|
841
|
+
cv_results_best_n_df = cv_results_df
|
842
|
+
else:
|
843
|
+
rank_column_name = "rank_test_score"
|
844
|
+
if rank_column_name not in cv_results_df.columns.values:
|
845
|
+
rank_column_name = first_custom_rank_column(cv_results_df)
|
846
|
+
warnings.warn(
|
847
|
+
f"Top {max_tuning_runs} child runs will be created based on ordering in "
|
848
|
+
f"{rank_column_name} column. You can choose not to limit the number of "
|
849
|
+
"child runs created by setting `max_tuning_runs=None`."
|
850
|
+
)
|
851
|
+
cv_results_best_n_df = cv_results_df.nsmallest(max_tuning_runs, rank_column_name)
|
852
|
+
# Log how many child runs will be created vs omitted.
|
853
|
+
_log_child_runs_info(max_tuning_runs, len(cv_results_df))
|
854
|
+
|
855
|
+
datasets = [
|
856
|
+
DatasetInput(
|
857
|
+
dataset._to_mlflow_entity(), tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="train")]
|
858
|
+
)
|
859
|
+
]
|
860
|
+
for _, result_row in cv_results_best_n_df.iterrows():
|
861
|
+
tags_to_log = dict(child_tags) if child_tags else {}
|
862
|
+
tags_to_log.update({MLFLOW_PARENT_RUN_ID: parent_run.info.run_id})
|
863
|
+
tags_to_log.update(_get_estimator_info_tags(seed_estimator))
|
864
|
+
pending_child_run_id = autologging_client.create_run(
|
865
|
+
experiment_id=parent_run.info.experiment_id,
|
866
|
+
start_time=child_run_start_time,
|
867
|
+
tags=tags_to_log,
|
868
|
+
)
|
869
|
+
|
870
|
+
params_to_log = dict(base_params)
|
871
|
+
params_to_log.update(result_row.get("params", {}))
|
872
|
+
autologging_client.log_params(run_id=pending_child_run_id, params=params_to_log)
|
873
|
+
|
874
|
+
# Parameters values are recorded twice in the set of search `cv_results_`:
|
875
|
+
# once within a `params` column with dictionary values and once within
|
876
|
+
# a separate dataframe column that is created for each parameter. To prevent
|
877
|
+
# duplication of parameters, we log the consolidated values from the parameter
|
878
|
+
# dictionary column and filter out the other parameter-specific columns with
|
879
|
+
# names of the form `param_{param_name}`. Additionally, `cv_results_` produces
|
880
|
+
# metrics for each training split, which is fairly verbose; accordingly, we filter
|
881
|
+
# out per-split metrics in favor of aggregate metrics (mean, std, etc.)
|
882
|
+
excluded_metric_prefixes = ["param", "split"]
|
883
|
+
metrics_to_log = {
|
884
|
+
key: value
|
885
|
+
for key, value in result_row.items()
|
886
|
+
if not any(key.startswith(prefix) for prefix in excluded_metric_prefixes)
|
887
|
+
and isinstance(value, Number)
|
888
|
+
}
|
889
|
+
# Only log metrics to the best_estimator_model when the child run's
|
890
|
+
# parameters match the best_estimator's parameters.
|
891
|
+
model_id = (
|
892
|
+
best_estimator_model_id
|
893
|
+
if best_estimator_params
|
894
|
+
and result_row.get("params", {}).items() <= best_estimator_params.items()
|
895
|
+
else None
|
896
|
+
)
|
897
|
+
autologging_client.log_metrics(
|
898
|
+
run_id=pending_child_run_id,
|
899
|
+
metrics=metrics_to_log,
|
900
|
+
dataset=dataset,
|
901
|
+
model_id=model_id,
|
902
|
+
)
|
903
|
+
autologging_client.log_inputs(run_id=pending_child_run_id, datasets=datasets)
|
904
|
+
autologging_client.set_terminated(run_id=pending_child_run_id, end_time=child_run_end_time)
|
905
|
+
|
906
|
+
|
907
|
+
# Util function to check whether a metric is able to be computed in given sklearn version
|
908
|
+
def _is_metric_supported(metric_name):
|
909
|
+
import sklearn
|
910
|
+
|
911
|
+
# This dict can be extended to store special metrics' specific supported versions
|
912
|
+
_metric_supported_version = {"roc_auc_score": "0.22.2"}
|
913
|
+
|
914
|
+
return Version(sklearn.__version__) >= Version(_metric_supported_version[metric_name])
|
915
|
+
|
916
|
+
|
917
|
+
# Util function to check whether artifact plotting functions are able to be computed
|
918
|
+
# in given sklearn version (should >= 0.22.0)
|
919
|
+
def _is_plotting_supported():
|
920
|
+
import sklearn
|
921
|
+
|
922
|
+
return Version(sklearn.__version__) >= Version("0.22.0")
|
923
|
+
|
924
|
+
|
925
|
+
def _all_estimators():
|
926
|
+
try:
|
927
|
+
from sklearn.utils import all_estimators
|
928
|
+
|
929
|
+
return all_estimators()
|
930
|
+
except ImportError:
|
931
|
+
return _backported_all_estimators()
|
932
|
+
|
933
|
+
|
934
|
+
def _backported_all_estimators(type_filter=None):
|
935
|
+
"""
|
936
|
+
Backported from scikit-learn 0.23.2:
|
937
|
+
https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/utils/__init__.py#L1146
|
938
|
+
|
939
|
+
Use this backported `all_estimators` in old versions of sklearn because:
|
940
|
+
1. An inferior version of `all_estimators` that old versions of sklearn use for testing,
|
941
|
+
might function differently from a newer version.
|
942
|
+
2. This backported `all_estimators` works on old versions of sklearn that don't even define
|
943
|
+
the testing utility variant of `all_estimators`.
|
944
|
+
|
945
|
+
========== original docstring ==========
|
946
|
+
Get a list of all estimators from sklearn.
|
947
|
+
This function crawls the module and gets all classes that inherit
|
948
|
+
from BaseEstimator. Classes that are defined in test-modules are not
|
949
|
+
included.
|
950
|
+
By default meta_estimators such as GridSearchCV are also not included.
|
951
|
+
Parameters
|
952
|
+
----------
|
953
|
+
type_filter : string, list of string, or None, default=None
|
954
|
+
Which kind of estimators should be returned. If None, no filter is
|
955
|
+
applied and all estimators are returned. Possible values are
|
956
|
+
'classifier', 'regressor', 'cluster' and 'transformer' to get
|
957
|
+
estimators only of these specific types, or a list of these to
|
958
|
+
get the estimators that fit at least one of the types.
|
959
|
+
|
960
|
+
Returns
|
961
|
+
-------
|
962
|
+
estimators : list of tuples
|
963
|
+
List of (name, class), where ``name`` is the class name as string
|
964
|
+
and ``class`` is the actual type of the class.
|
965
|
+
"""
|
966
|
+
# lazy import to avoid circular imports from sklearn.base
|
967
|
+
import sklearn
|
968
|
+
from sklearn.base import (
|
969
|
+
BaseEstimator,
|
970
|
+
ClassifierMixin,
|
971
|
+
ClusterMixin,
|
972
|
+
RegressorMixin,
|
973
|
+
TransformerMixin,
|
974
|
+
)
|
975
|
+
from sklearn.utils._testing import ignore_warnings
|
976
|
+
|
977
|
+
IS_PYPY = platform.python_implementation() == "PyPy"
|
978
|
+
|
979
|
+
def is_abstract(c):
|
980
|
+
if not hasattr(c, "__abstractmethods__"):
|
981
|
+
return False
|
982
|
+
if not len(c.__abstractmethods__):
|
983
|
+
return False
|
984
|
+
return True
|
985
|
+
|
986
|
+
all_classes = []
|
987
|
+
modules_to_ignore = {"tests", "externals", "setup", "conftest"}
|
988
|
+
root = sklearn.__path__[0] # sklearn package
|
989
|
+
# Ignore deprecation warnings triggered at import time and from walking
|
990
|
+
# packages
|
991
|
+
with ignore_warnings(category=FutureWarning):
|
992
|
+
for _, modname, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
|
993
|
+
mod_parts = modname.split(".")
|
994
|
+
if any(part in modules_to_ignore for part in mod_parts) or "._" in modname:
|
995
|
+
continue
|
996
|
+
module = import_module(modname)
|
997
|
+
classes = inspect.getmembers(module, inspect.isclass)
|
998
|
+
classes = [(name, est_cls) for name, est_cls in classes if not name.startswith("_")]
|
999
|
+
|
1000
|
+
# TODO: Remove when FeatureHasher is implemented in PYPY
|
1001
|
+
# Skips FeatureHasher for PYPY
|
1002
|
+
if IS_PYPY and "feature_extraction" in modname:
|
1003
|
+
classes = [(name, est_cls) for name, est_cls in classes if name == "FeatureHasher"]
|
1004
|
+
|
1005
|
+
all_classes.extend(classes)
|
1006
|
+
|
1007
|
+
all_classes = set(all_classes)
|
1008
|
+
|
1009
|
+
estimators = [
|
1010
|
+
c for c in all_classes if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
|
1011
|
+
]
|
1012
|
+
# get rid of abstract base classes
|
1013
|
+
estimators = [c for c in estimators if not is_abstract(c[1])]
|
1014
|
+
|
1015
|
+
if type_filter is not None:
|
1016
|
+
# copy the object if type_filter is a list
|
1017
|
+
type_filter = list(type_filter) if isinstance(type_filter, list) else [type_filter]
|
1018
|
+
filtered_estimators = []
|
1019
|
+
filters = {
|
1020
|
+
"classifier": ClassifierMixin,
|
1021
|
+
"regressor": RegressorMixin,
|
1022
|
+
"transformer": TransformerMixin,
|
1023
|
+
"cluster": ClusterMixin,
|
1024
|
+
}
|
1025
|
+
for name, mixin in filters.items():
|
1026
|
+
if name in type_filter:
|
1027
|
+
type_filter.remove(name)
|
1028
|
+
filtered_estimators.extend([est for est in estimators if issubclass(est[1], mixin)])
|
1029
|
+
estimators = filtered_estimators
|
1030
|
+
if type_filter:
|
1031
|
+
raise ValueError(
|
1032
|
+
"Parameter type_filter must be 'classifier', "
|
1033
|
+
"'regressor', 'transformer', 'cluster' or "
|
1034
|
+
"None, got"
|
1035
|
+
f" {type_filter!r}"
|
1036
|
+
)
|
1037
|
+
|
1038
|
+
# drop duplicates, sort for reproducibility
|
1039
|
+
# itemgetter is used to ensure the sort does not extend to the 2nd item of
|
1040
|
+
# the tuple
|
1041
|
+
return sorted(set(estimators), key=itemgetter(0))
|