sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1495 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains code related to Amazon SageMaker Explainability AI Model Monitoring.
14
+
15
+ These classes assist with suggesting baselines and creating monitoring schedules for monitoring
16
+ bias metrics and feature attribution of SageMaker Endpoints.
17
+ """
18
+ from __future__ import print_function, absolute_import
19
+
20
+ import copy
21
+ import json
22
+ import logging
23
+ import uuid
24
+
25
+ from sagemaker.core.model_monitor import model_monitoring as mm
26
+ from sagemaker.core.model_monitor.utils import (
27
+ boto_describe_monitoring_schedule,
28
+ boto_list_monitoring_executions,
29
+ )
30
+ from sagemaker.core import image_uris, s3
31
+ from sagemaker.core.helper.session_helper import Session, expand_role
32
+ from sagemaker.core.common_utils import (
33
+ name_from_base,
34
+ format_tags,
35
+ get_resource_name_from_arn,
36
+ list_tags,
37
+ )
38
+ from sagemaker.core.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig
39
+ from sagemaker.core.processing import logs_for_processing_job
40
+
41
+ # Setting _LOGGER for backward compatibility, in case users import it...
42
+ logger = _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class ClarifyModelMonitor(mm.ModelMonitor):
46
+ """Base class of Amazon SageMaker Explainability API model monitors.
47
+
48
+ This class is an ``abstract base class``, please instantiate its subclasses
49
+ if you want to monitor bias metrics or feature attribution of an endpoint.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ role=None,
55
+ instance_count=1,
56
+ instance_type="ml.m5.xlarge",
57
+ volume_size_in_gb=30,
58
+ volume_kms_key=None,
59
+ output_kms_key=None,
60
+ max_runtime_in_seconds=None,
61
+ base_job_name=None,
62
+ sagemaker_session=None,
63
+ env=None,
64
+ tags=None,
65
+ network_config=None,
66
+ ):
67
+ """Initializes a monitor instance.
68
+
69
+ The monitor handles baselining datasets and creating Amazon SageMaker
70
+ Monitoring Schedules to monitor SageMaker endpoints.
71
+
72
+ Args:
73
+ role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
74
+ instance_count (int): The number of instances to run
75
+ the jobs with.
76
+ instance_type (str): Type of EC2 instance to use for
77
+ the job, for example, 'ml.m5.xlarge'.
78
+ volume_size_in_gb (int): Size in GB of the EBS volume
79
+ to use for storing data during processing (default: 30).
80
+ volume_kms_key (str): A KMS key for the job's volume.
81
+ output_kms_key (str): The KMS key id for the job's outputs.
82
+ max_runtime_in_seconds (int): Timeout in seconds. After this amount of
83
+ time, Amazon SageMaker terminates the job regardless of its current status.
84
+ Default: 3600
85
+ base_job_name (str): Prefix for the job name. If not specified,
86
+ a default name is generated based on the training image name and
87
+ current timestamp.
88
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
89
+ manages interactions with Amazon SageMaker APIs and any other
90
+ AWS services needed. If not specified, one is created using
91
+ the default AWS configuration chain.
92
+ env (dict): Environment variables to be passed to the job.
93
+ tags (Optional[Tags]): List of tags to be passed to the job.
94
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
95
+ object that configures network isolation, encryption of
96
+ inter-container traffic, security group IDs, and subnets.
97
+ """
98
+ if type(self) == __class__: # pylint: disable=unidiomatic-typecheck
99
+ raise TypeError(
100
+ "{} is abstract, please instantiate its subclasses instead.".format(
101
+ __class__.__name__
102
+ )
103
+ )
104
+
105
+ session = sagemaker_session or Session()
106
+ clarify_image_uri = image_uris.retrieve("clarify", session.boto_session.region_name)
107
+
108
+ super(ClarifyModelMonitor, self).__init__(
109
+ role=role,
110
+ image_uri=clarify_image_uri,
111
+ instance_count=instance_count,
112
+ instance_type=instance_type,
113
+ volume_size_in_gb=volume_size_in_gb,
114
+ volume_kms_key=volume_kms_key,
115
+ output_kms_key=output_kms_key,
116
+ max_runtime_in_seconds=max_runtime_in_seconds,
117
+ base_job_name=base_job_name,
118
+ sagemaker_session=session,
119
+ env=env,
120
+ tags=format_tags(tags),
121
+ network_config=network_config,
122
+ )
123
+ self.latest_baselining_job_config = None
124
+
125
+ def run_baseline(self, **_):
126
+ """Not implemented.
127
+
128
+ '.run_baseline()' is only allowed for ModelMonitor objects.
129
+ Please use `suggest_baseline` instead.
130
+
131
+ Raises:
132
+ NotImplementedError
133
+ """
134
+ raise NotImplementedError(
135
+ "'.run_baseline()' is only allowed for ModelMonitor objects."
136
+ "Please use suggest_baseline instead."
137
+ )
138
+
139
+ def latest_monitoring_statistics(self, **_):
140
+ """Not implemented.
141
+
142
+ The class doesn't support statistics.
143
+
144
+ Raises:
145
+ NotImplementedError
146
+ """
147
+ raise NotImplementedError("{} doesn't support statistics.".format(self.__class__.__name__))
148
+
149
+ def list_executions(self):
150
+ """Get the list of the latest monitoring executions in descending order of "ScheduledTime".
151
+
152
+ Returns:
153
+ [sagemaker.model_monitor.ClarifyMonitoringExecution]: List of
154
+ ClarifyMonitoringExecution in descending order of "ScheduledTime".
155
+ """
156
+ executions = super(ClarifyModelMonitor, self).list_executions()
157
+ return [
158
+ ClarifyMonitoringExecution(
159
+ sagemaker_session=execution.sagemaker_session,
160
+ job_name=execution.job_name,
161
+ inputs=execution.inputs,
162
+ output=execution.output,
163
+ output_kms_key=execution.output_kms_key,
164
+ )
165
+ for execution in executions
166
+ ]
167
+
168
+ def get_latest_execution_logs(self, wait=False):
169
+ """Get the processing job logs for the most recent monitoring execution
170
+
171
+ Args:
172
+ wait (bool): Whether the call should wait until the job completes (default: False).
173
+
174
+ Raises:
175
+ ValueError: If no execution job or processing job for the last execution has run
176
+
177
+ Returns: None
178
+ """
179
+ monitoring_executions = boto_list_monitoring_executions(
180
+ sagemaker_session=self.sagemaker_session,
181
+ monitoring_schedule_name=self.monitoring_schedule_name,
182
+ )
183
+ if len(monitoring_executions["MonitoringExecutionSummaries"]) == 0:
184
+ raise ValueError("No execution jobs were kicked off.")
185
+ if "ProcessingJobArn" not in monitoring_executions["MonitoringExecutionSummaries"][0]:
186
+ raise ValueError("Processing Job did not run for the last execution")
187
+ job_arn = monitoring_executions["MonitoringExecutionSummaries"][0]["ProcessingJobArn"]
188
+ logs_for_processing_job(
189
+ self.sagemaker_session,
190
+ job_name=get_resource_name_from_arn(job_arn),
191
+ wait=wait,
192
+ )
193
+
194
+ def _create_baselining_processor(self):
195
+ """Create and return a SageMakerClarifyProcessor object which will run the baselining job.
196
+
197
+ Returns:
198
+ sagemaker.clarify.SageMakerClarifyProcessor object.
199
+ """
200
+
201
+ baselining_processor = SageMakerClarifyProcessor(
202
+ role=self.role,
203
+ instance_count=self.instance_count,
204
+ instance_type=self.instance_type,
205
+ volume_size_in_gb=self.volume_size_in_gb,
206
+ volume_kms_key=self.volume_kms_key,
207
+ output_kms_key=self.output_kms_key,
208
+ max_runtime_in_seconds=self.max_runtime_in_seconds,
209
+ sagemaker_session=self.sagemaker_session,
210
+ env=self.env,
211
+ tags=self.tags,
212
+ network_config=self.network_config,
213
+ )
214
+ baselining_processor.image_uri = self.image_uri
215
+ baselining_processor.base_job_name = self.base_job_name
216
+ return baselining_processor
217
+
218
+ def _upload_analysis_config(self, analysis_config, output_s3_uri, job_definition_name, kms_key):
219
+ """Upload analysis config to s3://<output path>/<job name>/analysis_config.json
220
+
221
+ Args:
222
+ analysis_config (dict): analysis config of a Clarify model monitor.
223
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
224
+ Default: "s3://<default_session_bucket>/<job_name>/output"
225
+ job_definition_name (str): Job definition name.
226
+ If not specified then a default one will be generated.
227
+ kms_key( str): The ARN of the KMS key that is used to encrypt the
228
+ user code file (default: None).
229
+
230
+ Returns:
231
+ str: The S3 uri of the uploaded file(s).
232
+ """
233
+ s3_uri = s3.s3_path_join(
234
+ output_s3_uri,
235
+ job_definition_name,
236
+ str(uuid.uuid4()),
237
+ "analysis_config.json",
238
+ )
239
+ logger.info("Uploading analysis config to {s3_uri}.")
240
+ return s3.S3Uploader.upload_string_as_file_body(
241
+ json.dumps(analysis_config),
242
+ desired_s3_uri=s3_uri,
243
+ sagemaker_session=self.sagemaker_session,
244
+ kms_key=kms_key,
245
+ )
246
+
247
+ def _build_create_job_definition_request(
248
+ self,
249
+ monitoring_schedule_name,
250
+ job_definition_name,
251
+ image_uri,
252
+ latest_baselining_job_name=None,
253
+ latest_baselining_job_config=None,
254
+ existing_job_desc=None,
255
+ endpoint_input=None,
256
+ ground_truth_input=None,
257
+ analysis_config=None,
258
+ output_s3_uri=None,
259
+ constraints=None,
260
+ enable_cloudwatch_metrics=None,
261
+ role=None,
262
+ instance_count=None,
263
+ instance_type=None,
264
+ volume_size_in_gb=None,
265
+ volume_kms_key=None,
266
+ output_kms_key=None,
267
+ max_runtime_in_seconds=None,
268
+ env=None,
269
+ tags=None,
270
+ network_config=None,
271
+ batch_transform_input=None,
272
+ ):
273
+ """Build the request for job definition creation API
274
+
275
+ Args:
276
+ monitoring_schedule_name (str): Monitoring schedule name.
277
+ job_definition_name (str): Job definition name.
278
+ If not specified then a default one will be generated.
279
+ image_uri (str): The uri of the image to use for the jobs started by the Monitor.
280
+ latest_baselining_job_name (str): name of the last baselining job.
281
+ latest_baselining_job_config (ClarifyBaseliningConfig): analysis config from
282
+ last baselining job.
283
+ existing_job_desc (dict): description of existing job definition. It will be updated by
284
+ values that were passed in, and then used to create the new job definition.
285
+ endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
286
+ This can either be the endpoint name or an EndpointInput.
287
+ ground_truth_input (str): S3 URI to ground truth dataset.
288
+ analysis_config (str or BiasAnalysisConfig or ExplainabilityAnalysisConfig): URI to the
289
+ analysis_config.json for the bias job. If it is None then configuration of latest
290
+ baselining job config will be reused. If no baselining job then fail the call.
291
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
292
+ Default: "s3://<default_session_bucket>/<job_name>/output"
293
+ constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
294
+ for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
295
+ to a constraints JSON file.
296
+ enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
297
+ the baselining or monitoring jobs.
298
+ role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
299
+ instance_count (int): The number of instances to run
300
+ the jobs with.
301
+ instance_type (str): Type of EC2 instance to use for
302
+ the job, for example, 'ml.m5.xlarge'.
303
+ volume_size_in_gb (int): Size in GB of the EBS volume
304
+ to use for storing data during processing (default: 30).
305
+ volume_kms_key (str): A KMS key for the job's volume.
306
+ output_kms_key (str): KMS key id for output.
307
+ max_runtime_in_seconds (int): Timeout in seconds. After this amount of
308
+ time, Amazon SageMaker terminates the job regardless of its current status.
309
+ Default: 3600
310
+ env (dict): Environment variables to be passed to the job.
311
+ tags (Optional[Tags]): List of tags to be passed to the job.
312
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
313
+ object that configures network isolation, encryption of
314
+ inter-container traffic, security group IDs, and subnets.
315
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
316
+ the monitoring schedule on the batch transform
317
+
318
+ Returns:
319
+ dict: request parameters to create job definition.
320
+ """
321
+ if existing_job_desc is not None:
322
+ app_specification = existing_job_desc[
323
+ "{}AppSpecification".format(self.monitoring_type())
324
+ ]
325
+ baseline_config = existing_job_desc.get(
326
+ "{}BaselineConfig".format(self.monitoring_type()), {}
327
+ )
328
+ job_input = existing_job_desc["{}JobInput".format(self.monitoring_type())]
329
+ job_output = existing_job_desc["{}JobOutputConfig".format(self.monitoring_type())]
330
+ cluster_config = existing_job_desc["JobResources"]["ClusterConfig"]
331
+ if role is None:
332
+ role = existing_job_desc["RoleArn"]
333
+ existing_network_config = existing_job_desc.get("NetworkConfig")
334
+ stop_condition = existing_job_desc.get("StoppingCondition", {})
335
+ else:
336
+ app_specification = {}
337
+ baseline_config = {}
338
+ job_input = {}
339
+ job_output = {}
340
+ cluster_config = {}
341
+ existing_network_config = None
342
+ stop_condition = {}
343
+
344
+ # job output
345
+ if output_s3_uri is not None:
346
+ normalized_monitoring_output = self._normalize_monitoring_output(
347
+ monitoring_schedule_name, output_s3_uri
348
+ )
349
+ job_output["MonitoringOutputs"] = [normalized_monitoring_output._to_request_dict()]
350
+ if output_kms_key is not None:
351
+ job_output["KmsKeyId"] = output_kms_key
352
+
353
+ # app specification
354
+ if analysis_config is None:
355
+ if latest_baselining_job_config is not None:
356
+ analysis_config = latest_baselining_job_config.analysis_config
357
+ elif app_specification:
358
+ analysis_config = app_specification["ConfigUri"]
359
+ else:
360
+ raise ValueError("analysis_config is mandatory.")
361
+ # backfill analysis_config
362
+ if isinstance(analysis_config, str):
363
+ analysis_config_uri = analysis_config
364
+ else:
365
+ analysis_config_uri = self._upload_analysis_config(
366
+ analysis_config._to_dict(), output_s3_uri, job_definition_name, output_kms_key
367
+ )
368
+ app_specification["ConfigUri"] = analysis_config_uri
369
+ app_specification["ImageUri"] = image_uri
370
+ normalized_env = self._generate_env_map(
371
+ env=env, enable_cloudwatch_metrics=enable_cloudwatch_metrics
372
+ )
373
+ if normalized_env:
374
+ app_specification["Environment"] = normalized_env
375
+
376
+ # baseline config
377
+ if constraints:
378
+ # noinspection PyTypeChecker
379
+ _, constraints_object = self._get_baseline_files(
380
+ statistics=None, constraints=constraints, sagemaker_session=self.sagemaker_session
381
+ )
382
+ constraints_s3_uri = None
383
+ if constraints_object is not None:
384
+ constraints_s3_uri = constraints_object.file_s3_uri
385
+ baseline_config["ConstraintsResource"] = dict(S3Uri=constraints_s3_uri)
386
+ elif latest_baselining_job_name:
387
+ baseline_config["BaseliningJobName"] = latest_baselining_job_name
388
+
389
+ # job input
390
+ if endpoint_input is not None:
391
+ normalized_endpoint_input = self._normalize_endpoint_input(
392
+ endpoint_input=endpoint_input
393
+ )
394
+ # backfill attributes to endpoint input
395
+ if latest_baselining_job_config is not None:
396
+ if normalized_endpoint_input.features_attribute is None:
397
+ normalized_endpoint_input.features_attribute = (
398
+ latest_baselining_job_config.features_attribute
399
+ )
400
+ if normalized_endpoint_input.inference_attribute is None:
401
+ normalized_endpoint_input.inference_attribute = (
402
+ latest_baselining_job_config.inference_attribute
403
+ )
404
+ if normalized_endpoint_input.probability_attribute is None:
405
+ normalized_endpoint_input.probability_attribute = (
406
+ latest_baselining_job_config.probability_attribute
407
+ )
408
+ if normalized_endpoint_input.probability_threshold_attribute is None:
409
+ normalized_endpoint_input.probability_threshold_attribute = (
410
+ latest_baselining_job_config.probability_threshold_attribute
411
+ )
412
+ job_input = normalized_endpoint_input._to_request_dict()
413
+ elif batch_transform_input is not None:
414
+ # backfill attributes to batch transform input
415
+ if latest_baselining_job_config is not None:
416
+ if batch_transform_input.features_attribute is None:
417
+ batch_transform_input.features_attribute = (
418
+ latest_baselining_job_config.features_attribute
419
+ )
420
+ if batch_transform_input.inference_attribute is None:
421
+ batch_transform_input.inference_attribute = (
422
+ latest_baselining_job_config.inference_attribute
423
+ )
424
+ if batch_transform_input.probability_attribute is None:
425
+ batch_transform_input.probability_attribute = (
426
+ latest_baselining_job_config.probability_attribute
427
+ )
428
+ if batch_transform_input.probability_threshold_attribute is None:
429
+ batch_transform_input.probability_threshold_attribute = (
430
+ latest_baselining_job_config.probability_threshold_attribute
431
+ )
432
+ job_input = batch_transform_input._to_request_dict()
433
+
434
+ if ground_truth_input is not None:
435
+ job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input)
436
+
437
+ # cluster config
438
+ if instance_count is not None:
439
+ cluster_config["InstanceCount"] = instance_count
440
+ if instance_type is not None:
441
+ cluster_config["InstanceType"] = instance_type
442
+ if volume_size_in_gb is not None:
443
+ cluster_config["VolumeSizeInGB"] = volume_size_in_gb
444
+ if volume_kms_key is not None:
445
+ cluster_config["VolumeKmsKeyId"] = volume_kms_key
446
+
447
+ # stop condition
448
+ if max_runtime_in_seconds is not None:
449
+ stop_condition["MaxRuntimeInSeconds"] = max_runtime_in_seconds
450
+
451
+ request_dict = {
452
+ "JobDefinitionName": job_definition_name,
453
+ "{}AppSpecification".format(self.monitoring_type()): app_specification,
454
+ "{}JobInput".format(self.monitoring_type()): job_input,
455
+ "{}JobOutputConfig".format(self.monitoring_type()): job_output,
456
+ "JobResources": dict(ClusterConfig=cluster_config),
457
+ "RoleArn": expand_role(self.sagemaker_session, role),
458
+ }
459
+
460
+ if baseline_config:
461
+ request_dict["{}BaselineConfig".format(self.monitoring_type())] = baseline_config
462
+
463
+ if network_config is not None:
464
+ network_config_dict = network_config._to_request_dict()
465
+ request_dict["NetworkConfig"] = network_config_dict
466
+ elif existing_network_config is not None:
467
+ request_dict["NetworkConfig"] = existing_network_config
468
+
469
+ if stop_condition:
470
+ request_dict["StoppingCondition"] = stop_condition
471
+
472
+ if tags is not None:
473
+ request_dict["Tags"] = format_tags(tags)
474
+
475
+ return request_dict
476
+
477
+
478
+ class ModelBiasMonitor(ClarifyModelMonitor):
479
+ """Amazon SageMaker model monitor to monitor bias metrics of an endpoint.
480
+
481
+ Please see the __init__ method of its base class for how to instantiate it.
482
+ """
483
+
484
+ JOB_DEFINITION_BASE_NAME = "model-bias-job-definition"
485
+
486
+ @classmethod
487
+ def monitoring_type(cls):
488
+ """Type of the monitoring job."""
489
+ return "ModelBias"
490
+
491
+ def suggest_baseline(
492
+ self,
493
+ data_config,
494
+ bias_config,
495
+ model_config,
496
+ model_predicted_label_config=None,
497
+ wait=False,
498
+ logs=False,
499
+ job_name=None,
500
+ kms_key=None,
501
+ ):
502
+ """Suggests baselines for use with Amazon SageMaker Model Monitoring Schedules.
503
+
504
+ Args:
505
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
506
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
507
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
508
+ endpoint to be created.
509
+ model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
510
+ Config of how to extract the predicted label from the model output.
511
+ wait (bool): Whether the call should wait until the job completes (default: False).
512
+ logs (bool): Whether to show the logs produced by the job.
513
+ Only meaningful when wait is True (default: False).
514
+ job_name (str): Processing job name. If not specified, the processor generates
515
+ a default job name, based on the image name and current timestamp.
516
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
517
+ user code file (default: None).
518
+
519
+ Returns:
520
+ sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
521
+ baselining job.
522
+ """
523
+ baselining_processor = self._create_baselining_processor()
524
+ baselining_job_name = self._generate_baselining_job_name(job_name=job_name)
525
+ baselining_processor.run_bias(
526
+ data_config=data_config,
527
+ bias_config=bias_config,
528
+ model_config=model_config,
529
+ model_predicted_label_config=model_predicted_label_config,
530
+ wait=wait,
531
+ logs=logs,
532
+ job_name=baselining_job_name,
533
+ kms_key=kms_key,
534
+ )
535
+
536
+ latest_baselining_job_config = ClarifyBaseliningConfig(
537
+ analysis_config=BiasAnalysisConfig(
538
+ bias_config=bias_config, headers=data_config.headers, label=data_config.label
539
+ ),
540
+ features_attribute=data_config.features,
541
+ )
542
+ if model_predicted_label_config is not None:
543
+ latest_baselining_job_config.inference_attribute = (
544
+ model_predicted_label_config.label
545
+ if model_predicted_label_config.label is None
546
+ else str(model_predicted_label_config.label)
547
+ )
548
+ latest_baselining_job_config.probability_attribute = (
549
+ model_predicted_label_config.probability
550
+ if model_predicted_label_config.probability is None
551
+ else str(model_predicted_label_config.probability)
552
+ )
553
+ latest_baselining_job_config.probability_threshold_attribute = (
554
+ model_predicted_label_config.probability_threshold
555
+ )
556
+ self.latest_baselining_job_config = latest_baselining_job_config
557
+ self.latest_baselining_job_name = baselining_job_name
558
+ self.latest_baselining_job = ClarifyBaseliningJob(
559
+ processing_job=baselining_processor.latest_job
560
+ )
561
+
562
+ self.baselining_jobs.append(self.latest_baselining_job)
563
+ return baselining_processor.latest_job
564
+
565
+ # noinspection PyMethodOverriding
566
+ def create_monitoring_schedule(
567
+ self,
568
+ endpoint_input=None,
569
+ ground_truth_input=None,
570
+ analysis_config=None,
571
+ output_s3_uri=None,
572
+ constraints=None,
573
+ monitor_schedule_name=None,
574
+ schedule_cron_expression=None,
575
+ enable_cloudwatch_metrics=True,
576
+ batch_transform_input=None,
577
+ data_analysis_start_time=None,
578
+ data_analysis_end_time=None,
579
+ ):
580
+ """Creates a monitoring schedule.
581
+
582
+ Args:
583
+ endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
584
+ This can either be the endpoint name or an EndpointInput. (default: None)
585
+ ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
586
+ analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
587
+ If it is None then configuration of the latest baselining job will be reused, but
588
+ if no baselining job then fail the call. (default: None)
589
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
590
+ Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
591
+ constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
592
+ for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
593
+ to a constraints JSON file. (default: None)
594
+ monitor_schedule_name (str): Schedule name. If not specified, the processor generates
595
+ a default job name, based on the image name and current timestamp.
596
+ (default: None)
597
+ schedule_cron_expression (str): The cron expression that dictates the frequency that
598
+ this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
599
+ expressions. Default: Daily. (default: None)
600
+ enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
601
+ the baselining or monitoring jobs. (default: True)
602
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
603
+ the monitoring schedule on the batch transform (default: None)
604
+ data_analysis_start_time (str): Start time for the data analysis window
605
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
606
+ data_analysis_end_time (str): End time for the data analysis window
607
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
608
+ """
609
+ # we default ground_truth_input to None in the function signature
610
+ # but verify they are giving here for positional argument
611
+ # backward compatibility reason.
612
+ if not ground_truth_input:
613
+ raise ValueError("ground_truth_input can not be None.")
614
+ if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
615
+ message = (
616
+ "It seems that this object was already used to create an Amazon Model "
617
+ "Monitoring Schedule. To create another, first delete the existing one "
618
+ "using my_monitor.delete_monitoring_schedule()."
619
+ )
620
+ logger.error(message)
621
+ raise ValueError(message)
622
+
623
+ if (batch_transform_input is not None) ^ (endpoint_input is None):
624
+ message = (
625
+ "Need to have either batch_transform_input or endpoint_input to create an "
626
+ "Amazon Model Monitoring Schedule. "
627
+ "Please provide only one of the above required inputs"
628
+ )
629
+ logger.error(message)
630
+ raise ValueError(message)
631
+
632
+ self._check_monitoring_schedule_cron_validity(
633
+ schedule_cron_expression=schedule_cron_expression,
634
+ data_analysis_start_time=data_analysis_start_time,
635
+ data_analysis_end_time=data_analysis_end_time,
636
+ )
637
+
638
+ # create job definition
639
+ monitor_schedule_name = self._generate_monitoring_schedule_name(
640
+ schedule_name=monitor_schedule_name
641
+ )
642
+ new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
643
+ request_dict = self._build_create_job_definition_request(
644
+ monitoring_schedule_name=monitor_schedule_name,
645
+ job_definition_name=new_job_definition_name,
646
+ image_uri=self.image_uri,
647
+ latest_baselining_job_name=self.latest_baselining_job_name,
648
+ latest_baselining_job_config=self.latest_baselining_job_config,
649
+ endpoint_input=endpoint_input,
650
+ ground_truth_input=ground_truth_input,
651
+ analysis_config=analysis_config,
652
+ output_s3_uri=self._normalize_monitoring_output(
653
+ monitor_schedule_name, output_s3_uri
654
+ ).s3_output.s3_uri,
655
+ constraints=constraints,
656
+ enable_cloudwatch_metrics=enable_cloudwatch_metrics,
657
+ role=self.role,
658
+ instance_count=self.instance_count,
659
+ instance_type=self.instance_type,
660
+ volume_size_in_gb=self.volume_size_in_gb,
661
+ volume_kms_key=self.volume_kms_key,
662
+ output_kms_key=self.output_kms_key,
663
+ max_runtime_in_seconds=self.max_runtime_in_seconds,
664
+ env=self.env,
665
+ tags=self.tags,
666
+ network_config=self.network_config,
667
+ batch_transform_input=batch_transform_input,
668
+ )
669
+ self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
670
+
671
+ # create schedule
672
+ try:
673
+ self._create_monitoring_schedule_from_job_definition(
674
+ monitor_schedule_name=monitor_schedule_name,
675
+ job_definition_name=new_job_definition_name,
676
+ schedule_cron_expression=schedule_cron_expression,
677
+ data_analysis_start_time=data_analysis_start_time,
678
+ data_analysis_end_time=data_analysis_end_time,
679
+ )
680
+ self.job_definition_name = new_job_definition_name
681
+ self.monitoring_schedule_name = monitor_schedule_name
682
+ except Exception:
683
+ logger.exception("Failed to create monitoring schedule.")
684
+ self.monitoring_schedule_name = None
685
+ # noinspection PyBroadException
686
+ try:
687
+ self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
688
+ JobDefinitionName=new_job_definition_name
689
+ )
690
+ except Exception: # pylint: disable=W0703
691
+ message = "Failed to delete job definition {}.".format(new_job_definition_name)
692
+ logger.exception(message)
693
+ raise
694
+
695
+ # noinspection PyMethodOverriding
696
+ def update_monitoring_schedule(
697
+ self,
698
+ endpoint_input=None,
699
+ ground_truth_input=None,
700
+ analysis_config=None,
701
+ output_s3_uri=None,
702
+ constraints=None,
703
+ schedule_cron_expression=None,
704
+ enable_cloudwatch_metrics=None,
705
+ role=None,
706
+ instance_count=None,
707
+ instance_type=None,
708
+ volume_size_in_gb=None,
709
+ volume_kms_key=None,
710
+ output_kms_key=None,
711
+ max_runtime_in_seconds=None,
712
+ env=None,
713
+ network_config=None,
714
+ batch_transform_input=None,
715
+ data_analysis_start_time=None,
716
+ data_analysis_end_time=None,
717
+ ):
718
+ """Updates the existing monitoring schedule.
719
+
720
+ If more options than schedule_cron_expression are to be updated, a new job definition will
721
+ be created to hold them. The old job definition will not be deleted.
722
+
723
+ Args:
724
+ endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
725
+ This can either be the endpoint name or an EndpointInput.
726
+ ground_truth_input (str): S3 URI to ground truth dataset.
727
+ analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
728
+ If it is None then configuration of the latest baselining job will be reused, but
729
+ if no baselining job then fail the call.
730
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
731
+ Default: "s3://<default_session_bucket>/<job_name>/output"
732
+ constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
733
+ for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
734
+ to a constraints JSON file.
735
+ schedule_cron_expression (str): The cron expression that dictates the frequency that
736
+ this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
737
+ expressions. Default: Daily.
738
+ enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
739
+ the baselining or monitoring jobs.
740
+ role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
741
+ instance_count (int): The number of instances to run
742
+ the jobs with.
743
+ instance_type (str): Type of EC2 instance to use for
744
+ the job, for example, 'ml.m5.xlarge'.
745
+ volume_size_in_gb (int): Size in GB of the EBS volume
746
+ to use for storing data during processing (default: 30).
747
+ volume_kms_key (str): A KMS key for the job's volume.
748
+ output_kms_key (str): The KMS key id for the job's outputs.
749
+ max_runtime_in_seconds (int): Timeout in seconds. After this amount of
750
+ time, Amazon SageMaker terminates the job regardless of its current status.
751
+ Default: 3600
752
+ env (dict): Environment variables to be passed to the job.
753
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
754
+ object that configures network isolation, encryption of
755
+ inter-container traffic, security group IDs, and subnets.
756
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
757
+ the monitoring schedule on the batch transform
758
+ """
759
+ valid_args = {
760
+ arg: value for arg, value in locals().items() if arg != "self" and value is not None
761
+ }
762
+
763
+ # Nothing to update
764
+ if len(valid_args) <= 0:
765
+ return
766
+
767
+ if batch_transform_input is not None and endpoint_input is not None:
768
+ message = (
769
+ "Need to have either batch_transform_input or endpoint_input to create an "
770
+ "Amazon Model Monitoring Schedule. "
771
+ "Please provide only one of the above required inputs"
772
+ )
773
+ logger.error(message)
774
+ raise ValueError(message)
775
+
776
+ # Only need to update schedule expression
777
+ if len(valid_args) == 1 and schedule_cron_expression is not None:
778
+ self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
779
+ return
780
+
781
+ # Need to update schedule with a new job definition
782
+ job_desc = self.sagemaker_session.sagemaker_client.describe_model_bias_job_definition(
783
+ JobDefinitionName=self.job_definition_name
784
+ )
785
+ new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
786
+ request_dict = self._build_create_job_definition_request(
787
+ monitoring_schedule_name=self.monitoring_schedule_name,
788
+ job_definition_name=new_job_definition_name,
789
+ image_uri=self.image_uri,
790
+ existing_job_desc=job_desc,
791
+ endpoint_input=endpoint_input,
792
+ ground_truth_input=ground_truth_input,
793
+ analysis_config=analysis_config,
794
+ output_s3_uri=output_s3_uri,
795
+ constraints=constraints,
796
+ enable_cloudwatch_metrics=enable_cloudwatch_metrics,
797
+ role=role,
798
+ instance_count=instance_count,
799
+ instance_type=instance_type,
800
+ volume_size_in_gb=volume_size_in_gb,
801
+ volume_kms_key=volume_kms_key,
802
+ output_kms_key=output_kms_key,
803
+ max_runtime_in_seconds=max_runtime_in_seconds,
804
+ env=env,
805
+ tags=self.tags,
806
+ network_config=network_config,
807
+ batch_transform_input=batch_transform_input,
808
+ )
809
+ self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
810
+ try:
811
+ self._update_monitoring_schedule(
812
+ new_job_definition_name,
813
+ schedule_cron_expression,
814
+ data_analysis_start_time=data_analysis_start_time,
815
+ data_analysis_end_time=data_analysis_end_time,
816
+ )
817
+ self.job_definition_name = new_job_definition_name
818
+ if role is not None:
819
+ self.role = role
820
+ if instance_count is not None:
821
+ self.instance_count = instance_count
822
+ if instance_type is not None:
823
+ self.instance_type = instance_type
824
+ if volume_size_in_gb is not None:
825
+ self.volume_size_in_gb = volume_size_in_gb
826
+ if volume_kms_key is not None:
827
+ self.volume_kms_key = volume_kms_key
828
+ if output_kms_key is not None:
829
+ self.output_kms_key = output_kms_key
830
+ if max_runtime_in_seconds is not None:
831
+ self.max_runtime_in_seconds = max_runtime_in_seconds
832
+ if env is not None:
833
+ self.env = env
834
+ if network_config is not None:
835
+ self.network_config = network_config
836
+ except Exception:
837
+ logger.exception("Failed to update monitoring schedule.")
838
+ # noinspection PyBroadException
839
+ try:
840
+ self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
841
+ JobDefinitionName=new_job_definition_name
842
+ )
843
+ except Exception: # pylint: disable=W0703
844
+ message = "Failed to delete job definition {}.".format(new_job_definition_name)
845
+ logger.exception(message)
846
+ raise
847
+
848
+ def delete_monitoring_schedule(self):
849
+ """Deletes the monitoring schedule and its job definition."""
850
+ super(ModelBiasMonitor, self).delete_monitoring_schedule()
851
+ # Delete job definition.
852
+ message = "Deleting Model Bias Job Definition with name: {}".format(
853
+ self.job_definition_name
854
+ )
855
+ logger.info(message)
856
+ self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
857
+ JobDefinitionName=self.job_definition_name
858
+ )
859
+ self.job_definition_name = None
860
+
861
+ @classmethod
862
+ def attach(cls, monitor_schedule_name, sagemaker_session=None):
863
+ """Sets this object's schedule name to the name provided.
864
+
865
+ This allows subsequent describe_schedule or list_executions calls to point
866
+ to the given schedule.
867
+
868
+ Args:
869
+ monitor_schedule_name (str): The name of the schedule to attach to.
870
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
871
+ manages interactions with Amazon SageMaker APIs and any other
872
+ AWS services needed. If not specified, one is created using
873
+ the default AWS configuration chain.
874
+ """
875
+ sagemaker_session = sagemaker_session or Session()
876
+ schedule_desc = boto_describe_monitoring_schedule(
877
+ sagemaker_session, monitoring_schedule_name=monitor_schedule_name
878
+ )
879
+ monitoring_type = schedule_desc["MonitoringScheduleConfig"].get("MonitoringType")
880
+ if monitoring_type != cls.monitoring_type():
881
+ raise TypeError("{} can only attach to ModelBias schedule.".format(__class__.__name__))
882
+ job_definition_name = schedule_desc["MonitoringScheduleConfig"][
883
+ "MonitoringJobDefinitionName"
884
+ ]
885
+ job_desc = sagemaker_session.sagemaker_client.describe_model_bias_job_definition(
886
+ JobDefinitionName=job_definition_name
887
+ )
888
+ tags = list_tags(sagemaker_session, resource_arn=schedule_desc["MonitoringScheduleArn"])
889
+ return ClarifyModelMonitor._attach(
890
+ clazz=cls,
891
+ sagemaker_session=sagemaker_session,
892
+ schedule_desc=schedule_desc,
893
+ job_desc=job_desc,
894
+ tags=tags,
895
+ )
896
+
897
+
898
+ class BiasAnalysisConfig:
899
+ """Analysis configuration for ModelBiasMonitor."""
900
+
901
+ def __init__(self, bias_config, headers=None, label=None):
902
+ """Creates an analysis config dictionary.
903
+
904
+ Args:
905
+ bias_config (sagemaker.clarify.BiasConfig): Config object related to bias
906
+ configurations.
907
+ headers (list[str]): A list of column names in the input dataset.
908
+ label (str): Target attribute for the model required by bias metrics. Specified as
909
+ column name or index for CSV dataset, or as JMESPath expression for JSONLines.
910
+ """
911
+ self.analysis_config = bias_config.get_config()
912
+ if headers is not None:
913
+ self.analysis_config["headers"] = headers
914
+ if label is not None:
915
+ self.analysis_config["label"] = label
916
+
917
+ def _to_dict(self):
918
+ """Generates a request dictionary using the parameters provided to the class."""
919
+ return self.analysis_config
920
+
921
+
922
+ class ModelExplainabilityMonitor(ClarifyModelMonitor):
923
+ """Amazon SageMaker model monitor to monitor feature attribution of an endpoint.
924
+
925
+ Please see the __init__ method of its base class for how to instantiate it.
926
+ """
927
+
928
+ JOB_DEFINITION_BASE_NAME = "model-explainability-job-definition"
929
+
930
+ @classmethod
931
+ def monitoring_type(cls):
932
+ """Type of the monitoring job."""
933
+ return "ModelExplainability"
934
+
935
+ def suggest_baseline(
936
+ self,
937
+ data_config,
938
+ explainability_config,
939
+ model_config,
940
+ model_scores=None,
941
+ wait=False,
942
+ logs=False,
943
+ job_name=None,
944
+ kms_key=None,
945
+ ):
946
+ """Suggest baselines for use with Amazon SageMaker Model Monitoring Schedules.
947
+
948
+ Args:
949
+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
950
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
951
+ specific explainability method. Currently, only SHAP is supported.
952
+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
953
+ endpoint to be created.
954
+ model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
955
+ Index or JMESPath expression to locate the predicted scores in the model output.
956
+ This is not required if the model output is a single score. Alternatively,
957
+ it can be an instance of ModelPredictedLabelConfig to provide more parameters
958
+ like label_headers.
959
+ wait (bool): Whether the call should wait until the job completes (default: False).
960
+ logs (bool): Whether to show the logs produced by the job.
961
+ Only meaningful when wait is True (default: False).
962
+ job_name (str): Processing job name. If not specified, the processor generates
963
+ a default job name, based on the image name and current timestamp.
964
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
965
+ user code file (default: None).
966
+
967
+ Returns:
968
+ sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
969
+ baselining job.
970
+ """
971
+ baselining_processor = self._create_baselining_processor()
972
+ baselining_job_name = self._generate_baselining_job_name(job_name=job_name)
973
+ baselining_processor.run_explainability(
974
+ data_config=data_config,
975
+ model_config=model_config,
976
+ explainability_config=explainability_config,
977
+ model_scores=model_scores,
978
+ wait=wait,
979
+ logs=logs,
980
+ job_name=baselining_job_name,
981
+ kms_key=kms_key,
982
+ )
983
+
984
+ # Explainability analysis doesn't need label
985
+ headers = copy.deepcopy(data_config.headers)
986
+ if headers and data_config.label in headers:
987
+ headers.remove(data_config.label)
988
+ if model_scores is None:
989
+ inference_attribute = None
990
+ label_headers = None
991
+ elif isinstance(model_scores, ModelPredictedLabelConfig):
992
+ inference_attribute = str(model_scores.label)
993
+ label_headers = model_scores.label_headers
994
+ else:
995
+ inference_attribute = str(model_scores)
996
+ label_headers = None
997
+ self.latest_baselining_job_config = ClarifyBaseliningConfig(
998
+ analysis_config=ExplainabilityAnalysisConfig(
999
+ explainability_config=explainability_config,
1000
+ model_config=model_config,
1001
+ headers=headers,
1002
+ label_headers=label_headers,
1003
+ ),
1004
+ features_attribute=data_config.features,
1005
+ inference_attribute=inference_attribute,
1006
+ )
1007
+ self.latest_baselining_job_name = baselining_job_name
1008
+ self.latest_baselining_job = ClarifyBaseliningJob(
1009
+ processing_job=baselining_processor.latest_job
1010
+ )
1011
+
1012
+ self.baselining_jobs.append(self.latest_baselining_job)
1013
+ return baselining_processor.latest_job
1014
+
1015
+ # noinspection PyMethodOverriding
1016
+ def create_monitoring_schedule(
1017
+ self,
1018
+ endpoint_input=None,
1019
+ analysis_config=None,
1020
+ output_s3_uri=None,
1021
+ constraints=None,
1022
+ monitor_schedule_name=None,
1023
+ schedule_cron_expression=None,
1024
+ enable_cloudwatch_metrics=True,
1025
+ batch_transform_input=None,
1026
+ data_analysis_start_time=None,
1027
+ data_analysis_end_time=None,
1028
+ ):
1029
+ """Creates a monitoring schedule.
1030
+
1031
+ Args:
1032
+ endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
1033
+ This can either be the endpoint name or an EndpointInput. (default: None)
1034
+ analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
1035
+ the explainability job. If it is None then configuration of the latest baselining
1036
+ job will be reused, but if no baselining job then fail the call.
1037
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
1038
+ Default: "s3://<default_session_bucket>/<job_name>/output"
1039
+ constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
1040
+ for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
1041
+ to a constraints JSON file.
1042
+ monitor_schedule_name (str): Schedule name. If not specified, the processor generates
1043
+ a default job name, based on the image name and current timestamp.
1044
+ schedule_cron_expression (str): The cron expression that dictates the frequency that
1045
+ this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
1046
+ expressions. Default: Daily.
1047
+ enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
1048
+ the baselining or monitoring jobs.
1049
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1050
+ run the monitoring schedule on the batch transform
1051
+ data_analysis_start_time (str): Start time for the data analysis window
1052
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
1053
+ data_analysis_end_time (str): End time for the data analysis window
1054
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
1055
+ """
1056
+ if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
1057
+ message = (
1058
+ "It seems that this object was already used to create an Amazon Model "
1059
+ "Monitoring Schedule. To create another, first delete the existing one "
1060
+ "using my_monitor.delete_monitoring_schedule()."
1061
+ )
1062
+ logger.error(message)
1063
+ raise ValueError(message)
1064
+
1065
+ if (batch_transform_input is not None) ^ (endpoint_input is None):
1066
+ message = (
1067
+ "Need to have either batch_transform_input or endpoint_input to create an "
1068
+ "Amazon Model Monitoring Schedule."
1069
+ "Please provide only one of the above required inputs"
1070
+ )
1071
+ logger.error(message)
1072
+ raise ValueError(message)
1073
+
1074
+ self._check_monitoring_schedule_cron_validity(
1075
+ schedule_cron_expression=schedule_cron_expression,
1076
+ data_analysis_start_time=data_analysis_start_time,
1077
+ data_analysis_end_time=data_analysis_end_time,
1078
+ )
1079
+
1080
+ # create job definition
1081
+ monitor_schedule_name = self._generate_monitoring_schedule_name(
1082
+ schedule_name=monitor_schedule_name
1083
+ )
1084
+ new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
1085
+ request_dict = self._build_create_job_definition_request(
1086
+ monitoring_schedule_name=monitor_schedule_name,
1087
+ job_definition_name=new_job_definition_name,
1088
+ image_uri=self.image_uri,
1089
+ latest_baselining_job_name=self.latest_baselining_job_name,
1090
+ latest_baselining_job_config=self.latest_baselining_job_config,
1091
+ endpoint_input=endpoint_input,
1092
+ analysis_config=analysis_config,
1093
+ output_s3_uri=self._normalize_monitoring_output(
1094
+ monitor_schedule_name, output_s3_uri
1095
+ ).s3_output.s3_uri,
1096
+ constraints=constraints,
1097
+ enable_cloudwatch_metrics=enable_cloudwatch_metrics,
1098
+ role=self.role,
1099
+ instance_count=self.instance_count,
1100
+ instance_type=self.instance_type,
1101
+ volume_size_in_gb=self.volume_size_in_gb,
1102
+ volume_kms_key=self.volume_kms_key,
1103
+ output_kms_key=self.output_kms_key,
1104
+ max_runtime_in_seconds=self.max_runtime_in_seconds,
1105
+ env=self.env,
1106
+ tags=self.tags,
1107
+ network_config=self.network_config,
1108
+ batch_transform_input=batch_transform_input,
1109
+ )
1110
+ self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
1111
+ **request_dict
1112
+ )
1113
+
1114
+ # create schedule
1115
+ try:
1116
+ self._create_monitoring_schedule_from_job_definition(
1117
+ monitor_schedule_name=monitor_schedule_name,
1118
+ job_definition_name=new_job_definition_name,
1119
+ schedule_cron_expression=schedule_cron_expression,
1120
+ )
1121
+ self.job_definition_name = new_job_definition_name
1122
+ self.monitoring_schedule_name = monitor_schedule_name
1123
+ except Exception:
1124
+ logger.exception("Failed to create monitoring schedule.")
1125
+ self.monitoring_schedule_name = None
1126
+ # noinspection PyBroadException
1127
+ try:
1128
+ self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
1129
+ JobDefinitionName=new_job_definition_name
1130
+ )
1131
+ except Exception: # pylint: disable=W0703
1132
+ message = "Failed to delete job definition {}.".format(new_job_definition_name)
1133
+ logger.exception(message)
1134
+ raise
1135
+
1136
+ # noinspection PyMethodOverriding
1137
+ def update_monitoring_schedule(
1138
+ self,
1139
+ endpoint_input=None,
1140
+ analysis_config=None,
1141
+ output_s3_uri=None,
1142
+ constraints=None,
1143
+ schedule_cron_expression=None,
1144
+ enable_cloudwatch_metrics=None,
1145
+ role=None,
1146
+ instance_count=None,
1147
+ instance_type=None,
1148
+ volume_size_in_gb=None,
1149
+ volume_kms_key=None,
1150
+ output_kms_key=None,
1151
+ max_runtime_in_seconds=None,
1152
+ env=None,
1153
+ network_config=None,
1154
+ batch_transform_input=None,
1155
+ data_analysis_start_time=None,
1156
+ data_analysis_end_time=None,
1157
+ ):
1158
+ """Updates the existing monitoring schedule.
1159
+
1160
+ If more options than schedule_cron_expression are to be updated, a new job definition will
1161
+ be created to hold them. The old job definition will not be deleted.
1162
+
1163
+ Args:
1164
+ endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
1165
+ This can either be the endpoint name or an EndpointInput.
1166
+ analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
1167
+ If it is None then configuration of the latest baselining job will be reused, but
1168
+ if no baselining job then fail the call.
1169
+ output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
1170
+ Default: "s3://<default_session_bucket>/<job_name>/output"
1171
+ constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
1172
+ for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
1173
+ to a constraints JSON file.
1174
+ schedule_cron_expression (str): The cron expression that dictates the frequency that
1175
+ this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
1176
+ expressions. Default: Daily.
1177
+ enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
1178
+ the baselining or monitoring jobs.
1179
+ role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
1180
+ instance_count (int): The number of instances to run
1181
+ the jobs with.
1182
+ instance_type (str): Type of EC2 instance to use for
1183
+ the job, for example, 'ml.m5.xlarge'.
1184
+ volume_size_in_gb (int): Size in GB of the EBS volume
1185
+ to use for storing data during processing (default: 30).
1186
+ volume_kms_key (str): A KMS key for the job's volume.
1187
+ output_kms_key (str): The KMS key id for the job's outputs.
1188
+ max_runtime_in_seconds (int): Timeout in seconds. After this amount of
1189
+ time, Amazon SageMaker terminates the job regardless of its current status.
1190
+ Default: 3600
1191
+ env (dict): Environment variables to be passed to the job.
1192
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
1193
+ object that configures network isolation, encryption of
1194
+ inter-container traffic, security group IDs, and subnets.
1195
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1196
+ run the monitoring schedule on the batch transform
1197
+ data_analysis_start_time (str): Start time for the data analysis window
1198
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
1199
+ data_analysis_end_time (str): End time for the data analysis window
1200
+ for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None)
1201
+ """
1202
+ valid_args = {
1203
+ arg: value for arg, value in locals().items() if arg != "self" and value is not None
1204
+ }
1205
+
1206
+ # Nothing to update
1207
+ if len(valid_args) <= 0:
1208
+ raise ValueError("Nothing to update.")
1209
+
1210
+ if batch_transform_input is not None and endpoint_input is not None:
1211
+ message = (
1212
+ "Need to have either batch_transform_input or endpoint_input to create an "
1213
+ "Amazon Model Monitoring Schedule. "
1214
+ "Please provide only one of the above required inputs"
1215
+ )
1216
+ logger.error(message)
1217
+ raise ValueError(message)
1218
+
1219
+ # Only need to update schedule expression
1220
+ if len(valid_args) == 1 and schedule_cron_expression is not None:
1221
+ self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
1222
+ return
1223
+
1224
+ # Need to update schedule with a new job definition
1225
+ job_desc = (
1226
+ self.sagemaker_session.sagemaker_client.describe_model_explainability_job_definition(
1227
+ JobDefinitionName=self.job_definition_name
1228
+ )
1229
+ )
1230
+ new_job_definition_name = name_from_base(self.JOB_DEFINITION_BASE_NAME)
1231
+ request_dict = self._build_create_job_definition_request(
1232
+ monitoring_schedule_name=self.monitoring_schedule_name,
1233
+ job_definition_name=new_job_definition_name,
1234
+ image_uri=self.image_uri,
1235
+ existing_job_desc=job_desc,
1236
+ endpoint_input=endpoint_input,
1237
+ analysis_config=analysis_config,
1238
+ output_s3_uri=output_s3_uri,
1239
+ constraints=constraints,
1240
+ enable_cloudwatch_metrics=enable_cloudwatch_metrics,
1241
+ role=role,
1242
+ instance_count=instance_count,
1243
+ instance_type=instance_type,
1244
+ volume_size_in_gb=volume_size_in_gb,
1245
+ volume_kms_key=volume_kms_key,
1246
+ output_kms_key=output_kms_key,
1247
+ max_runtime_in_seconds=max_runtime_in_seconds,
1248
+ env=env,
1249
+ tags=self.tags,
1250
+ network_config=network_config,
1251
+ batch_transform_input=batch_transform_input,
1252
+ )
1253
+ self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
1254
+ **request_dict
1255
+ )
1256
+ try:
1257
+ self._update_monitoring_schedule(
1258
+ new_job_definition_name,
1259
+ schedule_cron_expression,
1260
+ data_analysis_start_time=data_analysis_start_time,
1261
+ data_analysis_end_time=data_analysis_end_time,
1262
+ )
1263
+ self.job_definition_name = new_job_definition_name
1264
+ if role is not None:
1265
+ self.role = role
1266
+ if instance_count is not None:
1267
+ self.instance_count = instance_count
1268
+ if instance_type is not None:
1269
+ self.instance_type = instance_type
1270
+ if volume_size_in_gb is not None:
1271
+ self.volume_size_in_gb = volume_size_in_gb
1272
+ if volume_kms_key is not None:
1273
+ self.volume_kms_key = volume_kms_key
1274
+ if output_kms_key is not None:
1275
+ self.output_kms_key = output_kms_key
1276
+ if max_runtime_in_seconds is not None:
1277
+ self.max_runtime_in_seconds = max_runtime_in_seconds
1278
+ if env is not None:
1279
+ self.env = env
1280
+ if network_config is not None:
1281
+ self.network_config = network_config
1282
+ except Exception:
1283
+ logger.exception("Failed to update monitoring schedule.")
1284
+ # noinspection PyBroadException
1285
+ try:
1286
+ self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
1287
+ JobDefinitionName=new_job_definition_name
1288
+ )
1289
+ except Exception: # pylint: disable=W0703
1290
+ message = "Failed to delete job definition {}.".format(new_job_definition_name)
1291
+ logger.exception(message)
1292
+ raise
1293
+
1294
+ def delete_monitoring_schedule(self):
1295
+ """Deletes the monitoring schedule and its job definition."""
1296
+ super(ModelExplainabilityMonitor, self).delete_monitoring_schedule()
1297
+ # Delete job definition.
1298
+ message = "Deleting Model Explainability Job Definition with name: {}".format(
1299
+ self.job_definition_name
1300
+ )
1301
+ logger.info(message)
1302
+ self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
1303
+ JobDefinitionName=self.job_definition_name
1304
+ )
1305
+ self.job_definition_name = None
1306
+
1307
+ @classmethod
1308
+ def attach(cls, monitor_schedule_name, sagemaker_session=None):
1309
+ """Sets this object's schedule name to the name provided.
1310
+
1311
+ This allows subsequent describe_schedule or list_executions calls to point
1312
+ to the given schedule.
1313
+
1314
+ Args:
1315
+ monitor_schedule_name (str): The name of the schedule to attach to.
1316
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
1317
+ manages interactions with Amazon SageMaker APIs and any other
1318
+ AWS services needed. If not specified, one is created using
1319
+ the default AWS configuration chain.
1320
+ """
1321
+ sagemaker_session = sagemaker_session or Session()
1322
+ schedule_desc = boto_describe_monitoring_schedule(
1323
+ sagemaker_session=sagemaker_session, monitoring_schedule_name=monitor_schedule_name
1324
+ )
1325
+ monitoring_type = schedule_desc["MonitoringScheduleConfig"].get("MonitoringType")
1326
+ if monitoring_type != cls.monitoring_type():
1327
+ raise TypeError(
1328
+ "{} can only attach to ModelExplainability schedule.".format(__class__.__name__)
1329
+ )
1330
+ job_definition_name = schedule_desc["MonitoringScheduleConfig"][
1331
+ "MonitoringJobDefinitionName"
1332
+ ]
1333
+ job_desc = sagemaker_session.sagemaker_client.describe_model_explainability_job_definition(
1334
+ JobDefinitionName=job_definition_name
1335
+ )
1336
+ tags = list_tags(
1337
+ sagemaker_session=sagemaker_session, resource_arn=schedule_desc["MonitoringScheduleArn"]
1338
+ )
1339
+ return ClarifyModelMonitor._attach(
1340
+ clazz=cls,
1341
+ sagemaker_session=sagemaker_session,
1342
+ schedule_desc=schedule_desc,
1343
+ job_desc=job_desc,
1344
+ tags=tags,
1345
+ )
1346
+
1347
+
1348
+ class ExplainabilityAnalysisConfig:
1349
+ """Analysis configuration for ModelExplainabilityMonitor."""
1350
+
1351
+ def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
1352
+ """Creates an analysis config dictionary.
1353
+
1354
+ Args:
1355
+ explainability_config (sagemaker.clarify.ExplainabilityConfig): Config object related
1356
+ to explainability configurations.
1357
+ model_config (sagemaker.clarify.ModelConfig): Config object related to bias
1358
+ configurations.
1359
+ headers (list[str]): A list of feature names (without label) of model/endpint input.
1360
+ label_headers (list[str]): List of headers, each for a predicted score in model output.
1361
+ It is used to beautify the analysis report by replacing placeholders like "label0".
1362
+
1363
+ """
1364
+ predictor_config = model_config.get_predictor_config()
1365
+ self.analysis_config = {
1366
+ "methods": explainability_config.get_explainability_config(),
1367
+ "predictor": predictor_config,
1368
+ }
1369
+ if headers is not None:
1370
+ self.analysis_config["headers"] = headers
1371
+ if label_headers is not None:
1372
+ predictor_config["label_headers"] = label_headers
1373
+
1374
+ def _to_dict(self):
1375
+ """Generates a request dictionary using the parameters provided to the class."""
1376
+ return self.analysis_config
1377
+
1378
+
1379
+ class ClarifyBaseliningConfig:
1380
+ """Data class to hold some essential analysis configuration of ClarifyBaseliningJob"""
1381
+
1382
+ def __init__(
1383
+ self,
1384
+ analysis_config,
1385
+ features_attribute=None,
1386
+ inference_attribute=None,
1387
+ probability_attribute=None,
1388
+ probability_threshold_attribute=None,
1389
+ ):
1390
+ """Initialization.
1391
+
1392
+ Args:
1393
+ analysis_config (BiasAnalysisConfig or ExplainabilityAnalysisConfig): analysis config
1394
+ from configurations of the baselining job.
1395
+ features_attribute (str): JMESPath expression to locate features in predictor request
1396
+ payload. Only required when predictor content type is JSONlines.
1397
+ inference_attribute (str): Index, header or JMESPath expression to locate predicted
1398
+ label in predictor response payload.
1399
+ probability_attribute (str): Index or JMESPath expression to locate probabilities or
1400
+ scores in the model output for computing feature attribution.
1401
+ probability_threshold_attribute (float): Value to indicate the threshold to select
1402
+ the binary label in the case of binary classification. Default is 0.5.
1403
+ """
1404
+ self.analysis_config = analysis_config
1405
+ self.features_attribute = features_attribute
1406
+ self.inference_attribute = inference_attribute
1407
+ self.probability_attribute = probability_attribute
1408
+ self.probability_threshold_attribute = probability_threshold_attribute
1409
+
1410
+
1411
+ class ClarifyBaseliningJob(mm.BaseliningJob):
1412
+ """Provides functionality to retrieve baseline-specific output from Clarify baselining job."""
1413
+
1414
+ def __init__(
1415
+ self,
1416
+ processing_job,
1417
+ ):
1418
+ """Initializes a ClarifyBaseliningJob that tracks a baselining job by suggest_baseline()
1419
+
1420
+ Args:
1421
+ processing_job (sagemaker.processing.ProcessingJob): The ProcessingJob used for
1422
+ baselining instance.
1423
+ """
1424
+ super(ClarifyBaseliningJob, self).__init__(
1425
+ sagemaker_session=processing_job.sagemaker_session,
1426
+ job_name=processing_job.job_name,
1427
+ inputs=processing_job.inputs,
1428
+ outputs=processing_job.outputs,
1429
+ output_kms_key=processing_job.output_kms_key,
1430
+ )
1431
+
1432
+ def baseline_statistics(self, **_):
1433
+ """Not implemented.
1434
+
1435
+ The class doesn't support statistics.
1436
+
1437
+ Raises:
1438
+ NotImplementedError
1439
+ """
1440
+ raise NotImplementedError("{} doesn't support statistics.".format(__class__.__name__))
1441
+
1442
+ def suggested_constraints(self, file_name=None, kms_key=None):
1443
+ """Returns a sagemaker.model_monitor.
1444
+
1445
+ Constraints object representing the constraints JSON file generated by this baselining job.
1446
+
1447
+ Args:
1448
+ file_name (str): Keep this parameter to align with method signature in super class,
1449
+ but it will be ignored.
1450
+ kms_key (str): The kms key to use when retrieving the file.
1451
+
1452
+ Returns:
1453
+ sagemaker.model_monitor.Constraints: The Constraints object representing the file that
1454
+ was generated by the job.
1455
+
1456
+ Raises:
1457
+ UnexpectedStatusException: This is thrown if the job is not in a 'Complete' state.
1458
+ """
1459
+ return super(ClarifyBaseliningJob, self).suggested_constraints("analysis.json", kms_key)
1460
+
1461
+
1462
+ class ClarifyMonitoringExecution(mm.MonitoringExecution):
1463
+ """Provides functionality to retrieve monitoring-specific files output from executions."""
1464
+
1465
+ def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key=None):
1466
+ """Initializes an object that tracks a monitoring execution by a Clarify model monitor
1467
+
1468
+ Args:
1469
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
1470
+ manages interactions with Amazon SageMaker APIs and any other
1471
+ AWS services needed. If not specified, one is created using
1472
+ the default AWS configuration chain.
1473
+ job_name (str): The name of the monitoring execution job.
1474
+ output (sagemaker.Processing.ProcessingOutput): The output associated with the
1475
+ monitoring execution.
1476
+ output_kms_key (str): The output kms key associated with the job. Defaults to None
1477
+ if not provided.
1478
+ """
1479
+ super(ClarifyMonitoringExecution, self).__init__(
1480
+ sagemaker_session=sagemaker_session,
1481
+ job_name=job_name,
1482
+ inputs=inputs,
1483
+ output=output,
1484
+ output_kms_key=output_kms_key,
1485
+ )
1486
+
1487
+ def statistics(self, **_):
1488
+ """Not implemented.
1489
+
1490
+ The class doesn't support statistics.
1491
+
1492
+ Raises:
1493
+ NotImplementedError
1494
+ """
1495
+ raise NotImplementedError("{} doesn't support statistics.".format(__class__.__name__))