sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2898 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module configures the SageMaker Clarify bias and model explainability processor jobs.
|
|
14
|
+
|
|
15
|
+
SageMaker Clarify
|
|
16
|
+
==================
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import absolute_import, print_function
|
|
19
|
+
|
|
20
|
+
import copy
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import re
|
|
25
|
+
|
|
26
|
+
import tempfile
|
|
27
|
+
from abc import ABC, abstractmethod
|
|
28
|
+
from typing import List, Literal, Union, Dict, Optional, Any
|
|
29
|
+
from enum import Enum
|
|
30
|
+
from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
|
|
31
|
+
from sagemaker.core import s3
|
|
32
|
+
from sagemaker.core import image_uris
|
|
33
|
+
from sagemaker.core.helper.session_helper import Session
|
|
34
|
+
from sagemaker.core.network import NetworkConfig
|
|
35
|
+
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput
|
|
36
|
+
from sagemaker.core.processing import Processor
|
|
37
|
+
from sagemaker.core.common_utils import (
|
|
38
|
+
format_tags,
|
|
39
|
+
Tags,
|
|
40
|
+
name_from_base,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
|
|
47
|
+
|
|
48
|
+
# asym shap val config default values (timeseries)
|
|
49
|
+
ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION = "chronological"
|
|
50
|
+
ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY = "timewise"
|
|
51
|
+
ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS = [
|
|
52
|
+
"chronological",
|
|
53
|
+
"anti_chronological",
|
|
54
|
+
"bidirectional",
|
|
55
|
+
]
|
|
56
|
+
ASYM_SHAP_VAL_GRANULARITIES = [
|
|
57
|
+
"timewise",
|
|
58
|
+
"fine_grained",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
|
|
62
|
+
{
|
|
63
|
+
SchemaOptional("version"): str,
|
|
64
|
+
"dataset_type": And(
|
|
65
|
+
str,
|
|
66
|
+
Use(str.lower),
|
|
67
|
+
lambda s: s
|
|
68
|
+
in (
|
|
69
|
+
"text/csv",
|
|
70
|
+
"application/jsonlines",
|
|
71
|
+
"application/json",
|
|
72
|
+
"application/sagemakercapturejson",
|
|
73
|
+
"application/x-parquet",
|
|
74
|
+
"application/x-image",
|
|
75
|
+
),
|
|
76
|
+
),
|
|
77
|
+
SchemaOptional("dataset_uri"): str,
|
|
78
|
+
SchemaOptional("headers"): [str],
|
|
79
|
+
SchemaOptional("label"): Or(str, int),
|
|
80
|
+
# this field indicates user provides predicted_label in dataset
|
|
81
|
+
SchemaOptional("predicted_label"): Or(str, int),
|
|
82
|
+
SchemaOptional("features"): str,
|
|
83
|
+
SchemaOptional("label_values_or_threshold"): [Or(int, float, str)],
|
|
84
|
+
SchemaOptional("probability_threshold"): float,
|
|
85
|
+
SchemaOptional("segment_config"): [
|
|
86
|
+
{
|
|
87
|
+
SchemaOptional("config_name"): str,
|
|
88
|
+
"name_or_index": Or(str, int),
|
|
89
|
+
"segments": [[Or(str, int)]],
|
|
90
|
+
SchemaOptional("display_aliases"): [str],
|
|
91
|
+
}
|
|
92
|
+
],
|
|
93
|
+
SchemaOptional("facet"): [
|
|
94
|
+
{
|
|
95
|
+
"name_or_index": Or(str, int),
|
|
96
|
+
SchemaOptional("value_or_threshold"): [Or(int, float, str)],
|
|
97
|
+
}
|
|
98
|
+
],
|
|
99
|
+
SchemaOptional("facet_dataset_uri"): str,
|
|
100
|
+
SchemaOptional("facet_headers"): [str],
|
|
101
|
+
SchemaOptional("predicted_label_dataset_uri"): str,
|
|
102
|
+
SchemaOptional("predicted_label_headers"): [str],
|
|
103
|
+
SchemaOptional("excluded_columns"): [Or(int, str)],
|
|
104
|
+
SchemaOptional("joinsource_name_or_index"): Or(str, int),
|
|
105
|
+
SchemaOptional("group_variable"): Or(str, int),
|
|
106
|
+
SchemaOptional("time_series_data_config"): {
|
|
107
|
+
"target_time_series": Or(str, int),
|
|
108
|
+
"item_id": Or(str, int),
|
|
109
|
+
"timestamp": Or(str, int),
|
|
110
|
+
SchemaOptional("related_time_series"): Or([str], [int]),
|
|
111
|
+
SchemaOptional("static_covariates"): Or([str], [int]),
|
|
112
|
+
SchemaOptional("dataset_format"): And(
|
|
113
|
+
str,
|
|
114
|
+
Use(str.lower),
|
|
115
|
+
lambda s: s
|
|
116
|
+
in (
|
|
117
|
+
"columns",
|
|
118
|
+
"item_records",
|
|
119
|
+
"timestamp_records",
|
|
120
|
+
),
|
|
121
|
+
),
|
|
122
|
+
},
|
|
123
|
+
"methods": {
|
|
124
|
+
SchemaOptional("shap"): {
|
|
125
|
+
SchemaOptional("baseline"): Or(
|
|
126
|
+
# URI of the baseline data file
|
|
127
|
+
str,
|
|
128
|
+
# Inplace baseline data (a list of something)
|
|
129
|
+
[
|
|
130
|
+
Or(
|
|
131
|
+
# CSV row
|
|
132
|
+
[Or(int, float, str, None)],
|
|
133
|
+
# JSON row (any JSON object). As I write this only
|
|
134
|
+
# SageMaker JSONLines Dense Format ([1])
|
|
135
|
+
# is supported and the validation is NOT done
|
|
136
|
+
# by the schema but by the data loader.
|
|
137
|
+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
|
|
138
|
+
{object: object},
|
|
139
|
+
)
|
|
140
|
+
],
|
|
141
|
+
# Arbitrary JSON object as baseline
|
|
142
|
+
{object: object},
|
|
143
|
+
),
|
|
144
|
+
SchemaOptional("num_clusters"): int,
|
|
145
|
+
SchemaOptional("use_logit"): bool,
|
|
146
|
+
SchemaOptional("num_samples"): int,
|
|
147
|
+
SchemaOptional("agg_method"): And(
|
|
148
|
+
str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
|
|
149
|
+
),
|
|
150
|
+
SchemaOptional("save_local_shap_values"): bool,
|
|
151
|
+
SchemaOptional("text_config"): {
|
|
152
|
+
"granularity": And(
|
|
153
|
+
str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
|
|
154
|
+
),
|
|
155
|
+
"language": And(
|
|
156
|
+
str,
|
|
157
|
+
Use(str.lower),
|
|
158
|
+
lambda s: s
|
|
159
|
+
in (
|
|
160
|
+
"chinese",
|
|
161
|
+
"zh",
|
|
162
|
+
"danish",
|
|
163
|
+
"da",
|
|
164
|
+
"dutch",
|
|
165
|
+
"nl",
|
|
166
|
+
"english",
|
|
167
|
+
"en",
|
|
168
|
+
"french",
|
|
169
|
+
"fr",
|
|
170
|
+
"german",
|
|
171
|
+
"de",
|
|
172
|
+
"greek",
|
|
173
|
+
"el",
|
|
174
|
+
"italian",
|
|
175
|
+
"it",
|
|
176
|
+
"japanese",
|
|
177
|
+
"ja",
|
|
178
|
+
"lithuanian",
|
|
179
|
+
"lt",
|
|
180
|
+
"multi-language",
|
|
181
|
+
"xx",
|
|
182
|
+
"norwegian bokmål",
|
|
183
|
+
"nb",
|
|
184
|
+
"polish",
|
|
185
|
+
"pl",
|
|
186
|
+
"portuguese",
|
|
187
|
+
"pt",
|
|
188
|
+
"romanian",
|
|
189
|
+
"ro",
|
|
190
|
+
"russian",
|
|
191
|
+
"ru",
|
|
192
|
+
"spanish",
|
|
193
|
+
"es",
|
|
194
|
+
"afrikaans",
|
|
195
|
+
"af",
|
|
196
|
+
"albanian",
|
|
197
|
+
"sq",
|
|
198
|
+
"arabic",
|
|
199
|
+
"ar",
|
|
200
|
+
"armenian",
|
|
201
|
+
"hy",
|
|
202
|
+
"basque",
|
|
203
|
+
"eu",
|
|
204
|
+
"bengali",
|
|
205
|
+
"bn",
|
|
206
|
+
"bulgarian",
|
|
207
|
+
"bg",
|
|
208
|
+
"catalan",
|
|
209
|
+
"ca",
|
|
210
|
+
"croatian",
|
|
211
|
+
"hr",
|
|
212
|
+
"czech",
|
|
213
|
+
"cs",
|
|
214
|
+
"estonian",
|
|
215
|
+
"et",
|
|
216
|
+
"finnish",
|
|
217
|
+
"fi",
|
|
218
|
+
"gujarati",
|
|
219
|
+
"gu",
|
|
220
|
+
"hebrew",
|
|
221
|
+
"he",
|
|
222
|
+
"hindi",
|
|
223
|
+
"hi",
|
|
224
|
+
"hungarian",
|
|
225
|
+
"hu",
|
|
226
|
+
"icelandic",
|
|
227
|
+
"is",
|
|
228
|
+
"indonesian",
|
|
229
|
+
"id",
|
|
230
|
+
"irish",
|
|
231
|
+
"ga",
|
|
232
|
+
"kannada",
|
|
233
|
+
"kn",
|
|
234
|
+
"kyrgyz",
|
|
235
|
+
"ky",
|
|
236
|
+
"latvian",
|
|
237
|
+
"lv",
|
|
238
|
+
"ligurian",
|
|
239
|
+
"lij",
|
|
240
|
+
"luxembourgish",
|
|
241
|
+
"lb",
|
|
242
|
+
"macedonian",
|
|
243
|
+
"mk",
|
|
244
|
+
"malayalam",
|
|
245
|
+
"ml",
|
|
246
|
+
"marathi",
|
|
247
|
+
"mr",
|
|
248
|
+
"nepali",
|
|
249
|
+
"ne",
|
|
250
|
+
"persian",
|
|
251
|
+
"fa",
|
|
252
|
+
"sanskrit",
|
|
253
|
+
"sa",
|
|
254
|
+
"serbian",
|
|
255
|
+
"sr",
|
|
256
|
+
"setswana",
|
|
257
|
+
"tn",
|
|
258
|
+
"sinhala",
|
|
259
|
+
"si",
|
|
260
|
+
"slovak",
|
|
261
|
+
"sk",
|
|
262
|
+
"slovenian",
|
|
263
|
+
"sl",
|
|
264
|
+
"swedish",
|
|
265
|
+
"sv",
|
|
266
|
+
"tagalog",
|
|
267
|
+
"tl",
|
|
268
|
+
"tamil",
|
|
269
|
+
"ta",
|
|
270
|
+
"tatar",
|
|
271
|
+
"tt",
|
|
272
|
+
"telugu",
|
|
273
|
+
"te",
|
|
274
|
+
"thai",
|
|
275
|
+
"th",
|
|
276
|
+
"turkish",
|
|
277
|
+
"tr",
|
|
278
|
+
"ukrainian",
|
|
279
|
+
"uk",
|
|
280
|
+
"urdu",
|
|
281
|
+
"ur",
|
|
282
|
+
"vietnamese",
|
|
283
|
+
"vi",
|
|
284
|
+
"yoruba",
|
|
285
|
+
"yo",
|
|
286
|
+
),
|
|
287
|
+
),
|
|
288
|
+
SchemaOptional("max_top_tokens"): int,
|
|
289
|
+
},
|
|
290
|
+
SchemaOptional("image_config"): {
|
|
291
|
+
SchemaOptional("num_segments"): int,
|
|
292
|
+
SchemaOptional("segment_compactness"): int,
|
|
293
|
+
SchemaOptional("feature_extraction_method"): str,
|
|
294
|
+
SchemaOptional("model_type"): str,
|
|
295
|
+
SchemaOptional("max_objects"): int,
|
|
296
|
+
SchemaOptional("iou_threshold"): float,
|
|
297
|
+
SchemaOptional("context"): float,
|
|
298
|
+
SchemaOptional("debug"): {
|
|
299
|
+
SchemaOptional("image_names"): [str],
|
|
300
|
+
SchemaOptional("class_ids"): [int],
|
|
301
|
+
SchemaOptional("sample_from"): int,
|
|
302
|
+
SchemaOptional("sample_to"): int,
|
|
303
|
+
},
|
|
304
|
+
},
|
|
305
|
+
SchemaOptional("seed"): int,
|
|
306
|
+
SchemaOptional("features_to_explain"): [Or(int, str)],
|
|
307
|
+
},
|
|
308
|
+
SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
|
|
309
|
+
SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
|
|
310
|
+
SchemaOptional("pdp"): {
|
|
311
|
+
"grid_resolution": int,
|
|
312
|
+
SchemaOptional("features"): [Or(str, int)],
|
|
313
|
+
SchemaOptional("top_k_features"): int,
|
|
314
|
+
},
|
|
315
|
+
SchemaOptional("report"): {"name": str, SchemaOptional("title"): str},
|
|
316
|
+
SchemaOptional("asymmetric_shapley_value"): {
|
|
317
|
+
"direction": And(
|
|
318
|
+
str,
|
|
319
|
+
Use(str.lower),
|
|
320
|
+
lambda s: s
|
|
321
|
+
in (
|
|
322
|
+
"chronological",
|
|
323
|
+
"anti_chronological",
|
|
324
|
+
"bidirectional",
|
|
325
|
+
),
|
|
326
|
+
),
|
|
327
|
+
"granularity": And(
|
|
328
|
+
str,
|
|
329
|
+
Use(str.lower),
|
|
330
|
+
lambda s: s
|
|
331
|
+
in (
|
|
332
|
+
"timewise",
|
|
333
|
+
"fine_grained",
|
|
334
|
+
),
|
|
335
|
+
),
|
|
336
|
+
SchemaOptional("num_samples"): int,
|
|
337
|
+
SchemaOptional("baseline"): Or(
|
|
338
|
+
str,
|
|
339
|
+
{
|
|
340
|
+
SchemaOptional("target_time_series", default="zero"): And(
|
|
341
|
+
str,
|
|
342
|
+
Use(str.lower),
|
|
343
|
+
lambda s: s
|
|
344
|
+
in (
|
|
345
|
+
"zero",
|
|
346
|
+
"mean",
|
|
347
|
+
),
|
|
348
|
+
),
|
|
349
|
+
SchemaOptional("related_time_series"): And(
|
|
350
|
+
str,
|
|
351
|
+
Use(str.lower),
|
|
352
|
+
lambda s: s
|
|
353
|
+
in (
|
|
354
|
+
"zero",
|
|
355
|
+
"mean",
|
|
356
|
+
),
|
|
357
|
+
),
|
|
358
|
+
SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]},
|
|
359
|
+
},
|
|
360
|
+
),
|
|
361
|
+
},
|
|
362
|
+
},
|
|
363
|
+
SchemaOptional("predictor"): {
|
|
364
|
+
SchemaOptional("endpoint_name"): str,
|
|
365
|
+
SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
|
|
366
|
+
SchemaOptional("model_name"): str,
|
|
367
|
+
SchemaOptional("target_model"): str,
|
|
368
|
+
SchemaOptional("instance_type"): str,
|
|
369
|
+
SchemaOptional("initial_instance_count"): int,
|
|
370
|
+
SchemaOptional("accelerator_type"): str,
|
|
371
|
+
SchemaOptional("content_type"): And(
|
|
372
|
+
str,
|
|
373
|
+
Use(str.lower),
|
|
374
|
+
lambda s: s
|
|
375
|
+
in (
|
|
376
|
+
"text/csv",
|
|
377
|
+
"application/jsonlines",
|
|
378
|
+
"application/json",
|
|
379
|
+
"image/jpeg",
|
|
380
|
+
"image/png",
|
|
381
|
+
"application/x-npy",
|
|
382
|
+
),
|
|
383
|
+
),
|
|
384
|
+
SchemaOptional("accept_type"): And(
|
|
385
|
+
str,
|
|
386
|
+
Use(str.lower),
|
|
387
|
+
lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
|
|
388
|
+
),
|
|
389
|
+
SchemaOptional("label"): Or(str, int),
|
|
390
|
+
SchemaOptional("probability"): Or(str, int),
|
|
391
|
+
SchemaOptional("label_headers"): [Or(str, int)],
|
|
392
|
+
SchemaOptional("content_template"): Or(str, {str: str}),
|
|
393
|
+
SchemaOptional("record_template"): str,
|
|
394
|
+
SchemaOptional("custom_attributes"): str,
|
|
395
|
+
SchemaOptional("time_series_predictor_config"): {
|
|
396
|
+
"forecast": str,
|
|
397
|
+
},
|
|
398
|
+
},
|
|
399
|
+
}
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class DatasetType(Enum):
|
|
404
|
+
"""Enum to store different dataset types supported in the Analysis config file"""
|
|
405
|
+
|
|
406
|
+
TEXTCSV = "text/csv"
|
|
407
|
+
JSONLINES = "application/jsonlines"
|
|
408
|
+
JSON = "application/json"
|
|
409
|
+
PARQUET = "application/x-parquet"
|
|
410
|
+
IMAGE = "application/x-image"
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class TimeSeriesJSONDatasetFormat(Enum):
|
|
414
|
+
"""Possible dataset formats for JSON time series data files.
|
|
415
|
+
|
|
416
|
+
Below is an example ``COLUMNS`` dataset for time series explainability::
|
|
417
|
+
|
|
418
|
+
{
|
|
419
|
+
"ids": [1, 2],
|
|
420
|
+
"timestamps": [3, 4],
|
|
421
|
+
"target_ts": [5, 6],
|
|
422
|
+
"rts1": [0.25, 0.5],
|
|
423
|
+
"rts2": [1.25, 1.5],
|
|
424
|
+
"scv1": [10, 20],
|
|
425
|
+
"scv2": [30, 40]
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
|
|
429
|
+
|
|
430
|
+
item_id="ids"
|
|
431
|
+
timestamp="timestamps"
|
|
432
|
+
target_time_series="target_ts"
|
|
433
|
+
related_time_series=["rts1", "rts2"]
|
|
434
|
+
static_covariates=["scv1", "scv2"]
|
|
435
|
+
|
|
436
|
+
Below is an example ``ITEM_RECORDS`` dataset for time series explainability::
|
|
437
|
+
|
|
438
|
+
[
|
|
439
|
+
{
|
|
440
|
+
"id": 1,
|
|
441
|
+
"scv1": 10,
|
|
442
|
+
"scv2": "red",
|
|
443
|
+
"timeseries": [
|
|
444
|
+
{"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10},
|
|
445
|
+
{"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20},
|
|
446
|
+
{"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30}
|
|
447
|
+
]
|
|
448
|
+
},
|
|
449
|
+
{
|
|
450
|
+
"id": 2,
|
|
451
|
+
"scv1": 20,
|
|
452
|
+
"scv2": "blue",
|
|
453
|
+
"timeseries": [
|
|
454
|
+
{"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40},
|
|
455
|
+
{"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50}
|
|
456
|
+
]
|
|
457
|
+
}
|
|
458
|
+
]
|
|
459
|
+
|
|
460
|
+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
|
|
461
|
+
|
|
462
|
+
item_id="[*].id"
|
|
463
|
+
timestamp="[*].timeseries[].timestamp"
|
|
464
|
+
target_time_series="[*].timeseries[].target_ts"
|
|
465
|
+
related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"]
|
|
466
|
+
static_covariates=["[*].scv1", "[*].scv2"]
|
|
467
|
+
|
|
468
|
+
Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability::
|
|
469
|
+
|
|
470
|
+
[
|
|
471
|
+
{"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25},
|
|
472
|
+
{"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5},
|
|
473
|
+
{"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75},
|
|
474
|
+
{"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1}
|
|
475
|
+
]
|
|
476
|
+
|
|
477
|
+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
|
|
478
|
+
|
|
479
|
+
item_id="[*].id"
|
|
480
|
+
timestamp="[*].timestamp"
|
|
481
|
+
target_time_series="[*].target_ts"
|
|
482
|
+
related_time_series=["[*].rts1"]
|
|
483
|
+
static_covariates=["[*].scv1"]
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
COLUMNS = "columns"
|
|
487
|
+
ITEM_RECORDS = "item_records"
|
|
488
|
+
TIMESTAMP_RECORDS = "timestamp_records"
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class SegmentationConfig:
|
|
492
|
+
"""Config object that defines segment(s) of the dataset on which metrics are computed."""
|
|
493
|
+
|
|
494
|
+
def __init__(
|
|
495
|
+
self,
|
|
496
|
+
name_or_index: Union[str, int],
|
|
497
|
+
segments: List[List[Union[str, int]]],
|
|
498
|
+
config_name: Optional[str] = None,
|
|
499
|
+
display_aliases: Optional[List[str]] = None,
|
|
500
|
+
):
|
|
501
|
+
"""Initializes a segmentation configuration for a dataset column.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
name_or_index (str or int): The name or index of the column in the dataset on which
|
|
505
|
+
the segment(s) is defined.
|
|
506
|
+
segments (List[List[str or int]]): Each List of values represents one segment. If N
|
|
507
|
+
Lists are provided, we generate N+1 segments - the additional segment, denoted as
|
|
508
|
+
the '__default__' segment, is for the rest of the values that are not covered by
|
|
509
|
+
these lists. For continuous columns, a segment must be given as strings in interval
|
|
510
|
+
notation (eg.: ["[1, 4]"] or ["(2, 5]"]). A segment can also be composed of
|
|
511
|
+
multiple intervals (eg.: ["[1, 4]", "(5, 6]"] is one segment). For categorical
|
|
512
|
+
columns, each segment should contain one or more of the categorical values for
|
|
513
|
+
the categorical column, which may be strings or integers.
|
|
514
|
+
Eg,: For a continuous column, ``segments`` could be
|
|
515
|
+
[["[1, 4]", "(5, 6]"], ["(7, 9)"]] - this generates 3 segments including the
|
|
516
|
+
default segment. For a categorical columns with values ("A", "B", "C", "D"),
|
|
517
|
+
``segments``,could be [["A", "B"]]. This generate 2 segments, including the default
|
|
518
|
+
segment.
|
|
519
|
+
config_name (str) - Optional name for the segment config to identify the config.
|
|
520
|
+
display_aliases (List[str]) - Optional list of display names for the ``segments`` for
|
|
521
|
+
the analysis output and report. This list should be the same length as the number of
|
|
522
|
+
lists provided in ``segments`` or with one additional display alias for the default
|
|
523
|
+
segment.
|
|
524
|
+
|
|
525
|
+
Raises:
|
|
526
|
+
ValueError: when the ``name_or_index`` is None, ``segments`` is invalid, or a wrong
|
|
527
|
+
number of ``display_aliases`` are specified.
|
|
528
|
+
"""
|
|
529
|
+
if name_or_index is None:
|
|
530
|
+
raise ValueError("`name_or_index` cannot be None")
|
|
531
|
+
self.name_or_index = name_or_index
|
|
532
|
+
if (
|
|
533
|
+
not segments
|
|
534
|
+
or not isinstance(segments, list)
|
|
535
|
+
or not all([isinstance(segment, list) for segment in segments])
|
|
536
|
+
):
|
|
537
|
+
raise ValueError("`segments` must be a list of lists of values or intervals.")
|
|
538
|
+
self.segments = segments
|
|
539
|
+
self.config_name = config_name
|
|
540
|
+
if display_aliases is not None and not (
|
|
541
|
+
len(display_aliases) == len(segments) or len(display_aliases) == len(segments) + 1
|
|
542
|
+
):
|
|
543
|
+
raise ValueError(
|
|
544
|
+
"Number of `display_aliases` must equal the number of segments"
|
|
545
|
+
" specified or with one additional default segment display alias."
|
|
546
|
+
)
|
|
547
|
+
self.display_aliases = display_aliases
|
|
548
|
+
|
|
549
|
+
def to_dict(self) -> Dict[str, Any]: # pragma: no cover
|
|
550
|
+
"""Returns SegmentationConfig as a dict."""
|
|
551
|
+
segment_config_dict = {"name_or_index": self.name_or_index, "segments": self.segments}
|
|
552
|
+
if self.config_name:
|
|
553
|
+
segment_config_dict["config_name"] = self.config_name
|
|
554
|
+
if self.display_aliases:
|
|
555
|
+
segment_config_dict["display_aliases"] = self.display_aliases
|
|
556
|
+
return segment_config_dict
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
class TimeSeriesDataConfig:
|
|
560
|
+
"""Config object for TimeSeries explainability data configuration fields."""
|
|
561
|
+
|
|
562
|
+
def __init__(
|
|
563
|
+
self,
|
|
564
|
+
target_time_series: Union[str, int],
|
|
565
|
+
item_id: Union[str, int],
|
|
566
|
+
timestamp: Union[str, int],
|
|
567
|
+
related_time_series: Optional[List[Union[str, int]]] = None,
|
|
568
|
+
static_covariates: Optional[List[Union[str, int]]] = None,
|
|
569
|
+
dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None,
|
|
570
|
+
):
|
|
571
|
+
"""Initialises TimeSeries explainability data configuration fields.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
target_time_series (str or int): A string or a zero-based integer index.
|
|
575
|
+
Used to locate the target time series in the shared input dataset.
|
|
576
|
+
If this parameter is a string, then all other parameters except
|
|
577
|
+
`dataset_format` must be strings or lists of strings. If
|
|
578
|
+
this parameter is an int, then all other parameters except
|
|
579
|
+
`dataset_format` must be ints or lists of ints.
|
|
580
|
+
item_id (str or int): A string or a zero-based integer index. Used to
|
|
581
|
+
locate item id in the shared input dataset.
|
|
582
|
+
timestamp (str or int): A string or a zero-based integer index. Used to
|
|
583
|
+
locate timestamp in the shared input dataset.
|
|
584
|
+
related_time_series (list[str] or list[int]): Optional. An array of strings
|
|
585
|
+
or array of zero-based integer indices. Used to locate all related time
|
|
586
|
+
series in the shared input dataset (if present).
|
|
587
|
+
static_covariates (list[str] or list[int]): Optional. An array of strings or
|
|
588
|
+
array of zero-based integer indices. Used to locate all static covariate
|
|
589
|
+
fields in the shared input dataset (if present).
|
|
590
|
+
dataset_format (TimeSeriesJSONDatasetFormat): Describes the format
|
|
591
|
+
of the data files provided for analysis. Should only be provided
|
|
592
|
+
when dataset is in JSON format.
|
|
593
|
+
|
|
594
|
+
Raises:
|
|
595
|
+
ValueError: If any required arguments are not provided or are the wrong type.
|
|
596
|
+
"""
|
|
597
|
+
# check target_time_series, item_id, and timestamp are provided
|
|
598
|
+
if not target_time_series:
|
|
599
|
+
raise ValueError("Please provide a target time series.")
|
|
600
|
+
if not item_id:
|
|
601
|
+
raise ValueError("Please provide an item id.")
|
|
602
|
+
if not timestamp:
|
|
603
|
+
raise ValueError("Please provide a timestamp.")
|
|
604
|
+
# check all arguments are the right types
|
|
605
|
+
if not isinstance(target_time_series, (str, int)):
|
|
606
|
+
raise ValueError("Please provide a string or an int for ``target_time_series``")
|
|
607
|
+
params_type = type(target_time_series)
|
|
608
|
+
if not isinstance(item_id, params_type):
|
|
609
|
+
raise ValueError(f"Please provide {params_type} for ``item_id``")
|
|
610
|
+
if not isinstance(timestamp, params_type):
|
|
611
|
+
raise ValueError(f"Please provide {params_type} for ``timestamp``")
|
|
612
|
+
# add mandatory fields to an internal dictionary
|
|
613
|
+
self.time_series_data_config = dict()
|
|
614
|
+
_set(target_time_series, "target_time_series", self.time_series_data_config)
|
|
615
|
+
_set(item_id, "item_id", self.time_series_data_config)
|
|
616
|
+
_set(timestamp, "timestamp", self.time_series_data_config)
|
|
617
|
+
# check optional arguments are right types if provided
|
|
618
|
+
related_time_series_error_message = (
|
|
619
|
+
f"Please provide a list of {params_type} for ``related_time_series``"
|
|
620
|
+
)
|
|
621
|
+
if related_time_series:
|
|
622
|
+
if not isinstance(related_time_series, list):
|
|
623
|
+
raise ValueError(
|
|
624
|
+
related_time_series_error_message
|
|
625
|
+
) # related_time_series is not a list
|
|
626
|
+
if not all([isinstance(value, params_type) for value in related_time_series]):
|
|
627
|
+
raise ValueError(
|
|
628
|
+
related_time_series_error_message
|
|
629
|
+
) # related_time_series is not a list of strings or list of ints
|
|
630
|
+
if params_type == str and not all(related_time_series):
|
|
631
|
+
raise ValueError("Please do not provide empty strings in ``related_time_series``.")
|
|
632
|
+
_set(
|
|
633
|
+
related_time_series, "related_time_series", self.time_series_data_config
|
|
634
|
+
) # related_time_series is valid, add it
|
|
635
|
+
static_covariates_series_error_message = (
|
|
636
|
+
f"Please provide a list of {params_type} for ``static_covariates``"
|
|
637
|
+
)
|
|
638
|
+
if static_covariates:
|
|
639
|
+
if not isinstance(static_covariates, list):
|
|
640
|
+
raise ValueError(
|
|
641
|
+
static_covariates_series_error_message
|
|
642
|
+
) # static_covariates is not a list
|
|
643
|
+
if not all([isinstance(value, params_type) for value in static_covariates]):
|
|
644
|
+
raise ValueError(
|
|
645
|
+
static_covariates_series_error_message
|
|
646
|
+
) # static_covariates is not a list of strings or list of ints
|
|
647
|
+
if params_type == str and not all(static_covariates):
|
|
648
|
+
raise ValueError("Please do not provide empty strings in ``static_covariates``.")
|
|
649
|
+
_set(
|
|
650
|
+
static_covariates, "static_covariates", self.time_series_data_config
|
|
651
|
+
) # static_covariates is valid, add it
|
|
652
|
+
if params_type == str:
|
|
653
|
+
# check dataset_format is provided and valid
|
|
654
|
+
if not isinstance(dataset_format, TimeSeriesJSONDatasetFormat):
|
|
655
|
+
raise ValueError("Please provide a valid dataset format.")
|
|
656
|
+
_set(dataset_format.value, "dataset_format", self.time_series_data_config)
|
|
657
|
+
else:
|
|
658
|
+
if dataset_format:
|
|
659
|
+
raise ValueError(
|
|
660
|
+
"Dataset format should only be provided when data files are JSONs."
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
def get_time_series_data_config(self):
|
|
664
|
+
"""Returns part of an analysis config dictionary."""
|
|
665
|
+
return copy.deepcopy(self.time_series_data_config)
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
class DataConfig:
|
|
669
|
+
"""Config object related to configurations of the input and output dataset."""
|
|
670
|
+
|
|
671
|
+
def __init__(
|
|
672
|
+
self,
|
|
673
|
+
s3_data_input_path: str,
|
|
674
|
+
s3_output_path: str,
|
|
675
|
+
s3_analysis_config_output_path: Optional[str] = None,
|
|
676
|
+
label: Optional[str] = None,
|
|
677
|
+
headers: Optional[List[str]] = None,
|
|
678
|
+
features: Optional[str] = None,
|
|
679
|
+
dataset_type: str = "text/csv",
|
|
680
|
+
s3_compression_type: str = "None",
|
|
681
|
+
joinsource: Optional[Union[str, int]] = None,
|
|
682
|
+
facet_dataset_uri: Optional[str] = None,
|
|
683
|
+
facet_headers: Optional[List[str]] = None,
|
|
684
|
+
predicted_label_dataset_uri: Optional[str] = None,
|
|
685
|
+
predicted_label_headers: Optional[List[str]] = None,
|
|
686
|
+
predicted_label: Optional[Union[str, int]] = None,
|
|
687
|
+
excluded_columns: Optional[Union[List[int], List[str]]] = None,
|
|
688
|
+
segmentation_config: Optional[List[SegmentationConfig]] = None,
|
|
689
|
+
time_series_data_config: Optional[TimeSeriesDataConfig] = None,
|
|
690
|
+
):
|
|
691
|
+
"""Initializes a configuration of both input and output datasets.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
s3_data_input_path (str): Dataset S3 prefix/object URI.
|
|
695
|
+
s3_output_path (str): S3 prefix to store the output.
|
|
696
|
+
s3_analysis_config_output_path (str): S3 prefix to store the analysis config output.
|
|
697
|
+
If this field is None, then the ``s3_output_path`` will be used
|
|
698
|
+
to store the ``analysis_config`` output.
|
|
699
|
+
label (str): Target attribute of the model required by bias metrics. Specified as
|
|
700
|
+
column name or index for CSV dataset or a JMESPath expression for JSON/JSON Lines.
|
|
701
|
+
*Required parameter* except for when the input dataset does not contain the label.
|
|
702
|
+
Note: For JSON, the JMESPath query must result in a list of labels for each
|
|
703
|
+
sample. For JSON Lines, it must result in the label for each line.
|
|
704
|
+
Only a single label per sample is supported at this time.
|
|
705
|
+
headers ([str]): List of column names in the dataset. If not provided, Clarify will
|
|
706
|
+
generate headers to use internally. For time series explainability cases,
|
|
707
|
+
please provide headers in the order of item_id, timestamp, target_time_series,
|
|
708
|
+
all related_time_series columns, and then all static_covariate columns.
|
|
709
|
+
features (str): JMESPath expression to locate the feature values
|
|
710
|
+
if the dataset format is JSON/JSON Lines.
|
|
711
|
+
Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of
|
|
712
|
+
feature values. For JSON Lines, it must result in a 1-D list of features for each
|
|
713
|
+
line.
|
|
714
|
+
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
|
|
715
|
+
``"application/jsonlines"`` for JSON Lines, ``"application/json"`` for JSON, and
|
|
716
|
+
``"application/x-parquet"`` for Parquet.
|
|
717
|
+
s3_compression_type (str): Valid options are "None" or ``"Gzip"``.
|
|
718
|
+
joinsource (str or int): The name or index of the column in the dataset that
|
|
719
|
+
acts as an identifier column (for instance, while performing a join).
|
|
720
|
+
This column is only used as an identifier, and not used for any other computations.
|
|
721
|
+
This is an optional field in all cases except:
|
|
722
|
+
|
|
723
|
+
* The dataset contains more than one file and `save_local_shap_values`
|
|
724
|
+
is set to true in :class:`~sagemaker.clarify.ShapConfig`, and/or
|
|
725
|
+
* When the dataset and/or facet dataset and/or predicted label dataset
|
|
726
|
+
are in separate files.
|
|
727
|
+
|
|
728
|
+
facet_dataset_uri (str): Dataset S3 prefix/object URI that contains facet attribute(s),
|
|
729
|
+
used for bias analysis on datasets without facets.
|
|
730
|
+
|
|
731
|
+
* If the dataset and the facet dataset are one single file each, then
|
|
732
|
+
the original dataset and facet dataset must have the same number of rows.
|
|
733
|
+
* If the dataset and facet dataset are in multiple files (either one), then
|
|
734
|
+
an index column, ``joinsource``, is required to join the two datasets.
|
|
735
|
+
|
|
736
|
+
Clarify will not use the ``joinsource`` column and columns present in the facet
|
|
737
|
+
dataset when calling model inference APIs.
|
|
738
|
+
Note: this is only supported for ``"text/csv"`` dataset type.
|
|
739
|
+
facet_headers (list[str]): List of column names in the facet dataset.
|
|
740
|
+
predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels,
|
|
741
|
+
which are used directly for analysis instead of making model inference API calls.
|
|
742
|
+
|
|
743
|
+
* If the dataset and the predicted label dataset are one single file each, then the
|
|
744
|
+
original dataset and predicted label dataset must have the same number of rows.
|
|
745
|
+
* If the dataset and predicted label dataset are in multiple files (either one),
|
|
746
|
+
then an index column, ``joinsource``, is required to join the two datasets.
|
|
747
|
+
|
|
748
|
+
Note: this is only supported for ``"text/csv"`` dataset type.
|
|
749
|
+
predicted_label_headers (list[str]): List of column names in the predicted label dataset
|
|
750
|
+
predicted_label (str or int): Predicted label of the target attribute of the model
|
|
751
|
+
required for running bias analysis. Specified as column name or index for CSV data,
|
|
752
|
+
or a JMESPath expression for JSON/JSON Lines.
|
|
753
|
+
Clarify uses the predicted labels directly instead of making model inference API
|
|
754
|
+
calls.
|
|
755
|
+
Note: For JSON, the JMESPath query must result in a list of predicted labels for
|
|
756
|
+
each sample. For JSON Lines, it must result in the predicted label for each line.
|
|
757
|
+
Only a single predicted label per sample is supported at this time.
|
|
758
|
+
excluded_columns (list[int] or list[str]): A list of names or indices of the columns
|
|
759
|
+
which are to be excluded from making model inference API calls.
|
|
760
|
+
segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig``
|
|
761
|
+
objects.
|
|
762
|
+
time_series_data_config (TimeSeriesDataConfig): Optional. A config object for TimeSeries
|
|
763
|
+
data specific fields, required for TimeSeries explainability use cases.
|
|
764
|
+
|
|
765
|
+
Raises:
|
|
766
|
+
ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
|
|
767
|
+
are used with un-supported ``dataset_type``, or facet dataset parameters
|
|
768
|
+
are used with un-supported ``dataset_type``
|
|
769
|
+
"""
|
|
770
|
+
if dataset_type not in [
|
|
771
|
+
"text/csv",
|
|
772
|
+
"application/jsonlines",
|
|
773
|
+
"application/json",
|
|
774
|
+
"application/x-parquet",
|
|
775
|
+
"application/x-image",
|
|
776
|
+
]:
|
|
777
|
+
raise ValueError(
|
|
778
|
+
f"Invalid dataset_type '{dataset_type}'."
|
|
779
|
+
f" Please check the API documentation for the supported dataset types."
|
|
780
|
+
)
|
|
781
|
+
# predicted_label and excluded_columns are only supported for tabular datasets
|
|
782
|
+
if dataset_type not in [
|
|
783
|
+
"text/csv",
|
|
784
|
+
"application/jsonlines",
|
|
785
|
+
"application/json",
|
|
786
|
+
"application/x-parquet",
|
|
787
|
+
]:
|
|
788
|
+
if predicted_label:
|
|
789
|
+
raise ValueError(
|
|
790
|
+
f"The parameter 'predicted_label' is not supported"
|
|
791
|
+
f" for dataset_type '{dataset_type}'."
|
|
792
|
+
f" Please check the API documentation for the supported dataset types."
|
|
793
|
+
)
|
|
794
|
+
if excluded_columns:
|
|
795
|
+
raise ValueError(
|
|
796
|
+
f"The parameter 'excluded_columns' is not supported"
|
|
797
|
+
f" for dataset_type '{dataset_type}'."
|
|
798
|
+
f" Please check the API documentation for the supported dataset types."
|
|
799
|
+
)
|
|
800
|
+
# parameters for analysis on datasets without facets are only supported for CSV datasets
|
|
801
|
+
if dataset_type != "text/csv":
|
|
802
|
+
if facet_dataset_uri or facet_headers:
|
|
803
|
+
raise ValueError(
|
|
804
|
+
f"The parameters 'facet_dataset_uri' and 'facet_headers'"
|
|
805
|
+
f" are not supported for dataset_type '{dataset_type}'."
|
|
806
|
+
f" Please check the API documentation for the supported dataset types."
|
|
807
|
+
)
|
|
808
|
+
if predicted_label_dataset_uri or predicted_label_headers:
|
|
809
|
+
raise ValueError(
|
|
810
|
+
f"The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers'"
|
|
811
|
+
f" are not supported for dataset_type '{dataset_type}'."
|
|
812
|
+
f" Please check the API documentation for the supported dataset types."
|
|
813
|
+
)
|
|
814
|
+
# check if any other format other than JSON is provided for time series case
|
|
815
|
+
if time_series_data_config:
|
|
816
|
+
if dataset_type != "application/json":
|
|
817
|
+
raise ValueError(
|
|
818
|
+
"Currently time series explainability only supports JSON format data."
|
|
819
|
+
)
|
|
820
|
+
# features JMESPath is required for JSON as we can't derive it ourselves
|
|
821
|
+
if dataset_type == "application/json" and features is None and not time_series_data_config:
|
|
822
|
+
raise ValueError("features JMESPath is required for application/json dataset_type")
|
|
823
|
+
self.s3_data_input_path = s3_data_input_path
|
|
824
|
+
self.s3_output_path = s3_output_path
|
|
825
|
+
self.s3_analysis_config_output_path = s3_analysis_config_output_path
|
|
826
|
+
self.s3_data_distribution_type = "FullyReplicated"
|
|
827
|
+
self.s3_compression_type = s3_compression_type
|
|
828
|
+
self.label = label
|
|
829
|
+
self.headers = headers
|
|
830
|
+
self.features = features
|
|
831
|
+
self.facet_dataset_uri = facet_dataset_uri
|
|
832
|
+
self.facet_headers = facet_headers
|
|
833
|
+
self.predicted_label_dataset_uri = predicted_label_dataset_uri
|
|
834
|
+
self.predicted_label_headers = predicted_label_headers
|
|
835
|
+
self.predicted_label = predicted_label
|
|
836
|
+
self.excluded_columns = excluded_columns
|
|
837
|
+
self.segmentation_configs = segmentation_config
|
|
838
|
+
self.analysis_config = {
|
|
839
|
+
"dataset_type": dataset_type,
|
|
840
|
+
}
|
|
841
|
+
_set(features, "features", self.analysis_config)
|
|
842
|
+
_set(headers, "headers", self.analysis_config)
|
|
843
|
+
_set(label, "label", self.analysis_config)
|
|
844
|
+
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
|
|
845
|
+
_set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config)
|
|
846
|
+
_set(facet_headers, "facet_headers", self.analysis_config)
|
|
847
|
+
_set(
|
|
848
|
+
predicted_label_dataset_uri,
|
|
849
|
+
"predicted_label_dataset_uri",
|
|
850
|
+
self.analysis_config,
|
|
851
|
+
)
|
|
852
|
+
_set(predicted_label_headers, "predicted_label_headers", self.analysis_config)
|
|
853
|
+
_set(predicted_label, "predicted_label", self.analysis_config)
|
|
854
|
+
_set(excluded_columns, "excluded_columns", self.analysis_config)
|
|
855
|
+
if segmentation_config:
|
|
856
|
+
_set(
|
|
857
|
+
[item.to_dict() for item in segmentation_config],
|
|
858
|
+
"segment_config",
|
|
859
|
+
self.analysis_config,
|
|
860
|
+
)
|
|
861
|
+
if time_series_data_config:
|
|
862
|
+
_set(
|
|
863
|
+
time_series_data_config.get_time_series_data_config(),
|
|
864
|
+
"time_series_data_config",
|
|
865
|
+
self.analysis_config,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
def get_config(self):
|
|
869
|
+
"""Returns part of an analysis config dictionary."""
|
|
870
|
+
return copy.deepcopy(self.analysis_config)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class BiasConfig:
|
|
874
|
+
"""Config object with user-defined bias configurations of the input dataset."""
|
|
875
|
+
|
|
876
|
+
def __init__(
|
|
877
|
+
self,
|
|
878
|
+
label_values_or_threshold: List[Union[int, float, str]],
|
|
879
|
+
facet_name: Union[str, int, List[str], List[int]],
|
|
880
|
+
facet_values_or_threshold: Optional[Union[int, float, str]] = None,
|
|
881
|
+
group_name: Optional[str] = None,
|
|
882
|
+
):
|
|
883
|
+
"""Initializes a configuration of the sensitive groups in the dataset.
|
|
884
|
+
|
|
885
|
+
Args:
|
|
886
|
+
label_values_or_threshold ([int or float or str]): List of label value(s) or threshold
|
|
887
|
+
to indicate positive outcome used for bias metrics.
|
|
888
|
+
The appropriate threshold depends on the problem type:
|
|
889
|
+
|
|
890
|
+
* Binary: The list has one positive value.
|
|
891
|
+
* Categorical:The list has one or more (but not all) categories
|
|
892
|
+
which are the positive values.
|
|
893
|
+
* Regression: The list should include one threshold that defines the **exclusive**
|
|
894
|
+
lower bound of positive values.
|
|
895
|
+
|
|
896
|
+
facet_name (str or int or list[str] or list[int]): Sensitive attribute column name
|
|
897
|
+
(or index in the input data) to use when computing bias metrics. It can also be a
|
|
898
|
+
list of names (or indexes) for computing metrics for multiple sensitive attributes.
|
|
899
|
+
facet_values_or_threshold ([int or float or str] or [[int or float or str]]):
|
|
900
|
+
The parameter controls the values of the sensitive group.
|
|
901
|
+
If ``facet_name`` is a scalar, then it can be None or a list.
|
|
902
|
+
Depending on the data type of the facet column, the values mean:
|
|
903
|
+
|
|
904
|
+
* Binary data: None means computing the bias metrics for each binary value.
|
|
905
|
+
Or add one binary value to the list, to compute its bias metrics only.
|
|
906
|
+
* Categorical data: None means computing the bias metrics for each category. Or add
|
|
907
|
+
one or more (but not all) categories to the list, to compute their
|
|
908
|
+
bias metrics v.s. the other categories.
|
|
909
|
+
* Continuous data: The list should include one and only one threshold which defines
|
|
910
|
+
the **exclusive** lower bound of a sensitive group.
|
|
911
|
+
|
|
912
|
+
If ``facet_name`` is a list, then ``facet_values_or_threshold`` can be None
|
|
913
|
+
if all facets are of binary or categorical type.
|
|
914
|
+
Otherwise, ``facet_values_or_threshold`` should be a list, and each element
|
|
915
|
+
is the value or threshold of the corresponding facet.
|
|
916
|
+
group_name (str): Optional column name or index to indicate a group column to be used
|
|
917
|
+
for the bias metric
|
|
918
|
+
`Conditional Demographic Disparity in Labels `(CDDL) <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_
|
|
919
|
+
or
|
|
920
|
+
`Conditional Demographic Disparity in Predicted Labels (CDDPL) <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_.
|
|
921
|
+
|
|
922
|
+
Raises:
|
|
923
|
+
ValueError: If the number of ``facet_names`` doesn't equal number of ``facet values``
|
|
924
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
925
|
+
if isinstance(facet_name, list):
|
|
926
|
+
assert len(facet_name) > 0, "Please provide at least one facet"
|
|
927
|
+
if facet_values_or_threshold is None:
|
|
928
|
+
facet_list = [
|
|
929
|
+
{"name_or_index": single_facet_name} for single_facet_name in facet_name
|
|
930
|
+
]
|
|
931
|
+
elif len(facet_values_or_threshold) == len(facet_name):
|
|
932
|
+
facet_list = []
|
|
933
|
+
for i, single_facet_name in enumerate(facet_name):
|
|
934
|
+
facet = {"name_or_index": single_facet_name}
|
|
935
|
+
if facet_values_or_threshold is not None:
|
|
936
|
+
_set(facet_values_or_threshold[i], "value_or_threshold", facet)
|
|
937
|
+
facet_list.append(facet)
|
|
938
|
+
else:
|
|
939
|
+
raise ValueError(
|
|
940
|
+
"The number of facet names doesn't match the number of facet values"
|
|
941
|
+
)
|
|
942
|
+
else:
|
|
943
|
+
facet = {"name_or_index": facet_name}
|
|
944
|
+
_set(facet_values_or_threshold, "value_or_threshold", facet)
|
|
945
|
+
facet_list = [facet]
|
|
946
|
+
self.analysis_config = {
|
|
947
|
+
"label_values_or_threshold": label_values_or_threshold,
|
|
948
|
+
"facet": facet_list,
|
|
949
|
+
}
|
|
950
|
+
_set(group_name, "group_variable", self.analysis_config)
|
|
951
|
+
|
|
952
|
+
def get_config(self):
|
|
953
|
+
"""Returns a dictionary of bias detection configurations, part of the analysis config"""
|
|
954
|
+
return copy.deepcopy(self.analysis_config)
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
class TimeSeriesModelConfig:
|
|
958
|
+
"""Config object for TimeSeries predictor configuration fields."""
|
|
959
|
+
|
|
960
|
+
def __init__(
|
|
961
|
+
self,
|
|
962
|
+
forecast: str,
|
|
963
|
+
):
|
|
964
|
+
"""Initializes model configuration fields for TimeSeries explainability use cases.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
forecast (str): JMESPath expression to extract the forecast result.
|
|
968
|
+
|
|
969
|
+
Raises:
|
|
970
|
+
ValueError: when ``forecast`` is not a string or not provided
|
|
971
|
+
"""
|
|
972
|
+
# check string forecast is provided
|
|
973
|
+
if not isinstance(forecast, str):
|
|
974
|
+
raise ValueError(
|
|
975
|
+
"Please provide a string JMESPath expression for ``forecast`` "
|
|
976
|
+
"to extract the forecast result."
|
|
977
|
+
)
|
|
978
|
+
# add fields to an internal config dictionary
|
|
979
|
+
self.time_series_model_config = dict()
|
|
980
|
+
_set(forecast, "forecast", self.time_series_model_config)
|
|
981
|
+
|
|
982
|
+
def get_time_series_model_config(self):
|
|
983
|
+
"""Returns TimeSeries model config dictionary"""
|
|
984
|
+
return copy.deepcopy(self.time_series_model_config)
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
class ModelConfig:
|
|
988
|
+
"""Config object related to a model and its endpoint to be created."""
|
|
989
|
+
|
|
990
|
+
def __init__(
|
|
991
|
+
self,
|
|
992
|
+
model_name: Optional[str] = None,
|
|
993
|
+
instance_count: Optional[int] = None,
|
|
994
|
+
instance_type: Optional[str] = None,
|
|
995
|
+
accept_type: Optional[str] = None,
|
|
996
|
+
content_type: Optional[str] = None,
|
|
997
|
+
content_template: Optional[str] = None,
|
|
998
|
+
record_template: Optional[str] = None,
|
|
999
|
+
custom_attributes: Optional[str] = None,
|
|
1000
|
+
accelerator_type: Optional[str] = None,
|
|
1001
|
+
endpoint_name_prefix: Optional[str] = None,
|
|
1002
|
+
target_model: Optional[str] = None,
|
|
1003
|
+
endpoint_name: Optional[str] = None,
|
|
1004
|
+
time_series_model_config: Optional[TimeSeriesModelConfig] = None,
|
|
1005
|
+
):
|
|
1006
|
+
r"""Initializes a configuration of a model and the endpoint to be created for it.
|
|
1007
|
+
|
|
1008
|
+
Args:
|
|
1009
|
+
model_name (str): Model name (as created by
|
|
1010
|
+
`CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
|
|
1011
|
+
Cannot be set when ``endpoint_name`` is set.
|
|
1012
|
+
Must be set with ``instance_count``, ``instance_type``
|
|
1013
|
+
instance_count (int): The number of instances of a new endpoint for model inference.
|
|
1014
|
+
Cannot be set when ``endpoint_name`` is set.
|
|
1015
|
+
Must be set with ``model_name``, ``instance_type``
|
|
1016
|
+
instance_type (str): The type of
|
|
1017
|
+
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
|
|
1018
|
+
to use for model inference; for example, ``"ml.c5.xlarge"``.
|
|
1019
|
+
Cannot be set when ``endpoint_name`` is set.
|
|
1020
|
+
Must be set with ``instance_count``, ``model_name``
|
|
1021
|
+
accept_type (str): The model output format to be used for getting inferences with the
|
|
1022
|
+
shadow endpoint. Valid values are ``"text/csv"`` for CSV,
|
|
1023
|
+
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
|
|
1024
|
+
Default is the same as ``content_type``.
|
|
1025
|
+
content_type (str): The model input format to be used for getting inferences with the
|
|
1026
|
+
shadow endpoint. Valid values are ``"text/csv"`` for CSV,
|
|
1027
|
+
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
|
|
1028
|
+
Default is the same as ``dataset_format``.
|
|
1029
|
+
content_template (str): A template string to be used to construct the model input from
|
|
1030
|
+
dataset instances. It is only used, and required, when ``model_content_type`` is
|
|
1031
|
+
``"application/jsonlines"`` or ``"application/json"``. When ``model_content_type``
|
|
1032
|
+
is ``application/jsonlines``, the template should have one and only one
|
|
1033
|
+
placeholder, ``$features``, which will be replaced by a features list for each
|
|
1034
|
+
record to form the model inference input. When ``model_content_type`` is
|
|
1035
|
+
``application/json``, the template can have either placeholder ``$record``, which
|
|
1036
|
+
will be replaced by a single record templated by ``record_template`` and only a
|
|
1037
|
+
single record at a time will be sent to the model, or placeholder ``$records``,
|
|
1038
|
+
which will be replaced by a list of records, each templated by ``record_template``.
|
|
1039
|
+
record_template (str): A template string to be used to construct each record of the
|
|
1040
|
+
model input from dataset instances. It is only used, and required, when
|
|
1041
|
+
``model_content_type`` is ``"application/json"``.
|
|
1042
|
+
The template string may contain one of the following:
|
|
1043
|
+
|
|
1044
|
+
* Placeholder ``$features`` that will be substituted by the array of feature values
|
|
1045
|
+
and/or an optional placeholder ``$feature_names`` that will be substituted by the
|
|
1046
|
+
array of feature names.
|
|
1047
|
+
* Exactly one placeholder ``$features_kvp`` that will be substituted by the
|
|
1048
|
+
key-value pairs of feature name and feature value.
|
|
1049
|
+
* Or for each feature, if "A" is the feature name in the ``headers`` configuration,
|
|
1050
|
+
then placeholder syntax ``"${A}"`` (the double-quotes are part of the
|
|
1051
|
+
placeholder) will be substituted by the feature value.
|
|
1052
|
+
|
|
1053
|
+
``record_template`` will be used in conjunction with ``content_template`` to
|
|
1054
|
+
construct the model input.
|
|
1055
|
+
|
|
1056
|
+
**Examples:**
|
|
1057
|
+
|
|
1058
|
+
Given:
|
|
1059
|
+
|
|
1060
|
+
* ``headers``: ``["A", "B"]``
|
|
1061
|
+
* ``features``: ``[[0, 1], [3, 4]]``
|
|
1062
|
+
|
|
1063
|
+
Example model input 1::
|
|
1064
|
+
|
|
1065
|
+
{
|
|
1066
|
+
"instances": [[0, 1], [3, 4]],
|
|
1067
|
+
"feature_names": ["A", "B"]
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
content_template and record_template to construct above:
|
|
1071
|
+
|
|
1072
|
+
* ``content_template``: ``"{\"instances\": $records}"``
|
|
1073
|
+
* ``record_template``: ``"$features"``
|
|
1074
|
+
|
|
1075
|
+
Example model input 2::
|
|
1076
|
+
|
|
1077
|
+
[
|
|
1078
|
+
{ "A": 0, "B": 1 },
|
|
1079
|
+
{ "A": 3, "B": 4 },
|
|
1080
|
+
]
|
|
1081
|
+
|
|
1082
|
+
content_template and record_template to construct above:
|
|
1083
|
+
|
|
1084
|
+
* ``content_template``: ``"$records"``
|
|
1085
|
+
* ``record_template``: ``"$features_kvp"``
|
|
1086
|
+
|
|
1087
|
+
Or, alternatively:
|
|
1088
|
+
|
|
1089
|
+
* ``content_template``: ``"$records"``
|
|
1090
|
+
* ``record_template``: ``"{\"A\": \"${A}\", \"B\": \"${B}\"}"``
|
|
1091
|
+
|
|
1092
|
+
Example model input 3 (single record only)::
|
|
1093
|
+
|
|
1094
|
+
{ "A": 0, "B": 1 }
|
|
1095
|
+
|
|
1096
|
+
content_template and record_template to construct above:
|
|
1097
|
+
|
|
1098
|
+
* ``content_template``: ``"$record"``
|
|
1099
|
+
* ``record_template``: ``"$features_kvp"``
|
|
1100
|
+
custom_attributes (str): Provides additional information about a request for an
|
|
1101
|
+
inference submitted to a model hosted at an Amazon SageMaker endpoint. The
|
|
1102
|
+
information is an opaque value that is forwarded verbatim. You could use this
|
|
1103
|
+
value, for example, to provide an ID that you can use to track a request or to
|
|
1104
|
+
provide other metadata that a service endpoint was programmed to process. The value
|
|
1105
|
+
must consist of no more than 1024 visible US-ASCII characters as specified in
|
|
1106
|
+
Section 3.3.6.
|
|
1107
|
+
`Field Value Components <https://tools.ietf.org/html/rfc7230#section-3.2.6>`_
|
|
1108
|
+
of the Hypertext Transfer Protocol (HTTP/1.1).
|
|
1109
|
+
accelerator_type (str): SageMaker
|
|
1110
|
+
`Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_
|
|
1111
|
+
accelerator type to deploy to the model endpoint instance
|
|
1112
|
+
for making inferences to the model.
|
|
1113
|
+
endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow
|
|
1114
|
+
pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``.
|
|
1115
|
+
target_model (str): Sets the target model name when using a multi-model endpoint. For
|
|
1116
|
+
more information about multi-model endpoints, see
|
|
1117
|
+
https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
|
|
1118
|
+
endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint.
|
|
1119
|
+
Cannot be set when ``model_name``, ``instance_count``,
|
|
1120
|
+
and ``instance_type`` set
|
|
1121
|
+
time_series_model_config (TimeSeriesModelConfig): Optional. A config object for
|
|
1122
|
+
TimeSeries predictor specific fields, required for TimeSeries
|
|
1123
|
+
explainability use cases.
|
|
1124
|
+
|
|
1125
|
+
Raises:
|
|
1126
|
+
ValueError: when the
|
|
1127
|
+
- ``endpoint_name_prefix`` is invalid,
|
|
1128
|
+
- ``accept_type`` is invalid,
|
|
1129
|
+
- ``content_type`` is invalid,
|
|
1130
|
+
- ``content_template`` has no placeholder "features"
|
|
1131
|
+
- both [``endpoint_name``]
|
|
1132
|
+
AND [``model_name``, ``instance_count``, ``instance_type``] are set
|
|
1133
|
+
- both [``endpoint_name``] AND [``endpoint_name_prefix``] are set
|
|
1134
|
+
"""
|
|
1135
|
+
|
|
1136
|
+
# validation
|
|
1137
|
+
_model_endpoint_config_rule = (
|
|
1138
|
+
all([model_name, instance_count, instance_type]),
|
|
1139
|
+
all([endpoint_name]),
|
|
1140
|
+
)
|
|
1141
|
+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
|
|
1142
|
+
if endpoint_name:
|
|
1143
|
+
assert not endpoint_name_prefix
|
|
1144
|
+
|
|
1145
|
+
# main init logic
|
|
1146
|
+
self.predictor_config = (
|
|
1147
|
+
{
|
|
1148
|
+
"model_name": model_name,
|
|
1149
|
+
"instance_type": instance_type,
|
|
1150
|
+
"initial_instance_count": instance_count,
|
|
1151
|
+
}
|
|
1152
|
+
if not endpoint_name
|
|
1153
|
+
else {"endpoint_name": endpoint_name}
|
|
1154
|
+
)
|
|
1155
|
+
if endpoint_name_prefix:
|
|
1156
|
+
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
|
|
1157
|
+
raise ValueError(
|
|
1158
|
+
"Invalid endpoint_name_prefix."
|
|
1159
|
+
" Please follow pattern ^[a-zA-Z0-9](-*[a-zA-Z0-9])."
|
|
1160
|
+
)
|
|
1161
|
+
self.predictor_config["endpoint_name_prefix"] = endpoint_name_prefix
|
|
1162
|
+
if accept_type is not None:
|
|
1163
|
+
if accept_type not in ["text/csv", "application/jsonlines", "application/json"]:
|
|
1164
|
+
raise ValueError(
|
|
1165
|
+
f"Invalid accept_type {accept_type}."
|
|
1166
|
+
f" Please choose text/csv or application/jsonlines."
|
|
1167
|
+
)
|
|
1168
|
+
if time_series_model_config and accept_type == "text/csv":
|
|
1169
|
+
raise ValueError(
|
|
1170
|
+
"``accept_type`` must be JSON or JSONLines for time series explainability."
|
|
1171
|
+
)
|
|
1172
|
+
self.predictor_config["accept_type"] = accept_type
|
|
1173
|
+
if content_type is not None:
|
|
1174
|
+
if content_type not in [
|
|
1175
|
+
"text/csv",
|
|
1176
|
+
"application/jsonlines",
|
|
1177
|
+
"application/json",
|
|
1178
|
+
"image/jpeg",
|
|
1179
|
+
"image/jpg",
|
|
1180
|
+
"image/png",
|
|
1181
|
+
"application/x-npy",
|
|
1182
|
+
]:
|
|
1183
|
+
raise ValueError(
|
|
1184
|
+
f"Invalid content_type {content_type}."
|
|
1185
|
+
f" Please choose text/csv or application/jsonlines."
|
|
1186
|
+
)
|
|
1187
|
+
if content_type == "application/jsonlines":
|
|
1188
|
+
if content_template is None:
|
|
1189
|
+
raise ValueError(
|
|
1190
|
+
f"content_template field is required for content_type {content_type}"
|
|
1191
|
+
)
|
|
1192
|
+
if "$features" not in content_template:
|
|
1193
|
+
raise ValueError(
|
|
1194
|
+
f"Invalid content_template {content_template}."
|
|
1195
|
+
f" Please include a placeholder $features."
|
|
1196
|
+
)
|
|
1197
|
+
if content_type == "application/json":
|
|
1198
|
+
if content_template is None or record_template is None:
|
|
1199
|
+
raise ValueError(
|
|
1200
|
+
f"content_template and record_template are required for content_type "
|
|
1201
|
+
f"{content_type}"
|
|
1202
|
+
)
|
|
1203
|
+
if "$record" not in content_template:
|
|
1204
|
+
raise ValueError(
|
|
1205
|
+
f"Invalid content_template {content_template}."
|
|
1206
|
+
f" Please include either placeholder $records or $record."
|
|
1207
|
+
)
|
|
1208
|
+
if time_series_model_config and content_type not in [
|
|
1209
|
+
"application/json",
|
|
1210
|
+
"application/jsonlines",
|
|
1211
|
+
]:
|
|
1212
|
+
raise ValueError(
|
|
1213
|
+
"``content_type`` must be JSON or JSONLines for time series explainability."
|
|
1214
|
+
)
|
|
1215
|
+
self.predictor_config["content_type"] = content_type
|
|
1216
|
+
if content_template is not None:
|
|
1217
|
+
self.predictor_config["content_template"] = content_template
|
|
1218
|
+
if record_template is not None:
|
|
1219
|
+
self.predictor_config["record_template"] = record_template
|
|
1220
|
+
_set(custom_attributes, "custom_attributes", self.predictor_config)
|
|
1221
|
+
_set(accelerator_type, "accelerator_type", self.predictor_config)
|
|
1222
|
+
_set(target_model, "target_model", self.predictor_config)
|
|
1223
|
+
if time_series_model_config:
|
|
1224
|
+
_set(
|
|
1225
|
+
time_series_model_config.get_time_series_model_config(),
|
|
1226
|
+
"time_series_predictor_config",
|
|
1227
|
+
self.predictor_config,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
def get_predictor_config(self):
|
|
1231
|
+
"""Returns part of the predictor dictionary of the analysis config."""
|
|
1232
|
+
return copy.deepcopy(self.predictor_config)
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
class ModelPredictedLabelConfig:
|
|
1236
|
+
"""Config object to extract a predicted label from the model output."""
|
|
1237
|
+
|
|
1238
|
+
def __init__(
|
|
1239
|
+
self,
|
|
1240
|
+
label: Optional[Union[str, int]] = None,
|
|
1241
|
+
probability: Optional[Union[str, int]] = None,
|
|
1242
|
+
probability_threshold: Optional[float] = None,
|
|
1243
|
+
label_headers: Optional[List[str]] = None,
|
|
1244
|
+
):
|
|
1245
|
+
"""Initializes a model output config to extract the predicted label or predicted score(s).
|
|
1246
|
+
|
|
1247
|
+
The following examples show different parameter configurations depending on the endpoint:
|
|
1248
|
+
|
|
1249
|
+
* **Regression task:**
|
|
1250
|
+
The model returns the score, e.g. ``1.2``. We don't need to specify
|
|
1251
|
+
anything. For json output, e.g. ``{'score': 1.2}``, we can set ``label='score'``.
|
|
1252
|
+
* **Binary classification:**
|
|
1253
|
+
|
|
1254
|
+
* The model returns a single probability score. We want to classify as ``"yes"``
|
|
1255
|
+
predictions with a probability score over ``0.2``.
|
|
1256
|
+
We can set ``probability_threshold=0.2`` and ``label_headers="yes"``.
|
|
1257
|
+
* The model returns ``{"probability": 0.3}``, for which we would like to apply a
|
|
1258
|
+
threshold of ``0.5`` to obtain a predicted label in ``{0, 1}``.
|
|
1259
|
+
In this case we can set ``label="probability"``.
|
|
1260
|
+
* The model returns a tuple of the predicted label and the probability.
|
|
1261
|
+
In this case we can set ``label = 0``.
|
|
1262
|
+
* **Multiclass classification:**
|
|
1263
|
+
|
|
1264
|
+
* The model returns ``{'labels': ['cat', 'dog', 'fish'],
|
|
1265
|
+
'probabilities': [0.35, 0.25, 0.4]}``. In this case we would set
|
|
1266
|
+
``probability='probabilities'``, ``label='labels'``,
|
|
1267
|
+
and infer the predicted label to be ``'fish'``.
|
|
1268
|
+
* The model returns ``{'predicted_label': 'fish', 'probabilities': [0.35, 0.25, 0.4]}``.
|
|
1269
|
+
In this case we would set the ``label='predicted_label'``.
|
|
1270
|
+
* The model returns ``[0.35, 0.25, 0.4]``. In this case, we can set
|
|
1271
|
+
``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``.
|
|
1272
|
+
|
|
1273
|
+
Args:
|
|
1274
|
+
label (str or int): Index or JMESPath expression to locate the prediction
|
|
1275
|
+
in the model output. In case, this is a predicted label of the same type
|
|
1276
|
+
as the label in the dataset, no further arguments need to be specified.
|
|
1277
|
+
probability (str or int): Index or JMESPath expression to locate the predicted score(s)
|
|
1278
|
+
in the model output.
|
|
1279
|
+
probability_threshold (float): An optional value for binary prediction tasks in which
|
|
1280
|
+
the model returns a probability, to indicate the threshold to convert the
|
|
1281
|
+
prediction to a boolean value. Default is ``0.5``.
|
|
1282
|
+
label_headers (list[str]): List of headers, each for a predicted score in model output.
|
|
1283
|
+
For bias analysis, it is used to extract the label value with the highest score as
|
|
1284
|
+
predicted label. For explainability jobs, it is used to beautify the analysis report
|
|
1285
|
+
by replacing placeholders like ``'label0'``.
|
|
1286
|
+
|
|
1287
|
+
Raises:
|
|
1288
|
+
TypeError: when the ``probability_threshold`` cannot be cast to a float
|
|
1289
|
+
"""
|
|
1290
|
+
self.label = label
|
|
1291
|
+
self.probability = probability
|
|
1292
|
+
self.probability_threshold = probability_threshold
|
|
1293
|
+
self.label_headers = label_headers
|
|
1294
|
+
if probability_threshold is not None:
|
|
1295
|
+
try:
|
|
1296
|
+
float(probability_threshold)
|
|
1297
|
+
except ValueError:
|
|
1298
|
+
raise TypeError(
|
|
1299
|
+
f"Invalid probability_threshold {probability_threshold}. "
|
|
1300
|
+
f"Please choose one that can be cast to float."
|
|
1301
|
+
)
|
|
1302
|
+
self.predictor_config = {}
|
|
1303
|
+
_set(label, "label", self.predictor_config)
|
|
1304
|
+
_set(probability, "probability", self.predictor_config)
|
|
1305
|
+
_set(label_headers, "label_headers", self.predictor_config)
|
|
1306
|
+
|
|
1307
|
+
def get_predictor_config(self):
|
|
1308
|
+
"""Returns ``probability_threshold`` and predictor config dictionary."""
|
|
1309
|
+
return self.probability_threshold, copy.deepcopy(self.predictor_config)
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
class ExplainabilityConfig(ABC):
|
|
1313
|
+
"""Abstract config class to configure an explainability method."""
|
|
1314
|
+
|
|
1315
|
+
@abstractmethod
|
|
1316
|
+
def get_explainability_config(self):
|
|
1317
|
+
"""Returns config."""
|
|
1318
|
+
return None
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
class PDPConfig(ExplainabilityConfig):
|
|
1322
|
+
"""Config class for Partial Dependence Plots (PDP).
|
|
1323
|
+
|
|
1324
|
+
`PDPs <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-partial-dependence-plots.html>`_
|
|
1325
|
+
show the marginal effect (the dependence) a subset of features has on the predicted
|
|
1326
|
+
outcome of an ML model.
|
|
1327
|
+
|
|
1328
|
+
When PDP is requested (by passing in a :class:`~sagemaker.clarify.PDPConfig` to the
|
|
1329
|
+
``explainability_config`` parameter of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`),
|
|
1330
|
+
the Partial Dependence Plots are included in the output
|
|
1331
|
+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
|
|
1332
|
+
and the corresponding values are included in the analysis output.
|
|
1333
|
+
""" # noqa E501
|
|
1334
|
+
|
|
1335
|
+
def __init__(
|
|
1336
|
+
self, features: Optional[List] = None, grid_resolution: int = 15, top_k_features: int = 10
|
|
1337
|
+
):
|
|
1338
|
+
"""Initializes PDP config.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
features (None or list): List of feature names or indices for which partial dependence
|
|
1342
|
+
plots are computed and plotted. When :class:`~sagemaker.clarify.ShapConfig`
|
|
1343
|
+
is provided, this parameter is optional, as Clarify will compute the
|
|
1344
|
+
partial dependence plots for top features based on
|
|
1345
|
+
`SHAP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`__
|
|
1346
|
+
attributions. When :class:`~sagemaker.clarify.ShapConfig` is not provided,
|
|
1347
|
+
``features`` must be provided.
|
|
1348
|
+
grid_resolution (int): When using numerical features, this integer represents the
|
|
1349
|
+
number of buckets that the range of values must be divided into. This decides the
|
|
1350
|
+
granularity of the grid in which the PDP are plotted.
|
|
1351
|
+
top_k_features (int): Sets the number of top SHAP attributes used to compute
|
|
1352
|
+
partial dependence plots.
|
|
1353
|
+
""" # noqa E501
|
|
1354
|
+
self.pdp_config = {
|
|
1355
|
+
"grid_resolution": grid_resolution,
|
|
1356
|
+
"top_k_features": top_k_features,
|
|
1357
|
+
}
|
|
1358
|
+
if features is not None:
|
|
1359
|
+
self.pdp_config["features"] = features
|
|
1360
|
+
|
|
1361
|
+
def get_explainability_config(self):
|
|
1362
|
+
"""Returns PDP config dictionary."""
|
|
1363
|
+
return copy.deepcopy({"pdp": self.pdp_config})
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
class TextConfig:
|
|
1367
|
+
"""Config object to handle text features for text explainability
|
|
1368
|
+
|
|
1369
|
+
`SHAP analysis <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html>`__
|
|
1370
|
+
breaks down longer text into chunks (e.g. tokens, sentences, or paragraphs)
|
|
1371
|
+
and replaces them with the strings specified in the baseline for that feature.
|
|
1372
|
+
The `shap value <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_
|
|
1373
|
+
of a chunk then captures how much replacing it affects the prediction.
|
|
1374
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1375
|
+
|
|
1376
|
+
_SUPPORTED_GRANULARITIES = ["token", "sentence", "paragraph"]
|
|
1377
|
+
_SUPPORTED_LANGUAGES = [
|
|
1378
|
+
"chinese",
|
|
1379
|
+
"zh",
|
|
1380
|
+
"danish",
|
|
1381
|
+
"da",
|
|
1382
|
+
"dutch",
|
|
1383
|
+
"nl",
|
|
1384
|
+
"english",
|
|
1385
|
+
"en",
|
|
1386
|
+
"french",
|
|
1387
|
+
"fr",
|
|
1388
|
+
"german",
|
|
1389
|
+
"de",
|
|
1390
|
+
"greek",
|
|
1391
|
+
"el",
|
|
1392
|
+
"italian",
|
|
1393
|
+
"it",
|
|
1394
|
+
"japanese",
|
|
1395
|
+
"ja",
|
|
1396
|
+
"lithuanian",
|
|
1397
|
+
"lt",
|
|
1398
|
+
"multi-language",
|
|
1399
|
+
"xx",
|
|
1400
|
+
"norwegian bokmål",
|
|
1401
|
+
"nb",
|
|
1402
|
+
"polish",
|
|
1403
|
+
"pl",
|
|
1404
|
+
"portuguese",
|
|
1405
|
+
"pt",
|
|
1406
|
+
"romanian",
|
|
1407
|
+
"ro",
|
|
1408
|
+
"russian",
|
|
1409
|
+
"ru",
|
|
1410
|
+
"spanish",
|
|
1411
|
+
"es",
|
|
1412
|
+
"afrikaans",
|
|
1413
|
+
"af",
|
|
1414
|
+
"albanian",
|
|
1415
|
+
"sq",
|
|
1416
|
+
"arabic",
|
|
1417
|
+
"ar",
|
|
1418
|
+
"armenian",
|
|
1419
|
+
"hy",
|
|
1420
|
+
"basque",
|
|
1421
|
+
"eu",
|
|
1422
|
+
"bengali",
|
|
1423
|
+
"bn",
|
|
1424
|
+
"bulgarian",
|
|
1425
|
+
"bg",
|
|
1426
|
+
"catalan",
|
|
1427
|
+
"ca",
|
|
1428
|
+
"croatian",
|
|
1429
|
+
"hr",
|
|
1430
|
+
"czech",
|
|
1431
|
+
"cs",
|
|
1432
|
+
"estonian",
|
|
1433
|
+
"et",
|
|
1434
|
+
"finnish",
|
|
1435
|
+
"fi",
|
|
1436
|
+
"gujarati",
|
|
1437
|
+
"gu",
|
|
1438
|
+
"hebrew",
|
|
1439
|
+
"he",
|
|
1440
|
+
"hindi",
|
|
1441
|
+
"hi",
|
|
1442
|
+
"hungarian",
|
|
1443
|
+
"hu",
|
|
1444
|
+
"icelandic",
|
|
1445
|
+
"is",
|
|
1446
|
+
"indonesian",
|
|
1447
|
+
"id",
|
|
1448
|
+
"irish",
|
|
1449
|
+
"ga",
|
|
1450
|
+
"kannada",
|
|
1451
|
+
"kn",
|
|
1452
|
+
"kyrgyz",
|
|
1453
|
+
"ky",
|
|
1454
|
+
"latvian",
|
|
1455
|
+
"lv",
|
|
1456
|
+
"ligurian",
|
|
1457
|
+
"lij",
|
|
1458
|
+
"luxembourgish",
|
|
1459
|
+
"lb",
|
|
1460
|
+
"macedonian",
|
|
1461
|
+
"mk",
|
|
1462
|
+
"malayalam",
|
|
1463
|
+
"ml",
|
|
1464
|
+
"marathi",
|
|
1465
|
+
"mr",
|
|
1466
|
+
"nepali",
|
|
1467
|
+
"ne",
|
|
1468
|
+
"persian",
|
|
1469
|
+
"fa",
|
|
1470
|
+
"sanskrit",
|
|
1471
|
+
"sa",
|
|
1472
|
+
"serbian",
|
|
1473
|
+
"sr",
|
|
1474
|
+
"setswana",
|
|
1475
|
+
"tn",
|
|
1476
|
+
"sinhala",
|
|
1477
|
+
"si",
|
|
1478
|
+
"slovak",
|
|
1479
|
+
"sk",
|
|
1480
|
+
"slovenian",
|
|
1481
|
+
"sl",
|
|
1482
|
+
"swedish",
|
|
1483
|
+
"sv",
|
|
1484
|
+
"tagalog",
|
|
1485
|
+
"tl",
|
|
1486
|
+
"tamil",
|
|
1487
|
+
"ta",
|
|
1488
|
+
"tatar",
|
|
1489
|
+
"tt",
|
|
1490
|
+
"telugu",
|
|
1491
|
+
"te",
|
|
1492
|
+
"thai",
|
|
1493
|
+
"th",
|
|
1494
|
+
"turkish",
|
|
1495
|
+
"tr",
|
|
1496
|
+
"ukrainian",
|
|
1497
|
+
"uk",
|
|
1498
|
+
"urdu",
|
|
1499
|
+
"ur",
|
|
1500
|
+
"vietnamese",
|
|
1501
|
+
"vi",
|
|
1502
|
+
"yoruba",
|
|
1503
|
+
"yo",
|
|
1504
|
+
]
|
|
1505
|
+
|
|
1506
|
+
def __init__(
|
|
1507
|
+
self,
|
|
1508
|
+
granularity: str,
|
|
1509
|
+
language: str,
|
|
1510
|
+
):
|
|
1511
|
+
"""Initializes a text configuration.
|
|
1512
|
+
|
|
1513
|
+
Args:
|
|
1514
|
+
granularity (str): Determines the granularity in which text features are broken down
|
|
1515
|
+
to. Accepted values are ``"token"``, ``"sentence"``, or ``"paragraph"``.
|
|
1516
|
+
Computes `shap values <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_
|
|
1517
|
+
for these units.
|
|
1518
|
+
language (str): Specifies the language of the text features. Accepted values are
|
|
1519
|
+
one of the following:
|
|
1520
|
+
``"chinese"``, ``"danish"``, ``"dutch"``, ``"english"``, ``"french"``, ``"german"``,
|
|
1521
|
+
``"greek"``, ``"italian"``, ``"japanese"``, ``"lithuanian"``, ``"multi-language"``,
|
|
1522
|
+
``"norwegian bokmål"``, ``"polish"``, ``"portuguese"``, ``"romanian"``,
|
|
1523
|
+
``"russian"``, ``"spanish"``, ``"afrikaans"``, ``"albanian"``, ``"arabic"``,
|
|
1524
|
+
``"armenian"``, ``"basque"``, ``"bengali"``, ``"bulgarian"``, ``"catalan"``,
|
|
1525
|
+
``"croatian"``, ``"czech"``, ``"estonian"``, ``"finnish"``, ``"gujarati"``,
|
|
1526
|
+
``"hebrew"``, ``"hindi"``, ``"hungarian"``, ``"icelandic"``, ``"indonesian"``,
|
|
1527
|
+
``"irish"``, ``"kannada"``, ``"kyrgyz"``, ``"latvian"``, ``"ligurian"``,
|
|
1528
|
+
``"luxembourgish"``, ``"macedonian"``, ``"malayalam"``, ``"marathi"``, ``"nepali"``,
|
|
1529
|
+
``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``,
|
|
1530
|
+
``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``,
|
|
1531
|
+
``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``,
|
|
1532
|
+
``"vietnamese"``, ``"yoruba"``. Use "multi-language" for a mix of multiple
|
|
1533
|
+
languages. The corresponding two-letter ISO codes are also accepted.
|
|
1534
|
+
|
|
1535
|
+
Raises:
|
|
1536
|
+
ValueError: when ``granularity`` is not in list of supported values
|
|
1537
|
+
or ``language`` is not in list of supported values
|
|
1538
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1539
|
+
if granularity not in TextConfig._SUPPORTED_GRANULARITIES:
|
|
1540
|
+
raise ValueError(
|
|
1541
|
+
f"Invalid granularity {granularity}. Please choose among "
|
|
1542
|
+
f"{TextConfig._SUPPORTED_GRANULARITIES}"
|
|
1543
|
+
)
|
|
1544
|
+
if language not in TextConfig._SUPPORTED_LANGUAGES:
|
|
1545
|
+
raise ValueError(
|
|
1546
|
+
f"Invalid language {language}. Please choose among "
|
|
1547
|
+
f"{TextConfig._SUPPORTED_LANGUAGES}"
|
|
1548
|
+
)
|
|
1549
|
+
self.text_config = {
|
|
1550
|
+
"granularity": granularity,
|
|
1551
|
+
"language": language,
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
def get_text_config(self):
|
|
1555
|
+
"""Returns a text config dictionary, part of the analysis config dictionary."""
|
|
1556
|
+
return copy.deepcopy(self.text_config)
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
class ImageConfig:
|
|
1560
|
+
"""Config object for handling images"""
|
|
1561
|
+
|
|
1562
|
+
def __init__(
|
|
1563
|
+
self,
|
|
1564
|
+
model_type: str,
|
|
1565
|
+
num_segments: Optional[int] = None,
|
|
1566
|
+
feature_extraction_method: Optional[str] = None,
|
|
1567
|
+
segment_compactness: Optional[float] = None,
|
|
1568
|
+
max_objects: Optional[int] = None,
|
|
1569
|
+
iou_threshold: Optional[float] = None,
|
|
1570
|
+
context: Optional[float] = None,
|
|
1571
|
+
):
|
|
1572
|
+
"""Initializes a config object for Computer Vision (CV) Image explainability.
|
|
1573
|
+
|
|
1574
|
+
`SHAP for CV explainability <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability-computer-vision.html>`__.
|
|
1575
|
+
generating heat maps that visualize feature attributions for input images.
|
|
1576
|
+
These heat maps highlight the image's features according
|
|
1577
|
+
to how much they contribute to the CV model prediction.
|
|
1578
|
+
|
|
1579
|
+
``"IMAGE_CLASSIFICATION"`` and ``"OBJECT_DETECTION"`` are the two supported CV use cases.
|
|
1580
|
+
|
|
1581
|
+
Args:
|
|
1582
|
+
model_type (str): Specifies the type of CV model and use case. Accepted options:
|
|
1583
|
+
``"IMAGE_CLASSIFICATION"`` or ``"OBJECT_DETECTION"``.
|
|
1584
|
+
num_segments (None or int): Approximate number of segments to generate when running
|
|
1585
|
+
SKLearn's `SLIC method <https://scikit-image.org/docs/dev/api/skimage.segmentation.html?highlight=slic#skimage.segmentation.slic>`_
|
|
1586
|
+
for image segmentation to generate features/superpixels.
|
|
1587
|
+
The default is None. When set to None, runs SLIC with 20 segments.
|
|
1588
|
+
feature_extraction_method (None or str): method used for extracting features from the
|
|
1589
|
+
image (ex: "segmentation"). Default is ``"segmentation"``.
|
|
1590
|
+
segment_compactness (None or float): Balances color proximity and space proximity.
|
|
1591
|
+
Higher values give more weight to space proximity, making superpixel
|
|
1592
|
+
shapes more square/cubic. We recommend exploring possible values on a log
|
|
1593
|
+
scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value.
|
|
1594
|
+
The default is None. When set to None, runs with the default value of ``5``.
|
|
1595
|
+
max_objects (None or int): Maximum number of objects displayed when running SHAP
|
|
1596
|
+
with an ``"OBJECT_DETECTION"`` model. The Object detection algorithm may detect
|
|
1597
|
+
more than the ``max_objects`` number of objects in a single image.
|
|
1598
|
+
In that case, the algorithm displays the top ``max_objects`` number of objects
|
|
1599
|
+
according to confidence score. Default value is None. In the ``"OBJECT_DETECTION"``
|
|
1600
|
+
case, passing in None leads to a default value of ``3``.
|
|
1601
|
+
iou_threshold (None or float): Minimum intersection over union for the object
|
|
1602
|
+
bounding box to consider its confidence score for computing SHAP values,
|
|
1603
|
+
in the range ``[0.0, 1.0]``. Used only for the ``"OBJECT_DETECTION"`` case,
|
|
1604
|
+
where passing in None sets the default value of ``0.5``.
|
|
1605
|
+
context (None or float): The portion of the image outside the bounding box used
|
|
1606
|
+
in SHAP analysis, in the range ``[0.0, 1.0]``. If set to ``1.0``, the whole image
|
|
1607
|
+
is considered; if set to ``0.0`` only the image inside bounding box is considered.
|
|
1608
|
+
Only used for the ``"OBJECT_DETECTION"`` case,
|
|
1609
|
+
when passing in None sets the default value of ``1.0``.
|
|
1610
|
+
|
|
1611
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1612
|
+
self.image_config = {}
|
|
1613
|
+
|
|
1614
|
+
if model_type not in ["OBJECT_DETECTION", "IMAGE_CLASSIFICATION"]:
|
|
1615
|
+
raise ValueError(
|
|
1616
|
+
"Clarify SHAP only supports object detection and image classification methods. "
|
|
1617
|
+
"Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION."
|
|
1618
|
+
)
|
|
1619
|
+
self.image_config["model_type"] = model_type
|
|
1620
|
+
_set(num_segments, "num_segments", self.image_config)
|
|
1621
|
+
_set(feature_extraction_method, "feature_extraction_method", self.image_config)
|
|
1622
|
+
_set(segment_compactness, "segment_compactness", self.image_config)
|
|
1623
|
+
_set(max_objects, "max_objects", self.image_config)
|
|
1624
|
+
_set(iou_threshold, "iou_threshold", self.image_config)
|
|
1625
|
+
_set(context, "context", self.image_config)
|
|
1626
|
+
|
|
1627
|
+
def get_image_config(self):
|
|
1628
|
+
"""Returns the image config part of an analysis config dictionary."""
|
|
1629
|
+
return copy.deepcopy(self.image_config)
|
|
1630
|
+
|
|
1631
|
+
|
|
1632
|
+
class SHAPConfig(ExplainabilityConfig):
|
|
1633
|
+
"""Config class for `SHAP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html>`__.
|
|
1634
|
+
|
|
1635
|
+
The SHAP algorithm calculates feature attributions by computing
|
|
1636
|
+
the contribution of each feature to the prediction outcome, using the concept of
|
|
1637
|
+
`Shapley values <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_.
|
|
1638
|
+
|
|
1639
|
+
These attributions can be provided for specific predictions (locally)
|
|
1640
|
+
and at a global level for the model as a whole.
|
|
1641
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1642
|
+
|
|
1643
|
+
def __init__(
|
|
1644
|
+
self,
|
|
1645
|
+
baseline: Optional[Union[str, List, Dict]] = None,
|
|
1646
|
+
num_samples: Optional[int] = None,
|
|
1647
|
+
agg_method: Optional[str] = None,
|
|
1648
|
+
use_logit: bool = False,
|
|
1649
|
+
save_local_shap_values: bool = True,
|
|
1650
|
+
seed: Optional[int] = None,
|
|
1651
|
+
num_clusters: Optional[int] = None,
|
|
1652
|
+
text_config: Optional[TextConfig] = None,
|
|
1653
|
+
image_config: Optional[ImageConfig] = None,
|
|
1654
|
+
features_to_explain: Optional[List[Union[str, int]]] = None,
|
|
1655
|
+
):
|
|
1656
|
+
"""Initializes config for SHAP analysis.
|
|
1657
|
+
|
|
1658
|
+
Args:
|
|
1659
|
+
baseline (None or str or list or dict): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
|
|
1660
|
+
for the Kernel SHAP algorithm, accepted in the form of:
|
|
1661
|
+
S3 object URI, a list of rows (with at least one element),
|
|
1662
|
+
or None (for no input baseline). The baseline dataset must have the same format
|
|
1663
|
+
as the input dataset specified in :class:`~sagemaker.clarify.DataConfig`.
|
|
1664
|
+
Each row must have only the feature columns/values and omit the label column/values.
|
|
1665
|
+
If None, a baseline will be calculated automatically on the input dataset
|
|
1666
|
+
using K-means (for numerical data) or K-prototypes (if there is categorical data).
|
|
1667
|
+
num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
|
|
1668
|
+
This number determines the size of the generated synthetic dataset to compute the
|
|
1669
|
+
SHAP values. If not provided then Clarify job will choose a proper value according
|
|
1670
|
+
to the count of features.
|
|
1671
|
+
agg_method (None or str): Aggregation method for global SHAP values. Valid values are
|
|
1672
|
+
``"mean_abs"`` (mean of absolute SHAP values for all instances),
|
|
1673
|
+
``"median"`` (median of SHAP values for all instances) and
|
|
1674
|
+
``"mean_sq"`` (mean of squared SHAP values for all instances).
|
|
1675
|
+
If None is provided, then Clarify job uses the method ``"mean_abs"``.
|
|
1676
|
+
use_logit (bool): Indicates whether to apply the logit function to model predictions.
|
|
1677
|
+
Default is False. If ``use_logit`` is true then the SHAP values will
|
|
1678
|
+
have log-odds units.
|
|
1679
|
+
save_local_shap_values (bool): Indicates whether to save the local SHAP values
|
|
1680
|
+
in the output location. Default is True.
|
|
1681
|
+
seed (int): Seed value to get deterministic SHAP values. Default is None.
|
|
1682
|
+
num_clusters (None or int): If a ``baseline`` is not provided, Clarify automatically
|
|
1683
|
+
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes), which
|
|
1684
|
+
takes ``num_clusters`` as a parameter. ``num_clusters`` will be the resulting size
|
|
1685
|
+
of the baseline dataset. If not provided, Clarify job uses a default value.
|
|
1686
|
+
text_config (:class:`~sagemaker.clarify.TextConfig`): Config object for handling
|
|
1687
|
+
text features. Default is None.
|
|
1688
|
+
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
|
|
1689
|
+
features. Default is None.
|
|
1690
|
+
features_to_explain: A list of names or indices of dataset features to compute SHAP
|
|
1691
|
+
values for. If not provided, SHAP values are computed for all features by default.
|
|
1692
|
+
Currently only supported for tabular datasets.
|
|
1693
|
+
|
|
1694
|
+
Raises:
|
|
1695
|
+
ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided
|
|
1696
|
+
together, or ``features_to_explain`` is specified when ``text_config`` or
|
|
1697
|
+
``image_config`` is provided
|
|
1698
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1699
|
+
if agg_method is not None and agg_method not in [
|
|
1700
|
+
"mean_abs",
|
|
1701
|
+
"median",
|
|
1702
|
+
"mean_sq",
|
|
1703
|
+
]:
|
|
1704
|
+
raise ValueError(
|
|
1705
|
+
f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
|
|
1706
|
+
)
|
|
1707
|
+
if num_clusters is not None and baseline is not None:
|
|
1708
|
+
raise ValueError(
|
|
1709
|
+
"Baseline and num_clusters cannot be provided together. "
|
|
1710
|
+
"Please specify one of the two."
|
|
1711
|
+
)
|
|
1712
|
+
self.shap_config = {
|
|
1713
|
+
"use_logit": use_logit,
|
|
1714
|
+
"save_local_shap_values": save_local_shap_values,
|
|
1715
|
+
}
|
|
1716
|
+
_set(baseline, "baseline", self.shap_config)
|
|
1717
|
+
_set(num_samples, "num_samples", self.shap_config)
|
|
1718
|
+
_set(agg_method, "agg_method", self.shap_config)
|
|
1719
|
+
_set(seed, "seed", self.shap_config)
|
|
1720
|
+
_set(num_clusters, "num_clusters", self.shap_config)
|
|
1721
|
+
if text_config:
|
|
1722
|
+
_set(text_config.get_text_config(), "text_config", self.shap_config)
|
|
1723
|
+
if not save_local_shap_values:
|
|
1724
|
+
logger.warning(
|
|
1725
|
+
"Global aggregation is not yet supported for text features. "
|
|
1726
|
+
"Consider setting save_local_shap_values=True to inspect local text "
|
|
1727
|
+
"explanations."
|
|
1728
|
+
)
|
|
1729
|
+
if image_config:
|
|
1730
|
+
_set(image_config.get_image_config(), "image_config", self.shap_config)
|
|
1731
|
+
if features_to_explain is not None and (
|
|
1732
|
+
text_config is not None or image_config is not None
|
|
1733
|
+
):
|
|
1734
|
+
raise ValueError(
|
|
1735
|
+
"`features_to_explain` is not supported for datasets containing text features or images."
|
|
1736
|
+
)
|
|
1737
|
+
_set(features_to_explain, "features_to_explain", self.shap_config)
|
|
1738
|
+
|
|
1739
|
+
def get_explainability_config(self):
|
|
1740
|
+
"""Returns a shap config dictionary."""
|
|
1741
|
+
return copy.deepcopy({"shap": self.shap_config})
|
|
1742
|
+
|
|
1743
|
+
|
|
1744
|
+
class AsymmetricShapleyValueConfig(ExplainabilityConfig):
|
|
1745
|
+
"""Config class for Asymmetric Shapley value algorithm for time series explainability.
|
|
1746
|
+
|
|
1747
|
+
Asymmetric Shapley Values are a variant of the Shapley Value that drop the symmetry axiom [1].
|
|
1748
|
+
We use these to determine how features contribute to the forecasting outcome. Asymmetric
|
|
1749
|
+
Shapley values can take into account the temporal dependencies of the time series that
|
|
1750
|
+
forecasting models take as input.
|
|
1751
|
+
|
|
1752
|
+
[1] Frye, Christopher, Colin Rowat, and Ilya Feige. "Asymmetric shapley values: incorporating
|
|
1753
|
+
causal knowledge into model-agnostic explainability." NeurIPS (2020).
|
|
1754
|
+
https://doi.org/10.48550/arXiv.1910.06358
|
|
1755
|
+
"""
|
|
1756
|
+
|
|
1757
|
+
def __init__(
|
|
1758
|
+
self,
|
|
1759
|
+
direction: Literal[
|
|
1760
|
+
"chronological",
|
|
1761
|
+
"anti_chronological",
|
|
1762
|
+
"bidirectional",
|
|
1763
|
+
] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION,
|
|
1764
|
+
granularity: Literal[
|
|
1765
|
+
"timewise",
|
|
1766
|
+
"fine_grained",
|
|
1767
|
+
] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY,
|
|
1768
|
+
num_samples: Optional[int] = None,
|
|
1769
|
+
baseline: Optional[Union[str, Dict[str, Any]]] = None,
|
|
1770
|
+
):
|
|
1771
|
+
"""Initialises config for time series explainability with Asymmetric Shapley Values.
|
|
1772
|
+
|
|
1773
|
+
AsymmetricShapleyValueConfig is used specifically and only for TimeSeries explainability
|
|
1774
|
+
purposes.
|
|
1775
|
+
|
|
1776
|
+
Args:
|
|
1777
|
+
direction (str): Type of explanation to be used. Available explanation
|
|
1778
|
+
types are ``"chronological"``, ``"anti_chronological"``, and ``"bidirectional"``.
|
|
1779
|
+
granularity (str): Explanation granularity to be used. Available granularity options
|
|
1780
|
+
are ``"timewise"`` and ``"fine_grained"``.
|
|
1781
|
+
num_samples (None or int): Number of samples to be used in the Asymmetric Shapley
|
|
1782
|
+
Value forecasting algorithm. Only applicable when using ``"fine_grained"``
|
|
1783
|
+
explanations.
|
|
1784
|
+
baseline (str or dict): Link to a baseline configuration or a dictionary for it. The
|
|
1785
|
+
baseline config is used to replace out-of-coalition values for the corresponding
|
|
1786
|
+
datasets (also known as background data). For temporal data (target time series,
|
|
1787
|
+
related time series), the baseline value types are "zero", where all
|
|
1788
|
+
out-of-coalition values will be replaced with 0.0, or "mean", all out-of-coalition
|
|
1789
|
+
values will be replaced with the average of a time series. For static data
|
|
1790
|
+
(static covariates), a baseline value for each covariate should be provided for
|
|
1791
|
+
each possible item_id. An example config follows, where ``item1`` and ``item2``
|
|
1792
|
+
are item ids::
|
|
1793
|
+
|
|
1794
|
+
{
|
|
1795
|
+
"target_time_series": "zero",
|
|
1796
|
+
"related_time_series": "zero",
|
|
1797
|
+
"static_covariates": {
|
|
1798
|
+
"item1": [1, 1],
|
|
1799
|
+
"item2": [0, 1]
|
|
1800
|
+
}
|
|
1801
|
+
}
|
|
1802
|
+
|
|
1803
|
+
Raises:
|
|
1804
|
+
ValueError: when ``direction`` or ``granularity`` are not valid, ``num_samples`` is not
|
|
1805
|
+
provided for fine-grained explanations, ``num_samples`` is provided for non
|
|
1806
|
+
fine-grained explanations, or when ``direction`` is not ``"chronological"`` while
|
|
1807
|
+
``granularity`` is ``"fine_grained"``.
|
|
1808
|
+
"""
|
|
1809
|
+
self.asymmetric_shapley_value_config = dict()
|
|
1810
|
+
# validate explanation direction
|
|
1811
|
+
if direction not in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS:
|
|
1812
|
+
raise ValueError(
|
|
1813
|
+
"Please provide a valid explanation direction from: "
|
|
1814
|
+
+ ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS)
|
|
1815
|
+
)
|
|
1816
|
+
# validate granularity
|
|
1817
|
+
if granularity not in ASYM_SHAP_VAL_GRANULARITIES:
|
|
1818
|
+
raise ValueError(
|
|
1819
|
+
"Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES)
|
|
1820
|
+
)
|
|
1821
|
+
if granularity == "fine_grained":
|
|
1822
|
+
if not isinstance(num_samples, int):
|
|
1823
|
+
raise ValueError("Please provide an integer for ``num_samples``.")
|
|
1824
|
+
if direction != "chronological":
|
|
1825
|
+
raise ValueError(
|
|
1826
|
+
f"{direction} and {granularity} granularity are not supported together."
|
|
1827
|
+
)
|
|
1828
|
+
elif num_samples: # validate num_samples is not provided when unnecessary
|
|
1829
|
+
raise ValueError("``num_samples`` is only used for fine-grained explanations.")
|
|
1830
|
+
# validate baseline if provided as a dictionary
|
|
1831
|
+
if isinstance(baseline, dict):
|
|
1832
|
+
temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields
|
|
1833
|
+
if "target_time_series" in baseline:
|
|
1834
|
+
target_baseline = baseline.get("target_time_series")
|
|
1835
|
+
if target_baseline not in temporal_baselines:
|
|
1836
|
+
raise ValueError(
|
|
1837
|
+
f"Provided value {target_baseline} for ``target_time_series`` is "
|
|
1838
|
+
f"invalid. Please select one of {temporal_baselines}."
|
|
1839
|
+
)
|
|
1840
|
+
if "related_time_series" in baseline:
|
|
1841
|
+
related_baseline = baseline.get("related_time_series")
|
|
1842
|
+
if related_baseline not in temporal_baselines:
|
|
1843
|
+
raise ValueError(
|
|
1844
|
+
f"Provided value {related_baseline} for ``related_time_series`` is "
|
|
1845
|
+
f"invalid. Please select one of {temporal_baselines}."
|
|
1846
|
+
)
|
|
1847
|
+
# set explanation type and (if provided) num_samples in internal config dictionary
|
|
1848
|
+
_set(direction, "direction", self.asymmetric_shapley_value_config)
|
|
1849
|
+
_set(granularity, "granularity", self.asymmetric_shapley_value_config)
|
|
1850
|
+
_set(
|
|
1851
|
+
num_samples, "num_samples", self.asymmetric_shapley_value_config
|
|
1852
|
+
) # _set() does nothing if a given argument is None
|
|
1853
|
+
_set(baseline, "baseline", self.asymmetric_shapley_value_config)
|
|
1854
|
+
|
|
1855
|
+
def get_explainability_config(self):
|
|
1856
|
+
"""Returns an asymmetric shap config dictionary."""
|
|
1857
|
+
return copy.deepcopy({"asymmetric_shapley_value": self.asymmetric_shapley_value_config})
|
|
1858
|
+
|
|
1859
|
+
|
|
1860
|
+
class SageMakerClarifyProcessor(Processor):
|
|
1861
|
+
"""Handles SageMaker Processing tasks to compute bias metrics and model explanations."""
|
|
1862
|
+
|
|
1863
|
+
_CLARIFY_DATA_INPUT = "/opt/ml/processing/input/data"
|
|
1864
|
+
_CLARIFY_CONFIG_INPUT = "/opt/ml/processing/input/config"
|
|
1865
|
+
_CLARIFY_OUTPUT = "/opt/ml/processing/output"
|
|
1866
|
+
|
|
1867
|
+
def __init__(
|
|
1868
|
+
self,
|
|
1869
|
+
role: Optional[str] = None,
|
|
1870
|
+
instance_count: int = None,
|
|
1871
|
+
instance_type: str = None,
|
|
1872
|
+
volume_size_in_gb: int = 30,
|
|
1873
|
+
volume_kms_key: Optional[str] = None,
|
|
1874
|
+
output_kms_key: Optional[str] = None,
|
|
1875
|
+
max_runtime_in_seconds: Optional[int] = None,
|
|
1876
|
+
sagemaker_session: Optional[Session] = None,
|
|
1877
|
+
env: Optional[Dict[str, str]] = None,
|
|
1878
|
+
tags: Optional[Tags] = None,
|
|
1879
|
+
network_config: Optional[NetworkConfig] = None,
|
|
1880
|
+
job_name_prefix: Optional[str] = None,
|
|
1881
|
+
version: Optional[str] = None,
|
|
1882
|
+
skip_early_validation: bool = False,
|
|
1883
|
+
):
|
|
1884
|
+
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
|
|
1885
|
+
|
|
1886
|
+
Instance of :class:`~sagemaker.processing.Processor`.
|
|
1887
|
+
|
|
1888
|
+
Args:
|
|
1889
|
+
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
|
|
1890
|
+
uses this role to access AWS resources, such as
|
|
1891
|
+
data stored in Amazon S3.
|
|
1892
|
+
instance_count (int): The number of instances to run
|
|
1893
|
+
a processing job with.
|
|
1894
|
+
instance_type (str): The type of
|
|
1895
|
+
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
|
|
1896
|
+
to use for model inference; for example, ``"ml.c5.xlarge"``.
|
|
1897
|
+
volume_size_in_gb (int): Size in GB of the
|
|
1898
|
+
`EBS volume <https://docs.aws.amazon.com/sagemaker/latest/dg/host-instance-storage.html>`_.
|
|
1899
|
+
to use for storing data during processing (default: 30 GB).
|
|
1900
|
+
volume_kms_key (str): A
|
|
1901
|
+
`KMS key <https://docs.aws.amazon.com/sagemaker/latest/dg/key-management.html>`_
|
|
1902
|
+
for the processing volume (default: None).
|
|
1903
|
+
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
|
|
1904
|
+
max_runtime_in_seconds (int): Timeout in seconds (default: None).
|
|
1905
|
+
After this amount of time, Amazon SageMaker terminates the job,
|
|
1906
|
+
regardless of its current status. If ``max_runtime_in_seconds`` is not
|
|
1907
|
+
specified, the default value is ``86400`` seconds (24 hours).
|
|
1908
|
+
sagemaker_session (:class:`~sagemaker.session.Session`):
|
|
1909
|
+
:class:`~sagemaker.session.Session` object which manages interactions
|
|
1910
|
+
with Amazon SageMaker and any other AWS services needed. If not specified,
|
|
1911
|
+
the Processor creates a :class:`~sagemaker.session.Session`
|
|
1912
|
+
using the default AWS configuration chain.
|
|
1913
|
+
env (dict[str, str]): Environment variables to be passed to
|
|
1914
|
+
the processing jobs (default: None).
|
|
1915
|
+
tags (Optional[Tags]): Tags to be passed to the processing job
|
|
1916
|
+
(default: None). For more, see
|
|
1917
|
+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
|
|
1918
|
+
network_config (:class:`~sagemaker.network.NetworkConfig`):
|
|
1919
|
+
A :class:`~sagemaker.network.NetworkConfig`
|
|
1920
|
+
object that configures network isolation, encryption of
|
|
1921
|
+
inter-container traffic, security group IDs, and subnets.
|
|
1922
|
+
job_name_prefix (str): Processing job name prefix.
|
|
1923
|
+
version (str): Clarify version to use.
|
|
1924
|
+
skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json.
|
|
1925
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
1926
|
+
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
|
|
1927
|
+
self._last_analysis_config = None
|
|
1928
|
+
self.job_name_prefix = job_name_prefix
|
|
1929
|
+
self.skip_early_validation = skip_early_validation
|
|
1930
|
+
super(SageMakerClarifyProcessor, self).__init__(
|
|
1931
|
+
role,
|
|
1932
|
+
container_uri,
|
|
1933
|
+
instance_count,
|
|
1934
|
+
instance_type,
|
|
1935
|
+
None, # We manage the entrypoint.
|
|
1936
|
+
volume_size_in_gb,
|
|
1937
|
+
volume_kms_key,
|
|
1938
|
+
output_kms_key,
|
|
1939
|
+
max_runtime_in_seconds,
|
|
1940
|
+
None, # We set method-specific job names below.
|
|
1941
|
+
sagemaker_session,
|
|
1942
|
+
env,
|
|
1943
|
+
format_tags(tags),
|
|
1944
|
+
network_config,
|
|
1945
|
+
)
|
|
1946
|
+
|
|
1947
|
+
def run(self, **_):
|
|
1948
|
+
"""Overriding the base class method but deferring to specific run_* methods."""
|
|
1949
|
+
raise NotImplementedError(
|
|
1950
|
+
"Please choose a method of run_pre_training_bias, run_post_training_bias or "
|
|
1951
|
+
"run_explainability."
|
|
1952
|
+
)
|
|
1953
|
+
|
|
1954
|
+
def _run(
|
|
1955
|
+
self,
|
|
1956
|
+
data_config: DataConfig,
|
|
1957
|
+
analysis_config: Dict[str, Any],
|
|
1958
|
+
wait: bool,
|
|
1959
|
+
logs: bool,
|
|
1960
|
+
job_name: str,
|
|
1961
|
+
kms_key: str,
|
|
1962
|
+
experiment_config: Dict[str, str],
|
|
1963
|
+
):
|
|
1964
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container
|
|
1965
|
+
|
|
1966
|
+
and analysis config.
|
|
1967
|
+
|
|
1968
|
+
Args:
|
|
1969
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
1970
|
+
analysis_config (dict): Config following the analysis_config.json format.
|
|
1971
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
1972
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
1973
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
1974
|
+
job_name (str): Processing job name.
|
|
1975
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
1976
|
+
user code file (default: None).
|
|
1977
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
1978
|
+
Optionally, the dict can contain three keys:
|
|
1979
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
1980
|
+
|
|
1981
|
+
The behavior of setting these keys is as follows:
|
|
1982
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
1983
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
1984
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
1985
|
+
the job's Trial Component will be associated with the Trial.
|
|
1986
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
1987
|
+
the Trial Component will be unassociated.
|
|
1988
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
1989
|
+
"""
|
|
1990
|
+
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
|
|
1991
|
+
self._last_analysis_config = analysis_config
|
|
1992
|
+
logger.info("Analysis Config: %s", analysis_config)
|
|
1993
|
+
if not self.skip_early_validation:
|
|
1994
|
+
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
|
|
1995
|
+
|
|
1996
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
1997
|
+
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
|
|
1998
|
+
with open(analysis_config_file, "w") as f:
|
|
1999
|
+
json.dump(analysis_config, f)
|
|
2000
|
+
s3_analysis_config_file = _upload_analysis_config(
|
|
2001
|
+
analysis_config_file,
|
|
2002
|
+
data_config.s3_analysis_config_output_path or data_config.s3_output_path,
|
|
2003
|
+
self.sagemaker_session,
|
|
2004
|
+
kms_key,
|
|
2005
|
+
)
|
|
2006
|
+
from sagemaker.core.shapes import ProcessingS3Input, ProcessingS3Output
|
|
2007
|
+
|
|
2008
|
+
config_input = ProcessingInput(
|
|
2009
|
+
input_name="analysis_config",
|
|
2010
|
+
s3_input=ProcessingS3Input(
|
|
2011
|
+
s3_uri=s3_analysis_config_file,
|
|
2012
|
+
local_path=self._CLARIFY_CONFIG_INPUT,
|
|
2013
|
+
s3_data_type="S3Prefix",
|
|
2014
|
+
s3_input_mode="File",
|
|
2015
|
+
s3_compression_type="None",
|
|
2016
|
+
),
|
|
2017
|
+
)
|
|
2018
|
+
data_input = ProcessingInput(
|
|
2019
|
+
input_name="dataset",
|
|
2020
|
+
s3_input=ProcessingS3Input(
|
|
2021
|
+
s3_uri=data_config.s3_data_input_path,
|
|
2022
|
+
local_path=self._CLARIFY_DATA_INPUT,
|
|
2023
|
+
s3_data_type="S3Prefix",
|
|
2024
|
+
s3_input_mode="File",
|
|
2025
|
+
s3_data_distribution_type=data_config.s3_data_distribution_type,
|
|
2026
|
+
s3_compression_type=data_config.s3_compression_type,
|
|
2027
|
+
),
|
|
2028
|
+
)
|
|
2029
|
+
result_output = ProcessingOutput(
|
|
2030
|
+
output_name="analysis_result",
|
|
2031
|
+
s3_output=ProcessingS3Output(
|
|
2032
|
+
s3_uri=data_config.s3_output_path,
|
|
2033
|
+
local_path=self._CLARIFY_OUTPUT,
|
|
2034
|
+
s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
|
|
2035
|
+
),
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
return super().run(
|
|
2039
|
+
inputs=[data_input, config_input],
|
|
2040
|
+
outputs=[result_output],
|
|
2041
|
+
wait=wait,
|
|
2042
|
+
logs=logs,
|
|
2043
|
+
job_name=job_name,
|
|
2044
|
+
kms_key=kms_key,
|
|
2045
|
+
experiment_config=experiment_config,
|
|
2046
|
+
)
|
|
2047
|
+
|
|
2048
|
+
def run_pre_training_bias(
|
|
2049
|
+
self,
|
|
2050
|
+
data_config: DataConfig,
|
|
2051
|
+
data_bias_config: BiasConfig,
|
|
2052
|
+
methods: Union[str, List[str]] = "all",
|
|
2053
|
+
wait: bool = True,
|
|
2054
|
+
logs: bool = True,
|
|
2055
|
+
job_name: Optional[str] = None,
|
|
2056
|
+
kms_key: Optional[str] = None,
|
|
2057
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
2058
|
+
):
|
|
2059
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods
|
|
2060
|
+
|
|
2061
|
+
Computes the requested ``methods`` on the input data. The ``methods`` compare
|
|
2062
|
+
metrics (e.g. fraction of examples) for the sensitive group(s) vs. the other examples.
|
|
2063
|
+
|
|
2064
|
+
Args:
|
|
2065
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
2066
|
+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
|
|
2067
|
+
methods (str or list[str]): Selects a subset of potential metrics:
|
|
2068
|
+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
|
|
2069
|
+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
|
|
2070
|
+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
|
|
2071
|
+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
|
|
2072
|
+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
|
|
2073
|
+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
|
|
2074
|
+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
|
|
2075
|
+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
|
|
2076
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2077
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
2078
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
2079
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
2080
|
+
job_name (str): Processing job name. When ``job_name`` is not specified,
|
|
2081
|
+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is
|
|
2082
|
+
specified, the job name will be the ``job_name_prefix`` and current timestamp;
|
|
2083
|
+
otherwise use ``"Clarify-Pretraining-Bias"`` as prefix.
|
|
2084
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2085
|
+
user code file (default: None).
|
|
2086
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
2087
|
+
Optionally, the dict can contain three keys:
|
|
2088
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
2089
|
+
|
|
2090
|
+
The behavior of setting these keys is as follows:
|
|
2091
|
+
|
|
2092
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
2093
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
2094
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
2095
|
+
the job's Trial Component will be associated with the Trial.
|
|
2096
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
2097
|
+
the Trial Component will be unassociated.
|
|
2098
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
2099
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
2100
|
+
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
|
|
2101
|
+
data_config, data_bias_config, methods
|
|
2102
|
+
)
|
|
2103
|
+
# when name is either not provided (is None) or an empty string ("")
|
|
2104
|
+
job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Pretraining-Bias")
|
|
2105
|
+
return self._run(
|
|
2106
|
+
data_config,
|
|
2107
|
+
analysis_config,
|
|
2108
|
+
wait,
|
|
2109
|
+
logs,
|
|
2110
|
+
job_name,
|
|
2111
|
+
kms_key,
|
|
2112
|
+
experiment_config,
|
|
2113
|
+
)
|
|
2114
|
+
|
|
2115
|
+
def run_post_training_bias(
|
|
2116
|
+
self,
|
|
2117
|
+
data_config: DataConfig,
|
|
2118
|
+
data_bias_config: BiasConfig,
|
|
2119
|
+
model_config: Optional[ModelConfig] = None,
|
|
2120
|
+
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
|
|
2121
|
+
methods: Union[str, List[str]] = "all",
|
|
2122
|
+
wait: bool = True,
|
|
2123
|
+
logs: bool = True,
|
|
2124
|
+
job_name: Optional[str] = None,
|
|
2125
|
+
kms_key: Optional[str] = None,
|
|
2126
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
2127
|
+
):
|
|
2128
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias
|
|
2129
|
+
|
|
2130
|
+
Spins up a model endpoint and runs inference over the input dataset in
|
|
2131
|
+
the ``s3_data_input_path`` (from the :class:`~sagemaker.clarify.DataConfig`) to obtain
|
|
2132
|
+
predicted labels. Using model predictions, computes the requested posttraining bias
|
|
2133
|
+
``methods`` that compare metrics (e.g. accuracy, precision, recall) for the
|
|
2134
|
+
sensitive group(s) versus the other examples.
|
|
2135
|
+
|
|
2136
|
+
Args:
|
|
2137
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
2138
|
+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
|
|
2139
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
2140
|
+
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
|
|
2141
|
+
``predicted_label`` is provided in ``data_config``.
|
|
2142
|
+
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
|
|
2143
|
+
Config of how to extract the predicted label from the model output.
|
|
2144
|
+
methods (str or list[str]): Selector of a subset of potential metrics:
|
|
2145
|
+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
|
|
2146
|
+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
|
|
2147
|
+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
|
|
2148
|
+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
|
|
2149
|
+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
|
|
2150
|
+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
|
|
2151
|
+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
|
|
2152
|
+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
|
|
2153
|
+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
|
|
2154
|
+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
|
|
2155
|
+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
|
|
2156
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2157
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
2158
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
2159
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
2160
|
+
job_name (str): Processing job name. When ``job_name`` is not specified,
|
|
2161
|
+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
|
|
2162
|
+
is specified, the job name will be the ``job_name_prefix`` and current timestamp;
|
|
2163
|
+
otherwise use ``"Clarify-Posttraining-Bias"`` as prefix.
|
|
2164
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2165
|
+
user code file (default: None).
|
|
2166
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
2167
|
+
Optionally, the dict can contain three keys:
|
|
2168
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
2169
|
+
|
|
2170
|
+
The behavior of setting these keys is as follows:
|
|
2171
|
+
|
|
2172
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
2173
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
2174
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
2175
|
+
the job's Trial Component will be associated with the Trial.
|
|
2176
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
2177
|
+
the Trial Component will be unassociated.
|
|
2178
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
2179
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
2180
|
+
analysis_config = _AnalysisConfigGenerator.bias_post_training(
|
|
2181
|
+
data_config,
|
|
2182
|
+
data_bias_config,
|
|
2183
|
+
model_predicted_label_config,
|
|
2184
|
+
methods,
|
|
2185
|
+
model_config,
|
|
2186
|
+
)
|
|
2187
|
+
# when name is either not provided (is None) or an empty string ("")
|
|
2188
|
+
job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Posttraining-Bias")
|
|
2189
|
+
return self._run(
|
|
2190
|
+
data_config,
|
|
2191
|
+
analysis_config,
|
|
2192
|
+
wait,
|
|
2193
|
+
logs,
|
|
2194
|
+
job_name,
|
|
2195
|
+
kms_key,
|
|
2196
|
+
experiment_config,
|
|
2197
|
+
)
|
|
2198
|
+
|
|
2199
|
+
def run_bias(
|
|
2200
|
+
self,
|
|
2201
|
+
data_config: DataConfig,
|
|
2202
|
+
bias_config: BiasConfig,
|
|
2203
|
+
model_config: Optional[ModelConfig] = None,
|
|
2204
|
+
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
|
|
2205
|
+
pre_training_methods: Union[str, List[str]] = "all",
|
|
2206
|
+
post_training_methods: Union[str, List[str]] = "all",
|
|
2207
|
+
wait: bool = True,
|
|
2208
|
+
logs: bool = True,
|
|
2209
|
+
job_name: Optional[str] = None,
|
|
2210
|
+
kms_key: Optional[str] = None,
|
|
2211
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
2212
|
+
):
|
|
2213
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods
|
|
2214
|
+
|
|
2215
|
+
Computes metrics for both the pre-training and the post-training methods.
|
|
2216
|
+
To calculate post-training methods, it spins up a model endpoint and runs inference over the
|
|
2217
|
+
input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
|
|
2218
|
+
to obtain predicted labels.
|
|
2219
|
+
|
|
2220
|
+
Args:
|
|
2221
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
2222
|
+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
|
|
2223
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
2224
|
+
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
|
|
2225
|
+
``predicted_label`` is provided in ``data_config``.
|
|
2226
|
+
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
|
|
2227
|
+
Config of how to extract the predicted label from the model output.
|
|
2228
|
+
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
|
|
2229
|
+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
|
|
2230
|
+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
|
|
2231
|
+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
|
|
2232
|
+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
|
|
2233
|
+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
|
|
2234
|
+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
|
|
2235
|
+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
|
|
2236
|
+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
|
|
2237
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2238
|
+
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
|
|
2239
|
+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
|
|
2240
|
+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
|
|
2241
|
+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
|
|
2242
|
+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
|
|
2243
|
+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
|
|
2244
|
+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
|
|
2245
|
+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
|
|
2246
|
+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
|
|
2247
|
+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
|
|
2248
|
+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
|
|
2249
|
+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
|
|
2250
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2251
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
2252
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
2253
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
2254
|
+
job_name (str): Processing job name. When ``job_name`` is not specified,
|
|
2255
|
+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is
|
|
2256
|
+
specified, the job name will be ``job_name_prefix`` and the current timestamp;
|
|
2257
|
+
otherwise use ``"Clarify-Bias"`` as prefix.
|
|
2258
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2259
|
+
user code file (default: None).
|
|
2260
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
2261
|
+
Optionally, the dict can contain three keys:
|
|
2262
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
2263
|
+
|
|
2264
|
+
The behavior of setting these keys is as follows:
|
|
2265
|
+
|
|
2266
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
2267
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
2268
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
2269
|
+
the job's Trial Component will be associated with the Trial.
|
|
2270
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
2271
|
+
the Trial Component will be unassociated.
|
|
2272
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
2273
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
2274
|
+
analysis_config = _AnalysisConfigGenerator.bias(
|
|
2275
|
+
data_config,
|
|
2276
|
+
bias_config,
|
|
2277
|
+
model_config,
|
|
2278
|
+
model_predicted_label_config,
|
|
2279
|
+
pre_training_methods,
|
|
2280
|
+
post_training_methods,
|
|
2281
|
+
)
|
|
2282
|
+
# when name is either not provided (is None) or an empty string ("")
|
|
2283
|
+
job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Bias")
|
|
2284
|
+
return self._run(
|
|
2285
|
+
data_config,
|
|
2286
|
+
analysis_config,
|
|
2287
|
+
wait,
|
|
2288
|
+
logs,
|
|
2289
|
+
job_name,
|
|
2290
|
+
kms_key,
|
|
2291
|
+
experiment_config,
|
|
2292
|
+
)
|
|
2293
|
+
|
|
2294
|
+
def run_explainability(
|
|
2295
|
+
self,
|
|
2296
|
+
data_config: DataConfig,
|
|
2297
|
+
model_config: ModelConfig,
|
|
2298
|
+
explainability_config: Union[ExplainabilityConfig, List],
|
|
2299
|
+
model_scores: Optional[Union[int, str, ModelPredictedLabelConfig]] = None,
|
|
2300
|
+
wait: bool = True,
|
|
2301
|
+
logs: bool = True,
|
|
2302
|
+
job_name: Optional[str] = None,
|
|
2303
|
+
kms_key: Optional[str] = None,
|
|
2304
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
2305
|
+
):
|
|
2306
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
|
|
2307
|
+
|
|
2308
|
+
Spins up a model endpoint.
|
|
2309
|
+
|
|
2310
|
+
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
|
|
2311
|
+
as explainability methods.
|
|
2312
|
+
You can request both methods or one at a time with the ``explainability_config`` parameter.
|
|
2313
|
+
|
|
2314
|
+
When SHAP is requested in the ``explainability_config``,
|
|
2315
|
+
the SHAP algorithm calculates the feature importance for each input example
|
|
2316
|
+
in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
|
|
2317
|
+
by creating ``num_samples`` copies of the example with a subset of features
|
|
2318
|
+
replaced with values from the ``baseline``.
|
|
2319
|
+
It then runs model inference to see how the model's prediction changes with the replaced
|
|
2320
|
+
features. If the model output returns multiple scores importance is computed for each score.
|
|
2321
|
+
Across examples, feature importance is aggregated using ``agg_method``.
|
|
2322
|
+
|
|
2323
|
+
When PDP is requested in the ``explainability_config``,
|
|
2324
|
+
the PDP algorithm calculates the dependence of the target response
|
|
2325
|
+
on the input features and marginalizes over the values of all other input features.
|
|
2326
|
+
The Partial Dependence Plots are included in the output
|
|
2327
|
+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
|
|
2328
|
+
and the corresponding values are included in the analysis output.
|
|
2329
|
+
|
|
2330
|
+
Args:
|
|
2331
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
2332
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
2333
|
+
endpoint to be created.
|
|
2334
|
+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
|
|
2335
|
+
Config of the specific explainability method or a list of
|
|
2336
|
+
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
|
|
2337
|
+
Currently, SHAP and PDP are the two methods supported.
|
|
2338
|
+
You can request multiple methods at once by passing in a list of
|
|
2339
|
+
`~sagemaker.clarify.ExplainabilityConfig`.
|
|
2340
|
+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
|
|
2341
|
+
Index or JMESPath expression to locate the predicted scores in the model output.
|
|
2342
|
+
This is not required if the model output is a single score. Alternatively,
|
|
2343
|
+
it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
|
|
2344
|
+
to provide more parameters like ``label_headers``.
|
|
2345
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
2346
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
2347
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
2348
|
+
job_name (str): Processing job name. When ``job_name`` is not specified,
|
|
2349
|
+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
|
|
2350
|
+
is specified, the job name will be composed of ``job_name_prefix`` and current
|
|
2351
|
+
timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
|
|
2352
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2353
|
+
user code file (default: None).
|
|
2354
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
2355
|
+
Optionally, the dict can contain three keys:
|
|
2356
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
2357
|
+
|
|
2358
|
+
The behavior of setting these keys is as follows:
|
|
2359
|
+
|
|
2360
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
2361
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
2362
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
2363
|
+
the job's Trial Component will be associated with the Trial.
|
|
2364
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
2365
|
+
the Trial Component will be unassociated.
|
|
2366
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
2367
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
2368
|
+
analysis_config = _AnalysisConfigGenerator.explainability(
|
|
2369
|
+
data_config, model_config, model_scores, explainability_config
|
|
2370
|
+
)
|
|
2371
|
+
# when name is either not provided (is None) or an empty string ("")
|
|
2372
|
+
job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Explainability")
|
|
2373
|
+
return self._run(
|
|
2374
|
+
data_config,
|
|
2375
|
+
analysis_config,
|
|
2376
|
+
wait,
|
|
2377
|
+
logs,
|
|
2378
|
+
job_name,
|
|
2379
|
+
kms_key,
|
|
2380
|
+
experiment_config,
|
|
2381
|
+
)
|
|
2382
|
+
|
|
2383
|
+
def run_bias_and_explainability(
|
|
2384
|
+
self,
|
|
2385
|
+
data_config: DataConfig,
|
|
2386
|
+
model_config: ModelConfig,
|
|
2387
|
+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
|
|
2388
|
+
bias_config: BiasConfig,
|
|
2389
|
+
pre_training_methods: Union[str, List[str]] = "all",
|
|
2390
|
+
post_training_methods: Union[str, List[str]] = "all",
|
|
2391
|
+
model_predicted_label_config: ModelPredictedLabelConfig = None,
|
|
2392
|
+
wait=True,
|
|
2393
|
+
logs=True,
|
|
2394
|
+
job_name=None,
|
|
2395
|
+
kms_key=None,
|
|
2396
|
+
experiment_config=None,
|
|
2397
|
+
):
|
|
2398
|
+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
|
|
2399
|
+
|
|
2400
|
+
For bias:
|
|
2401
|
+
Computes metrics for both the pre-training and the post-training methods.
|
|
2402
|
+
To calculate post-training methods, it spins up a model endpoint and runs inference over the
|
|
2403
|
+
input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
|
|
2404
|
+
to obtain predicted labels.
|
|
2405
|
+
|
|
2406
|
+
For Explainability:
|
|
2407
|
+
Spins up a model endpoint.
|
|
2408
|
+
|
|
2409
|
+
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
|
|
2410
|
+
as explainability methods.
|
|
2411
|
+
You can request both methods or one at a time with the ``explainability_config`` parameter.
|
|
2412
|
+
|
|
2413
|
+
When SHAP is requested in the ``explainability_config``,
|
|
2414
|
+
the SHAP algorithm calculates the feature importance for each input example
|
|
2415
|
+
in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
|
|
2416
|
+
by creating ``num_samples`` copies of the example with a subset of features
|
|
2417
|
+
replaced with values from the ``baseline``.
|
|
2418
|
+
It then runs model inference to see how the model's prediction changes with the replaced
|
|
2419
|
+
features. If the model output returns multiple scores importance is computed for each score.
|
|
2420
|
+
Across examples, feature importance is aggregated using ``agg_method``.
|
|
2421
|
+
|
|
2422
|
+
When PDP is requested in the ``explainability_config``,
|
|
2423
|
+
the PDP algorithm calculates the dependence of the target response
|
|
2424
|
+
on the input features and marginalizes over the values of all other input features.
|
|
2425
|
+
The Partial Dependence Plots are included in the output
|
|
2426
|
+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
|
|
2427
|
+
and the corresponding values are included in the analysis output.
|
|
2428
|
+
|
|
2429
|
+
Args:
|
|
2430
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
2431
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
2432
|
+
endpoint to be created.
|
|
2433
|
+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
|
|
2434
|
+
Config of the specific explainability method or a list of
|
|
2435
|
+
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
|
|
2436
|
+
Currently, SHAP and PDP are the two methods supported.
|
|
2437
|
+
You can request multiple methods at once by passing in a list of
|
|
2438
|
+
`~sagemaker.clarify.ExplainabilityConfig`.
|
|
2439
|
+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
|
|
2440
|
+
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
|
|
2441
|
+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
|
|
2442
|
+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
|
|
2443
|
+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
|
|
2444
|
+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
|
|
2445
|
+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
|
|
2446
|
+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
|
|
2447
|
+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
|
|
2448
|
+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
|
|
2449
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2450
|
+
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
|
|
2451
|
+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
|
|
2452
|
+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
|
|
2453
|
+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
|
|
2454
|
+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
|
|
2455
|
+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
|
|
2456
|
+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
|
|
2457
|
+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
|
|
2458
|
+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
|
|
2459
|
+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
|
|
2460
|
+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
|
|
2461
|
+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
|
|
2462
|
+
Defaults to str "all" to run all metrics if left unspecified.
|
|
2463
|
+
model_predicted_label_config (
|
|
2464
|
+
int or
|
|
2465
|
+
str or
|
|
2466
|
+
:class:`~sagemaker.clarify.ModelPredictedLabelConfig`
|
|
2467
|
+
):
|
|
2468
|
+
Index or JMESPath expression to locate the predicted scores in the model output.
|
|
2469
|
+
This is not required if the model output is a single score. Alternatively,
|
|
2470
|
+
it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
|
|
2471
|
+
to provide more parameters like ``label_headers``.
|
|
2472
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
2473
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
2474
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
2475
|
+
job_name (str): Processing job name. When ``job_name`` is not specified,
|
|
2476
|
+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
|
|
2477
|
+
is specified, the job name will be composed of ``job_name_prefix`` and current
|
|
2478
|
+
timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
|
|
2479
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2480
|
+
user code file (default: None).
|
|
2481
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
2482
|
+
Optionally, the dict can contain three keys:
|
|
2483
|
+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
|
|
2484
|
+
|
|
2485
|
+
The behavior of setting these keys is as follows:
|
|
2486
|
+
|
|
2487
|
+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
|
|
2488
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
2489
|
+
* If ``'TrialName'`` is supplied and the Trial already exists,
|
|
2490
|
+
the job's Trial Component will be associated with the Trial.
|
|
2491
|
+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
|
|
2492
|
+
the Trial Component will be unassociated.
|
|
2493
|
+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
|
2494
|
+
""" # noqa E501 # pylint: disable=c0301
|
|
2495
|
+
analysis_config = _AnalysisConfigGenerator.bias_and_explainability(
|
|
2496
|
+
data_config,
|
|
2497
|
+
model_config,
|
|
2498
|
+
model_predicted_label_config,
|
|
2499
|
+
explainability_config,
|
|
2500
|
+
bias_config,
|
|
2501
|
+
pre_training_methods,
|
|
2502
|
+
post_training_methods,
|
|
2503
|
+
)
|
|
2504
|
+
# when name is either not provided (is None) or an empty string ("")
|
|
2505
|
+
job_name = job_name or name_from_base(
|
|
2506
|
+
self.job_name_prefix or "Clarify-Bias-And-Explainability"
|
|
2507
|
+
)
|
|
2508
|
+
return self._run(
|
|
2509
|
+
data_config,
|
|
2510
|
+
analysis_config,
|
|
2511
|
+
wait,
|
|
2512
|
+
logs,
|
|
2513
|
+
job_name,
|
|
2514
|
+
kms_key,
|
|
2515
|
+
experiment_config,
|
|
2516
|
+
)
|
|
2517
|
+
|
|
2518
|
+
|
|
2519
|
+
class _AnalysisConfigGenerator:
|
|
2520
|
+
"""Creates analysis_config objects for different type of runs."""
|
|
2521
|
+
|
|
2522
|
+
@classmethod
|
|
2523
|
+
def bias_and_explainability(
|
|
2524
|
+
cls,
|
|
2525
|
+
data_config: DataConfig,
|
|
2526
|
+
model_config: ModelConfig,
|
|
2527
|
+
model_predicted_label_config: ModelPredictedLabelConfig,
|
|
2528
|
+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
|
|
2529
|
+
bias_config: BiasConfig,
|
|
2530
|
+
pre_training_methods: Union[str, List[str]] = "all",
|
|
2531
|
+
post_training_methods: Union[str, List[str]] = "all",
|
|
2532
|
+
):
|
|
2533
|
+
"""Generates a config for Bias and Explainability"""
|
|
2534
|
+
# TimeSeries bias metrics are not supported
|
|
2535
|
+
if (
|
|
2536
|
+
isinstance(explainability_config, AsymmetricShapleyValueConfig)
|
|
2537
|
+
or "time_series_data_config" in data_config.analysis_config
|
|
2538
|
+
or (model_config and "time_series_predictor_config" in model_config.predictor_config)
|
|
2539
|
+
):
|
|
2540
|
+
raise ValueError("Bias metrics are unsupported for time series.")
|
|
2541
|
+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
|
|
2542
|
+
analysis_config = cls._add_methods(
|
|
2543
|
+
analysis_config,
|
|
2544
|
+
pre_training_methods=pre_training_methods,
|
|
2545
|
+
post_training_methods=post_training_methods,
|
|
2546
|
+
explainability_config=explainability_config,
|
|
2547
|
+
)
|
|
2548
|
+
analysis_config = cls._add_predictor(
|
|
2549
|
+
analysis_config, model_config, model_predicted_label_config
|
|
2550
|
+
)
|
|
2551
|
+
return analysis_config
|
|
2552
|
+
|
|
2553
|
+
@classmethod
|
|
2554
|
+
def explainability(
|
|
2555
|
+
cls,
|
|
2556
|
+
data_config: DataConfig,
|
|
2557
|
+
model_config: ModelConfig,
|
|
2558
|
+
model_predicted_label_config: ModelPredictedLabelConfig,
|
|
2559
|
+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
|
|
2560
|
+
):
|
|
2561
|
+
"""Generates a config for Explainability"""
|
|
2562
|
+
# determine if this is a time series explainability case by checking
|
|
2563
|
+
# if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given
|
|
2564
|
+
ts_data_conf_absent = "time_series_data_config" not in data_config.analysis_config
|
|
2565
|
+
ts_model_conf_absent = "time_series_predictor_config" not in model_config.predictor_config
|
|
2566
|
+
|
|
2567
|
+
if isinstance(explainability_config, AsymmetricShapleyValueConfig):
|
|
2568
|
+
if ts_data_conf_absent:
|
|
2569
|
+
raise ValueError("Please provide a TimeSeriesDataConfig to DataConfig.")
|
|
2570
|
+
if ts_model_conf_absent:
|
|
2571
|
+
raise ValueError("Please provide a TimeSeriesModelConfig to ModelConfig.")
|
|
2572
|
+
# Check static covariates baseline matches number of provided static covariate columns
|
|
2573
|
+
_AnalysisConfigGenerator._validate_time_series_static_covariates_baseline(
|
|
2574
|
+
explainability_config=explainability_config,
|
|
2575
|
+
data_config=data_config,
|
|
2576
|
+
)
|
|
2577
|
+
else:
|
|
2578
|
+
if not ts_data_conf_absent:
|
|
2579
|
+
raise ValueError(
|
|
2580
|
+
"Please provide an AsymmetricShapleyValueConfig for time series "
|
|
2581
|
+
"explainability cases. For non time series cases, please do not provide a "
|
|
2582
|
+
"TimeSeriesDataConfig."
|
|
2583
|
+
)
|
|
2584
|
+
if not ts_model_conf_absent:
|
|
2585
|
+
raise ValueError(
|
|
2586
|
+
"Please provide an AsymmetricShapleyValueConfig for time series "
|
|
2587
|
+
"explainability cases. For non time series cases, please do not provide a "
|
|
2588
|
+
"TimeSeriesModelConfig."
|
|
2589
|
+
)
|
|
2590
|
+
|
|
2591
|
+
# construct whole analysis config
|
|
2592
|
+
analysis_config = data_config.analysis_config
|
|
2593
|
+
analysis_config = cls._add_predictor(
|
|
2594
|
+
analysis_config, model_config, model_predicted_label_config
|
|
2595
|
+
)
|
|
2596
|
+
analysis_config = cls._add_methods(
|
|
2597
|
+
analysis_config,
|
|
2598
|
+
explainability_config=explainability_config,
|
|
2599
|
+
)
|
|
2600
|
+
return analysis_config
|
|
2601
|
+
|
|
2602
|
+
@classmethod
|
|
2603
|
+
def bias_pre_training(
|
|
2604
|
+
cls,
|
|
2605
|
+
data_config: DataConfig,
|
|
2606
|
+
bias_config: BiasConfig,
|
|
2607
|
+
methods: Union[str, List[str]],
|
|
2608
|
+
):
|
|
2609
|
+
"""Generates a config for Bias Pre Training"""
|
|
2610
|
+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
|
|
2611
|
+
analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
|
|
2612
|
+
return analysis_config
|
|
2613
|
+
|
|
2614
|
+
@classmethod
|
|
2615
|
+
def bias_post_training(
|
|
2616
|
+
cls,
|
|
2617
|
+
data_config: DataConfig,
|
|
2618
|
+
bias_config: BiasConfig,
|
|
2619
|
+
model_predicted_label_config: ModelPredictedLabelConfig,
|
|
2620
|
+
methods: Union[str, List[str]],
|
|
2621
|
+
model_config: ModelConfig,
|
|
2622
|
+
):
|
|
2623
|
+
"""Generates a config for Bias Post Training"""
|
|
2624
|
+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
|
|
2625
|
+
analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
|
|
2626
|
+
analysis_config = cls._add_predictor(
|
|
2627
|
+
analysis_config, model_config, model_predicted_label_config
|
|
2628
|
+
)
|
|
2629
|
+
return analysis_config
|
|
2630
|
+
|
|
2631
|
+
@classmethod
|
|
2632
|
+
def bias(
|
|
2633
|
+
cls,
|
|
2634
|
+
data_config: DataConfig,
|
|
2635
|
+
bias_config: BiasConfig,
|
|
2636
|
+
model_config: ModelConfig,
|
|
2637
|
+
model_predicted_label_config: ModelPredictedLabelConfig,
|
|
2638
|
+
pre_training_methods: Union[str, List[str]] = "all",
|
|
2639
|
+
post_training_methods: Union[str, List[str]] = "all",
|
|
2640
|
+
):
|
|
2641
|
+
"""Generates a config for Bias"""
|
|
2642
|
+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
|
|
2643
|
+
analysis_config = cls._add_methods(
|
|
2644
|
+
analysis_config,
|
|
2645
|
+
pre_training_methods=pre_training_methods,
|
|
2646
|
+
post_training_methods=post_training_methods,
|
|
2647
|
+
)
|
|
2648
|
+
analysis_config = cls._add_predictor(
|
|
2649
|
+
analysis_config, model_config, model_predicted_label_config
|
|
2650
|
+
)
|
|
2651
|
+
return analysis_config
|
|
2652
|
+
|
|
2653
|
+
@classmethod
|
|
2654
|
+
def _add_predictor(
|
|
2655
|
+
cls,
|
|
2656
|
+
analysis_config: Dict,
|
|
2657
|
+
model_config: ModelConfig,
|
|
2658
|
+
model_predicted_label_config: ModelPredictedLabelConfig,
|
|
2659
|
+
):
|
|
2660
|
+
"""Extends analysis config with predictor."""
|
|
2661
|
+
analysis_config = {**analysis_config}
|
|
2662
|
+
if isinstance(model_config, ModelConfig):
|
|
2663
|
+
analysis_config["predictor"] = model_config.get_predictor_config()
|
|
2664
|
+
else:
|
|
2665
|
+
if (
|
|
2666
|
+
"shap" in analysis_config["methods"]
|
|
2667
|
+
or "pdp" in analysis_config["methods"]
|
|
2668
|
+
or "asymmetric_shapley_value" in analysis_config["methods"]
|
|
2669
|
+
):
|
|
2670
|
+
raise ValueError(
|
|
2671
|
+
"model_config must be provided when explainability methods are selected."
|
|
2672
|
+
)
|
|
2673
|
+
if (
|
|
2674
|
+
"predicted_label_dataset_uri" not in analysis_config
|
|
2675
|
+
and "predicted_label" not in analysis_config
|
|
2676
|
+
):
|
|
2677
|
+
raise ValueError(
|
|
2678
|
+
"model_config must be provided when `predicted_label_dataset_uri` or "
|
|
2679
|
+
"`predicted_label` are not provided in data_config."
|
|
2680
|
+
)
|
|
2681
|
+
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
|
|
2682
|
+
(
|
|
2683
|
+
probability_threshold,
|
|
2684
|
+
predictor_config,
|
|
2685
|
+
) = model_predicted_label_config.get_predictor_config()
|
|
2686
|
+
if predictor_config and "predictor" in analysis_config:
|
|
2687
|
+
analysis_config["predictor"].update(predictor_config)
|
|
2688
|
+
_set(probability_threshold, "probability_threshold", analysis_config)
|
|
2689
|
+
elif "predictor" in analysis_config:
|
|
2690
|
+
_set(model_predicted_label_config, "label", analysis_config["predictor"])
|
|
2691
|
+
return analysis_config
|
|
2692
|
+
|
|
2693
|
+
@classmethod
|
|
2694
|
+
def _add_methods(
|
|
2695
|
+
cls,
|
|
2696
|
+
analysis_config: Dict,
|
|
2697
|
+
pre_training_methods: Union[str, List[str]] = None,
|
|
2698
|
+
post_training_methods: Union[str, List[str]] = None,
|
|
2699
|
+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None,
|
|
2700
|
+
report: bool = True,
|
|
2701
|
+
):
|
|
2702
|
+
"""Extends analysis config with methods."""
|
|
2703
|
+
# validate
|
|
2704
|
+
params = [pre_training_methods, post_training_methods, explainability_config]
|
|
2705
|
+
if not any(params):
|
|
2706
|
+
raise AttributeError(
|
|
2707
|
+
"analysis_config must have at least one working method: "
|
|
2708
|
+
"One of the "
|
|
2709
|
+
"`pre_training_methods`, `post_training_methods`, `explainability_config`."
|
|
2710
|
+
)
|
|
2711
|
+
|
|
2712
|
+
# main logic
|
|
2713
|
+
analysis_config = {**analysis_config}
|
|
2714
|
+
if "methods" not in analysis_config:
|
|
2715
|
+
analysis_config["methods"] = {}
|
|
2716
|
+
|
|
2717
|
+
if report:
|
|
2718
|
+
analysis_config["methods"]["report"] = {
|
|
2719
|
+
"name": "report",
|
|
2720
|
+
"title": "Analysis Report",
|
|
2721
|
+
}
|
|
2722
|
+
|
|
2723
|
+
if pre_training_methods:
|
|
2724
|
+
analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
|
|
2725
|
+
|
|
2726
|
+
if post_training_methods:
|
|
2727
|
+
analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
|
|
2728
|
+
|
|
2729
|
+
if explainability_config is not None:
|
|
2730
|
+
if isinstance(explainability_config, AsymmetricShapleyValueConfig):
|
|
2731
|
+
explainability_methods = explainability_config.get_explainability_config()
|
|
2732
|
+
else:
|
|
2733
|
+
explainability_methods = cls._merge_explainability_configs(
|
|
2734
|
+
explainability_config,
|
|
2735
|
+
)
|
|
2736
|
+
analysis_config["methods"] = {
|
|
2737
|
+
**analysis_config["methods"],
|
|
2738
|
+
**explainability_methods,
|
|
2739
|
+
}
|
|
2740
|
+
return analysis_config
|
|
2741
|
+
|
|
2742
|
+
@classmethod
|
|
2743
|
+
def _merge_explainability_configs(
|
|
2744
|
+
cls,
|
|
2745
|
+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
|
|
2746
|
+
):
|
|
2747
|
+
"""Merges explainability configs, when more than one."""
|
|
2748
|
+
non_ts = "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses."
|
|
2749
|
+
# validation
|
|
2750
|
+
if isinstance(explainability_config, AsymmetricShapleyValueConfig):
|
|
2751
|
+
raise ValueError(non_ts)
|
|
2752
|
+
if (
|
|
2753
|
+
isinstance(explainability_config, PDPConfig)
|
|
2754
|
+
and "features" not in explainability_config.get_explainability_config()["pdp"]
|
|
2755
|
+
):
|
|
2756
|
+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
|
|
2757
|
+
if isinstance(explainability_config, list):
|
|
2758
|
+
if len(explainability_config) == 0:
|
|
2759
|
+
raise ValueError("Please provide at least one explainability config.")
|
|
2760
|
+
# list validation
|
|
2761
|
+
for config in explainability_config:
|
|
2762
|
+
# ensure all provided explainability configs are not AsymmetricShapleyValueConfig
|
|
2763
|
+
if isinstance(config, AsymmetricShapleyValueConfig):
|
|
2764
|
+
raise ValueError(non_ts)
|
|
2765
|
+
# main logic
|
|
2766
|
+
explainability_methods = {}
|
|
2767
|
+
for config in explainability_config:
|
|
2768
|
+
explain_config = config.get_explainability_config()
|
|
2769
|
+
explainability_methods.update(explain_config)
|
|
2770
|
+
if not len(explainability_methods) == len(explainability_config):
|
|
2771
|
+
raise ValueError("Duplicate explainability configs are provided")
|
|
2772
|
+
if (
|
|
2773
|
+
"shap" not in explainability_methods
|
|
2774
|
+
and "features" not in explainability_methods["pdp"]
|
|
2775
|
+
):
|
|
2776
|
+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
|
|
2777
|
+
return explainability_methods
|
|
2778
|
+
return explainability_config.get_explainability_config()
|
|
2779
|
+
|
|
2780
|
+
@classmethod
|
|
2781
|
+
def _validate_time_series_static_covariates_baseline(
|
|
2782
|
+
cls,
|
|
2783
|
+
explainability_config: AsymmetricShapleyValueConfig,
|
|
2784
|
+
data_config: DataConfig,
|
|
2785
|
+
):
|
|
2786
|
+
"""Validates static covariates in baseline for asymmetric shapley value (for time series).
|
|
2787
|
+
|
|
2788
|
+
Checks that baseline values set for static covariate columns are
|
|
2789
|
+
consistent between every item_id and the number of static covariate columns
|
|
2790
|
+
provided in DataConfig.
|
|
2791
|
+
"""
|
|
2792
|
+
baseline = explainability_config.get_explainability_config()[
|
|
2793
|
+
"asymmetric_shapley_value"
|
|
2794
|
+
].get("baseline")
|
|
2795
|
+
if isinstance(baseline, dict) and "static_covariates" in baseline:
|
|
2796
|
+
covariate_count = len(
|
|
2797
|
+
data_config.get_config()["time_series_data_config"].get("static_covariates", [])
|
|
2798
|
+
)
|
|
2799
|
+
if covariate_count > 0:
|
|
2800
|
+
for item_id in baseline.get("static_covariates", []):
|
|
2801
|
+
baseline_entry = baseline["static_covariates"][item_id]
|
|
2802
|
+
if not isinstance(baseline_entry, list):
|
|
2803
|
+
raise ValueError(
|
|
2804
|
+
f"Baseline entry for {item_id} must be a list, is "
|
|
2805
|
+
f"{type(baseline_entry)}."
|
|
2806
|
+
)
|
|
2807
|
+
if len(baseline_entry) != covariate_count:
|
|
2808
|
+
raise ValueError(
|
|
2809
|
+
f"Length of baseline entry for {item_id} does not match number "
|
|
2810
|
+
f"of static covariate columns. Please ensure every covariate "
|
|
2811
|
+
f"has a baseline value for every item id."
|
|
2812
|
+
)
|
|
2813
|
+
else:
|
|
2814
|
+
raise ValueError(
|
|
2815
|
+
"Static covariate baselines are provided in AsymmetricShapleyValueConfig "
|
|
2816
|
+
"when no static covariate columns are provided in TimeSeriesDataConfig. "
|
|
2817
|
+
"Please check these configs."
|
|
2818
|
+
)
|
|
2819
|
+
|
|
2820
|
+
|
|
2821
|
+
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
|
|
2822
|
+
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
|
|
2823
|
+
|
|
2824
|
+
Args:
|
|
2825
|
+
analysis_config_file (str): File path to the local analysis config file.
|
|
2826
|
+
s3_output_path (str): S3 prefix to store the analysis config file.
|
|
2827
|
+
sagemaker_session (:class:`~sagemaker.session.Session`):
|
|
2828
|
+
:class:`~sagemaker.session.Session` object which manages interactions with
|
|
2829
|
+
Amazon SageMaker and any other AWS services needed. If not specified,
|
|
2830
|
+
the processor creates a :class:`~sagemaker.session.Session`
|
|
2831
|
+
using the default AWS configuration chain.
|
|
2832
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
2833
|
+
user code file (default: None).
|
|
2834
|
+
|
|
2835
|
+
Returns:
|
|
2836
|
+
The S3 URI of the uploaded file.
|
|
2837
|
+
"""
|
|
2838
|
+
return s3.S3Uploader.upload(
|
|
2839
|
+
local_path=analysis_config_file,
|
|
2840
|
+
desired_s3_uri=s3_output_path,
|
|
2841
|
+
sagemaker_session=sagemaker_session,
|
|
2842
|
+
kms_key=kms_key,
|
|
2843
|
+
)
|
|
2844
|
+
|
|
2845
|
+
|
|
2846
|
+
class ProcessingOutputHandler:
|
|
2847
|
+
"""Class to handle the parameters for SagemakerProcessor.Processingoutput"""
|
|
2848
|
+
|
|
2849
|
+
class S3UploadMode(Enum):
|
|
2850
|
+
"""Enum values for different uplaod modes to s3 bucket"""
|
|
2851
|
+
|
|
2852
|
+
CONTINUOUS = "Continuous"
|
|
2853
|
+
ENDOFJOB = "EndOfJob"
|
|
2854
|
+
|
|
2855
|
+
@classmethod
|
|
2856
|
+
def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str:
|
|
2857
|
+
"""Fetches s3_upload mode based on the shap_config values
|
|
2858
|
+
|
|
2859
|
+
Args:
|
|
2860
|
+
analysis_config (dict): dict Config following the analysis_config.json format
|
|
2861
|
+
|
|
2862
|
+
Returns:
|
|
2863
|
+
The s3_upload_mode type for the processing output.
|
|
2864
|
+
"""
|
|
2865
|
+
dataset_type = analysis_config["dataset_type"]
|
|
2866
|
+
return (
|
|
2867
|
+
ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
|
|
2868
|
+
if dataset_type == DatasetType.IMAGE.value
|
|
2869
|
+
else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value
|
|
2870
|
+
)
|
|
2871
|
+
|
|
2872
|
+
|
|
2873
|
+
def _set(value, key, dictionary):
|
|
2874
|
+
"""Sets dictionary[key] = value if value is not None."""
|
|
2875
|
+
if value is not None:
|
|
2876
|
+
dictionary[key] = value
|
|
2877
|
+
|
|
2878
|
+
|
|
2879
|
+
# Public API
|
|
2880
|
+
__all__ = [
|
|
2881
|
+
"AsymmetricShapleyValueConfig",
|
|
2882
|
+
"BiasConfig",
|
|
2883
|
+
"DataConfig",
|
|
2884
|
+
"DatasetType",
|
|
2885
|
+
"ExplainabilityConfig",
|
|
2886
|
+
"ImageConfig",
|
|
2887
|
+
"ModelConfig",
|
|
2888
|
+
"ModelPredictedLabelConfig",
|
|
2889
|
+
"PDPConfig",
|
|
2890
|
+
"ProcessingOutputHandler",
|
|
2891
|
+
"SageMakerClarifyProcessor",
|
|
2892
|
+
"SegmentationConfig",
|
|
2893
|
+
"SHAPConfig",
|
|
2894
|
+
"TextConfig",
|
|
2895
|
+
"TimeSeriesDataConfig",
|
|
2896
|
+
"TimeSeriesJSONDatasetFormat",
|
|
2897
|
+
"TimeSeriesModelConfig",
|
|
2898
|
+
]
|