sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1497 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains code related to Amazon SageMaker Explainability AI Model Monitoring.
|
|
14
|
+
|
|
15
|
+
These classes assist with suggesting baselines and creating monitoring schedules for monitoring
|
|
16
|
+
bias metrics and feature attribution of SageMaker Endpoints.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import print_function, absolute_import
|
|
19
|
+
|
|
20
|
+
import copy
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import uuid
|
|
24
|
+
|
|
25
|
+
from sagemaker.core.model_monitor import model_monitoring as mm
|
|
26
|
+
from sagemaker.core.model_monitor.utils import (
|
|
27
|
+
boto_describe_monitoring_schedule,
|
|
28
|
+
boto_list_monitoring_executions,
|
|
29
|
+
)
|
|
30
|
+
from sagemaker.core import image_uris, s3
|
|
31
|
+
from sagemaker.core.helper.session_helper import Session, expand_role
|
|
32
|
+
from sagemaker.core.common_utils import (
|
|
33
|
+
name_from_base,
|
|
34
|
+
format_tags,
|
|
35
|
+
get_resource_name_from_arn,
|
|
36
|
+
list_tags,
|
|
37
|
+
)
|
|
38
|
+
from sagemaker.core.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig
|
|
39
|
+
from sagemaker.core.processing import logs_for_processing_job
|
|
40
|
+
|
|
41
|
+
# Setting _LOGGER for backward compatibility, in case users import it...
|
|
42
|
+
logger = _LOGGER = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ClarifyModelMonitor(mm.ModelMonitor):
|
|
46
|
+
"""Base class of Amazon SageMaker Explainability API model monitors.
|
|
47
|
+
|
|
48
|
+
This class is an ``abstract base class``, please instantiate its subclasses
|
|
49
|
+
if you want to monitor bias metrics or feature attribution of an endpoint.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
role=None,
|
|
55
|
+
instance_count=1,
|
|
56
|
+
instance_type="ml.m5.xlarge",
|
|
57
|
+
volume_size_in_gb=30,
|
|
58
|
+
volume_kms_key=None,
|
|
59
|
+
output_kms_key=None,
|
|
60
|
+
max_runtime_in_seconds=None,
|
|
61
|
+
base_job_name=None,
|
|
62
|
+
sagemaker_session=None,
|
|
63
|
+
env=None,
|
|
64
|
+
tags=None,
|
|
65
|
+
network_config=None,
|
|
66
|
+
):
|
|
67
|
+
"""Initializes a monitor instance.
|
|
68
|
+
|
|
69
|
+
The monitor handles baselining datasets and creating Amazon SageMaker
|
|
70
|
+
Monitoring Schedules to monitor SageMaker endpoints.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
|
|
74
|
+
instance_count (int): The number of instances to run
|
|
75
|
+
the jobs with.
|
|
76
|
+
instance_type (str): Type of EC2 instance to use for
|
|
77
|
+
the job, for example, 'ml.m5.xlarge'.
|
|
78
|
+
volume_size_in_gb (int): Size in GB of the EBS volume
|
|
79
|
+
to use for storing data during processing (default: 30).
|
|
80
|
+
volume_kms_key (str): A KMS key for the job's volume.
|
|
81
|
+
output_kms_key (str): The KMS key id for the job's outputs.
|
|
82
|
+
max_runtime_in_seconds (int): Timeout in seconds. After this amount of
|
|
83
|
+
time, Amazon SageMaker terminates the job regardless of its current status.
|
|
84
|
+
Default: 3600
|
|
85
|
+
base_job_name (str): Prefix for the job name. If not specified,
|
|
86
|
+
a default name is generated based on the training image name and
|
|
87
|
+
current timestamp.
|
|
88
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
89
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
90
|
+
AWS services needed. If not specified, one is created using
|
|
91
|
+
the default AWS configuration chain.
|
|
92
|
+
env (dict): Environment variables to be passed to the job.
|
|
93
|
+
tags (Optional[Tags]): List of tags to be passed to the job.
|
|
94
|
+
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
|
|
95
|
+
object that configures network isolation, encryption of
|
|
96
|
+
inter-container traffic, security group IDs, and subnets.
|
|
97
|
+
"""
|
|
98
|
+
if type(self) == __class__: # pylint: disable=unidiomatic-typecheck
|
|
99
|
+
raise TypeError(
|
|
100
|
+
"{} is abstract, please instantiate its subclasses instead.".format(
|
|
101
|
+
__class__.__name__
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
session = sagemaker_session or Session()
|
|
106
|
+
clarify_image_uri = image_uris.retrieve("clarify", session.boto_session.region_name)
|
|
107
|
+
|
|
108
|
+
super(ClarifyModelMonitor, self).__init__(
|
|
109
|
+
role=role,
|
|
110
|
+
image_uri=clarify_image_uri,
|
|
111
|
+
instance_count=instance_count,
|
|
112
|
+
instance_type=instance_type,
|
|
113
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
114
|
+
volume_kms_key=volume_kms_key,
|
|
115
|
+
output_kms_key=output_kms_key,
|
|
116
|
+
max_runtime_in_seconds=max_runtime_in_seconds,
|
|
117
|
+
base_job_name=base_job_name,
|
|
118
|
+
sagemaker_session=session,
|
|
119
|
+
env=env,
|
|
120
|
+
tags=format_tags(tags),
|
|
121
|
+
network_config=network_config,
|
|
122
|
+
)
|
|
123
|
+
self.latest_baselining_job_config = None
|
|
124
|
+
|
|
125
|
+
def run_baseline(self, **_):
|
|
126
|
+
"""Not implemented.
|
|
127
|
+
|
|
128
|
+
'.run_baseline()' is only allowed for ModelMonitor objects.
|
|
129
|
+
Please use `suggest_baseline` instead.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
NotImplementedError
|
|
133
|
+
"""
|
|
134
|
+
raise NotImplementedError(
|
|
135
|
+
"'.run_baseline()' is only allowed for ModelMonitor objects."
|
|
136
|
+
"Please use suggest_baseline instead."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def latest_monitoring_statistics(self, **_):
|
|
140
|
+
"""Not implemented.
|
|
141
|
+
|
|
142
|
+
The class doesn't support statistics.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
NotImplementedError
|
|
146
|
+
"""
|
|
147
|
+
raise NotImplementedError("{} doesn't support statistics.".format(self.__class__.__name__))
|
|
148
|
+
|
|
149
|
+
def list_executions(self):
|
|
150
|
+
"""Get the list of the latest monitoring executions in descending order of "ScheduledTime".
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
[sagemaker.model_monitor.ClarifyMonitoringExecution]: List of
|
|
154
|
+
ClarifyMonitoringExecution in descending order of "ScheduledTime".
|
|
155
|
+
"""
|
|
156
|
+
executions = super(ClarifyModelMonitor, self).list_executions()
|
|
157
|
+
return [
|
|
158
|
+
ClarifyMonitoringExecution(
|
|
159
|
+
sagemaker_session=execution.sagemaker_session,
|
|
160
|
+
job_name=execution.job_name,
|
|
161
|
+
inputs=execution.inputs,
|
|
162
|
+
output=execution.output,
|
|
163
|
+
output_kms_key=execution.output_kms_key,
|
|
164
|
+
)
|
|
165
|
+
for execution in executions
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
def get_latest_execution_logs(self, wait=False):
|
|
169
|
+
"""Get the processing job logs for the most recent monitoring execution
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
wait (bool): Whether the call should wait until the job completes (default: False).
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ValueError: If no execution job or processing job for the last execution has run
|
|
176
|
+
|
|
177
|
+
Returns: None
|
|
178
|
+
"""
|
|
179
|
+
monitoring_executions = boto_list_monitoring_executions(
|
|
180
|
+
sagemaker_session=self.sagemaker_session,
|
|
181
|
+
monitoring_schedule_name=self.monitoring_schedule_name,
|
|
182
|
+
)
|
|
183
|
+
if len(monitoring_executions["MonitoringExecutionSummaries"]) == 0:
|
|
184
|
+
raise ValueError("No execution jobs were kicked off.")
|
|
185
|
+
if "ProcessingJobArn" not in monitoring_executions["MonitoringExecutionSummaries"][0]:
|
|
186
|
+
raise ValueError("Processing Job did not run for the last execution")
|
|
187
|
+
job_arn = monitoring_executions["MonitoringExecutionSummaries"][0]["ProcessingJobArn"]
|
|
188
|
+
logs_for_processing_job(
|
|
189
|
+
self.sagemaker_session,
|
|
190
|
+
job_name=get_resource_name_from_arn(job_arn),
|
|
191
|
+
wait=wait,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _create_baselining_processor(self):
|
|
195
|
+
"""Create and return a SageMakerClarifyProcessor object which will run the baselining job.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
sagemaker.clarify.SageMakerClarifyProcessor object.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
baselining_processor = SageMakerClarifyProcessor(
|
|
202
|
+
role=self.role,
|
|
203
|
+
instance_count=self.instance_count,
|
|
204
|
+
instance_type=self.instance_type,
|
|
205
|
+
volume_size_in_gb=self.volume_size_in_gb,
|
|
206
|
+
volume_kms_key=self.volume_kms_key,
|
|
207
|
+
output_kms_key=self.output_kms_key,
|
|
208
|
+
max_runtime_in_seconds=self.max_runtime_in_seconds,
|
|
209
|
+
sagemaker_session=self.sagemaker_session,
|
|
210
|
+
env=self.env,
|
|
211
|
+
tags=self.tags,
|
|
212
|
+
network_config=self.network_config,
|
|
213
|
+
)
|
|
214
|
+
baselining_processor.image_uri = self.image_uri
|
|
215
|
+
baselining_processor.base_job_name = self.base_job_name
|
|
216
|
+
return baselining_processor
|
|
217
|
+
|
|
218
|
+
def _upload_analysis_config(self, analysis_config, output_s3_uri, job_definition_name, kms_key):
|
|
219
|
+
"""Upload analysis config to s3://<output path>/<job name>/analysis_config.json
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
analysis_config (dict): analysis config of a Clarify model monitor.
|
|
223
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
224
|
+
Default: "s3://<default_session_bucket>/<job_name>/output"
|
|
225
|
+
job_definition_name (str): Job definition name.
|
|
226
|
+
If not specified then a default one will be generated.
|
|
227
|
+
kms_key( str): The ARN of the KMS key that is used to encrypt the
|
|
228
|
+
user code file (default: None).
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
str: The S3 uri of the uploaded file(s).
|
|
232
|
+
"""
|
|
233
|
+
s3_uri = s3.s3_path_join(
|
|
234
|
+
output_s3_uri,
|
|
235
|
+
job_definition_name,
|
|
236
|
+
str(uuid.uuid4()),
|
|
237
|
+
"analysis_config.json",
|
|
238
|
+
)
|
|
239
|
+
logger.info("Uploading analysis config to {s3_uri}.")
|
|
240
|
+
return s3.S3Uploader.upload_string_as_file_body(
|
|
241
|
+
json.dumps(analysis_config),
|
|
242
|
+
desired_s3_uri=s3_uri,
|
|
243
|
+
sagemaker_session=self.sagemaker_session,
|
|
244
|
+
kms_key=kms_key,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def _build_create_job_definition_request(
|
|
248
|
+
self,
|
|
249
|
+
monitoring_schedule_name,
|
|
250
|
+
job_definition_name,
|
|
251
|
+
image_uri,
|
|
252
|
+
latest_baselining_job_name=None,
|
|
253
|
+
latest_baselining_job_config=None,
|
|
254
|
+
existing_job_desc=None,
|
|
255
|
+
endpoint_input=None,
|
|
256
|
+
ground_truth_input=None,
|
|
257
|
+
analysis_config=None,
|
|
258
|
+
output_s3_uri=None,
|
|
259
|
+
constraints=None,
|
|
260
|
+
enable_cloudwatch_metrics=None,
|
|
261
|
+
role=None,
|
|
262
|
+
instance_count=None,
|
|
263
|
+
instance_type=None,
|
|
264
|
+
volume_size_in_gb=None,
|
|
265
|
+
volume_kms_key=None,
|
|
266
|
+
output_kms_key=None,
|
|
267
|
+
max_runtime_in_seconds=None,
|
|
268
|
+
env=None,
|
|
269
|
+
tags=None,
|
|
270
|
+
network_config=None,
|
|
271
|
+
batch_transform_input=None,
|
|
272
|
+
):
|
|
273
|
+
"""Build the request for job definition creation API
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
monitoring_schedule_name (str): Monitoring schedule name.
|
|
277
|
+
job_definition_name (str): Job definition name.
|
|
278
|
+
If not specified then a default one will be generated.
|
|
279
|
+
image_uri (str): The uri of the image to use for the jobs started by the Monitor.
|
|
280
|
+
latest_baselining_job_name (str): name of the last baselining job.
|
|
281
|
+
latest_baselining_job_config (ClarifyBaseliningConfig): analysis config from
|
|
282
|
+
last baselining job.
|
|
283
|
+
existing_job_desc (dict): description of existing job definition. It will be updated by
|
|
284
|
+
values that were passed in, and then used to create the new job definition.
|
|
285
|
+
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
|
|
286
|
+
This can either be the endpoint name or an EndpointInput.
|
|
287
|
+
ground_truth_input (str): S3 URI to ground truth dataset.
|
|
288
|
+
analysis_config (str or BiasAnalysisConfig or ExplainabilityAnalysisConfig): URI to the
|
|
289
|
+
analysis_config.json for the bias job. If it is None then configuration of latest
|
|
290
|
+
baselining job config will be reused. If no baselining job then fail the call.
|
|
291
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
292
|
+
Default: "s3://<default_session_bucket>/<job_name>/output"
|
|
293
|
+
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
|
|
294
|
+
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
|
|
295
|
+
to a constraints JSON file.
|
|
296
|
+
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
|
|
297
|
+
the baselining or monitoring jobs.
|
|
298
|
+
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
|
|
299
|
+
instance_count (int): The number of instances to run
|
|
300
|
+
the jobs with.
|
|
301
|
+
instance_type (str): Type of EC2 instance to use for
|
|
302
|
+
the job, for example, 'ml.m5.xlarge'.
|
|
303
|
+
volume_size_in_gb (int): Size in GB of the EBS volume
|
|
304
|
+
to use for storing data during processing (default: 30).
|
|
305
|
+
volume_kms_key (str): A KMS key for the job's volume.
|
|
306
|
+
output_kms_key (str): KMS key id for output.
|
|
307
|
+
max_runtime_in_seconds (int): Timeout in seconds. After this amount of
|
|
308
|
+
time, Amazon SageMaker terminates the job regardless of its current status.
|
|
309
|
+
Default: 3600
|
|
310
|
+
env (dict): Environment variables to be passed to the job.
|
|
311
|
+
tags (Optional[Tags]): List of tags to be passed to the job.
|
|
312
|
+
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
|
|
313
|
+
object that configures network isolation, encryption of
|
|
314
|
+
inter-container traffic, security group IDs, and subnets.
|
|
315
|
+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
|
|
316
|
+
the monitoring schedule on the batch transform
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
dict: request parameters to create job definition.
|
|
320
|
+
"""
|
|
321
|
+
if existing_job_desc is not None:
|
|
322
|
+
app_specification = existing_job_desc[
|
|
323
|
+
"{}AppSpecification".format(self.monitoring_type())
|
|
324
|
+
]
|
|
325
|
+
baseline_config = existing_job_desc.get(
|
|
326
|
+
"{}BaselineConfig".format(self.monitoring_type()), {}
|
|
327
|
+
)
|
|
328
|
+
job_input = existing_job_desc["{}JobInput".format(self.monitoring_type())]
|
|
329
|
+
job_output = existing_job_desc["{}JobOutputConfig".format(self.monitoring_type())]
|
|
330
|
+
cluster_config = existing_job_desc["JobResources"]["ClusterConfig"]
|
|
331
|
+
if role is None:
|
|
332
|
+
role = existing_job_desc["RoleArn"]
|
|
333
|
+
existing_network_config = existing_job_desc.get("NetworkConfig")
|
|
334
|
+
stop_condition = existing_job_desc.get("StoppingCondition", {})
|
|
335
|
+
else:
|
|
336
|
+
app_specification = {}
|
|
337
|
+
baseline_config = {}
|
|
338
|
+
job_input = {}
|
|
339
|
+
job_output = {}
|
|
340
|
+
cluster_config = {}
|
|
341
|
+
existing_network_config = None
|
|
342
|
+
stop_condition = {}
|
|
343
|
+
|
|
344
|
+
# job output
|
|
345
|
+
if output_s3_uri is not None:
|
|
346
|
+
normalized_monitoring_output = self._normalize_monitoring_output(
|
|
347
|
+
monitoring_schedule_name, output_s3_uri
|
|
348
|
+
)
|
|
349
|
+
job_output["MonitoringOutputs"] = [normalized_monitoring_output._to_request_dict()]
|
|
350
|
+
if output_kms_key is not None:
|
|
351
|
+
job_output["KmsKeyId"] = output_kms_key
|
|
352
|
+
|
|
353
|
+
# app specification
|
|
354
|
+
if analysis_config is None:
|
|
355
|
+
if latest_baselining_job_config is not None:
|
|
356
|
+
analysis_config = latest_baselining_job_config.analysis_config
|
|
357
|
+
elif app_specification:
|
|
358
|
+
analysis_config = app_specification["ConfigUri"]
|
|
359
|
+
else:
|
|
360
|
+
raise ValueError("analysis_config is mandatory.")
|
|
361
|
+
# backfill analysis_config
|
|
362
|
+
if isinstance(analysis_config, str):
|
|
363
|
+
analysis_config_uri = analysis_config
|
|
364
|
+
else:
|
|
365
|
+
analysis_config_uri = self._upload_analysis_config(
|
|
366
|
+
analysis_config._to_dict(), output_s3_uri, job_definition_name, output_kms_key
|
|
367
|
+
)
|
|
368
|
+
app_specification["ConfigUri"] = analysis_config_uri
|
|
369
|
+
app_specification["ImageUri"] = image_uri
|
|
370
|
+
normalized_env = self._generate_env_map(
|
|
371
|
+
env=env, enable_cloudwatch_metrics=enable_cloudwatch_metrics
|
|
372
|
+
)
|
|
373
|
+
if normalized_env:
|
|
374
|
+
app_specification["Environment"] = normalized_env
|
|
375
|
+
|
|
376
|
+
# baseline config
|
|
377
|
+
if constraints:
|
|
378
|
+
# noinspection PyTypeChecker
|
|
379
|
+
_, constraints_object = self._get_baseline_files(
|
|
380
|
+
statistics=None, constraints=constraints, sagemaker_session=self.sagemaker_session
|
|
381
|
+
)
|
|
382
|
+
constraints_s3_uri = None
|
|
383
|
+
if constraints_object is not None:
|
|
384
|
+
constraints_s3_uri = constraints_object.file_s3_uri
|
|
385
|
+
baseline_config["ConstraintsResource"] = dict(S3Uri=constraints_s3_uri)
|
|
386
|
+
elif latest_baselining_job_name:
|
|
387
|
+
baseline_config["BaseliningJobName"] = latest_baselining_job_name
|
|
388
|
+
|
|
389
|
+
# job input
|
|
390
|
+
if endpoint_input is not None:
|
|
391
|
+
normalized_endpoint_input = self._normalize_endpoint_input(
|
|
392
|
+
endpoint_input=endpoint_input
|
|
393
|
+
)
|
|
394
|
+
# backfill attributes to endpoint input
|
|
395
|
+
if latest_baselining_job_config is not None:
|
|
396
|
+
if normalized_endpoint_input.features_attribute is None:
|
|
397
|
+
normalized_endpoint_input.features_attribute = (
|
|
398
|
+
latest_baselining_job_config.features_attribute
|
|
399
|
+
)
|
|
400
|
+
if normalized_endpoint_input.inference_attribute is None:
|
|
401
|
+
normalized_endpoint_input.inference_attribute = (
|
|
402
|
+
latest_baselining_job_config.inference_attribute
|
|
403
|
+
)
|
|
404
|
+
if normalized_endpoint_input.probability_attribute is None:
|
|
405
|
+
normalized_endpoint_input.probability_attribute = (
|
|
406
|
+
latest_baselining_job_config.probability_attribute
|
|
407
|
+
)
|
|
408
|
+
if normalized_endpoint_input.probability_threshold_attribute is None:
|
|
409
|
+
normalized_endpoint_input.probability_threshold_attribute = (
|
|
410
|
+
latest_baselining_job_config.probability_threshold_attribute
|
|
411
|
+
)
|
|
412
|
+
job_input = normalized_endpoint_input._to_request_dict()
|
|
413
|
+
elif batch_transform_input is not None:
|
|
414
|
+
# backfill attributes to batch transform input
|
|
415
|
+
if latest_baselining_job_config is not None:
|
|
416
|
+
if batch_transform_input.features_attribute is None:
|
|
417
|
+
batch_transform_input.features_attribute = (
|
|
418
|
+
latest_baselining_job_config.features_attribute
|
|
419
|
+
)
|
|
420
|
+
if batch_transform_input.inference_attribute is None:
|
|
421
|
+
batch_transform_input.inference_attribute = (
|
|
422
|
+
latest_baselining_job_config.inference_attribute
|
|
423
|
+
)
|
|
424
|
+
if batch_transform_input.probability_attribute is None:
|
|
425
|
+
batch_transform_input.probability_attribute = (
|
|
426
|
+
latest_baselining_job_config.probability_attribute
|
|
427
|
+
)
|
|
428
|
+
if batch_transform_input.probability_threshold_attribute is None:
|
|
429
|
+
batch_transform_input.probability_threshold_attribute = (
|
|
430
|
+
latest_baselining_job_config.probability_threshold_attribute
|
|
431
|
+
)
|
|
432
|
+
job_input = batch_transform_input._to_request_dict()
|
|
433
|
+
|
|
434
|
+
if ground_truth_input is not None:
|
|
435
|
+
job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input)
|
|
436
|
+
|
|
437
|
+
# cluster config
|
|
438
|
+
if instance_count is not None:
|
|
439
|
+
cluster_config["InstanceCount"] = instance_count
|
|
440
|
+
if instance_type is not None:
|
|
441
|
+
cluster_config["InstanceType"] = instance_type
|
|
442
|
+
if volume_size_in_gb is not None:
|
|
443
|
+
cluster_config["VolumeSizeInGB"] = volume_size_in_gb
|
|
444
|
+
if volume_kms_key is not None:
|
|
445
|
+
cluster_config["VolumeKmsKeyId"] = volume_kms_key
|
|
446
|
+
|
|
447
|
+
# stop condition
|
|
448
|
+
if max_runtime_in_seconds is not None:
|
|
449
|
+
stop_condition["MaxRuntimeInSeconds"] = max_runtime_in_seconds
|
|
450
|
+
|
|
451
|
+
request_dict = {
|
|
452
|
+
"JobDefinitionName": job_definition_name,
|
|
453
|
+
"{}AppSpecification".format(self.monitoring_type()): app_specification,
|
|
454
|
+
"{}JobInput".format(self.monitoring_type()): job_input,
|
|
455
|
+
"{}JobOutputConfig".format(self.monitoring_type()): job_output,
|
|
456
|
+
"JobResources": dict(ClusterConfig=cluster_config),
|
|
457
|
+
"RoleArn": expand_role(self.sagemaker_session, role),
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
if baseline_config:
|
|
461
|
+
request_dict["{}BaselineConfig".format(self.monitoring_type())] = baseline_config
|
|
462
|
+
|
|
463
|
+
if network_config is not None:
|
|
464
|
+
network_config_dict = network_config._to_request_dict()
|
|
465
|
+
request_dict["NetworkConfig"] = network_config_dict
|
|
466
|
+
elif existing_network_config is not None:
|
|
467
|
+
request_dict["NetworkConfig"] = existing_network_config
|
|
468
|
+
|
|
469
|
+
if stop_condition:
|
|
470
|
+
request_dict["StoppingCondition"] = stop_condition
|
|
471
|
+
|
|
472
|
+
if tags is not None:
|
|
473
|
+
request_dict["Tags"] = format_tags(tags)
|
|
474
|
+
|
|
475
|
+
return request_dict
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class ModelBiasMonitor(ClarifyModelMonitor):
|
|
479
|
+
"""Amazon SageMaker model monitor to monitor bias metrics of an endpoint.
|
|
480
|
+
|
|
481
|
+
Please see the __init__ method of its base class for how to instantiate it.
|
|
482
|
+
"""
|
|
483
|
+
|
|
484
|
+
JOB_DEFINITION_BASE_NAME = "model-bias-job-definition"
|
|
485
|
+
|
|
486
|
+
@classmethod
|
|
487
|
+
def monitoring_type(cls):
|
|
488
|
+
"""Type of the monitoring job."""
|
|
489
|
+
return "ModelBias"
|
|
490
|
+
|
|
491
|
+
def suggest_baseline(
|
|
492
|
+
self,
|
|
493
|
+
data_config,
|
|
494
|
+
bias_config,
|
|
495
|
+
model_config,
|
|
496
|
+
model_predicted_label_config=None,
|
|
497
|
+
wait=False,
|
|
498
|
+
logs=False,
|
|
499
|
+
job_name=None,
|
|
500
|
+
kms_key=None,
|
|
501
|
+
):
|
|
502
|
+
"""Suggests baselines for use with Amazon SageMaker Model Monitoring Schedules.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
506
|
+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
|
|
507
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
508
|
+
endpoint to be created.
|
|
509
|
+
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
|
|
510
|
+
Config of how to extract the predicted label from the model output.
|
|
511
|
+
wait (bool): Whether the call should wait until the job completes (default: False).
|
|
512
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
513
|
+
Only meaningful when wait is True (default: False).
|
|
514
|
+
job_name (str): Processing job name. If not specified, the processor generates
|
|
515
|
+
a default job name, based on the image name and current timestamp.
|
|
516
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
517
|
+
user code file (default: None).
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
|
|
521
|
+
baselining job.
|
|
522
|
+
"""
|
|
523
|
+
baselining_processor = self._create_baselining_processor()
|
|
524
|
+
baselining_job_name = self._generate_baselining_job_name(job_name=job_name)
|
|
525
|
+
baselining_processor.run_bias(
|
|
526
|
+
data_config=data_config,
|
|
527
|
+
bias_config=bias_config,
|
|
528
|
+
model_config=model_config,
|
|
529
|
+
model_predicted_label_config=model_predicted_label_config,
|
|
530
|
+
wait=wait,
|
|
531
|
+
logs=logs,
|
|
532
|
+
job_name=baselining_job_name,
|
|
533
|
+
kms_key=kms_key,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
latest_baselining_job_config = ClarifyBaseliningConfig(
|
|
537
|
+
analysis_config=BiasAnalysisConfig(
|
|
538
|
+
bias_config=bias_config, headers=data_config.headers, label=data_config.label
|
|
539
|
+
),
|
|
540
|
+
features_attribute=data_config.features,
|
|
541
|
+
)
|
|
542
|
+
if model_predicted_label_config is not None:
|
|
543
|
+
latest_baselining_job_config.inference_attribute = (
|
|
544
|
+
model_predicted_label_config.label
|
|
545
|
+
if model_predicted_label_config.label is None
|
|
546
|
+
else str(model_predicted_label_config.label)
|
|
547
|
+
)
|
|
548
|
+
latest_baselining_job_config.probability_attribute = (
|
|
549
|
+
model_predicted_label_config.probability
|
|
550
|
+
if model_predicted_label_config.probability is None
|
|
551
|
+
else str(model_predicted_label_config.probability)
|
|
552
|
+
)
|
|
553
|
+
latest_baselining_job_config.probability_threshold_attribute = (
|
|
554
|
+
model_predicted_label_config.probability_threshold
|
|
555
|
+
)
|
|
556
|
+
self.latest_baselining_job_config = latest_baselining_job_config
|
|
557
|
+
self.latest_baselining_job_name = baselining_job_name
|
|
558
|
+
self.latest_baselining_job = ClarifyBaseliningJob(
|
|
559
|
+
processing_job=baselining_processor.latest_job
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
self.baselining_jobs.append(self.latest_baselining_job)
|
|
563
|
+
return baselining_processor.latest_job
|
|
564
|
+
|
|
565
|
+
# noinspection PyMethodOverriding
|
|
566
|
+
def create_monitoring_schedule(
|
|
567
|
+
self,
|
|
568
|
+
endpoint_input=None,
|
|
569
|
+
ground_truth_input=None,
|
|
570
|
+
analysis_config=None,
|
|
571
|
+
output_s3_uri=None,
|
|
572
|
+
constraints=None,
|
|
573
|
+
monitor_schedule_name=None,
|
|
574
|
+
schedule_cron_expression=None,
|
|
575
|
+
enable_cloudwatch_metrics=True,
|
|
576
|
+
batch_transform_input=None,
|
|
577
|
+
data_analysis_start_time=None,
|
|
578
|
+
data_analysis_end_time=None,
|
|
579
|
+
):
|
|
580
|
+
"""Creates a monitoring schedule.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
|
|
584
|
+
This can either be the endpoint name or an EndpointInput. (default: None)
|
|
585
|
+
ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
|
|
586
|
+
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
|
|
587
|
+
If it is None then configuration of the latest baselining job will be reused, but
|
|
588
|
+
if no baselining job then fail the call. (default: None)
|
|
589
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
590
|
+
Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
|
|
591
|
+
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
|
|
592
|
+
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
|
|
593
|
+
to a constraints JSON file. (default: None)
|
|
594
|
+
monitor_schedule_name (str): Schedule name. If not specified, the processor generates
|
|
595
|
+
a default job name, based on the image name and current timestamp.
|
|
596
|
+
(default: None)
|
|
597
|
+
schedule_cron_expression (str): The cron expression that dictates the frequency that
|
|
598
|
+
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
|
|
599
|
+
expressions. Default: Daily. (default: None)
|
|
600
|
+
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
|
|
601
|
+
the baselining or monitoring jobs. (default: True)
|
|
602
|
+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
|
|
603
|
+
the monitoring schedule on the batch transform (default: None)
|
|
604
|
+
data_analysis_start_time (str): Start time for the data analysis window
|
|
605
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
606
|
+
data_analysis_end_time (str): End time for the data analysis window
|
|
607
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
608
|
+
"""
|
|
609
|
+
# we default ground_truth_input to None in the function signature
|
|
610
|
+
# but verify they are giving here for positional argument
|
|
611
|
+
# backward compatibility reason.
|
|
612
|
+
if not ground_truth_input:
|
|
613
|
+
raise ValueError("ground_truth_input can not be None.")
|
|
614
|
+
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
|
|
615
|
+
message = (
|
|
616
|
+
"It seems that this object was already used to create an Amazon Model "
|
|
617
|
+
"Monitoring Schedule. To create another, first delete the existing one "
|
|
618
|
+
"using my_monitor.delete_monitoring_schedule()."
|
|
619
|
+
)
|
|
620
|
+
logger.error(message)
|
|
621
|
+
raise ValueError(message)
|
|
622
|
+
|
|
623
|
+
if (batch_transform_input is not None) ^ (endpoint_input is None):
|
|
624
|
+
message = (
|
|
625
|
+
"Need to have either batch_transform_input or endpoint_input to create an "
|
|
626
|
+
"Amazon Model Monitoring Schedule. "
|
|
627
|
+
"Please provide only one of the above required inputs"
|
|
628
|
+
)
|
|
629
|
+
logger.error(message)
|
|
630
|
+
raise ValueError(message)
|
|
631
|
+
|
|
632
|
+
self._check_monitoring_schedule_cron_validity(
|
|
633
|
+
schedule_cron_expression=schedule_cron_expression,
|
|
634
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
635
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
# create job definition
|
|
639
|
+
monitor_schedule_name = self._generate_monitoring_schedule_name(
|
|
640
|
+
schedule_name=monitor_schedule_name
|
|
641
|
+
)
|
|
642
|
+
new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
|
|
643
|
+
request_dict = self._build_create_job_definition_request(
|
|
644
|
+
monitoring_schedule_name=monitor_schedule_name,
|
|
645
|
+
job_definition_name=new_job_definition_name,
|
|
646
|
+
image_uri=self.image_uri,
|
|
647
|
+
latest_baselining_job_name=self.latest_baselining_job_name,
|
|
648
|
+
latest_baselining_job_config=self.latest_baselining_job_config,
|
|
649
|
+
endpoint_input=endpoint_input,
|
|
650
|
+
ground_truth_input=ground_truth_input,
|
|
651
|
+
analysis_config=analysis_config,
|
|
652
|
+
output_s3_uri=self._normalize_monitoring_output(
|
|
653
|
+
monitor_schedule_name, output_s3_uri
|
|
654
|
+
).s3_output.s3_uri,
|
|
655
|
+
constraints=constraints,
|
|
656
|
+
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
|
|
657
|
+
role=self.role,
|
|
658
|
+
instance_count=self.instance_count,
|
|
659
|
+
instance_type=self.instance_type,
|
|
660
|
+
volume_size_in_gb=self.volume_size_in_gb,
|
|
661
|
+
volume_kms_key=self.volume_kms_key,
|
|
662
|
+
output_kms_key=self.output_kms_key,
|
|
663
|
+
max_runtime_in_seconds=self.max_runtime_in_seconds,
|
|
664
|
+
env=self.env,
|
|
665
|
+
tags=self.tags,
|
|
666
|
+
network_config=self.network_config,
|
|
667
|
+
batch_transform_input=batch_transform_input,
|
|
668
|
+
)
|
|
669
|
+
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
|
|
670
|
+
|
|
671
|
+
# create schedule
|
|
672
|
+
try:
|
|
673
|
+
self._create_monitoring_schedule_from_job_definition(
|
|
674
|
+
monitor_schedule_name=monitor_schedule_name,
|
|
675
|
+
job_definition_name=new_job_definition_name,
|
|
676
|
+
schedule_cron_expression=schedule_cron_expression,
|
|
677
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
678
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
679
|
+
)
|
|
680
|
+
self.job_definition_name = new_job_definition_name
|
|
681
|
+
self.monitoring_schedule_name = monitor_schedule_name
|
|
682
|
+
except Exception:
|
|
683
|
+
logger.exception("Failed to create monitoring schedule.")
|
|
684
|
+
self.monitoring_schedule_name = None
|
|
685
|
+
# noinspection PyBroadException
|
|
686
|
+
try:
|
|
687
|
+
self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
|
|
688
|
+
JobDefinitionName=new_job_definition_name
|
|
689
|
+
)
|
|
690
|
+
except Exception: # pylint: disable=W0703
|
|
691
|
+
message = "Failed to delete job definition {}.".format(new_job_definition_name)
|
|
692
|
+
logger.exception(message)
|
|
693
|
+
raise
|
|
694
|
+
|
|
695
|
+
# noinspection PyMethodOverriding
|
|
696
|
+
def update_monitoring_schedule(
|
|
697
|
+
self,
|
|
698
|
+
endpoint_input=None,
|
|
699
|
+
ground_truth_input=None,
|
|
700
|
+
analysis_config=None,
|
|
701
|
+
output_s3_uri=None,
|
|
702
|
+
constraints=None,
|
|
703
|
+
schedule_cron_expression=None,
|
|
704
|
+
enable_cloudwatch_metrics=None,
|
|
705
|
+
role=None,
|
|
706
|
+
instance_count=None,
|
|
707
|
+
instance_type=None,
|
|
708
|
+
volume_size_in_gb=None,
|
|
709
|
+
volume_kms_key=None,
|
|
710
|
+
output_kms_key=None,
|
|
711
|
+
max_runtime_in_seconds=None,
|
|
712
|
+
env=None,
|
|
713
|
+
network_config=None,
|
|
714
|
+
batch_transform_input=None,
|
|
715
|
+
data_analysis_start_time=None,
|
|
716
|
+
data_analysis_end_time=None,
|
|
717
|
+
):
|
|
718
|
+
"""Updates the existing monitoring schedule.
|
|
719
|
+
|
|
720
|
+
If more options than schedule_cron_expression are to be updated, a new job definition will
|
|
721
|
+
be created to hold them. The old job definition will not be deleted.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
|
|
725
|
+
This can either be the endpoint name or an EndpointInput.
|
|
726
|
+
ground_truth_input (str): S3 URI to ground truth dataset.
|
|
727
|
+
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
|
|
728
|
+
If it is None then configuration of the latest baselining job will be reused, but
|
|
729
|
+
if no baselining job then fail the call.
|
|
730
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
731
|
+
Default: "s3://<default_session_bucket>/<job_name>/output"
|
|
732
|
+
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
|
|
733
|
+
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
|
|
734
|
+
to a constraints JSON file.
|
|
735
|
+
schedule_cron_expression (str): The cron expression that dictates the frequency that
|
|
736
|
+
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
|
|
737
|
+
expressions. Default: Daily.
|
|
738
|
+
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
|
|
739
|
+
the baselining or monitoring jobs.
|
|
740
|
+
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
|
|
741
|
+
instance_count (int): The number of instances to run
|
|
742
|
+
the jobs with.
|
|
743
|
+
instance_type (str): Type of EC2 instance to use for
|
|
744
|
+
the job, for example, 'ml.m5.xlarge'.
|
|
745
|
+
volume_size_in_gb (int): Size in GB of the EBS volume
|
|
746
|
+
to use for storing data during processing (default: 30).
|
|
747
|
+
volume_kms_key (str): A KMS key for the job's volume.
|
|
748
|
+
output_kms_key (str): The KMS key id for the job's outputs.
|
|
749
|
+
max_runtime_in_seconds (int): Timeout in seconds. After this amount of
|
|
750
|
+
time, Amazon SageMaker terminates the job regardless of its current status.
|
|
751
|
+
Default: 3600
|
|
752
|
+
env (dict): Environment variables to be passed to the job.
|
|
753
|
+
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
|
|
754
|
+
object that configures network isolation, encryption of
|
|
755
|
+
inter-container traffic, security group IDs, and subnets.
|
|
756
|
+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
|
|
757
|
+
the monitoring schedule on the batch transform
|
|
758
|
+
"""
|
|
759
|
+
valid_args = {
|
|
760
|
+
arg: value for arg, value in locals().items() if arg != "self" and value is not None
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
# Nothing to update
|
|
764
|
+
if len(valid_args) <= 0:
|
|
765
|
+
return
|
|
766
|
+
|
|
767
|
+
if batch_transform_input is not None and endpoint_input is not None:
|
|
768
|
+
message = (
|
|
769
|
+
"Need to have either batch_transform_input or endpoint_input to create an "
|
|
770
|
+
"Amazon Model Monitoring Schedule. "
|
|
771
|
+
"Please provide only one of the above required inputs"
|
|
772
|
+
)
|
|
773
|
+
logger.error(message)
|
|
774
|
+
raise ValueError(message)
|
|
775
|
+
|
|
776
|
+
# Only need to update schedule expression
|
|
777
|
+
if len(valid_args) == 1 and schedule_cron_expression is not None:
|
|
778
|
+
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
|
|
779
|
+
return
|
|
780
|
+
|
|
781
|
+
# Need to update schedule with a new job definition
|
|
782
|
+
job_desc = self.sagemaker_session.sagemaker_client.describe_model_bias_job_definition(
|
|
783
|
+
JobDefinitionName=self.job_definition_name
|
|
784
|
+
)
|
|
785
|
+
new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
|
|
786
|
+
request_dict = self._build_create_job_definition_request(
|
|
787
|
+
monitoring_schedule_name=self.monitoring_schedule_name,
|
|
788
|
+
job_definition_name=new_job_definition_name,
|
|
789
|
+
image_uri=self.image_uri,
|
|
790
|
+
existing_job_desc=job_desc,
|
|
791
|
+
endpoint_input=endpoint_input,
|
|
792
|
+
ground_truth_input=ground_truth_input,
|
|
793
|
+
analysis_config=analysis_config,
|
|
794
|
+
output_s3_uri=output_s3_uri,
|
|
795
|
+
constraints=constraints,
|
|
796
|
+
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
|
|
797
|
+
role=role,
|
|
798
|
+
instance_count=instance_count,
|
|
799
|
+
instance_type=instance_type,
|
|
800
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
801
|
+
volume_kms_key=volume_kms_key,
|
|
802
|
+
output_kms_key=output_kms_key,
|
|
803
|
+
max_runtime_in_seconds=max_runtime_in_seconds,
|
|
804
|
+
env=env,
|
|
805
|
+
tags=self.tags,
|
|
806
|
+
network_config=network_config,
|
|
807
|
+
batch_transform_input=batch_transform_input,
|
|
808
|
+
)
|
|
809
|
+
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
|
|
810
|
+
try:
|
|
811
|
+
self._update_monitoring_schedule(
|
|
812
|
+
new_job_definition_name,
|
|
813
|
+
schedule_cron_expression,
|
|
814
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
815
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
816
|
+
)
|
|
817
|
+
self.job_definition_name = new_job_definition_name
|
|
818
|
+
if role is not None:
|
|
819
|
+
self.role = role
|
|
820
|
+
if instance_count is not None:
|
|
821
|
+
self.instance_count = instance_count
|
|
822
|
+
if instance_type is not None:
|
|
823
|
+
self.instance_type = instance_type
|
|
824
|
+
if volume_size_in_gb is not None:
|
|
825
|
+
self.volume_size_in_gb = volume_size_in_gb
|
|
826
|
+
if volume_kms_key is not None:
|
|
827
|
+
self.volume_kms_key = volume_kms_key
|
|
828
|
+
if output_kms_key is not None:
|
|
829
|
+
self.output_kms_key = output_kms_key
|
|
830
|
+
if max_runtime_in_seconds is not None:
|
|
831
|
+
self.max_runtime_in_seconds = max_runtime_in_seconds
|
|
832
|
+
if env is not None:
|
|
833
|
+
self.env = env
|
|
834
|
+
if network_config is not None:
|
|
835
|
+
self.network_config = network_config
|
|
836
|
+
except Exception:
|
|
837
|
+
logger.exception("Failed to update monitoring schedule.")
|
|
838
|
+
# noinspection PyBroadException
|
|
839
|
+
try:
|
|
840
|
+
self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
|
|
841
|
+
JobDefinitionName=new_job_definition_name
|
|
842
|
+
)
|
|
843
|
+
except Exception: # pylint: disable=W0703
|
|
844
|
+
message = "Failed to delete job definition {}.".format(new_job_definition_name)
|
|
845
|
+
logger.exception(message)
|
|
846
|
+
raise
|
|
847
|
+
|
|
848
|
+
def delete_monitoring_schedule(self):
|
|
849
|
+
"""Deletes the monitoring schedule and its job definition."""
|
|
850
|
+
super(ModelBiasMonitor, self).delete_monitoring_schedule()
|
|
851
|
+
# Delete job definition.
|
|
852
|
+
message = "Deleting Model Bias Job Definition with name: {}".format(
|
|
853
|
+
self.job_definition_name
|
|
854
|
+
)
|
|
855
|
+
logger.info(message)
|
|
856
|
+
self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
|
|
857
|
+
JobDefinitionName=self.job_definition_name
|
|
858
|
+
)
|
|
859
|
+
self.job_definition_name = None
|
|
860
|
+
|
|
861
|
+
@classmethod
|
|
862
|
+
def attach(cls, monitor_schedule_name, sagemaker_session=None):
|
|
863
|
+
"""Sets this object's schedule name to the name provided.
|
|
864
|
+
|
|
865
|
+
This allows subsequent describe_schedule or list_executions calls to point
|
|
866
|
+
to the given schedule.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
monitor_schedule_name (str): The name of the schedule to attach to.
|
|
870
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
871
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
872
|
+
AWS services needed. If not specified, one is created using
|
|
873
|
+
the default AWS configuration chain.
|
|
874
|
+
"""
|
|
875
|
+
sagemaker_session = sagemaker_session or Session()
|
|
876
|
+
schedule_desc = boto_describe_monitoring_schedule(
|
|
877
|
+
sagemaker_session, monitoring_schedule_name=monitor_schedule_name
|
|
878
|
+
)
|
|
879
|
+
monitoring_type = schedule_desc["MonitoringScheduleConfig"].get("MonitoringType")
|
|
880
|
+
if monitoring_type != cls.monitoring_type():
|
|
881
|
+
raise TypeError("{} can only attach to ModelBias schedule.".format(__class__.__name__))
|
|
882
|
+
job_definition_name = schedule_desc["MonitoringScheduleConfig"][
|
|
883
|
+
"MonitoringJobDefinitionName"
|
|
884
|
+
]
|
|
885
|
+
job_desc = sagemaker_session.sagemaker_client.describe_model_bias_job_definition(
|
|
886
|
+
JobDefinitionName=job_definition_name
|
|
887
|
+
)
|
|
888
|
+
tags = list_tags(sagemaker_session, resource_arn=schedule_desc["MonitoringScheduleArn"])
|
|
889
|
+
return ClarifyModelMonitor._attach(
|
|
890
|
+
clazz=cls,
|
|
891
|
+
sagemaker_session=sagemaker_session,
|
|
892
|
+
schedule_desc=schedule_desc,
|
|
893
|
+
job_desc=job_desc,
|
|
894
|
+
tags=tags,
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
class BiasAnalysisConfig:
|
|
899
|
+
"""Analysis configuration for ModelBiasMonitor."""
|
|
900
|
+
|
|
901
|
+
def __init__(self, bias_config, headers=None, label=None):
|
|
902
|
+
"""Creates an analysis config dictionary.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
bias_config (sagemaker.clarify.BiasConfig): Config object related to bias
|
|
906
|
+
configurations.
|
|
907
|
+
headers (list[str]): A list of column names in the input dataset.
|
|
908
|
+
label (str): Target attribute for the model required by bias metrics. Specified as
|
|
909
|
+
column name or index for CSV dataset, or as JMESPath expression for JSONLines.
|
|
910
|
+
"""
|
|
911
|
+
self.analysis_config = bias_config.get_config()
|
|
912
|
+
if headers is not None:
|
|
913
|
+
self.analysis_config["headers"] = headers
|
|
914
|
+
if label is not None:
|
|
915
|
+
self.analysis_config["label"] = label
|
|
916
|
+
|
|
917
|
+
def _to_dict(self):
|
|
918
|
+
"""Generates a request dictionary using the parameters provided to the class."""
|
|
919
|
+
return self.analysis_config
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
class ModelExplainabilityMonitor(ClarifyModelMonitor):
|
|
923
|
+
"""Amazon SageMaker model monitor to monitor feature attribution of an endpoint.
|
|
924
|
+
|
|
925
|
+
Please see the __init__ method of its base class for how to instantiate it.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
JOB_DEFINITION_BASE_NAME = "model-explainability-job-definition"
|
|
929
|
+
|
|
930
|
+
@classmethod
|
|
931
|
+
def monitoring_type(cls):
|
|
932
|
+
"""Type of the monitoring job."""
|
|
933
|
+
return "ModelExplainability"
|
|
934
|
+
|
|
935
|
+
def suggest_baseline(
|
|
936
|
+
self,
|
|
937
|
+
data_config,
|
|
938
|
+
explainability_config,
|
|
939
|
+
model_config,
|
|
940
|
+
model_scores=None,
|
|
941
|
+
wait=False,
|
|
942
|
+
logs=False,
|
|
943
|
+
job_name=None,
|
|
944
|
+
kms_key=None,
|
|
945
|
+
):
|
|
946
|
+
"""Suggest baselines for use with Amazon SageMaker Model Monitoring Schedules.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
|
|
950
|
+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
|
|
951
|
+
specific explainability method. Currently, only SHAP is supported.
|
|
952
|
+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
|
|
953
|
+
endpoint to be created.
|
|
954
|
+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
|
|
955
|
+
Index or JMESPath expression to locate the predicted scores in the model output.
|
|
956
|
+
This is not required if the model output is a single score. Alternatively,
|
|
957
|
+
it can be an instance of ModelPredictedLabelConfig to provide more parameters
|
|
958
|
+
like label_headers.
|
|
959
|
+
wait (bool): Whether the call should wait until the job completes (default: False).
|
|
960
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
961
|
+
Only meaningful when wait is True (default: False).
|
|
962
|
+
job_name (str): Processing job name. If not specified, the processor generates
|
|
963
|
+
a default job name, based on the image name and current timestamp.
|
|
964
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
965
|
+
user code file (default: None).
|
|
966
|
+
|
|
967
|
+
Returns:
|
|
968
|
+
sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
|
|
969
|
+
baselining job.
|
|
970
|
+
"""
|
|
971
|
+
baselining_processor = self._create_baselining_processor()
|
|
972
|
+
baselining_job_name = self._generate_baselining_job_name(job_name=job_name)
|
|
973
|
+
baselining_processor.run_explainability(
|
|
974
|
+
data_config=data_config,
|
|
975
|
+
model_config=model_config,
|
|
976
|
+
explainability_config=explainability_config,
|
|
977
|
+
model_scores=model_scores,
|
|
978
|
+
wait=wait,
|
|
979
|
+
logs=logs,
|
|
980
|
+
job_name=baselining_job_name,
|
|
981
|
+
kms_key=kms_key,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
# Explainability analysis doesn't need label
|
|
985
|
+
headers = copy.deepcopy(data_config.headers)
|
|
986
|
+
if headers and data_config.label in headers:
|
|
987
|
+
headers.remove(data_config.label)
|
|
988
|
+
if model_scores is None:
|
|
989
|
+
inference_attribute = None
|
|
990
|
+
label_headers = None
|
|
991
|
+
elif isinstance(model_scores, ModelPredictedLabelConfig):
|
|
992
|
+
inference_attribute = str(model_scores.label)
|
|
993
|
+
label_headers = model_scores.label_headers
|
|
994
|
+
else:
|
|
995
|
+
inference_attribute = str(model_scores)
|
|
996
|
+
label_headers = None
|
|
997
|
+
self.latest_baselining_job_config = ClarifyBaseliningConfig(
|
|
998
|
+
analysis_config=ExplainabilityAnalysisConfig(
|
|
999
|
+
explainability_config=explainability_config,
|
|
1000
|
+
model_config=model_config,
|
|
1001
|
+
headers=headers,
|
|
1002
|
+
label_headers=label_headers,
|
|
1003
|
+
),
|
|
1004
|
+
features_attribute=data_config.features,
|
|
1005
|
+
inference_attribute=inference_attribute,
|
|
1006
|
+
)
|
|
1007
|
+
self.latest_baselining_job_name = baselining_job_name
|
|
1008
|
+
self.latest_baselining_job = ClarifyBaseliningJob(
|
|
1009
|
+
processing_job=baselining_processor.latest_job
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
self.baselining_jobs.append(self.latest_baselining_job)
|
|
1013
|
+
return baselining_processor.latest_job
|
|
1014
|
+
|
|
1015
|
+
# noinspection PyMethodOverriding
|
|
1016
|
+
def create_monitoring_schedule(
|
|
1017
|
+
self,
|
|
1018
|
+
endpoint_input=None,
|
|
1019
|
+
analysis_config=None,
|
|
1020
|
+
output_s3_uri=None,
|
|
1021
|
+
constraints=None,
|
|
1022
|
+
monitor_schedule_name=None,
|
|
1023
|
+
schedule_cron_expression=None,
|
|
1024
|
+
enable_cloudwatch_metrics=True,
|
|
1025
|
+
batch_transform_input=None,
|
|
1026
|
+
data_analysis_start_time=None,
|
|
1027
|
+
data_analysis_end_time=None,
|
|
1028
|
+
):
|
|
1029
|
+
"""Creates a monitoring schedule.
|
|
1030
|
+
|
|
1031
|
+
Args:
|
|
1032
|
+
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
|
|
1033
|
+
This can either be the endpoint name or an EndpointInput. (default: None)
|
|
1034
|
+
analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
|
|
1035
|
+
the explainability job. If it is None then configuration of the latest baselining
|
|
1036
|
+
job will be reused, but if no baselining job then fail the call.
|
|
1037
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
1038
|
+
Default: "s3://<default_session_bucket>/<job_name>/output"
|
|
1039
|
+
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
|
|
1040
|
+
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
|
|
1041
|
+
to a constraints JSON file.
|
|
1042
|
+
monitor_schedule_name (str): Schedule name. If not specified, the processor generates
|
|
1043
|
+
a default job name, based on the image name and current timestamp.
|
|
1044
|
+
schedule_cron_expression (str): The cron expression that dictates the frequency that
|
|
1045
|
+
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
|
|
1046
|
+
expressions. Default: Daily.
|
|
1047
|
+
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
|
|
1048
|
+
the baselining or monitoring jobs.
|
|
1049
|
+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
|
|
1050
|
+
run the monitoring schedule on the batch transform
|
|
1051
|
+
data_analysis_start_time (str): Start time for the data analysis window
|
|
1052
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
1053
|
+
data_analysis_end_time (str): End time for the data analysis window
|
|
1054
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
1055
|
+
"""
|
|
1056
|
+
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
|
|
1057
|
+
message = (
|
|
1058
|
+
"It seems that this object was already used to create an Amazon Model "
|
|
1059
|
+
"Monitoring Schedule. To create another, first delete the existing one "
|
|
1060
|
+
"using my_monitor.delete_monitoring_schedule()."
|
|
1061
|
+
)
|
|
1062
|
+
logger.error(message)
|
|
1063
|
+
raise ValueError(message)
|
|
1064
|
+
|
|
1065
|
+
if (batch_transform_input is not None) ^ (endpoint_input is None):
|
|
1066
|
+
message = (
|
|
1067
|
+
"Need to have either batch_transform_input or endpoint_input to create an "
|
|
1068
|
+
"Amazon Model Monitoring Schedule."
|
|
1069
|
+
"Please provide only one of the above required inputs"
|
|
1070
|
+
)
|
|
1071
|
+
logger.error(message)
|
|
1072
|
+
raise ValueError(message)
|
|
1073
|
+
|
|
1074
|
+
self._check_monitoring_schedule_cron_validity(
|
|
1075
|
+
schedule_cron_expression=schedule_cron_expression,
|
|
1076
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
1077
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# create job definition
|
|
1081
|
+
monitor_schedule_name = self._generate_monitoring_schedule_name(
|
|
1082
|
+
schedule_name=monitor_schedule_name
|
|
1083
|
+
)
|
|
1084
|
+
new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
|
|
1085
|
+
request_dict = self._build_create_job_definition_request(
|
|
1086
|
+
monitoring_schedule_name=monitor_schedule_name,
|
|
1087
|
+
job_definition_name=new_job_definition_name,
|
|
1088
|
+
image_uri=self.image_uri,
|
|
1089
|
+
latest_baselining_job_name=self.latest_baselining_job_name,
|
|
1090
|
+
latest_baselining_job_config=self.latest_baselining_job_config,
|
|
1091
|
+
endpoint_input=endpoint_input,
|
|
1092
|
+
analysis_config=analysis_config,
|
|
1093
|
+
output_s3_uri=self._normalize_monitoring_output(
|
|
1094
|
+
monitor_schedule_name, output_s3_uri
|
|
1095
|
+
).s3_output.s3_uri,
|
|
1096
|
+
constraints=constraints,
|
|
1097
|
+
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
|
|
1098
|
+
role=self.role,
|
|
1099
|
+
instance_count=self.instance_count,
|
|
1100
|
+
instance_type=self.instance_type,
|
|
1101
|
+
volume_size_in_gb=self.volume_size_in_gb,
|
|
1102
|
+
volume_kms_key=self.volume_kms_key,
|
|
1103
|
+
output_kms_key=self.output_kms_key,
|
|
1104
|
+
max_runtime_in_seconds=self.max_runtime_in_seconds,
|
|
1105
|
+
env=self.env,
|
|
1106
|
+
tags=self.tags,
|
|
1107
|
+
network_config=self.network_config,
|
|
1108
|
+
batch_transform_input=batch_transform_input,
|
|
1109
|
+
)
|
|
1110
|
+
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
|
|
1111
|
+
**request_dict
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
# create schedule
|
|
1115
|
+
try:
|
|
1116
|
+
self._create_monitoring_schedule_from_job_definition(
|
|
1117
|
+
monitor_schedule_name=monitor_schedule_name,
|
|
1118
|
+
job_definition_name=new_job_definition_name,
|
|
1119
|
+
schedule_cron_expression=schedule_cron_expression,
|
|
1120
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
1121
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
1122
|
+
)
|
|
1123
|
+
self.job_definition_name = new_job_definition_name
|
|
1124
|
+
self.monitoring_schedule_name = monitor_schedule_name
|
|
1125
|
+
except Exception:
|
|
1126
|
+
logger.exception("Failed to create monitoring schedule.")
|
|
1127
|
+
self.monitoring_schedule_name = None
|
|
1128
|
+
# noinspection PyBroadException
|
|
1129
|
+
try:
|
|
1130
|
+
self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
|
|
1131
|
+
JobDefinitionName=new_job_definition_name
|
|
1132
|
+
)
|
|
1133
|
+
except Exception: # pylint: disable=W0703
|
|
1134
|
+
message = "Failed to delete job definition {}.".format(new_job_definition_name)
|
|
1135
|
+
logger.exception(message)
|
|
1136
|
+
raise
|
|
1137
|
+
|
|
1138
|
+
# noinspection PyMethodOverriding
|
|
1139
|
+
def update_monitoring_schedule(
|
|
1140
|
+
self,
|
|
1141
|
+
endpoint_input=None,
|
|
1142
|
+
analysis_config=None,
|
|
1143
|
+
output_s3_uri=None,
|
|
1144
|
+
constraints=None,
|
|
1145
|
+
schedule_cron_expression=None,
|
|
1146
|
+
enable_cloudwatch_metrics=None,
|
|
1147
|
+
role=None,
|
|
1148
|
+
instance_count=None,
|
|
1149
|
+
instance_type=None,
|
|
1150
|
+
volume_size_in_gb=None,
|
|
1151
|
+
volume_kms_key=None,
|
|
1152
|
+
output_kms_key=None,
|
|
1153
|
+
max_runtime_in_seconds=None,
|
|
1154
|
+
env=None,
|
|
1155
|
+
network_config=None,
|
|
1156
|
+
batch_transform_input=None,
|
|
1157
|
+
data_analysis_start_time=None,
|
|
1158
|
+
data_analysis_end_time=None,
|
|
1159
|
+
):
|
|
1160
|
+
"""Updates the existing monitoring schedule.
|
|
1161
|
+
|
|
1162
|
+
If more options than schedule_cron_expression are to be updated, a new job definition will
|
|
1163
|
+
be created to hold them. The old job definition will not be deleted.
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
|
|
1167
|
+
This can either be the endpoint name or an EndpointInput.
|
|
1168
|
+
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
|
|
1169
|
+
If it is None then configuration of the latest baselining job will be reused, but
|
|
1170
|
+
if no baselining job then fail the call.
|
|
1171
|
+
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
|
|
1172
|
+
Default: "s3://<default_session_bucket>/<job_name>/output"
|
|
1173
|
+
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
|
|
1174
|
+
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
|
|
1175
|
+
to a constraints JSON file.
|
|
1176
|
+
schedule_cron_expression (str): The cron expression that dictates the frequency that
|
|
1177
|
+
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
|
|
1178
|
+
expressions. Default: Daily.
|
|
1179
|
+
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
|
|
1180
|
+
the baselining or monitoring jobs.
|
|
1181
|
+
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
|
|
1182
|
+
instance_count (int): The number of instances to run
|
|
1183
|
+
the jobs with.
|
|
1184
|
+
instance_type (str): Type of EC2 instance to use for
|
|
1185
|
+
the job, for example, 'ml.m5.xlarge'.
|
|
1186
|
+
volume_size_in_gb (int): Size in GB of the EBS volume
|
|
1187
|
+
to use for storing data during processing (default: 30).
|
|
1188
|
+
volume_kms_key (str): A KMS key for the job's volume.
|
|
1189
|
+
output_kms_key (str): The KMS key id for the job's outputs.
|
|
1190
|
+
max_runtime_in_seconds (int): Timeout in seconds. After this amount of
|
|
1191
|
+
time, Amazon SageMaker terminates the job regardless of its current status.
|
|
1192
|
+
Default: 3600
|
|
1193
|
+
env (dict): Environment variables to be passed to the job.
|
|
1194
|
+
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
|
|
1195
|
+
object that configures network isolation, encryption of
|
|
1196
|
+
inter-container traffic, security group IDs, and subnets.
|
|
1197
|
+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
|
|
1198
|
+
run the monitoring schedule on the batch transform
|
|
1199
|
+
data_analysis_start_time (str): Start time for the data analysis window
|
|
1200
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
1201
|
+
data_analysis_end_time (str): End time for the data analysis window
|
|
1202
|
+
for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
|
|
1203
|
+
"""
|
|
1204
|
+
valid_args = {
|
|
1205
|
+
arg: value for arg, value in locals().items() if arg != "self" and value is not None
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
# Nothing to update
|
|
1209
|
+
if len(valid_args) <= 0:
|
|
1210
|
+
raise ValueError("Nothing to update.")
|
|
1211
|
+
|
|
1212
|
+
if batch_transform_input is not None and endpoint_input is not None:
|
|
1213
|
+
message = (
|
|
1214
|
+
"Need to have either batch_transform_input or endpoint_input to create an "
|
|
1215
|
+
"Amazon Model Monitoring Schedule. "
|
|
1216
|
+
"Please provide only one of the above required inputs"
|
|
1217
|
+
)
|
|
1218
|
+
logger.error(message)
|
|
1219
|
+
raise ValueError(message)
|
|
1220
|
+
|
|
1221
|
+
# Only need to update schedule expression
|
|
1222
|
+
if len(valid_args) == 1 and schedule_cron_expression is not None:
|
|
1223
|
+
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
|
|
1224
|
+
return
|
|
1225
|
+
|
|
1226
|
+
# Need to update schedule with a new job definition
|
|
1227
|
+
job_desc = (
|
|
1228
|
+
self.sagemaker_session.sagemaker_client.describe_model_explainability_job_definition(
|
|
1229
|
+
JobDefinitionName=self.job_definition_name
|
|
1230
|
+
)
|
|
1231
|
+
)
|
|
1232
|
+
new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
|
|
1233
|
+
request_dict = self._build_create_job_definition_request(
|
|
1234
|
+
monitoring_schedule_name=self.monitoring_schedule_name,
|
|
1235
|
+
job_definition_name=new_job_definition_name,
|
|
1236
|
+
image_uri=self.image_uri,
|
|
1237
|
+
existing_job_desc=job_desc,
|
|
1238
|
+
endpoint_input=endpoint_input,
|
|
1239
|
+
analysis_config=analysis_config,
|
|
1240
|
+
output_s3_uri=output_s3_uri,
|
|
1241
|
+
constraints=constraints,
|
|
1242
|
+
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
|
|
1243
|
+
role=role,
|
|
1244
|
+
instance_count=instance_count,
|
|
1245
|
+
instance_type=instance_type,
|
|
1246
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
1247
|
+
volume_kms_key=volume_kms_key,
|
|
1248
|
+
output_kms_key=output_kms_key,
|
|
1249
|
+
max_runtime_in_seconds=max_runtime_in_seconds,
|
|
1250
|
+
env=env,
|
|
1251
|
+
tags=self.tags,
|
|
1252
|
+
network_config=network_config,
|
|
1253
|
+
batch_transform_input=batch_transform_input,
|
|
1254
|
+
)
|
|
1255
|
+
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
|
|
1256
|
+
**request_dict
|
|
1257
|
+
)
|
|
1258
|
+
try:
|
|
1259
|
+
self._update_monitoring_schedule(
|
|
1260
|
+
new_job_definition_name,
|
|
1261
|
+
schedule_cron_expression,
|
|
1262
|
+
data_analysis_start_time=data_analysis_start_time,
|
|
1263
|
+
data_analysis_end_time=data_analysis_end_time,
|
|
1264
|
+
)
|
|
1265
|
+
self.job_definition_name = new_job_definition_name
|
|
1266
|
+
if role is not None:
|
|
1267
|
+
self.role = role
|
|
1268
|
+
if instance_count is not None:
|
|
1269
|
+
self.instance_count = instance_count
|
|
1270
|
+
if instance_type is not None:
|
|
1271
|
+
self.instance_type = instance_type
|
|
1272
|
+
if volume_size_in_gb is not None:
|
|
1273
|
+
self.volume_size_in_gb = volume_size_in_gb
|
|
1274
|
+
if volume_kms_key is not None:
|
|
1275
|
+
self.volume_kms_key = volume_kms_key
|
|
1276
|
+
if output_kms_key is not None:
|
|
1277
|
+
self.output_kms_key = output_kms_key
|
|
1278
|
+
if max_runtime_in_seconds is not None:
|
|
1279
|
+
self.max_runtime_in_seconds = max_runtime_in_seconds
|
|
1280
|
+
if env is not None:
|
|
1281
|
+
self.env = env
|
|
1282
|
+
if network_config is not None:
|
|
1283
|
+
self.network_config = network_config
|
|
1284
|
+
except Exception:
|
|
1285
|
+
logger.exception("Failed to update monitoring schedule.")
|
|
1286
|
+
# noinspection PyBroadException
|
|
1287
|
+
try:
|
|
1288
|
+
self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
|
|
1289
|
+
JobDefinitionName=new_job_definition_name
|
|
1290
|
+
)
|
|
1291
|
+
except Exception: # pylint: disable=W0703
|
|
1292
|
+
message = "Failed to delete job definition {}.".format(new_job_definition_name)
|
|
1293
|
+
logger.exception(message)
|
|
1294
|
+
raise
|
|
1295
|
+
|
|
1296
|
+
def delete_monitoring_schedule(self):
|
|
1297
|
+
"""Deletes the monitoring schedule and its job definition."""
|
|
1298
|
+
super(ModelExplainabilityMonitor, self).delete_monitoring_schedule()
|
|
1299
|
+
# Delete job definition.
|
|
1300
|
+
message = "Deleting Model Explainability Job Definition with name: {}".format(
|
|
1301
|
+
self.job_definition_name
|
|
1302
|
+
)
|
|
1303
|
+
logger.info(message)
|
|
1304
|
+
self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
|
|
1305
|
+
JobDefinitionName=self.job_definition_name
|
|
1306
|
+
)
|
|
1307
|
+
self.job_definition_name = None
|
|
1308
|
+
|
|
1309
|
+
@classmethod
|
|
1310
|
+
def attach(cls, monitor_schedule_name, sagemaker_session=None):
|
|
1311
|
+
"""Sets this object's schedule name to the name provided.
|
|
1312
|
+
|
|
1313
|
+
This allows subsequent describe_schedule or list_executions calls to point
|
|
1314
|
+
to the given schedule.
|
|
1315
|
+
|
|
1316
|
+
Args:
|
|
1317
|
+
monitor_schedule_name (str): The name of the schedule to attach to.
|
|
1318
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
1319
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
1320
|
+
AWS services needed. If not specified, one is created using
|
|
1321
|
+
the default AWS configuration chain.
|
|
1322
|
+
"""
|
|
1323
|
+
sagemaker_session = sagemaker_session or Session()
|
|
1324
|
+
schedule_desc = boto_describe_monitoring_schedule(
|
|
1325
|
+
sagemaker_session=sagemaker_session, monitoring_schedule_name=monitor_schedule_name
|
|
1326
|
+
)
|
|
1327
|
+
monitoring_type = schedule_desc["MonitoringScheduleConfig"].get("MonitoringType")
|
|
1328
|
+
if monitoring_type != cls.monitoring_type():
|
|
1329
|
+
raise TypeError(
|
|
1330
|
+
"{} can only attach to ModelExplainability schedule.".format(__class__.__name__)
|
|
1331
|
+
)
|
|
1332
|
+
job_definition_name = schedule_desc["MonitoringScheduleConfig"][
|
|
1333
|
+
"MonitoringJobDefinitionName"
|
|
1334
|
+
]
|
|
1335
|
+
job_desc = sagemaker_session.sagemaker_client.describe_model_explainability_job_definition(
|
|
1336
|
+
JobDefinitionName=job_definition_name
|
|
1337
|
+
)
|
|
1338
|
+
tags = list_tags(
|
|
1339
|
+
sagemaker_session=sagemaker_session, resource_arn=schedule_desc["MonitoringScheduleArn"]
|
|
1340
|
+
)
|
|
1341
|
+
return ClarifyModelMonitor._attach(
|
|
1342
|
+
clazz=cls,
|
|
1343
|
+
sagemaker_session=sagemaker_session,
|
|
1344
|
+
schedule_desc=schedule_desc,
|
|
1345
|
+
job_desc=job_desc,
|
|
1346
|
+
tags=tags,
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
class ExplainabilityAnalysisConfig:
|
|
1351
|
+
"""Analysis configuration for ModelExplainabilityMonitor."""
|
|
1352
|
+
|
|
1353
|
+
def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
|
|
1354
|
+
"""Creates an analysis config dictionary.
|
|
1355
|
+
|
|
1356
|
+
Args:
|
|
1357
|
+
explainability_config (sagemaker.clarify.ExplainabilityConfig): Config object related
|
|
1358
|
+
to explainability configurations.
|
|
1359
|
+
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
|
|
1360
|
+
configurations.
|
|
1361
|
+
headers (list[str]): A list of feature names (without label) of model/endpint input.
|
|
1362
|
+
label_headers (list[str]): List of headers, each for a predicted score in model output.
|
|
1363
|
+
It is used to beautify the analysis report by replacing placeholders like "label0".
|
|
1364
|
+
|
|
1365
|
+
"""
|
|
1366
|
+
predictor_config = model_config.get_predictor_config()
|
|
1367
|
+
self.analysis_config = {
|
|
1368
|
+
"methods": explainability_config.get_explainability_config(),
|
|
1369
|
+
"predictor": predictor_config,
|
|
1370
|
+
}
|
|
1371
|
+
if headers is not None:
|
|
1372
|
+
self.analysis_config["headers"] = headers
|
|
1373
|
+
if label_headers is not None:
|
|
1374
|
+
predictor_config["label_headers"] = label_headers
|
|
1375
|
+
|
|
1376
|
+
def _to_dict(self):
|
|
1377
|
+
"""Generates a request dictionary using the parameters provided to the class."""
|
|
1378
|
+
return self.analysis_config
|
|
1379
|
+
|
|
1380
|
+
|
|
1381
|
+
class ClarifyBaseliningConfig:
|
|
1382
|
+
"""Data class to hold some essential analysis configuration of ClarifyBaseliningJob"""
|
|
1383
|
+
|
|
1384
|
+
def __init__(
|
|
1385
|
+
self,
|
|
1386
|
+
analysis_config,
|
|
1387
|
+
features_attribute=None,
|
|
1388
|
+
inference_attribute=None,
|
|
1389
|
+
probability_attribute=None,
|
|
1390
|
+
probability_threshold_attribute=None,
|
|
1391
|
+
):
|
|
1392
|
+
"""Initialization.
|
|
1393
|
+
|
|
1394
|
+
Args:
|
|
1395
|
+
analysis_config (BiasAnalysisConfig or ExplainabilityAnalysisConfig): analysis config
|
|
1396
|
+
from configurations of the baselining job.
|
|
1397
|
+
features_attribute (str): JMESPath expression to locate features in predictor request
|
|
1398
|
+
payload. Only required when predictor content type is JSONlines.
|
|
1399
|
+
inference_attribute (str): Index, header or JMESPath expression to locate predicted
|
|
1400
|
+
label in predictor response payload.
|
|
1401
|
+
probability_attribute (str): Index or JMESPath expression to locate probabilities or
|
|
1402
|
+
scores in the model output for computing feature attribution.
|
|
1403
|
+
probability_threshold_attribute (float): Value to indicate the threshold to select
|
|
1404
|
+
the binary label in the case of binary classification. Default is 0.5.
|
|
1405
|
+
"""
|
|
1406
|
+
self.analysis_config = analysis_config
|
|
1407
|
+
self.features_attribute = features_attribute
|
|
1408
|
+
self.inference_attribute = inference_attribute
|
|
1409
|
+
self.probability_attribute = probability_attribute
|
|
1410
|
+
self.probability_threshold_attribute = probability_threshold_attribute
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
class ClarifyBaseliningJob(mm.BaseliningJob):
|
|
1414
|
+
"""Provides functionality to retrieve baseline-specific output from Clarify baselining job."""
|
|
1415
|
+
|
|
1416
|
+
def __init__(
|
|
1417
|
+
self,
|
|
1418
|
+
processing_job,
|
|
1419
|
+
):
|
|
1420
|
+
"""Initializes a ClarifyBaseliningJob that tracks a baselining job by suggest_baseline()
|
|
1421
|
+
|
|
1422
|
+
Args:
|
|
1423
|
+
processing_job (sagemaker.processing.ProcessingJob): The ProcessingJob used for
|
|
1424
|
+
baselining instance.
|
|
1425
|
+
"""
|
|
1426
|
+
super(ClarifyBaseliningJob, self).__init__(
|
|
1427
|
+
sagemaker_session=processing_job.sagemaker_session,
|
|
1428
|
+
job_name=processing_job.job_name,
|
|
1429
|
+
inputs=processing_job.inputs,
|
|
1430
|
+
outputs=processing_job.outputs,
|
|
1431
|
+
output_kms_key=processing_job.output_kms_key,
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1434
|
+
def baseline_statistics(self, **_):
|
|
1435
|
+
"""Not implemented.
|
|
1436
|
+
|
|
1437
|
+
The class doesn't support statistics.
|
|
1438
|
+
|
|
1439
|
+
Raises:
|
|
1440
|
+
NotImplementedError
|
|
1441
|
+
"""
|
|
1442
|
+
raise NotImplementedError("{} doesn't support statistics.".format(__class__.__name__))
|
|
1443
|
+
|
|
1444
|
+
def suggested_constraints(self, file_name=None, kms_key=None):
|
|
1445
|
+
"""Returns a sagemaker.model_monitor.
|
|
1446
|
+
|
|
1447
|
+
Constraints object representing the constraints JSON file generated by this baselining job.
|
|
1448
|
+
|
|
1449
|
+
Args:
|
|
1450
|
+
file_name (str): Keep this parameter to align with method signature in super class,
|
|
1451
|
+
but it will be ignored.
|
|
1452
|
+
kms_key (str): The kms key to use when retrieving the file.
|
|
1453
|
+
|
|
1454
|
+
Returns:
|
|
1455
|
+
sagemaker.model_monitor.Constraints: The Constraints object representing the file that
|
|
1456
|
+
was generated by the job.
|
|
1457
|
+
|
|
1458
|
+
Raises:
|
|
1459
|
+
UnexpectedStatusException: This is thrown if the job is not in a 'Complete' state.
|
|
1460
|
+
"""
|
|
1461
|
+
return super(ClarifyBaseliningJob, self).suggested_constraints("analysis.json", kms_key)
|
|
1462
|
+
|
|
1463
|
+
|
|
1464
|
+
class ClarifyMonitoringExecution(mm.MonitoringExecution):
|
|
1465
|
+
"""Provides functionality to retrieve monitoring-specific files output from executions."""
|
|
1466
|
+
|
|
1467
|
+
def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key=None):
|
|
1468
|
+
"""Initializes an object that tracks a monitoring execution by a Clarify model monitor
|
|
1469
|
+
|
|
1470
|
+
Args:
|
|
1471
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
1472
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
1473
|
+
AWS services needed. If not specified, one is created using
|
|
1474
|
+
the default AWS configuration chain.
|
|
1475
|
+
job_name (str): The name of the monitoring execution job.
|
|
1476
|
+
output (sagemaker.Processing.ProcessingOutput): The output associated with the
|
|
1477
|
+
monitoring execution.
|
|
1478
|
+
output_kms_key (str): The output kms key associated with the job. Defaults to None
|
|
1479
|
+
if not provided.
|
|
1480
|
+
"""
|
|
1481
|
+
super(ClarifyMonitoringExecution, self).__init__(
|
|
1482
|
+
sagemaker_session=sagemaker_session,
|
|
1483
|
+
job_name=job_name,
|
|
1484
|
+
inputs=inputs,
|
|
1485
|
+
output=output,
|
|
1486
|
+
output_kms_key=output_kms_key,
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
def statistics(self, **_):
|
|
1490
|
+
"""Not implemented.
|
|
1491
|
+
|
|
1492
|
+
The class doesn't support statistics.
|
|
1493
|
+
|
|
1494
|
+
Raises:
|
|
1495
|
+
NotImplementedError
|
|
1496
|
+
"""
|
|
1497
|
+
raise NotImplementedError("{} doesn't support statistics.".format(__class__.__name__))
|