sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2977 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ from __future__ import absolute_import, annotations, print_function
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import re
19
+ import sys
20
+ import time
21
+ import uuid
22
+ import six
23
+ import warnings
24
+ from functools import reduce
25
+ from typing import Dict, Any, Optional, List
26
+
27
+ import boto3
28
+ import botocore
29
+ import botocore.config
30
+ from botocore.exceptions import ClientError
31
+ from sagemaker.core import exceptions
32
+ from sagemaker.core.common_utils import (
33
+ secondary_training_status_changed,
34
+ secondary_training_status_message,
35
+ )
36
+ import sagemaker.core.logs
37
+ from sagemaker.core.session_settings import SessionSettings
38
+ from sagemaker.core.common_utils import (
39
+ secondary_training_status_changed,
40
+ secondary_training_status_message,
41
+ sts_regional_endpoint,
42
+ retries,
43
+ resolve_value_from_config,
44
+ get_sagemaker_config_value,
45
+ resolve_class_attribute_from_config,
46
+ resolve_nested_dict_value_from_config,
47
+ update_nested_dictionary_with_values_from_config,
48
+ update_list_of_dicts_with_values_from_config,
49
+ format_tags,
50
+ Tags,
51
+ TagsDict,
52
+ instance_supports_kms,
53
+ create_paginator_config,
54
+ )
55
+
56
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
57
+ from sagemaker.core._studio import _append_project_tags
58
+ from sagemaker.core.config.config import load_sagemaker_config, validate_sagemaker_config
59
+ from sagemaker.core.config.config_schema import (
60
+ KEY,
61
+ TRANSFORM_JOB,
62
+ TRANSFORM_JOB_ENVIRONMENT_PATH,
63
+ TRANSFORM_JOB_KMS_KEY_ID_PATH,
64
+ TRANSFORM_OUTPUT_KMS_KEY_ID_PATH,
65
+ VOLUME_KMS_KEY_ID,
66
+ TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH,
67
+ MODEL,
68
+ MODEL_CONTAINERS_PATH,
69
+ MODEL_EXECUTION_ROLE_ARN_PATH,
70
+ MODEL_ENABLE_NETWORK_ISOLATION_PATH,
71
+ MODEL_PRIMARY_CONTAINER_PATH,
72
+ MODEL_VPC_CONFIG_PATH,
73
+ ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
74
+ KMS_KEY_ID,
75
+ ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
76
+ ENDPOINT_CONFIG,
77
+ ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
78
+ ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
79
+ ENDPOINT_CONFIG_VPC_CONFIG_PATH,
80
+ ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
81
+ ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
82
+ ENDPOINT,
83
+ INFERENCE_COMPONENT,
84
+ SAGEMAKER,
85
+ TAGS,
86
+ SESSION_DEFAULT_S3_BUCKET_PATH,
87
+ SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
88
+ )
89
+
90
+ # Setting LOGGER for backward compatibility, in case users import it...
91
+ logger = LOGGER = logging.getLogger("sagemaker")
92
+
93
+ NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
94
+ MODEL_MONITOR_ONE_TIME_SCHEDULE = "NOW"
95
+ _STATUS_CODE_TABLE = {
96
+ "COMPLETED": "Completed",
97
+ "INPROGRESS": "InProgress",
98
+ "IN_PROGRESS": "InProgress",
99
+ "FAILED": "Failed",
100
+ "STOPPED": "Stopped",
101
+ "STOPPING": "Stopping",
102
+ "STARTING": "Starting",
103
+ "PENDING": "Pending",
104
+ }
105
+ EP_LOGGER_POLL = 30
106
+ DEFAULT_EP_POLL = 30
107
+
108
+
109
+ class LogState(object):
110
+ """Placeholder docstring"""
111
+
112
+ STARTING = 1
113
+ WAIT_IN_PROGRESS = 2
114
+ TAILING = 3
115
+ JOB_COMPLETE = 4
116
+ COMPLETE = 5
117
+
118
+
119
+ class Session(object): # pylint: disable=too-many-public-methods
120
+ """Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
121
+
122
+ This class provides convenient methods for manipulating entities and resources that Amazon
123
+ SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
124
+ AWS service calls are delegated to an underlying Boto3 session, which by default
125
+ is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
126
+ that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
127
+ bucket based on a naming convention which includes the current AWS account ID.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ boto_session=None,
133
+ sagemaker_client=None,
134
+ sagemaker_runtime_client=None,
135
+ sagemaker_featurestore_runtime_client=None,
136
+ default_bucket=None,
137
+ sagemaker_config: dict = None,
138
+ settings=None,
139
+ sagemaker_metrics_client=None,
140
+ default_bucket_prefix: str = None,
141
+ ):
142
+ """Initialize a SageMaker ``Session``.
143
+
144
+ Args:
145
+ boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
146
+ calls are delegated to (default: None). If not provided, one is created with
147
+ default AWS configuration chain.
148
+ sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
149
+ calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
150
+ ``Session`` use this client. If not provided, one will be created using this
151
+ instance's ``boto_session``.
152
+ sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
153
+ ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
154
+ using this ``Session`` use this client. If not provided, one will be created using
155
+ this instance's ``boto_session``.
156
+ sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client):
157
+ Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker
158
+ (default: None). If not provided, one will be created using
159
+ this instance's ``boto_session``.
160
+ default_bucket (str): The default Amazon S3 bucket to be used by this session.
161
+ This will be created the next time an Amazon S3 bucket is needed (by calling
162
+ :func:`default_bucket`).
163
+ If not provided, it will be fetched from the sagemaker_config. If not configured
164
+ there either, a default bucket will be created based on the following format:
165
+ "sagemaker-{region}-{aws-account-id}".
166
+ Example: "sagemaker-my-custom-bucket".
167
+ sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
168
+ Client which makes SageMaker Metrics related calls to Amazon SageMaker
169
+ (default: None). If not provided, one will be created using
170
+ this instance's ``boto_session``.
171
+ default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default:
172
+ None). If provided and where applicable, it will be used by the SDK to construct
173
+ default S3 URIs, in the format:
174
+ `s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>`
175
+ This parameter can also be specified via `{sagemaker_config}` instead of here. If
176
+ not provided here or within `{sagemaker_config}`, default S3 URIs will have the
177
+ format: `s3://{default_bucket}/<rest of object key>`
178
+ """
179
+
180
+ # sagemaker_config is validated and initialized inside :func:`_initialize`,
181
+ # so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
182
+ # _default_bucket_name_override will be set again inside :func:`_initialize`.
183
+ self.endpoint_arn = None
184
+ self._default_bucket = None
185
+ self._default_bucket_name_override = default_bucket
186
+ # this may also be set again inside :func:`_initialize` if it is None
187
+ self.default_bucket_prefix = default_bucket_prefix
188
+ self._default_bucket_set_by_sdk = False
189
+
190
+ self.s3_resource = None
191
+ self.s3_client = None
192
+ self.resource_groups_client = None
193
+ self.resource_group_tagging_client = None
194
+ self._config = None
195
+ self.lambda_client = None
196
+ self.settings = settings if settings else SessionSettings()
197
+
198
+ self._initialize(
199
+ boto_session=boto_session,
200
+ sagemaker_client=sagemaker_client,
201
+ sagemaker_runtime_client=sagemaker_runtime_client,
202
+ sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
203
+ sagemaker_metrics_client=sagemaker_metrics_client,
204
+ sagemaker_config=sagemaker_config,
205
+ )
206
+
207
+ def _initialize(
208
+ self,
209
+ boto_session,
210
+ sagemaker_client,
211
+ sagemaker_runtime_client,
212
+ sagemaker_featurestore_runtime_client,
213
+ sagemaker_metrics_client,
214
+ sagemaker_config: dict = None,
215
+ ):
216
+ """Initialize this SageMaker Session.
217
+
218
+ Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
219
+ Sets the region_name.
220
+ """
221
+
222
+ self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
223
+
224
+ self._region_name = self.boto_session.region_name
225
+ if self._region_name is None:
226
+ raise ValueError(
227
+ "Must setup local AWS configuration with a region supported by SageMaker."
228
+ )
229
+
230
+ # Make use of user_agent_extra field of the botocore_config object
231
+ # to append SageMaker Python SDK specific user_agent suffix
232
+ # to the current User-Agent header value from boto3
233
+ # This config will also make sure that user_agent never fails to log the User-Agent string
234
+ # even if boto User-Agent header format is updated in the future
235
+ # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
236
+
237
+ # Create sagemaker_client with the botocore_config object
238
+ # This config is customized to append SageMaker Python SDK specific user_agent suffix
239
+ if sagemaker_client is not None:
240
+ self.sagemaker_client = sagemaker_client
241
+ else:
242
+ from sagemaker.core.user_agent import get_user_agent_extra_suffix
243
+ config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
244
+ self.sagemaker_client = self.boto_session.client("sagemaker", config=config)
245
+
246
+ if sagemaker_runtime_client is not None:
247
+ self.sagemaker_runtime_client = sagemaker_runtime_client
248
+ else:
249
+ config = botocore.config.Config(read_timeout=80)
250
+ self.sagemaker_runtime_client = self.boto_session.client(
251
+ "runtime.sagemaker", config=config
252
+ )
253
+
254
+ if sagemaker_featurestore_runtime_client:
255
+ self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
256
+ else:
257
+ self.sagemaker_featurestore_runtime_client = self.boto_session.client(
258
+ "sagemaker-featurestore-runtime"
259
+ )
260
+
261
+ if sagemaker_metrics_client:
262
+ self.sagemaker_metrics_client = sagemaker_metrics_client
263
+ else:
264
+ self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
265
+
266
+ self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
267
+ self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
268
+
269
+ self.local_mode = False
270
+
271
+ if sagemaker_config:
272
+ validate_sagemaker_config(sagemaker_config)
273
+ self.sagemaker_config = sagemaker_config
274
+ else:
275
+ # self.s3_resource might be None. If it is None, load_sagemaker_config will
276
+ # create a default S3 resource, but only if it needs to fetch from S3
277
+ self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
278
+
279
+ # after sagemaker_config initialization, update self._default_bucket_name_override if needed
280
+ self._default_bucket_name_override = resolve_value_from_config(
281
+ direct_input=self._default_bucket_name_override,
282
+ config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
283
+ sagemaker_session=self,
284
+ )
285
+ # after sagemaker_config initialization, update self.default_bucket_prefix if needed
286
+ self.default_bucket_prefix = resolve_value_from_config(
287
+ direct_input=self.default_bucket_prefix,
288
+ config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
289
+ sagemaker_session=self,
290
+ )
291
+
292
+ def account_id(self) -> str:
293
+ """Get the AWS account id of the caller.
294
+
295
+ Returns:
296
+ AWS account ID.
297
+ """
298
+ region = self.boto_session.region_name
299
+ sts_client = self.boto_session.client(
300
+ "sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
301
+ )
302
+ return sts_client.get_caller_identity()["Account"]
303
+
304
+ @property
305
+ def config(self) -> Dict | None:
306
+ """The config for the local mode, unused in a normal session"""
307
+ return self._config
308
+
309
+ @config.setter
310
+ def config(self, value: Dict | None):
311
+ """The config for the local mode, unused in a normal session"""
312
+ self._config = value
313
+
314
+ @property
315
+ def boto_region_name(self):
316
+ """Placeholder docstring"""
317
+ return self._region_name
318
+
319
+ def get_caller_identity_arn(self):
320
+ """Returns the ARN user or role whose credentials are used to call the API.
321
+
322
+ Returns:
323
+ str: The ARN user or role
324
+ """
325
+ if os.path.exists(NOTEBOOK_METADATA_FILE):
326
+ with open(NOTEBOOK_METADATA_FILE, "rb") as f:
327
+ metadata = json.loads(f.read())
328
+ instance_name = metadata.get("ResourceName")
329
+ domain_id = metadata.get("DomainId")
330
+ user_profile_name = metadata.get("UserProfileName")
331
+ execution_role_arn = metadata.get("ExecutionRoleArn")
332
+ try:
333
+ # find execution role from the metadata file if present
334
+ if execution_role_arn is not None:
335
+ return execution_role_arn
336
+
337
+ if domain_id is None:
338
+ instance_desc = self.sagemaker_client.describe_notebook_instance(
339
+ NotebookInstanceName=instance_name
340
+ )
341
+ return instance_desc["RoleArn"]
342
+
343
+ user_profile_desc = self.sagemaker_client.describe_user_profile(
344
+ DomainId=domain_id, UserProfileName=user_profile_name
345
+ )
346
+
347
+ # First, try to find role in userSettings
348
+ if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"):
349
+ return user_profile_desc["UserSettings"]["ExecutionRole"]
350
+
351
+ # If not found, fallback to the domain
352
+ domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
353
+ return domain_desc["DefaultUserSettings"]["ExecutionRole"]
354
+ except ClientError:
355
+ logger.debug(
356
+ "Couldn't call 'describe_notebook_instance' to get the Role "
357
+ "ARN of the instance %s.",
358
+ instance_name,
359
+ )
360
+
361
+ assumed_role = self.boto_session.client(
362
+ "sts",
363
+ region_name=self.boto_region_name,
364
+ endpoint_url=sts_regional_endpoint(self.boto_region_name),
365
+ ).get_caller_identity()["Arn"]
366
+
367
+ role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
368
+
369
+ # Call IAM to get the role's path
370
+ role_name = role[role.rfind("/") + 1 :]
371
+ try:
372
+ role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
373
+ except ClientError:
374
+ logger.warning(
375
+ "Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.",
376
+ role_name,
377
+ )
378
+
379
+ # This conditional has been present since the inception of SageMaker
380
+ # Guessing this conditional's purpose was to handle lack of IAM permissions
381
+ # https://github.com/aws/sagemaker-python-sdk/issues/2089#issuecomment-791802713
382
+ if "AmazonSageMaker-ExecutionRole" in assumed_role:
383
+ logger.warning(
384
+ "Assuming role was created in SageMaker AWS console, "
385
+ "as the name contains `AmazonSageMaker-ExecutionRole`. "
386
+ "Defaulting to Role ARN with service-role in path. "
387
+ "If this Role ARN is incorrect, please add "
388
+ "IAM read permissions to your role or supply the "
389
+ "Role Arn directly."
390
+ )
391
+ role = re.sub(
392
+ r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
393
+ r"\1iam::\2:role/service-role/\3",
394
+ assumed_role,
395
+ )
396
+
397
+ return role
398
+
399
+ def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra_args=None):
400
+ """Upload local file or directory to S3.
401
+
402
+ If a single file is specified for upload, the resulting S3 object key is
403
+ ``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
404
+ If a directory is specified for upload, the API uploads all content, recursively,
405
+ preserving relative structure of subdirectories. The resulting object key names are:
406
+ ``{key_prefix}/{relative_subdirectory_path}/filename``.
407
+
408
+ Args:
409
+ path (str): Path (absolute or relative) of local file or directory to upload.
410
+ bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
411
+ default bucket of the ``Session`` is used (if default bucket does not exist, the
412
+ ``Session`` creates it).
413
+ key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
414
+ prefix to create a directory structure for the bucket content that it display in
415
+ the S3 console.
416
+ extra_args (dict): Optional extra arguments that may be passed to the upload operation.
417
+ Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
418
+ ExtraArgs parameter documentation here:
419
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
420
+
421
+ Returns:
422
+ str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
423
+ the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
424
+ If a directory is specified in the path argument, the URI format is
425
+ ``s3://{bucket name}/{key_prefix}``.
426
+ """
427
+ bucket, key_prefix = self.determine_bucket_and_prefix(
428
+ bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
429
+ )
430
+
431
+ # Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
432
+ files = []
433
+ key_suffix = None
434
+ if os.path.isdir(path):
435
+ for dirpath, _, filenames in os.walk(path):
436
+ for name in filenames:
437
+ local_path = os.path.join(dirpath, name)
438
+ s3_relative_prefix = (
439
+ "" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
440
+ )
441
+ s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
442
+ files.append((local_path, s3_key))
443
+ else:
444
+ _, name = os.path.split(path)
445
+ s3_key = "{}/{}".format(key_prefix, name)
446
+ files.append((path, s3_key))
447
+ key_suffix = name
448
+
449
+ if self.s3_resource is None:
450
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
451
+ else:
452
+ s3 = self.s3_resource
453
+
454
+ for local_path, s3_key in files:
455
+ s3.Object(bucket, s3_key).upload_file(
456
+ local_path, Callback=callback, ExtraArgs=extra_args
457
+ )
458
+
459
+ s3_uri = "s3://{}/{}".format(bucket, key_prefix)
460
+ # If a specific file was used as input (instead of a directory), we return the full S3 key
461
+ # of the uploaded object. This prevents unintentionally using other files under the same
462
+ # prefix during training.
463
+ if key_suffix:
464
+ s3_uri = "{}/{}".format(s3_uri, key_suffix)
465
+ return s3_uri
466
+
467
+ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
468
+ """Upload a string as a file body.
469
+ Args:
470
+ body (str): String representing the body of the file.
471
+ bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
472
+ default bucket of the ``Session`` is used (if default bucket does not exist, the
473
+ ``Session`` creates it).
474
+ key (str): S3 object key. This is the s3 path to the file.
475
+ kms_key (str): The KMS key to use for encrypting the file.
476
+ Returns:
477
+ str: The S3 URI of the uploaded file.
478
+ The URI format is: ``s3://{bucket name}/{key}``.
479
+ """
480
+ if self.s3_resource is None:
481
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
482
+ else:
483
+ s3 = self.s3_resource
484
+
485
+ s3_object = s3.Object(bucket_name=bucket, key=key)
486
+
487
+ if kms_key is not None:
488
+ s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
489
+ else:
490
+ s3_object.put(Body=body)
491
+
492
+ s3_uri = "s3://{}/{}".format(bucket, key)
493
+ return s3_uri
494
+
495
+ def download_data(self, path, bucket, key_prefix="", extra_args=None):
496
+ """Download file or directory from S3.
497
+ Args:
498
+ path (str): Local path where the file or directory should be downloaded to.
499
+ bucket (str): Name of the S3 Bucket to download from.
500
+ key_prefix (str): Optional S3 object key name prefix.
501
+ extra_args (dict): Optional extra arguments that may be passed to the
502
+ download operation. Please refer to the ExtraArgs parameter in the boto3
503
+ documentation here:
504
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
505
+ Returns:
506
+ list[str]: List of local paths of downloaded files
507
+ """
508
+ # Initialize the S3 client.
509
+ if self.s3_client is None:
510
+ s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
511
+ else:
512
+ s3 = self.s3_client
513
+
514
+ # Initialize the variables used to loop through the contents of the S3 bucket.
515
+ keys = []
516
+ directories = []
517
+ next_token = ""
518
+ base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
519
+
520
+ # Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
521
+ # a "keys" list.
522
+ while next_token is not None:
523
+ request_parameters = base_parameters.copy()
524
+ if next_token != "":
525
+ request_parameters.update({"ContinuationToken": next_token})
526
+ response = s3.list_objects_v2(**request_parameters)
527
+ contents = response.get("Contents", None)
528
+ if not contents:
529
+ logger.info(
530
+ "Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
531
+ )
532
+ return []
533
+ # For each object, save its key or directory.
534
+ for s3_object in contents:
535
+ key: str = s3_object.get("Key")
536
+ obj_size = s3_object.get("Size")
537
+ if key.endswith("/") and int(obj_size) == 0:
538
+ directories.append(os.path.join(path, key))
539
+ else:
540
+ keys.append(key)
541
+ next_token = response.get("NextContinuationToken")
542
+
543
+ # For each object key, create the directory on the local machine if needed, and then
544
+ # download the file.
545
+ downloaded_paths = []
546
+ for dir_path in directories:
547
+ os.makedirs(os.path.dirname(dir_path), exist_ok=True)
548
+ for key in keys:
549
+ tail_s3_uri_path = os.path.basename(key)
550
+ if not os.path.splitext(key_prefix)[1]:
551
+ tail_s3_uri_path = os.path.relpath(key, key_prefix)
552
+ destination_path = os.path.join(path, tail_s3_uri_path)
553
+ if not os.path.exists(os.path.dirname(destination_path)):
554
+ os.makedirs(os.path.dirname(destination_path), exist_ok=True)
555
+ s3.download_file(
556
+ Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
557
+ )
558
+ downloaded_paths.append(destination_path)
559
+ return downloaded_paths
560
+
561
+ def read_s3_file(self, bucket, key_prefix):
562
+ """Read a single file from S3.
563
+
564
+ Args:
565
+ bucket (str): Name of the S3 Bucket to download from.
566
+ key_prefix (str): S3 object key name prefix.
567
+
568
+ Returns:
569
+ str: The body of the s3 file as a string.
570
+ """
571
+ if self.s3_client is None:
572
+ s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
573
+ else:
574
+ s3 = self.s3_client
575
+
576
+ # Explicitly passing a None kms_key to boto3 throws a validation error.
577
+ s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
578
+
579
+ return s3_object["Body"].read().decode("utf-8")
580
+
581
+ def list_s3_files(self, bucket, key_prefix):
582
+ """Lists the S3 files given an S3 bucket and key.
583
+ Args:
584
+ bucket (str): Name of the S3 Bucket to download from.
585
+ key_prefix (str): S3 object key name prefix.
586
+ Returns:
587
+ [str]: The list of files at the S3 path.
588
+ """
589
+ if self.s3_resource is None:
590
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
591
+ else:
592
+ s3 = self.s3_resource
593
+
594
+ s3_bucket = s3.Bucket(name=bucket)
595
+ s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
596
+ return [s3_object.key for s3_object in s3_objects]
597
+
598
+ def default_bucket(self):
599
+ """Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
600
+
601
+ This function will create the s3 bucket if it does not exist.
602
+
603
+ Returns:
604
+ str: The name of the default bucket. If the name was not explicitly specified through
605
+ the Session or sagemaker_config, the bucket will take the form:
606
+ ``sagemaker-{region}-{AWS account ID}``.
607
+ """
608
+
609
+ if self._default_bucket:
610
+ return self._default_bucket
611
+
612
+ region = self.boto_session.region_name
613
+
614
+ default_bucket = self._default_bucket_name_override
615
+ if not default_bucket:
616
+ default_bucket = self.generate_default_sagemaker_bucket_name(self.boto_session)
617
+ self._default_bucket_set_by_sdk = True
618
+
619
+ self._create_s3_bucket_if_it_does_not_exist(
620
+ bucket_name=default_bucket,
621
+ region=region,
622
+ )
623
+
624
+ self._default_bucket = default_bucket
625
+
626
+ return self._default_bucket
627
+
628
+ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
629
+ """Creates an S3 Bucket if it does not exist.
630
+
631
+ Also swallows a few common exceptions that indicate that the bucket already exists or
632
+ that it is being created.
633
+
634
+ Args:
635
+ bucket_name (str): Name of the S3 bucket to be created.
636
+ region (str): The region in which to create the bucket.
637
+
638
+ Raises:
639
+ botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
640
+ creation.
641
+ If the exception is due to the bucket already existing or
642
+ already being created, no exception is raised.
643
+ """
644
+ if self.s3_resource is None:
645
+ s3 = self.boto_session.resource("s3", region_name=region)
646
+ else:
647
+ s3 = self.s3_resource
648
+
649
+ bucket = s3.Bucket(name=bucket_name)
650
+ if bucket.creation_date is None:
651
+ self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
652
+
653
+ elif self._default_bucket_set_by_sdk:
654
+ self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
655
+
656
+ expected_bucket_owner_id = self.account_id()
657
+ self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
658
+
659
+ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
660
+ """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
661
+
662
+ Args:
663
+ bucket_name (str): Name of the S3 bucket
664
+ s3 (str): S3 object from boto session
665
+ expected_bucket_owner_id (str): Owner ID string
666
+
667
+ """
668
+ try:
669
+ if self.default_bucket_prefix:
670
+ s3.meta.client.list_objects_v2(
671
+ Bucket=bucket_name,
672
+ Prefix=self.default_bucket_prefix,
673
+ ExpectedBucketOwner=expected_bucket_owner_id,
674
+ )
675
+ else:
676
+ s3.meta.client.head_bucket(
677
+ Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
678
+ )
679
+ except ClientError as e:
680
+ error_code = e.response["Error"]["Code"]
681
+ message = e.response["Error"]["Message"]
682
+ if error_code == "403" and message == "Forbidden":
683
+ LOGGER.error(
684
+ "Since default_bucket param was not set, SageMaker Python SDK tried to use "
685
+ "%s bucket. "
686
+ "This bucket cannot be configured to use as it is not owned by Account %s. "
687
+ "To unblock it's recommended to use custom default_bucket "
688
+ "parameter in sagemaker.Session",
689
+ bucket_name,
690
+ expected_bucket_owner_id,
691
+ )
692
+ raise
693
+
694
+ def general_bucket_check_if_user_has_permission(
695
+ self, bucket_name, s3, bucket, region, bucket_creation_date_none
696
+ ):
697
+ """Checks if the person running has the permissions to the bucket
698
+
699
+ If there is any other error that comes up with calling head bucket, it is raised up here
700
+ If there is no bucket , it will create one
701
+
702
+ Args:
703
+ bucket_name (str): Name of the S3 bucket
704
+ s3 (str): S3 object from boto session
705
+ region (str): The region in which to create the bucket.
706
+ bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
707
+ """
708
+ try:
709
+ if self.default_bucket_prefix:
710
+ s3.meta.client.list_objects_v2(
711
+ Bucket=bucket_name, Prefix=self.default_bucket_prefix
712
+ )
713
+ else:
714
+ s3.meta.client.head_bucket(Bucket=bucket_name)
715
+ except ClientError as e:
716
+ error_code = e.response["Error"]["Code"]
717
+ message = e.response["Error"]["Message"]
718
+ # bucket does not exist or forbidden to access
719
+ if bucket_creation_date_none:
720
+ if error_code == "404" and message == "Not Found":
721
+ self.create_bucket_for_not_exist_error(bucket_name, region, s3)
722
+ elif error_code == "403" and message == "Forbidden":
723
+ LOGGER.error(
724
+ "Bucket %s exists, but access is forbidden. Please try again after "
725
+ "adding appropriate access.",
726
+ bucket.name,
727
+ )
728
+ raise
729
+ else:
730
+ raise
731
+
732
+ def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
733
+ """Creates the S3 bucket in the given region
734
+
735
+ Args:
736
+ bucket_name (str): Name of the S3 bucket
737
+ s3 (str): S3 object from boto session
738
+ region (str): The region in which to create the bucket.
739
+ """
740
+ # bucket does not exist, create one
741
+ try:
742
+ if region == "us-east-1":
743
+ # 'us-east-1' cannot be specified because it is the default region:
744
+ # https://github.com/boto/boto3/issues/125
745
+ s3.create_bucket(Bucket=bucket_name)
746
+ else:
747
+ s3.create_bucket(
748
+ Bucket=bucket_name,
749
+ CreateBucketConfiguration={"LocationConstraint": region},
750
+ )
751
+
752
+ logger.info("Created S3 bucket: %s", bucket_name)
753
+ except ClientError as e:
754
+ error_code = e.response["Error"]["Code"]
755
+ message = e.response["Error"]["Message"]
756
+
757
+ if error_code == "OperationAborted" and "conflicting conditional operation" in message:
758
+ # If this bucket is already being concurrently created,
759
+ # we don't need to create it again.
760
+ pass
761
+ else:
762
+ raise
763
+
764
+ def generate_default_sagemaker_bucket_name(self, boto_session):
765
+ """Generates a name for the default sagemaker S3 bucket.
766
+
767
+ Args:
768
+ boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
769
+ """
770
+ region = boto_session.region_name
771
+ account = boto_session.client(
772
+ "sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
773
+ ).get_caller_identity()["Account"]
774
+ return "sagemaker-{}-{}".format(region, account)
775
+
776
+ def determine_bucket_and_prefix(
777
+ self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
778
+ ):
779
+ """Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
780
+
781
+ Args:
782
+ bucket (Optional[str]): S3 Bucket to use (if it exists)
783
+ key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
784
+ sagemaker_session (sagemaker.core.session.Session): Session to fetch a default bucket and
785
+ prefix from, if bucket doesn't exist. Expected to exist
786
+
787
+ Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
788
+ """
789
+ if bucket:
790
+ final_bucket = bucket
791
+ final_key_prefix = key_prefix
792
+ else:
793
+ final_bucket = sagemaker_session.default_bucket()
794
+
795
+ # default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
796
+ # exist and we are using the Session's default_bucket.
797
+ final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
798
+
799
+ # We should not append default_bucket_prefix even if the bucket exists but is equal to the
800
+ # default_bucket, because either:
801
+ # (1) the bucket was explicitly passed in by the user and just happens to be the same as the
802
+ # default_bucket (in which case we don't want to change the user's input), or
803
+ # (2) the default_bucket was fetched from Session earlier already (and the default prefix
804
+ # should have been fetched then as well), and then this function was
805
+ # called with it. If we appended the default prefix here, we would be appending it more than
806
+ # once in total.
807
+
808
+ return final_bucket, final_key_prefix
809
+
810
+ def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
811
+ """Appends tags specified in the sagemaker_config to the given list of tags.
812
+
813
+ To minimize the chance of duplicate tags being applied, this is intended to be used
814
+ immediately before calls to sagemaker_client, rather than during initialization of
815
+ classes like EstimatorBase.
816
+
817
+ Args:
818
+ tags: The list of tags to append to.
819
+ config_path_to_tags: The path to look up tags in the config.
820
+
821
+ Returns:
822
+ A list of tags.
823
+ """
824
+ config_tags = get_sagemaker_config_value(self, config_path_to_tags)
825
+
826
+ if config_tags is None or len(config_tags) == 0:
827
+ return tags
828
+
829
+ all_tags = tags or []
830
+ for config_tag in config_tags:
831
+ config_tag_key = config_tag[KEY]
832
+ if not any(tag.get("Key", None) == config_tag_key for tag in all_tags):
833
+ # This check prevents new tags with duplicate keys from being added
834
+ # (to prevent API failure and/or overwriting of tags). If there is a conflict,
835
+ # the user-provided tag should take precedence over the config-provided tag.
836
+ # Note: this does not check user-provided tags for conflicts with other
837
+ # user-provided tags.
838
+ all_tags.append(config_tag)
839
+
840
+ _log_sagemaker_config_merge(
841
+ source_value=tags,
842
+ config_value=config_tags,
843
+ merged_source_and_config_value=all_tags,
844
+ config_key_path=config_path_to_tags,
845
+ )
846
+
847
+ return all_tags
848
+
849
+ def endpoint_from_production_variants(
850
+ self,
851
+ name,
852
+ production_variants,
853
+ tags=None,
854
+ kms_key=None,
855
+ wait=True,
856
+ data_capture_config_dict=None,
857
+ async_inference_config_dict=None,
858
+ explainer_config_dict=None,
859
+ live_logging=False,
860
+ vpc_config=None,
861
+ enable_network_isolation=None,
862
+ role=None,
863
+ ):
864
+ """Create an SageMaker ``Endpoint`` from a list of production variants.
865
+
866
+ Args:
867
+ name (str): The name of the ``Endpoint`` to create.
868
+ production_variants (list[dict[str, str]]): The list of production variants to deploy.
869
+ tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
870
+ (default: None).
871
+ kms_key (str): The KMS key that is used to encrypt the data on the storage volume
872
+ attached to the instance hosting the endpoint.
873
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
874
+ (default: True).
875
+ data_capture_config_dict (dict): Specifies configuration related to Endpoint data
876
+ capture for use with Amazon SageMaker Model Monitoring. Default: None.
877
+ async_inference_config_dict (dict) : specifies configuration related to async endpoint.
878
+ Use this configuration when trying to create async endpoint and make async inference
879
+ (default: None)
880
+ explainer_config_dict (dict) : Specifies configuration related to explainer.
881
+ Use this configuration when trying to use online explainability.
882
+ (default: None).
883
+ vpc_config (dict[str, list[str]]:
884
+ The VpcConfig set on the model (default: None).
885
+ * 'Subnets' (list[str]): List of subnet ids.
886
+ * 'SecurityGroupIds' (list[str]): List of security group ids.
887
+ enable_network_isolation (Boolean): Default False.
888
+ If True, enables network isolation in the endpoint, isolating the model
889
+ container. No inbound or outbound network calls can be made to
890
+ or from the model container.
891
+ role (str): An AWS IAM role (either name or full ARN). The Amazon
892
+ SageMaker training jobs and APIs that create Amazon SageMaker
893
+ endpoints use this role to access training data and model
894
+ artifacts. After the endpoint is created, the inference code
895
+ might use the IAM role if it needs to access some AWS resources.
896
+ (default: None).
897
+ Returns:
898
+ str: The name of the created ``Endpoint``.
899
+ """
900
+
901
+ supports_kms = any(
902
+ [
903
+ instance_supports_kms(production_variant["InstanceType"])
904
+ for production_variant in production_variants
905
+ if "InstanceType" in production_variant
906
+ ]
907
+ )
908
+
909
+ update_list_of_dicts_with_values_from_config(
910
+ production_variants,
911
+ ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
912
+ required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
913
+ sagemaker_session=self,
914
+ )
915
+
916
+ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
917
+
918
+ kms_key = (
919
+ resolve_value_from_config(
920
+ kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
921
+ )
922
+ if supports_kms
923
+ else kms_key
924
+ )
925
+
926
+ vpc_config = resolve_value_from_config(
927
+ vpc_config,
928
+ ENDPOINT_CONFIG_VPC_CONFIG_PATH,
929
+ sagemaker_session=self,
930
+ )
931
+
932
+ enable_network_isolation = resolve_value_from_config(
933
+ enable_network_isolation,
934
+ ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
935
+ sagemaker_session=self,
936
+ )
937
+
938
+ role = resolve_value_from_config(
939
+ role,
940
+ ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
941
+ sagemaker_session=self,
942
+ sagemaker_config=load_sagemaker_config() if (self is None) else None,
943
+ )
944
+
945
+ # For Amazon SageMaker inference component based endpoint, it will not pass
946
+ # Model names during endpoint creation. Instead, ExecutionRoleArn will be
947
+ # needed in the endpoint config to create Endpoint
948
+ model_names = [pv["ModelName"] for pv in production_variants if "ModelName" in pv]
949
+ if len(model_names) == 0:
950
+ # Currently, SageMaker Python SDK allow using RoleName to deploy models.
951
+ # Use expand_role method to handle this situation.
952
+ role = self.expand_role(role)
953
+ config_options["ExecutionRoleArn"] = role
954
+ endpoint_config_tags = _append_project_tags(format_tags(tags))
955
+ endpoint_tags = _append_project_tags(format_tags(tags))
956
+
957
+ endpoint_config_tags = self._append_sagemaker_config_tags(
958
+ endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
959
+ )
960
+ if endpoint_config_tags:
961
+ config_options["Tags"] = endpoint_config_tags
962
+ if kms_key:
963
+ config_options["KmsKeyId"] = kms_key
964
+ if data_capture_config_dict is not None:
965
+ inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config(
966
+ data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self
967
+ )
968
+ config_options["DataCaptureConfig"] = inferred_data_capture_config_dict
969
+ if async_inference_config_dict is not None:
970
+ inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
971
+ async_inference_config_dict,
972
+ ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
973
+ sagemaker_session=self,
974
+ )
975
+ config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
976
+ if explainer_config_dict is not None:
977
+ config_options["ExplainerConfig"] = explainer_config_dict
978
+ if vpc_config is not None:
979
+ config_options["VpcConfig"] = vpc_config
980
+ if enable_network_isolation is not None:
981
+ config_options["EnableNetworkIsolation"] = enable_network_isolation
982
+ if role is not None:
983
+ config_options["ExecutionRoleArn"] = role
984
+
985
+ logger.info("Creating endpoint-config with name %s", name)
986
+ self.sagemaker_client.create_endpoint_config(**config_options)
987
+
988
+ return self.create_endpoint(
989
+ endpoint_name=name,
990
+ config_name=name,
991
+ tags=endpoint_tags,
992
+ wait=wait,
993
+ live_logging=live_logging,
994
+ )
995
+
996
+ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live_logging=False):
997
+ """Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request.
998
+
999
+ Once the ``Endpoint`` is created, client applications can send requests to obtain
1000
+ inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API.
1001
+
1002
+ Args:
1003
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created.
1004
+ config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
1005
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
1006
+ (default: True).
1007
+ tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
1008
+ (default: None).
1009
+
1010
+ Returns:
1011
+ str: Name of the Amazon SageMaker ``Endpoint`` created.
1012
+
1013
+ Raises:
1014
+ botocore.exceptions.ClientError: If Sagemaker throws an exception while creating
1015
+ endpoint.
1016
+ """
1017
+ logger.info("Creating endpoint with name %s", endpoint_name)
1018
+
1019
+ tags = format_tags(tags) or []
1020
+ tags = _append_project_tags(tags)
1021
+ tags = self._append_sagemaker_config_tags(
1022
+ tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS)
1023
+ )
1024
+ try:
1025
+ res = self.sagemaker_client.create_endpoint(
1026
+ EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
1027
+ )
1028
+ if res:
1029
+ self.endpoint_arn = res["EndpointArn"]
1030
+
1031
+ if wait:
1032
+ self.wait_for_endpoint(endpoint_name, live_logging=live_logging)
1033
+ return endpoint_name
1034
+ except Exception as e:
1035
+ troubleshooting = (
1036
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
1037
+ "sagemaker-python-sdk-troubleshooting.html"
1038
+ "#sagemaker-python-sdk-troubleshooting-create-endpoint"
1039
+ )
1040
+ logger.error(
1041
+ "Please check the troubleshooting guide for common errors: %s", troubleshooting
1042
+ )
1043
+ raise e
1044
+
1045
+ def wait_for_endpoint(self, endpoint, poll=DEFAULT_EP_POLL, live_logging=False):
1046
+ """Wait for an Amazon SageMaker endpoint deployment to complete.
1047
+
1048
+ Args:
1049
+ endpoint (str): Name of the ``Endpoint`` to wait for.
1050
+ poll (int): Polling interval in seconds (default: 30).
1051
+
1052
+ Raises:
1053
+ exceptions.CapacityError: If the endpoint creation job fails with CapacityError.
1054
+ exceptions.UnexpectedStatusException: If the endpoint creation job fails.
1055
+
1056
+ Returns:
1057
+ dict: Return value from the ``DescribeEndpoint`` API.
1058
+ """
1059
+
1060
+ if not live_logging or not _has_permission_for_live_logging(self.boto_session, endpoint):
1061
+ desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll)
1062
+ else:
1063
+ cloudwatch_client = self.boto_session.client("logs")
1064
+ paginator = cloudwatch_client.get_paginator("filter_log_events")
1065
+ paginator_config = create_paginator_config()
1066
+ desc = _wait_until(
1067
+ lambda: _live_logging_deploy_done(
1068
+ self.sagemaker_client, endpoint, paginator, paginator_config, EP_LOGGER_POLL
1069
+ ),
1070
+ poll=EP_LOGGER_POLL,
1071
+ )
1072
+ status = desc["EndpointStatus"]
1073
+
1074
+ if status != "InService":
1075
+ reason = desc.get("FailureReason", None)
1076
+ trouble_shooting = (
1077
+ "Try changing the instance type or reference the troubleshooting page "
1078
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-troubleshooting"
1079
+ ".html"
1080
+ )
1081
+ message = "Error hosting endpoint {}: {}. Reason: {}. {}".format(
1082
+ endpoint, status, reason, trouble_shooting
1083
+ )
1084
+ if "CapacityError" in str(reason):
1085
+ raise exceptions.CapacityError(
1086
+ message=message,
1087
+ allowed_statuses=["InService"],
1088
+ actual_status=status,
1089
+ )
1090
+ raise exceptions.UnexpectedStatusException(
1091
+ message=message,
1092
+ allowed_statuses=["InService"],
1093
+ actual_status=status,
1094
+ )
1095
+ return desc
1096
+
1097
+ def create_inference_component(
1098
+ self,
1099
+ inference_component_name: str,
1100
+ endpoint_name: str,
1101
+ variant_name: str,
1102
+ specification: Dict[str, Any],
1103
+ runtime_config: Optional[Dict[str, Any]] = None,
1104
+ tags: Optional[Tags] = None,
1105
+ wait: bool = True,
1106
+ ):
1107
+ """Create an Amazon SageMaker Inference Component.
1108
+
1109
+ Args:
1110
+ inference_component_name (str): Name of the Amazon SageMaker inference component
1111
+ to create.
1112
+ endpoint_name (str): Name of the Amazon SageMaker endpoint that the inference component
1113
+ will deploy to.
1114
+ variant_name (str): Name of the Amazon SageMaker variant that the inference component
1115
+ will deploy to.
1116
+ specification (Dict[str, Any]): The inference component specification.
1117
+ runtime_config (Optional[Dict[str, Any]]): Optional. The inference component
1118
+ runtime configuration. (Default: None).
1119
+ tags (Optional[Tags]): Optional. Either a dictionary or a list
1120
+ of dictionaries containing key-value pairs. (Default: None).
1121
+ wait (bool) : Optional. Wait for the inference component to finish being created before
1122
+ returning a value. (Default: True).
1123
+
1124
+ Returns:
1125
+ str: Name of the Amazon SageMaker ``InferenceComponent`` if created.
1126
+ """
1127
+ LOGGER.info(
1128
+ "Creating inference component with name %s for endpoint %s",
1129
+ inference_component_name,
1130
+ endpoint_name,
1131
+ )
1132
+
1133
+ if runtime_config is None:
1134
+ runtime_config = {"CopyCount": 1}
1135
+
1136
+ request = {
1137
+ "InferenceComponentName": inference_component_name,
1138
+ "EndpointName": endpoint_name,
1139
+ "VariantName": variant_name,
1140
+ "Specification": specification,
1141
+ "RuntimeConfig": runtime_config,
1142
+ }
1143
+
1144
+ tags = format_tags(tags)
1145
+ tags = _append_project_tags(tags)
1146
+ tags = self._append_sagemaker_config_tags(
1147
+ tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS)
1148
+ )
1149
+ if tags and len(tags) != 0:
1150
+ request["Tags"] = tags
1151
+
1152
+ self.sagemaker_client.create_inference_component(**request)
1153
+ if wait:
1154
+ self.wait_for_inference_component(inference_component_name)
1155
+ return inference_component_name
1156
+
1157
+ def wait_for_inference_component(self, inference_component_name, poll=20):
1158
+ """Wait for an Amazon SageMaker ``Inference Component`` deployment to complete.
1159
+
1160
+ Args:
1161
+ inference_component_name (str): Name of the ``Inference Component`` to wait for.
1162
+ poll (int): Polling interval in seconds (default: 20).
1163
+
1164
+ Raises:
1165
+ exceptions.CapacityError: If the inference component creation fails with CapacityError.
1166
+ exceptions.UnexpectedStatusException: If the inference component creation fails.
1167
+
1168
+ Returns:
1169
+ dict: Return value from the ``DescribeInferenceComponent`` API.
1170
+ """
1171
+ desc = _wait_until(
1172
+ lambda: self._inference_component_done(self.sagemaker_client, inference_component_name),
1173
+ poll,
1174
+ )
1175
+ status = desc["InferenceComponentStatus"]
1176
+
1177
+ if status != "InService":
1178
+ message = f"Error creating inference component '{inference_component_name}'"
1179
+ reason = desc.get("FailureReason")
1180
+ if reason:
1181
+ message = f"{message}: {reason}"
1182
+ if "CapacityError" in str(reason):
1183
+ raise exceptions.CapacityError(
1184
+ message=message,
1185
+ allowed_statuses=["InService"],
1186
+ actual_status=status,
1187
+ )
1188
+ raise exceptions.UnexpectedStatusException(
1189
+ message=message,
1190
+ allowed_statuses=["InService"],
1191
+ actual_status=status,
1192
+ )
1193
+ return desc
1194
+
1195
+ def describe_inference_component(self, inference_component_name):
1196
+ """Describe an Amazon SageMaker ``InferenceComponent``
1197
+
1198
+ Args:
1199
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1200
+
1201
+ Returns:
1202
+ dict[str,str]: Inference component details.
1203
+ """
1204
+
1205
+ return self.sagemaker_client.describe_inference_component(
1206
+ InferenceComponentName=inference_component_name
1207
+ )
1208
+
1209
+ def _inference_component_done(self, sagemaker_client, inference_component_name):
1210
+ """Check if creation of inference component is done.
1211
+
1212
+ Args:
1213
+ sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker
1214
+ service calls
1215
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1216
+ Returns:
1217
+ dict[str,str]: Inference component details.
1218
+ """
1219
+
1220
+ create_inference_component_codes = {
1221
+ "InService": "!",
1222
+ "Creating": "-",
1223
+ "Updating": "-",
1224
+ "Failed": "*",
1225
+ "Deleting": "o",
1226
+ }
1227
+ in_progress_statuses = ["Creating", "Updating", "Deleting"]
1228
+
1229
+ desc = sagemaker_client.describe_inference_component(
1230
+ InferenceComponentName=inference_component_name
1231
+ )
1232
+ status = desc["InferenceComponentStatus"]
1233
+
1234
+ print(create_inference_component_codes.get(status, "?"), end="", flush=True)
1235
+
1236
+ return None if status in in_progress_statuses else desc
1237
+
1238
+ def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True):
1239
+ """Update an Amazon SageMaker ``Endpoint`` , Raise an error endpoint_name does not exist.
1240
+
1241
+ Args:
1242
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
1243
+ endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
1244
+ deploy.
1245
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
1246
+ (default: True).
1247
+
1248
+ Returns:
1249
+ str: Name of the Amazon SageMaker ``Endpoint`` being updated.
1250
+
1251
+ Raises:
1252
+ - ValueError: if the endpoint does not already exist
1253
+ - botocore.exceptions.ClientError: If SageMaker throws an error while
1254
+ creating endpoint config, describing endpoint or updating endpoint
1255
+ """
1256
+ if not _deployment_entity_exists(
1257
+ lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
1258
+ ):
1259
+ raise ValueError(
1260
+ "Endpoint with name '{}' does not exist; please use an "
1261
+ "existing endpoint name".format(endpoint_name)
1262
+ )
1263
+
1264
+ try:
1265
+
1266
+ res = self.sagemaker_client.update_endpoint(
1267
+ EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
1268
+ )
1269
+ if res:
1270
+ self.endpoint_arn = res["EndpointArn"]
1271
+
1272
+ if wait:
1273
+ self.wait_for_endpoint(endpoint_name)
1274
+ return endpoint_name
1275
+ except Exception as e:
1276
+ troubleshooting = (
1277
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
1278
+ "sagemaker-python-sdk-troubleshooting.html"
1279
+ "#sagemaker-python-sdk-troubleshooting-update-endpoint"
1280
+ )
1281
+ logger.error(
1282
+ "Please check the troubleshooting guide for common errors: %s", troubleshooting
1283
+ )
1284
+ raise e
1285
+
1286
+ def endpoint_in_service_or_not(self, endpoint_name: str):
1287
+ """Check whether an Amazon SageMaker ``Endpoint``` is in IN_SERVICE status.
1288
+
1289
+ Raise any exception that is not recognized as "not found".
1290
+
1291
+ Args:
1292
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to
1293
+ check status.
1294
+
1295
+ Returns:
1296
+ bool: True if ``Endpoint`` is IN_SERVICE, False if ``Endpoint`` not exists
1297
+ or it's in other status.
1298
+
1299
+ Raises:
1300
+
1301
+ """
1302
+ try:
1303
+ desc = self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
1304
+ status = desc["EndpointStatus"]
1305
+ if status == "InService":
1306
+ return True
1307
+ return False
1308
+
1309
+ except botocore.exceptions.ClientError as e:
1310
+ str_err = str(e).lower()
1311
+ if "could not find" in str_err or "not found" in str_err:
1312
+ return False
1313
+ raise
1314
+
1315
+ def _intercept_create_request(
1316
+ self,
1317
+ request: Dict,
1318
+ create,
1319
+ func_name: str = None,
1320
+ # pylint: disable=unused-argument
1321
+ ):
1322
+ """This function intercepts the create job request.
1323
+
1324
+ PipelineSession inherits this Session class and will override
1325
+ this function to intercept the create request.
1326
+
1327
+ Args:
1328
+ request (dict): the create job request
1329
+ create (functor): a functor calls the sagemaker client create method
1330
+ func_name (str): the name of the function needed intercepting
1331
+ """
1332
+ return create(request)
1333
+
1334
+ def _create_inference_recommendations_job_request(
1335
+ self,
1336
+ role: str,
1337
+ job_name: str,
1338
+ job_description: str,
1339
+ framework: str,
1340
+ sample_payload_url: str,
1341
+ supported_content_types: List[str],
1342
+ tags: Optional[Tags],
1343
+ model_name: str = None,
1344
+ model_package_version_arn: str = None,
1345
+ job_duration_in_seconds: int = None,
1346
+ job_type: str = "Default",
1347
+ framework_version: str = None,
1348
+ nearest_model_name: str = None,
1349
+ supported_instance_types: List[str] = None,
1350
+ endpoint_configurations: List[Dict[str, Any]] = None,
1351
+ traffic_pattern: Dict[str, Any] = None,
1352
+ stopping_conditions: Dict[str, Any] = None,
1353
+ resource_limit: Dict[str, Any] = None,
1354
+ ) -> Dict[str, Any]:
1355
+ """Get request dictionary for CreateInferenceRecommendationsJob API.
1356
+
1357
+ Args:
1358
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1359
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1360
+ training data and model artifacts.
1361
+ You must grant sufficient permissions to this role.
1362
+ job_name (str): The name of the Inference Recommendations Job.
1363
+ job_description (str): A description of the Inference Recommendations Job.
1364
+ framework (str): The machine learning framework of the Image URI.
1365
+ sample_payload_url (str): The S3 path where the sample payload is stored.
1366
+ supported_content_types (List[str]): The supported MIME types for the input data.
1367
+ model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
1368
+ model_package_version_arn (str): The Amazon Resource Name (ARN) of a
1369
+ versioned model package.
1370
+ job_duration_in_seconds (int): The maximum job duration that a job
1371
+ can run for. Will be used for `Advanced` jobs.
1372
+ job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
1373
+ framework_version (str): The framework version of the Image URI.
1374
+ nearest_model_name (str): The name of a pre-trained machine learning model
1375
+ benchmarked by Amazon SageMaker Inference Recommender that matches your model.
1376
+ supported_instance_types (List[str]): A list of the instance types that are used
1377
+ to generate inferences in real-time.
1378
+ tags (Optional[Tags]): Tags used to identify where
1379
+ the Inference Recommendatons Call was made from.
1380
+ endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
1381
+ to use for a job. Will be used for `Advanced` jobs.
1382
+ traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
1383
+ Will be used for `Advanced` jobs.
1384
+ stopping_conditions (Dict[str, any]): A set of conditions for stopping a
1385
+ recommendation job.
1386
+ If any of the conditions are met, the job is automatically stopped.
1387
+ Will be used for `Advanced` jobs.
1388
+ resource_limit (Dict[str, any]): Defines the resource limit for the job.
1389
+ Will be used for `Advanced` jobs.
1390
+ Returns:
1391
+ Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API
1392
+ """
1393
+
1394
+ containerConfig = {
1395
+ "Domain": "MACHINE_LEARNING",
1396
+ "Task": "OTHER",
1397
+ "Framework": framework,
1398
+ "PayloadConfig": {
1399
+ "SamplePayloadUrl": sample_payload_url,
1400
+ "SupportedContentTypes": supported_content_types,
1401
+ },
1402
+ }
1403
+
1404
+ if framework_version:
1405
+ containerConfig["FrameworkVersion"] = framework_version
1406
+ if nearest_model_name:
1407
+ containerConfig["NearestModelName"] = nearest_model_name
1408
+ if supported_instance_types:
1409
+ containerConfig["SupportedInstanceTypes"] = supported_instance_types
1410
+
1411
+ request = {
1412
+ "JobName": job_name,
1413
+ "JobType": job_type,
1414
+ "RoleArn": role,
1415
+ "InputConfig": {
1416
+ "ContainerConfig": containerConfig,
1417
+ },
1418
+ "Tags": format_tags(tags),
1419
+ }
1420
+
1421
+ request.get("InputConfig").update(
1422
+ {"ModelPackageVersionArn": model_package_version_arn}
1423
+ if model_package_version_arn
1424
+ else {"ModelName": model_name}
1425
+ )
1426
+
1427
+ if job_description:
1428
+ request["JobDescription"] = job_description
1429
+ if job_duration_in_seconds:
1430
+ request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds
1431
+
1432
+ if job_type == "Advanced":
1433
+ if stopping_conditions:
1434
+ request["StoppingConditions"] = stopping_conditions
1435
+ if resource_limit:
1436
+ request["InputConfig"]["ResourceLimit"] = resource_limit
1437
+ if traffic_pattern:
1438
+ request["InputConfig"]["TrafficPattern"] = traffic_pattern
1439
+ if endpoint_configurations:
1440
+ request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations
1441
+
1442
+ return request
1443
+
1444
+ def create_inference_recommendations_job(
1445
+ self,
1446
+ role: str,
1447
+ sample_payload_url: str,
1448
+ supported_content_types: List[str],
1449
+ job_name: str = None,
1450
+ job_type: str = "Default",
1451
+ model_name: str = None,
1452
+ model_package_version_arn: str = None,
1453
+ job_duration_in_seconds: int = None,
1454
+ nearest_model_name: str = None,
1455
+ supported_instance_types: List[str] = None,
1456
+ framework: str = None,
1457
+ framework_version: str = None,
1458
+ endpoint_configurations: List[Dict[str, any]] = None,
1459
+ traffic_pattern: Dict[str, any] = None,
1460
+ stopping_conditions: Dict[str, any] = None,
1461
+ resource_limit: Dict[str, any] = None,
1462
+ ):
1463
+ """Creates an Inference Recommendations Job
1464
+
1465
+ Args:
1466
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1467
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1468
+ training data and model artifactså.
1469
+ You must grant sufficient permissions to this role.
1470
+ sample_payload_url (str): The S3 path where the sample payload is stored.
1471
+ supported_content_types (List[str]): The supported MIME types for the input data.
1472
+ model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
1473
+ model_package_version_arn (str): The Amazon Resource Name (ARN) of a
1474
+ versioned model package.
1475
+ job_name (str): The name of the job being run.
1476
+ job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
1477
+ job_duration_in_seconds (int): The maximum job duration that a job
1478
+ can run for. Will be used for `Advanced` jobs.
1479
+ nearest_model_name (str): The name of a pre-trained machine learning model
1480
+ benchmarked by Amazon SageMaker Inference Recommender that matches your model.
1481
+ supported_instance_types (List[str]): A list of the instance types that are used
1482
+ to generate inferences in real-time.
1483
+ framework (str): The machine learning framework of the Image URI.
1484
+ framework_version (str): The framework version of the Image URI.
1485
+ endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
1486
+ to use for a job. Will be used for `Advanced` jobs.
1487
+ traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
1488
+ Will be used for `Advanced` jobs.
1489
+ stopping_conditions (Dict[str, any]): A set of conditions for stopping a
1490
+ recommendation job.
1491
+ If any of the conditions are met, the job is automatically stopped.
1492
+ Will be used for `Advanced` jobs.
1493
+ resource_limit (Dict[str, any]): Defines the resource limit for the job.
1494
+ Will be used for `Advanced` jobs.
1495
+ Returns:
1496
+ str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
1497
+ """
1498
+
1499
+ if model_name is None and model_package_version_arn is None:
1500
+ raise ValueError("Please provide either model_name or model_package_version_arn.")
1501
+
1502
+ if model_name is not None and model_package_version_arn is not None:
1503
+ raise ValueError("Please provide either model_name or model_package_version_arn.")
1504
+
1505
+ if not job_name:
1506
+ unique_tail = uuid.uuid4()
1507
+ job_name = "SMPYTHONSDK-" + str(unique_tail)
1508
+ job_description = "#python-sdk-create"
1509
+
1510
+ tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}]
1511
+
1512
+ create_inference_recommendations_job_request = (
1513
+ self._create_inference_recommendations_job_request(
1514
+ role=role,
1515
+ model_name=model_name,
1516
+ model_package_version_arn=model_package_version_arn,
1517
+ job_name=job_name,
1518
+ job_type=job_type,
1519
+ job_duration_in_seconds=job_duration_in_seconds,
1520
+ job_description=job_description,
1521
+ framework=framework,
1522
+ framework_version=framework_version,
1523
+ nearest_model_name=nearest_model_name,
1524
+ sample_payload_url=sample_payload_url,
1525
+ supported_content_types=supported_content_types,
1526
+ supported_instance_types=supported_instance_types,
1527
+ endpoint_configurations=endpoint_configurations,
1528
+ traffic_pattern=traffic_pattern,
1529
+ stopping_conditions=stopping_conditions,
1530
+ resource_limit=resource_limit,
1531
+ tags=tags,
1532
+ )
1533
+ )
1534
+
1535
+ def submit(request):
1536
+ logger.info("Creating Inference Recommendations job with name: %s", job_name)
1537
+ logger.debug("process request: %s", json.dumps(request, indent=4))
1538
+ self.sagemaker_client.create_inference_recommendations_job(**request)
1539
+
1540
+ self._intercept_create_request(
1541
+ create_inference_recommendations_job_request,
1542
+ submit,
1543
+ self.create_inference_recommendations_job.__name__,
1544
+ )
1545
+ return job_name
1546
+
1547
+ def wait_for_inference_recommendations_job(
1548
+ self, job_name: str, poll: int = 120, log_level: str = "Verbose"
1549
+ ) -> Dict[str, Any]:
1550
+ """Wait for an Amazon SageMaker Inference Recommender job to complete.
1551
+
1552
+ Args:
1553
+ job_name (str): Name of the Inference Recommender job to wait for.
1554
+ poll (int): Polling interval in seconds (default: 120).
1555
+ log_level (str): The level of verbosity for the logs.
1556
+ Can be "Quiet" or "Verbose" (default: "Quiet").
1557
+
1558
+ Returns:
1559
+ (dict): Return value from the ``DescribeInferenceRecommendationsJob`` API.
1560
+
1561
+ Raises:
1562
+ exceptions.CapacityError: If the Inference Recommender job fails with CapacityError.
1563
+ exceptions.UnexpectedStatusException: If the Inference Recommender job fails.
1564
+ """
1565
+ if log_level == "Quiet":
1566
+ _wait_until(
1567
+ lambda: _describe_inference_recommendations_job_status(
1568
+ self.sagemaker_client, job_name
1569
+ ),
1570
+ poll,
1571
+ )
1572
+ elif log_level == "Verbose":
1573
+ _display_inference_recommendations_job_steps_status(
1574
+ self, self.sagemaker_client, job_name
1575
+ )
1576
+ else:
1577
+ raise ValueError("log_level must be either Quiet or Verbose")
1578
+ desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name)
1579
+ _check_job_status(job_name, desc, "Status")
1580
+ return desc
1581
+
1582
+ def delete_model(self, model_name):
1583
+ """Delete an Amazon SageMaker Model.
1584
+
1585
+ Args:
1586
+ model_name (str): Name of the Amazon SageMaker model to delete.
1587
+ """
1588
+ logger.info("Deleting model with name: %s", model_name)
1589
+ self.sagemaker_client.delete_model(ModelName=model_name)
1590
+
1591
+ def delete_endpoint(self, endpoint_name):
1592
+ """Delete an Amazon SageMaker ``Endpoint``.
1593
+
1594
+ Args:
1595
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
1596
+ """
1597
+ logger.info("Deleting endpoint with name: %s", endpoint_name)
1598
+ self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
1599
+
1600
+ def delete_endpoint_config(self, endpoint_config_name):
1601
+ """Delete an Amazon SageMaker endpoint configuration.
1602
+
1603
+ Args:
1604
+ endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
1605
+ delete.
1606
+ """
1607
+ logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
1608
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
1609
+
1610
+ def wait_for_optimization_job(self, job, poll=5):
1611
+ """Wait for an Amazon SageMaker Optimization job to complete.
1612
+
1613
+ Args:
1614
+ job (str): Name of optimization job to wait for.
1615
+ poll (int): Polling interval in seconds (default: 5).
1616
+
1617
+ Returns:
1618
+ (dict): Return value from the ``DescribeOptimizationJob`` API.
1619
+
1620
+ Raises:
1621
+ exceptions.ResourceNotFound: If optimization job fails with CapacityError.
1622
+ exceptions.UnexpectedStatusException: If optimization job fails.
1623
+ """
1624
+ desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll)
1625
+ _check_job_status(job, desc, "OptimizationJobStatus")
1626
+ return desc
1627
+
1628
+ def update_inference_component(
1629
+ self, inference_component_name, specification=None, runtime_config=None, wait=True
1630
+ ):
1631
+ """Update an Amazon SageMaker ``InferenceComponent``
1632
+
1633
+ Args:
1634
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1635
+ specification ([dict[str,int]]): Resource configuration. Optional.
1636
+ Example: {
1637
+ "MinMemoryRequiredInMb": 1024,
1638
+ "NumberOfCpuCoresRequired": 1,
1639
+ "NumberOfAcceleratorDevicesRequired": 1,
1640
+ "MaxMemoryRequiredInMb": 4096,
1641
+ },
1642
+ runtime_config ([dict[str,int]]): Number of copies. Optional.
1643
+ Default: {
1644
+ "copyCount": 1
1645
+ }
1646
+ wait: Wait for inference component to be created before return. Optional. Default is
1647
+ True.
1648
+
1649
+ Return:
1650
+ str: inference component name
1651
+
1652
+ Raises:
1653
+ ValueError: If the inference_component_name does not exist.
1654
+ """
1655
+ if not _deployment_entity_exists(
1656
+ lambda: self.sagemaker_client.describe_inference_component(
1657
+ InferenceComponentName=inference_component_name
1658
+ )
1659
+ ):
1660
+ raise ValueError(
1661
+ "InferenceComponent with name '{}' does not exist; please use an "
1662
+ "existing model name".format(inference_component_name)
1663
+ )
1664
+
1665
+ request = {
1666
+ "InferenceComponentName": inference_component_name,
1667
+ "Specification": specification,
1668
+ "RuntimeConfig": runtime_config,
1669
+ }
1670
+
1671
+ self.sagemaker_client.update_inference_component(**request)
1672
+
1673
+ if wait:
1674
+ self.wait_for_inference_component(inference_component_name)
1675
+ return inference_component_name
1676
+
1677
+ def _create_model_request(
1678
+ self,
1679
+ name,
1680
+ role,
1681
+ container_defs,
1682
+ vpc_config=None,
1683
+ enable_network_isolation=False,
1684
+ primary_container=None,
1685
+ tags=None,
1686
+ ): # pylint: disable=redefined-outer-name
1687
+ """Placeholder docstring"""
1688
+
1689
+ if container_defs and primary_container:
1690
+ raise ValueError("Both container_defs and primary_container can not be passed as input")
1691
+
1692
+ if primary_container:
1693
+ msg = (
1694
+ "primary_container is going to be deprecated in a future release. Please use "
1695
+ "container_defs instead."
1696
+ )
1697
+ warnings.warn(msg, DeprecationWarning)
1698
+ container_defs = primary_container
1699
+
1700
+ role = self.expand_role(role)
1701
+
1702
+ if isinstance(container_defs, list):
1703
+ update_list_of_dicts_with_values_from_config(
1704
+ container_defs, MODEL_CONTAINERS_PATH, sagemaker_session=self
1705
+ )
1706
+ container_definition = container_defs
1707
+ else:
1708
+ container_definition = _expand_container_def(container_defs)
1709
+ container_definition = update_nested_dictionary_with_values_from_config(
1710
+ container_definition, MODEL_PRIMARY_CONTAINER_PATH, sagemaker_session=self
1711
+ )
1712
+
1713
+ request = {"ModelName": name, "ExecutionRoleArn": role}
1714
+ if isinstance(container_definition, list):
1715
+ request["Containers"] = container_definition
1716
+ elif "ModelPackageName" in container_definition:
1717
+ request["Containers"] = [container_definition]
1718
+ else:
1719
+ request["PrimaryContainer"] = container_definition
1720
+
1721
+ if tags:
1722
+ request["Tags"] = format_tags(tags)
1723
+
1724
+ if vpc_config:
1725
+ request["VpcConfig"] = vpc_config
1726
+
1727
+ if enable_network_isolation:
1728
+ # enable_network_isolation may be a pipeline variable which is
1729
+ # parsed in execution time
1730
+ request["EnableNetworkIsolation"] = enable_network_isolation
1731
+
1732
+ return request
1733
+
1734
+ def create_model(
1735
+ self,
1736
+ name,
1737
+ role=None,
1738
+ container_defs=None,
1739
+ vpc_config=None,
1740
+ enable_network_isolation=None,
1741
+ primary_container=None,
1742
+ tags=None,
1743
+ ):
1744
+ """Create an Amazon SageMaker ``Model``.
1745
+
1746
+ Specify the S3 location of the model artifacts and Docker image containing
1747
+ the inference code. Amazon SageMaker uses this information to deploy the
1748
+ model in Amazon SageMaker. This method can also be used to create a Model for an Inference
1749
+ Pipeline if you pass the list of container definitions through the containers parameter.
1750
+
1751
+ Args:
1752
+ name (str): Name of the Amazon SageMaker ``Model`` to create.
1753
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1754
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1755
+ training data and model artifacts. You must grant sufficient permissions to this
1756
+ role.
1757
+ container_defs (list[dict[str, str]] or [dict[str, str]]): A single container
1758
+ definition or a list of container definitions which will be invoked sequentially
1759
+ while performing the prediction. If the list contains only one container, then
1760
+ it'll be passed to SageMaker Hosting as the ``PrimaryContainer`` and otherwise,
1761
+ it'll be passed as ``Containers``.You can also specify the return value of
1762
+ ``sagemaker.get_container_def()`` or ``sagemaker.pipeline_container_def()``,
1763
+ which will used to create more advanced container configurations, including model
1764
+ containers which need artifacts from S3.
1765
+ vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None)
1766
+ * 'Subnets' (list[str]): List of subnet ids.
1767
+ * 'SecurityGroupIds' (list[str]): List of security group ids.
1768
+ enable_network_isolation (bool): Whether the model requires network isolation or not.
1769
+ primary_container (str or dict[str, str]): Docker image which defines the inference
1770
+ code. You can also specify the return value of ``sagemaker.container_def()``,
1771
+ which is used to create more advanced container configurations, including model
1772
+ containers which need artifacts from S3. This field is deprecated, please use
1773
+ container_defs instead.
1774
+ tags(Optional[Tags]): Optional. The list of tags to add to the model.
1775
+
1776
+ Example:
1777
+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
1778
+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
1779
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
1780
+
1781
+ Returns:
1782
+ str: Name of the Amazon SageMaker ``Model`` created.
1783
+ """
1784
+ tags = _append_project_tags(format_tags(tags))
1785
+ tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS))
1786
+ role = resolve_value_from_config(
1787
+ role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self
1788
+ )
1789
+ vpc_config = resolve_value_from_config(
1790
+ vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self
1791
+ )
1792
+ enable_network_isolation = resolve_value_from_config(
1793
+ direct_input=enable_network_isolation,
1794
+ config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
1795
+ default_value=False,
1796
+ sagemaker_session=self,
1797
+ )
1798
+
1799
+ # Due to ambuiguity in container_defs which accepts both a single
1800
+ # container definition(dtype: dict) and a list of container definitions (dtype: list),
1801
+ # we need to inject environment variables into the container_defs in the helper function
1802
+ # _create_model_request.
1803
+ create_model_request = self._create_model_request(
1804
+ name=name,
1805
+ role=role,
1806
+ container_defs=container_defs,
1807
+ vpc_config=vpc_config,
1808
+ enable_network_isolation=enable_network_isolation,
1809
+ primary_container=primary_container,
1810
+ tags=tags,
1811
+ )
1812
+
1813
+ def submit(request):
1814
+ logger.info("Creating model with name: %s", name)
1815
+ logger.debug("CreateModel request: %s", json.dumps(request, indent=4))
1816
+ try:
1817
+ self.sagemaker_client.create_model(**request)
1818
+ except ClientError as e:
1819
+ error_code = e.response["Error"]["Code"]
1820
+ message = e.response["Error"]["Message"]
1821
+ if (
1822
+ error_code == "ValidationException"
1823
+ and "Cannot create already existing model" in message
1824
+ ):
1825
+ logger.warning("Using already existing model: %s", name)
1826
+ else:
1827
+ raise
1828
+
1829
+ self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
1830
+ return name
1831
+
1832
+ def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
1833
+ """Create a SageMaker Model Package from the results of training with an Algorithm Package.
1834
+
1835
+ Args:
1836
+ name (str): ModelPackage name
1837
+ description (str): Model Package description
1838
+ algorithm_arn (str): arn or name of the algorithm used for training.
1839
+ model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
1840
+ ``ModelDataSource`` to the model artifacts produced by training
1841
+ """
1842
+ sourceAlgorithm = {"AlgorithmName": algorithm_arn}
1843
+ if isinstance(model_data, dict):
1844
+ sourceAlgorithm["ModelDataSource"] = model_data
1845
+ else:
1846
+ sourceAlgorithm["ModelDataUrl"] = model_data
1847
+
1848
+ request = {
1849
+ "ModelPackageName": name,
1850
+ "ModelPackageDescription": description,
1851
+ "SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
1852
+ }
1853
+ try:
1854
+ logger.info("Creating model package with name: %s", name)
1855
+ self.sagemaker_client.create_model_package(**request)
1856
+ except ClientError as e:
1857
+ error_code = e.response["Error"]["Code"]
1858
+ message = e.response["Error"]["Message"]
1859
+
1860
+ if error_code == "ValidationException" and "ModelPackage already exists" in message:
1861
+ logger.warning("Using already existing model package: %s", name)
1862
+ else:
1863
+ raise
1864
+
1865
+ def expand_role(self, role):
1866
+ """Expand an IAM role name into an ARN.
1867
+
1868
+ If the role is already in the form of an ARN, then the role is simply returned. Otherwise
1869
+ we retrieve the full ARN and return it.
1870
+
1871
+ Args:
1872
+ role (str): An AWS IAM role (either name or full ARN).
1873
+
1874
+ Returns:
1875
+ str: The corresponding AWS IAM role ARN.
1876
+ """
1877
+ if "/" in role:
1878
+ return role
1879
+ return self.boto_session.resource("iam").Role(role).arn
1880
+
1881
+
1882
+ def _expand_container_def(c_def):
1883
+ """Placeholder docstring"""
1884
+ if isinstance(c_def, six.string_types):
1885
+ return container_def(c_def)
1886
+ return c_def
1887
+
1888
+
1889
+ def expand_role(self, role):
1890
+ """Expand an IAM role name into an ARN.
1891
+
1892
+ If the role is already in the form of an ARN, then the role is simply returned. Otherwise
1893
+ we retrieve the full ARN and return it.
1894
+
1895
+ Args:
1896
+ role (str): An AWS IAM role (either name or full ARN).
1897
+
1898
+ Returns:
1899
+ str: The corresponding AWS IAM role ARN.
1900
+ """
1901
+ if "/" in role:
1902
+ return role
1903
+ return self.boto_session.resource("iam").Role(role).arn
1904
+
1905
+
1906
+ def s3_path_join(*args, with_end_slash: bool = False):
1907
+ """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
1908
+
1909
+ Behavior of this function:
1910
+ - If the first argument is "s3://", then that is preserved.
1911
+ - The output by default will have no slashes at the beginning or end. There is one exception
1912
+ (see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
1913
+ `"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
1914
+ - Any repeat slashes will be removed in the output (except for "s3://" if provided at the
1915
+ beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
1916
+ `"s3://foo/bar/baz"`.
1917
+ - Empty or None arguments will be skipped. For example
1918
+ `s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
1919
+
1920
+ Alternatives to this function that are NOT recommended for S3 paths:
1921
+ - `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
1922
+ - `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
1923
+ dots (".") and root directories. (for example
1924
+ `pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
1925
+ - `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
1926
+
1927
+ Args:
1928
+ *args: The strings to join with a slash.
1929
+ with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
1930
+ to the end of the path
1931
+
1932
+ Returns:
1933
+ str: The joined string, without a slash at the end unless with_end_slash is True.
1934
+ """
1935
+ delimiter = "/"
1936
+
1937
+ non_empty_args = list(filter(lambda item: item is not None and item != "", args))
1938
+
1939
+ merged_path = ""
1940
+ for index, path in enumerate(non_empty_args):
1941
+ if (
1942
+ index == 0
1943
+ or (merged_path and merged_path[-1] == delimiter)
1944
+ or (path and path[0] == delimiter)
1945
+ ):
1946
+ # dont need to add an extra slash because either this is the beginning of the string,
1947
+ # or one (or more) slash already exists
1948
+ merged_path += path
1949
+ else:
1950
+ merged_path += delimiter + path
1951
+
1952
+ if with_end_slash and merged_path and merged_path[-1] != delimiter:
1953
+ merged_path += delimiter
1954
+
1955
+ # At this point, merged_path may include slashes at the beginning and/or end. And some of the
1956
+ # provided args may have had duplicate slashes inside or at the ends.
1957
+ # For backwards compatibility reasons, these need to be filtered out (done below). In the
1958
+ # future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
1959
+ # one option is to create a new optional argument (or a new function) that only executes the
1960
+ # logic above.
1961
+ filtered_path = merged_path
1962
+
1963
+ # remove duplicate slashes
1964
+ if filtered_path:
1965
+
1966
+ def duplicate_delimiter_remover(sequence, next_char):
1967
+ if sequence[-1] == delimiter and next_char == delimiter:
1968
+ return sequence
1969
+ return sequence + next_char
1970
+
1971
+ if filtered_path.startswith("s3://"):
1972
+ filtered_path = reduce(
1973
+ duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
1974
+ )
1975
+ else:
1976
+ filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
1977
+
1978
+ # remove beginning slashes
1979
+ filtered_path = filtered_path.lstrip(delimiter)
1980
+
1981
+ # remove end slashes
1982
+ if not with_end_slash and filtered_path != "s3://":
1983
+ filtered_path = filtered_path.rstrip(delimiter)
1984
+
1985
+ return filtered_path
1986
+
1987
+
1988
+ def botocore_resolver():
1989
+ """Get the DNS suffix for the given region.
1990
+
1991
+ Args:
1992
+ region (str): AWS region name
1993
+
1994
+ Returns:
1995
+ str: the DNS suffix
1996
+ """
1997
+ loader = botocore.loaders.create_loader()
1998
+ return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
1999
+
2000
+
2001
+ def sts_regional_endpoint(region):
2002
+ """Get the AWS STS endpoint specific for the given region.
2003
+
2004
+ We need this function because the AWS SDK does not yet honor
2005
+ the ``region_name`` parameter when creating an AWS STS client.
2006
+
2007
+ For the list of regional endpoints, see
2008
+ https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
2009
+
2010
+ Args:
2011
+ region (str): AWS region name
2012
+
2013
+ Returns:
2014
+ str: AWS STS regional endpoint
2015
+ """
2016
+ endpoint_data = botocore_resolver().construct_endpoint("sts", region)
2017
+ if region == "il-central-1" and not endpoint_data:
2018
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
2019
+ return "https://{}".format(endpoint_data["hostname"])
2020
+
2021
+
2022
+ def get_execution_role(sagemaker_session=None, use_default=False):
2023
+ """Return the role ARN whose credentials are used to call the API.
2024
+
2025
+ Throws an exception if role doesn't exist.
2026
+
2027
+ Args:
2028
+ sagemaker_session (Session): Current sagemaker session.
2029
+ use_default (bool): Use a default role if ``get_caller_identity_arn`` does not
2030
+ return a correct role. This default role will be created if needed.
2031
+ Defaults to ``False``.
2032
+
2033
+ Returns:
2034
+ (str): The role ARN
2035
+ """
2036
+ if not sagemaker_session:
2037
+ sagemaker_session = Session()
2038
+ arn = sagemaker_session.get_caller_identity_arn()
2039
+
2040
+ if ":role/" in arn:
2041
+ return arn
2042
+
2043
+ if use_default:
2044
+ default_role_name = "AmazonSageMaker-DefaultRole"
2045
+
2046
+ LOGGER.warning("Using default role: %s", default_role_name)
2047
+
2048
+ boto3_session = sagemaker_session.boto_session
2049
+ permissions_policy = json.dumps(
2050
+ {
2051
+ "Version": "2012-10-17",
2052
+ "Statement": [
2053
+ {
2054
+ "Effect": "Allow",
2055
+ "Principal": {"Service": ["sagemaker.amazonaws.com"]},
2056
+ "Action": "sts:AssumeRole",
2057
+ }
2058
+ ],
2059
+ }
2060
+ )
2061
+ iam_client = boto3_session.client("iam")
2062
+ try:
2063
+ iam_client.get_role(RoleName=default_role_name)
2064
+ except iam_client.exceptions.NoSuchEntityException:
2065
+ iam_client.create_role(
2066
+ RoleName=default_role_name, AssumeRolePolicyDocument=str(permissions_policy)
2067
+ )
2068
+
2069
+ LOGGER.warning("Created new sagemaker execution role: %s", default_role_name)
2070
+
2071
+ iam_client.attach_role_policy(
2072
+ PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
2073
+ RoleName=default_role_name,
2074
+ )
2075
+ return iam_client.get_role(RoleName=default_role_name)["Role"]["Arn"]
2076
+
2077
+ message = (
2078
+ "The current AWS identity is not a role: {}, therefore it cannot be used as a "
2079
+ "SageMaker execution role"
2080
+ )
2081
+ raise ValueError(message.format(arn))
2082
+
2083
+
2084
+ def get_add_model_package_inference_args(
2085
+ model_package_arn,
2086
+ name,
2087
+ containers=None,
2088
+ content_types=None,
2089
+ response_types=None,
2090
+ inference_instances=None,
2091
+ transform_instances=None,
2092
+ description=None,
2093
+ ):
2094
+ """Get request dictionary for UpdateModelPackage API for additional inference.
2095
+
2096
+ Args:
2097
+ model_package_arn (str): Arn for the model package.
2098
+ name (str): Name to identify the additional inference specification
2099
+ containers (dict): The Amazon ECR registry path of the Docker image
2100
+ that contains the inference code.
2101
+ image_uris (List[str]): The ECR path where inference code is stored.
2102
+ description (str): Description for the additional inference specification
2103
+ content_types (list[str]): The supported MIME types
2104
+ for the input data.
2105
+ response_types (list[str]): The supported MIME types
2106
+ for the output data.
2107
+ inference_instances (list[str]): A list of the instance
2108
+ types that are used to generate inferences in real-time (default: None).
2109
+ transform_instances (list[str]): A list of the instance
2110
+ types on which a transformation job can be run or on which an endpoint can be
2111
+ deployed (default: None).
2112
+ """
2113
+
2114
+ request_dict = {}
2115
+ if containers is not None:
2116
+ inference_specification = {
2117
+ "Containers": containers,
2118
+ }
2119
+
2120
+ if name is not None:
2121
+ inference_specification.update({"Name": name})
2122
+
2123
+ if description is not None:
2124
+ inference_specification.update({"Description": description})
2125
+ if content_types is not None:
2126
+ inference_specification.update(
2127
+ {
2128
+ "SupportedContentTypes": content_types,
2129
+ }
2130
+ )
2131
+ if response_types is not None:
2132
+ inference_specification.update(
2133
+ {
2134
+ "SupportedResponseMIMETypes": response_types,
2135
+ }
2136
+ )
2137
+ if inference_instances is not None:
2138
+ inference_specification.update(
2139
+ {
2140
+ "SupportedRealtimeInferenceInstanceTypes": inference_instances,
2141
+ }
2142
+ )
2143
+ if transform_instances is not None:
2144
+ inference_specification.update(
2145
+ {
2146
+ "SupportedTransformInstanceTypes": transform_instances,
2147
+ }
2148
+ )
2149
+ request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
2150
+ request_dict.update({"ModelPackageArn": model_package_arn})
2151
+ return request_dict
2152
+
2153
+
2154
+ def get_update_model_package_inference_args(
2155
+ model_package_arn,
2156
+ containers=None,
2157
+ content_types=None,
2158
+ response_types=None,
2159
+ inference_instances=None,
2160
+ transform_instances=None,
2161
+ ):
2162
+ """Get request dictionary for UpdateModelPackage API for inference specification.
2163
+
2164
+ Args:
2165
+ model_package_arn (str): Arn for the model package.
2166
+ containers (dict): The Amazon ECR registry path of the Docker image
2167
+ that contains the inference code.
2168
+ content_types (list[str]): The supported MIME types
2169
+ for the input data.
2170
+ response_types (list[str]): The supported MIME types
2171
+ for the output data.
2172
+ inference_instances (list[str]): A list of the instance
2173
+ types that are used to generate inferences in real-time (default: None).
2174
+ transform_instances (list[str]): A list of the instance
2175
+ types on which a transformation job can be run or on which an endpoint can be
2176
+ deployed (default: None).
2177
+ """
2178
+
2179
+ request_dict = {}
2180
+ if containers is not None:
2181
+ inference_specification = {
2182
+ "Containers": containers,
2183
+ }
2184
+ if content_types is not None:
2185
+ inference_specification.update(
2186
+ {
2187
+ "SupportedContentTypes": content_types,
2188
+ }
2189
+ )
2190
+ if response_types is not None:
2191
+ inference_specification.update(
2192
+ {
2193
+ "SupportedResponseMIMETypes": response_types,
2194
+ }
2195
+ )
2196
+ if inference_instances is not None:
2197
+ inference_specification.update(
2198
+ {
2199
+ "SupportedRealtimeInferenceInstanceTypes": inference_instances,
2200
+ }
2201
+ )
2202
+ if transform_instances is not None:
2203
+ inference_specification.update(
2204
+ {
2205
+ "SupportedTransformInstanceTypes": transform_instances,
2206
+ }
2207
+ )
2208
+ request_dict["InferenceSpecification"] = inference_specification
2209
+ request_dict.update({"ModelPackageArn": model_package_arn})
2210
+ return request_dict
2211
+
2212
+
2213
+ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
2214
+ sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
2215
+ ):
2216
+ """Display logs for a given training job, optionally tailing them until job is complete.
2217
+
2218
+ If the output is a tty or a Jupyter cell, it will be color-coded
2219
+ based on which instance the log entry is from.
2220
+
2221
+ Args:
2222
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
2223
+ object, used for SageMaker interactions.
2224
+ job_name (str): Name of the training job to display the logs for.
2225
+ wait (bool): Whether to keep looking for new log entries until the job completes
2226
+ (default: False).
2227
+ poll (int): The interval in seconds between polling for new log entries and job
2228
+ completion (default: 5).
2229
+ log_type ([str]): A list of strings specifying which logs to print. Acceptable
2230
+ strings are "All", "None", "Training", or "Rules". To maintain backwards
2231
+ compatibility, boolean values are also accepted and converted to strings.
2232
+ timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
2233
+ default.
2234
+ Returns:
2235
+ Last call to sagemaker DescribeTrainingJob
2236
+ Raises:
2237
+ exceptions.CapacityError: If the training job fails with CapacityError.
2238
+ exceptions.UnexpectedStatusException: If waiting and the training job fails.
2239
+ """
2240
+ sagemaker_client = sagemaker_session.sagemaker_client
2241
+ request_end_time = time.time() + timeout if timeout else None
2242
+ description = _wait_until(
2243
+ lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name)
2244
+ )
2245
+ print(secondary_training_status_message(description, None), end="")
2246
+
2247
+ instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
2248
+ sagemaker_session.boto_session, description, job="Training"
2249
+ )
2250
+
2251
+ state = _get_initial_job_state(description, "TrainingJobStatus", wait)
2252
+
2253
+ # The loop below implements a state machine that alternates between checking the job status
2254
+ # and reading whatever is available in the logs at this point. Note, that if we were
2255
+ # called with wait == False, we never check the job status.
2256
+ #
2257
+ # If wait == TRUE and job is not completed, the initial state is TAILING
2258
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
2259
+ # complete).
2260
+ #
2261
+ # The state table:
2262
+ #
2263
+ # STATE ACTIONS CONDITION NEW STATE
2264
+ # ---------------- ---------------- ----------------- ----------------
2265
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
2266
+ # Else TAILING
2267
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
2268
+ # COMPLETE Read logs, Exit N/A
2269
+ #
2270
+ # Notes:
2271
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
2272
+ # Cloudwatch after the job was marked complete.
2273
+ last_describe_job_call = time.time()
2274
+ last_description = description
2275
+ last_debug_rule_statuses = None
2276
+ last_profiler_rule_statuses = None
2277
+
2278
+ while True:
2279
+ _flush_log_streams(
2280
+ stream_names,
2281
+ instance_count,
2282
+ client,
2283
+ log_group,
2284
+ job_name,
2285
+ positions,
2286
+ dot,
2287
+ color_wrap,
2288
+ )
2289
+ if timeout and time.time() > request_end_time:
2290
+ print("Timeout Exceeded. {} seconds elapsed.".format(timeout))
2291
+ break
2292
+
2293
+ if state == LogState.COMPLETE:
2294
+ break
2295
+
2296
+ time.sleep(poll)
2297
+
2298
+ if state == LogState.JOB_COMPLETE:
2299
+ state = LogState.COMPLETE
2300
+ elif time.time() - last_describe_job_call >= 30:
2301
+ description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
2302
+ last_describe_job_call = time.time()
2303
+
2304
+ if secondary_training_status_changed(description, last_description):
2305
+ print()
2306
+ print(secondary_training_status_message(description, last_description), end="")
2307
+ last_description = description
2308
+
2309
+ status = description["TrainingJobStatus"]
2310
+
2311
+ if status in ("Completed", "Failed", "Stopped"):
2312
+ print()
2313
+ state = LogState.JOB_COMPLETE
2314
+
2315
+ # Print prettified logs related to the status of SageMaker Debugger rules.
2316
+ debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
2317
+ if (
2318
+ debug_rule_statuses
2319
+ and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
2320
+ and (log_type in {"All", "Rules"})
2321
+ ):
2322
+ for status in debug_rule_statuses:
2323
+ rule_log = (
2324
+ f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
2325
+ )
2326
+ print(rule_log)
2327
+
2328
+ last_debug_rule_statuses = debug_rule_statuses
2329
+
2330
+ # Print prettified logs related to the status of SageMaker Profiler rules.
2331
+ profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
2332
+ if (
2333
+ profiler_rule_statuses
2334
+ and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
2335
+ and (log_type in {"All", "Rules"})
2336
+ ):
2337
+ for status in profiler_rule_statuses:
2338
+ rule_log = (
2339
+ f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
2340
+ )
2341
+ print(rule_log)
2342
+
2343
+ last_profiler_rule_statuses = profiler_rule_statuses
2344
+
2345
+ if wait:
2346
+ _check_job_status(job_name, description, "TrainingJobStatus")
2347
+ if dot:
2348
+ print()
2349
+ # Customers are not billed for hardware provisioning, so billable time is less than
2350
+ # total time
2351
+ training_time = description.get("TrainingTimeInSeconds")
2352
+ billable_time = description.get("BillableTimeInSeconds")
2353
+ if training_time is not None:
2354
+ print("Training seconds:", training_time * instance_count)
2355
+ if billable_time is not None:
2356
+ print("Billable seconds:", billable_time * instance_count)
2357
+ if description.get("EnableManagedSpotTraining"):
2358
+ saving = (1 - float(billable_time) / training_time) * 100
2359
+ print("Managed Spot Training savings: {:.1f}%".format(saving))
2360
+ return last_description
2361
+
2362
+
2363
+ def _check_job_status(job, desc, status_key_name):
2364
+ """Check to see if the job completed successfully.
2365
+
2366
+ If not, construct and raise a exceptions. (UnexpectedStatusException).
2367
+
2368
+ Args:
2369
+ job (str): The name of the job to check.
2370
+ desc (dict[str, str]): The result of ``describe_training_job()``.
2371
+ status_key_name (str): Status key name to check for.
2372
+
2373
+ Raises:
2374
+ exceptions.CapacityError: If the training job fails with CapacityError.
2375
+ exceptions.UnexpectedStatusException: If the training job fails.
2376
+ """
2377
+ status = desc[status_key_name]
2378
+ # If the status is capital case, then convert it to Camel case
2379
+ status = _STATUS_CODE_TABLE.get(status, status)
2380
+
2381
+ if status == "Stopped":
2382
+ logger.warning(
2383
+ "Job ended with status 'Stopped' rather than 'Completed'. "
2384
+ "This could mean the job timed out or stopped early for some other reason: "
2385
+ "Consider checking whether it completed as you expect."
2386
+ )
2387
+ elif status != "Completed":
2388
+ reason = desc.get("FailureReason", "(No reason provided)")
2389
+ job_type = status_key_name.replace("JobStatus", " job")
2390
+ troubleshooting = (
2391
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
2392
+ "sagemaker-python-sdk-troubleshooting.html"
2393
+ )
2394
+ message = (
2395
+ "Error for {job_type} {job_name}: {status}. Reason: {reason}. "
2396
+ "Check troubleshooting guide for common errors: {troubleshooting}"
2397
+ ).format(
2398
+ job_type=job_type,
2399
+ job_name=job,
2400
+ status=status,
2401
+ reason=reason,
2402
+ troubleshooting=troubleshooting,
2403
+ )
2404
+ if "CapacityError" in str(reason):
2405
+ raise exceptions.CapacityError(
2406
+ message=message,
2407
+ allowed_statuses=["Completed", "Stopped"],
2408
+ actual_status=status,
2409
+ )
2410
+ raise exceptions.UnexpectedStatusException(
2411
+ message=message,
2412
+ allowed_statuses=["Completed", "Stopped"],
2413
+ actual_status=status,
2414
+ )
2415
+
2416
+
2417
+ def _logs_init(boto_session, description, job):
2418
+ """Placeholder docstring"""
2419
+ if job == "Training":
2420
+ if "InstanceGroups" in description["ResourceConfig"]:
2421
+ instance_count = 0
2422
+ for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
2423
+ instance_count += instanceGroup["InstanceCount"]
2424
+ else:
2425
+ instance_count = description["ResourceConfig"]["InstanceCount"]
2426
+ elif job == "Transform":
2427
+ instance_count = description["TransformResources"]["InstanceCount"]
2428
+ elif job == "Processing":
2429
+ instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
2430
+ elif job == "AutoML":
2431
+ instance_count = 0
2432
+
2433
+ stream_names = [] # The list of log streams
2434
+ positions = {} # The current position in each stream, map of stream name -> position
2435
+
2436
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
2437
+ # to be interrupted by a transient exception.
2438
+ config = botocore.config.Config(retries={"max_attempts": 15})
2439
+ client = boto_session.client("logs", config=config)
2440
+ log_group = "/aws/sagemaker/" + job + "Jobs"
2441
+
2442
+ dot = False
2443
+
2444
+ color_wrap = sagemaker.core.logs.ColorWrap()
2445
+
2446
+ return instance_count, stream_names, positions, client, log_group, dot, color_wrap
2447
+
2448
+
2449
+ def _flush_log_streams(
2450
+ stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
2451
+ ):
2452
+ """Placeholder docstring"""
2453
+ if len(stream_names) < instance_count:
2454
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
2455
+ # may be dynamic until we have a stream for every instance.
2456
+ try:
2457
+ streams = client.describe_log_streams(
2458
+ logGroupName=log_group,
2459
+ logStreamNamePrefix=job_name + "/",
2460
+ orderBy="LogStreamName",
2461
+ limit=min(instance_count, 50),
2462
+ )
2463
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
2464
+
2465
+ while "nextToken" in streams:
2466
+ streams = client.describe_log_streams(
2467
+ logGroupName=log_group,
2468
+ logStreamNamePrefix=job_name + "/",
2469
+ orderBy="LogStreamName",
2470
+ limit=50,
2471
+ )
2472
+
2473
+ stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
2474
+
2475
+ positions.update(
2476
+ [
2477
+ (s, sagemaker.core.logs.Position(timestamp=0, skip=0))
2478
+ for s in stream_names
2479
+ if s not in positions
2480
+ ]
2481
+ )
2482
+ except ClientError as e:
2483
+ # On the very first training job run on an account, there's no log group until
2484
+ # the container starts logging, so ignore any errors thrown about that
2485
+ err = e.response.get("Error", {})
2486
+ if err.get("Code", None) != "ResourceNotFoundException":
2487
+ raise
2488
+
2489
+ if len(stream_names) > 0:
2490
+ if dot:
2491
+ print("")
2492
+ dot = False
2493
+ for idx, event in sagemaker.core.logs.multi_stream_iter(
2494
+ client, log_group, stream_names, positions
2495
+ ):
2496
+ color_wrap(idx, event["message"])
2497
+ ts, count = positions[stream_names[idx]]
2498
+ if event["timestamp"] == ts:
2499
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2500
+ timestamp=ts, skip=count + 1
2501
+ )
2502
+ else:
2503
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2504
+ timestamp=event["timestamp"], skip=1
2505
+ )
2506
+ else:
2507
+ dot = True
2508
+ print(".", end="")
2509
+ sys.stdout.flush()
2510
+
2511
+
2512
+ def _wait_until(callable_fn, poll=5):
2513
+ """Placeholder docstring"""
2514
+ elapsed_time = 0
2515
+ result = None
2516
+ while result is None:
2517
+ try:
2518
+ elapsed_time += poll
2519
+ time.sleep(poll)
2520
+ result = callable_fn()
2521
+ except botocore.exceptions.ClientError as err:
2522
+ # For initial 5 mins we accept/pass AccessDeniedException.
2523
+ # The reason is to await tag propagation to avoid false AccessDenied claims for an
2524
+ # access policy based on resource tags, The caveat here is for true AccessDenied
2525
+ # cases the routine will fail after 5 mins
2526
+ if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
2527
+ logger.warning(
2528
+ "Received AccessDeniedException. This could mean the IAM role does not "
2529
+ "have the resource permissions, in which case please add resource access "
2530
+ "and retry. For cases where the role has tag based resource policy, "
2531
+ "continuing to wait for tag propagation.."
2532
+ )
2533
+ continue
2534
+ raise err
2535
+ return result
2536
+
2537
+
2538
+ def _get_initial_job_state(description, status_key, wait):
2539
+ """Placeholder docstring"""
2540
+ status = description[status_key]
2541
+ job_already_completed = status in ("Completed", "Failed", "Stopped")
2542
+ return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
2543
+
2544
+
2545
+ def _rule_statuses_changed(current_statuses, last_statuses):
2546
+ """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
2547
+ if not last_statuses:
2548
+ return True
2549
+
2550
+ for current, last in zip(current_statuses, last_statuses):
2551
+ if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and (
2552
+ current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"]
2553
+ ):
2554
+ return True
2555
+
2556
+ return False
2557
+
2558
+
2559
+ def update_args(args: Dict[str, Any], **kwargs):
2560
+ """Updates the request arguments dict with the value if populated.
2561
+
2562
+ This is to handle the case that the service API doesn't like NoneTypes for argument values.
2563
+
2564
+ Args:
2565
+ request_args (Dict[str, Any]): the request arguments dict
2566
+ kwargs: key, value pairs to update the args dict
2567
+ """
2568
+ for key, value in kwargs.items():
2569
+ if value is not None:
2570
+ args.update({key: value})
2571
+
2572
+
2573
+ def production_variant(
2574
+ model_name=None,
2575
+ instance_type=None,
2576
+ initial_instance_count=None,
2577
+ variant_name="AllTraffic",
2578
+ initial_weight=1,
2579
+ accelerator_type=None,
2580
+ serverless_inference_config=None,
2581
+ volume_size=None,
2582
+ model_data_download_timeout=None,
2583
+ container_startup_health_check_timeout=None,
2584
+ managed_instance_scaling=None,
2585
+ routing_config=None,
2586
+ inference_ami_version=None,
2587
+ ):
2588
+ """Create a production variant description suitable for use in a ``ProductionVariant`` list.
2589
+
2590
+ This is also part of a ``CreateEndpointConfig`` request.
2591
+
2592
+ Args:
2593
+ model_name (str): The name of the SageMaker model this production variant references.
2594
+ instance_type (str): The EC2 instance type for this production variant. For example,
2595
+ 'ml.c4.8xlarge'.
2596
+ initial_instance_count (int): The initial instance count for this production variant
2597
+ (default: 1).
2598
+ variant_name (string): The ``VariantName`` of this production variant
2599
+ (default: 'AllTraffic').
2600
+ initial_weight (int): The relative ``InitialVariantWeight`` of this production variant
2601
+ (default: 1).
2602
+ accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
2603
+ For example, 'ml.eia1.medium'.
2604
+ For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
2605
+ serverless_inference_config (dict): Specifies configuration dict related to serverless
2606
+ endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
2607
+ object (default: None)
2608
+ volume_size (int): The size, in GB, of the ML storage volume attached to individual
2609
+ inference instance associated with the production variant. Currenly only Amazon EBS
2610
+ gp2 storage volumes are supported.
2611
+ model_data_download_timeout (int): The timeout value, in seconds, to download and extract
2612
+ model data from Amazon S3 to the individual inference instance associated with this
2613
+ production variant.
2614
+ container_startup_health_check_timeout (int): The timeout value, in seconds, for your
2615
+ inference container to pass health check by SageMaker Hosting. For more information
2616
+ about health check see:
2617
+ https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
2618
+ #your-algorithms
2619
+ -inference-algo-ping-requests
2620
+
2621
+ Returns:
2622
+ dict[str, str]: An SageMaker ``ProductionVariant`` description
2623
+ """
2624
+ production_variant_configuration = {
2625
+ "VariantName": variant_name,
2626
+ }
2627
+ if model_name:
2628
+ production_variant_configuration["ModelName"] = model_name
2629
+ production_variant_configuration["InitialVariantWeight"] = initial_weight
2630
+
2631
+ if accelerator_type:
2632
+ production_variant_configuration["AcceleratorType"] = accelerator_type
2633
+
2634
+ if managed_instance_scaling:
2635
+ production_variant_configuration["ManagedInstanceScaling"] = managed_instance_scaling
2636
+
2637
+ if serverless_inference_config:
2638
+ production_variant_configuration["ServerlessConfig"] = serverless_inference_config
2639
+ else:
2640
+ initial_instance_count = initial_instance_count or 1
2641
+ production_variant_configuration["InitialInstanceCount"] = initial_instance_count
2642
+ production_variant_configuration["InstanceType"] = instance_type
2643
+ update_args(
2644
+ production_variant_configuration,
2645
+ VolumeSizeInGB=volume_size,
2646
+ ModelDataDownloadTimeoutInSeconds=model_data_download_timeout,
2647
+ ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout,
2648
+ RoutingConfig=routing_config,
2649
+ )
2650
+
2651
+ if inference_ami_version:
2652
+ production_variant_configuration["InferenceAmiVersion"] = inference_ami_version
2653
+
2654
+ return production_variant_configuration
2655
+
2656
+
2657
+ def _has_permission_for_live_logging(boto_session, endpoint_name) -> bool:
2658
+ """Validate if customer's role has the right permission to access logs from CloudWatch"""
2659
+ try:
2660
+ cloudwatch_client = boto_session.client("logs")
2661
+ cloudwatch_client.filter_log_events(
2662
+ logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
2663
+ logStreamNamePrefix="AllTraffic/",
2664
+ )
2665
+ return True
2666
+ except ClientError as e:
2667
+ if e.response["Error"]["Code"] == "AccessDeniedException":
2668
+ LOGGER.warning(
2669
+ ("Failed to enable live logging: %s. Fallback to default logging..."),
2670
+ e,
2671
+ )
2672
+
2673
+ return False
2674
+ return True
2675
+
2676
+
2677
+ def _deploy_done(sagemaker_client, endpoint_name):
2678
+ """Placeholder docstring"""
2679
+ hosting_status_codes = {
2680
+ "OutOfService": "x",
2681
+ "Creating": "-",
2682
+ "Updating": "-",
2683
+ "InService": "!",
2684
+ "RollingBack": "<",
2685
+ "Deleting": "o",
2686
+ "Failed": "*",
2687
+ }
2688
+ in_progress_statuses = ["Creating", "Updating"]
2689
+
2690
+ desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
2691
+ status = desc["EndpointStatus"]
2692
+
2693
+ print(hosting_status_codes.get(status, "?"), end="")
2694
+ sys.stdout.flush()
2695
+
2696
+ return None if status in in_progress_statuses else desc
2697
+
2698
+
2699
+ def _live_logging_deploy_done(sagemaker_client, endpoint_name, paginator, paginator_config, poll):
2700
+ """Placeholder docstring"""
2701
+ stop = False
2702
+ endpoint_status = None
2703
+ try:
2704
+ desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
2705
+ endpoint_status = desc["EndpointStatus"]
2706
+ except ClientError as e:
2707
+ if e.response["Error"]["Code"] == "ValidationException":
2708
+ LOGGER.debug("Waiting for endpoint to become visible")
2709
+ return None
2710
+ raise e
2711
+
2712
+ try:
2713
+ # if endpoint is in an invalid state -> set stop to true, sleep, and flush the logs
2714
+ if endpoint_status != "Creating":
2715
+ stop = True
2716
+ if endpoint_status == "InService":
2717
+ LOGGER.info("Created endpoint with name %s. Waiting for it to be InService", endpoint_name)
2718
+ else:
2719
+ time.sleep(poll)
2720
+
2721
+ pages = paginator.paginate(
2722
+ logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
2723
+ logStreamNamePrefix="AllTraffic/",
2724
+ PaginationConfig=paginator_config,
2725
+ )
2726
+
2727
+ for page in pages:
2728
+ if "nextToken" in page:
2729
+ paginator_config["StartingToken"] = page["nextToken"]
2730
+ for event in page["events"]:
2731
+ LOGGER.info(event["message"])
2732
+ else:
2733
+ LOGGER.debug("No log events available")
2734
+
2735
+ # if stop is true -> return the describe response and stop polling
2736
+ if stop:
2737
+ return desc
2738
+ except ClientError as e:
2739
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
2740
+ LOGGER.debug("Waiting for endpoint log group to appear")
2741
+ return None
2742
+ raise e
2743
+
2744
+ return None
2745
+
2746
+
2747
+ def _deployment_entity_exists(describe_fn):
2748
+ """Placeholder docstring"""
2749
+ try:
2750
+ describe_fn()
2751
+ return True
2752
+ except ClientError as ce:
2753
+ error_code = ce.response["Error"]["Code"]
2754
+ if not (
2755
+ error_code == "ValidationException"
2756
+ and "Could not find" in ce.response["Error"]["Message"]
2757
+ ):
2758
+ raise ce
2759
+ return False
2760
+
2761
+
2762
+ def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name):
2763
+ """Retrieves log events from the specified CloudWatch log group and log stream.
2764
+
2765
+ Args:
2766
+ cw_client (boto3.client): A boto3 CloudWatch client.
2767
+ log_group_name (str): The name of the CloudWatch log group.
2768
+ log_stream_name (str): The name of the CloudWatch log stream.
2769
+
2770
+ Returns:
2771
+ (dict): A dictionary containing log events from CloudWatch log group and log stream.
2772
+ """
2773
+ print("Fetching logs from CloudWatch...", flush=True)
2774
+ for _ in retries(
2775
+ max_retry_count=30, # 30*10 = 5min
2776
+ exception_message_prefix="Waiting for cloudwatch stream to appear. ",
2777
+ seconds_to_sleep=10,
2778
+ ):
2779
+ try:
2780
+ return cw_client.get_log_events(
2781
+ logGroupName=log_group_name, logStreamName=log_stream_name
2782
+ )
2783
+ except ClientError as e:
2784
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
2785
+ pass
2786
+
2787
+
2788
+ def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str):
2789
+ """Describes the status of a job and returns the job description.
2790
+
2791
+ Args:
2792
+ sagemaker_client (boto3.client.sagemaker): A SageMaker client.
2793
+ job_name (str): The name of the job.
2794
+
2795
+ Returns:
2796
+ dict: The job description, or None if the job is still in progress.
2797
+ """
2798
+ inference_recommendations_job_status_codes = {
2799
+ "PENDING": ".",
2800
+ "IN_PROGRESS": ".",
2801
+ "COMPLETED": "!",
2802
+ "FAILED": "*",
2803
+ "STOPPING": "_",
2804
+ "STOPPED": "s",
2805
+ }
2806
+ in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
2807
+
2808
+ desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
2809
+ status = desc["Status"]
2810
+
2811
+ print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True)
2812
+
2813
+ if status in in_progress_statuses:
2814
+ return None
2815
+
2816
+ print("")
2817
+ return desc
2818
+
2819
+
2820
+ def _display_inference_recommendations_job_steps_status(
2821
+ sagemaker_session, sagemaker_client, job_name: str, poll: int = 60
2822
+ ):
2823
+ """Placeholder docstring"""
2824
+ cloudwatch_client = sagemaker_session.boto_session.client("logs")
2825
+ in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
2826
+ log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs"
2827
+ log_stream_name = job_name + "/execution"
2828
+
2829
+ initial_logs_batch = get_log_events_for_inference_recommender(
2830
+ cloudwatch_client, log_group_name, log_stream_name
2831
+ )
2832
+ print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True)
2833
+ events = initial_logs_batch["events"]
2834
+ print(*[event["message"] for event in events], sep="\n", flush=True)
2835
+
2836
+ next_forward_token = initial_logs_batch["nextForwardToken"] if events else None
2837
+ flush_remaining = True
2838
+ while True:
2839
+ logs_batch = (
2840
+ cloudwatch_client.get_log_events(
2841
+ logGroupName=log_group_name,
2842
+ logStreamName=log_stream_name,
2843
+ nextToken=next_forward_token,
2844
+ )
2845
+ if next_forward_token
2846
+ else cloudwatch_client.get_log_events(
2847
+ logGroupName=log_group_name, logStreamName=log_stream_name
2848
+ )
2849
+ )
2850
+
2851
+ events = logs_batch["events"]
2852
+
2853
+ desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
2854
+ status = desc["Status"]
2855
+
2856
+ if not events:
2857
+ if status in in_progress_statuses:
2858
+ time.sleep(poll)
2859
+ continue
2860
+ if flush_remaining:
2861
+ flush_remaining = False
2862
+ time.sleep(poll)
2863
+ continue
2864
+
2865
+ next_forward_token = logs_batch["nextForwardToken"]
2866
+ print(*[event["message"] for event in events], sep="\n", flush=True)
2867
+
2868
+ if status not in in_progress_statuses:
2869
+ break
2870
+
2871
+ time.sleep(poll)
2872
+
2873
+
2874
+ def _optimization_job_status(sagemaker_client, job_name):
2875
+ """Placeholder docstring"""
2876
+ optimization_job_status_codes = {
2877
+ "INPROGRESS": ".",
2878
+ "COMPLETED": "!",
2879
+ "FAILED": "*",
2880
+ "STARTING": ".",
2881
+ "STOPPING": "_",
2882
+ "STOPPED": "s",
2883
+ }
2884
+ in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"]
2885
+
2886
+ desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name)
2887
+ status = desc["OptimizationJobStatus"]
2888
+
2889
+ print(optimization_job_status_codes.get(status, "?"), end="")
2890
+ sys.stdout.flush()
2891
+
2892
+ if status in in_progress_statuses:
2893
+ return None
2894
+
2895
+ print("")
2896
+ return desc
2897
+
2898
+
2899
+ def container_def(
2900
+ image_uri,
2901
+ model_data_url=None,
2902
+ env=None,
2903
+ container_mode=None,
2904
+ image_config=None,
2905
+ accept_eula=None,
2906
+ additional_model_data_sources=None,
2907
+ model_reference_arn=None,
2908
+ ):
2909
+ """Create a definition for executing a container as part of a SageMaker model.
2910
+
2911
+ Args:
2912
+ image_uri (str): Docker image URI to run for this container.
2913
+ model_data_url (str or dict[str, Any]): S3 location of model data required by this
2914
+ container, e.g. SageMaker training job model artifacts. It can either be a string
2915
+ representing S3 URI of model data, or a dictionary representing a
2916
+ ``ModelDataSource`` object. (default: None).
2917
+ env (dict[str, str]): Environment variables to set inside the container (default: None).
2918
+ container_mode (str): The model container mode. Valid modes:
2919
+ * MultiModel: Indicates that model container can support hosting multiple models
2920
+ * SingleModel: Indicates that model container can support hosting a single model
2921
+ This is the default model container mode when container_mode = None
2922
+ image_config (dict[str, str]): Specifies whether the image of model container is pulled
2923
+ from ECR, or private registry in your VPC. By default it is set to pull model
2924
+ container image from ECR. (default: None).
2925
+ accept_eula (bool): For models that require a Model Access Config, specify True or
2926
+ False to indicate whether model terms of use have been accepted.
2927
+ The `accept_eula` value must be explicitly defined as `True` in order to
2928
+ accept the end-user license agreement (EULA) that some
2929
+ models require. (Default: None).
2930
+ additional_model_data_sources (PipelineVariable or dict): Additional location
2931
+ of SageMaker model data (default: None).
2932
+
2933
+ Returns:
2934
+ dict[str, str]: A complete container definition object usable with the CreateModel API if
2935
+ passed via `PrimaryContainers` field.
2936
+ """
2937
+ if env is None:
2938
+ env = {}
2939
+ c_def = {"Image": image_uri, "Environment": env}
2940
+
2941
+ if additional_model_data_sources:
2942
+ c_def["AdditionalModelDataSources"] = additional_model_data_sources
2943
+
2944
+ if isinstance(model_data_url, str) and (
2945
+ not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
2946
+ or accept_eula is None
2947
+ ):
2948
+ c_def["ModelDataUrl"] = model_data_url
2949
+
2950
+ elif isinstance(model_data_url, (dict, str)):
2951
+ if isinstance(model_data_url, dict):
2952
+ c_def["ModelDataSource"] = model_data_url
2953
+ else:
2954
+ c_def["ModelDataSource"] = {
2955
+ "S3DataSource": {
2956
+ "S3Uri": model_data_url,
2957
+ "S3DataType": "S3Object",
2958
+ "CompressionType": "Gzip",
2959
+ }
2960
+ }
2961
+ if accept_eula is not None:
2962
+ c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = {
2963
+ "AcceptEula": accept_eula
2964
+ }
2965
+ if model_reference_arn:
2966
+ c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = {
2967
+ "HubContentArn": model_reference_arn
2968
+ }
2969
+
2970
+ elif model_data_url is not None:
2971
+ c_def["ModelDataUrl"] = model_data_url
2972
+
2973
+ if container_mode:
2974
+ c_def["Mode"] = container_mode
2975
+ if image_config:
2976
+ c_def["ImageConfig"] = image_config
2977
+ return c_def