sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2965 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ from __future__ import absolute_import, annotations, print_function
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import re
19
+ import sys
20
+ import time
21
+ import uuid
22
+ import six
23
+ import warnings
24
+ from functools import reduce
25
+ from typing import Dict, Any, Optional, List
26
+
27
+ import boto3
28
+ import botocore
29
+ import botocore.config
30
+ from botocore.exceptions import ClientError
31
+ from sagemaker.core import exceptions
32
+ from sagemaker.core.common_utils import (
33
+ secondary_training_status_changed,
34
+ secondary_training_status_message,
35
+ )
36
+ import sagemaker.core.logs
37
+ from sagemaker.core.session_settings import SessionSettings
38
+ from sagemaker.core.common_utils import (
39
+ secondary_training_status_changed,
40
+ secondary_training_status_message,
41
+ sts_regional_endpoint,
42
+ retries,
43
+ resolve_value_from_config,
44
+ get_sagemaker_config_value,
45
+ resolve_class_attribute_from_config,
46
+ resolve_nested_dict_value_from_config,
47
+ update_nested_dictionary_with_values_from_config,
48
+ update_list_of_dicts_with_values_from_config,
49
+ format_tags,
50
+ Tags,
51
+ TagsDict,
52
+ instance_supports_kms,
53
+ create_paginator_config,
54
+ )
55
+
56
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
57
+ from sagemaker.core._studio import _append_project_tags
58
+ from sagemaker.core.config.config import load_sagemaker_config, validate_sagemaker_config
59
+ from sagemaker.core.config.config_schema import (
60
+ KEY,
61
+ TRANSFORM_JOB,
62
+ TRANSFORM_JOB_ENVIRONMENT_PATH,
63
+ TRANSFORM_JOB_KMS_KEY_ID_PATH,
64
+ TRANSFORM_OUTPUT_KMS_KEY_ID_PATH,
65
+ VOLUME_KMS_KEY_ID,
66
+ TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH,
67
+ MODEL,
68
+ MODEL_CONTAINERS_PATH,
69
+ MODEL_EXECUTION_ROLE_ARN_PATH,
70
+ MODEL_ENABLE_NETWORK_ISOLATION_PATH,
71
+ MODEL_PRIMARY_CONTAINER_PATH,
72
+ MODEL_VPC_CONFIG_PATH,
73
+ ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
74
+ KMS_KEY_ID,
75
+ ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
76
+ ENDPOINT_CONFIG,
77
+ ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
78
+ ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
79
+ ENDPOINT_CONFIG_VPC_CONFIG_PATH,
80
+ ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
81
+ ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
82
+ ENDPOINT,
83
+ INFERENCE_COMPONENT,
84
+ SAGEMAKER,
85
+ TAGS,
86
+ SESSION_DEFAULT_S3_BUCKET_PATH,
87
+ SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
88
+ )
89
+
90
+ # Setting LOGGER for backward compatibility, in case users import it...
91
+ logger = LOGGER = logging.getLogger("sagemaker")
92
+
93
+ NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
94
+ MODEL_MONITOR_ONE_TIME_SCHEDULE = "NOW"
95
+ _STATUS_CODE_TABLE = {
96
+ "COMPLETED": "Completed",
97
+ "INPROGRESS": "InProgress",
98
+ "IN_PROGRESS": "InProgress",
99
+ "FAILED": "Failed",
100
+ "STOPPED": "Stopped",
101
+ "STOPPING": "Stopping",
102
+ "STARTING": "Starting",
103
+ "PENDING": "Pending",
104
+ }
105
+ EP_LOGGER_POLL = 30
106
+ DEFAULT_EP_POLL = 30
107
+
108
+
109
+ class LogState(object):
110
+ """Placeholder docstring"""
111
+
112
+ STARTING = 1
113
+ WAIT_IN_PROGRESS = 2
114
+ TAILING = 3
115
+ JOB_COMPLETE = 4
116
+ COMPLETE = 5
117
+
118
+
119
+ class Session(object): # pylint: disable=too-many-public-methods
120
+ """Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
121
+
122
+ This class provides convenient methods for manipulating entities and resources that Amazon
123
+ SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
124
+ AWS service calls are delegated to an underlying Boto3 session, which by default
125
+ is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
126
+ that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
127
+ bucket based on a naming convention which includes the current AWS account ID.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ boto_session=None,
133
+ sagemaker_client=None,
134
+ sagemaker_runtime_client=None,
135
+ sagemaker_featurestore_runtime_client=None,
136
+ default_bucket=None,
137
+ sagemaker_config: dict = None,
138
+ settings=None,
139
+ sagemaker_metrics_client=None,
140
+ default_bucket_prefix: str = None,
141
+ ):
142
+ """Initialize a SageMaker ``Session``.
143
+
144
+ Args:
145
+ boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
146
+ calls are delegated to (default: None). If not provided, one is created with
147
+ default AWS configuration chain.
148
+ sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
149
+ calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
150
+ ``Session`` use this client. If not provided, one will be created using this
151
+ instance's ``boto_session``.
152
+ sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
153
+ ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
154
+ using this ``Session`` use this client. If not provided, one will be created using
155
+ this instance's ``boto_session``.
156
+ sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client):
157
+ Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker
158
+ (default: None). If not provided, one will be created using
159
+ this instance's ``boto_session``.
160
+ default_bucket (str): The default Amazon S3 bucket to be used by this session.
161
+ This will be created the next time an Amazon S3 bucket is needed (by calling
162
+ :func:`default_bucket`).
163
+ If not provided, it will be fetched from the sagemaker_config. If not configured
164
+ there either, a default bucket will be created based on the following format:
165
+ "sagemaker-{region}-{aws-account-id}".
166
+ Example: "sagemaker-my-custom-bucket".
167
+ sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
168
+ Client which makes SageMaker Metrics related calls to Amazon SageMaker
169
+ (default: None). If not provided, one will be created using
170
+ this instance's ``boto_session``.
171
+ default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default:
172
+ None). If provided and where applicable, it will be used by the SDK to construct
173
+ default S3 URIs, in the format:
174
+ `s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>`
175
+ This parameter can also be specified via `{sagemaker_config}` instead of here. If
176
+ not provided here or within `{sagemaker_config}`, default S3 URIs will have the
177
+ format: `s3://{default_bucket}/<rest of object key>`
178
+ """
179
+
180
+ # sagemaker_config is validated and initialized inside :func:`_initialize`,
181
+ # so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
182
+ # _default_bucket_name_override will be set again inside :func:`_initialize`.
183
+ self.endpoint_arn = None
184
+ self._default_bucket = None
185
+ self._default_bucket_name_override = default_bucket
186
+ # this may also be set again inside :func:`_initialize` if it is None
187
+ self.default_bucket_prefix = default_bucket_prefix
188
+ self._default_bucket_set_by_sdk = False
189
+
190
+ self.s3_resource = None
191
+ self.s3_client = None
192
+ self.resource_groups_client = None
193
+ self.resource_group_tagging_client = None
194
+ self._config = None
195
+ self.lambda_client = None
196
+ self.settings = settings if settings else SessionSettings()
197
+
198
+ self._initialize(
199
+ boto_session=boto_session,
200
+ sagemaker_client=sagemaker_client,
201
+ sagemaker_runtime_client=sagemaker_runtime_client,
202
+ sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
203
+ sagemaker_metrics_client=sagemaker_metrics_client,
204
+ sagemaker_config=sagemaker_config,
205
+ )
206
+
207
+ def _initialize(
208
+ self,
209
+ boto_session,
210
+ sagemaker_client,
211
+ sagemaker_runtime_client,
212
+ sagemaker_featurestore_runtime_client,
213
+ sagemaker_metrics_client,
214
+ sagemaker_config: dict = None,
215
+ ):
216
+ """Initialize this SageMaker Session.
217
+
218
+ Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
219
+ Sets the region_name.
220
+ """
221
+
222
+ self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
223
+
224
+ self._region_name = self.boto_session.region_name
225
+ if self._region_name is None:
226
+ raise ValueError(
227
+ "Must setup local AWS configuration with a region supported by SageMaker."
228
+ )
229
+
230
+ # Make use of user_agent_extra field of the botocore_config object
231
+ # to append SageMaker Python SDK specific user_agent suffix
232
+ # to the current User-Agent header value from boto3
233
+ # This config will also make sure that user_agent never fails to log the User-Agent string
234
+ # even if boto User-Agent header format is updated in the future
235
+ # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
236
+
237
+ # Create sagemaker_client with the botocore_config object
238
+ # This config is customized to append SageMaker Python SDK specific user_agent suffix
239
+ if sagemaker_client is not None:
240
+ self.sagemaker_client = sagemaker_client
241
+ else:
242
+ from sagemaker.core.user_agent import get_user_agent_extra_suffix
243
+ config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
244
+ self.sagemaker_client = self.boto_session.client("sagemaker", config=config)
245
+
246
+ if sagemaker_runtime_client is not None:
247
+ self.sagemaker_runtime_client = sagemaker_runtime_client
248
+ else:
249
+ config = botocore.config.Config(read_timeout=80)
250
+ self.sagemaker_runtime_client = self.boto_session.client(
251
+ "runtime.sagemaker", config=config
252
+ )
253
+
254
+ if sagemaker_featurestore_runtime_client:
255
+ self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
256
+ else:
257
+ self.sagemaker_featurestore_runtime_client = self.boto_session.client(
258
+ "sagemaker-featurestore-runtime"
259
+ )
260
+
261
+ if sagemaker_metrics_client:
262
+ self.sagemaker_metrics_client = sagemaker_metrics_client
263
+ else:
264
+ self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
265
+
266
+ self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
267
+ self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
268
+
269
+ self.local_mode = False
270
+
271
+ if sagemaker_config:
272
+ validate_sagemaker_config(sagemaker_config)
273
+ self.sagemaker_config = sagemaker_config
274
+ else:
275
+ # self.s3_resource might be None. If it is None, load_sagemaker_config will
276
+ # create a default S3 resource, but only if it needs to fetch from S3
277
+ self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
278
+
279
+ # after sagemaker_config initialization, update self._default_bucket_name_override if needed
280
+ self._default_bucket_name_override = resolve_value_from_config(
281
+ direct_input=self._default_bucket_name_override,
282
+ config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
283
+ sagemaker_session=self,
284
+ )
285
+ # after sagemaker_config initialization, update self.default_bucket_prefix if needed
286
+ self.default_bucket_prefix = resolve_value_from_config(
287
+ direct_input=self.default_bucket_prefix,
288
+ config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
289
+ sagemaker_session=self,
290
+ )
291
+
292
+ def account_id(self) -> str:
293
+ """Get the AWS account id of the caller.
294
+
295
+ Returns:
296
+ AWS account ID.
297
+ """
298
+ region = self.boto_session.region_name
299
+ sts_client = self.boto_session.client(
300
+ "sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
301
+ )
302
+ return sts_client.get_caller_identity()["Account"]
303
+
304
+ @property
305
+ def config(self) -> Dict | None:
306
+ """The config for the local mode, unused in a normal session"""
307
+ return self._config
308
+
309
+ @config.setter
310
+ def config(self, value: Dict | None):
311
+ """The config for the local mode, unused in a normal session"""
312
+ self._config = value
313
+
314
+ @property
315
+ def boto_region_name(self):
316
+ """Placeholder docstring"""
317
+ return self._region_name
318
+
319
+ def get_caller_identity_arn(self):
320
+ """Returns the ARN user or role whose credentials are used to call the API.
321
+
322
+ Returns:
323
+ str: The ARN user or role
324
+ """
325
+ if os.path.exists(NOTEBOOK_METADATA_FILE):
326
+ with open(NOTEBOOK_METADATA_FILE, "rb") as f:
327
+ metadata = json.loads(f.read())
328
+ instance_name = metadata.get("ResourceName")
329
+ domain_id = metadata.get("DomainId")
330
+ user_profile_name = metadata.get("UserProfileName")
331
+ execution_role_arn = metadata.get("ExecutionRoleArn")
332
+ try:
333
+ if domain_id is None:
334
+ instance_desc = self.sagemaker_client.describe_notebook_instance(
335
+ NotebookInstanceName=instance_name
336
+ )
337
+ return instance_desc["RoleArn"]
338
+
339
+ # find execution role from the metadata file if present
340
+ if execution_role_arn is not None:
341
+ return execution_role_arn
342
+
343
+ user_profile_desc = self.sagemaker_client.describe_user_profile(
344
+ DomainId=domain_id, UserProfileName=user_profile_name
345
+ )
346
+
347
+ # First, try to find role in userSettings
348
+ if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"):
349
+ return user_profile_desc["UserSettings"]["ExecutionRole"]
350
+
351
+ # If not found, fallback to the domain
352
+ domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
353
+ return domain_desc["DefaultUserSettings"]["ExecutionRole"]
354
+ except ClientError:
355
+ logger.debug(
356
+ "Couldn't call 'describe_notebook_instance' to get the Role "
357
+ "ARN of the instance %s.",
358
+ instance_name,
359
+ )
360
+
361
+ assumed_role = self.boto_session.client(
362
+ "sts",
363
+ region_name=self.boto_region_name,
364
+ endpoint_url=sts_regional_endpoint(self.boto_region_name),
365
+ ).get_caller_identity()["Arn"]
366
+
367
+ role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
368
+
369
+ # Call IAM to get the role's path
370
+ role_name = role[role.rfind("/") + 1 :]
371
+ try:
372
+ role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
373
+ except ClientError:
374
+ logger.warning(
375
+ "Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.",
376
+ role_name,
377
+ )
378
+
379
+ # This conditional has been present since the inception of SageMaker
380
+ # Guessing this conditional's purpose was to handle lack of IAM permissions
381
+ # https://github.com/aws/sagemaker-python-sdk/issues/2089#issuecomment-791802713
382
+ if "AmazonSageMaker-ExecutionRole" in assumed_role:
383
+ logger.warning(
384
+ "Assuming role was created in SageMaker AWS console, "
385
+ "as the name contains `AmazonSageMaker-ExecutionRole`. "
386
+ "Defaulting to Role ARN with service-role in path. "
387
+ "If this Role ARN is incorrect, please add "
388
+ "IAM read permissions to your role or supply the "
389
+ "Role Arn directly."
390
+ )
391
+ role = re.sub(
392
+ r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
393
+ r"\1iam::\2:role/service-role/\3",
394
+ assumed_role,
395
+ )
396
+
397
+ return role
398
+
399
+ def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra_args=None):
400
+ """Upload local file or directory to S3.
401
+
402
+ If a single file is specified for upload, the resulting S3 object key is
403
+ ``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
404
+ If a directory is specified for upload, the API uploads all content, recursively,
405
+ preserving relative structure of subdirectories. The resulting object key names are:
406
+ ``{key_prefix}/{relative_subdirectory_path}/filename``.
407
+
408
+ Args:
409
+ path (str): Path (absolute or relative) of local file or directory to upload.
410
+ bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
411
+ default bucket of the ``Session`` is used (if default bucket does not exist, the
412
+ ``Session`` creates it).
413
+ key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
414
+ prefix to create a directory structure for the bucket content that it display in
415
+ the S3 console.
416
+ extra_args (dict): Optional extra arguments that may be passed to the upload operation.
417
+ Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
418
+ ExtraArgs parameter documentation here:
419
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
420
+
421
+ Returns:
422
+ str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
423
+ the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
424
+ If a directory is specified in the path argument, the URI format is
425
+ ``s3://{bucket name}/{key_prefix}``.
426
+ """
427
+ bucket, key_prefix = self.determine_bucket_and_prefix(
428
+ bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
429
+ )
430
+
431
+ # Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
432
+ files = []
433
+ key_suffix = None
434
+ if os.path.isdir(path):
435
+ for dirpath, _, filenames in os.walk(path):
436
+ for name in filenames:
437
+ local_path = os.path.join(dirpath, name)
438
+ s3_relative_prefix = (
439
+ "" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
440
+ )
441
+ s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
442
+ files.append((local_path, s3_key))
443
+ else:
444
+ _, name = os.path.split(path)
445
+ s3_key = "{}/{}".format(key_prefix, name)
446
+ files.append((path, s3_key))
447
+ key_suffix = name
448
+
449
+ if self.s3_resource is None:
450
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
451
+ else:
452
+ s3 = self.s3_resource
453
+
454
+ for local_path, s3_key in files:
455
+ s3.Object(bucket, s3_key).upload_file(
456
+ local_path, Callback=callback, ExtraArgs=extra_args
457
+ )
458
+
459
+ s3_uri = "s3://{}/{}".format(bucket, key_prefix)
460
+ # If a specific file was used as input (instead of a directory), we return the full S3 key
461
+ # of the uploaded object. This prevents unintentionally using other files under the same
462
+ # prefix during training.
463
+ if key_suffix:
464
+ s3_uri = "{}/{}".format(s3_uri, key_suffix)
465
+ return s3_uri
466
+
467
+ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
468
+ """Upload a string as a file body.
469
+ Args:
470
+ body (str): String representing the body of the file.
471
+ bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
472
+ default bucket of the ``Session`` is used (if default bucket does not exist, the
473
+ ``Session`` creates it).
474
+ key (str): S3 object key. This is the s3 path to the file.
475
+ kms_key (str): The KMS key to use for encrypting the file.
476
+ Returns:
477
+ str: The S3 URI of the uploaded file.
478
+ The URI format is: ``s3://{bucket name}/{key}``.
479
+ """
480
+ if self.s3_resource is None:
481
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
482
+ else:
483
+ s3 = self.s3_resource
484
+
485
+ s3_object = s3.Object(bucket_name=bucket, key=key)
486
+
487
+ if kms_key is not None:
488
+ s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
489
+ else:
490
+ s3_object.put(Body=body)
491
+
492
+ s3_uri = "s3://{}/{}".format(bucket, key)
493
+ return s3_uri
494
+
495
+ def download_data(self, path, bucket, key_prefix="", extra_args=None):
496
+ """Download file or directory from S3.
497
+ Args:
498
+ path (str): Local path where the file or directory should be downloaded to.
499
+ bucket (str): Name of the S3 Bucket to download from.
500
+ key_prefix (str): Optional S3 object key name prefix.
501
+ extra_args (dict): Optional extra arguments that may be passed to the
502
+ download operation. Please refer to the ExtraArgs parameter in the boto3
503
+ documentation here:
504
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
505
+ Returns:
506
+ list[str]: List of local paths of downloaded files
507
+ """
508
+ # Initialize the S3 client.
509
+ if self.s3_client is None:
510
+ s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
511
+ else:
512
+ s3 = self.s3_client
513
+
514
+ # Initialize the variables used to loop through the contents of the S3 bucket.
515
+ keys = []
516
+ directories = []
517
+ next_token = ""
518
+ base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
519
+
520
+ # Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
521
+ # a "keys" list.
522
+ while next_token is not None:
523
+ request_parameters = base_parameters.copy()
524
+ if next_token != "":
525
+ request_parameters.update({"ContinuationToken": next_token})
526
+ response = s3.list_objects_v2(**request_parameters)
527
+ contents = response.get("Contents", None)
528
+ if not contents:
529
+ logger.info(
530
+ "Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
531
+ )
532
+ return []
533
+ # For each object, save its key or directory.
534
+ for s3_object in contents:
535
+ key: str = s3_object.get("Key")
536
+ obj_size = s3_object.get("Size")
537
+ if key.endswith("/") and int(obj_size) == 0:
538
+ directories.append(os.path.join(path, key))
539
+ else:
540
+ keys.append(key)
541
+ next_token = response.get("NextContinuationToken")
542
+
543
+ # For each object key, create the directory on the local machine if needed, and then
544
+ # download the file.
545
+ downloaded_paths = []
546
+ for dir_path in directories:
547
+ os.makedirs(os.path.dirname(dir_path), exist_ok=True)
548
+ for key in keys:
549
+ tail_s3_uri_path = os.path.basename(key)
550
+ if not os.path.splitext(key_prefix)[1]:
551
+ tail_s3_uri_path = os.path.relpath(key, key_prefix)
552
+ destination_path = os.path.join(path, tail_s3_uri_path)
553
+ if not os.path.exists(os.path.dirname(destination_path)):
554
+ os.makedirs(os.path.dirname(destination_path), exist_ok=True)
555
+ s3.download_file(
556
+ Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
557
+ )
558
+ downloaded_paths.append(destination_path)
559
+ return downloaded_paths
560
+
561
+ def read_s3_file(self, bucket, key_prefix):
562
+ """Read a single file from S3.
563
+
564
+ Args:
565
+ bucket (str): Name of the S3 Bucket to download from.
566
+ key_prefix (str): S3 object key name prefix.
567
+
568
+ Returns:
569
+ str: The body of the s3 file as a string.
570
+ """
571
+ if self.s3_client is None:
572
+ s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
573
+ else:
574
+ s3 = self.s3_client
575
+
576
+ # Explicitly passing a None kms_key to boto3 throws a validation error.
577
+ s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
578
+
579
+ return s3_object["Body"].read().decode("utf-8")
580
+
581
+ def list_s3_files(self, bucket, key_prefix):
582
+ """Lists the S3 files given an S3 bucket and key.
583
+ Args:
584
+ bucket (str): Name of the S3 Bucket to download from.
585
+ key_prefix (str): S3 object key name prefix.
586
+ Returns:
587
+ [str]: The list of files at the S3 path.
588
+ """
589
+ if self.s3_resource is None:
590
+ s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
591
+ else:
592
+ s3 = self.s3_resource
593
+
594
+ s3_bucket = s3.Bucket(name=bucket)
595
+ s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
596
+ return [s3_object.key for s3_object in s3_objects]
597
+
598
+ def default_bucket(self):
599
+ """Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
600
+
601
+ This function will create the s3 bucket if it does not exist.
602
+
603
+ Returns:
604
+ str: The name of the default bucket. If the name was not explicitly specified through
605
+ the Session or sagemaker_config, the bucket will take the form:
606
+ ``sagemaker-{region}-{AWS account ID}``.
607
+ """
608
+
609
+ if self._default_bucket:
610
+ return self._default_bucket
611
+
612
+ region = self.boto_session.region_name
613
+
614
+ default_bucket = self._default_bucket_name_override
615
+ if not default_bucket:
616
+ default_bucket = self.generate_default_sagemaker_bucket_name(self.boto_session)
617
+ self._default_bucket_set_by_sdk = True
618
+
619
+ self._create_s3_bucket_if_it_does_not_exist(
620
+ bucket_name=default_bucket,
621
+ region=region,
622
+ )
623
+
624
+ self._default_bucket = default_bucket
625
+
626
+ return self._default_bucket
627
+
628
+ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
629
+ """Creates an S3 Bucket if it does not exist.
630
+
631
+ Also swallows a few common exceptions that indicate that the bucket already exists or
632
+ that it is being created.
633
+
634
+ Args:
635
+ bucket_name (str): Name of the S3 bucket to be created.
636
+ region (str): The region in which to create the bucket.
637
+
638
+ Raises:
639
+ botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
640
+ creation.
641
+ If the exception is due to the bucket already existing or
642
+ already being created, no exception is raised.
643
+ """
644
+ if self.s3_resource is None:
645
+ s3 = self.boto_session.resource("s3", region_name=region)
646
+ else:
647
+ s3 = self.s3_resource
648
+
649
+ bucket = s3.Bucket(name=bucket_name)
650
+ if bucket.creation_date is None:
651
+ self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
652
+
653
+ elif self._default_bucket_set_by_sdk:
654
+ self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
655
+
656
+ expected_bucket_owner_id = self.account_id()
657
+ self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
658
+
659
+ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
660
+ """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
661
+
662
+ Args:
663
+ bucket_name (str): Name of the S3 bucket
664
+ s3 (str): S3 object from boto session
665
+ expected_bucket_owner_id (str): Owner ID string
666
+
667
+ """
668
+ try:
669
+ s3.meta.client.head_bucket(
670
+ Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
671
+ )
672
+ except ClientError as e:
673
+ error_code = e.response["Error"]["Code"]
674
+ message = e.response["Error"]["Message"]
675
+ if error_code == "403" and message == "Forbidden":
676
+ LOGGER.error(
677
+ "Since default_bucket param was not set, SageMaker Python SDK tried to use "
678
+ "%s bucket. "
679
+ "This bucket cannot be configured to use as it is not owned by Account %s. "
680
+ "To unblock it's recommended to use custom default_bucket "
681
+ "parameter in sagemaker.Session",
682
+ bucket_name,
683
+ expected_bucket_owner_id,
684
+ )
685
+ raise
686
+
687
+ def general_bucket_check_if_user_has_permission(
688
+ self, bucket_name, s3, bucket, region, bucket_creation_date_none
689
+ ):
690
+ """Checks if the person running has the permissions to the bucket
691
+
692
+ If there is any other error that comes up with calling head bucket, it is raised up here
693
+ If there is no bucket , it will create one
694
+
695
+ Args:
696
+ bucket_name (str): Name of the S3 bucket
697
+ s3 (str): S3 object from boto session
698
+ region (str): The region in which to create the bucket.
699
+ bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
700
+ """
701
+ try:
702
+ s3.meta.client.head_bucket(Bucket=bucket_name)
703
+ except ClientError as e:
704
+ error_code = e.response["Error"]["Code"]
705
+ message = e.response["Error"]["Message"]
706
+ # bucket does not exist or forbidden to access
707
+ if bucket_creation_date_none:
708
+ if error_code == "404" and message == "Not Found":
709
+ self.create_bucket_for_not_exist_error(bucket_name, region, s3)
710
+ elif error_code == "403" and message == "Forbidden":
711
+ LOGGER.error(
712
+ "Bucket %s exists, but access is forbidden. Please try again after "
713
+ "adding appropriate access.",
714
+ bucket.name,
715
+ )
716
+ raise
717
+ else:
718
+ raise
719
+
720
+ def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
721
+ """Creates the S3 bucket in the given region
722
+
723
+ Args:
724
+ bucket_name (str): Name of the S3 bucket
725
+ s3 (str): S3 object from boto session
726
+ region (str): The region in which to create the bucket.
727
+ """
728
+ # bucket does not exist, create one
729
+ try:
730
+ if region == "us-east-1":
731
+ # 'us-east-1' cannot be specified because it is the default region:
732
+ # https://github.com/boto/boto3/issues/125
733
+ s3.create_bucket(Bucket=bucket_name)
734
+ else:
735
+ s3.create_bucket(
736
+ Bucket=bucket_name,
737
+ CreateBucketConfiguration={"LocationConstraint": region},
738
+ )
739
+
740
+ logger.info("Created S3 bucket: %s", bucket_name)
741
+ except ClientError as e:
742
+ error_code = e.response["Error"]["Code"]
743
+ message = e.response["Error"]["Message"]
744
+
745
+ if error_code == "OperationAborted" and "conflicting conditional operation" in message:
746
+ # If this bucket is already being concurrently created,
747
+ # we don't need to create it again.
748
+ pass
749
+ else:
750
+ raise
751
+
752
+ def generate_default_sagemaker_bucket_name(self, boto_session):
753
+ """Generates a name for the default sagemaker S3 bucket.
754
+
755
+ Args:
756
+ boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
757
+ """
758
+ region = boto_session.region_name
759
+ account = boto_session.client(
760
+ "sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
761
+ ).get_caller_identity()["Account"]
762
+ return "sagemaker-{}-{}".format(region, account)
763
+
764
+ def determine_bucket_and_prefix(
765
+ self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
766
+ ):
767
+ """Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
768
+
769
+ Args:
770
+ bucket (Optional[str]): S3 Bucket to use (if it exists)
771
+ key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
772
+ sagemaker_session (sagemaker.core.session.Session): Session to fetch a default bucket and
773
+ prefix from, if bucket doesn't exist. Expected to exist
774
+
775
+ Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
776
+ """
777
+ if bucket:
778
+ final_bucket = bucket
779
+ final_key_prefix = key_prefix
780
+ else:
781
+ final_bucket = sagemaker_session.default_bucket()
782
+
783
+ # default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
784
+ # exist and we are using the Session's default_bucket.
785
+ final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
786
+
787
+ # We should not append default_bucket_prefix even if the bucket exists but is equal to the
788
+ # default_bucket, because either:
789
+ # (1) the bucket was explicitly passed in by the user and just happens to be the same as the
790
+ # default_bucket (in which case we don't want to change the user's input), or
791
+ # (2) the default_bucket was fetched from Session earlier already (and the default prefix
792
+ # should have been fetched then as well), and then this function was
793
+ # called with it. If we appended the default prefix here, we would be appending it more than
794
+ # once in total.
795
+
796
+ return final_bucket, final_key_prefix
797
+
798
+ def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
799
+ """Appends tags specified in the sagemaker_config to the given list of tags.
800
+
801
+ To minimize the chance of duplicate tags being applied, this is intended to be used
802
+ immediately before calls to sagemaker_client, rather than during initialization of
803
+ classes like EstimatorBase.
804
+
805
+ Args:
806
+ tags: The list of tags to append to.
807
+ config_path_to_tags: The path to look up tags in the config.
808
+
809
+ Returns:
810
+ A list of tags.
811
+ """
812
+ config_tags = get_sagemaker_config_value(self, config_path_to_tags)
813
+
814
+ if config_tags is None or len(config_tags) == 0:
815
+ return tags
816
+
817
+ all_tags = tags or []
818
+ for config_tag in config_tags:
819
+ config_tag_key = config_tag[KEY]
820
+ if not any(tag.get("Key", None) == config_tag_key for tag in all_tags):
821
+ # This check prevents new tags with duplicate keys from being added
822
+ # (to prevent API failure and/or overwriting of tags). If there is a conflict,
823
+ # the user-provided tag should take precedence over the config-provided tag.
824
+ # Note: this does not check user-provided tags for conflicts with other
825
+ # user-provided tags.
826
+ all_tags.append(config_tag)
827
+
828
+ _log_sagemaker_config_merge(
829
+ source_value=tags,
830
+ config_value=config_tags,
831
+ merged_source_and_config_value=all_tags,
832
+ config_key_path=config_path_to_tags,
833
+ )
834
+
835
+ return all_tags
836
+
837
+ def endpoint_from_production_variants(
838
+ self,
839
+ name,
840
+ production_variants,
841
+ tags=None,
842
+ kms_key=None,
843
+ wait=True,
844
+ data_capture_config_dict=None,
845
+ async_inference_config_dict=None,
846
+ explainer_config_dict=None,
847
+ live_logging=False,
848
+ vpc_config=None,
849
+ enable_network_isolation=None,
850
+ role=None,
851
+ ):
852
+ """Create an SageMaker ``Endpoint`` from a list of production variants.
853
+
854
+ Args:
855
+ name (str): The name of the ``Endpoint`` to create.
856
+ production_variants (list[dict[str, str]]): The list of production variants to deploy.
857
+ tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
858
+ (default: None).
859
+ kms_key (str): The KMS key that is used to encrypt the data on the storage volume
860
+ attached to the instance hosting the endpoint.
861
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
862
+ (default: True).
863
+ data_capture_config_dict (dict): Specifies configuration related to Endpoint data
864
+ capture for use with Amazon SageMaker Model Monitoring. Default: None.
865
+ async_inference_config_dict (dict) : specifies configuration related to async endpoint.
866
+ Use this configuration when trying to create async endpoint and make async inference
867
+ (default: None)
868
+ explainer_config_dict (dict) : Specifies configuration related to explainer.
869
+ Use this configuration when trying to use online explainability.
870
+ (default: None).
871
+ vpc_config (dict[str, list[str]]:
872
+ The VpcConfig set on the model (default: None).
873
+ * 'Subnets' (list[str]): List of subnet ids.
874
+ * 'SecurityGroupIds' (list[str]): List of security group ids.
875
+ enable_network_isolation (Boolean): Default False.
876
+ If True, enables network isolation in the endpoint, isolating the model
877
+ container. No inbound or outbound network calls can be made to
878
+ or from the model container.
879
+ role (str): An AWS IAM role (either name or full ARN). The Amazon
880
+ SageMaker training jobs and APIs that create Amazon SageMaker
881
+ endpoints use this role to access training data and model
882
+ artifacts. After the endpoint is created, the inference code
883
+ might use the IAM role if it needs to access some AWS resources.
884
+ (default: None).
885
+ Returns:
886
+ str: The name of the created ``Endpoint``.
887
+ """
888
+
889
+ supports_kms = any(
890
+ [
891
+ instance_supports_kms(production_variant["InstanceType"])
892
+ for production_variant in production_variants
893
+ if "InstanceType" in production_variant
894
+ ]
895
+ )
896
+
897
+ update_list_of_dicts_with_values_from_config(
898
+ production_variants,
899
+ ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
900
+ required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
901
+ sagemaker_session=self,
902
+ )
903
+
904
+ config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
905
+
906
+ kms_key = (
907
+ resolve_value_from_config(
908
+ kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
909
+ )
910
+ if supports_kms
911
+ else kms_key
912
+ )
913
+
914
+ vpc_config = resolve_value_from_config(
915
+ vpc_config,
916
+ ENDPOINT_CONFIG_VPC_CONFIG_PATH,
917
+ sagemaker_session=self,
918
+ )
919
+
920
+ enable_network_isolation = resolve_value_from_config(
921
+ enable_network_isolation,
922
+ ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
923
+ sagemaker_session=self,
924
+ )
925
+
926
+ role = resolve_value_from_config(
927
+ role,
928
+ ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
929
+ sagemaker_session=self,
930
+ sagemaker_config=load_sagemaker_config() if (self is None) else None,
931
+ )
932
+
933
+ # For Amazon SageMaker inference component based endpoint, it will not pass
934
+ # Model names during endpoint creation. Instead, ExecutionRoleArn will be
935
+ # needed in the endpoint config to create Endpoint
936
+ model_names = [pv["ModelName"] for pv in production_variants if "ModelName" in pv]
937
+ if len(model_names) == 0:
938
+ # Currently, SageMaker Python SDK allow using RoleName to deploy models.
939
+ # Use expand_role method to handle this situation.
940
+ role = self.expand_role(role)
941
+ config_options["ExecutionRoleArn"] = role
942
+ endpoint_config_tags = _append_project_tags(format_tags(tags))
943
+ endpoint_tags = _append_project_tags(format_tags(tags))
944
+
945
+ endpoint_config_tags = self._append_sagemaker_config_tags(
946
+ endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
947
+ )
948
+ if endpoint_config_tags:
949
+ config_options["Tags"] = endpoint_config_tags
950
+ if kms_key:
951
+ config_options["KmsKeyId"] = kms_key
952
+ if data_capture_config_dict is not None:
953
+ inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config(
954
+ data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self
955
+ )
956
+ config_options["DataCaptureConfig"] = inferred_data_capture_config_dict
957
+ if async_inference_config_dict is not None:
958
+ inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
959
+ async_inference_config_dict,
960
+ ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
961
+ sagemaker_session=self,
962
+ )
963
+ config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
964
+ if explainer_config_dict is not None:
965
+ config_options["ExplainerConfig"] = explainer_config_dict
966
+ if vpc_config is not None:
967
+ config_options["VpcConfig"] = vpc_config
968
+ if enable_network_isolation is not None:
969
+ config_options["EnableNetworkIsolation"] = enable_network_isolation
970
+ if role is not None:
971
+ config_options["ExecutionRoleArn"] = role
972
+
973
+ logger.info("Creating endpoint-config with name %s", name)
974
+ self.sagemaker_client.create_endpoint_config(**config_options)
975
+
976
+ return self.create_endpoint(
977
+ endpoint_name=name,
978
+ config_name=name,
979
+ tags=endpoint_tags,
980
+ wait=wait,
981
+ live_logging=live_logging,
982
+ )
983
+
984
+ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live_logging=False):
985
+ """Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request.
986
+
987
+ Once the ``Endpoint`` is created, client applications can send requests to obtain
988
+ inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API.
989
+
990
+ Args:
991
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created.
992
+ config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
993
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
994
+ (default: True).
995
+ tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
996
+ (default: None).
997
+
998
+ Returns:
999
+ str: Name of the Amazon SageMaker ``Endpoint`` created.
1000
+
1001
+ Raises:
1002
+ botocore.exceptions.ClientError: If Sagemaker throws an exception while creating
1003
+ endpoint.
1004
+ """
1005
+ logger.info("Creating endpoint with name %s", endpoint_name)
1006
+
1007
+ tags = format_tags(tags) or []
1008
+ tags = _append_project_tags(tags)
1009
+ tags = self._append_sagemaker_config_tags(
1010
+ tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS)
1011
+ )
1012
+ try:
1013
+ res = self.sagemaker_client.create_endpoint(
1014
+ EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
1015
+ )
1016
+ if res:
1017
+ self.endpoint_arn = res["EndpointArn"]
1018
+
1019
+ if wait:
1020
+ self.wait_for_endpoint(endpoint_name, live_logging=live_logging)
1021
+ return endpoint_name
1022
+ except Exception as e:
1023
+ troubleshooting = (
1024
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
1025
+ "sagemaker-python-sdk-troubleshooting.html"
1026
+ "#sagemaker-python-sdk-troubleshooting-create-endpoint"
1027
+ )
1028
+ logger.error(
1029
+ "Please check the troubleshooting guide for common errors: %s", troubleshooting
1030
+ )
1031
+ raise e
1032
+
1033
+ def wait_for_endpoint(self, endpoint, poll=DEFAULT_EP_POLL, live_logging=False):
1034
+ """Wait for an Amazon SageMaker endpoint deployment to complete.
1035
+
1036
+ Args:
1037
+ endpoint (str): Name of the ``Endpoint`` to wait for.
1038
+ poll (int): Polling interval in seconds (default: 30).
1039
+
1040
+ Raises:
1041
+ exceptions.CapacityError: If the endpoint creation job fails with CapacityError.
1042
+ exceptions.UnexpectedStatusException: If the endpoint creation job fails.
1043
+
1044
+ Returns:
1045
+ dict: Return value from the ``DescribeEndpoint`` API.
1046
+ """
1047
+
1048
+ if not live_logging or not _has_permission_for_live_logging(self.boto_session, endpoint):
1049
+ desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll)
1050
+ else:
1051
+ cloudwatch_client = self.boto_session.client("logs")
1052
+ paginator = cloudwatch_client.get_paginator("filter_log_events")
1053
+ paginator_config = create_paginator_config()
1054
+ desc = _wait_until(
1055
+ lambda: _live_logging_deploy_done(
1056
+ self.sagemaker_client, endpoint, paginator, paginator_config, EP_LOGGER_POLL
1057
+ ),
1058
+ poll=EP_LOGGER_POLL,
1059
+ )
1060
+ status = desc["EndpointStatus"]
1061
+
1062
+ if status != "InService":
1063
+ reason = desc.get("FailureReason", None)
1064
+ trouble_shooting = (
1065
+ "Try changing the instance type or reference the troubleshooting page "
1066
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-troubleshooting"
1067
+ ".html"
1068
+ )
1069
+ message = "Error hosting endpoint {}: {}. Reason: {}. {}".format(
1070
+ endpoint, status, reason, trouble_shooting
1071
+ )
1072
+ if "CapacityError" in str(reason):
1073
+ raise exceptions.CapacityError(
1074
+ message=message,
1075
+ allowed_statuses=["InService"],
1076
+ actual_status=status,
1077
+ )
1078
+ raise exceptions.UnexpectedStatusException(
1079
+ message=message,
1080
+ allowed_statuses=["InService"],
1081
+ actual_status=status,
1082
+ )
1083
+ return desc
1084
+
1085
+ def create_inference_component(
1086
+ self,
1087
+ inference_component_name: str,
1088
+ endpoint_name: str,
1089
+ variant_name: str,
1090
+ specification: Dict[str, Any],
1091
+ runtime_config: Optional[Dict[str, Any]] = None,
1092
+ tags: Optional[Tags] = None,
1093
+ wait: bool = True,
1094
+ ):
1095
+ """Create an Amazon SageMaker Inference Component.
1096
+
1097
+ Args:
1098
+ inference_component_name (str): Name of the Amazon SageMaker inference component
1099
+ to create.
1100
+ endpoint_name (str): Name of the Amazon SageMaker endpoint that the inference component
1101
+ will deploy to.
1102
+ variant_name (str): Name of the Amazon SageMaker variant that the inference component
1103
+ will deploy to.
1104
+ specification (Dict[str, Any]): The inference component specification.
1105
+ runtime_config (Optional[Dict[str, Any]]): Optional. The inference component
1106
+ runtime configuration. (Default: None).
1107
+ tags (Optional[Tags]): Optional. Either a dictionary or a list
1108
+ of dictionaries containing key-value pairs. (Default: None).
1109
+ wait (bool) : Optional. Wait for the inference component to finish being created before
1110
+ returning a value. (Default: True).
1111
+
1112
+ Returns:
1113
+ str: Name of the Amazon SageMaker ``InferenceComponent`` if created.
1114
+ """
1115
+ LOGGER.info(
1116
+ "Creating inference component with name %s for endpoint %s",
1117
+ inference_component_name,
1118
+ endpoint_name,
1119
+ )
1120
+
1121
+ if runtime_config is None:
1122
+ runtime_config = {"CopyCount": 1}
1123
+
1124
+ request = {
1125
+ "InferenceComponentName": inference_component_name,
1126
+ "EndpointName": endpoint_name,
1127
+ "VariantName": variant_name,
1128
+ "Specification": specification,
1129
+ "RuntimeConfig": runtime_config,
1130
+ }
1131
+
1132
+ tags = format_tags(tags)
1133
+ tags = _append_project_tags(tags)
1134
+ tags = self._append_sagemaker_config_tags(
1135
+ tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS)
1136
+ )
1137
+ if tags and len(tags) != 0:
1138
+ request["Tags"] = tags
1139
+
1140
+ self.sagemaker_client.create_inference_component(**request)
1141
+ if wait:
1142
+ self.wait_for_inference_component(inference_component_name)
1143
+ return inference_component_name
1144
+
1145
+ def wait_for_inference_component(self, inference_component_name, poll=20):
1146
+ """Wait for an Amazon SageMaker ``Inference Component`` deployment to complete.
1147
+
1148
+ Args:
1149
+ inference_component_name (str): Name of the ``Inference Component`` to wait for.
1150
+ poll (int): Polling interval in seconds (default: 20).
1151
+
1152
+ Raises:
1153
+ exceptions.CapacityError: If the inference component creation fails with CapacityError.
1154
+ exceptions.UnexpectedStatusException: If the inference component creation fails.
1155
+
1156
+ Returns:
1157
+ dict: Return value from the ``DescribeInferenceComponent`` API.
1158
+ """
1159
+ desc = _wait_until(
1160
+ lambda: self._inference_component_done(self.sagemaker_client, inference_component_name),
1161
+ poll,
1162
+ )
1163
+ status = desc["InferenceComponentStatus"]
1164
+
1165
+ if status != "InService":
1166
+ message = f"Error creating inference component '{inference_component_name}'"
1167
+ reason = desc.get("FailureReason")
1168
+ if reason:
1169
+ message = f"{message}: {reason}"
1170
+ if "CapacityError" in str(reason):
1171
+ raise exceptions.CapacityError(
1172
+ message=message,
1173
+ allowed_statuses=["InService"],
1174
+ actual_status=status,
1175
+ )
1176
+ raise exceptions.UnexpectedStatusException(
1177
+ message=message,
1178
+ allowed_statuses=["InService"],
1179
+ actual_status=status,
1180
+ )
1181
+ return desc
1182
+
1183
+ def describe_inference_component(self, inference_component_name):
1184
+ """Describe an Amazon SageMaker ``InferenceComponent``
1185
+
1186
+ Args:
1187
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1188
+
1189
+ Returns:
1190
+ dict[str,str]: Inference component details.
1191
+ """
1192
+
1193
+ return self.sagemaker_client.describe_inference_component(
1194
+ InferenceComponentName=inference_component_name
1195
+ )
1196
+
1197
+ def _inference_component_done(self, sagemaker_client, inference_component_name):
1198
+ """Check if creation of inference component is done.
1199
+
1200
+ Args:
1201
+ sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker
1202
+ service calls
1203
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1204
+ Returns:
1205
+ dict[str,str]: Inference component details.
1206
+ """
1207
+
1208
+ create_inference_component_codes = {
1209
+ "InService": "!",
1210
+ "Creating": "-",
1211
+ "Updating": "-",
1212
+ "Failed": "*",
1213
+ "Deleting": "o",
1214
+ }
1215
+ in_progress_statuses = ["Creating", "Updating", "Deleting"]
1216
+
1217
+ desc = sagemaker_client.describe_inference_component(
1218
+ InferenceComponentName=inference_component_name
1219
+ )
1220
+ status = desc["InferenceComponentStatus"]
1221
+
1222
+ print(create_inference_component_codes.get(status, "?"), end="", flush=True)
1223
+
1224
+ return None if status in in_progress_statuses else desc
1225
+
1226
+ def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True):
1227
+ """Update an Amazon SageMaker ``Endpoint`` , Raise an error endpoint_name does not exist.
1228
+
1229
+ Args:
1230
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
1231
+ endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
1232
+ deploy.
1233
+ wait (bool): Whether to wait for the endpoint deployment to complete before returning
1234
+ (default: True).
1235
+
1236
+ Returns:
1237
+ str: Name of the Amazon SageMaker ``Endpoint`` being updated.
1238
+
1239
+ Raises:
1240
+ - ValueError: if the endpoint does not already exist
1241
+ - botocore.exceptions.ClientError: If SageMaker throws an error while
1242
+ creating endpoint config, describing endpoint or updating endpoint
1243
+ """
1244
+ if not _deployment_entity_exists(
1245
+ lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
1246
+ ):
1247
+ raise ValueError(
1248
+ "Endpoint with name '{}' does not exist; please use an "
1249
+ "existing endpoint name".format(endpoint_name)
1250
+ )
1251
+
1252
+ try:
1253
+
1254
+ res = self.sagemaker_client.update_endpoint(
1255
+ EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
1256
+ )
1257
+ if res:
1258
+ self.endpoint_arn = res["EndpointArn"]
1259
+
1260
+ if wait:
1261
+ self.wait_for_endpoint(endpoint_name)
1262
+ return endpoint_name
1263
+ except Exception as e:
1264
+ troubleshooting = (
1265
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
1266
+ "sagemaker-python-sdk-troubleshooting.html"
1267
+ "#sagemaker-python-sdk-troubleshooting-update-endpoint"
1268
+ )
1269
+ logger.error(
1270
+ "Please check the troubleshooting guide for common errors: %s", troubleshooting
1271
+ )
1272
+ raise e
1273
+
1274
+ def endpoint_in_service_or_not(self, endpoint_name: str):
1275
+ """Check whether an Amazon SageMaker ``Endpoint``` is in IN_SERVICE status.
1276
+
1277
+ Raise any exception that is not recognized as "not found".
1278
+
1279
+ Args:
1280
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to
1281
+ check status.
1282
+
1283
+ Returns:
1284
+ bool: True if ``Endpoint`` is IN_SERVICE, False if ``Endpoint`` not exists
1285
+ or it's in other status.
1286
+
1287
+ Raises:
1288
+
1289
+ """
1290
+ try:
1291
+ desc = self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
1292
+ status = desc["EndpointStatus"]
1293
+ if status == "InService":
1294
+ return True
1295
+ return False
1296
+
1297
+ except botocore.exceptions.ClientError as e:
1298
+ str_err = str(e).lower()
1299
+ if "could not find" in str_err or "not found" in str_err:
1300
+ return False
1301
+ raise
1302
+
1303
+ def _intercept_create_request(
1304
+ self,
1305
+ request: Dict,
1306
+ create,
1307
+ func_name: str = None,
1308
+ # pylint: disable=unused-argument
1309
+ ):
1310
+ """This function intercepts the create job request.
1311
+
1312
+ PipelineSession inherits this Session class and will override
1313
+ this function to intercept the create request.
1314
+
1315
+ Args:
1316
+ request (dict): the create job request
1317
+ create (functor): a functor calls the sagemaker client create method
1318
+ func_name (str): the name of the function needed intercepting
1319
+ """
1320
+ return create(request)
1321
+
1322
+ def _create_inference_recommendations_job_request(
1323
+ self,
1324
+ role: str,
1325
+ job_name: str,
1326
+ job_description: str,
1327
+ framework: str,
1328
+ sample_payload_url: str,
1329
+ supported_content_types: List[str],
1330
+ tags: Optional[Tags],
1331
+ model_name: str = None,
1332
+ model_package_version_arn: str = None,
1333
+ job_duration_in_seconds: int = None,
1334
+ job_type: str = "Default",
1335
+ framework_version: str = None,
1336
+ nearest_model_name: str = None,
1337
+ supported_instance_types: List[str] = None,
1338
+ endpoint_configurations: List[Dict[str, Any]] = None,
1339
+ traffic_pattern: Dict[str, Any] = None,
1340
+ stopping_conditions: Dict[str, Any] = None,
1341
+ resource_limit: Dict[str, Any] = None,
1342
+ ) -> Dict[str, Any]:
1343
+ """Get request dictionary for CreateInferenceRecommendationsJob API.
1344
+
1345
+ Args:
1346
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1347
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1348
+ training data and model artifacts.
1349
+ You must grant sufficient permissions to this role.
1350
+ job_name (str): The name of the Inference Recommendations Job.
1351
+ job_description (str): A description of the Inference Recommendations Job.
1352
+ framework (str): The machine learning framework of the Image URI.
1353
+ sample_payload_url (str): The S3 path where the sample payload is stored.
1354
+ supported_content_types (List[str]): The supported MIME types for the input data.
1355
+ model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
1356
+ model_package_version_arn (str): The Amazon Resource Name (ARN) of a
1357
+ versioned model package.
1358
+ job_duration_in_seconds (int): The maximum job duration that a job
1359
+ can run for. Will be used for `Advanced` jobs.
1360
+ job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
1361
+ framework_version (str): The framework version of the Image URI.
1362
+ nearest_model_name (str): The name of a pre-trained machine learning model
1363
+ benchmarked by Amazon SageMaker Inference Recommender that matches your model.
1364
+ supported_instance_types (List[str]): A list of the instance types that are used
1365
+ to generate inferences in real-time.
1366
+ tags (Optional[Tags]): Tags used to identify where
1367
+ the Inference Recommendatons Call was made from.
1368
+ endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
1369
+ to use for a job. Will be used for `Advanced` jobs.
1370
+ traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
1371
+ Will be used for `Advanced` jobs.
1372
+ stopping_conditions (Dict[str, any]): A set of conditions for stopping a
1373
+ recommendation job.
1374
+ If any of the conditions are met, the job is automatically stopped.
1375
+ Will be used for `Advanced` jobs.
1376
+ resource_limit (Dict[str, any]): Defines the resource limit for the job.
1377
+ Will be used for `Advanced` jobs.
1378
+ Returns:
1379
+ Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API
1380
+ """
1381
+
1382
+ containerConfig = {
1383
+ "Domain": "MACHINE_LEARNING",
1384
+ "Task": "OTHER",
1385
+ "Framework": framework,
1386
+ "PayloadConfig": {
1387
+ "SamplePayloadUrl": sample_payload_url,
1388
+ "SupportedContentTypes": supported_content_types,
1389
+ },
1390
+ }
1391
+
1392
+ if framework_version:
1393
+ containerConfig["FrameworkVersion"] = framework_version
1394
+ if nearest_model_name:
1395
+ containerConfig["NearestModelName"] = nearest_model_name
1396
+ if supported_instance_types:
1397
+ containerConfig["SupportedInstanceTypes"] = supported_instance_types
1398
+
1399
+ request = {
1400
+ "JobName": job_name,
1401
+ "JobType": job_type,
1402
+ "RoleArn": role,
1403
+ "InputConfig": {
1404
+ "ContainerConfig": containerConfig,
1405
+ },
1406
+ "Tags": format_tags(tags),
1407
+ }
1408
+
1409
+ request.get("InputConfig").update(
1410
+ {"ModelPackageVersionArn": model_package_version_arn}
1411
+ if model_package_version_arn
1412
+ else {"ModelName": model_name}
1413
+ )
1414
+
1415
+ if job_description:
1416
+ request["JobDescription"] = job_description
1417
+ if job_duration_in_seconds:
1418
+ request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds
1419
+
1420
+ if job_type == "Advanced":
1421
+ if stopping_conditions:
1422
+ request["StoppingConditions"] = stopping_conditions
1423
+ if resource_limit:
1424
+ request["InputConfig"]["ResourceLimit"] = resource_limit
1425
+ if traffic_pattern:
1426
+ request["InputConfig"]["TrafficPattern"] = traffic_pattern
1427
+ if endpoint_configurations:
1428
+ request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations
1429
+
1430
+ return request
1431
+
1432
+ def create_inference_recommendations_job(
1433
+ self,
1434
+ role: str,
1435
+ sample_payload_url: str,
1436
+ supported_content_types: List[str],
1437
+ job_name: str = None,
1438
+ job_type: str = "Default",
1439
+ model_name: str = None,
1440
+ model_package_version_arn: str = None,
1441
+ job_duration_in_seconds: int = None,
1442
+ nearest_model_name: str = None,
1443
+ supported_instance_types: List[str] = None,
1444
+ framework: str = None,
1445
+ framework_version: str = None,
1446
+ endpoint_configurations: List[Dict[str, any]] = None,
1447
+ traffic_pattern: Dict[str, any] = None,
1448
+ stopping_conditions: Dict[str, any] = None,
1449
+ resource_limit: Dict[str, any] = None,
1450
+ ):
1451
+ """Creates an Inference Recommendations Job
1452
+
1453
+ Args:
1454
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1455
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1456
+ training data and model artifactså.
1457
+ You must grant sufficient permissions to this role.
1458
+ sample_payload_url (str): The S3 path where the sample payload is stored.
1459
+ supported_content_types (List[str]): The supported MIME types for the input data.
1460
+ model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
1461
+ model_package_version_arn (str): The Amazon Resource Name (ARN) of a
1462
+ versioned model package.
1463
+ job_name (str): The name of the job being run.
1464
+ job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
1465
+ job_duration_in_seconds (int): The maximum job duration that a job
1466
+ can run for. Will be used for `Advanced` jobs.
1467
+ nearest_model_name (str): The name of a pre-trained machine learning model
1468
+ benchmarked by Amazon SageMaker Inference Recommender that matches your model.
1469
+ supported_instance_types (List[str]): A list of the instance types that are used
1470
+ to generate inferences in real-time.
1471
+ framework (str): The machine learning framework of the Image URI.
1472
+ framework_version (str): The framework version of the Image URI.
1473
+ endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
1474
+ to use for a job. Will be used for `Advanced` jobs.
1475
+ traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
1476
+ Will be used for `Advanced` jobs.
1477
+ stopping_conditions (Dict[str, any]): A set of conditions for stopping a
1478
+ recommendation job.
1479
+ If any of the conditions are met, the job is automatically stopped.
1480
+ Will be used for `Advanced` jobs.
1481
+ resource_limit (Dict[str, any]): Defines the resource limit for the job.
1482
+ Will be used for `Advanced` jobs.
1483
+ Returns:
1484
+ str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
1485
+ """
1486
+
1487
+ if model_name is None and model_package_version_arn is None:
1488
+ raise ValueError("Please provide either model_name or model_package_version_arn.")
1489
+
1490
+ if model_name is not None and model_package_version_arn is not None:
1491
+ raise ValueError("Please provide either model_name or model_package_version_arn.")
1492
+
1493
+ if not job_name:
1494
+ unique_tail = uuid.uuid4()
1495
+ job_name = "SMPYTHONSDK-" + str(unique_tail)
1496
+ job_description = "#python-sdk-create"
1497
+
1498
+ tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}]
1499
+
1500
+ create_inference_recommendations_job_request = (
1501
+ self._create_inference_recommendations_job_request(
1502
+ role=role,
1503
+ model_name=model_name,
1504
+ model_package_version_arn=model_package_version_arn,
1505
+ job_name=job_name,
1506
+ job_type=job_type,
1507
+ job_duration_in_seconds=job_duration_in_seconds,
1508
+ job_description=job_description,
1509
+ framework=framework,
1510
+ framework_version=framework_version,
1511
+ nearest_model_name=nearest_model_name,
1512
+ sample_payload_url=sample_payload_url,
1513
+ supported_content_types=supported_content_types,
1514
+ supported_instance_types=supported_instance_types,
1515
+ endpoint_configurations=endpoint_configurations,
1516
+ traffic_pattern=traffic_pattern,
1517
+ stopping_conditions=stopping_conditions,
1518
+ resource_limit=resource_limit,
1519
+ tags=tags,
1520
+ )
1521
+ )
1522
+
1523
+ def submit(request):
1524
+ logger.info("Creating Inference Recommendations job with name: %s", job_name)
1525
+ logger.debug("process request: %s", json.dumps(request, indent=4))
1526
+ self.sagemaker_client.create_inference_recommendations_job(**request)
1527
+
1528
+ self._intercept_create_request(
1529
+ create_inference_recommendations_job_request,
1530
+ submit,
1531
+ self.create_inference_recommendations_job.__name__,
1532
+ )
1533
+ return job_name
1534
+
1535
+ def wait_for_inference_recommendations_job(
1536
+ self, job_name: str, poll: int = 120, log_level: str = "Verbose"
1537
+ ) -> Dict[str, Any]:
1538
+ """Wait for an Amazon SageMaker Inference Recommender job to complete.
1539
+
1540
+ Args:
1541
+ job_name (str): Name of the Inference Recommender job to wait for.
1542
+ poll (int): Polling interval in seconds (default: 120).
1543
+ log_level (str): The level of verbosity for the logs.
1544
+ Can be "Quiet" or "Verbose" (default: "Quiet").
1545
+
1546
+ Returns:
1547
+ (dict): Return value from the ``DescribeInferenceRecommendationsJob`` API.
1548
+
1549
+ Raises:
1550
+ exceptions.CapacityError: If the Inference Recommender job fails with CapacityError.
1551
+ exceptions.UnexpectedStatusException: If the Inference Recommender job fails.
1552
+ """
1553
+ if log_level == "Quiet":
1554
+ _wait_until(
1555
+ lambda: _describe_inference_recommendations_job_status(
1556
+ self.sagemaker_client, job_name
1557
+ ),
1558
+ poll,
1559
+ )
1560
+ elif log_level == "Verbose":
1561
+ _display_inference_recommendations_job_steps_status(
1562
+ self, self.sagemaker_client, job_name
1563
+ )
1564
+ else:
1565
+ raise ValueError("log_level must be either Quiet or Verbose")
1566
+ desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name)
1567
+ _check_job_status(job_name, desc, "Status")
1568
+ return desc
1569
+
1570
+ def delete_model(self, model_name):
1571
+ """Delete an Amazon SageMaker Model.
1572
+
1573
+ Args:
1574
+ model_name (str): Name of the Amazon SageMaker model to delete.
1575
+ """
1576
+ logger.info("Deleting model with name: %s", model_name)
1577
+ self.sagemaker_client.delete_model(ModelName=model_name)
1578
+
1579
+ def delete_endpoint(self, endpoint_name):
1580
+ """Delete an Amazon SageMaker ``Endpoint``.
1581
+
1582
+ Args:
1583
+ endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
1584
+ """
1585
+ logger.info("Deleting endpoint with name: %s", endpoint_name)
1586
+ self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
1587
+
1588
+ def delete_endpoint_config(self, endpoint_config_name):
1589
+ """Delete an Amazon SageMaker endpoint configuration.
1590
+
1591
+ Args:
1592
+ endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
1593
+ delete.
1594
+ """
1595
+ logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
1596
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
1597
+
1598
+ def wait_for_optimization_job(self, job, poll=5):
1599
+ """Wait for an Amazon SageMaker Optimization job to complete.
1600
+
1601
+ Args:
1602
+ job (str): Name of optimization job to wait for.
1603
+ poll (int): Polling interval in seconds (default: 5).
1604
+
1605
+ Returns:
1606
+ (dict): Return value from the ``DescribeOptimizationJob`` API.
1607
+
1608
+ Raises:
1609
+ exceptions.ResourceNotFound: If optimization job fails with CapacityError.
1610
+ exceptions.UnexpectedStatusException: If optimization job fails.
1611
+ """
1612
+ desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll)
1613
+ _check_job_status(job, desc, "OptimizationJobStatus")
1614
+ return desc
1615
+
1616
+ def update_inference_component(
1617
+ self, inference_component_name, specification=None, runtime_config=None, wait=True
1618
+ ):
1619
+ """Update an Amazon SageMaker ``InferenceComponent``
1620
+
1621
+ Args:
1622
+ inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
1623
+ specification ([dict[str,int]]): Resource configuration. Optional.
1624
+ Example: {
1625
+ "MinMemoryRequiredInMb": 1024,
1626
+ "NumberOfCpuCoresRequired": 1,
1627
+ "NumberOfAcceleratorDevicesRequired": 1,
1628
+ "MaxMemoryRequiredInMb": 4096,
1629
+ },
1630
+ runtime_config ([dict[str,int]]): Number of copies. Optional.
1631
+ Default: {
1632
+ "copyCount": 1
1633
+ }
1634
+ wait: Wait for inference component to be created before return. Optional. Default is
1635
+ True.
1636
+
1637
+ Return:
1638
+ str: inference component name
1639
+
1640
+ Raises:
1641
+ ValueError: If the inference_component_name does not exist.
1642
+ """
1643
+ if not _deployment_entity_exists(
1644
+ lambda: self.sagemaker_client.describe_inference_component(
1645
+ InferenceComponentName=inference_component_name
1646
+ )
1647
+ ):
1648
+ raise ValueError(
1649
+ "InferenceComponent with name '{}' does not exist; please use an "
1650
+ "existing model name".format(inference_component_name)
1651
+ )
1652
+
1653
+ request = {
1654
+ "InferenceComponentName": inference_component_name,
1655
+ "Specification": specification,
1656
+ "RuntimeConfig": runtime_config,
1657
+ }
1658
+
1659
+ self.sagemaker_client.update_inference_component(**request)
1660
+
1661
+ if wait:
1662
+ self.wait_for_inference_component(inference_component_name)
1663
+ return inference_component_name
1664
+
1665
+ def _create_model_request(
1666
+ self,
1667
+ name,
1668
+ role,
1669
+ container_defs,
1670
+ vpc_config=None,
1671
+ enable_network_isolation=False,
1672
+ primary_container=None,
1673
+ tags=None,
1674
+ ): # pylint: disable=redefined-outer-name
1675
+ """Placeholder docstring"""
1676
+
1677
+ if container_defs and primary_container:
1678
+ raise ValueError("Both container_defs and primary_container can not be passed as input")
1679
+
1680
+ if primary_container:
1681
+ msg = (
1682
+ "primary_container is going to be deprecated in a future release. Please use "
1683
+ "container_defs instead."
1684
+ )
1685
+ warnings.warn(msg, DeprecationWarning)
1686
+ container_defs = primary_container
1687
+
1688
+ role = self.expand_role(role)
1689
+
1690
+ if isinstance(container_defs, list):
1691
+ update_list_of_dicts_with_values_from_config(
1692
+ container_defs, MODEL_CONTAINERS_PATH, sagemaker_session=self
1693
+ )
1694
+ container_definition = container_defs
1695
+ else:
1696
+ container_definition = _expand_container_def(container_defs)
1697
+ container_definition = update_nested_dictionary_with_values_from_config(
1698
+ container_definition, MODEL_PRIMARY_CONTAINER_PATH, sagemaker_session=self
1699
+ )
1700
+
1701
+ request = {"ModelName": name, "ExecutionRoleArn": role}
1702
+ if isinstance(container_definition, list):
1703
+ request["Containers"] = container_definition
1704
+ elif "ModelPackageName" in container_definition:
1705
+ request["Containers"] = [container_definition]
1706
+ else:
1707
+ request["PrimaryContainer"] = container_definition
1708
+
1709
+ if tags:
1710
+ request["Tags"] = format_tags(tags)
1711
+
1712
+ if vpc_config:
1713
+ request["VpcConfig"] = vpc_config
1714
+
1715
+ if enable_network_isolation:
1716
+ # enable_network_isolation may be a pipeline variable which is
1717
+ # parsed in execution time
1718
+ request["EnableNetworkIsolation"] = enable_network_isolation
1719
+
1720
+ return request
1721
+
1722
+ def create_model(
1723
+ self,
1724
+ name,
1725
+ role=None,
1726
+ container_defs=None,
1727
+ vpc_config=None,
1728
+ enable_network_isolation=None,
1729
+ primary_container=None,
1730
+ tags=None,
1731
+ ):
1732
+ """Create an Amazon SageMaker ``Model``.
1733
+
1734
+ Specify the S3 location of the model artifacts and Docker image containing
1735
+ the inference code. Amazon SageMaker uses this information to deploy the
1736
+ model in Amazon SageMaker. This method can also be used to create a Model for an Inference
1737
+ Pipeline if you pass the list of container definitions through the containers parameter.
1738
+
1739
+ Args:
1740
+ name (str): Name of the Amazon SageMaker ``Model`` to create.
1741
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
1742
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
1743
+ training data and model artifacts. You must grant sufficient permissions to this
1744
+ role.
1745
+ container_defs (list[dict[str, str]] or [dict[str, str]]): A single container
1746
+ definition or a list of container definitions which will be invoked sequentially
1747
+ while performing the prediction. If the list contains only one container, then
1748
+ it'll be passed to SageMaker Hosting as the ``PrimaryContainer`` and otherwise,
1749
+ it'll be passed as ``Containers``.You can also specify the return value of
1750
+ ``sagemaker.get_container_def()`` or ``sagemaker.pipeline_container_def()``,
1751
+ which will used to create more advanced container configurations, including model
1752
+ containers which need artifacts from S3.
1753
+ vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None)
1754
+ * 'Subnets' (list[str]): List of subnet ids.
1755
+ * 'SecurityGroupIds' (list[str]): List of security group ids.
1756
+ enable_network_isolation (bool): Whether the model requires network isolation or not.
1757
+ primary_container (str or dict[str, str]): Docker image which defines the inference
1758
+ code. You can also specify the return value of ``sagemaker.container_def()``,
1759
+ which is used to create more advanced container configurations, including model
1760
+ containers which need artifacts from S3. This field is deprecated, please use
1761
+ container_defs instead.
1762
+ tags(Optional[Tags]): Optional. The list of tags to add to the model.
1763
+
1764
+ Example:
1765
+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
1766
+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
1767
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
1768
+
1769
+ Returns:
1770
+ str: Name of the Amazon SageMaker ``Model`` created.
1771
+ """
1772
+ tags = _append_project_tags(format_tags(tags))
1773
+ tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS))
1774
+ role = resolve_value_from_config(
1775
+ role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self
1776
+ )
1777
+ vpc_config = resolve_value_from_config(
1778
+ vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self
1779
+ )
1780
+ enable_network_isolation = resolve_value_from_config(
1781
+ direct_input=enable_network_isolation,
1782
+ config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
1783
+ default_value=False,
1784
+ sagemaker_session=self,
1785
+ )
1786
+
1787
+ # Due to ambuiguity in container_defs which accepts both a single
1788
+ # container definition(dtype: dict) and a list of container definitions (dtype: list),
1789
+ # we need to inject environment variables into the container_defs in the helper function
1790
+ # _create_model_request.
1791
+ create_model_request = self._create_model_request(
1792
+ name=name,
1793
+ role=role,
1794
+ container_defs=container_defs,
1795
+ vpc_config=vpc_config,
1796
+ enable_network_isolation=enable_network_isolation,
1797
+ primary_container=primary_container,
1798
+ tags=tags,
1799
+ )
1800
+
1801
+ def submit(request):
1802
+ logger.info("Creating model with name: %s", name)
1803
+ logger.debug("CreateModel request: %s", json.dumps(request, indent=4))
1804
+ try:
1805
+ self.sagemaker_client.create_model(**request)
1806
+ except ClientError as e:
1807
+ error_code = e.response["Error"]["Code"]
1808
+ message = e.response["Error"]["Message"]
1809
+ if (
1810
+ error_code == "ValidationException"
1811
+ and "Cannot create already existing model" in message
1812
+ ):
1813
+ logger.warning("Using already existing model: %s", name)
1814
+ else:
1815
+ raise
1816
+
1817
+ self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
1818
+ return name
1819
+
1820
+ def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
1821
+ """Create a SageMaker Model Package from the results of training with an Algorithm Package.
1822
+
1823
+ Args:
1824
+ name (str): ModelPackage name
1825
+ description (str): Model Package description
1826
+ algorithm_arn (str): arn or name of the algorithm used for training.
1827
+ model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
1828
+ ``ModelDataSource`` to the model artifacts produced by training
1829
+ """
1830
+ sourceAlgorithm = {"AlgorithmName": algorithm_arn}
1831
+ if isinstance(model_data, dict):
1832
+ sourceAlgorithm["ModelDataSource"] = model_data
1833
+ else:
1834
+ sourceAlgorithm["ModelDataUrl"] = model_data
1835
+
1836
+ request = {
1837
+ "ModelPackageName": name,
1838
+ "ModelPackageDescription": description,
1839
+ "SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
1840
+ }
1841
+ try:
1842
+ logger.info("Creating model package with name: %s", name)
1843
+ self.sagemaker_client.create_model_package(**request)
1844
+ except ClientError as e:
1845
+ error_code = e.response["Error"]["Code"]
1846
+ message = e.response["Error"]["Message"]
1847
+
1848
+ if error_code == "ValidationException" and "ModelPackage already exists" in message:
1849
+ logger.warning("Using already existing model package: %s", name)
1850
+ else:
1851
+ raise
1852
+
1853
+ def expand_role(self, role):
1854
+ """Expand an IAM role name into an ARN.
1855
+
1856
+ If the role is already in the form of an ARN, then the role is simply returned. Otherwise
1857
+ we retrieve the full ARN and return it.
1858
+
1859
+ Args:
1860
+ role (str): An AWS IAM role (either name or full ARN).
1861
+
1862
+ Returns:
1863
+ str: The corresponding AWS IAM role ARN.
1864
+ """
1865
+ if "/" in role:
1866
+ return role
1867
+ return self.boto_session.resource("iam").Role(role).arn
1868
+
1869
+
1870
+ def _expand_container_def(c_def):
1871
+ """Placeholder docstring"""
1872
+ if isinstance(c_def, six.string_types):
1873
+ return container_def(c_def)
1874
+ return c_def
1875
+
1876
+
1877
+ def expand_role(self, role):
1878
+ """Expand an IAM role name into an ARN.
1879
+
1880
+ If the role is already in the form of an ARN, then the role is simply returned. Otherwise
1881
+ we retrieve the full ARN and return it.
1882
+
1883
+ Args:
1884
+ role (str): An AWS IAM role (either name or full ARN).
1885
+
1886
+ Returns:
1887
+ str: The corresponding AWS IAM role ARN.
1888
+ """
1889
+ if "/" in role:
1890
+ return role
1891
+ return self.boto_session.resource("iam").Role(role).arn
1892
+
1893
+
1894
+ def s3_path_join(*args, with_end_slash: bool = False):
1895
+ """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
1896
+
1897
+ Behavior of this function:
1898
+ - If the first argument is "s3://", then that is preserved.
1899
+ - The output by default will have no slashes at the beginning or end. There is one exception
1900
+ (see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
1901
+ `"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
1902
+ - Any repeat slashes will be removed in the output (except for "s3://" if provided at the
1903
+ beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
1904
+ `"s3://foo/bar/baz"`.
1905
+ - Empty or None arguments will be skipped. For example
1906
+ `s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
1907
+
1908
+ Alternatives to this function that are NOT recommended for S3 paths:
1909
+ - `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
1910
+ - `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
1911
+ dots (".") and root directories. (for example
1912
+ `pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
1913
+ - `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
1914
+
1915
+ Args:
1916
+ *args: The strings to join with a slash.
1917
+ with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
1918
+ to the end of the path
1919
+
1920
+ Returns:
1921
+ str: The joined string, without a slash at the end unless with_end_slash is True.
1922
+ """
1923
+ delimiter = "/"
1924
+
1925
+ non_empty_args = list(filter(lambda item: item is not None and item != "", args))
1926
+
1927
+ merged_path = ""
1928
+ for index, path in enumerate(non_empty_args):
1929
+ if (
1930
+ index == 0
1931
+ or (merged_path and merged_path[-1] == delimiter)
1932
+ or (path and path[0] == delimiter)
1933
+ ):
1934
+ # dont need to add an extra slash because either this is the beginning of the string,
1935
+ # or one (or more) slash already exists
1936
+ merged_path += path
1937
+ else:
1938
+ merged_path += delimiter + path
1939
+
1940
+ if with_end_slash and merged_path and merged_path[-1] != delimiter:
1941
+ merged_path += delimiter
1942
+
1943
+ # At this point, merged_path may include slashes at the beginning and/or end. And some of the
1944
+ # provided args may have had duplicate slashes inside or at the ends.
1945
+ # For backwards compatibility reasons, these need to be filtered out (done below). In the
1946
+ # future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
1947
+ # one option is to create a new optional argument (or a new function) that only executes the
1948
+ # logic above.
1949
+ filtered_path = merged_path
1950
+
1951
+ # remove duplicate slashes
1952
+ if filtered_path:
1953
+
1954
+ def duplicate_delimiter_remover(sequence, next_char):
1955
+ if sequence[-1] == delimiter and next_char == delimiter:
1956
+ return sequence
1957
+ return sequence + next_char
1958
+
1959
+ if filtered_path.startswith("s3://"):
1960
+ filtered_path = reduce(
1961
+ duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
1962
+ )
1963
+ else:
1964
+ filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
1965
+
1966
+ # remove beginning slashes
1967
+ filtered_path = filtered_path.lstrip(delimiter)
1968
+
1969
+ # remove end slashes
1970
+ if not with_end_slash and filtered_path != "s3://":
1971
+ filtered_path = filtered_path.rstrip(delimiter)
1972
+
1973
+ return filtered_path
1974
+
1975
+
1976
+ def botocore_resolver():
1977
+ """Get the DNS suffix for the given region.
1978
+
1979
+ Args:
1980
+ region (str): AWS region name
1981
+
1982
+ Returns:
1983
+ str: the DNS suffix
1984
+ """
1985
+ loader = botocore.loaders.create_loader()
1986
+ return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
1987
+
1988
+
1989
+ def sts_regional_endpoint(region):
1990
+ """Get the AWS STS endpoint specific for the given region.
1991
+
1992
+ We need this function because the AWS SDK does not yet honor
1993
+ the ``region_name`` parameter when creating an AWS STS client.
1994
+
1995
+ For the list of regional endpoints, see
1996
+ https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
1997
+
1998
+ Args:
1999
+ region (str): AWS region name
2000
+
2001
+ Returns:
2002
+ str: AWS STS regional endpoint
2003
+ """
2004
+ endpoint_data = botocore_resolver().construct_endpoint("sts", region)
2005
+ if region == "il-central-1" and not endpoint_data:
2006
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
2007
+ return "https://{}".format(endpoint_data["hostname"])
2008
+
2009
+
2010
+ def get_execution_role(sagemaker_session=None, use_default=False):
2011
+ """Return the role ARN whose credentials are used to call the API.
2012
+
2013
+ Throws an exception if role doesn't exist.
2014
+
2015
+ Args:
2016
+ sagemaker_session (Session): Current sagemaker session.
2017
+ use_default (bool): Use a default role if ``get_caller_identity_arn`` does not
2018
+ return a correct role. This default role will be created if needed.
2019
+ Defaults to ``False``.
2020
+
2021
+ Returns:
2022
+ (str): The role ARN
2023
+ """
2024
+ if not sagemaker_session:
2025
+ sagemaker_session = Session()
2026
+ arn = sagemaker_session.get_caller_identity_arn()
2027
+
2028
+ if ":role/" in arn:
2029
+ return arn
2030
+
2031
+ if use_default:
2032
+ default_role_name = "AmazonSageMaker-DefaultRole"
2033
+
2034
+ LOGGER.warning("Using default role: %s", default_role_name)
2035
+
2036
+ boto3_session = sagemaker_session.boto_session
2037
+ permissions_policy = json.dumps(
2038
+ {
2039
+ "Version": "2012-10-17",
2040
+ "Statement": [
2041
+ {
2042
+ "Effect": "Allow",
2043
+ "Principal": {"Service": ["sagemaker.amazonaws.com"]},
2044
+ "Action": "sts:AssumeRole",
2045
+ }
2046
+ ],
2047
+ }
2048
+ )
2049
+ iam_client = boto3_session.client("iam")
2050
+ try:
2051
+ iam_client.get_role(RoleName=default_role_name)
2052
+ except iam_client.exceptions.NoSuchEntityException:
2053
+ iam_client.create_role(
2054
+ RoleName=default_role_name, AssumeRolePolicyDocument=str(permissions_policy)
2055
+ )
2056
+
2057
+ LOGGER.warning("Created new sagemaker execution role: %s", default_role_name)
2058
+
2059
+ iam_client.attach_role_policy(
2060
+ PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
2061
+ RoleName=default_role_name,
2062
+ )
2063
+ return iam_client.get_role(RoleName=default_role_name)["Role"]["Arn"]
2064
+
2065
+ message = (
2066
+ "The current AWS identity is not a role: {}, therefore it cannot be used as a "
2067
+ "SageMaker execution role"
2068
+ )
2069
+ raise ValueError(message.format(arn))
2070
+
2071
+
2072
+ def get_add_model_package_inference_args(
2073
+ model_package_arn,
2074
+ name,
2075
+ containers=None,
2076
+ content_types=None,
2077
+ response_types=None,
2078
+ inference_instances=None,
2079
+ transform_instances=None,
2080
+ description=None,
2081
+ ):
2082
+ """Get request dictionary for UpdateModelPackage API for additional inference.
2083
+
2084
+ Args:
2085
+ model_package_arn (str): Arn for the model package.
2086
+ name (str): Name to identify the additional inference specification
2087
+ containers (dict): The Amazon ECR registry path of the Docker image
2088
+ that contains the inference code.
2089
+ image_uris (List[str]): The ECR path where inference code is stored.
2090
+ description (str): Description for the additional inference specification
2091
+ content_types (list[str]): The supported MIME types
2092
+ for the input data.
2093
+ response_types (list[str]): The supported MIME types
2094
+ for the output data.
2095
+ inference_instances (list[str]): A list of the instance
2096
+ types that are used to generate inferences in real-time (default: None).
2097
+ transform_instances (list[str]): A list of the instance
2098
+ types on which a transformation job can be run or on which an endpoint can be
2099
+ deployed (default: None).
2100
+ """
2101
+
2102
+ request_dict = {}
2103
+ if containers is not None:
2104
+ inference_specification = {
2105
+ "Containers": containers,
2106
+ }
2107
+
2108
+ if name is not None:
2109
+ inference_specification.update({"Name": name})
2110
+
2111
+ if description is not None:
2112
+ inference_specification.update({"Description": description})
2113
+ if content_types is not None:
2114
+ inference_specification.update(
2115
+ {
2116
+ "SupportedContentTypes": content_types,
2117
+ }
2118
+ )
2119
+ if response_types is not None:
2120
+ inference_specification.update(
2121
+ {
2122
+ "SupportedResponseMIMETypes": response_types,
2123
+ }
2124
+ )
2125
+ if inference_instances is not None:
2126
+ inference_specification.update(
2127
+ {
2128
+ "SupportedRealtimeInferenceInstanceTypes": inference_instances,
2129
+ }
2130
+ )
2131
+ if transform_instances is not None:
2132
+ inference_specification.update(
2133
+ {
2134
+ "SupportedTransformInstanceTypes": transform_instances,
2135
+ }
2136
+ )
2137
+ request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
2138
+ request_dict.update({"ModelPackageArn": model_package_arn})
2139
+ return request_dict
2140
+
2141
+
2142
+ def get_update_model_package_inference_args(
2143
+ model_package_arn,
2144
+ containers=None,
2145
+ content_types=None,
2146
+ response_types=None,
2147
+ inference_instances=None,
2148
+ transform_instances=None,
2149
+ ):
2150
+ """Get request dictionary for UpdateModelPackage API for inference specification.
2151
+
2152
+ Args:
2153
+ model_package_arn (str): Arn for the model package.
2154
+ containers (dict): The Amazon ECR registry path of the Docker image
2155
+ that contains the inference code.
2156
+ content_types (list[str]): The supported MIME types
2157
+ for the input data.
2158
+ response_types (list[str]): The supported MIME types
2159
+ for the output data.
2160
+ inference_instances (list[str]): A list of the instance
2161
+ types that are used to generate inferences in real-time (default: None).
2162
+ transform_instances (list[str]): A list of the instance
2163
+ types on which a transformation job can be run or on which an endpoint can be
2164
+ deployed (default: None).
2165
+ """
2166
+
2167
+ request_dict = {}
2168
+ if containers is not None:
2169
+ inference_specification = {
2170
+ "Containers": containers,
2171
+ }
2172
+ if content_types is not None:
2173
+ inference_specification.update(
2174
+ {
2175
+ "SupportedContentTypes": content_types,
2176
+ }
2177
+ )
2178
+ if response_types is not None:
2179
+ inference_specification.update(
2180
+ {
2181
+ "SupportedResponseMIMETypes": response_types,
2182
+ }
2183
+ )
2184
+ if inference_instances is not None:
2185
+ inference_specification.update(
2186
+ {
2187
+ "SupportedRealtimeInferenceInstanceTypes": inference_instances,
2188
+ }
2189
+ )
2190
+ if transform_instances is not None:
2191
+ inference_specification.update(
2192
+ {
2193
+ "SupportedTransformInstanceTypes": transform_instances,
2194
+ }
2195
+ )
2196
+ request_dict["InferenceSpecification"] = inference_specification
2197
+ request_dict.update({"ModelPackageArn": model_package_arn})
2198
+ return request_dict
2199
+
2200
+
2201
+ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
2202
+ sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
2203
+ ):
2204
+ """Display logs for a given training job, optionally tailing them until job is complete.
2205
+
2206
+ If the output is a tty or a Jupyter cell, it will be color-coded
2207
+ based on which instance the log entry is from.
2208
+
2209
+ Args:
2210
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
2211
+ object, used for SageMaker interactions.
2212
+ job_name (str): Name of the training job to display the logs for.
2213
+ wait (bool): Whether to keep looking for new log entries until the job completes
2214
+ (default: False).
2215
+ poll (int): The interval in seconds between polling for new log entries and job
2216
+ completion (default: 5).
2217
+ log_type ([str]): A list of strings specifying which logs to print. Acceptable
2218
+ strings are "All", "None", "Training", or "Rules". To maintain backwards
2219
+ compatibility, boolean values are also accepted and converted to strings.
2220
+ timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
2221
+ default.
2222
+ Returns:
2223
+ Last call to sagemaker DescribeTrainingJob
2224
+ Raises:
2225
+ exceptions.CapacityError: If the training job fails with CapacityError.
2226
+ exceptions.UnexpectedStatusException: If waiting and the training job fails.
2227
+ """
2228
+ sagemaker_client = sagemaker_session.sagemaker_client
2229
+ request_end_time = time.time() + timeout if timeout else None
2230
+ description = _wait_until(
2231
+ lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name)
2232
+ )
2233
+ print(secondary_training_status_message(description, None), end="")
2234
+
2235
+ instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
2236
+ sagemaker_session.boto_session, description, job="Training"
2237
+ )
2238
+
2239
+ state = _get_initial_job_state(description, "TrainingJobStatus", wait)
2240
+
2241
+ # The loop below implements a state machine that alternates between checking the job status
2242
+ # and reading whatever is available in the logs at this point. Note, that if we were
2243
+ # called with wait == False, we never check the job status.
2244
+ #
2245
+ # If wait == TRUE and job is not completed, the initial state is TAILING
2246
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
2247
+ # complete).
2248
+ #
2249
+ # The state table:
2250
+ #
2251
+ # STATE ACTIONS CONDITION NEW STATE
2252
+ # ---------------- ---------------- ----------------- ----------------
2253
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
2254
+ # Else TAILING
2255
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
2256
+ # COMPLETE Read logs, Exit N/A
2257
+ #
2258
+ # Notes:
2259
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
2260
+ # Cloudwatch after the job was marked complete.
2261
+ last_describe_job_call = time.time()
2262
+ last_description = description
2263
+ last_debug_rule_statuses = None
2264
+ last_profiler_rule_statuses = None
2265
+
2266
+ while True:
2267
+ _flush_log_streams(
2268
+ stream_names,
2269
+ instance_count,
2270
+ client,
2271
+ log_group,
2272
+ job_name,
2273
+ positions,
2274
+ dot,
2275
+ color_wrap,
2276
+ )
2277
+ if timeout and time.time() > request_end_time:
2278
+ print("Timeout Exceeded. {} seconds elapsed.".format(timeout))
2279
+ break
2280
+
2281
+ if state == LogState.COMPLETE:
2282
+ break
2283
+
2284
+ time.sleep(poll)
2285
+
2286
+ if state == LogState.JOB_COMPLETE:
2287
+ state = LogState.COMPLETE
2288
+ elif time.time() - last_describe_job_call >= 30:
2289
+ description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
2290
+ last_describe_job_call = time.time()
2291
+
2292
+ if secondary_training_status_changed(description, last_description):
2293
+ print()
2294
+ print(secondary_training_status_message(description, last_description), end="")
2295
+ last_description = description
2296
+
2297
+ status = description["TrainingJobStatus"]
2298
+
2299
+ if status in ("Completed", "Failed", "Stopped"):
2300
+ print()
2301
+ state = LogState.JOB_COMPLETE
2302
+
2303
+ # Print prettified logs related to the status of SageMaker Debugger rules.
2304
+ debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
2305
+ if (
2306
+ debug_rule_statuses
2307
+ and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
2308
+ and (log_type in {"All", "Rules"})
2309
+ ):
2310
+ for status in debug_rule_statuses:
2311
+ rule_log = (
2312
+ f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
2313
+ )
2314
+ print(rule_log)
2315
+
2316
+ last_debug_rule_statuses = debug_rule_statuses
2317
+
2318
+ # Print prettified logs related to the status of SageMaker Profiler rules.
2319
+ profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
2320
+ if (
2321
+ profiler_rule_statuses
2322
+ and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
2323
+ and (log_type in {"All", "Rules"})
2324
+ ):
2325
+ for status in profiler_rule_statuses:
2326
+ rule_log = (
2327
+ f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
2328
+ )
2329
+ print(rule_log)
2330
+
2331
+ last_profiler_rule_statuses = profiler_rule_statuses
2332
+
2333
+ if wait:
2334
+ _check_job_status(job_name, description, "TrainingJobStatus")
2335
+ if dot:
2336
+ print()
2337
+ # Customers are not billed for hardware provisioning, so billable time is less than
2338
+ # total time
2339
+ training_time = description.get("TrainingTimeInSeconds")
2340
+ billable_time = description.get("BillableTimeInSeconds")
2341
+ if training_time is not None:
2342
+ print("Training seconds:", training_time * instance_count)
2343
+ if billable_time is not None:
2344
+ print("Billable seconds:", billable_time * instance_count)
2345
+ if description.get("EnableManagedSpotTraining"):
2346
+ saving = (1 - float(billable_time) / training_time) * 100
2347
+ print("Managed Spot Training savings: {:.1f}%".format(saving))
2348
+ return last_description
2349
+
2350
+
2351
+ def _check_job_status(job, desc, status_key_name):
2352
+ """Check to see if the job completed successfully.
2353
+
2354
+ If not, construct and raise a exceptions. (UnexpectedStatusException).
2355
+
2356
+ Args:
2357
+ job (str): The name of the job to check.
2358
+ desc (dict[str, str]): The result of ``describe_training_job()``.
2359
+ status_key_name (str): Status key name to check for.
2360
+
2361
+ Raises:
2362
+ exceptions.CapacityError: If the training job fails with CapacityError.
2363
+ exceptions.UnexpectedStatusException: If the training job fails.
2364
+ """
2365
+ status = desc[status_key_name]
2366
+ # If the status is capital case, then convert it to Camel case
2367
+ status = _STATUS_CODE_TABLE.get(status, status)
2368
+
2369
+ if status == "Stopped":
2370
+ logger.warning(
2371
+ "Job ended with status 'Stopped' rather than 'Completed'. "
2372
+ "This could mean the job timed out or stopped early for some other reason: "
2373
+ "Consider checking whether it completed as you expect."
2374
+ )
2375
+ elif status != "Completed":
2376
+ reason = desc.get("FailureReason", "(No reason provided)")
2377
+ job_type = status_key_name.replace("JobStatus", " job")
2378
+ troubleshooting = (
2379
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
2380
+ "sagemaker-python-sdk-troubleshooting.html"
2381
+ )
2382
+ message = (
2383
+ "Error for {job_type} {job_name}: {status}. Reason: {reason}. "
2384
+ "Check troubleshooting guide for common errors: {troubleshooting}"
2385
+ ).format(
2386
+ job_type=job_type,
2387
+ job_name=job,
2388
+ status=status,
2389
+ reason=reason,
2390
+ troubleshooting=troubleshooting,
2391
+ )
2392
+ if "CapacityError" in str(reason):
2393
+ raise exceptions.CapacityError(
2394
+ message=message,
2395
+ allowed_statuses=["Completed", "Stopped"],
2396
+ actual_status=status,
2397
+ )
2398
+ raise exceptions.UnexpectedStatusException(
2399
+ message=message,
2400
+ allowed_statuses=["Completed", "Stopped"],
2401
+ actual_status=status,
2402
+ )
2403
+
2404
+
2405
+ def _logs_init(boto_session, description, job):
2406
+ """Placeholder docstring"""
2407
+ if job == "Training":
2408
+ if "InstanceGroups" in description["ResourceConfig"]:
2409
+ instance_count = 0
2410
+ for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
2411
+ instance_count += instanceGroup["InstanceCount"]
2412
+ else:
2413
+ instance_count = description["ResourceConfig"]["InstanceCount"]
2414
+ elif job == "Transform":
2415
+ instance_count = description["TransformResources"]["InstanceCount"]
2416
+ elif job == "Processing":
2417
+ instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
2418
+ elif job == "AutoML":
2419
+ instance_count = 0
2420
+
2421
+ stream_names = [] # The list of log streams
2422
+ positions = {} # The current position in each stream, map of stream name -> position
2423
+
2424
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
2425
+ # to be interrupted by a transient exception.
2426
+ config = botocore.config.Config(retries={"max_attempts": 15})
2427
+ client = boto_session.client("logs", config=config)
2428
+ log_group = "/aws/sagemaker/" + job + "Jobs"
2429
+
2430
+ dot = False
2431
+
2432
+ color_wrap = sagemaker.core.logs.ColorWrap()
2433
+
2434
+ return instance_count, stream_names, positions, client, log_group, dot, color_wrap
2435
+
2436
+
2437
+ def _flush_log_streams(
2438
+ stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
2439
+ ):
2440
+ """Placeholder docstring"""
2441
+ if len(stream_names) < instance_count:
2442
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
2443
+ # may be dynamic until we have a stream for every instance.
2444
+ try:
2445
+ streams = client.describe_log_streams(
2446
+ logGroupName=log_group,
2447
+ logStreamNamePrefix=job_name + "/",
2448
+ orderBy="LogStreamName",
2449
+ limit=min(instance_count, 50),
2450
+ )
2451
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
2452
+
2453
+ while "nextToken" in streams:
2454
+ streams = client.describe_log_streams(
2455
+ logGroupName=log_group,
2456
+ logStreamNamePrefix=job_name + "/",
2457
+ orderBy="LogStreamName",
2458
+ limit=50,
2459
+ )
2460
+
2461
+ stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
2462
+
2463
+ positions.update(
2464
+ [
2465
+ (s, sagemaker.core.logs.Position(timestamp=0, skip=0))
2466
+ for s in stream_names
2467
+ if s not in positions
2468
+ ]
2469
+ )
2470
+ except ClientError as e:
2471
+ # On the very first training job run on an account, there's no log group until
2472
+ # the container starts logging, so ignore any errors thrown about that
2473
+ err = e.response.get("Error", {})
2474
+ if err.get("Code", None) != "ResourceNotFoundException":
2475
+ raise
2476
+
2477
+ if len(stream_names) > 0:
2478
+ if dot:
2479
+ print("")
2480
+ dot = False
2481
+ for idx, event in sagemaker.core.logs.multi_stream_iter(
2482
+ client, log_group, stream_names, positions
2483
+ ):
2484
+ color_wrap(idx, event["message"])
2485
+ ts, count = positions[stream_names[idx]]
2486
+ if event["timestamp"] == ts:
2487
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2488
+ timestamp=ts, skip=count + 1
2489
+ )
2490
+ else:
2491
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2492
+ timestamp=event["timestamp"], skip=1
2493
+ )
2494
+ else:
2495
+ dot = True
2496
+ print(".", end="")
2497
+ sys.stdout.flush()
2498
+
2499
+
2500
+ def _wait_until(callable_fn, poll=5):
2501
+ """Placeholder docstring"""
2502
+ elapsed_time = 0
2503
+ result = None
2504
+ while result is None:
2505
+ try:
2506
+ elapsed_time += poll
2507
+ time.sleep(poll)
2508
+ result = callable_fn()
2509
+ except botocore.exceptions.ClientError as err:
2510
+ # For initial 5 mins we accept/pass AccessDeniedException.
2511
+ # The reason is to await tag propagation to avoid false AccessDenied claims for an
2512
+ # access policy based on resource tags, The caveat here is for true AccessDenied
2513
+ # cases the routine will fail after 5 mins
2514
+ if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
2515
+ logger.warning(
2516
+ "Received AccessDeniedException. This could mean the IAM role does not "
2517
+ "have the resource permissions, in which case please add resource access "
2518
+ "and retry. For cases where the role has tag based resource policy, "
2519
+ "continuing to wait for tag propagation.."
2520
+ )
2521
+ continue
2522
+ raise err
2523
+ return result
2524
+
2525
+
2526
+ def _get_initial_job_state(description, status_key, wait):
2527
+ """Placeholder docstring"""
2528
+ status = description[status_key]
2529
+ job_already_completed = status in ("Completed", "Failed", "Stopped")
2530
+ return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
2531
+
2532
+
2533
+ def _rule_statuses_changed(current_statuses, last_statuses):
2534
+ """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
2535
+ if not last_statuses:
2536
+ return True
2537
+
2538
+ for current, last in zip(current_statuses, last_statuses):
2539
+ if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and (
2540
+ current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"]
2541
+ ):
2542
+ return True
2543
+
2544
+ return False
2545
+
2546
+
2547
+ def update_args(args: Dict[str, Any], **kwargs):
2548
+ """Updates the request arguments dict with the value if populated.
2549
+
2550
+ This is to handle the case that the service API doesn't like NoneTypes for argument values.
2551
+
2552
+ Args:
2553
+ request_args (Dict[str, Any]): the request arguments dict
2554
+ kwargs: key, value pairs to update the args dict
2555
+ """
2556
+ for key, value in kwargs.items():
2557
+ if value is not None:
2558
+ args.update({key: value})
2559
+
2560
+
2561
+ def production_variant(
2562
+ model_name=None,
2563
+ instance_type=None,
2564
+ initial_instance_count=None,
2565
+ variant_name="AllTraffic",
2566
+ initial_weight=1,
2567
+ accelerator_type=None,
2568
+ serverless_inference_config=None,
2569
+ volume_size=None,
2570
+ model_data_download_timeout=None,
2571
+ container_startup_health_check_timeout=None,
2572
+ managed_instance_scaling=None,
2573
+ routing_config=None,
2574
+ inference_ami_version=None,
2575
+ ):
2576
+ """Create a production variant description suitable for use in a ``ProductionVariant`` list.
2577
+
2578
+ This is also part of a ``CreateEndpointConfig`` request.
2579
+
2580
+ Args:
2581
+ model_name (str): The name of the SageMaker model this production variant references.
2582
+ instance_type (str): The EC2 instance type for this production variant. For example,
2583
+ 'ml.c4.8xlarge'.
2584
+ initial_instance_count (int): The initial instance count for this production variant
2585
+ (default: 1).
2586
+ variant_name (string): The ``VariantName`` of this production variant
2587
+ (default: 'AllTraffic').
2588
+ initial_weight (int): The relative ``InitialVariantWeight`` of this production variant
2589
+ (default: 1).
2590
+ accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
2591
+ For example, 'ml.eia1.medium'.
2592
+ For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
2593
+ serverless_inference_config (dict): Specifies configuration dict related to serverless
2594
+ endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
2595
+ object (default: None)
2596
+ volume_size (int): The size, in GB, of the ML storage volume attached to individual
2597
+ inference instance associated with the production variant. Currenly only Amazon EBS
2598
+ gp2 storage volumes are supported.
2599
+ model_data_download_timeout (int): The timeout value, in seconds, to download and extract
2600
+ model data from Amazon S3 to the individual inference instance associated with this
2601
+ production variant.
2602
+ container_startup_health_check_timeout (int): The timeout value, in seconds, for your
2603
+ inference container to pass health check by SageMaker Hosting. For more information
2604
+ about health check see:
2605
+ https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
2606
+ #your-algorithms
2607
+ -inference-algo-ping-requests
2608
+
2609
+ Returns:
2610
+ dict[str, str]: An SageMaker ``ProductionVariant`` description
2611
+ """
2612
+ production_variant_configuration = {
2613
+ "VariantName": variant_name,
2614
+ }
2615
+ if model_name:
2616
+ production_variant_configuration["ModelName"] = model_name
2617
+ production_variant_configuration["InitialVariantWeight"] = initial_weight
2618
+
2619
+ if accelerator_type:
2620
+ production_variant_configuration["AcceleratorType"] = accelerator_type
2621
+
2622
+ if managed_instance_scaling:
2623
+ production_variant_configuration["ManagedInstanceScaling"] = managed_instance_scaling
2624
+
2625
+ if serverless_inference_config:
2626
+ production_variant_configuration["ServerlessConfig"] = serverless_inference_config
2627
+ else:
2628
+ initial_instance_count = initial_instance_count or 1
2629
+ production_variant_configuration["InitialInstanceCount"] = initial_instance_count
2630
+ production_variant_configuration["InstanceType"] = instance_type
2631
+ update_args(
2632
+ production_variant_configuration,
2633
+ VolumeSizeInGB=volume_size,
2634
+ ModelDataDownloadTimeoutInSeconds=model_data_download_timeout,
2635
+ ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout,
2636
+ RoutingConfig=routing_config,
2637
+ )
2638
+
2639
+ if inference_ami_version:
2640
+ production_variant_configuration["InferenceAmiVersion"] = inference_ami_version
2641
+
2642
+ return production_variant_configuration
2643
+
2644
+
2645
+ def _has_permission_for_live_logging(boto_session, endpoint_name) -> bool:
2646
+ """Validate if customer's role has the right permission to access logs from CloudWatch"""
2647
+ try:
2648
+ cloudwatch_client = boto_session.client("logs")
2649
+ cloudwatch_client.filter_log_events(
2650
+ logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
2651
+ logStreamNamePrefix="AllTraffic/",
2652
+ )
2653
+ return True
2654
+ except ClientError as e:
2655
+ if e.response["Error"]["Code"] == "AccessDeniedException":
2656
+ LOGGER.warning(
2657
+ ("Failed to enable live logging: %s. Fallback to default logging..."),
2658
+ e,
2659
+ )
2660
+
2661
+ return False
2662
+ return True
2663
+
2664
+
2665
+ def _deploy_done(sagemaker_client, endpoint_name):
2666
+ """Placeholder docstring"""
2667
+ hosting_status_codes = {
2668
+ "OutOfService": "x",
2669
+ "Creating": "-",
2670
+ "Updating": "-",
2671
+ "InService": "!",
2672
+ "RollingBack": "<",
2673
+ "Deleting": "o",
2674
+ "Failed": "*",
2675
+ }
2676
+ in_progress_statuses = ["Creating", "Updating"]
2677
+
2678
+ desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
2679
+ status = desc["EndpointStatus"]
2680
+
2681
+ print(hosting_status_codes.get(status, "?"), end="")
2682
+ sys.stdout.flush()
2683
+
2684
+ return None if status in in_progress_statuses else desc
2685
+
2686
+
2687
+ def _live_logging_deploy_done(sagemaker_client, endpoint_name, paginator, paginator_config, poll):
2688
+ """Placeholder docstring"""
2689
+ stop = False
2690
+ endpoint_status = None
2691
+ try:
2692
+ desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
2693
+ endpoint_status = desc["EndpointStatus"]
2694
+ except ClientError as e:
2695
+ if e.response["Error"]["Code"] == "ValidationException":
2696
+ LOGGER.debug("Waiting for endpoint to become visible")
2697
+ return None
2698
+ raise e
2699
+
2700
+ try:
2701
+ # if endpoint is in an invalid state -> set stop to true, sleep, and flush the logs
2702
+ if endpoint_status != "Creating":
2703
+ stop = True
2704
+ if endpoint_status == "InService":
2705
+ LOGGER.info("Created endpoint with name %s. Waiting for it to be InService", endpoint_name)
2706
+ else:
2707
+ time.sleep(poll)
2708
+
2709
+ pages = paginator.paginate(
2710
+ logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
2711
+ logStreamNamePrefix="AllTraffic/",
2712
+ PaginationConfig=paginator_config,
2713
+ )
2714
+
2715
+ for page in pages:
2716
+ if "nextToken" in page:
2717
+ paginator_config["StartingToken"] = page["nextToken"]
2718
+ for event in page["events"]:
2719
+ LOGGER.info(event["message"])
2720
+ else:
2721
+ LOGGER.debug("No log events available")
2722
+
2723
+ # if stop is true -> return the describe response and stop polling
2724
+ if stop:
2725
+ return desc
2726
+ except ClientError as e:
2727
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
2728
+ LOGGER.debug("Waiting for endpoint log group to appear")
2729
+ return None
2730
+ raise e
2731
+
2732
+ return None
2733
+
2734
+
2735
+ def _deployment_entity_exists(describe_fn):
2736
+ """Placeholder docstring"""
2737
+ try:
2738
+ describe_fn()
2739
+ return True
2740
+ except ClientError as ce:
2741
+ error_code = ce.response["Error"]["Code"]
2742
+ if not (
2743
+ error_code == "ValidationException"
2744
+ and "Could not find" in ce.response["Error"]["Message"]
2745
+ ):
2746
+ raise ce
2747
+ return False
2748
+
2749
+
2750
+ def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name):
2751
+ """Retrieves log events from the specified CloudWatch log group and log stream.
2752
+
2753
+ Args:
2754
+ cw_client (boto3.client): A boto3 CloudWatch client.
2755
+ log_group_name (str): The name of the CloudWatch log group.
2756
+ log_stream_name (str): The name of the CloudWatch log stream.
2757
+
2758
+ Returns:
2759
+ (dict): A dictionary containing log events from CloudWatch log group and log stream.
2760
+ """
2761
+ print("Fetching logs from CloudWatch...", flush=True)
2762
+ for _ in retries(
2763
+ max_retry_count=30, # 30*10 = 5min
2764
+ exception_message_prefix="Waiting for cloudwatch stream to appear. ",
2765
+ seconds_to_sleep=10,
2766
+ ):
2767
+ try:
2768
+ return cw_client.get_log_events(
2769
+ logGroupName=log_group_name, logStreamName=log_stream_name
2770
+ )
2771
+ except ClientError as e:
2772
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
2773
+ pass
2774
+
2775
+
2776
+ def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str):
2777
+ """Describes the status of a job and returns the job description.
2778
+
2779
+ Args:
2780
+ sagemaker_client (boto3.client.sagemaker): A SageMaker client.
2781
+ job_name (str): The name of the job.
2782
+
2783
+ Returns:
2784
+ dict: The job description, or None if the job is still in progress.
2785
+ """
2786
+ inference_recommendations_job_status_codes = {
2787
+ "PENDING": ".",
2788
+ "IN_PROGRESS": ".",
2789
+ "COMPLETED": "!",
2790
+ "FAILED": "*",
2791
+ "STOPPING": "_",
2792
+ "STOPPED": "s",
2793
+ }
2794
+ in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
2795
+
2796
+ desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
2797
+ status = desc["Status"]
2798
+
2799
+ print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True)
2800
+
2801
+ if status in in_progress_statuses:
2802
+ return None
2803
+
2804
+ print("")
2805
+ return desc
2806
+
2807
+
2808
+ def _display_inference_recommendations_job_steps_status(
2809
+ sagemaker_session, sagemaker_client, job_name: str, poll: int = 60
2810
+ ):
2811
+ """Placeholder docstring"""
2812
+ cloudwatch_client = sagemaker_session.boto_session.client("logs")
2813
+ in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
2814
+ log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs"
2815
+ log_stream_name = job_name + "/execution"
2816
+
2817
+ initial_logs_batch = get_log_events_for_inference_recommender(
2818
+ cloudwatch_client, log_group_name, log_stream_name
2819
+ )
2820
+ print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True)
2821
+ events = initial_logs_batch["events"]
2822
+ print(*[event["message"] for event in events], sep="\n", flush=True)
2823
+
2824
+ next_forward_token = initial_logs_batch["nextForwardToken"] if events else None
2825
+ flush_remaining = True
2826
+ while True:
2827
+ logs_batch = (
2828
+ cloudwatch_client.get_log_events(
2829
+ logGroupName=log_group_name,
2830
+ logStreamName=log_stream_name,
2831
+ nextToken=next_forward_token,
2832
+ )
2833
+ if next_forward_token
2834
+ else cloudwatch_client.get_log_events(
2835
+ logGroupName=log_group_name, logStreamName=log_stream_name
2836
+ )
2837
+ )
2838
+
2839
+ events = logs_batch["events"]
2840
+
2841
+ desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
2842
+ status = desc["Status"]
2843
+
2844
+ if not events:
2845
+ if status in in_progress_statuses:
2846
+ time.sleep(poll)
2847
+ continue
2848
+ if flush_remaining:
2849
+ flush_remaining = False
2850
+ time.sleep(poll)
2851
+ continue
2852
+
2853
+ next_forward_token = logs_batch["nextForwardToken"]
2854
+ print(*[event["message"] for event in events], sep="\n", flush=True)
2855
+
2856
+ if status not in in_progress_statuses:
2857
+ break
2858
+
2859
+ time.sleep(poll)
2860
+
2861
+
2862
+ def _optimization_job_status(sagemaker_client, job_name):
2863
+ """Placeholder docstring"""
2864
+ optimization_job_status_codes = {
2865
+ "INPROGRESS": ".",
2866
+ "COMPLETED": "!",
2867
+ "FAILED": "*",
2868
+ "STARTING": ".",
2869
+ "STOPPING": "_",
2870
+ "STOPPED": "s",
2871
+ }
2872
+ in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"]
2873
+
2874
+ desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name)
2875
+ status = desc["OptimizationJobStatus"]
2876
+
2877
+ print(optimization_job_status_codes.get(status, "?"), end="")
2878
+ sys.stdout.flush()
2879
+
2880
+ if status in in_progress_statuses:
2881
+ return None
2882
+
2883
+ print("")
2884
+ return desc
2885
+
2886
+
2887
+ def container_def(
2888
+ image_uri,
2889
+ model_data_url=None,
2890
+ env=None,
2891
+ container_mode=None,
2892
+ image_config=None,
2893
+ accept_eula=None,
2894
+ additional_model_data_sources=None,
2895
+ model_reference_arn=None,
2896
+ ):
2897
+ """Create a definition for executing a container as part of a SageMaker model.
2898
+
2899
+ Args:
2900
+ image_uri (str): Docker image URI to run for this container.
2901
+ model_data_url (str or dict[str, Any]): S3 location of model data required by this
2902
+ container, e.g. SageMaker training job model artifacts. It can either be a string
2903
+ representing S3 URI of model data, or a dictionary representing a
2904
+ ``ModelDataSource`` object. (default: None).
2905
+ env (dict[str, str]): Environment variables to set inside the container (default: None).
2906
+ container_mode (str): The model container mode. Valid modes:
2907
+ * MultiModel: Indicates that model container can support hosting multiple models
2908
+ * SingleModel: Indicates that model container can support hosting a single model
2909
+ This is the default model container mode when container_mode = None
2910
+ image_config (dict[str, str]): Specifies whether the image of model container is pulled
2911
+ from ECR, or private registry in your VPC. By default it is set to pull model
2912
+ container image from ECR. (default: None).
2913
+ accept_eula (bool): For models that require a Model Access Config, specify True or
2914
+ False to indicate whether model terms of use have been accepted.
2915
+ The `accept_eula` value must be explicitly defined as `True` in order to
2916
+ accept the end-user license agreement (EULA) that some
2917
+ models require. (Default: None).
2918
+ additional_model_data_sources (PipelineVariable or dict): Additional location
2919
+ of SageMaker model data (default: None).
2920
+
2921
+ Returns:
2922
+ dict[str, str]: A complete container definition object usable with the CreateModel API if
2923
+ passed via `PrimaryContainers` field.
2924
+ """
2925
+ if env is None:
2926
+ env = {}
2927
+ c_def = {"Image": image_uri, "Environment": env}
2928
+
2929
+ if additional_model_data_sources:
2930
+ c_def["AdditionalModelDataSources"] = additional_model_data_sources
2931
+
2932
+ if isinstance(model_data_url, str) and (
2933
+ not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
2934
+ or accept_eula is None
2935
+ ):
2936
+ c_def["ModelDataUrl"] = model_data_url
2937
+
2938
+ elif isinstance(model_data_url, (dict, str)):
2939
+ if isinstance(model_data_url, dict):
2940
+ c_def["ModelDataSource"] = model_data_url
2941
+ else:
2942
+ c_def["ModelDataSource"] = {
2943
+ "S3DataSource": {
2944
+ "S3Uri": model_data_url,
2945
+ "S3DataType": "S3Object",
2946
+ "CompressionType": "Gzip",
2947
+ }
2948
+ }
2949
+ if accept_eula is not None:
2950
+ c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = {
2951
+ "AcceptEula": accept_eula
2952
+ }
2953
+ if model_reference_arn:
2954
+ c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = {
2955
+ "HubContentArn": model_reference_arn
2956
+ }
2957
+
2958
+ elif model_data_url is not None:
2959
+ c_def["ModelDataUrl"] = model_data_url
2960
+
2961
+ if container_mode:
2962
+ c_def["Mode"] = container_mode
2963
+ if image_config:
2964
+ c_def["ImageConfig"] = image_config
2965
+ return c_def