sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2898 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module configures the SageMaker Clarify bias and model explainability processor jobs.
14
+
15
+ SageMaker Clarify
16
+ ==================
17
+ """
18
+ from __future__ import absolute_import, print_function
19
+
20
+ import copy
21
+ import json
22
+ import logging
23
+ import os
24
+ import re
25
+
26
+ import tempfile
27
+ from abc import ABC, abstractmethod
28
+ from typing import List, Literal, Union, Dict, Optional, Any
29
+ from enum import Enum
30
+ from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
31
+ from sagemaker.core import s3
32
+ from sagemaker.core import image_uris
33
+ from sagemaker.core.helper.session_helper import Session
34
+ from sagemaker.core.network import NetworkConfig
35
+ from sagemaker.core.shapes import ProcessingInput, ProcessingOutput
36
+ from sagemaker.core.processing import Processor
37
+ from sagemaker.core.common_utils import (
38
+ format_tags,
39
+ Tags,
40
+ name_from_base,
41
+ )
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
47
+
48
+ # asym shap val config default values (timeseries)
49
+ ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION = "chronological"
50
+ ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY = "timewise"
51
+ ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS = [
52
+ "chronological",
53
+ "anti_chronological",
54
+ "bidirectional",
55
+ ]
56
+ ASYM_SHAP_VAL_GRANULARITIES = [
57
+ "timewise",
58
+ "fine_grained",
59
+ ]
60
+
61
+ ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
62
+ {
63
+ SchemaOptional("version"): str,
64
+ "dataset_type": And(
65
+ str,
66
+ Use(str.lower),
67
+ lambda s: s
68
+ in (
69
+ "text/csv",
70
+ "application/jsonlines",
71
+ "application/json",
72
+ "application/sagemakercapturejson",
73
+ "application/x-parquet",
74
+ "application/x-image",
75
+ ),
76
+ ),
77
+ SchemaOptional("dataset_uri"): str,
78
+ SchemaOptional("headers"): [str],
79
+ SchemaOptional("label"): Or(str, int),
80
+ # this field indicates user provides predicted_label in dataset
81
+ SchemaOptional("predicted_label"): Or(str, int),
82
+ SchemaOptional("features"): str,
83
+ SchemaOptional("label_values_or_threshold"): [Or(int, float, str)],
84
+ SchemaOptional("probability_threshold"): float,
85
+ SchemaOptional("segment_config"): [
86
+ {
87
+ SchemaOptional("config_name"): str,
88
+ "name_or_index": Or(str, int),
89
+ "segments": [[Or(str, int)]],
90
+ SchemaOptional("display_aliases"): [str],
91
+ }
92
+ ],
93
+ SchemaOptional("facet"): [
94
+ {
95
+ "name_or_index": Or(str, int),
96
+ SchemaOptional("value_or_threshold"): [Or(int, float, str)],
97
+ }
98
+ ],
99
+ SchemaOptional("facet_dataset_uri"): str,
100
+ SchemaOptional("facet_headers"): [str],
101
+ SchemaOptional("predicted_label_dataset_uri"): str,
102
+ SchemaOptional("predicted_label_headers"): [str],
103
+ SchemaOptional("excluded_columns"): [Or(int, str)],
104
+ SchemaOptional("joinsource_name_or_index"): Or(str, int),
105
+ SchemaOptional("group_variable"): Or(str, int),
106
+ SchemaOptional("time_series_data_config"): {
107
+ "target_time_series": Or(str, int),
108
+ "item_id": Or(str, int),
109
+ "timestamp": Or(str, int),
110
+ SchemaOptional("related_time_series"): Or([str], [int]),
111
+ SchemaOptional("static_covariates"): Or([str], [int]),
112
+ SchemaOptional("dataset_format"): And(
113
+ str,
114
+ Use(str.lower),
115
+ lambda s: s
116
+ in (
117
+ "columns",
118
+ "item_records",
119
+ "timestamp_records",
120
+ ),
121
+ ),
122
+ },
123
+ "methods": {
124
+ SchemaOptional("shap"): {
125
+ SchemaOptional("baseline"): Or(
126
+ # URI of the baseline data file
127
+ str,
128
+ # Inplace baseline data (a list of something)
129
+ [
130
+ Or(
131
+ # CSV row
132
+ [Or(int, float, str, None)],
133
+ # JSON row (any JSON object). As I write this only
134
+ # SageMaker JSONLines Dense Format ([1])
135
+ # is supported and the validation is NOT done
136
+ # by the schema but by the data loader.
137
+ # [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
138
+ {object: object},
139
+ )
140
+ ],
141
+ # Arbitrary JSON object as baseline
142
+ {object: object},
143
+ ),
144
+ SchemaOptional("num_clusters"): int,
145
+ SchemaOptional("use_logit"): bool,
146
+ SchemaOptional("num_samples"): int,
147
+ SchemaOptional("agg_method"): And(
148
+ str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
149
+ ),
150
+ SchemaOptional("save_local_shap_values"): bool,
151
+ SchemaOptional("text_config"): {
152
+ "granularity": And(
153
+ str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
154
+ ),
155
+ "language": And(
156
+ str,
157
+ Use(str.lower),
158
+ lambda s: s
159
+ in (
160
+ "chinese",
161
+ "zh",
162
+ "danish",
163
+ "da",
164
+ "dutch",
165
+ "nl",
166
+ "english",
167
+ "en",
168
+ "french",
169
+ "fr",
170
+ "german",
171
+ "de",
172
+ "greek",
173
+ "el",
174
+ "italian",
175
+ "it",
176
+ "japanese",
177
+ "ja",
178
+ "lithuanian",
179
+ "lt",
180
+ "multi-language",
181
+ "xx",
182
+ "norwegian bokmål",
183
+ "nb",
184
+ "polish",
185
+ "pl",
186
+ "portuguese",
187
+ "pt",
188
+ "romanian",
189
+ "ro",
190
+ "russian",
191
+ "ru",
192
+ "spanish",
193
+ "es",
194
+ "afrikaans",
195
+ "af",
196
+ "albanian",
197
+ "sq",
198
+ "arabic",
199
+ "ar",
200
+ "armenian",
201
+ "hy",
202
+ "basque",
203
+ "eu",
204
+ "bengali",
205
+ "bn",
206
+ "bulgarian",
207
+ "bg",
208
+ "catalan",
209
+ "ca",
210
+ "croatian",
211
+ "hr",
212
+ "czech",
213
+ "cs",
214
+ "estonian",
215
+ "et",
216
+ "finnish",
217
+ "fi",
218
+ "gujarati",
219
+ "gu",
220
+ "hebrew",
221
+ "he",
222
+ "hindi",
223
+ "hi",
224
+ "hungarian",
225
+ "hu",
226
+ "icelandic",
227
+ "is",
228
+ "indonesian",
229
+ "id",
230
+ "irish",
231
+ "ga",
232
+ "kannada",
233
+ "kn",
234
+ "kyrgyz",
235
+ "ky",
236
+ "latvian",
237
+ "lv",
238
+ "ligurian",
239
+ "lij",
240
+ "luxembourgish",
241
+ "lb",
242
+ "macedonian",
243
+ "mk",
244
+ "malayalam",
245
+ "ml",
246
+ "marathi",
247
+ "mr",
248
+ "nepali",
249
+ "ne",
250
+ "persian",
251
+ "fa",
252
+ "sanskrit",
253
+ "sa",
254
+ "serbian",
255
+ "sr",
256
+ "setswana",
257
+ "tn",
258
+ "sinhala",
259
+ "si",
260
+ "slovak",
261
+ "sk",
262
+ "slovenian",
263
+ "sl",
264
+ "swedish",
265
+ "sv",
266
+ "tagalog",
267
+ "tl",
268
+ "tamil",
269
+ "ta",
270
+ "tatar",
271
+ "tt",
272
+ "telugu",
273
+ "te",
274
+ "thai",
275
+ "th",
276
+ "turkish",
277
+ "tr",
278
+ "ukrainian",
279
+ "uk",
280
+ "urdu",
281
+ "ur",
282
+ "vietnamese",
283
+ "vi",
284
+ "yoruba",
285
+ "yo",
286
+ ),
287
+ ),
288
+ SchemaOptional("max_top_tokens"): int,
289
+ },
290
+ SchemaOptional("image_config"): {
291
+ SchemaOptional("num_segments"): int,
292
+ SchemaOptional("segment_compactness"): int,
293
+ SchemaOptional("feature_extraction_method"): str,
294
+ SchemaOptional("model_type"): str,
295
+ SchemaOptional("max_objects"): int,
296
+ SchemaOptional("iou_threshold"): float,
297
+ SchemaOptional("context"): float,
298
+ SchemaOptional("debug"): {
299
+ SchemaOptional("image_names"): [str],
300
+ SchemaOptional("class_ids"): [int],
301
+ SchemaOptional("sample_from"): int,
302
+ SchemaOptional("sample_to"): int,
303
+ },
304
+ },
305
+ SchemaOptional("seed"): int,
306
+ SchemaOptional("features_to_explain"): [Or(int, str)],
307
+ },
308
+ SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
309
+ SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
310
+ SchemaOptional("pdp"): {
311
+ "grid_resolution": int,
312
+ SchemaOptional("features"): [Or(str, int)],
313
+ SchemaOptional("top_k_features"): int,
314
+ },
315
+ SchemaOptional("report"): {"name": str, SchemaOptional("title"): str},
316
+ SchemaOptional("asymmetric_shapley_value"): {
317
+ "direction": And(
318
+ str,
319
+ Use(str.lower),
320
+ lambda s: s
321
+ in (
322
+ "chronological",
323
+ "anti_chronological",
324
+ "bidirectional",
325
+ ),
326
+ ),
327
+ "granularity": And(
328
+ str,
329
+ Use(str.lower),
330
+ lambda s: s
331
+ in (
332
+ "timewise",
333
+ "fine_grained",
334
+ ),
335
+ ),
336
+ SchemaOptional("num_samples"): int,
337
+ SchemaOptional("baseline"): Or(
338
+ str,
339
+ {
340
+ SchemaOptional("target_time_series", default="zero"): And(
341
+ str,
342
+ Use(str.lower),
343
+ lambda s: s
344
+ in (
345
+ "zero",
346
+ "mean",
347
+ ),
348
+ ),
349
+ SchemaOptional("related_time_series"): And(
350
+ str,
351
+ Use(str.lower),
352
+ lambda s: s
353
+ in (
354
+ "zero",
355
+ "mean",
356
+ ),
357
+ ),
358
+ SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]},
359
+ },
360
+ ),
361
+ },
362
+ },
363
+ SchemaOptional("predictor"): {
364
+ SchemaOptional("endpoint_name"): str,
365
+ SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
366
+ SchemaOptional("model_name"): str,
367
+ SchemaOptional("target_model"): str,
368
+ SchemaOptional("instance_type"): str,
369
+ SchemaOptional("initial_instance_count"): int,
370
+ SchemaOptional("accelerator_type"): str,
371
+ SchemaOptional("content_type"): And(
372
+ str,
373
+ Use(str.lower),
374
+ lambda s: s
375
+ in (
376
+ "text/csv",
377
+ "application/jsonlines",
378
+ "application/json",
379
+ "image/jpeg",
380
+ "image/png",
381
+ "application/x-npy",
382
+ ),
383
+ ),
384
+ SchemaOptional("accept_type"): And(
385
+ str,
386
+ Use(str.lower),
387
+ lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
388
+ ),
389
+ SchemaOptional("label"): Or(str, int),
390
+ SchemaOptional("probability"): Or(str, int),
391
+ SchemaOptional("label_headers"): [Or(str, int)],
392
+ SchemaOptional("content_template"): Or(str, {str: str}),
393
+ SchemaOptional("record_template"): str,
394
+ SchemaOptional("custom_attributes"): str,
395
+ SchemaOptional("time_series_predictor_config"): {
396
+ "forecast": str,
397
+ },
398
+ },
399
+ }
400
+ )
401
+
402
+
403
+ class DatasetType(Enum):
404
+ """Enum to store different dataset types supported in the Analysis config file"""
405
+
406
+ TEXTCSV = "text/csv"
407
+ JSONLINES = "application/jsonlines"
408
+ JSON = "application/json"
409
+ PARQUET = "application/x-parquet"
410
+ IMAGE = "application/x-image"
411
+
412
+
413
+ class TimeSeriesJSONDatasetFormat(Enum):
414
+ """Possible dataset formats for JSON time series data files.
415
+
416
+ Below is an example ``COLUMNS`` dataset for time series explainability::
417
+
418
+ {
419
+ "ids": [1, 2],
420
+ "timestamps": [3, 4],
421
+ "target_ts": [5, 6],
422
+ "rts1": [0.25, 0.5],
423
+ "rts2": [1.25, 1.5],
424
+ "scv1": [10, 20],
425
+ "scv2": [30, 40]
426
+ }
427
+
428
+ For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
429
+
430
+ item_id="ids"
431
+ timestamp="timestamps"
432
+ target_time_series="target_ts"
433
+ related_time_series=["rts1", "rts2"]
434
+ static_covariates=["scv1", "scv2"]
435
+
436
+ Below is an example ``ITEM_RECORDS`` dataset for time series explainability::
437
+
438
+ [
439
+ {
440
+ "id": 1,
441
+ "scv1": 10,
442
+ "scv2": "red",
443
+ "timeseries": [
444
+ {"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10},
445
+ {"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20},
446
+ {"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30}
447
+ ]
448
+ },
449
+ {
450
+ "id": 2,
451
+ "scv1": 20,
452
+ "scv2": "blue",
453
+ "timeseries": [
454
+ {"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40},
455
+ {"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50}
456
+ ]
457
+ }
458
+ ]
459
+
460
+ For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
461
+
462
+ item_id="[*].id"
463
+ timestamp="[*].timeseries[].timestamp"
464
+ target_time_series="[*].timeseries[].target_ts"
465
+ related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"]
466
+ static_covariates=["[*].scv1", "[*].scv2"]
467
+
468
+ Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability::
469
+
470
+ [
471
+ {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25},
472
+ {"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5},
473
+ {"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75},
474
+ {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1}
475
+ ]
476
+
477
+ For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
478
+
479
+ item_id="[*].id"
480
+ timestamp="[*].timestamp"
481
+ target_time_series="[*].target_ts"
482
+ related_time_series=["[*].rts1"]
483
+ static_covariates=["[*].scv1"]
484
+ """
485
+
486
+ COLUMNS = "columns"
487
+ ITEM_RECORDS = "item_records"
488
+ TIMESTAMP_RECORDS = "timestamp_records"
489
+
490
+
491
+ class SegmentationConfig:
492
+ """Config object that defines segment(s) of the dataset on which metrics are computed."""
493
+
494
+ def __init__(
495
+ self,
496
+ name_or_index: Union[str, int],
497
+ segments: List[List[Union[str, int]]],
498
+ config_name: Optional[str] = None,
499
+ display_aliases: Optional[List[str]] = None,
500
+ ):
501
+ """Initializes a segmentation configuration for a dataset column.
502
+
503
+ Args:
504
+ name_or_index (str or int): The name or index of the column in the dataset on which
505
+ the segment(s) is defined.
506
+ segments (List[List[str or int]]): Each List of values represents one segment. If N
507
+ Lists are provided, we generate N+1 segments - the additional segment, denoted as
508
+ the '__default__' segment, is for the rest of the values that are not covered by
509
+ these lists. For continuous columns, a segment must be given as strings in interval
510
+ notation (eg.: ["[1, 4]"] or ["(2, 5]"]). A segment can also be composed of
511
+ multiple intervals (eg.: ["[1, 4]", "(5, 6]"] is one segment). For categorical
512
+ columns, each segment should contain one or more of the categorical values for
513
+ the categorical column, which may be strings or integers.
514
+ Eg,: For a continuous column, ``segments`` could be
515
+ [["[1, 4]", "(5, 6]"], ["(7, 9)"]] - this generates 3 segments including the
516
+ default segment. For a categorical columns with values ("A", "B", "C", "D"),
517
+ ``segments``,could be [["A", "B"]]. This generate 2 segments, including the default
518
+ segment.
519
+ config_name (str) - Optional name for the segment config to identify the config.
520
+ display_aliases (List[str]) - Optional list of display names for the ``segments`` for
521
+ the analysis output and report. This list should be the same length as the number of
522
+ lists provided in ``segments`` or with one additional display alias for the default
523
+ segment.
524
+
525
+ Raises:
526
+ ValueError: when the ``name_or_index`` is None, ``segments`` is invalid, or a wrong
527
+ number of ``display_aliases`` are specified.
528
+ """
529
+ if name_or_index is None:
530
+ raise ValueError("`name_or_index` cannot be None")
531
+ self.name_or_index = name_or_index
532
+ if (
533
+ not segments
534
+ or not isinstance(segments, list)
535
+ or not all([isinstance(segment, list) for segment in segments])
536
+ ):
537
+ raise ValueError("`segments` must be a list of lists of values or intervals.")
538
+ self.segments = segments
539
+ self.config_name = config_name
540
+ if display_aliases is not None and not (
541
+ len(display_aliases) == len(segments) or len(display_aliases) == len(segments) + 1
542
+ ):
543
+ raise ValueError(
544
+ "Number of `display_aliases` must equal the number of segments"
545
+ " specified or with one additional default segment display alias."
546
+ )
547
+ self.display_aliases = display_aliases
548
+
549
+ def to_dict(self) -> Dict[str, Any]: # pragma: no cover
550
+ """Returns SegmentationConfig as a dict."""
551
+ segment_config_dict = {"name_or_index": self.name_or_index, "segments": self.segments}
552
+ if self.config_name:
553
+ segment_config_dict["config_name"] = self.config_name
554
+ if self.display_aliases:
555
+ segment_config_dict["display_aliases"] = self.display_aliases
556
+ return segment_config_dict
557
+
558
+
559
+ class TimeSeriesDataConfig:
560
+ """Config object for TimeSeries explainability data configuration fields."""
561
+
562
+ def __init__(
563
+ self,
564
+ target_time_series: Union[str, int],
565
+ item_id: Union[str, int],
566
+ timestamp: Union[str, int],
567
+ related_time_series: Optional[List[Union[str, int]]] = None,
568
+ static_covariates: Optional[List[Union[str, int]]] = None,
569
+ dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None,
570
+ ):
571
+ """Initialises TimeSeries explainability data configuration fields.
572
+
573
+ Args:
574
+ target_time_series (str or int): A string or a zero-based integer index.
575
+ Used to locate the target time series in the shared input dataset.
576
+ If this parameter is a string, then all other parameters except
577
+ `dataset_format` must be strings or lists of strings. If
578
+ this parameter is an int, then all other parameters except
579
+ `dataset_format` must be ints or lists of ints.
580
+ item_id (str or int): A string or a zero-based integer index. Used to
581
+ locate item id in the shared input dataset.
582
+ timestamp (str or int): A string or a zero-based integer index. Used to
583
+ locate timestamp in the shared input dataset.
584
+ related_time_series (list[str] or list[int]): Optional. An array of strings
585
+ or array of zero-based integer indices. Used to locate all related time
586
+ series in the shared input dataset (if present).
587
+ static_covariates (list[str] or list[int]): Optional. An array of strings or
588
+ array of zero-based integer indices. Used to locate all static covariate
589
+ fields in the shared input dataset (if present).
590
+ dataset_format (TimeSeriesJSONDatasetFormat): Describes the format
591
+ of the data files provided for analysis. Should only be provided
592
+ when dataset is in JSON format.
593
+
594
+ Raises:
595
+ ValueError: If any required arguments are not provided or are the wrong type.
596
+ """
597
+ # check target_time_series, item_id, and timestamp are provided
598
+ if not target_time_series:
599
+ raise ValueError("Please provide a target time series.")
600
+ if not item_id:
601
+ raise ValueError("Please provide an item id.")
602
+ if not timestamp:
603
+ raise ValueError("Please provide a timestamp.")
604
+ # check all arguments are the right types
605
+ if not isinstance(target_time_series, (str, int)):
606
+ raise ValueError("Please provide a string or an int for ``target_time_series``")
607
+ params_type = type(target_time_series)
608
+ if not isinstance(item_id, params_type):
609
+ raise ValueError(f"Please provide {params_type} for ``item_id``")
610
+ if not isinstance(timestamp, params_type):
611
+ raise ValueError(f"Please provide {params_type} for ``timestamp``")
612
+ # add mandatory fields to an internal dictionary
613
+ self.time_series_data_config = dict()
614
+ _set(target_time_series, "target_time_series", self.time_series_data_config)
615
+ _set(item_id, "item_id", self.time_series_data_config)
616
+ _set(timestamp, "timestamp", self.time_series_data_config)
617
+ # check optional arguments are right types if provided
618
+ related_time_series_error_message = (
619
+ f"Please provide a list of {params_type} for ``related_time_series``"
620
+ )
621
+ if related_time_series:
622
+ if not isinstance(related_time_series, list):
623
+ raise ValueError(
624
+ related_time_series_error_message
625
+ ) # related_time_series is not a list
626
+ if not all([isinstance(value, params_type) for value in related_time_series]):
627
+ raise ValueError(
628
+ related_time_series_error_message
629
+ ) # related_time_series is not a list of strings or list of ints
630
+ if params_type == str and not all(related_time_series):
631
+ raise ValueError("Please do not provide empty strings in ``related_time_series``.")
632
+ _set(
633
+ related_time_series, "related_time_series", self.time_series_data_config
634
+ ) # related_time_series is valid, add it
635
+ static_covariates_series_error_message = (
636
+ f"Please provide a list of {params_type} for ``static_covariates``"
637
+ )
638
+ if static_covariates:
639
+ if not isinstance(static_covariates, list):
640
+ raise ValueError(
641
+ static_covariates_series_error_message
642
+ ) # static_covariates is not a list
643
+ if not all([isinstance(value, params_type) for value in static_covariates]):
644
+ raise ValueError(
645
+ static_covariates_series_error_message
646
+ ) # static_covariates is not a list of strings or list of ints
647
+ if params_type == str and not all(static_covariates):
648
+ raise ValueError("Please do not provide empty strings in ``static_covariates``.")
649
+ _set(
650
+ static_covariates, "static_covariates", self.time_series_data_config
651
+ ) # static_covariates is valid, add it
652
+ if params_type == str:
653
+ # check dataset_format is provided and valid
654
+ if not isinstance(dataset_format, TimeSeriesJSONDatasetFormat):
655
+ raise ValueError("Please provide a valid dataset format.")
656
+ _set(dataset_format.value, "dataset_format", self.time_series_data_config)
657
+ else:
658
+ if dataset_format:
659
+ raise ValueError(
660
+ "Dataset format should only be provided when data files are JSONs."
661
+ )
662
+
663
+ def get_time_series_data_config(self):
664
+ """Returns part of an analysis config dictionary."""
665
+ return copy.deepcopy(self.time_series_data_config)
666
+
667
+
668
+ class DataConfig:
669
+ """Config object related to configurations of the input and output dataset."""
670
+
671
+ def __init__(
672
+ self,
673
+ s3_data_input_path: str,
674
+ s3_output_path: str,
675
+ s3_analysis_config_output_path: Optional[str] = None,
676
+ label: Optional[str] = None,
677
+ headers: Optional[List[str]] = None,
678
+ features: Optional[str] = None,
679
+ dataset_type: str = "text/csv",
680
+ s3_compression_type: str = "None",
681
+ joinsource: Optional[Union[str, int]] = None,
682
+ facet_dataset_uri: Optional[str] = None,
683
+ facet_headers: Optional[List[str]] = None,
684
+ predicted_label_dataset_uri: Optional[str] = None,
685
+ predicted_label_headers: Optional[List[str]] = None,
686
+ predicted_label: Optional[Union[str, int]] = None,
687
+ excluded_columns: Optional[Union[List[int], List[str]]] = None,
688
+ segmentation_config: Optional[List[SegmentationConfig]] = None,
689
+ time_series_data_config: Optional[TimeSeriesDataConfig] = None,
690
+ ):
691
+ """Initializes a configuration of both input and output datasets.
692
+
693
+ Args:
694
+ s3_data_input_path (str): Dataset S3 prefix/object URI.
695
+ s3_output_path (str): S3 prefix to store the output.
696
+ s3_analysis_config_output_path (str): S3 prefix to store the analysis config output.
697
+ If this field is None, then the ``s3_output_path`` will be used
698
+ to store the ``analysis_config`` output.
699
+ label (str): Target attribute of the model required by bias metrics. Specified as
700
+ column name or index for CSV dataset or a JMESPath expression for JSON/JSON Lines.
701
+ *Required parameter* except for when the input dataset does not contain the label.
702
+ Note: For JSON, the JMESPath query must result in a list of labels for each
703
+ sample. For JSON Lines, it must result in the label for each line.
704
+ Only a single label per sample is supported at this time.
705
+ headers ([str]): List of column names in the dataset. If not provided, Clarify will
706
+ generate headers to use internally. For time series explainability cases,
707
+ please provide headers in the order of item_id, timestamp, target_time_series,
708
+ all related_time_series columns, and then all static_covariate columns.
709
+ features (str): JMESPath expression to locate the feature values
710
+ if the dataset format is JSON/JSON Lines.
711
+ Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of
712
+ feature values. For JSON Lines, it must result in a 1-D list of features for each
713
+ line.
714
+ dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
715
+ ``"application/jsonlines"`` for JSON Lines, ``"application/json"`` for JSON, and
716
+ ``"application/x-parquet"`` for Parquet.
717
+ s3_compression_type (str): Valid options are "None" or ``"Gzip"``.
718
+ joinsource (str or int): The name or index of the column in the dataset that
719
+ acts as an identifier column (for instance, while performing a join).
720
+ This column is only used as an identifier, and not used for any other computations.
721
+ This is an optional field in all cases except:
722
+
723
+ * The dataset contains more than one file and `save_local_shap_values`
724
+ is set to true in :class:`~sagemaker.clarify.ShapConfig`, and/or
725
+ * When the dataset and/or facet dataset and/or predicted label dataset
726
+ are in separate files.
727
+
728
+ facet_dataset_uri (str): Dataset S3 prefix/object URI that contains facet attribute(s),
729
+ used for bias analysis on datasets without facets.
730
+
731
+ * If the dataset and the facet dataset are one single file each, then
732
+ the original dataset and facet dataset must have the same number of rows.
733
+ * If the dataset and facet dataset are in multiple files (either one), then
734
+ an index column, ``joinsource``, is required to join the two datasets.
735
+
736
+ Clarify will not use the ``joinsource`` column and columns present in the facet
737
+ dataset when calling model inference APIs.
738
+ Note: this is only supported for ``"text/csv"`` dataset type.
739
+ facet_headers (list[str]): List of column names in the facet dataset.
740
+ predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels,
741
+ which are used directly for analysis instead of making model inference API calls.
742
+
743
+ * If the dataset and the predicted label dataset are one single file each, then the
744
+ original dataset and predicted label dataset must have the same number of rows.
745
+ * If the dataset and predicted label dataset are in multiple files (either one),
746
+ then an index column, ``joinsource``, is required to join the two datasets.
747
+
748
+ Note: this is only supported for ``"text/csv"`` dataset type.
749
+ predicted_label_headers (list[str]): List of column names in the predicted label dataset
750
+ predicted_label (str or int): Predicted label of the target attribute of the model
751
+ required for running bias analysis. Specified as column name or index for CSV data,
752
+ or a JMESPath expression for JSON/JSON Lines.
753
+ Clarify uses the predicted labels directly instead of making model inference API
754
+ calls.
755
+ Note: For JSON, the JMESPath query must result in a list of predicted labels for
756
+ each sample. For JSON Lines, it must result in the predicted label for each line.
757
+ Only a single predicted label per sample is supported at this time.
758
+ excluded_columns (list[int] or list[str]): A list of names or indices of the columns
759
+ which are to be excluded from making model inference API calls.
760
+ segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig``
761
+ objects.
762
+ time_series_data_config (TimeSeriesDataConfig): Optional. A config object for TimeSeries
763
+ data specific fields, required for TimeSeries explainability use cases.
764
+
765
+ Raises:
766
+ ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
767
+ are used with un-supported ``dataset_type``, or facet dataset parameters
768
+ are used with un-supported ``dataset_type``
769
+ """
770
+ if dataset_type not in [
771
+ "text/csv",
772
+ "application/jsonlines",
773
+ "application/json",
774
+ "application/x-parquet",
775
+ "application/x-image",
776
+ ]:
777
+ raise ValueError(
778
+ f"Invalid dataset_type '{dataset_type}'."
779
+ f" Please check the API documentation for the supported dataset types."
780
+ )
781
+ # predicted_label and excluded_columns are only supported for tabular datasets
782
+ if dataset_type not in [
783
+ "text/csv",
784
+ "application/jsonlines",
785
+ "application/json",
786
+ "application/x-parquet",
787
+ ]:
788
+ if predicted_label:
789
+ raise ValueError(
790
+ f"The parameter 'predicted_label' is not supported"
791
+ f" for dataset_type '{dataset_type}'."
792
+ f" Please check the API documentation for the supported dataset types."
793
+ )
794
+ if excluded_columns:
795
+ raise ValueError(
796
+ f"The parameter 'excluded_columns' is not supported"
797
+ f" for dataset_type '{dataset_type}'."
798
+ f" Please check the API documentation for the supported dataset types."
799
+ )
800
+ # parameters for analysis on datasets without facets are only supported for CSV datasets
801
+ if dataset_type != "text/csv":
802
+ if facet_dataset_uri or facet_headers:
803
+ raise ValueError(
804
+ f"The parameters 'facet_dataset_uri' and 'facet_headers'"
805
+ f" are not supported for dataset_type '{dataset_type}'."
806
+ f" Please check the API documentation for the supported dataset types."
807
+ )
808
+ if predicted_label_dataset_uri or predicted_label_headers:
809
+ raise ValueError(
810
+ f"The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers'"
811
+ f" are not supported for dataset_type '{dataset_type}'."
812
+ f" Please check the API documentation for the supported dataset types."
813
+ )
814
+ # check if any other format other than JSON is provided for time series case
815
+ if time_series_data_config:
816
+ if dataset_type != "application/json":
817
+ raise ValueError(
818
+ "Currently time series explainability only supports JSON format data."
819
+ )
820
+ # features JMESPath is required for JSON as we can't derive it ourselves
821
+ if dataset_type == "application/json" and features is None and not time_series_data_config:
822
+ raise ValueError("features JMESPath is required for application/json dataset_type")
823
+ self.s3_data_input_path = s3_data_input_path
824
+ self.s3_output_path = s3_output_path
825
+ self.s3_analysis_config_output_path = s3_analysis_config_output_path
826
+ self.s3_data_distribution_type = "FullyReplicated"
827
+ self.s3_compression_type = s3_compression_type
828
+ self.label = label
829
+ self.headers = headers
830
+ self.features = features
831
+ self.facet_dataset_uri = facet_dataset_uri
832
+ self.facet_headers = facet_headers
833
+ self.predicted_label_dataset_uri = predicted_label_dataset_uri
834
+ self.predicted_label_headers = predicted_label_headers
835
+ self.predicted_label = predicted_label
836
+ self.excluded_columns = excluded_columns
837
+ self.segmentation_configs = segmentation_config
838
+ self.analysis_config = {
839
+ "dataset_type": dataset_type,
840
+ }
841
+ _set(features, "features", self.analysis_config)
842
+ _set(headers, "headers", self.analysis_config)
843
+ _set(label, "label", self.analysis_config)
844
+ _set(joinsource, "joinsource_name_or_index", self.analysis_config)
845
+ _set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config)
846
+ _set(facet_headers, "facet_headers", self.analysis_config)
847
+ _set(
848
+ predicted_label_dataset_uri,
849
+ "predicted_label_dataset_uri",
850
+ self.analysis_config,
851
+ )
852
+ _set(predicted_label_headers, "predicted_label_headers", self.analysis_config)
853
+ _set(predicted_label, "predicted_label", self.analysis_config)
854
+ _set(excluded_columns, "excluded_columns", self.analysis_config)
855
+ if segmentation_config:
856
+ _set(
857
+ [item.to_dict() for item in segmentation_config],
858
+ "segment_config",
859
+ self.analysis_config,
860
+ )
861
+ if time_series_data_config:
862
+ _set(
863
+ time_series_data_config.get_time_series_data_config(),
864
+ "time_series_data_config",
865
+ self.analysis_config,
866
+ )
867
+
868
+ def get_config(self):
869
+ """Returns part of an analysis config dictionary."""
870
+ return copy.deepcopy(self.analysis_config)
871
+
872
+
873
+ class BiasConfig:
874
+ """Config object with user-defined bias configurations of the input dataset."""
875
+
876
+ def __init__(
877
+ self,
878
+ label_values_or_threshold: List[Union[int, float, str]],
879
+ facet_name: Union[str, int, List[str], List[int]],
880
+ facet_values_or_threshold: Optional[Union[int, float, str]] = None,
881
+ group_name: Optional[str] = None,
882
+ ):
883
+ """Initializes a configuration of the sensitive groups in the dataset.
884
+
885
+ Args:
886
+ label_values_or_threshold ([int or float or str]): List of label value(s) or threshold
887
+ to indicate positive outcome used for bias metrics.
888
+ The appropriate threshold depends on the problem type:
889
+
890
+ * Binary: The list has one positive value.
891
+ * Categorical:The list has one or more (but not all) categories
892
+ which are the positive values.
893
+ * Regression: The list should include one threshold that defines the **exclusive**
894
+ lower bound of positive values.
895
+
896
+ facet_name (str or int or list[str] or list[int]): Sensitive attribute column name
897
+ (or index in the input data) to use when computing bias metrics. It can also be a
898
+ list of names (or indexes) for computing metrics for multiple sensitive attributes.
899
+ facet_values_or_threshold ([int or float or str] or [[int or float or str]]):
900
+ The parameter controls the values of the sensitive group.
901
+ If ``facet_name`` is a scalar, then it can be None or a list.
902
+ Depending on the data type of the facet column, the values mean:
903
+
904
+ * Binary data: None means computing the bias metrics for each binary value.
905
+ Or add one binary value to the list, to compute its bias metrics only.
906
+ * Categorical data: None means computing the bias metrics for each category. Or add
907
+ one or more (but not all) categories to the list, to compute their
908
+ bias metrics v.s. the other categories.
909
+ * Continuous data: The list should include one and only one threshold which defines
910
+ the **exclusive** lower bound of a sensitive group.
911
+
912
+ If ``facet_name`` is a list, then ``facet_values_or_threshold`` can be None
913
+ if all facets are of binary or categorical type.
914
+ Otherwise, ``facet_values_or_threshold`` should be a list, and each element
915
+ is the value or threshold of the corresponding facet.
916
+ group_name (str): Optional column name or index to indicate a group column to be used
917
+ for the bias metric
918
+ `Conditional Demographic Disparity in Labels `(CDDL) <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_
919
+ or
920
+ `Conditional Demographic Disparity in Predicted Labels (CDDPL) <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_.
921
+
922
+ Raises:
923
+ ValueError: If the number of ``facet_names`` doesn't equal number of ``facet values``
924
+ """ # noqa E501 # pylint: disable=c0301
925
+ if isinstance(facet_name, list):
926
+ assert len(facet_name) > 0, "Please provide at least one facet"
927
+ if facet_values_or_threshold is None:
928
+ facet_list = [
929
+ {"name_or_index": single_facet_name} for single_facet_name in facet_name
930
+ ]
931
+ elif len(facet_values_or_threshold) == len(facet_name):
932
+ facet_list = []
933
+ for i, single_facet_name in enumerate(facet_name):
934
+ facet = {"name_or_index": single_facet_name}
935
+ if facet_values_or_threshold is not None:
936
+ _set(facet_values_or_threshold[i], "value_or_threshold", facet)
937
+ facet_list.append(facet)
938
+ else:
939
+ raise ValueError(
940
+ "The number of facet names doesn't match the number of facet values"
941
+ )
942
+ else:
943
+ facet = {"name_or_index": facet_name}
944
+ _set(facet_values_or_threshold, "value_or_threshold", facet)
945
+ facet_list = [facet]
946
+ self.analysis_config = {
947
+ "label_values_or_threshold": label_values_or_threshold,
948
+ "facet": facet_list,
949
+ }
950
+ _set(group_name, "group_variable", self.analysis_config)
951
+
952
+ def get_config(self):
953
+ """Returns a dictionary of bias detection configurations, part of the analysis config"""
954
+ return copy.deepcopy(self.analysis_config)
955
+
956
+
957
+ class TimeSeriesModelConfig:
958
+ """Config object for TimeSeries predictor configuration fields."""
959
+
960
+ def __init__(
961
+ self,
962
+ forecast: str,
963
+ ):
964
+ """Initializes model configuration fields for TimeSeries explainability use cases.
965
+
966
+ Args:
967
+ forecast (str): JMESPath expression to extract the forecast result.
968
+
969
+ Raises:
970
+ ValueError: when ``forecast`` is not a string or not provided
971
+ """
972
+ # check string forecast is provided
973
+ if not isinstance(forecast, str):
974
+ raise ValueError(
975
+ "Please provide a string JMESPath expression for ``forecast`` "
976
+ "to extract the forecast result."
977
+ )
978
+ # add fields to an internal config dictionary
979
+ self.time_series_model_config = dict()
980
+ _set(forecast, "forecast", self.time_series_model_config)
981
+
982
+ def get_time_series_model_config(self):
983
+ """Returns TimeSeries model config dictionary"""
984
+ return copy.deepcopy(self.time_series_model_config)
985
+
986
+
987
+ class ModelConfig:
988
+ """Config object related to a model and its endpoint to be created."""
989
+
990
+ def __init__(
991
+ self,
992
+ model_name: Optional[str] = None,
993
+ instance_count: Optional[int] = None,
994
+ instance_type: Optional[str] = None,
995
+ accept_type: Optional[str] = None,
996
+ content_type: Optional[str] = None,
997
+ content_template: Optional[str] = None,
998
+ record_template: Optional[str] = None,
999
+ custom_attributes: Optional[str] = None,
1000
+ accelerator_type: Optional[str] = None,
1001
+ endpoint_name_prefix: Optional[str] = None,
1002
+ target_model: Optional[str] = None,
1003
+ endpoint_name: Optional[str] = None,
1004
+ time_series_model_config: Optional[TimeSeriesModelConfig] = None,
1005
+ ):
1006
+ r"""Initializes a configuration of a model and the endpoint to be created for it.
1007
+
1008
+ Args:
1009
+ model_name (str): Model name (as created by
1010
+ `CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
1011
+ Cannot be set when ``endpoint_name`` is set.
1012
+ Must be set with ``instance_count``, ``instance_type``
1013
+ instance_count (int): The number of instances of a new endpoint for model inference.
1014
+ Cannot be set when ``endpoint_name`` is set.
1015
+ Must be set with ``model_name``, ``instance_type``
1016
+ instance_type (str): The type of
1017
+ `EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
1018
+ to use for model inference; for example, ``"ml.c5.xlarge"``.
1019
+ Cannot be set when ``endpoint_name`` is set.
1020
+ Must be set with ``instance_count``, ``model_name``
1021
+ accept_type (str): The model output format to be used for getting inferences with the
1022
+ shadow endpoint. Valid values are ``"text/csv"`` for CSV,
1023
+ ``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
1024
+ Default is the same as ``content_type``.
1025
+ content_type (str): The model input format to be used for getting inferences with the
1026
+ shadow endpoint. Valid values are ``"text/csv"`` for CSV,
1027
+ ``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
1028
+ Default is the same as ``dataset_format``.
1029
+ content_template (str): A template string to be used to construct the model input from
1030
+ dataset instances. It is only used, and required, when ``model_content_type`` is
1031
+ ``"application/jsonlines"`` or ``"application/json"``. When ``model_content_type``
1032
+ is ``application/jsonlines``, the template should have one and only one
1033
+ placeholder, ``$features``, which will be replaced by a features list for each
1034
+ record to form the model inference input. When ``model_content_type`` is
1035
+ ``application/json``, the template can have either placeholder ``$record``, which
1036
+ will be replaced by a single record templated by ``record_template`` and only a
1037
+ single record at a time will be sent to the model, or placeholder ``$records``,
1038
+ which will be replaced by a list of records, each templated by ``record_template``.
1039
+ record_template (str): A template string to be used to construct each record of the
1040
+ model input from dataset instances. It is only used, and required, when
1041
+ ``model_content_type`` is ``"application/json"``.
1042
+ The template string may contain one of the following:
1043
+
1044
+ * Placeholder ``$features`` that will be substituted by the array of feature values
1045
+ and/or an optional placeholder ``$feature_names`` that will be substituted by the
1046
+ array of feature names.
1047
+ * Exactly one placeholder ``$features_kvp`` that will be substituted by the
1048
+ key-value pairs of feature name and feature value.
1049
+ * Or for each feature, if "A" is the feature name in the ``headers`` configuration,
1050
+ then placeholder syntax ``"${A}"`` (the double-quotes are part of the
1051
+ placeholder) will be substituted by the feature value.
1052
+
1053
+ ``record_template`` will be used in conjunction with ``content_template`` to
1054
+ construct the model input.
1055
+
1056
+ **Examples:**
1057
+
1058
+ Given:
1059
+
1060
+ * ``headers``: ``["A", "B"]``
1061
+ * ``features``: ``[[0, 1], [3, 4]]``
1062
+
1063
+ Example model input 1::
1064
+
1065
+ {
1066
+ "instances": [[0, 1], [3, 4]],
1067
+ "feature_names": ["A", "B"]
1068
+ }
1069
+
1070
+ content_template and record_template to construct above:
1071
+
1072
+ * ``content_template``: ``"{\"instances\": $records}"``
1073
+ * ``record_template``: ``"$features"``
1074
+
1075
+ Example model input 2::
1076
+
1077
+ [
1078
+ { "A": 0, "B": 1 },
1079
+ { "A": 3, "B": 4 },
1080
+ ]
1081
+
1082
+ content_template and record_template to construct above:
1083
+
1084
+ * ``content_template``: ``"$records"``
1085
+ * ``record_template``: ``"$features_kvp"``
1086
+
1087
+ Or, alternatively:
1088
+
1089
+ * ``content_template``: ``"$records"``
1090
+ * ``record_template``: ``"{\"A\": \"${A}\", \"B\": \"${B}\"}"``
1091
+
1092
+ Example model input 3 (single record only)::
1093
+
1094
+ { "A": 0, "B": 1 }
1095
+
1096
+ content_template and record_template to construct above:
1097
+
1098
+ * ``content_template``: ``"$record"``
1099
+ * ``record_template``: ``"$features_kvp"``
1100
+ custom_attributes (str): Provides additional information about a request for an
1101
+ inference submitted to a model hosted at an Amazon SageMaker endpoint. The
1102
+ information is an opaque value that is forwarded verbatim. You could use this
1103
+ value, for example, to provide an ID that you can use to track a request or to
1104
+ provide other metadata that a service endpoint was programmed to process. The value
1105
+ must consist of no more than 1024 visible US-ASCII characters as specified in
1106
+ Section 3.3.6.
1107
+ `Field Value Components <https://tools.ietf.org/html/rfc7230#section-3.2.6>`_
1108
+ of the Hypertext Transfer Protocol (HTTP/1.1).
1109
+ accelerator_type (str): SageMaker
1110
+ `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_
1111
+ accelerator type to deploy to the model endpoint instance
1112
+ for making inferences to the model.
1113
+ endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow
1114
+ pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``.
1115
+ target_model (str): Sets the target model name when using a multi-model endpoint. For
1116
+ more information about multi-model endpoints, see
1117
+ https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
1118
+ endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint.
1119
+ Cannot be set when ``model_name``, ``instance_count``,
1120
+ and ``instance_type`` set
1121
+ time_series_model_config (TimeSeriesModelConfig): Optional. A config object for
1122
+ TimeSeries predictor specific fields, required for TimeSeries
1123
+ explainability use cases.
1124
+
1125
+ Raises:
1126
+ ValueError: when the
1127
+ - ``endpoint_name_prefix`` is invalid,
1128
+ - ``accept_type`` is invalid,
1129
+ - ``content_type`` is invalid,
1130
+ - ``content_template`` has no placeholder "features"
1131
+ - both [``endpoint_name``]
1132
+ AND [``model_name``, ``instance_count``, ``instance_type``] are set
1133
+ - both [``endpoint_name``] AND [``endpoint_name_prefix``] are set
1134
+ """
1135
+
1136
+ # validation
1137
+ _model_endpoint_config_rule = (
1138
+ all([model_name, instance_count, instance_type]),
1139
+ all([endpoint_name]),
1140
+ )
1141
+ assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
1142
+ if endpoint_name:
1143
+ assert not endpoint_name_prefix
1144
+
1145
+ # main init logic
1146
+ self.predictor_config = (
1147
+ {
1148
+ "model_name": model_name,
1149
+ "instance_type": instance_type,
1150
+ "initial_instance_count": instance_count,
1151
+ }
1152
+ if not endpoint_name
1153
+ else {"endpoint_name": endpoint_name}
1154
+ )
1155
+ if endpoint_name_prefix:
1156
+ if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
1157
+ raise ValueError(
1158
+ "Invalid endpoint_name_prefix."
1159
+ " Please follow pattern ^[a-zA-Z0-9](-*[a-zA-Z0-9])."
1160
+ )
1161
+ self.predictor_config["endpoint_name_prefix"] = endpoint_name_prefix
1162
+ if accept_type is not None:
1163
+ if accept_type not in ["text/csv", "application/jsonlines", "application/json"]:
1164
+ raise ValueError(
1165
+ f"Invalid accept_type {accept_type}."
1166
+ f" Please choose text/csv or application/jsonlines."
1167
+ )
1168
+ if time_series_model_config and accept_type == "text/csv":
1169
+ raise ValueError(
1170
+ "``accept_type`` must be JSON or JSONLines for time series explainability."
1171
+ )
1172
+ self.predictor_config["accept_type"] = accept_type
1173
+ if content_type is not None:
1174
+ if content_type not in [
1175
+ "text/csv",
1176
+ "application/jsonlines",
1177
+ "application/json",
1178
+ "image/jpeg",
1179
+ "image/jpg",
1180
+ "image/png",
1181
+ "application/x-npy",
1182
+ ]:
1183
+ raise ValueError(
1184
+ f"Invalid content_type {content_type}."
1185
+ f" Please choose text/csv or application/jsonlines."
1186
+ )
1187
+ if content_type == "application/jsonlines":
1188
+ if content_template is None:
1189
+ raise ValueError(
1190
+ f"content_template field is required for content_type {content_type}"
1191
+ )
1192
+ if "$features" not in content_template:
1193
+ raise ValueError(
1194
+ f"Invalid content_template {content_template}."
1195
+ f" Please include a placeholder $features."
1196
+ )
1197
+ if content_type == "application/json":
1198
+ if content_template is None or record_template is None:
1199
+ raise ValueError(
1200
+ f"content_template and record_template are required for content_type "
1201
+ f"{content_type}"
1202
+ )
1203
+ if "$record" not in content_template:
1204
+ raise ValueError(
1205
+ f"Invalid content_template {content_template}."
1206
+ f" Please include either placeholder $records or $record."
1207
+ )
1208
+ if time_series_model_config and content_type not in [
1209
+ "application/json",
1210
+ "application/jsonlines",
1211
+ ]:
1212
+ raise ValueError(
1213
+ "``content_type`` must be JSON or JSONLines for time series explainability."
1214
+ )
1215
+ self.predictor_config["content_type"] = content_type
1216
+ if content_template is not None:
1217
+ self.predictor_config["content_template"] = content_template
1218
+ if record_template is not None:
1219
+ self.predictor_config["record_template"] = record_template
1220
+ _set(custom_attributes, "custom_attributes", self.predictor_config)
1221
+ _set(accelerator_type, "accelerator_type", self.predictor_config)
1222
+ _set(target_model, "target_model", self.predictor_config)
1223
+ if time_series_model_config:
1224
+ _set(
1225
+ time_series_model_config.get_time_series_model_config(),
1226
+ "time_series_predictor_config",
1227
+ self.predictor_config,
1228
+ )
1229
+
1230
+ def get_predictor_config(self):
1231
+ """Returns part of the predictor dictionary of the analysis config."""
1232
+ return copy.deepcopy(self.predictor_config)
1233
+
1234
+
1235
+ class ModelPredictedLabelConfig:
1236
+ """Config object to extract a predicted label from the model output."""
1237
+
1238
+ def __init__(
1239
+ self,
1240
+ label: Optional[Union[str, int]] = None,
1241
+ probability: Optional[Union[str, int]] = None,
1242
+ probability_threshold: Optional[float] = None,
1243
+ label_headers: Optional[List[str]] = None,
1244
+ ):
1245
+ """Initializes a model output config to extract the predicted label or predicted score(s).
1246
+
1247
+ The following examples show different parameter configurations depending on the endpoint:
1248
+
1249
+ * **Regression task:**
1250
+ The model returns the score, e.g. ``1.2``. We don't need to specify
1251
+ anything. For json output, e.g. ``{'score': 1.2}``, we can set ``label='score'``.
1252
+ * **Binary classification:**
1253
+
1254
+ * The model returns a single probability score. We want to classify as ``"yes"``
1255
+ predictions with a probability score over ``0.2``.
1256
+ We can set ``probability_threshold=0.2`` and ``label_headers="yes"``.
1257
+ * The model returns ``{"probability": 0.3}``, for which we would like to apply a
1258
+ threshold of ``0.5`` to obtain a predicted label in ``{0, 1}``.
1259
+ In this case we can set ``label="probability"``.
1260
+ * The model returns a tuple of the predicted label and the probability.
1261
+ In this case we can set ``label = 0``.
1262
+ * **Multiclass classification:**
1263
+
1264
+ * The model returns ``{'labels': ['cat', 'dog', 'fish'],
1265
+ 'probabilities': [0.35, 0.25, 0.4]}``. In this case we would set
1266
+ ``probability='probabilities'``, ``label='labels'``,
1267
+ and infer the predicted label to be ``'fish'``.
1268
+ * The model returns ``{'predicted_label': 'fish', 'probabilities': [0.35, 0.25, 0.4]}``.
1269
+ In this case we would set the ``label='predicted_label'``.
1270
+ * The model returns ``[0.35, 0.25, 0.4]``. In this case, we can set
1271
+ ``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``.
1272
+
1273
+ Args:
1274
+ label (str or int): Index or JMESPath expression to locate the prediction
1275
+ in the model output. In case, this is a predicted label of the same type
1276
+ as the label in the dataset, no further arguments need to be specified.
1277
+ probability (str or int): Index or JMESPath expression to locate the predicted score(s)
1278
+ in the model output.
1279
+ probability_threshold (float): An optional value for binary prediction tasks in which
1280
+ the model returns a probability, to indicate the threshold to convert the
1281
+ prediction to a boolean value. Default is ``0.5``.
1282
+ label_headers (list[str]): List of headers, each for a predicted score in model output.
1283
+ For bias analysis, it is used to extract the label value with the highest score as
1284
+ predicted label. For explainability jobs, it is used to beautify the analysis report
1285
+ by replacing placeholders like ``'label0'``.
1286
+
1287
+ Raises:
1288
+ TypeError: when the ``probability_threshold`` cannot be cast to a float
1289
+ """
1290
+ self.label = label
1291
+ self.probability = probability
1292
+ self.probability_threshold = probability_threshold
1293
+ self.label_headers = label_headers
1294
+ if probability_threshold is not None:
1295
+ try:
1296
+ float(probability_threshold)
1297
+ except ValueError:
1298
+ raise TypeError(
1299
+ f"Invalid probability_threshold {probability_threshold}. "
1300
+ f"Please choose one that can be cast to float."
1301
+ )
1302
+ self.predictor_config = {}
1303
+ _set(label, "label", self.predictor_config)
1304
+ _set(probability, "probability", self.predictor_config)
1305
+ _set(label_headers, "label_headers", self.predictor_config)
1306
+
1307
+ def get_predictor_config(self):
1308
+ """Returns ``probability_threshold`` and predictor config dictionary."""
1309
+ return self.probability_threshold, copy.deepcopy(self.predictor_config)
1310
+
1311
+
1312
+ class ExplainabilityConfig(ABC):
1313
+ """Abstract config class to configure an explainability method."""
1314
+
1315
+ @abstractmethod
1316
+ def get_explainability_config(self):
1317
+ """Returns config."""
1318
+ return None
1319
+
1320
+
1321
+ class PDPConfig(ExplainabilityConfig):
1322
+ """Config class for Partial Dependence Plots (PDP).
1323
+
1324
+ `PDPs <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-partial-dependence-plots.html>`_
1325
+ show the marginal effect (the dependence) a subset of features has on the predicted
1326
+ outcome of an ML model.
1327
+
1328
+ When PDP is requested (by passing in a :class:`~sagemaker.clarify.PDPConfig` to the
1329
+ ``explainability_config`` parameter of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`),
1330
+ the Partial Dependence Plots are included in the output
1331
+ `report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1332
+ and the corresponding values are included in the analysis output.
1333
+ """ # noqa E501
1334
+
1335
+ def __init__(
1336
+ self, features: Optional[List] = None, grid_resolution: int = 15, top_k_features: int = 10
1337
+ ):
1338
+ """Initializes PDP config.
1339
+
1340
+ Args:
1341
+ features (None or list): List of feature names or indices for which partial dependence
1342
+ plots are computed and plotted. When :class:`~sagemaker.clarify.ShapConfig`
1343
+ is provided, this parameter is optional, as Clarify will compute the
1344
+ partial dependence plots for top features based on
1345
+ `SHAP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`__
1346
+ attributions. When :class:`~sagemaker.clarify.ShapConfig` is not provided,
1347
+ ``features`` must be provided.
1348
+ grid_resolution (int): When using numerical features, this integer represents the
1349
+ number of buckets that the range of values must be divided into. This decides the
1350
+ granularity of the grid in which the PDP are plotted.
1351
+ top_k_features (int): Sets the number of top SHAP attributes used to compute
1352
+ partial dependence plots.
1353
+ """ # noqa E501
1354
+ self.pdp_config = {
1355
+ "grid_resolution": grid_resolution,
1356
+ "top_k_features": top_k_features,
1357
+ }
1358
+ if features is not None:
1359
+ self.pdp_config["features"] = features
1360
+
1361
+ def get_explainability_config(self):
1362
+ """Returns PDP config dictionary."""
1363
+ return copy.deepcopy({"pdp": self.pdp_config})
1364
+
1365
+
1366
+ class TextConfig:
1367
+ """Config object to handle text features for text explainability
1368
+
1369
+ `SHAP analysis <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html>`__
1370
+ breaks down longer text into chunks (e.g. tokens, sentences, or paragraphs)
1371
+ and replaces them with the strings specified in the baseline for that feature.
1372
+ The `shap value <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_
1373
+ of a chunk then captures how much replacing it affects the prediction.
1374
+ """ # noqa E501 # pylint: disable=c0301
1375
+
1376
+ _SUPPORTED_GRANULARITIES = ["token", "sentence", "paragraph"]
1377
+ _SUPPORTED_LANGUAGES = [
1378
+ "chinese",
1379
+ "zh",
1380
+ "danish",
1381
+ "da",
1382
+ "dutch",
1383
+ "nl",
1384
+ "english",
1385
+ "en",
1386
+ "french",
1387
+ "fr",
1388
+ "german",
1389
+ "de",
1390
+ "greek",
1391
+ "el",
1392
+ "italian",
1393
+ "it",
1394
+ "japanese",
1395
+ "ja",
1396
+ "lithuanian",
1397
+ "lt",
1398
+ "multi-language",
1399
+ "xx",
1400
+ "norwegian bokmål",
1401
+ "nb",
1402
+ "polish",
1403
+ "pl",
1404
+ "portuguese",
1405
+ "pt",
1406
+ "romanian",
1407
+ "ro",
1408
+ "russian",
1409
+ "ru",
1410
+ "spanish",
1411
+ "es",
1412
+ "afrikaans",
1413
+ "af",
1414
+ "albanian",
1415
+ "sq",
1416
+ "arabic",
1417
+ "ar",
1418
+ "armenian",
1419
+ "hy",
1420
+ "basque",
1421
+ "eu",
1422
+ "bengali",
1423
+ "bn",
1424
+ "bulgarian",
1425
+ "bg",
1426
+ "catalan",
1427
+ "ca",
1428
+ "croatian",
1429
+ "hr",
1430
+ "czech",
1431
+ "cs",
1432
+ "estonian",
1433
+ "et",
1434
+ "finnish",
1435
+ "fi",
1436
+ "gujarati",
1437
+ "gu",
1438
+ "hebrew",
1439
+ "he",
1440
+ "hindi",
1441
+ "hi",
1442
+ "hungarian",
1443
+ "hu",
1444
+ "icelandic",
1445
+ "is",
1446
+ "indonesian",
1447
+ "id",
1448
+ "irish",
1449
+ "ga",
1450
+ "kannada",
1451
+ "kn",
1452
+ "kyrgyz",
1453
+ "ky",
1454
+ "latvian",
1455
+ "lv",
1456
+ "ligurian",
1457
+ "lij",
1458
+ "luxembourgish",
1459
+ "lb",
1460
+ "macedonian",
1461
+ "mk",
1462
+ "malayalam",
1463
+ "ml",
1464
+ "marathi",
1465
+ "mr",
1466
+ "nepali",
1467
+ "ne",
1468
+ "persian",
1469
+ "fa",
1470
+ "sanskrit",
1471
+ "sa",
1472
+ "serbian",
1473
+ "sr",
1474
+ "setswana",
1475
+ "tn",
1476
+ "sinhala",
1477
+ "si",
1478
+ "slovak",
1479
+ "sk",
1480
+ "slovenian",
1481
+ "sl",
1482
+ "swedish",
1483
+ "sv",
1484
+ "tagalog",
1485
+ "tl",
1486
+ "tamil",
1487
+ "ta",
1488
+ "tatar",
1489
+ "tt",
1490
+ "telugu",
1491
+ "te",
1492
+ "thai",
1493
+ "th",
1494
+ "turkish",
1495
+ "tr",
1496
+ "ukrainian",
1497
+ "uk",
1498
+ "urdu",
1499
+ "ur",
1500
+ "vietnamese",
1501
+ "vi",
1502
+ "yoruba",
1503
+ "yo",
1504
+ ]
1505
+
1506
+ def __init__(
1507
+ self,
1508
+ granularity: str,
1509
+ language: str,
1510
+ ):
1511
+ """Initializes a text configuration.
1512
+
1513
+ Args:
1514
+ granularity (str): Determines the granularity in which text features are broken down
1515
+ to. Accepted values are ``"token"``, ``"sentence"``, or ``"paragraph"``.
1516
+ Computes `shap values <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_
1517
+ for these units.
1518
+ language (str): Specifies the language of the text features. Accepted values are
1519
+ one of the following:
1520
+ ``"chinese"``, ``"danish"``, ``"dutch"``, ``"english"``, ``"french"``, ``"german"``,
1521
+ ``"greek"``, ``"italian"``, ``"japanese"``, ``"lithuanian"``, ``"multi-language"``,
1522
+ ``"norwegian bokmål"``, ``"polish"``, ``"portuguese"``, ``"romanian"``,
1523
+ ``"russian"``, ``"spanish"``, ``"afrikaans"``, ``"albanian"``, ``"arabic"``,
1524
+ ``"armenian"``, ``"basque"``, ``"bengali"``, ``"bulgarian"``, ``"catalan"``,
1525
+ ``"croatian"``, ``"czech"``, ``"estonian"``, ``"finnish"``, ``"gujarati"``,
1526
+ ``"hebrew"``, ``"hindi"``, ``"hungarian"``, ``"icelandic"``, ``"indonesian"``,
1527
+ ``"irish"``, ``"kannada"``, ``"kyrgyz"``, ``"latvian"``, ``"ligurian"``,
1528
+ ``"luxembourgish"``, ``"macedonian"``, ``"malayalam"``, ``"marathi"``, ``"nepali"``,
1529
+ ``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``,
1530
+ ``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``,
1531
+ ``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``,
1532
+ ``"vietnamese"``, ``"yoruba"``. Use "multi-language" for a mix of multiple
1533
+ languages. The corresponding two-letter ISO codes are also accepted.
1534
+
1535
+ Raises:
1536
+ ValueError: when ``granularity`` is not in list of supported values
1537
+ or ``language`` is not in list of supported values
1538
+ """ # noqa E501 # pylint: disable=c0301
1539
+ if granularity not in TextConfig._SUPPORTED_GRANULARITIES:
1540
+ raise ValueError(
1541
+ f"Invalid granularity {granularity}. Please choose among "
1542
+ f"{TextConfig._SUPPORTED_GRANULARITIES}"
1543
+ )
1544
+ if language not in TextConfig._SUPPORTED_LANGUAGES:
1545
+ raise ValueError(
1546
+ f"Invalid language {language}. Please choose among "
1547
+ f"{TextConfig._SUPPORTED_LANGUAGES}"
1548
+ )
1549
+ self.text_config = {
1550
+ "granularity": granularity,
1551
+ "language": language,
1552
+ }
1553
+
1554
+ def get_text_config(self):
1555
+ """Returns a text config dictionary, part of the analysis config dictionary."""
1556
+ return copy.deepcopy(self.text_config)
1557
+
1558
+
1559
+ class ImageConfig:
1560
+ """Config object for handling images"""
1561
+
1562
+ def __init__(
1563
+ self,
1564
+ model_type: str,
1565
+ num_segments: Optional[int] = None,
1566
+ feature_extraction_method: Optional[str] = None,
1567
+ segment_compactness: Optional[float] = None,
1568
+ max_objects: Optional[int] = None,
1569
+ iou_threshold: Optional[float] = None,
1570
+ context: Optional[float] = None,
1571
+ ):
1572
+ """Initializes a config object for Computer Vision (CV) Image explainability.
1573
+
1574
+ `SHAP for CV explainability <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability-computer-vision.html>`__.
1575
+ generating heat maps that visualize feature attributions for input images.
1576
+ These heat maps highlight the image's features according
1577
+ to how much they contribute to the CV model prediction.
1578
+
1579
+ ``"IMAGE_CLASSIFICATION"`` and ``"OBJECT_DETECTION"`` are the two supported CV use cases.
1580
+
1581
+ Args:
1582
+ model_type (str): Specifies the type of CV model and use case. Accepted options:
1583
+ ``"IMAGE_CLASSIFICATION"`` or ``"OBJECT_DETECTION"``.
1584
+ num_segments (None or int): Approximate number of segments to generate when running
1585
+ SKLearn's `SLIC method <https://scikit-image.org/docs/dev/api/skimage.segmentation.html?highlight=slic#skimage.segmentation.slic>`_
1586
+ for image segmentation to generate features/superpixels.
1587
+ The default is None. When set to None, runs SLIC with 20 segments.
1588
+ feature_extraction_method (None or str): method used for extracting features from the
1589
+ image (ex: "segmentation"). Default is ``"segmentation"``.
1590
+ segment_compactness (None or float): Balances color proximity and space proximity.
1591
+ Higher values give more weight to space proximity, making superpixel
1592
+ shapes more square/cubic. We recommend exploring possible values on a log
1593
+ scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value.
1594
+ The default is None. When set to None, runs with the default value of ``5``.
1595
+ max_objects (None or int): Maximum number of objects displayed when running SHAP
1596
+ with an ``"OBJECT_DETECTION"`` model. The Object detection algorithm may detect
1597
+ more than the ``max_objects`` number of objects in a single image.
1598
+ In that case, the algorithm displays the top ``max_objects`` number of objects
1599
+ according to confidence score. Default value is None. In the ``"OBJECT_DETECTION"``
1600
+ case, passing in None leads to a default value of ``3``.
1601
+ iou_threshold (None or float): Minimum intersection over union for the object
1602
+ bounding box to consider its confidence score for computing SHAP values,
1603
+ in the range ``[0.0, 1.0]``. Used only for the ``"OBJECT_DETECTION"`` case,
1604
+ where passing in None sets the default value of ``0.5``.
1605
+ context (None or float): The portion of the image outside the bounding box used
1606
+ in SHAP analysis, in the range ``[0.0, 1.0]``. If set to ``1.0``, the whole image
1607
+ is considered; if set to ``0.0`` only the image inside bounding box is considered.
1608
+ Only used for the ``"OBJECT_DETECTION"`` case,
1609
+ when passing in None sets the default value of ``1.0``.
1610
+
1611
+ """ # noqa E501 # pylint: disable=c0301
1612
+ self.image_config = {}
1613
+
1614
+ if model_type not in ["OBJECT_DETECTION", "IMAGE_CLASSIFICATION"]:
1615
+ raise ValueError(
1616
+ "Clarify SHAP only supports object detection and image classification methods. "
1617
+ "Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION."
1618
+ )
1619
+ self.image_config["model_type"] = model_type
1620
+ _set(num_segments, "num_segments", self.image_config)
1621
+ _set(feature_extraction_method, "feature_extraction_method", self.image_config)
1622
+ _set(segment_compactness, "segment_compactness", self.image_config)
1623
+ _set(max_objects, "max_objects", self.image_config)
1624
+ _set(iou_threshold, "iou_threshold", self.image_config)
1625
+ _set(context, "context", self.image_config)
1626
+
1627
+ def get_image_config(self):
1628
+ """Returns the image config part of an analysis config dictionary."""
1629
+ return copy.deepcopy(self.image_config)
1630
+
1631
+
1632
+ class SHAPConfig(ExplainabilityConfig):
1633
+ """Config class for `SHAP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html>`__.
1634
+
1635
+ The SHAP algorithm calculates feature attributions by computing
1636
+ the contribution of each feature to the prediction outcome, using the concept of
1637
+ `Shapley values <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html>`_.
1638
+
1639
+ These attributions can be provided for specific predictions (locally)
1640
+ and at a global level for the model as a whole.
1641
+ """ # noqa E501 # pylint: disable=c0301
1642
+
1643
+ def __init__(
1644
+ self,
1645
+ baseline: Optional[Union[str, List, Dict]] = None,
1646
+ num_samples: Optional[int] = None,
1647
+ agg_method: Optional[str] = None,
1648
+ use_logit: bool = False,
1649
+ save_local_shap_values: bool = True,
1650
+ seed: Optional[int] = None,
1651
+ num_clusters: Optional[int] = None,
1652
+ text_config: Optional[TextConfig] = None,
1653
+ image_config: Optional[ImageConfig] = None,
1654
+ features_to_explain: Optional[List[Union[str, int]]] = None,
1655
+ ):
1656
+ """Initializes config for SHAP analysis.
1657
+
1658
+ Args:
1659
+ baseline (None or str or list or dict): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
1660
+ for the Kernel SHAP algorithm, accepted in the form of:
1661
+ S3 object URI, a list of rows (with at least one element),
1662
+ or None (for no input baseline). The baseline dataset must have the same format
1663
+ as the input dataset specified in :class:`~sagemaker.clarify.DataConfig`.
1664
+ Each row must have only the feature columns/values and omit the label column/values.
1665
+ If None, a baseline will be calculated automatically on the input dataset
1666
+ using K-means (for numerical data) or K-prototypes (if there is categorical data).
1667
+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
1668
+ This number determines the size of the generated synthetic dataset to compute the
1669
+ SHAP values. If not provided then Clarify job will choose a proper value according
1670
+ to the count of features.
1671
+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
1672
+ ``"mean_abs"`` (mean of absolute SHAP values for all instances),
1673
+ ``"median"`` (median of SHAP values for all instances) and
1674
+ ``"mean_sq"`` (mean of squared SHAP values for all instances).
1675
+ If None is provided, then Clarify job uses the method ``"mean_abs"``.
1676
+ use_logit (bool): Indicates whether to apply the logit function to model predictions.
1677
+ Default is False. If ``use_logit`` is true then the SHAP values will
1678
+ have log-odds units.
1679
+ save_local_shap_values (bool): Indicates whether to save the local SHAP values
1680
+ in the output location. Default is True.
1681
+ seed (int): Seed value to get deterministic SHAP values. Default is None.
1682
+ num_clusters (None or int): If a ``baseline`` is not provided, Clarify automatically
1683
+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes), which
1684
+ takes ``num_clusters`` as a parameter. ``num_clusters`` will be the resulting size
1685
+ of the baseline dataset. If not provided, Clarify job uses a default value.
1686
+ text_config (:class:`~sagemaker.clarify.TextConfig`): Config object for handling
1687
+ text features. Default is None.
1688
+ image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
1689
+ features. Default is None.
1690
+ features_to_explain: A list of names or indices of dataset features to compute SHAP
1691
+ values for. If not provided, SHAP values are computed for all features by default.
1692
+ Currently only supported for tabular datasets.
1693
+
1694
+ Raises:
1695
+ ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided
1696
+ together, or ``features_to_explain`` is specified when ``text_config`` or
1697
+ ``image_config`` is provided
1698
+ """ # noqa E501 # pylint: disable=c0301
1699
+ if agg_method is not None and agg_method not in [
1700
+ "mean_abs",
1701
+ "median",
1702
+ "mean_sq",
1703
+ ]:
1704
+ raise ValueError(
1705
+ f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
1706
+ )
1707
+ if num_clusters is not None and baseline is not None:
1708
+ raise ValueError(
1709
+ "Baseline and num_clusters cannot be provided together. "
1710
+ "Please specify one of the two."
1711
+ )
1712
+ self.shap_config = {
1713
+ "use_logit": use_logit,
1714
+ "save_local_shap_values": save_local_shap_values,
1715
+ }
1716
+ _set(baseline, "baseline", self.shap_config)
1717
+ _set(num_samples, "num_samples", self.shap_config)
1718
+ _set(agg_method, "agg_method", self.shap_config)
1719
+ _set(seed, "seed", self.shap_config)
1720
+ _set(num_clusters, "num_clusters", self.shap_config)
1721
+ if text_config:
1722
+ _set(text_config.get_text_config(), "text_config", self.shap_config)
1723
+ if not save_local_shap_values:
1724
+ logger.warning(
1725
+ "Global aggregation is not yet supported for text features. "
1726
+ "Consider setting save_local_shap_values=True to inspect local text "
1727
+ "explanations."
1728
+ )
1729
+ if image_config:
1730
+ _set(image_config.get_image_config(), "image_config", self.shap_config)
1731
+ if features_to_explain is not None and (
1732
+ text_config is not None or image_config is not None
1733
+ ):
1734
+ raise ValueError(
1735
+ "`features_to_explain` is not supported for datasets containing text features or images."
1736
+ )
1737
+ _set(features_to_explain, "features_to_explain", self.shap_config)
1738
+
1739
+ def get_explainability_config(self):
1740
+ """Returns a shap config dictionary."""
1741
+ return copy.deepcopy({"shap": self.shap_config})
1742
+
1743
+
1744
+ class AsymmetricShapleyValueConfig(ExplainabilityConfig):
1745
+ """Config class for Asymmetric Shapley value algorithm for time series explainability.
1746
+
1747
+ Asymmetric Shapley Values are a variant of the Shapley Value that drop the symmetry axiom [1].
1748
+ We use these to determine how features contribute to the forecasting outcome. Asymmetric
1749
+ Shapley values can take into account the temporal dependencies of the time series that
1750
+ forecasting models take as input.
1751
+
1752
+ [1] Frye, Christopher, Colin Rowat, and Ilya Feige. "Asymmetric shapley values: incorporating
1753
+ causal knowledge into model-agnostic explainability." NeurIPS (2020).
1754
+ https://doi.org/10.48550/arXiv.1910.06358
1755
+ """
1756
+
1757
+ def __init__(
1758
+ self,
1759
+ direction: Literal[
1760
+ "chronological",
1761
+ "anti_chronological",
1762
+ "bidirectional",
1763
+ ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION,
1764
+ granularity: Literal[
1765
+ "timewise",
1766
+ "fine_grained",
1767
+ ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY,
1768
+ num_samples: Optional[int] = None,
1769
+ baseline: Optional[Union[str, Dict[str, Any]]] = None,
1770
+ ):
1771
+ """Initialises config for time series explainability with Asymmetric Shapley Values.
1772
+
1773
+ AsymmetricShapleyValueConfig is used specifically and only for TimeSeries explainability
1774
+ purposes.
1775
+
1776
+ Args:
1777
+ direction (str): Type of explanation to be used. Available explanation
1778
+ types are ``"chronological"``, ``"anti_chronological"``, and ``"bidirectional"``.
1779
+ granularity (str): Explanation granularity to be used. Available granularity options
1780
+ are ``"timewise"`` and ``"fine_grained"``.
1781
+ num_samples (None or int): Number of samples to be used in the Asymmetric Shapley
1782
+ Value forecasting algorithm. Only applicable when using ``"fine_grained"``
1783
+ explanations.
1784
+ baseline (str or dict): Link to a baseline configuration or a dictionary for it. The
1785
+ baseline config is used to replace out-of-coalition values for the corresponding
1786
+ datasets (also known as background data). For temporal data (target time series,
1787
+ related time series), the baseline value types are "zero", where all
1788
+ out-of-coalition values will be replaced with 0.0, or "mean", all out-of-coalition
1789
+ values will be replaced with the average of a time series. For static data
1790
+ (static covariates), a baseline value for each covariate should be provided for
1791
+ each possible item_id. An example config follows, where ``item1`` and ``item2``
1792
+ are item ids::
1793
+
1794
+ {
1795
+ "target_time_series": "zero",
1796
+ "related_time_series": "zero",
1797
+ "static_covariates": {
1798
+ "item1": [1, 1],
1799
+ "item2": [0, 1]
1800
+ }
1801
+ }
1802
+
1803
+ Raises:
1804
+ ValueError: when ``direction`` or ``granularity`` are not valid, ``num_samples`` is not
1805
+ provided for fine-grained explanations, ``num_samples`` is provided for non
1806
+ fine-grained explanations, or when ``direction`` is not ``"chronological"`` while
1807
+ ``granularity`` is ``"fine_grained"``.
1808
+ """
1809
+ self.asymmetric_shapley_value_config = dict()
1810
+ # validate explanation direction
1811
+ if direction not in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS:
1812
+ raise ValueError(
1813
+ "Please provide a valid explanation direction from: "
1814
+ + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS)
1815
+ )
1816
+ # validate granularity
1817
+ if granularity not in ASYM_SHAP_VAL_GRANULARITIES:
1818
+ raise ValueError(
1819
+ "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES)
1820
+ )
1821
+ if granularity == "fine_grained":
1822
+ if not isinstance(num_samples, int):
1823
+ raise ValueError("Please provide an integer for ``num_samples``.")
1824
+ if direction != "chronological":
1825
+ raise ValueError(
1826
+ f"{direction} and {granularity} granularity are not supported together."
1827
+ )
1828
+ elif num_samples: # validate num_samples is not provided when unnecessary
1829
+ raise ValueError("``num_samples`` is only used for fine-grained explanations.")
1830
+ # validate baseline if provided as a dictionary
1831
+ if isinstance(baseline, dict):
1832
+ temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields
1833
+ if "target_time_series" in baseline:
1834
+ target_baseline = baseline.get("target_time_series")
1835
+ if target_baseline not in temporal_baselines:
1836
+ raise ValueError(
1837
+ f"Provided value {target_baseline} for ``target_time_series`` is "
1838
+ f"invalid. Please select one of {temporal_baselines}."
1839
+ )
1840
+ if "related_time_series" in baseline:
1841
+ related_baseline = baseline.get("related_time_series")
1842
+ if related_baseline not in temporal_baselines:
1843
+ raise ValueError(
1844
+ f"Provided value {related_baseline} for ``related_time_series`` is "
1845
+ f"invalid. Please select one of {temporal_baselines}."
1846
+ )
1847
+ # set explanation type and (if provided) num_samples in internal config dictionary
1848
+ _set(direction, "direction", self.asymmetric_shapley_value_config)
1849
+ _set(granularity, "granularity", self.asymmetric_shapley_value_config)
1850
+ _set(
1851
+ num_samples, "num_samples", self.asymmetric_shapley_value_config
1852
+ ) # _set() does nothing if a given argument is None
1853
+ _set(baseline, "baseline", self.asymmetric_shapley_value_config)
1854
+
1855
+ def get_explainability_config(self):
1856
+ """Returns an asymmetric shap config dictionary."""
1857
+ return copy.deepcopy({"asymmetric_shapley_value": self.asymmetric_shapley_value_config})
1858
+
1859
+
1860
+ class SageMakerClarifyProcessor(Processor):
1861
+ """Handles SageMaker Processing tasks to compute bias metrics and model explanations."""
1862
+
1863
+ _CLARIFY_DATA_INPUT = "/opt/ml/processing/input/data"
1864
+ _CLARIFY_CONFIG_INPUT = "/opt/ml/processing/input/config"
1865
+ _CLARIFY_OUTPUT = "/opt/ml/processing/output"
1866
+
1867
+ def __init__(
1868
+ self,
1869
+ role: Optional[str] = None,
1870
+ instance_count: int = None,
1871
+ instance_type: str = None,
1872
+ volume_size_in_gb: int = 30,
1873
+ volume_kms_key: Optional[str] = None,
1874
+ output_kms_key: Optional[str] = None,
1875
+ max_runtime_in_seconds: Optional[int] = None,
1876
+ sagemaker_session: Optional[Session] = None,
1877
+ env: Optional[Dict[str, str]] = None,
1878
+ tags: Optional[Tags] = None,
1879
+ network_config: Optional[NetworkConfig] = None,
1880
+ job_name_prefix: Optional[str] = None,
1881
+ version: Optional[str] = None,
1882
+ skip_early_validation: bool = False,
1883
+ ):
1884
+ """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
1885
+
1886
+ Instance of :class:`~sagemaker.processing.Processor`.
1887
+
1888
+ Args:
1889
+ role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
1890
+ uses this role to access AWS resources, such as
1891
+ data stored in Amazon S3.
1892
+ instance_count (int): The number of instances to run
1893
+ a processing job with.
1894
+ instance_type (str): The type of
1895
+ `EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
1896
+ to use for model inference; for example, ``"ml.c5.xlarge"``.
1897
+ volume_size_in_gb (int): Size in GB of the
1898
+ `EBS volume <https://docs.aws.amazon.com/sagemaker/latest/dg/host-instance-storage.html>`_.
1899
+ to use for storing data during processing (default: 30 GB).
1900
+ volume_kms_key (str): A
1901
+ `KMS key <https://docs.aws.amazon.com/sagemaker/latest/dg/key-management.html>`_
1902
+ for the processing volume (default: None).
1903
+ output_kms_key (str): The KMS key ID for processing job outputs (default: None).
1904
+ max_runtime_in_seconds (int): Timeout in seconds (default: None).
1905
+ After this amount of time, Amazon SageMaker terminates the job,
1906
+ regardless of its current status. If ``max_runtime_in_seconds`` is not
1907
+ specified, the default value is ``86400`` seconds (24 hours).
1908
+ sagemaker_session (:class:`~sagemaker.session.Session`):
1909
+ :class:`~sagemaker.session.Session` object which manages interactions
1910
+ with Amazon SageMaker and any other AWS services needed. If not specified,
1911
+ the Processor creates a :class:`~sagemaker.session.Session`
1912
+ using the default AWS configuration chain.
1913
+ env (dict[str, str]): Environment variables to be passed to
1914
+ the processing jobs (default: None).
1915
+ tags (Optional[Tags]): Tags to be passed to the processing job
1916
+ (default: None). For more, see
1917
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1918
+ network_config (:class:`~sagemaker.network.NetworkConfig`):
1919
+ A :class:`~sagemaker.network.NetworkConfig`
1920
+ object that configures network isolation, encryption of
1921
+ inter-container traffic, security group IDs, and subnets.
1922
+ job_name_prefix (str): Processing job name prefix.
1923
+ version (str): Clarify version to use.
1924
+ skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json.
1925
+ """ # noqa E501 # pylint: disable=c0301
1926
+ container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
1927
+ self._last_analysis_config = None
1928
+ self.job_name_prefix = job_name_prefix
1929
+ self.skip_early_validation = skip_early_validation
1930
+ super(SageMakerClarifyProcessor, self).__init__(
1931
+ role,
1932
+ container_uri,
1933
+ instance_count,
1934
+ instance_type,
1935
+ None, # We manage the entrypoint.
1936
+ volume_size_in_gb,
1937
+ volume_kms_key,
1938
+ output_kms_key,
1939
+ max_runtime_in_seconds,
1940
+ None, # We set method-specific job names below.
1941
+ sagemaker_session,
1942
+ env,
1943
+ format_tags(tags),
1944
+ network_config,
1945
+ )
1946
+
1947
+ def run(self, **_):
1948
+ """Overriding the base class method but deferring to specific run_* methods."""
1949
+ raise NotImplementedError(
1950
+ "Please choose a method of run_pre_training_bias, run_post_training_bias or "
1951
+ "run_explainability."
1952
+ )
1953
+
1954
+ def _run(
1955
+ self,
1956
+ data_config: DataConfig,
1957
+ analysis_config: Dict[str, Any],
1958
+ wait: bool,
1959
+ logs: bool,
1960
+ job_name: str,
1961
+ kms_key: str,
1962
+ experiment_config: Dict[str, str],
1963
+ ):
1964
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container
1965
+
1966
+ and analysis config.
1967
+
1968
+ Args:
1969
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1970
+ analysis_config (dict): Config following the analysis_config.json format.
1971
+ wait (bool): Whether the call should wait until the job completes (default: True).
1972
+ logs (bool): Whether to show the logs produced by the job.
1973
+ Only meaningful when ``wait`` is True (default: True).
1974
+ job_name (str): Processing job name.
1975
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
1976
+ user code file (default: None).
1977
+ experiment_config (dict[str, str]): Experiment management configuration.
1978
+ Optionally, the dict can contain three keys:
1979
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1980
+
1981
+ The behavior of setting these keys is as follows:
1982
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1983
+ automatically created and the job's Trial Component associated with the Trial.
1984
+ * If ``'TrialName'`` is supplied and the Trial already exists,
1985
+ the job's Trial Component will be associated with the Trial.
1986
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1987
+ the Trial Component will be unassociated.
1988
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1989
+ """
1990
+ # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
1991
+ self._last_analysis_config = analysis_config
1992
+ logger.info("Analysis Config: %s", analysis_config)
1993
+ if not self.skip_early_validation:
1994
+ ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
1995
+
1996
+ with tempfile.TemporaryDirectory() as tmpdirname:
1997
+ analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
1998
+ with open(analysis_config_file, "w") as f:
1999
+ json.dump(analysis_config, f)
2000
+ s3_analysis_config_file = _upload_analysis_config(
2001
+ analysis_config_file,
2002
+ data_config.s3_analysis_config_output_path or data_config.s3_output_path,
2003
+ self.sagemaker_session,
2004
+ kms_key,
2005
+ )
2006
+ from sagemaker.core.shapes import ProcessingS3Input, ProcessingS3Output
2007
+
2008
+ config_input = ProcessingInput(
2009
+ input_name="analysis_config",
2010
+ s3_input=ProcessingS3Input(
2011
+ s3_uri=s3_analysis_config_file,
2012
+ local_path=self._CLARIFY_CONFIG_INPUT,
2013
+ s3_data_type="S3Prefix",
2014
+ s3_input_mode="File",
2015
+ s3_compression_type="None",
2016
+ ),
2017
+ )
2018
+ data_input = ProcessingInput(
2019
+ input_name="dataset",
2020
+ s3_input=ProcessingS3Input(
2021
+ s3_uri=data_config.s3_data_input_path,
2022
+ local_path=self._CLARIFY_DATA_INPUT,
2023
+ s3_data_type="S3Prefix",
2024
+ s3_input_mode="File",
2025
+ s3_data_distribution_type=data_config.s3_data_distribution_type,
2026
+ s3_compression_type=data_config.s3_compression_type,
2027
+ ),
2028
+ )
2029
+ result_output = ProcessingOutput(
2030
+ output_name="analysis_result",
2031
+ s3_output=ProcessingS3Output(
2032
+ s3_uri=data_config.s3_output_path,
2033
+ local_path=self._CLARIFY_OUTPUT,
2034
+ s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
2035
+ ),
2036
+ )
2037
+
2038
+ return super().run(
2039
+ inputs=[data_input, config_input],
2040
+ outputs=[result_output],
2041
+ wait=wait,
2042
+ logs=logs,
2043
+ job_name=job_name,
2044
+ kms_key=kms_key,
2045
+ experiment_config=experiment_config,
2046
+ )
2047
+
2048
+ def run_pre_training_bias(
2049
+ self,
2050
+ data_config: DataConfig,
2051
+ data_bias_config: BiasConfig,
2052
+ methods: Union[str, List[str]] = "all",
2053
+ wait: bool = True,
2054
+ logs: bool = True,
2055
+ job_name: Optional[str] = None,
2056
+ kms_key: Optional[str] = None,
2057
+ experiment_config: Optional[Dict[str, str]] = None,
2058
+ ):
2059
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods
2060
+
2061
+ Computes the requested ``methods`` on the input data. The ``methods`` compare
2062
+ metrics (e.g. fraction of examples) for the sensitive group(s) vs. the other examples.
2063
+
2064
+ Args:
2065
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
2066
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
2067
+ methods (str or list[str]): Selects a subset of potential metrics:
2068
+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
2069
+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
2070
+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
2071
+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
2072
+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
2073
+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
2074
+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
2075
+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
2076
+ Defaults to str "all" to run all metrics if left unspecified.
2077
+ wait (bool): Whether the call should wait until the job completes (default: True).
2078
+ logs (bool): Whether to show the logs produced by the job.
2079
+ Only meaningful when ``wait`` is True (default: True).
2080
+ job_name (str): Processing job name. When ``job_name`` is not specified,
2081
+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is
2082
+ specified, the job name will be the ``job_name_prefix`` and current timestamp;
2083
+ otherwise use ``"Clarify-Pretraining-Bias"`` as prefix.
2084
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2085
+ user code file (default: None).
2086
+ experiment_config (dict[str, str]): Experiment management configuration.
2087
+ Optionally, the dict can contain three keys:
2088
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
2089
+
2090
+ The behavior of setting these keys is as follows:
2091
+
2092
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
2093
+ automatically created and the job's Trial Component associated with the Trial.
2094
+ * If ``'TrialName'`` is supplied and the Trial already exists,
2095
+ the job's Trial Component will be associated with the Trial.
2096
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
2097
+ the Trial Component will be unassociated.
2098
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
2099
+ """ # noqa E501 # pylint: disable=c0301
2100
+ analysis_config = _AnalysisConfigGenerator.bias_pre_training(
2101
+ data_config, data_bias_config, methods
2102
+ )
2103
+ # when name is either not provided (is None) or an empty string ("")
2104
+ job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Pretraining-Bias")
2105
+ return self._run(
2106
+ data_config,
2107
+ analysis_config,
2108
+ wait,
2109
+ logs,
2110
+ job_name,
2111
+ kms_key,
2112
+ experiment_config,
2113
+ )
2114
+
2115
+ def run_post_training_bias(
2116
+ self,
2117
+ data_config: DataConfig,
2118
+ data_bias_config: BiasConfig,
2119
+ model_config: Optional[ModelConfig] = None,
2120
+ model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
2121
+ methods: Union[str, List[str]] = "all",
2122
+ wait: bool = True,
2123
+ logs: bool = True,
2124
+ job_name: Optional[str] = None,
2125
+ kms_key: Optional[str] = None,
2126
+ experiment_config: Optional[Dict[str, str]] = None,
2127
+ ):
2128
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias
2129
+
2130
+ Spins up a model endpoint and runs inference over the input dataset in
2131
+ the ``s3_data_input_path`` (from the :class:`~sagemaker.clarify.DataConfig`) to obtain
2132
+ predicted labels. Using model predictions, computes the requested posttraining bias
2133
+ ``methods`` that compare metrics (e.g. accuracy, precision, recall) for the
2134
+ sensitive group(s) versus the other examples.
2135
+
2136
+ Args:
2137
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
2138
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
2139
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
2140
+ endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
2141
+ ``predicted_label`` is provided in ``data_config``.
2142
+ model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
2143
+ Config of how to extract the predicted label from the model output.
2144
+ methods (str or list[str]): Selector of a subset of potential metrics:
2145
+ ["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
2146
+ , "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
2147
+ "`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
2148
+ "`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
2149
+ "`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
2150
+ "`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
2151
+ "`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
2152
+ "`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
2153
+ "`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
2154
+ ", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
2155
+ "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
2156
+ Defaults to str "all" to run all metrics if left unspecified.
2157
+ wait (bool): Whether the call should wait until the job completes (default: True).
2158
+ logs (bool): Whether to show the logs produced by the job.
2159
+ Only meaningful when ``wait`` is True (default: True).
2160
+ job_name (str): Processing job name. When ``job_name`` is not specified,
2161
+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
2162
+ is specified, the job name will be the ``job_name_prefix`` and current timestamp;
2163
+ otherwise use ``"Clarify-Posttraining-Bias"`` as prefix.
2164
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2165
+ user code file (default: None).
2166
+ experiment_config (dict[str, str]): Experiment management configuration.
2167
+ Optionally, the dict can contain three keys:
2168
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
2169
+
2170
+ The behavior of setting these keys is as follows:
2171
+
2172
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
2173
+ automatically created and the job's Trial Component associated with the Trial.
2174
+ * If ``'TrialName'`` is supplied and the Trial already exists,
2175
+ the job's Trial Component will be associated with the Trial.
2176
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
2177
+ the Trial Component will be unassociated.
2178
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
2179
+ """ # noqa E501 # pylint: disable=c0301
2180
+ analysis_config = _AnalysisConfigGenerator.bias_post_training(
2181
+ data_config,
2182
+ data_bias_config,
2183
+ model_predicted_label_config,
2184
+ methods,
2185
+ model_config,
2186
+ )
2187
+ # when name is either not provided (is None) or an empty string ("")
2188
+ job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Posttraining-Bias")
2189
+ return self._run(
2190
+ data_config,
2191
+ analysis_config,
2192
+ wait,
2193
+ logs,
2194
+ job_name,
2195
+ kms_key,
2196
+ experiment_config,
2197
+ )
2198
+
2199
+ def run_bias(
2200
+ self,
2201
+ data_config: DataConfig,
2202
+ bias_config: BiasConfig,
2203
+ model_config: Optional[ModelConfig] = None,
2204
+ model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
2205
+ pre_training_methods: Union[str, List[str]] = "all",
2206
+ post_training_methods: Union[str, List[str]] = "all",
2207
+ wait: bool = True,
2208
+ logs: bool = True,
2209
+ job_name: Optional[str] = None,
2210
+ kms_key: Optional[str] = None,
2211
+ experiment_config: Optional[Dict[str, str]] = None,
2212
+ ):
2213
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods
2214
+
2215
+ Computes metrics for both the pre-training and the post-training methods.
2216
+ To calculate post-training methods, it spins up a model endpoint and runs inference over the
2217
+ input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
2218
+ to obtain predicted labels.
2219
+
2220
+ Args:
2221
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
2222
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
2223
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
2224
+ endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
2225
+ ``predicted_label`` is provided in ``data_config``.
2226
+ model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
2227
+ Config of how to extract the predicted label from the model output.
2228
+ pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
2229
+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
2230
+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
2231
+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
2232
+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
2233
+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
2234
+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
2235
+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
2236
+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
2237
+ Defaults to str "all" to run all metrics if left unspecified.
2238
+ post_training_methods (str or list[str]): Selector of a subset of potential metrics:
2239
+ ["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
2240
+ , "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
2241
+ "`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
2242
+ "`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
2243
+ "`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
2244
+ "`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
2245
+ "`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
2246
+ "`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
2247
+ "`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
2248
+ ", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
2249
+ "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
2250
+ Defaults to str "all" to run all metrics if left unspecified.
2251
+ wait (bool): Whether the call should wait until the job completes (default: True).
2252
+ logs (bool): Whether to show the logs produced by the job.
2253
+ Only meaningful when ``wait`` is True (default: True).
2254
+ job_name (str): Processing job name. When ``job_name`` is not specified,
2255
+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is
2256
+ specified, the job name will be ``job_name_prefix`` and the current timestamp;
2257
+ otherwise use ``"Clarify-Bias"`` as prefix.
2258
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2259
+ user code file (default: None).
2260
+ experiment_config (dict[str, str]): Experiment management configuration.
2261
+ Optionally, the dict can contain three keys:
2262
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
2263
+
2264
+ The behavior of setting these keys is as follows:
2265
+
2266
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
2267
+ automatically created and the job's Trial Component associated with the Trial.
2268
+ * If ``'TrialName'`` is supplied and the Trial already exists,
2269
+ the job's Trial Component will be associated with the Trial.
2270
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
2271
+ the Trial Component will be unassociated.
2272
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
2273
+ """ # noqa E501 # pylint: disable=c0301
2274
+ analysis_config = _AnalysisConfigGenerator.bias(
2275
+ data_config,
2276
+ bias_config,
2277
+ model_config,
2278
+ model_predicted_label_config,
2279
+ pre_training_methods,
2280
+ post_training_methods,
2281
+ )
2282
+ # when name is either not provided (is None) or an empty string ("")
2283
+ job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Bias")
2284
+ return self._run(
2285
+ data_config,
2286
+ analysis_config,
2287
+ wait,
2288
+ logs,
2289
+ job_name,
2290
+ kms_key,
2291
+ experiment_config,
2292
+ )
2293
+
2294
+ def run_explainability(
2295
+ self,
2296
+ data_config: DataConfig,
2297
+ model_config: ModelConfig,
2298
+ explainability_config: Union[ExplainabilityConfig, List],
2299
+ model_scores: Optional[Union[int, str, ModelPredictedLabelConfig]] = None,
2300
+ wait: bool = True,
2301
+ logs: bool = True,
2302
+ job_name: Optional[str] = None,
2303
+ kms_key: Optional[str] = None,
2304
+ experiment_config: Optional[Dict[str, str]] = None,
2305
+ ):
2306
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
2307
+
2308
+ Spins up a model endpoint.
2309
+
2310
+ Currently, only SHAP and Partial Dependence Plots (PDP) are supported
2311
+ as explainability methods.
2312
+ You can request both methods or one at a time with the ``explainability_config`` parameter.
2313
+
2314
+ When SHAP is requested in the ``explainability_config``,
2315
+ the SHAP algorithm calculates the feature importance for each input example
2316
+ in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
2317
+ by creating ``num_samples`` copies of the example with a subset of features
2318
+ replaced with values from the ``baseline``.
2319
+ It then runs model inference to see how the model's prediction changes with the replaced
2320
+ features. If the model output returns multiple scores importance is computed for each score.
2321
+ Across examples, feature importance is aggregated using ``agg_method``.
2322
+
2323
+ When PDP is requested in the ``explainability_config``,
2324
+ the PDP algorithm calculates the dependence of the target response
2325
+ on the input features and marginalizes over the values of all other input features.
2326
+ The Partial Dependence Plots are included in the output
2327
+ `report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
2328
+ and the corresponding values are included in the analysis output.
2329
+
2330
+ Args:
2331
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
2332
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
2333
+ endpoint to be created.
2334
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
2335
+ Config of the specific explainability method or a list of
2336
+ :class:`~sagemaker.clarify.ExplainabilityConfig` objects.
2337
+ Currently, SHAP and PDP are the two methods supported.
2338
+ You can request multiple methods at once by passing in a list of
2339
+ `~sagemaker.clarify.ExplainabilityConfig`.
2340
+ model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
2341
+ Index or JMESPath expression to locate the predicted scores in the model output.
2342
+ This is not required if the model output is a single score. Alternatively,
2343
+ it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
2344
+ to provide more parameters like ``label_headers``.
2345
+ wait (bool): Whether the call should wait until the job completes (default: True).
2346
+ logs (bool): Whether to show the logs produced by the job.
2347
+ Only meaningful when ``wait`` is True (default: True).
2348
+ job_name (str): Processing job name. When ``job_name`` is not specified,
2349
+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
2350
+ is specified, the job name will be composed of ``job_name_prefix`` and current
2351
+ timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
2352
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2353
+ user code file (default: None).
2354
+ experiment_config (dict[str, str]): Experiment management configuration.
2355
+ Optionally, the dict can contain three keys:
2356
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
2357
+
2358
+ The behavior of setting these keys is as follows:
2359
+
2360
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
2361
+ automatically created and the job's Trial Component associated with the Trial.
2362
+ * If ``'TrialName'`` is supplied and the Trial already exists,
2363
+ the job's Trial Component will be associated with the Trial.
2364
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
2365
+ the Trial Component will be unassociated.
2366
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
2367
+ """ # noqa E501 # pylint: disable=c0301
2368
+ analysis_config = _AnalysisConfigGenerator.explainability(
2369
+ data_config, model_config, model_scores, explainability_config
2370
+ )
2371
+ # when name is either not provided (is None) or an empty string ("")
2372
+ job_name = job_name or name_from_base(self.job_name_prefix or "Clarify-Explainability")
2373
+ return self._run(
2374
+ data_config,
2375
+ analysis_config,
2376
+ wait,
2377
+ logs,
2378
+ job_name,
2379
+ kms_key,
2380
+ experiment_config,
2381
+ )
2382
+
2383
+ def run_bias_and_explainability(
2384
+ self,
2385
+ data_config: DataConfig,
2386
+ model_config: ModelConfig,
2387
+ explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
2388
+ bias_config: BiasConfig,
2389
+ pre_training_methods: Union[str, List[str]] = "all",
2390
+ post_training_methods: Union[str, List[str]] = "all",
2391
+ model_predicted_label_config: ModelPredictedLabelConfig = None,
2392
+ wait=True,
2393
+ logs=True,
2394
+ job_name=None,
2395
+ kms_key=None,
2396
+ experiment_config=None,
2397
+ ):
2398
+ """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
2399
+
2400
+ For bias:
2401
+ Computes metrics for both the pre-training and the post-training methods.
2402
+ To calculate post-training methods, it spins up a model endpoint and runs inference over the
2403
+ input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
2404
+ to obtain predicted labels.
2405
+
2406
+ For Explainability:
2407
+ Spins up a model endpoint.
2408
+
2409
+ Currently, only SHAP and Partial Dependence Plots (PDP) are supported
2410
+ as explainability methods.
2411
+ You can request both methods or one at a time with the ``explainability_config`` parameter.
2412
+
2413
+ When SHAP is requested in the ``explainability_config``,
2414
+ the SHAP algorithm calculates the feature importance for each input example
2415
+ in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
2416
+ by creating ``num_samples`` copies of the example with a subset of features
2417
+ replaced with values from the ``baseline``.
2418
+ It then runs model inference to see how the model's prediction changes with the replaced
2419
+ features. If the model output returns multiple scores importance is computed for each score.
2420
+ Across examples, feature importance is aggregated using ``agg_method``.
2421
+
2422
+ When PDP is requested in the ``explainability_config``,
2423
+ the PDP algorithm calculates the dependence of the target response
2424
+ on the input features and marginalizes over the values of all other input features.
2425
+ The Partial Dependence Plots are included in the output
2426
+ `report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
2427
+ and the corresponding values are included in the analysis output.
2428
+
2429
+ Args:
2430
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
2431
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
2432
+ endpoint to be created.
2433
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
2434
+ Config of the specific explainability method or a list of
2435
+ :class:`~sagemaker.clarify.ExplainabilityConfig` objects.
2436
+ Currently, SHAP and PDP are the two methods supported.
2437
+ You can request multiple methods at once by passing in a list of
2438
+ `~sagemaker.clarify.ExplainabilityConfig`.
2439
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
2440
+ pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
2441
+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
2442
+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
2443
+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
2444
+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
2445
+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
2446
+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
2447
+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
2448
+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
2449
+ Defaults to str "all" to run all metrics if left unspecified.
2450
+ post_training_methods (str or list[str]): Selector of a subset of potential metrics:
2451
+ ["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
2452
+ , "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
2453
+ "`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
2454
+ "`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
2455
+ "`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
2456
+ "`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
2457
+ "`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
2458
+ "`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
2459
+ "`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
2460
+ ", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
2461
+ "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
2462
+ Defaults to str "all" to run all metrics if left unspecified.
2463
+ model_predicted_label_config (
2464
+ int or
2465
+ str or
2466
+ :class:`~sagemaker.clarify.ModelPredictedLabelConfig`
2467
+ ):
2468
+ Index or JMESPath expression to locate the predicted scores in the model output.
2469
+ This is not required if the model output is a single score. Alternatively,
2470
+ it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
2471
+ to provide more parameters like ``label_headers``.
2472
+ wait (bool): Whether the call should wait until the job completes (default: True).
2473
+ logs (bool): Whether to show the logs produced by the job.
2474
+ Only meaningful when ``wait`` is True (default: True).
2475
+ job_name (str): Processing job name. When ``job_name`` is not specified,
2476
+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
2477
+ is specified, the job name will be composed of ``job_name_prefix`` and current
2478
+ timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
2479
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2480
+ user code file (default: None).
2481
+ experiment_config (dict[str, str]): Experiment management configuration.
2482
+ Optionally, the dict can contain three keys:
2483
+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
2484
+
2485
+ The behavior of setting these keys is as follows:
2486
+
2487
+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
2488
+ automatically created and the job's Trial Component associated with the Trial.
2489
+ * If ``'TrialName'`` is supplied and the Trial already exists,
2490
+ the job's Trial Component will be associated with the Trial.
2491
+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
2492
+ the Trial Component will be unassociated.
2493
+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
2494
+ """ # noqa E501 # pylint: disable=c0301
2495
+ analysis_config = _AnalysisConfigGenerator.bias_and_explainability(
2496
+ data_config,
2497
+ model_config,
2498
+ model_predicted_label_config,
2499
+ explainability_config,
2500
+ bias_config,
2501
+ pre_training_methods,
2502
+ post_training_methods,
2503
+ )
2504
+ # when name is either not provided (is None) or an empty string ("")
2505
+ job_name = job_name or name_from_base(
2506
+ self.job_name_prefix or "Clarify-Bias-And-Explainability"
2507
+ )
2508
+ return self._run(
2509
+ data_config,
2510
+ analysis_config,
2511
+ wait,
2512
+ logs,
2513
+ job_name,
2514
+ kms_key,
2515
+ experiment_config,
2516
+ )
2517
+
2518
+
2519
+ class _AnalysisConfigGenerator:
2520
+ """Creates analysis_config objects for different type of runs."""
2521
+
2522
+ @classmethod
2523
+ def bias_and_explainability(
2524
+ cls,
2525
+ data_config: DataConfig,
2526
+ model_config: ModelConfig,
2527
+ model_predicted_label_config: ModelPredictedLabelConfig,
2528
+ explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
2529
+ bias_config: BiasConfig,
2530
+ pre_training_methods: Union[str, List[str]] = "all",
2531
+ post_training_methods: Union[str, List[str]] = "all",
2532
+ ):
2533
+ """Generates a config for Bias and Explainability"""
2534
+ # TimeSeries bias metrics are not supported
2535
+ if (
2536
+ isinstance(explainability_config, AsymmetricShapleyValueConfig)
2537
+ or "time_series_data_config" in data_config.analysis_config
2538
+ or (model_config and "time_series_predictor_config" in model_config.predictor_config)
2539
+ ):
2540
+ raise ValueError("Bias metrics are unsupported for time series.")
2541
+ analysis_config = {**data_config.get_config(), **bias_config.get_config()}
2542
+ analysis_config = cls._add_methods(
2543
+ analysis_config,
2544
+ pre_training_methods=pre_training_methods,
2545
+ post_training_methods=post_training_methods,
2546
+ explainability_config=explainability_config,
2547
+ )
2548
+ analysis_config = cls._add_predictor(
2549
+ analysis_config, model_config, model_predicted_label_config
2550
+ )
2551
+ return analysis_config
2552
+
2553
+ @classmethod
2554
+ def explainability(
2555
+ cls,
2556
+ data_config: DataConfig,
2557
+ model_config: ModelConfig,
2558
+ model_predicted_label_config: ModelPredictedLabelConfig,
2559
+ explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
2560
+ ):
2561
+ """Generates a config for Explainability"""
2562
+ # determine if this is a time series explainability case by checking
2563
+ # if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given
2564
+ ts_data_conf_absent = "time_series_data_config" not in data_config.analysis_config
2565
+ ts_model_conf_absent = "time_series_predictor_config" not in model_config.predictor_config
2566
+
2567
+ if isinstance(explainability_config, AsymmetricShapleyValueConfig):
2568
+ if ts_data_conf_absent:
2569
+ raise ValueError("Please provide a TimeSeriesDataConfig to DataConfig.")
2570
+ if ts_model_conf_absent:
2571
+ raise ValueError("Please provide a TimeSeriesModelConfig to ModelConfig.")
2572
+ # Check static covariates baseline matches number of provided static covariate columns
2573
+ _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline(
2574
+ explainability_config=explainability_config,
2575
+ data_config=data_config,
2576
+ )
2577
+ else:
2578
+ if not ts_data_conf_absent:
2579
+ raise ValueError(
2580
+ "Please provide an AsymmetricShapleyValueConfig for time series "
2581
+ "explainability cases. For non time series cases, please do not provide a "
2582
+ "TimeSeriesDataConfig."
2583
+ )
2584
+ if not ts_model_conf_absent:
2585
+ raise ValueError(
2586
+ "Please provide an AsymmetricShapleyValueConfig for time series "
2587
+ "explainability cases. For non time series cases, please do not provide a "
2588
+ "TimeSeriesModelConfig."
2589
+ )
2590
+
2591
+ # construct whole analysis config
2592
+ analysis_config = data_config.analysis_config
2593
+ analysis_config = cls._add_predictor(
2594
+ analysis_config, model_config, model_predicted_label_config
2595
+ )
2596
+ analysis_config = cls._add_methods(
2597
+ analysis_config,
2598
+ explainability_config=explainability_config,
2599
+ )
2600
+ return analysis_config
2601
+
2602
+ @classmethod
2603
+ def bias_pre_training(
2604
+ cls,
2605
+ data_config: DataConfig,
2606
+ bias_config: BiasConfig,
2607
+ methods: Union[str, List[str]],
2608
+ ):
2609
+ """Generates a config for Bias Pre Training"""
2610
+ analysis_config = {**data_config.get_config(), **bias_config.get_config()}
2611
+ analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
2612
+ return analysis_config
2613
+
2614
+ @classmethod
2615
+ def bias_post_training(
2616
+ cls,
2617
+ data_config: DataConfig,
2618
+ bias_config: BiasConfig,
2619
+ model_predicted_label_config: ModelPredictedLabelConfig,
2620
+ methods: Union[str, List[str]],
2621
+ model_config: ModelConfig,
2622
+ ):
2623
+ """Generates a config for Bias Post Training"""
2624
+ analysis_config = {**data_config.get_config(), **bias_config.get_config()}
2625
+ analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
2626
+ analysis_config = cls._add_predictor(
2627
+ analysis_config, model_config, model_predicted_label_config
2628
+ )
2629
+ return analysis_config
2630
+
2631
+ @classmethod
2632
+ def bias(
2633
+ cls,
2634
+ data_config: DataConfig,
2635
+ bias_config: BiasConfig,
2636
+ model_config: ModelConfig,
2637
+ model_predicted_label_config: ModelPredictedLabelConfig,
2638
+ pre_training_methods: Union[str, List[str]] = "all",
2639
+ post_training_methods: Union[str, List[str]] = "all",
2640
+ ):
2641
+ """Generates a config for Bias"""
2642
+ analysis_config = {**data_config.get_config(), **bias_config.get_config()}
2643
+ analysis_config = cls._add_methods(
2644
+ analysis_config,
2645
+ pre_training_methods=pre_training_methods,
2646
+ post_training_methods=post_training_methods,
2647
+ )
2648
+ analysis_config = cls._add_predictor(
2649
+ analysis_config, model_config, model_predicted_label_config
2650
+ )
2651
+ return analysis_config
2652
+
2653
+ @classmethod
2654
+ def _add_predictor(
2655
+ cls,
2656
+ analysis_config: Dict,
2657
+ model_config: ModelConfig,
2658
+ model_predicted_label_config: ModelPredictedLabelConfig,
2659
+ ):
2660
+ """Extends analysis config with predictor."""
2661
+ analysis_config = {**analysis_config}
2662
+ if isinstance(model_config, ModelConfig):
2663
+ analysis_config["predictor"] = model_config.get_predictor_config()
2664
+ else:
2665
+ if (
2666
+ "shap" in analysis_config["methods"]
2667
+ or "pdp" in analysis_config["methods"]
2668
+ or "asymmetric_shapley_value" in analysis_config["methods"]
2669
+ ):
2670
+ raise ValueError(
2671
+ "model_config must be provided when explainability methods are selected."
2672
+ )
2673
+ if (
2674
+ "predicted_label_dataset_uri" not in analysis_config
2675
+ and "predicted_label" not in analysis_config
2676
+ ):
2677
+ raise ValueError(
2678
+ "model_config must be provided when `predicted_label_dataset_uri` or "
2679
+ "`predicted_label` are not provided in data_config."
2680
+ )
2681
+ if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
2682
+ (
2683
+ probability_threshold,
2684
+ predictor_config,
2685
+ ) = model_predicted_label_config.get_predictor_config()
2686
+ if predictor_config and "predictor" in analysis_config:
2687
+ analysis_config["predictor"].update(predictor_config)
2688
+ _set(probability_threshold, "probability_threshold", analysis_config)
2689
+ elif "predictor" in analysis_config:
2690
+ _set(model_predicted_label_config, "label", analysis_config["predictor"])
2691
+ return analysis_config
2692
+
2693
+ @classmethod
2694
+ def _add_methods(
2695
+ cls,
2696
+ analysis_config: Dict,
2697
+ pre_training_methods: Union[str, List[str]] = None,
2698
+ post_training_methods: Union[str, List[str]] = None,
2699
+ explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None,
2700
+ report: bool = True,
2701
+ ):
2702
+ """Extends analysis config with methods."""
2703
+ # validate
2704
+ params = [pre_training_methods, post_training_methods, explainability_config]
2705
+ if not any(params):
2706
+ raise AttributeError(
2707
+ "analysis_config must have at least one working method: "
2708
+ "One of the "
2709
+ "`pre_training_methods`, `post_training_methods`, `explainability_config`."
2710
+ )
2711
+
2712
+ # main logic
2713
+ analysis_config = {**analysis_config}
2714
+ if "methods" not in analysis_config:
2715
+ analysis_config["methods"] = {}
2716
+
2717
+ if report:
2718
+ analysis_config["methods"]["report"] = {
2719
+ "name": "report",
2720
+ "title": "Analysis Report",
2721
+ }
2722
+
2723
+ if pre_training_methods:
2724
+ analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
2725
+
2726
+ if post_training_methods:
2727
+ analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
2728
+
2729
+ if explainability_config is not None:
2730
+ if isinstance(explainability_config, AsymmetricShapleyValueConfig):
2731
+ explainability_methods = explainability_config.get_explainability_config()
2732
+ else:
2733
+ explainability_methods = cls._merge_explainability_configs(
2734
+ explainability_config,
2735
+ )
2736
+ analysis_config["methods"] = {
2737
+ **analysis_config["methods"],
2738
+ **explainability_methods,
2739
+ }
2740
+ return analysis_config
2741
+
2742
+ @classmethod
2743
+ def _merge_explainability_configs(
2744
+ cls,
2745
+ explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
2746
+ ):
2747
+ """Merges explainability configs, when more than one."""
2748
+ non_ts = "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses."
2749
+ # validation
2750
+ if isinstance(explainability_config, AsymmetricShapleyValueConfig):
2751
+ raise ValueError(non_ts)
2752
+ if (
2753
+ isinstance(explainability_config, PDPConfig)
2754
+ and "features" not in explainability_config.get_explainability_config()["pdp"]
2755
+ ):
2756
+ raise ValueError("PDP features must be provided when ShapConfig is not provided")
2757
+ if isinstance(explainability_config, list):
2758
+ if len(explainability_config) == 0:
2759
+ raise ValueError("Please provide at least one explainability config.")
2760
+ # list validation
2761
+ for config in explainability_config:
2762
+ # ensure all provided explainability configs are not AsymmetricShapleyValueConfig
2763
+ if isinstance(config, AsymmetricShapleyValueConfig):
2764
+ raise ValueError(non_ts)
2765
+ # main logic
2766
+ explainability_methods = {}
2767
+ for config in explainability_config:
2768
+ explain_config = config.get_explainability_config()
2769
+ explainability_methods.update(explain_config)
2770
+ if not len(explainability_methods) == len(explainability_config):
2771
+ raise ValueError("Duplicate explainability configs are provided")
2772
+ if (
2773
+ "shap" not in explainability_methods
2774
+ and "features" not in explainability_methods["pdp"]
2775
+ ):
2776
+ raise ValueError("PDP features must be provided when ShapConfig is not provided")
2777
+ return explainability_methods
2778
+ return explainability_config.get_explainability_config()
2779
+
2780
+ @classmethod
2781
+ def _validate_time_series_static_covariates_baseline(
2782
+ cls,
2783
+ explainability_config: AsymmetricShapleyValueConfig,
2784
+ data_config: DataConfig,
2785
+ ):
2786
+ """Validates static covariates in baseline for asymmetric shapley value (for time series).
2787
+
2788
+ Checks that baseline values set for static covariate columns are
2789
+ consistent between every item_id and the number of static covariate columns
2790
+ provided in DataConfig.
2791
+ """
2792
+ baseline = explainability_config.get_explainability_config()[
2793
+ "asymmetric_shapley_value"
2794
+ ].get("baseline")
2795
+ if isinstance(baseline, dict) and "static_covariates" in baseline:
2796
+ covariate_count = len(
2797
+ data_config.get_config()["time_series_data_config"].get("static_covariates", [])
2798
+ )
2799
+ if covariate_count > 0:
2800
+ for item_id in baseline.get("static_covariates", []):
2801
+ baseline_entry = baseline["static_covariates"][item_id]
2802
+ if not isinstance(baseline_entry, list):
2803
+ raise ValueError(
2804
+ f"Baseline entry for {item_id} must be a list, is "
2805
+ f"{type(baseline_entry)}."
2806
+ )
2807
+ if len(baseline_entry) != covariate_count:
2808
+ raise ValueError(
2809
+ f"Length of baseline entry for {item_id} does not match number "
2810
+ f"of static covariate columns. Please ensure every covariate "
2811
+ f"has a baseline value for every item id."
2812
+ )
2813
+ else:
2814
+ raise ValueError(
2815
+ "Static covariate baselines are provided in AsymmetricShapleyValueConfig "
2816
+ "when no static covariate columns are provided in TimeSeriesDataConfig. "
2817
+ "Please check these configs."
2818
+ )
2819
+
2820
+
2821
+ def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
2822
+ """Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
2823
+
2824
+ Args:
2825
+ analysis_config_file (str): File path to the local analysis config file.
2826
+ s3_output_path (str): S3 prefix to store the analysis config file.
2827
+ sagemaker_session (:class:`~sagemaker.session.Session`):
2828
+ :class:`~sagemaker.session.Session` object which manages interactions with
2829
+ Amazon SageMaker and any other AWS services needed. If not specified,
2830
+ the processor creates a :class:`~sagemaker.session.Session`
2831
+ using the default AWS configuration chain.
2832
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
2833
+ user code file (default: None).
2834
+
2835
+ Returns:
2836
+ The S3 URI of the uploaded file.
2837
+ """
2838
+ return s3.S3Uploader.upload(
2839
+ local_path=analysis_config_file,
2840
+ desired_s3_uri=s3_output_path,
2841
+ sagemaker_session=sagemaker_session,
2842
+ kms_key=kms_key,
2843
+ )
2844
+
2845
+
2846
+ class ProcessingOutputHandler:
2847
+ """Class to handle the parameters for SagemakerProcessor.Processingoutput"""
2848
+
2849
+ class S3UploadMode(Enum):
2850
+ """Enum values for different uplaod modes to s3 bucket"""
2851
+
2852
+ CONTINUOUS = "Continuous"
2853
+ ENDOFJOB = "EndOfJob"
2854
+
2855
+ @classmethod
2856
+ def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str:
2857
+ """Fetches s3_upload mode based on the shap_config values
2858
+
2859
+ Args:
2860
+ analysis_config (dict): dict Config following the analysis_config.json format
2861
+
2862
+ Returns:
2863
+ The s3_upload_mode type for the processing output.
2864
+ """
2865
+ dataset_type = analysis_config["dataset_type"]
2866
+ return (
2867
+ ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
2868
+ if dataset_type == DatasetType.IMAGE.value
2869
+ else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value
2870
+ )
2871
+
2872
+
2873
+ def _set(value, key, dictionary):
2874
+ """Sets dictionary[key] = value if value is not None."""
2875
+ if value is not None:
2876
+ dictionary[key] = value
2877
+
2878
+
2879
+ # Public API
2880
+ __all__ = [
2881
+ "AsymmetricShapleyValueConfig",
2882
+ "BiasConfig",
2883
+ "DataConfig",
2884
+ "DatasetType",
2885
+ "ExplainabilityConfig",
2886
+ "ImageConfig",
2887
+ "ModelConfig",
2888
+ "ModelPredictedLabelConfig",
2889
+ "PDPConfig",
2890
+ "ProcessingOutputHandler",
2891
+ "SageMakerClarifyProcessor",
2892
+ "SegmentationConfig",
2893
+ "SHAPConfig",
2894
+ "TextConfig",
2895
+ "TimeSeriesDataConfig",
2896
+ "TimeSeriesJSONDatasetFormat",
2897
+ "TimeSeriesModelConfig",
2898
+ ]