genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,704 @@
|
|
1
|
+
import logging
|
2
|
+
import math
|
3
|
+
from collections import namedtuple
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
from sklearn import metrics as sk_metrics
|
10
|
+
|
11
|
+
import mlflow
|
12
|
+
from mlflow import MlflowException
|
13
|
+
from mlflow.environment_variables import _MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS
|
14
|
+
from mlflow.models.evaluation.artifacts import CsvEvaluationArtifact
|
15
|
+
from mlflow.models.evaluation.base import EvaluationMetric, EvaluationResult, _ModelType
|
16
|
+
from mlflow.models.evaluation.default_evaluator import (
|
17
|
+
BuiltInEvaluator,
|
18
|
+
_extract_raw_model,
|
19
|
+
_get_aggregate_metrics_values,
|
20
|
+
)
|
21
|
+
from mlflow.models.utils import plot_lines
|
22
|
+
|
23
|
+
_logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
_Curve = namedtuple("_Curve", ["plot_fn", "plot_fn_args", "auc"])
|
27
|
+
|
28
|
+
|
29
|
+
class ClassifierEvaluator(BuiltInEvaluator):
|
30
|
+
"""
|
31
|
+
A built-in evaluator for classifier models.
|
32
|
+
"""
|
33
|
+
|
34
|
+
name = "classifier"
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def can_evaluate(cls, *, model_type, evaluator_config, **kwargs):
|
38
|
+
# TODO: Also the model needs to be pyfunc model, not function or endpoint URI
|
39
|
+
return model_type == _ModelType.CLASSIFIER
|
40
|
+
|
41
|
+
def _evaluate(
|
42
|
+
self,
|
43
|
+
model: Optional["mlflow.pyfunc.PyFuncModel"],
|
44
|
+
extra_metrics: list[EvaluationMetric],
|
45
|
+
custom_artifacts=None,
|
46
|
+
**kwargs,
|
47
|
+
) -> Optional[EvaluationResult]:
|
48
|
+
# Get classification config
|
49
|
+
self.y_true = self.dataset.labels_data
|
50
|
+
self.label_list = self.evaluator_config.get("label_list")
|
51
|
+
self.pos_label = self.evaluator_config.get("pos_label")
|
52
|
+
self.sample_weights = self.evaluator_config.get("sample_weights")
|
53
|
+
if self.pos_label and self.label_list and self.pos_label not in self.label_list:
|
54
|
+
raise MlflowException.invalid_parameter_value(
|
55
|
+
f"'pos_label' {self.pos_label} must exist in 'label_list' {self.label_list}."
|
56
|
+
)
|
57
|
+
|
58
|
+
# Check if the model_type is consistent with ground truth labels
|
59
|
+
inferred_model_type = _infer_model_type_by_labels(self.y_true)
|
60
|
+
if _ModelType.CLASSIFIER != inferred_model_type:
|
61
|
+
_logger.warning(
|
62
|
+
f"According to the evaluation dataset label values, the model type looks like "
|
63
|
+
f"{inferred_model_type}, but you specified model type 'classifier'. Please "
|
64
|
+
f"verify that you set the `model_type` and `dataset` arguments correctly."
|
65
|
+
)
|
66
|
+
|
67
|
+
# Run model prediction
|
68
|
+
input_df = self.X.copy_to_avoid_mutation()
|
69
|
+
self.y_pred, self.y_probs = self._generate_model_predictions(model, input_df)
|
70
|
+
|
71
|
+
self._validate_label_list()
|
72
|
+
|
73
|
+
self._compute_builtin_metrics(model)
|
74
|
+
self.evaluate_metrics(extra_metrics, prediction=self.y_pred, target=self.y_true)
|
75
|
+
self.evaluate_and_log_custom_artifacts(
|
76
|
+
custom_artifacts, prediction=self.y_pred, target=self.y_true
|
77
|
+
)
|
78
|
+
|
79
|
+
# Log metrics and artifacts
|
80
|
+
self.log_metrics()
|
81
|
+
self.log_eval_table(self.y_pred)
|
82
|
+
|
83
|
+
if len(self.label_list) == 2:
|
84
|
+
self._log_binary_classifier_artifacts()
|
85
|
+
else:
|
86
|
+
self._log_multiclass_classifier_artifacts()
|
87
|
+
self._log_confusion_matrix()
|
88
|
+
|
89
|
+
return EvaluationResult(
|
90
|
+
metrics=self.aggregate_metrics, artifacts=self.artifacts, run_id=self.run_id
|
91
|
+
)
|
92
|
+
|
93
|
+
def _generate_model_predictions(self, model, input_df):
|
94
|
+
predict_fn, predict_proba_fn = _extract_predict_fn_and_prodict_proba_fn(model)
|
95
|
+
# Classifier model is guaranteed to output single column of predictions
|
96
|
+
y_pred = self.dataset.predictions_data if model is None else predict_fn(input_df)
|
97
|
+
|
98
|
+
# Predict class probabilities if the model supports it
|
99
|
+
y_probs = predict_proba_fn(input_df) if predict_proba_fn is not None else None
|
100
|
+
return y_pred, y_probs
|
101
|
+
|
102
|
+
def _validate_label_list(self):
|
103
|
+
if self.label_list is None:
|
104
|
+
# If label list is not specified, infer label list from model output
|
105
|
+
self.label_list = np.unique(np.concatenate([self.y_true, self.y_pred]))
|
106
|
+
else:
|
107
|
+
# np.where only works for numpy array, not list
|
108
|
+
self.label_list = np.array(self.label_list)
|
109
|
+
|
110
|
+
if len(self.label_list) < 2:
|
111
|
+
raise MlflowException(
|
112
|
+
"Evaluation dataset for classification must contain at least two unique "
|
113
|
+
f"labels, but only {len(self.label_list)} unique labels were found.",
|
114
|
+
"Please provide a 'label_list' parameter in 'evaluator_config' with all "
|
115
|
+
"possible classes, e.g., evaluator_config={{'label_list': [0, 1]}}.",
|
116
|
+
)
|
117
|
+
|
118
|
+
# sort label_list ASC, for binary classification it makes sure the last one is pos label
|
119
|
+
self.label_list.sort()
|
120
|
+
|
121
|
+
if len(self.label_list) == 2:
|
122
|
+
if self.pos_label is None:
|
123
|
+
self.pos_label = self.label_list[-1]
|
124
|
+
else:
|
125
|
+
if self.pos_label in self.label_list:
|
126
|
+
self.label_list = np.delete(
|
127
|
+
self.label_list, np.where(self.label_list == self.pos_label)
|
128
|
+
)
|
129
|
+
self.label_list = np.append(self.label_list, self.pos_label)
|
130
|
+
with _suppress_class_imbalance_errors(IndexError, log_warning=False):
|
131
|
+
_logger.info(
|
132
|
+
"The evaluation dataset is inferred as binary dataset, positive label is "
|
133
|
+
f"{self.label_list[1]}, negative label is {self.label_list[0]}."
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
_logger.info(
|
137
|
+
"The evaluation dataset is inferred as multiclass dataset, number of classes "
|
138
|
+
f"is inferred as {len(self.label_list)}. If this is incorrect, please specify the "
|
139
|
+
"`label_list` parameter in `evaluator_config`."
|
140
|
+
)
|
141
|
+
|
142
|
+
def _compute_builtin_metrics(self, model):
|
143
|
+
self._evaluate_sklearn_model_score_if_scorable(model, self.y_true, self.sample_weights)
|
144
|
+
|
145
|
+
if len(self.label_list) == 2:
|
146
|
+
metrics = _get_binary_classifier_metrics(
|
147
|
+
y_true=self.y_true,
|
148
|
+
y_pred=self.y_pred,
|
149
|
+
y_proba=self.y_probs,
|
150
|
+
labels=self.label_list,
|
151
|
+
pos_label=self.pos_label,
|
152
|
+
sample_weights=self.sample_weights,
|
153
|
+
)
|
154
|
+
if metrics:
|
155
|
+
self.metrics_values.update(_get_aggregate_metrics_values(metrics))
|
156
|
+
self._compute_roc_and_pr_curve()
|
157
|
+
else:
|
158
|
+
average = self.evaluator_config.get("average", "weighted")
|
159
|
+
metrics = _get_multiclass_classifier_metrics(
|
160
|
+
y_true=self.y_true,
|
161
|
+
y_pred=self.y_pred,
|
162
|
+
y_proba=self.y_probs,
|
163
|
+
labels=self.label_list,
|
164
|
+
average=average,
|
165
|
+
sample_weights=self.sample_weights,
|
166
|
+
)
|
167
|
+
if metrics:
|
168
|
+
self.metrics_values.update(_get_aggregate_metrics_values(metrics))
|
169
|
+
|
170
|
+
def _compute_roc_and_pr_curve(self):
|
171
|
+
if self.y_probs is not None:
|
172
|
+
with _suppress_class_imbalance_errors(ValueError, log_warning=False):
|
173
|
+
self.roc_curve = _gen_classifier_curve(
|
174
|
+
is_binomial=True,
|
175
|
+
y=self.y_true,
|
176
|
+
y_probs=self.y_probs[:, 1],
|
177
|
+
labels=self.label_list,
|
178
|
+
pos_label=self.pos_label,
|
179
|
+
curve_type="roc",
|
180
|
+
sample_weights=self.sample_weights,
|
181
|
+
)
|
182
|
+
|
183
|
+
self.metrics_values.update(
|
184
|
+
_get_aggregate_metrics_values({"roc_auc": self.roc_curve.auc})
|
185
|
+
)
|
186
|
+
with _suppress_class_imbalance_errors(ValueError, log_warning=False):
|
187
|
+
self.pr_curve = _gen_classifier_curve(
|
188
|
+
is_binomial=True,
|
189
|
+
y=self.y_true,
|
190
|
+
y_probs=self.y_probs[:, 1],
|
191
|
+
labels=self.label_list,
|
192
|
+
pos_label=self.pos_label,
|
193
|
+
curve_type="pr",
|
194
|
+
sample_weights=self.sample_weights,
|
195
|
+
)
|
196
|
+
|
197
|
+
self.metrics_values.update(
|
198
|
+
_get_aggregate_metrics_values({"precision_recall_auc": self.pr_curve.auc})
|
199
|
+
)
|
200
|
+
|
201
|
+
def _log_pandas_df_artifact(self, pandas_df, artifact_name):
|
202
|
+
artifact_file_name = f"{artifact_name}.csv"
|
203
|
+
artifact_file_local_path = self.temp_dir.path(artifact_file_name)
|
204
|
+
pandas_df.to_csv(artifact_file_local_path, index=False)
|
205
|
+
mlflow.log_artifact(artifact_file_local_path)
|
206
|
+
artifact = CsvEvaluationArtifact(
|
207
|
+
uri=mlflow.get_artifact_uri(artifact_file_name),
|
208
|
+
content=pandas_df,
|
209
|
+
)
|
210
|
+
artifact._load(artifact_file_local_path)
|
211
|
+
self.artifacts[artifact_name] = artifact
|
212
|
+
|
213
|
+
def _log_multiclass_classifier_artifacts(self):
|
214
|
+
per_class_metrics_collection_df = _get_classifier_per_class_metrics_collection_df(
|
215
|
+
y=self.y_true,
|
216
|
+
y_pred=self.y_pred,
|
217
|
+
labels=self.label_list,
|
218
|
+
sample_weights=self.sample_weights,
|
219
|
+
)
|
220
|
+
|
221
|
+
log_roc_pr_curve = False
|
222
|
+
if self.y_probs is not None:
|
223
|
+
with _suppress_class_imbalance_errors(TypeError, log_warning=False):
|
224
|
+
self._log_calibration_curve()
|
225
|
+
|
226
|
+
max_classes_for_multiclass_roc_pr = self.evaluator_config.get(
|
227
|
+
"max_classes_for_multiclass_roc_pr", 10
|
228
|
+
)
|
229
|
+
if len(self.label_list) <= max_classes_for_multiclass_roc_pr:
|
230
|
+
log_roc_pr_curve = True
|
231
|
+
else:
|
232
|
+
_logger.warning(
|
233
|
+
f"The classifier num_classes > {max_classes_for_multiclass_roc_pr}, skip "
|
234
|
+
f"logging ROC curve and Precision-Recall curve. You can add evaluator config "
|
235
|
+
f"'max_classes_for_multiclass_roc_pr' to increase the threshold."
|
236
|
+
)
|
237
|
+
|
238
|
+
if log_roc_pr_curve:
|
239
|
+
roc_curve = _gen_classifier_curve(
|
240
|
+
is_binomial=False,
|
241
|
+
y=self.y_true,
|
242
|
+
y_probs=self.y_probs,
|
243
|
+
labels=self.label_list,
|
244
|
+
pos_label=self.pos_label,
|
245
|
+
curve_type="roc",
|
246
|
+
sample_weights=self.sample_weights,
|
247
|
+
)
|
248
|
+
|
249
|
+
def plot_roc_curve():
|
250
|
+
roc_curve.plot_fn(**roc_curve.plot_fn_args)
|
251
|
+
|
252
|
+
self._log_image_artifact(plot_roc_curve, "roc_curve_plot")
|
253
|
+
per_class_metrics_collection_df["roc_auc"] = roc_curve.auc
|
254
|
+
|
255
|
+
pr_curve = _gen_classifier_curve(
|
256
|
+
is_binomial=False,
|
257
|
+
y=self.y_true,
|
258
|
+
y_probs=self.y_probs,
|
259
|
+
labels=self.label_list,
|
260
|
+
pos_label=self.pos_label,
|
261
|
+
curve_type="pr",
|
262
|
+
sample_weights=self.sample_weights,
|
263
|
+
)
|
264
|
+
|
265
|
+
def plot_pr_curve():
|
266
|
+
pr_curve.plot_fn(**pr_curve.plot_fn_args)
|
267
|
+
|
268
|
+
self._log_image_artifact(plot_pr_curve, "precision_recall_curve_plot")
|
269
|
+
per_class_metrics_collection_df["precision_recall_auc"] = pr_curve.auc
|
270
|
+
|
271
|
+
self._log_pandas_df_artifact(per_class_metrics_collection_df, "per_class_metrics")
|
272
|
+
|
273
|
+
def _log_roc_curve(self):
|
274
|
+
def _plot_roc_curve():
|
275
|
+
self.roc_curve.plot_fn(**self.roc_curve.plot_fn_args)
|
276
|
+
|
277
|
+
self._log_image_artifact(_plot_roc_curve, "roc_curve_plot")
|
278
|
+
|
279
|
+
def _log_precision_recall_curve(self):
|
280
|
+
def _plot_pr_curve():
|
281
|
+
self.pr_curve.plot_fn(**self.pr_curve.plot_fn_args)
|
282
|
+
|
283
|
+
self._log_image_artifact(_plot_pr_curve, "precision_recall_curve_plot")
|
284
|
+
|
285
|
+
def _log_lift_curve(self):
|
286
|
+
from mlflow.models.evaluation.lift_curve import plot_lift_curve
|
287
|
+
|
288
|
+
def _plot_lift_curve():
|
289
|
+
return plot_lift_curve(self.y_true, self.y_probs, pos_label=self.pos_label)
|
290
|
+
|
291
|
+
self._log_image_artifact(_plot_lift_curve, "lift_curve_plot")
|
292
|
+
|
293
|
+
def _log_calibration_curve(self):
|
294
|
+
from mlflow.models.evaluation.calibration_curve import plot_calibration_curve
|
295
|
+
|
296
|
+
def _plot_calibration_curve():
|
297
|
+
return plot_calibration_curve(
|
298
|
+
y_true=self.y_true,
|
299
|
+
y_probs=self.y_probs,
|
300
|
+
pos_label=self.pos_label,
|
301
|
+
calibration_config={
|
302
|
+
k: v for k, v in self.evaluator_config.items() if k.startswith("calibration_")
|
303
|
+
},
|
304
|
+
label_list=self.label_list,
|
305
|
+
)
|
306
|
+
|
307
|
+
self._log_image_artifact(_plot_calibration_curve, "calibration_curve_plot")
|
308
|
+
|
309
|
+
def _log_binary_classifier_artifacts(self):
|
310
|
+
if self.y_probs is not None:
|
311
|
+
with _suppress_class_imbalance_errors(log_warning=False):
|
312
|
+
self._log_roc_curve()
|
313
|
+
with _suppress_class_imbalance_errors(log_warning=False):
|
314
|
+
self._log_precision_recall_curve()
|
315
|
+
with _suppress_class_imbalance_errors(ValueError, log_warning=False):
|
316
|
+
self._log_lift_curve()
|
317
|
+
with _suppress_class_imbalance_errors(TypeError, log_warning=False):
|
318
|
+
self._log_calibration_curve()
|
319
|
+
|
320
|
+
def _log_confusion_matrix(self):
|
321
|
+
"""
|
322
|
+
Helper method for logging confusion matrix
|
323
|
+
"""
|
324
|
+
# normalize the confusion matrix, keep consistent with sklearn autologging.
|
325
|
+
confusion_matrix = sk_metrics.confusion_matrix(
|
326
|
+
self.y_true,
|
327
|
+
self.y_pred,
|
328
|
+
labels=self.label_list,
|
329
|
+
normalize="true",
|
330
|
+
sample_weight=self.sample_weights,
|
331
|
+
)
|
332
|
+
|
333
|
+
def plot_confusion_matrix():
|
334
|
+
import matplotlib
|
335
|
+
import matplotlib.pyplot as plt
|
336
|
+
|
337
|
+
with matplotlib.rc_context(
|
338
|
+
{
|
339
|
+
"font.size": min(8, math.ceil(50.0 / len(self.label_list))),
|
340
|
+
"axes.labelsize": 8,
|
341
|
+
}
|
342
|
+
):
|
343
|
+
_, ax = plt.subplots(1, 1, figsize=(6.0, 4.0), dpi=175)
|
344
|
+
disp = sk_metrics.ConfusionMatrixDisplay(
|
345
|
+
confusion_matrix=confusion_matrix,
|
346
|
+
display_labels=self.label_list,
|
347
|
+
).plot(cmap="Blues", ax=ax)
|
348
|
+
disp.ax_.set_title("Normalized confusion matrix")
|
349
|
+
|
350
|
+
if hasattr(sk_metrics, "ConfusionMatrixDisplay"):
|
351
|
+
self._log_image_artifact(
|
352
|
+
plot_confusion_matrix,
|
353
|
+
"confusion_matrix",
|
354
|
+
)
|
355
|
+
return
|
356
|
+
|
357
|
+
|
358
|
+
def _is_categorical(values):
|
359
|
+
"""
|
360
|
+
Infer whether input values are categorical on best effort.
|
361
|
+
Return True represent they are categorical, return False represent we cannot determine result.
|
362
|
+
"""
|
363
|
+
dtype_name = pd.Series(values).convert_dtypes().dtype.name.lower()
|
364
|
+
return dtype_name in ["category", "string", "boolean"]
|
365
|
+
|
366
|
+
|
367
|
+
def _is_continuous(values):
|
368
|
+
"""
|
369
|
+
Infer whether input values is continuous on best effort.
|
370
|
+
Return True represent they are continuous, return False represent we cannot determine result.
|
371
|
+
"""
|
372
|
+
dtype_name = pd.Series(values).convert_dtypes().dtype.name.lower()
|
373
|
+
return dtype_name.startswith("float")
|
374
|
+
|
375
|
+
|
376
|
+
def _infer_model_type_by_labels(labels):
|
377
|
+
"""
|
378
|
+
Infer model type by target values.
|
379
|
+
"""
|
380
|
+
if _is_categorical(labels):
|
381
|
+
return _ModelType.CLASSIFIER
|
382
|
+
elif _is_continuous(labels):
|
383
|
+
return _ModelType.REGRESSOR
|
384
|
+
else:
|
385
|
+
return None # Unknown
|
386
|
+
|
387
|
+
|
388
|
+
def _extract_predict_fn_and_prodict_proba_fn(model):
|
389
|
+
predict_fn = None
|
390
|
+
predict_proba_fn = None
|
391
|
+
|
392
|
+
_, raw_model = _extract_raw_model(model)
|
393
|
+
|
394
|
+
if raw_model is not None:
|
395
|
+
predict_fn = raw_model.predict
|
396
|
+
predict_proba_fn = getattr(raw_model, "predict_proba", None)
|
397
|
+
try:
|
398
|
+
from mlflow.xgboost import (
|
399
|
+
_wrapped_xgboost_model_predict_fn,
|
400
|
+
_wrapped_xgboost_model_predict_proba_fn,
|
401
|
+
)
|
402
|
+
|
403
|
+
# Because shap evaluation will pass evaluation data in ndarray format
|
404
|
+
# (without feature names), if set validate_features=True it will raise error.
|
405
|
+
predict_fn = _wrapped_xgboost_model_predict_fn(raw_model, validate_features=False)
|
406
|
+
predict_proba_fn = _wrapped_xgboost_model_predict_proba_fn(
|
407
|
+
raw_model, validate_features=False
|
408
|
+
)
|
409
|
+
except ImportError:
|
410
|
+
pass
|
411
|
+
elif model is not None:
|
412
|
+
predict_fn = model.predict
|
413
|
+
|
414
|
+
return predict_fn, predict_proba_fn
|
415
|
+
|
416
|
+
|
417
|
+
@contextmanager
|
418
|
+
def _suppress_class_imbalance_errors(exception_type=Exception, log_warning=True):
|
419
|
+
"""
|
420
|
+
Exception handler context manager to suppress Exceptions if the private environment
|
421
|
+
variable `_MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS` is set to `True`.
|
422
|
+
The purpose of this handler is to prevent an evaluation call for a binary or multiclass
|
423
|
+
classification automl run from aborting due to an extreme minority class imbalance
|
424
|
+
encountered during iterative training cycles due to the non deterministic sampling
|
425
|
+
behavior of Spark's DataFrame.sample() API.
|
426
|
+
The Exceptions caught in the usage of this are broad and are designed purely to not
|
427
|
+
interrupt the iterative hyperparameter tuning process. Final evaluations are done
|
428
|
+
in a more deterministic (but expensive) fashion.
|
429
|
+
"""
|
430
|
+
try:
|
431
|
+
yield
|
432
|
+
except exception_type as e:
|
433
|
+
if _MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS.get():
|
434
|
+
if log_warning:
|
435
|
+
_logger.warning(
|
436
|
+
"Failed to calculate metrics due to class imbalance. "
|
437
|
+
"This is expected when the dataset is imbalanced."
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
raise e
|
441
|
+
|
442
|
+
|
443
|
+
def _get_binary_sum_up_label_pred_prob(positive_class_index, positive_class, y, y_pred, y_probs):
|
444
|
+
y = np.array(y)
|
445
|
+
y_bin = np.where(y == positive_class, 1, 0)
|
446
|
+
y_pred_bin = None
|
447
|
+
y_prob_bin = None
|
448
|
+
if y_pred is not None:
|
449
|
+
y_pred = np.array(y_pred)
|
450
|
+
y_pred_bin = np.where(y_pred == positive_class, 1, 0)
|
451
|
+
|
452
|
+
if y_probs is not None:
|
453
|
+
y_probs = np.array(y_probs)
|
454
|
+
y_prob_bin = y_probs[:, positive_class_index]
|
455
|
+
|
456
|
+
return y_bin, y_pred_bin, y_prob_bin
|
457
|
+
|
458
|
+
|
459
|
+
def _get_common_classifier_metrics(
|
460
|
+
*, y_true, y_pred, y_proba, labels, average, pos_label, sample_weights
|
461
|
+
):
|
462
|
+
metrics = {
|
463
|
+
"example_count": len(y_true),
|
464
|
+
"accuracy_score": sk_metrics.accuracy_score(y_true, y_pred, sample_weight=sample_weights),
|
465
|
+
"recall_score": sk_metrics.recall_score(
|
466
|
+
y_true,
|
467
|
+
y_pred,
|
468
|
+
average=average,
|
469
|
+
pos_label=pos_label,
|
470
|
+
sample_weight=sample_weights,
|
471
|
+
),
|
472
|
+
"precision_score": sk_metrics.precision_score(
|
473
|
+
y_true,
|
474
|
+
y_pred,
|
475
|
+
average=average,
|
476
|
+
pos_label=pos_label,
|
477
|
+
sample_weight=sample_weights,
|
478
|
+
),
|
479
|
+
"f1_score": sk_metrics.f1_score(
|
480
|
+
y_true,
|
481
|
+
y_pred,
|
482
|
+
average=average,
|
483
|
+
pos_label=pos_label,
|
484
|
+
sample_weight=sample_weights,
|
485
|
+
),
|
486
|
+
}
|
487
|
+
|
488
|
+
if y_proba is not None:
|
489
|
+
with _suppress_class_imbalance_errors(ValueError):
|
490
|
+
metrics["log_loss"] = sk_metrics.log_loss(
|
491
|
+
y_true, y_proba, labels=labels, sample_weight=sample_weights
|
492
|
+
)
|
493
|
+
return metrics
|
494
|
+
|
495
|
+
|
496
|
+
def _get_binary_classifier_metrics(
|
497
|
+
*, y_true, y_pred, y_proba=None, labels=None, pos_label=1, sample_weights=None
|
498
|
+
):
|
499
|
+
with _suppress_class_imbalance_errors(ValueError):
|
500
|
+
tn, fp, fn, tp = sk_metrics.confusion_matrix(y_true, y_pred, labels=labels).ravel()
|
501
|
+
return {
|
502
|
+
"true_negatives": tn,
|
503
|
+
"false_positives": fp,
|
504
|
+
"false_negatives": fn,
|
505
|
+
"true_positives": tp,
|
506
|
+
**_get_common_classifier_metrics(
|
507
|
+
y_true=y_true,
|
508
|
+
y_pred=y_pred,
|
509
|
+
y_proba=y_proba,
|
510
|
+
labels=labels,
|
511
|
+
average="binary",
|
512
|
+
pos_label=pos_label,
|
513
|
+
sample_weights=sample_weights,
|
514
|
+
),
|
515
|
+
}
|
516
|
+
|
517
|
+
|
518
|
+
def _get_multiclass_classifier_metrics(
|
519
|
+
*,
|
520
|
+
y_true,
|
521
|
+
y_pred,
|
522
|
+
y_proba=None,
|
523
|
+
labels=None,
|
524
|
+
average="weighted",
|
525
|
+
sample_weights=None,
|
526
|
+
):
|
527
|
+
metrics = _get_common_classifier_metrics(
|
528
|
+
y_true=y_true,
|
529
|
+
y_pred=y_pred,
|
530
|
+
y_proba=y_proba,
|
531
|
+
labels=labels,
|
532
|
+
average=average,
|
533
|
+
pos_label=None,
|
534
|
+
sample_weights=sample_weights,
|
535
|
+
)
|
536
|
+
if average in ("macro", "weighted") and y_proba is not None:
|
537
|
+
metrics.update(
|
538
|
+
roc_auc=sk_metrics.roc_auc_score(
|
539
|
+
y_true=y_true,
|
540
|
+
y_score=y_proba,
|
541
|
+
sample_weight=sample_weights,
|
542
|
+
average=average,
|
543
|
+
multi_class="ovr",
|
544
|
+
)
|
545
|
+
)
|
546
|
+
return metrics
|
547
|
+
|
548
|
+
|
549
|
+
def _get_classifier_per_class_metrics_collection_df(y, y_pred, labels, sample_weights):
|
550
|
+
per_class_metrics_list = []
|
551
|
+
for positive_class_index, positive_class in enumerate(labels):
|
552
|
+
(
|
553
|
+
y_bin,
|
554
|
+
y_pred_bin,
|
555
|
+
_,
|
556
|
+
) = _get_binary_sum_up_label_pred_prob(
|
557
|
+
positive_class_index, positive_class, y, y_pred, None
|
558
|
+
)
|
559
|
+
per_class_metrics = {"positive_class": positive_class}
|
560
|
+
binary_classifier_metrics = _get_binary_classifier_metrics(
|
561
|
+
y_true=y_bin,
|
562
|
+
y_pred=y_pred_bin,
|
563
|
+
labels=[0, 1], # Use binary labels for per-class metrics
|
564
|
+
pos_label=1,
|
565
|
+
sample_weights=sample_weights,
|
566
|
+
)
|
567
|
+
if binary_classifier_metrics:
|
568
|
+
per_class_metrics.update(binary_classifier_metrics)
|
569
|
+
per_class_metrics_list.append(per_class_metrics)
|
570
|
+
|
571
|
+
return pd.DataFrame(per_class_metrics_list)
|
572
|
+
|
573
|
+
|
574
|
+
_Curve = namedtuple("_Curve", ["plot_fn", "plot_fn_args", "auc"])
|
575
|
+
|
576
|
+
|
577
|
+
def _gen_classifier_curve(
|
578
|
+
is_binomial,
|
579
|
+
y,
|
580
|
+
y_probs,
|
581
|
+
labels,
|
582
|
+
pos_label,
|
583
|
+
curve_type,
|
584
|
+
sample_weights,
|
585
|
+
):
|
586
|
+
"""
|
587
|
+
Generate precision-recall curve or ROC curve for classifier.
|
588
|
+
|
589
|
+
Args:
|
590
|
+
is_binomial: True if it is binary classifier otherwise False
|
591
|
+
y: True label values
|
592
|
+
y_probs: if binary classifier, the predicted probability for positive class.
|
593
|
+
if multiclass classifier, the predicted probabilities for all classes.
|
594
|
+
labels: The set of labels.
|
595
|
+
pos_label: The label of the positive class.
|
596
|
+
curve_type: "pr" or "roc"
|
597
|
+
sample_weights: Optional sample weights.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
An instance of "_Curve" which includes attributes "plot_fn", "plot_fn_args", "auc".
|
601
|
+
"""
|
602
|
+
if curve_type == "roc":
|
603
|
+
|
604
|
+
def gen_line_x_y_label_auc(_y, _y_prob, _pos_label):
|
605
|
+
fpr, tpr, _ = sk_metrics.roc_curve(
|
606
|
+
_y,
|
607
|
+
_y_prob,
|
608
|
+
sample_weight=sample_weights,
|
609
|
+
# For multiclass classification where a one-vs-rest ROC curve is produced for each
|
610
|
+
# class, the positive label is binarized and should not be included in the plot
|
611
|
+
# legend
|
612
|
+
pos_label=_pos_label if _pos_label == pos_label else None,
|
613
|
+
)
|
614
|
+
|
615
|
+
auc = sk_metrics.roc_auc_score(y_true=_y, y_score=_y_prob, sample_weight=sample_weights)
|
616
|
+
return fpr, tpr, f"AUC={auc:.3f}", auc
|
617
|
+
|
618
|
+
xlabel = "False Positive Rate"
|
619
|
+
ylabel = "True Positive Rate"
|
620
|
+
title = "ROC curve"
|
621
|
+
if pos_label:
|
622
|
+
xlabel = f"False Positive Rate (Positive label: {pos_label})"
|
623
|
+
ylabel = f"True Positive Rate (Positive label: {pos_label})"
|
624
|
+
elif curve_type == "pr":
|
625
|
+
|
626
|
+
def gen_line_x_y_label_auc(_y, _y_prob, _pos_label):
|
627
|
+
precision, recall, _ = sk_metrics.precision_recall_curve(
|
628
|
+
_y,
|
629
|
+
_y_prob,
|
630
|
+
sample_weight=sample_weights,
|
631
|
+
# For multiclass classification where a one-vs-rest precision-recall curve is
|
632
|
+
# produced for each class, the positive label is binarized and should not be
|
633
|
+
# included in the plot legend
|
634
|
+
pos_label=_pos_label if _pos_label == pos_label else None,
|
635
|
+
)
|
636
|
+
# NB: We return average precision score (AP) instead of AUC because AP is more
|
637
|
+
# appropriate for summarizing a precision-recall curve
|
638
|
+
ap = sk_metrics.average_precision_score(
|
639
|
+
y_true=_y, y_score=_y_prob, pos_label=_pos_label, sample_weight=sample_weights
|
640
|
+
)
|
641
|
+
return recall, precision, f"AP={ap:.3f}", ap
|
642
|
+
|
643
|
+
xlabel = "Recall"
|
644
|
+
ylabel = "Precision"
|
645
|
+
title = "Precision recall curve"
|
646
|
+
if pos_label:
|
647
|
+
xlabel = f"Recall (Positive label: {pos_label})"
|
648
|
+
ylabel = f"Precision (Positive label: {pos_label})"
|
649
|
+
else:
|
650
|
+
assert False, "illegal curve type"
|
651
|
+
|
652
|
+
if is_binomial:
|
653
|
+
x_data, y_data, line_label, auc = gen_line_x_y_label_auc(y, y_probs, pos_label)
|
654
|
+
data_series = [(line_label, x_data, y_data)]
|
655
|
+
else:
|
656
|
+
curve_list = []
|
657
|
+
for positive_class_index, positive_class in enumerate(labels):
|
658
|
+
y_bin, _, y_prob_bin = _get_binary_sum_up_label_pred_prob(
|
659
|
+
positive_class_index, positive_class, y, labels, y_probs
|
660
|
+
)
|
661
|
+
|
662
|
+
x_data, y_data, line_label, auc = gen_line_x_y_label_auc(
|
663
|
+
y_bin, y_prob_bin, _pos_label=1
|
664
|
+
)
|
665
|
+
curve_list.append((positive_class, x_data, y_data, line_label, auc))
|
666
|
+
|
667
|
+
data_series = [
|
668
|
+
(f"label={positive_class},{line_label}", x_data, y_data)
|
669
|
+
for positive_class, x_data, y_data, line_label, _ in curve_list
|
670
|
+
]
|
671
|
+
auc = [auc for _, _, _, _, auc in curve_list]
|
672
|
+
|
673
|
+
def _do_plot(**kwargs):
|
674
|
+
from matplotlib import pyplot
|
675
|
+
|
676
|
+
_, ax = plot_lines(**kwargs)
|
677
|
+
dash_line_args = {
|
678
|
+
"color": "gray",
|
679
|
+
"alpha": 0.3,
|
680
|
+
"drawstyle": "default",
|
681
|
+
"linestyle": "dashed",
|
682
|
+
}
|
683
|
+
if curve_type == "pr":
|
684
|
+
ax.plot([0, 1], [1, 0], **dash_line_args)
|
685
|
+
elif curve_type == "roc":
|
686
|
+
ax.plot([0, 1], [0, 1], **dash_line_args)
|
687
|
+
|
688
|
+
if is_binomial:
|
689
|
+
ax.legend(loc="best")
|
690
|
+
else:
|
691
|
+
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
692
|
+
pyplot.subplots_adjust(right=0.6, bottom=0.25)
|
693
|
+
|
694
|
+
return _Curve(
|
695
|
+
plot_fn=_do_plot,
|
696
|
+
plot_fn_args={
|
697
|
+
"data_series": data_series,
|
698
|
+
"xlabel": xlabel,
|
699
|
+
"ylabel": ylabel,
|
700
|
+
"line_kwargs": {"drawstyle": "steps-post", "linewidth": 1},
|
701
|
+
"title": title,
|
702
|
+
},
|
703
|
+
auc=auc,
|
704
|
+
)
|