genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/shap/__init__.py
ADDED
@@ -0,0 +1,691 @@
|
|
1
|
+
import os
|
2
|
+
import tempfile
|
3
|
+
import types
|
4
|
+
import warnings
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Optional
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import yaml
|
10
|
+
|
11
|
+
import mlflow
|
12
|
+
import mlflow.utils.autologging_utils
|
13
|
+
from mlflow import pyfunc
|
14
|
+
from mlflow.models import Model, ModelInputExample, ModelSignature
|
15
|
+
from mlflow.models.model import MLMODEL_FILE_NAME
|
16
|
+
from mlflow.models.utils import _save_example
|
17
|
+
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
|
18
|
+
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
|
19
|
+
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
|
20
|
+
from mlflow.utils.environment import (
|
21
|
+
_CONDA_ENV_FILE_NAME,
|
22
|
+
_CONSTRAINTS_FILE_NAME,
|
23
|
+
_PYTHON_ENV_FILE_NAME,
|
24
|
+
_REQUIREMENTS_FILE_NAME,
|
25
|
+
_get_pip_deps,
|
26
|
+
_mlflow_conda_env,
|
27
|
+
_process_conda_env,
|
28
|
+
_process_pip_requirements,
|
29
|
+
_PythonEnv,
|
30
|
+
_validate_env_arguments,
|
31
|
+
)
|
32
|
+
from mlflow.utils.file_utils import write_to
|
33
|
+
from mlflow.utils.model_utils import (
|
34
|
+
_add_code_from_conf_to_system_path,
|
35
|
+
_get_flavor_configuration,
|
36
|
+
_validate_and_copy_code_paths,
|
37
|
+
_validate_and_prepare_target_save_path,
|
38
|
+
)
|
39
|
+
from mlflow.utils.requirements_utils import _get_package_name
|
40
|
+
from mlflow.utils.uri import append_to_uri_path
|
41
|
+
|
42
|
+
FLAVOR_NAME = "shap"
|
43
|
+
|
44
|
+
_MAXIMUM_BACKGROUND_DATA_SIZE = 100
|
45
|
+
_DEFAULT_ARTIFACT_PATH = "model_explanations_shap"
|
46
|
+
_SUMMARY_BAR_PLOT_FILE_NAME = "summary_bar_plot.png"
|
47
|
+
_BASE_VALUES_FILE_NAME = "base_values.npy"
|
48
|
+
_SHAP_VALUES_FILE_NAME = "shap_values.npy"
|
49
|
+
_UNKNOWN_MODEL_FLAVOR = "unknown"
|
50
|
+
_UNDERLYING_MODEL_SUBPATH = "underlying_model"
|
51
|
+
|
52
|
+
|
53
|
+
def get_underlying_model_flavor(model):
|
54
|
+
"""
|
55
|
+
Find the underlying models flavor.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
model: underlying model of the explainer.
|
59
|
+
"""
|
60
|
+
|
61
|
+
# checking if underlying model is wrapped
|
62
|
+
|
63
|
+
if hasattr(model, "inner_model"):
|
64
|
+
unwrapped_model = model.inner_model
|
65
|
+
|
66
|
+
# check if passed model is a method of object
|
67
|
+
if isinstance(unwrapped_model, types.MethodType):
|
68
|
+
model_object = unwrapped_model.__self__
|
69
|
+
|
70
|
+
# check if model object is of type sklearn
|
71
|
+
try:
|
72
|
+
import sklearn
|
73
|
+
|
74
|
+
if issubclass(type(model_object), sklearn.base.BaseEstimator):
|
75
|
+
return mlflow.sklearn.FLAVOR_NAME
|
76
|
+
except ImportError:
|
77
|
+
pass
|
78
|
+
|
79
|
+
# check if passed model is of type pytorch
|
80
|
+
try:
|
81
|
+
import torch
|
82
|
+
|
83
|
+
if issubclass(type(unwrapped_model), torch.nn.Module):
|
84
|
+
return mlflow.pytorch.FLAVOR_NAME
|
85
|
+
except ImportError:
|
86
|
+
pass
|
87
|
+
|
88
|
+
return _UNKNOWN_MODEL_FLAVOR
|
89
|
+
|
90
|
+
|
91
|
+
def get_default_pip_requirements():
|
92
|
+
"""
|
93
|
+
A list of default pip requirements for MLflow Models produced by this flavor. Calls to
|
94
|
+
:func:`save_explainer()` and :func:`log_explainer()` produce a pip environment that, at
|
95
|
+
minimum, contains these requirements.
|
96
|
+
"""
|
97
|
+
import shap
|
98
|
+
|
99
|
+
return [f"shap=={shap.__version__}"]
|
100
|
+
|
101
|
+
|
102
|
+
def get_default_conda_env():
|
103
|
+
"""
|
104
|
+
Returns:
|
105
|
+
The default Conda environment for MLflow Models produced by calls to
|
106
|
+
:func:`save_explainer()` and :func:`log_explainer()`.
|
107
|
+
"""
|
108
|
+
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
|
109
|
+
|
110
|
+
|
111
|
+
def _load_pyfunc(path):
|
112
|
+
"""
|
113
|
+
Load PyFunc implementation. Called by ``pyfunc.load_model``.
|
114
|
+
"""
|
115
|
+
return _SHAPWrapper(path)
|
116
|
+
|
117
|
+
|
118
|
+
@contextmanager
|
119
|
+
def _log_artifact_contextmanager(out_file, artifact_path=None):
|
120
|
+
"""
|
121
|
+
A context manager to make it easier to log an artifact.
|
122
|
+
"""
|
123
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
124
|
+
tmp_path = os.path.join(tmp_dir, out_file)
|
125
|
+
yield tmp_path
|
126
|
+
mlflow.log_artifact(tmp_path, artifact_path)
|
127
|
+
|
128
|
+
|
129
|
+
def _log_numpy(numpy_obj, out_file, artifact_path=None):
|
130
|
+
"""
|
131
|
+
Log a numpy object.
|
132
|
+
"""
|
133
|
+
with _log_artifact_contextmanager(out_file, artifact_path) as tmp_path:
|
134
|
+
np.save(tmp_path, numpy_obj)
|
135
|
+
|
136
|
+
|
137
|
+
def _log_matplotlib_figure(fig, out_file, artifact_path=None):
|
138
|
+
"""
|
139
|
+
Log a matplotlib figure.
|
140
|
+
"""
|
141
|
+
with _log_artifact_contextmanager(out_file, artifact_path) as tmp_path:
|
142
|
+
fig.savefig(tmp_path)
|
143
|
+
|
144
|
+
|
145
|
+
def _get_conda_env_for_underlying_model(underlying_model_path):
|
146
|
+
underlying_model_conda_path = os.path.join(underlying_model_path, "conda.yaml")
|
147
|
+
with open(underlying_model_conda_path) as underlying_model_conda_file:
|
148
|
+
return yaml.safe_load(underlying_model_conda_file)
|
149
|
+
|
150
|
+
|
151
|
+
def log_explanation(predict_function, features, artifact_path=None):
|
152
|
+
r"""
|
153
|
+
Given a ``predict_function`` capable of computing ML model output on the provided ``features``,
|
154
|
+
computes and logs explanations of an ML model's output. Explanations are logged as a directory
|
155
|
+
of artifacts containing the following items generated by `SHAP`_ (SHapley Additive
|
156
|
+
exPlanations).
|
157
|
+
|
158
|
+
- Base values
|
159
|
+
- SHAP values (computed using `shap.KernelExplainer`_)
|
160
|
+
- Summary bar plot (shows the average impact of each feature on model output)
|
161
|
+
|
162
|
+
Args:
|
163
|
+
predict_function:
|
164
|
+
A function to compute the output of a model (e.g. ``predict_proba`` method of
|
165
|
+
scikit-learn classifiers). Must have the following signature:
|
166
|
+
|
167
|
+
.. code-block:: python
|
168
|
+
|
169
|
+
def predict_function(X) -> pred: ...
|
170
|
+
|
171
|
+
- ``X``: An array-like object whose shape should be (# samples, # features).
|
172
|
+
- ``pred``: An array-like object whose shape should be (# samples) for a regressor or
|
173
|
+
(# classes, # samples) for a classifier. For a classifier, the values in ``pred``
|
174
|
+
should correspond to the predicted probability of each class.
|
175
|
+
|
176
|
+
Acceptable array-like object types:
|
177
|
+
|
178
|
+
- ``numpy.array``
|
179
|
+
- ``pandas.DataFrame``
|
180
|
+
- ``shap.common.DenseData``
|
181
|
+
- ``scipy.sparse matrix``
|
182
|
+
|
183
|
+
features:
|
184
|
+
A matrix of features to compute SHAP values with. The provided features should
|
185
|
+
have shape (# samples, # features), and can be either of the array-like object
|
186
|
+
types listed above.
|
187
|
+
|
188
|
+
.. note::
|
189
|
+
Background data for `shap.KernelExplainer`_ is generated by subsampling ``features``
|
190
|
+
with `shap.kmeans`_. The background data size is limited to 100 rows for performance
|
191
|
+
reasons.
|
192
|
+
|
193
|
+
artifact_path:
|
194
|
+
The run-relative artifact path to which the explanation is saved.
|
195
|
+
If unspecified, defaults to "model_explanations_shap".
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
Artifact URI of the logged explanations.
|
199
|
+
|
200
|
+
.. _SHAP: https://github.com/slundberg/shap
|
201
|
+
|
202
|
+
.. _shap.KernelExplainer: https://shap.readthedocs.io/en/latest/generated
|
203
|
+
/shap.KernelExplainer.html#shap.KernelExplainer
|
204
|
+
|
205
|
+
.. _shap.kmeans: https://github.com/slundberg/shap/blob/v0.36.0/shap/utils/_legacy.py#L9
|
206
|
+
|
207
|
+
.. code-block:: python
|
208
|
+
:caption: Example
|
209
|
+
|
210
|
+
import os
|
211
|
+
|
212
|
+
import numpy as np
|
213
|
+
import pandas as pd
|
214
|
+
from sklearn.datasets import load_diabetes
|
215
|
+
from sklearn.linear_model import LinearRegression
|
216
|
+
|
217
|
+
import mlflow
|
218
|
+
from mlflow import MlflowClient
|
219
|
+
|
220
|
+
# prepare training data
|
221
|
+
X, y = dataset = load_diabetes(return_X_y=True, as_frame=True)
|
222
|
+
X = pd.DataFrame(dataset.data[:50, :8], columns=dataset.feature_names[:8])
|
223
|
+
y = dataset.target[:50]
|
224
|
+
|
225
|
+
# train a model
|
226
|
+
model = LinearRegression()
|
227
|
+
model.fit(X, y)
|
228
|
+
|
229
|
+
# log an explanation
|
230
|
+
with mlflow.start_run() as run:
|
231
|
+
mlflow.shap.log_explanation(model.predict, X)
|
232
|
+
|
233
|
+
# list artifacts
|
234
|
+
client = MlflowClient()
|
235
|
+
artifact_path = "model_explanations_shap"
|
236
|
+
artifacts = [x.path for x in client.list_artifacts(run.info.run_id, artifact_path)]
|
237
|
+
print("# artifacts:")
|
238
|
+
print(artifacts)
|
239
|
+
|
240
|
+
# load back the logged explanation
|
241
|
+
dst_path = client.download_artifacts(run.info.run_id, artifact_path)
|
242
|
+
base_values = np.load(os.path.join(dst_path, "base_values.npy"))
|
243
|
+
shap_values = np.load(os.path.join(dst_path, "shap_values.npy"))
|
244
|
+
|
245
|
+
print("\n# base_values:")
|
246
|
+
print(base_values)
|
247
|
+
print("\n# shap_values:")
|
248
|
+
print(shap_values[:3])
|
249
|
+
|
250
|
+
.. code-block:: text
|
251
|
+
:caption: Output
|
252
|
+
|
253
|
+
# artifacts:
|
254
|
+
['model_explanations_shap/base_values.npy',
|
255
|
+
'model_explanations_shap/shap_values.npy',
|
256
|
+
'model_explanations_shap/summary_bar_plot.png']
|
257
|
+
|
258
|
+
# base_values:
|
259
|
+
20.502000000000002
|
260
|
+
|
261
|
+
# shap_values:
|
262
|
+
[[ 2.09975523 0.4746513 7.63759026 0. ]
|
263
|
+
[ 2.00883109 -0.18816665 -0.14419184 0. ]
|
264
|
+
[ 2.00891772 -0.18816665 -0.14419184 0. ]]
|
265
|
+
|
266
|
+
.. figure:: ../_static/images/shap-ui-screenshot.png
|
267
|
+
|
268
|
+
Logged artifacts
|
269
|
+
"""
|
270
|
+
import matplotlib.pyplot as plt
|
271
|
+
import shap
|
272
|
+
|
273
|
+
artifact_path = _DEFAULT_ARTIFACT_PATH if artifact_path is None else artifact_path
|
274
|
+
with mlflow.utils.autologging_utils.disable_autologging():
|
275
|
+
background_data = shap.kmeans(features, min(_MAXIMUM_BACKGROUND_DATA_SIZE, len(features)))
|
276
|
+
explainer = shap.KernelExplainer(predict_function, background_data)
|
277
|
+
shap_values = explainer.shap_values(features)
|
278
|
+
|
279
|
+
_log_numpy(explainer.expected_value, _BASE_VALUES_FILE_NAME, artifact_path)
|
280
|
+
_log_numpy(shap_values, _SHAP_VALUES_FILE_NAME, artifact_path)
|
281
|
+
|
282
|
+
shap.summary_plot(shap_values, features, plot_type="bar", show=False)
|
283
|
+
fig = plt.gcf()
|
284
|
+
fig.tight_layout()
|
285
|
+
_log_matplotlib_figure(fig, _SUMMARY_BAR_PLOT_FILE_NAME, artifact_path)
|
286
|
+
plt.close(fig)
|
287
|
+
|
288
|
+
return append_to_uri_path(mlflow.active_run().info.artifact_uri, artifact_path)
|
289
|
+
|
290
|
+
|
291
|
+
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
|
292
|
+
def log_explainer(
|
293
|
+
explainer,
|
294
|
+
artifact_path: Optional[str] = None,
|
295
|
+
serialize_model_using_mlflow=True,
|
296
|
+
conda_env=None,
|
297
|
+
code_paths=None,
|
298
|
+
registered_model_name=None,
|
299
|
+
signature: ModelSignature = None,
|
300
|
+
input_example: ModelInputExample = None,
|
301
|
+
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
|
302
|
+
pip_requirements=None,
|
303
|
+
extra_pip_requirements=None,
|
304
|
+
name: Optional[str] = None,
|
305
|
+
metadata=None,
|
306
|
+
params: Optional[dict[str, Any]] = None,
|
307
|
+
tags: Optional[dict[str, Any]] = None,
|
308
|
+
model_type: Optional[str] = None,
|
309
|
+
step: int = 0,
|
310
|
+
model_id: Optional[str] = None,
|
311
|
+
):
|
312
|
+
"""
|
313
|
+
Log an SHAP explainer as an MLflow artifact for the current run.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
explainer: SHAP explainer to be saved.
|
317
|
+
artifact_path: Deprecated. Use `name` instead.
|
318
|
+
serialize_model_using_mlflow: When set to True, MLflow will extract the underlying
|
319
|
+
model and serialize it as an MLmodel, otherwise it uses SHAP's internal serialization.
|
320
|
+
Defaults to True. Currently MLflow serialization is only supported for models of
|
321
|
+
'sklearn' or 'pytorch' flavors.
|
322
|
+
conda_env: {{ conda_env }}
|
323
|
+
code_paths: {{ code_paths }}
|
324
|
+
registered_model_name: If given, create a model version under ``registered_model_name``,
|
325
|
+
also creating a registered model if one with the given name does not exist.
|
326
|
+
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` describes model input
|
327
|
+
and output :py:class:`Schema <mlflow.types.Schema>`. The model signature can be
|
328
|
+
:py:func:`inferred <mlflow.models.infer_signature>` from datasets with valid model input
|
329
|
+
(e.g. the training dataset with target column omitted) and valid model output
|
330
|
+
(e.g. model predictions generated on the training dataset), for example:
|
331
|
+
|
332
|
+
.. code-block:: python
|
333
|
+
|
334
|
+
from mlflow.models import infer_signature
|
335
|
+
|
336
|
+
train = df.drop_column("target_label")
|
337
|
+
predictions = ... # compute model predictions
|
338
|
+
signature = infer_signature(train, predictions)
|
339
|
+
input_example: {{ input_example }}
|
340
|
+
await_registration_for: Number of seconds to wait for the model version to finish
|
341
|
+
being created and is in ``READY`` status. By default, the function waits for five
|
342
|
+
minutes. Specify 0 or None to skip waiting.
|
343
|
+
pip_requirements: {{ pip_requirements }}
|
344
|
+
extra_pip_requirements: {{ extra_pip_requirements }}
|
345
|
+
name: {{ name }}
|
346
|
+
metadata: {{ metadata }}
|
347
|
+
params: {{ params }}
|
348
|
+
tags: {{ tags }}
|
349
|
+
model_type: {{ model_type }}
|
350
|
+
step: {{ step }}
|
351
|
+
model_id: {{ model_id }}
|
352
|
+
"""
|
353
|
+
|
354
|
+
return Model.log(
|
355
|
+
artifact_path=artifact_path,
|
356
|
+
name=name,
|
357
|
+
flavor=mlflow.shap,
|
358
|
+
explainer=explainer,
|
359
|
+
conda_env=conda_env,
|
360
|
+
code_paths=code_paths,
|
361
|
+
serialize_model_using_mlflow=serialize_model_using_mlflow,
|
362
|
+
registered_model_name=registered_model_name,
|
363
|
+
signature=signature,
|
364
|
+
input_example=input_example,
|
365
|
+
await_registration_for=await_registration_for,
|
366
|
+
pip_requirements=pip_requirements,
|
367
|
+
extra_pip_requirements=extra_pip_requirements,
|
368
|
+
metadata=metadata,
|
369
|
+
params=params,
|
370
|
+
tags=tags,
|
371
|
+
model_type=model_type,
|
372
|
+
step=step,
|
373
|
+
model_id=model_id,
|
374
|
+
)
|
375
|
+
|
376
|
+
|
377
|
+
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
|
378
|
+
def save_explainer(
|
379
|
+
explainer,
|
380
|
+
path,
|
381
|
+
serialize_model_using_mlflow=True,
|
382
|
+
conda_env=None,
|
383
|
+
code_paths=None,
|
384
|
+
mlflow_model=None,
|
385
|
+
signature: ModelSignature = None,
|
386
|
+
input_example: ModelInputExample = None,
|
387
|
+
pip_requirements=None,
|
388
|
+
extra_pip_requirements=None,
|
389
|
+
metadata=None,
|
390
|
+
):
|
391
|
+
"""
|
392
|
+
Save a SHAP explainer to a path on the local file system. Produces an MLflow Model
|
393
|
+
containing the following flavors:
|
394
|
+
|
395
|
+
- :py:mod:`mlflow.shap`
|
396
|
+
- :py:mod:`mlflow.pyfunc`
|
397
|
+
|
398
|
+
Args:
|
399
|
+
explainer: SHAP explainer to be saved.
|
400
|
+
path: Local path where the explainer is to be saved.
|
401
|
+
serialize_model_using_mlflow: When set to True, MLflow will extract the underlying
|
402
|
+
model and serialize it as an MLmodel, otherwise it uses SHAP's internal serialization.
|
403
|
+
Defaults to True. Currently MLflow serialization is only supported for models of
|
404
|
+
'sklearn' or 'pytorch' flavors.
|
405
|
+
conda_env: {{ conda_env }}
|
406
|
+
code_paths: {{ code_paths }}
|
407
|
+
mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
|
408
|
+
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` describes model input
|
409
|
+
and output :py:class:`Schema <mlflow.types.Schema>`. The model signature can be
|
410
|
+
:py:func:`inferred <mlflow.models.infer_signature>` from datasets with valid model input
|
411
|
+
(e.g. the training dataset with target column omitted) and valid model output (e.g.
|
412
|
+
model predictions generated on the training dataset), for example:
|
413
|
+
|
414
|
+
.. code-block:: python
|
415
|
+
|
416
|
+
from mlflow.models import infer_signature
|
417
|
+
|
418
|
+
train = df.drop_column("target_label")
|
419
|
+
predictions = ... # compute model predictions
|
420
|
+
signature = infer_signature(train, predictions)
|
421
|
+
input_example: {{ input_example }}
|
422
|
+
pip_requirements: {{ pip_requirements }}
|
423
|
+
extra_pip_requirements: {{ extra_pip_requirements }}
|
424
|
+
metadata: {{ metadata }}
|
425
|
+
"""
|
426
|
+
import shap
|
427
|
+
|
428
|
+
_validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
|
429
|
+
|
430
|
+
_validate_and_prepare_target_save_path(path)
|
431
|
+
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
|
432
|
+
|
433
|
+
if mlflow_model is None:
|
434
|
+
mlflow_model = Model()
|
435
|
+
if signature is not None:
|
436
|
+
mlflow_model.signature = signature
|
437
|
+
if input_example is not None:
|
438
|
+
_save_example(mlflow_model, input_example, path)
|
439
|
+
if metadata is not None:
|
440
|
+
mlflow_model.metadata = metadata
|
441
|
+
|
442
|
+
underlying_model_flavor = None
|
443
|
+
underlying_model_path = None
|
444
|
+
serializable_by_mlflow = False
|
445
|
+
|
446
|
+
# saving the underlying model if required
|
447
|
+
if serialize_model_using_mlflow:
|
448
|
+
underlying_model_flavor = get_underlying_model_flavor(explainer.model)
|
449
|
+
|
450
|
+
if underlying_model_flavor != _UNKNOWN_MODEL_FLAVOR:
|
451
|
+
serializable_by_mlflow = True # prevents SHAP from serializing the underlying model
|
452
|
+
underlying_model_path = os.path.join(path, _UNDERLYING_MODEL_SUBPATH)
|
453
|
+
else:
|
454
|
+
warnings.warn(
|
455
|
+
"Unable to serialize underlying model using MLflow, will use SHAP serialization"
|
456
|
+
)
|
457
|
+
|
458
|
+
if underlying_model_flavor == mlflow.sklearn.FLAVOR_NAME:
|
459
|
+
mlflow.sklearn.save_model(explainer.model.inner_model.__self__, underlying_model_path)
|
460
|
+
elif underlying_model_flavor == mlflow.pytorch.FLAVOR_NAME:
|
461
|
+
mlflow.pytorch.save_model(explainer.model.inner_model, underlying_model_path)
|
462
|
+
|
463
|
+
# saving the explainer object
|
464
|
+
explainer_data_subpath = "explainer.shap"
|
465
|
+
explainer_output_path = os.path.join(path, explainer_data_subpath)
|
466
|
+
with open(explainer_output_path, "wb") as explainer_output_file_handle:
|
467
|
+
if serialize_model_using_mlflow and serializable_by_mlflow:
|
468
|
+
explainer.save(explainer_output_file_handle, model_saver=False)
|
469
|
+
else:
|
470
|
+
explainer.save(explainer_output_file_handle)
|
471
|
+
|
472
|
+
pyfunc.add_to_model(
|
473
|
+
mlflow_model,
|
474
|
+
loader_module="mlflow.shap",
|
475
|
+
model_path=explainer_data_subpath,
|
476
|
+
underlying_model_flavor=underlying_model_flavor,
|
477
|
+
conda_env=_CONDA_ENV_FILE_NAME,
|
478
|
+
python_env=_PYTHON_ENV_FILE_NAME,
|
479
|
+
code=code_dir_subpath,
|
480
|
+
)
|
481
|
+
|
482
|
+
mlflow_model.add_flavor(
|
483
|
+
FLAVOR_NAME,
|
484
|
+
shap_version=shap.__version__,
|
485
|
+
serialized_explainer=explainer_data_subpath,
|
486
|
+
underlying_model_flavor=underlying_model_flavor,
|
487
|
+
code=code_dir_subpath,
|
488
|
+
)
|
489
|
+
|
490
|
+
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
|
491
|
+
|
492
|
+
if conda_env is None:
|
493
|
+
if pip_requirements is None:
|
494
|
+
default_reqs = get_default_pip_requirements()
|
495
|
+
# To ensure `_load_pyfunc` can successfully load the model during the dependency
|
496
|
+
# inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
|
497
|
+
inferred_reqs = mlflow.models.infer_pip_requirements(
|
498
|
+
path,
|
499
|
+
FLAVOR_NAME,
|
500
|
+
fallback=default_reqs,
|
501
|
+
)
|
502
|
+
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
|
503
|
+
else:
|
504
|
+
default_reqs = None
|
505
|
+
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
|
506
|
+
default_reqs,
|
507
|
+
pip_requirements,
|
508
|
+
extra_pip_requirements,
|
509
|
+
)
|
510
|
+
else:
|
511
|
+
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
|
512
|
+
|
513
|
+
if underlying_model_path is not None:
|
514
|
+
underlying_model_conda_env = _get_conda_env_for_underlying_model(underlying_model_path)
|
515
|
+
conda_env = _merge_environments(conda_env, underlying_model_conda_env)
|
516
|
+
pip_requirements = _get_pip_deps(conda_env)
|
517
|
+
|
518
|
+
with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
|
519
|
+
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
|
520
|
+
|
521
|
+
# Save `constraints.txt` if necessary
|
522
|
+
if pip_constraints:
|
523
|
+
write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
|
524
|
+
|
525
|
+
# Save `requirements.txt`
|
526
|
+
write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
|
527
|
+
|
528
|
+
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
|
529
|
+
|
530
|
+
|
531
|
+
# Defining save_model (Required by Model.log) to refer to save_explainer
|
532
|
+
save_model = save_explainer
|
533
|
+
|
534
|
+
|
535
|
+
def _get_conda_and_pip_dependencies(conda_env):
|
536
|
+
"""
|
537
|
+
Extract conda and pip dependencies from conda environments
|
538
|
+
|
539
|
+
Args:
|
540
|
+
conda_env: Conda environment
|
541
|
+
"""
|
542
|
+
|
543
|
+
conda_deps = []
|
544
|
+
# NB: Set operations are required in case there are multiple references of MLflow as a
|
545
|
+
# dependency to ensure that duplicate entries are not present in the final consolidated
|
546
|
+
# dependency list.
|
547
|
+
pip_deps_set = set()
|
548
|
+
|
549
|
+
for dependency in conda_env["dependencies"]:
|
550
|
+
if isinstance(dependency, dict) and dependency["pip"]:
|
551
|
+
for pip_dependency in dependency["pip"]:
|
552
|
+
if pip_dependency != "mlflow":
|
553
|
+
pip_deps_set.add(pip_dependency)
|
554
|
+
else:
|
555
|
+
package_name = _get_package_name(dependency)
|
556
|
+
if package_name is not None and package_name not in ["python", "pip"]:
|
557
|
+
conda_deps.append(dependency)
|
558
|
+
|
559
|
+
return conda_deps, sorted(pip_deps_set)
|
560
|
+
|
561
|
+
|
562
|
+
def _union_lists(l1, l2):
|
563
|
+
"""
|
564
|
+
Returns the union of two lists as a new list.
|
565
|
+
"""
|
566
|
+
return list(dict.fromkeys(l1 + l2))
|
567
|
+
|
568
|
+
|
569
|
+
def _merge_environments(shap_environment, model_environment):
|
570
|
+
"""
|
571
|
+
Merge conda environments of underlying model and shap.
|
572
|
+
|
573
|
+
Args:
|
574
|
+
shap_environment: SHAP conda environment.
|
575
|
+
model_environment: Underlying model conda environment.
|
576
|
+
"""
|
577
|
+
# merge the channels from the two environments and remove the default conda
|
578
|
+
# channels if present since its added later in `_mlflow_conda_env`
|
579
|
+
merged_conda_channels = _union_lists(
|
580
|
+
shap_environment["channels"], model_environment["channels"]
|
581
|
+
)
|
582
|
+
merged_conda_channels = [x for x in merged_conda_channels if x != "conda-forge"]
|
583
|
+
|
584
|
+
shap_conda_deps, shap_pip_deps = _get_conda_and_pip_dependencies(shap_environment)
|
585
|
+
model_conda_deps, model_pip_deps = _get_conda_and_pip_dependencies(model_environment)
|
586
|
+
|
587
|
+
merged_conda_deps = _union_lists(shap_conda_deps, model_conda_deps)
|
588
|
+
merged_pip_deps = _union_lists(shap_pip_deps, model_pip_deps)
|
589
|
+
return _mlflow_conda_env(
|
590
|
+
additional_conda_deps=merged_conda_deps,
|
591
|
+
additional_pip_deps=merged_pip_deps,
|
592
|
+
additional_conda_channels=merged_conda_channels,
|
593
|
+
)
|
594
|
+
|
595
|
+
|
596
|
+
def load_explainer(model_uri):
|
597
|
+
"""
|
598
|
+
Load a SHAP explainer from a local file or a run.
|
599
|
+
|
600
|
+
Args:
|
601
|
+
model_uri: The location, in URI format, of the MLflow model. For example:
|
602
|
+
|
603
|
+
- ``/Users/me/path/to/local/model``
|
604
|
+
- ``relative/path/to/local/model``
|
605
|
+
- ``s3://my_bucket/path/to/model``
|
606
|
+
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
|
607
|
+
- ``models:/<model_name>/<model_version>``
|
608
|
+
- ``models:/<model_name>/<stage>``
|
609
|
+
|
610
|
+
For more information about supported URI schemes, see
|
611
|
+
`Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
|
612
|
+
artifact-locations>`_.
|
613
|
+
|
614
|
+
Returns:
|
615
|
+
A SHAP explainer.
|
616
|
+
"""
|
617
|
+
|
618
|
+
explainer_path = _download_artifact_from_uri(artifact_uri=model_uri)
|
619
|
+
flavor_conf = _get_flavor_configuration(model_path=explainer_path, flavor_name=FLAVOR_NAME)
|
620
|
+
_add_code_from_conf_to_system_path(explainer_path, flavor_conf)
|
621
|
+
explainer_artifacts_path = os.path.join(explainer_path, flavor_conf["serialized_explainer"])
|
622
|
+
underlying_model_flavor = flavor_conf["underlying_model_flavor"]
|
623
|
+
model = None
|
624
|
+
|
625
|
+
if underlying_model_flavor != _UNKNOWN_MODEL_FLAVOR:
|
626
|
+
underlying_model_path = os.path.join(explainer_path, _UNDERLYING_MODEL_SUBPATH)
|
627
|
+
if underlying_model_flavor == mlflow.sklearn.FLAVOR_NAME:
|
628
|
+
model = mlflow.sklearn._load_pyfunc(underlying_model_path).predict
|
629
|
+
elif underlying_model_flavor == mlflow.pytorch.FLAVOR_NAME:
|
630
|
+
model = mlflow.pytorch._load_model(os.path.join(underlying_model_path, "data"))
|
631
|
+
|
632
|
+
return _load_explainer(explainer_file=explainer_artifacts_path, model=model)
|
633
|
+
|
634
|
+
|
635
|
+
def _load_explainer(explainer_file, model=None):
|
636
|
+
"""
|
637
|
+
Load a SHAP explainer saved as an MLflow artifact on the local file system.
|
638
|
+
|
639
|
+
Args:
|
640
|
+
explainer_file: Local filesystem path to the MLflow Model saved with the ``shap`` flavor.
|
641
|
+
model: Model to override underlying explainer model.
|
642
|
+
|
643
|
+
"""
|
644
|
+
import shap
|
645
|
+
|
646
|
+
def inject_model_loader(_in_file):
|
647
|
+
return model
|
648
|
+
|
649
|
+
with open(explainer_file, "rb") as explainer:
|
650
|
+
if model is None:
|
651
|
+
explainer = shap.Explainer.load(explainer)
|
652
|
+
else:
|
653
|
+
explainer = shap.Explainer.load(explainer, model_loader=inject_model_loader)
|
654
|
+
return explainer
|
655
|
+
|
656
|
+
|
657
|
+
class _SHAPWrapper:
|
658
|
+
def __init__(self, path):
|
659
|
+
flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
|
660
|
+
shap_explainer_artifacts_path = os.path.join(path, flavor_conf["serialized_explainer"])
|
661
|
+
underlying_model_flavor = flavor_conf["underlying_model_flavor"]
|
662
|
+
model = None
|
663
|
+
if underlying_model_flavor != _UNKNOWN_MODEL_FLAVOR:
|
664
|
+
underlying_model_path = os.path.join(path, _UNDERLYING_MODEL_SUBPATH)
|
665
|
+
if underlying_model_flavor == mlflow.sklearn.FLAVOR_NAME:
|
666
|
+
model = mlflow.sklearn._load_pyfunc(underlying_model_path).predict
|
667
|
+
elif underlying_model_flavor == mlflow.pytorch.FLAVOR_NAME:
|
668
|
+
model = mlflow.pytorch._load_model(os.path.join(underlying_model_path, "data"))
|
669
|
+
|
670
|
+
self.explainer = _load_explainer(explainer_file=shap_explainer_artifacts_path, model=model)
|
671
|
+
|
672
|
+
def get_raw_model(self):
|
673
|
+
"""
|
674
|
+
Returns the underlying model.
|
675
|
+
"""
|
676
|
+
return self.explainer
|
677
|
+
|
678
|
+
def predict(
|
679
|
+
self,
|
680
|
+
dataframe,
|
681
|
+
params: Optional[dict[str, Any]] = None,
|
682
|
+
):
|
683
|
+
"""
|
684
|
+
Args:
|
685
|
+
dataframe: Model input data.
|
686
|
+
params: Additional parameters to pass to the model for inference.
|
687
|
+
|
688
|
+
Returns:
|
689
|
+
Model predictions.
|
690
|
+
"""
|
691
|
+
return self.explainer(dataframe.values).values
|