sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2399 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Placeholder docstring"""
14
+ from __future__ import absolute_import
15
+
16
+ import sys
17
+ import contextlib
18
+ import copy
19
+ import errno
20
+ import inspect
21
+ import logging
22
+ import os
23
+ import random
24
+ import re
25
+ import shutil
26
+ import tarfile
27
+ import tempfile
28
+ import time
29
+ from functools import lru_cache
30
+ from typing import Union, Any, List, Optional, Dict
31
+ import json
32
+ import abc
33
+ import uuid
34
+ from datetime import datetime
35
+ from os.path import abspath, realpath, dirname, normpath, join as joinpath
36
+
37
+ from importlib import import_module
38
+
39
+ import boto3
40
+ import botocore
41
+ from botocore.utils import merge_dicts
42
+ from botocore import exceptions
43
+ from botocore.exceptions import ClientError
44
+ from six.moves.urllib import parse
45
+ from six import viewitems
46
+
47
+ import sagemaker
48
+
49
+ from sagemaker.core.enums import RoutingStrategy
50
+ from sagemaker.core.session_settings import SessionSettings
51
+ from sagemaker.core.workflow import is_pipeline_variable, is_pipeline_parameter_string
52
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
53
+ from enum import Enum
54
+
55
+ ALTERNATE_DOMAINS = {
56
+ "cn-north-1": "amazonaws.com.cn",
57
+ "cn-northwest-1": "amazonaws.com.cn",
58
+ "us-iso-east-1": "c2s.ic.gov",
59
+ "us-isob-east-1": "sc2s.sgov.gov",
60
+ "us-isof-south-1": "csp.hci.ic.gov",
61
+ "us-isof-east-1": "csp.hci.ic.gov",
62
+ "eu-isoe-west-1": "cloud.adc-e.uk",
63
+ }
64
+
65
+ ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
66
+ MODEL_PACKAGE_ARN_PATTERN = (
67
+ r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
68
+ )
69
+ MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)"
70
+ MAX_BUCKET_PATHS_COUNT = 5
71
+ S3_PREFIX = "s3://"
72
+ HTTP_PREFIX = "http://"
73
+ HTTPS_PREFIX = "https://"
74
+ DEFAULT_SLEEP_TIME_SECONDS = 10
75
+ WAITING_DOT_NUMBER = 10
76
+ MAX_ITEMS = 100
77
+ PAGE_SIZE = 10
78
+ _MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
79
+
80
+ _SENSITIVE_SYSTEM_PATHS = [
81
+ abspath(os.path.expanduser("~/.aws")),
82
+ abspath(os.path.expanduser("~/.ssh")),
83
+ abspath(os.path.expanduser("~/.kube")),
84
+ abspath(os.path.expanduser("~/.docker")),
85
+ abspath(os.path.expanduser("~/.config")),
86
+ abspath(os.path.expanduser("~/.credentials")),
87
+ "/etc",
88
+ "/root",
89
+ "/var/lib",
90
+ "/opt/ml/metadata",
91
+ ]
92
+
93
+ logger = logging.getLogger(__name__)
94
+
95
+ TagsDict = Dict[str, Union[str, PipelineVariable]]
96
+ Tags = Union[List[TagsDict], TagsDict]
97
+
98
+
99
+ class ModelApprovalStatusEnum(str, Enum):
100
+ """Model package approval status enumerator"""
101
+
102
+ APPROVED = "Approved"
103
+ REJECTED = "Rejected"
104
+ PENDING_MANUAL_APPROVAL = "PendingManualApproval"
105
+
106
+
107
+ # Use the base name of the image as the job name if the user doesn't give us one
108
+ def name_from_image(image, max_length=63):
109
+ """Create a training job name based on the image name and a timestamp.
110
+
111
+ Args:
112
+ image (str): Image name.
113
+
114
+ Returns:
115
+ str: Training job name using the algorithm from the image name and a
116
+ timestamp.
117
+ max_length (int): Maximum length for the resulting string (default: 63).
118
+ """
119
+ return name_from_base(base_name_from_image(image), max_length=max_length)
120
+
121
+
122
+ def name_from_base(base, max_length=63, short=False):
123
+ """Append a timestamp to the provided string.
124
+
125
+ This function assures that the total length of the resulting string is
126
+ not longer than the specified max length, trimming the input parameter if
127
+ necessary.
128
+
129
+ Args:
130
+ base (str): String used as prefix to generate the unique name.
131
+ max_length (int): Maximum length for the resulting string (default: 63).
132
+ short (bool): Whether or not to use a truncated timestamp (default: False).
133
+
134
+ Returns:
135
+ str: Input parameter with appended timestamp.
136
+ """
137
+ timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp()
138
+ trimmed_base = base[: max_length - len(timestamp) - 1]
139
+ return "{}-{}".format(trimmed_base, timestamp)
140
+
141
+
142
+ def unique_name_from_base_uuid4(base, max_length=63):
143
+ """Append a UUID to the provided string.
144
+
145
+ This function is used to generate a name using UUID instead of timestamps
146
+ for uniqueness.
147
+
148
+ Args:
149
+ base (str): String used as prefix to generate the unique name.
150
+ max_length (int): Maximum length for the resulting string (default: 63).
151
+
152
+ Returns:
153
+ str: Input parameter with appended timestamp.
154
+ """
155
+ random.seed(int(uuid.uuid4())) # using uuid to randomize
156
+ unique = str(uuid.uuid4())
157
+ trimmed_base = base[: max_length - len(unique) - 1]
158
+ return "{}-{}".format(trimmed_base, unique)
159
+
160
+
161
+ def unique_name_from_base(base, max_length=63):
162
+ """Placeholder Docstring"""
163
+ random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used.
164
+ unique = "%04x" % random.randrange(16**4) # 4-digit hex
165
+ ts = str(int(time.time()))
166
+ available_length = max_length - 2 - len(ts) - len(unique)
167
+ trimmed = base[:available_length]
168
+ return "{}-{}-{}".format(trimmed, ts, unique)
169
+
170
+
171
+ def base_name_from_image(image, default_base_name=None):
172
+ """Extract the base name of the image to use as the 'algorithm name' for the job.
173
+
174
+ Args:
175
+ image (str): Image name.
176
+ default_base_name (str): The default base name
177
+
178
+ Returns:
179
+ str: Algorithm name, as extracted from the image name.
180
+ """
181
+ if is_pipeline_variable(image):
182
+ if is_pipeline_parameter_string(image) and image.default_value:
183
+ image_str = image.default_value
184
+ else:
185
+ return default_base_name if default_base_name else "base_name"
186
+ else:
187
+ image_str = image
188
+
189
+ m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
190
+ base_name = m.group(2) if m else image_str
191
+ return base_name
192
+
193
+
194
+ def base_from_name(name):
195
+ """Extract the base name of the resource name (for use with future resource name generation).
196
+
197
+ This function looks for timestamps that match the ones produced by
198
+ :func:`~sagemaker.utils.name_from_base`.
199
+
200
+ Args:
201
+ name (str): The resource name.
202
+
203
+ Returns:
204
+ str: The base name, as extracted from the resource name.
205
+ """
206
+ m = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", name)
207
+ return m.group(1) if m else name
208
+
209
+
210
+ def sagemaker_timestamp():
211
+ """Return a timestamp with millisecond precision."""
212
+ moment = time.time()
213
+ moment_ms = repr(moment).split(".")[1][:3]
214
+ return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment))
215
+
216
+
217
+ def sagemaker_short_timestamp():
218
+ """Return a timestamp that is relatively short in length"""
219
+ return time.strftime("%y%m%d-%H%M")
220
+
221
+
222
+ def build_dict(key, value):
223
+ """Return a dict of key and value pair if value is not None, otherwise return an empty dict.
224
+
225
+ Args:
226
+ key (str): input key
227
+ value (str): input value
228
+
229
+ Returns:
230
+ dict: dict of key and value or an empty dict.
231
+ """
232
+ if value:
233
+ return {key: value}
234
+ return {}
235
+
236
+
237
+ def get_config_value(key_path, config):
238
+ """Placeholder Docstring"""
239
+ if config is None:
240
+ return None
241
+
242
+ current_section = config
243
+ for key in key_path.split("."):
244
+ if key in current_section:
245
+ current_section = current_section[key]
246
+ else:
247
+ return None
248
+
249
+ return current_section
250
+
251
+
252
+ def get_nested_value(dictionary: dict, nested_keys: List[str]):
253
+ """Returns a nested value from the given dictionary, and None if none present.
254
+
255
+ Raises
256
+ ValueError if the dictionary structure does not match the nested_keys
257
+ """
258
+ if (
259
+ dictionary is not None
260
+ and isinstance(dictionary, dict)
261
+ and nested_keys is not None
262
+ and len(nested_keys) > 0
263
+ ):
264
+
265
+ current_section = dictionary
266
+
267
+ for key in nested_keys[:-1]:
268
+ current_section = current_section.get(key, None)
269
+ if current_section is None:
270
+ # means the full path of nested_keys doesnt exist in the dictionary
271
+ # or the value was set to None
272
+ return None
273
+ if not isinstance(current_section, dict):
274
+ raise ValueError(
275
+ "Unexpected structure of dictionary.",
276
+ "Expected value of type dict at key '{}' but got '{}' for dict '{}'".format(
277
+ key, current_section, dictionary
278
+ ),
279
+ )
280
+ return current_section.get(nested_keys[-1], None)
281
+
282
+ return None
283
+
284
+
285
+ def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object):
286
+ """Sets a nested value in a dictionary.
287
+
288
+ This sets a nested value inside the given dictionary and returns the new dictionary. Note: if
289
+ provided an unintended list of nested keys, this can overwrite an unexpected part of the dict.
290
+ Recommended to use after a check with get_nested_value first
291
+ """
292
+
293
+ if dictionary is None:
294
+ dictionary = {}
295
+
296
+ if (
297
+ dictionary is not None
298
+ and isinstance(dictionary, dict)
299
+ and nested_keys is not None
300
+ and len(nested_keys) > 0
301
+ ):
302
+ current_section = dictionary
303
+ for key in nested_keys[:-1]:
304
+ if (
305
+ key not in current_section
306
+ or current_section[key] is None
307
+ or not isinstance(current_section[key], dict)
308
+ ):
309
+ current_section[key] = {}
310
+ current_section = current_section[key]
311
+
312
+ current_section[nested_keys[-1]] = value_to_set
313
+ return dictionary
314
+
315
+
316
+ def get_short_version(framework_version):
317
+ """Return short version in the format of x.x
318
+
319
+ Args:
320
+ framework_version: The version string to be shortened.
321
+
322
+ Returns:
323
+ str: The short version string
324
+ """
325
+ return ".".join(framework_version.split(".")[:2])
326
+
327
+
328
+ def secondary_training_status_changed(current_job_description, prev_job_description):
329
+ """Returns true if training job's secondary status message has changed.
330
+
331
+ Args:
332
+ current_job_description: Current job description, returned from DescribeTrainingJob call.
333
+ prev_job_description: Previous job description, returned from DescribeTrainingJob call.
334
+
335
+ Returns:
336
+ boolean: Whether the secondary status message of a training job changed
337
+ or not.
338
+ """
339
+ current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions")
340
+ if (
341
+ current_secondary_status_transitions is None
342
+ or len(current_secondary_status_transitions) == 0
343
+ ):
344
+ return False
345
+
346
+ prev_job_secondary_status_transitions = (
347
+ prev_job_description.get("SecondaryStatusTransitions")
348
+ if prev_job_description is not None
349
+ else None
350
+ )
351
+
352
+ last_message = (
353
+ prev_job_secondary_status_transitions[-1]["StatusMessage"]
354
+ if prev_job_secondary_status_transitions is not None
355
+ and len(prev_job_secondary_status_transitions) > 0
356
+ else ""
357
+ )
358
+
359
+ message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"]
360
+
361
+ return message != last_message
362
+
363
+
364
+ def secondary_training_status_message(job_description, prev_description):
365
+ """Returns a string contains last modified time and the secondary training job status message.
366
+
367
+ Args:
368
+ job_description: Returned response from DescribeTrainingJob call
369
+ prev_description: Previous job description from DescribeTrainingJob call
370
+
371
+ Returns:
372
+ str: Job status string to be printed.
373
+ """
374
+
375
+ if (
376
+ job_description is None
377
+ or job_description.get("SecondaryStatusTransitions") is None
378
+ or len(job_description.get("SecondaryStatusTransitions")) == 0
379
+ ):
380
+ return ""
381
+
382
+ prev_description_secondary_transitions = (
383
+ prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None
384
+ )
385
+ prev_transitions_num = (
386
+ len(prev_description["SecondaryStatusTransitions"])
387
+ if prev_description_secondary_transitions is not None
388
+ else 0
389
+ )
390
+ current_transitions = job_description["SecondaryStatusTransitions"]
391
+
392
+ if len(current_transitions) == prev_transitions_num:
393
+ # Secondary status is not changed but the message changed.
394
+ transitions_to_print = current_transitions[-1:]
395
+ else:
396
+ # Secondary status is changed we need to print all the entries.
397
+ transitions_to_print = current_transitions[
398
+ prev_transitions_num - len(current_transitions) :
399
+ ]
400
+
401
+ status_strs = []
402
+ for transition in transitions_to_print:
403
+ message = transition["StatusMessage"]
404
+ time_str = datetime.utcfromtimestamp(
405
+ time.mktime(job_description["LastModifiedTime"].timetuple())
406
+ ).strftime("%Y-%m-%d %H:%M:%S")
407
+ status_strs.append("{} {} - {}".format(time_str, transition["Status"], message))
408
+
409
+ return "\n".join(status_strs)
410
+
411
+
412
+ def download_folder(bucket_name, prefix, target, sagemaker_session):
413
+ """Download a folder from S3 to a local path
414
+
415
+ Args:
416
+ bucket_name (str): S3 bucket name
417
+ prefix (str): S3 prefix within the bucket that will be downloaded. Can
418
+ be a single file.
419
+ target (str): destination path where the downloaded items will be placed
420
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
421
+ interact with S3.
422
+ """
423
+ s3 = sagemaker_session.s3_resource
424
+
425
+ prefix = prefix.lstrip("/")
426
+
427
+ if ".." in prefix:
428
+ raise ValueError("Traversal components are not allowed in S3 path!")
429
+
430
+ # Try to download the prefix as an object first, in case it is a file and not a 'directory'.
431
+ # Do this first, in case the object has broader permissions than the bucket.
432
+ if not prefix.endswith("/"):
433
+ try:
434
+ file_destination = os.path.join(target, os.path.basename(prefix))
435
+ s3.Object(bucket_name, prefix).download_file(file_destination)
436
+ return
437
+ except botocore.exceptions.ClientError as e:
438
+ err_info = e.response["Error"]
439
+ if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
440
+ # S3 also throws this error if the object is a folder,
441
+ # so assume that is the case here, and then raise for an actual 404 later.
442
+ pass
443
+ else:
444
+ raise
445
+
446
+ _download_files_under_prefix(bucket_name, prefix, target, s3)
447
+
448
+
449
+ def _download_files_under_prefix(bucket_name, prefix, target, s3):
450
+ """Download all S3 files which match the given prefix
451
+
452
+ Args:
453
+ bucket_name (str): S3 bucket name
454
+ prefix (str): S3 prefix within the bucket that will be downloaded
455
+ target (str): destination path where the downloaded items will be placed
456
+ s3 (boto3.resources.base.ServiceResource): S3 resource
457
+ """
458
+ bucket = s3.Bucket(bucket_name)
459
+ for obj_sum in bucket.objects.filter(Prefix=prefix):
460
+ # if obj_sum is a folder object skip it.
461
+ if obj_sum.key.endswith("/"):
462
+ continue
463
+ obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
464
+ s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
465
+ file_path = os.path.join(target, s3_relative_path)
466
+
467
+ try:
468
+ os.makedirs(os.path.dirname(file_path))
469
+ except OSError as exc:
470
+ # EEXIST means the folder already exists, this is safe to skip
471
+ # anything else will be raised.
472
+ if exc.errno != errno.EEXIST:
473
+ raise
474
+ obj.download_file(file_path)
475
+
476
+
477
+ def create_tar_file(source_files, target=None):
478
+ """Create a tar file containing all the source_files
479
+
480
+ Args:
481
+ source_files: (List[str]): List of file paths that will be contained in the tar file
482
+ target:
483
+
484
+ Returns:
485
+ (str): path to created tar file
486
+ """
487
+ if target:
488
+ filename = target
489
+ else:
490
+ _, filename = tempfile.mkstemp()
491
+
492
+ with tarfile.open(filename, mode="w:gz", dereference=True) as t:
493
+ for sf in source_files:
494
+ # Add all files from the directory into the root of the directory structure of the tar
495
+ t.add(sf, arcname=os.path.basename(sf))
496
+ return filename
497
+
498
+
499
+ @contextlib.contextmanager
500
+ def _tmpdir(suffix="", prefix="tmp", directory=None):
501
+ """Create a temporary directory with a context manager.
502
+
503
+ The file is deleted when the context exits, even when there's an exception.
504
+ The prefix, suffix, and dir arguments are the same as for mkstemp().
505
+
506
+ Args:
507
+ suffix (str): If suffix is specified, the file name will end with that
508
+ suffix, otherwise there will be no suffix.
509
+ prefix (str): If prefix is specified, the file name will begin with that
510
+ prefix; otherwise, a default prefix is used.
511
+ directory (str): If a directory is specified, the file will be downloaded
512
+ in this directory; otherwise, a default directory is used.
513
+
514
+ Returns:
515
+ str: path to the directory
516
+ """
517
+ if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
518
+ raise ValueError(
519
+ "Inputted directory for storing newly generated temporary "
520
+ f"directory does not exist: '{directory}'"
521
+ )
522
+ tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
523
+ try:
524
+ yield tmp
525
+ finally:
526
+ shutil.rmtree(tmp)
527
+
528
+
529
+ def repack_model(
530
+ inference_script,
531
+ source_directory,
532
+ dependencies,
533
+ model_uri,
534
+ repacked_model_uri,
535
+ sagemaker_session,
536
+ kms_key=None,
537
+ ):
538
+ """Unpack model tarball and creates a new model tarball with the provided code script.
539
+
540
+ This function does the following: - uncompresses model tarball from S3 or
541
+ local system into a temp folder - replaces the inference code from the model
542
+ with the new code provided - compresses the new model tarball and saves it
543
+ in S3 or local file system
544
+
545
+ Args:
546
+ inference_script (str): path or basename of the inference script that
547
+ will be packed into the model
548
+ source_directory (str): path including all the files that will be packed
549
+ into the model
550
+ dependencies (list[str]): A list of paths to directories (absolute or
551
+ relative) with any additional libraries that will be exported to the
552
+ container (default: []). The library folders will be copied to
553
+ SageMaker in the same folder where the entrypoint is copied.
554
+ Example
555
+
556
+ The following call >>> Estimator(entry_point='train.py',
557
+ dependencies=['my/libs/common', 'virtual-env']) results in the
558
+ following inside the container:
559
+
560
+ >>> $ ls
561
+
562
+ >>> opt/ml/code
563
+ >>> |------ train.py
564
+ >>> |------ common
565
+ >>> |------ virtual-env
566
+ model_uri (str): S3 or file system location of the original model tar
567
+ repacked_model_uri (str): path or file system location where the new
568
+ model will be saved
569
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
570
+ interact with S3.
571
+ kms_key (str): KMS key ARN for encrypting the repacked model file
572
+
573
+ Returns:
574
+ str: path to the new packed model
575
+ """
576
+ dependencies = dependencies or []
577
+
578
+ local_download_dir = (
579
+ None
580
+ if sagemaker_session.settings is None
581
+ or sagemaker_session.settings.local_download_dir is None
582
+ else sagemaker_session.settings.local_download_dir
583
+ )
584
+ with _tmpdir(directory=local_download_dir) as tmp:
585
+ model_dir = _extract_model(model_uri, sagemaker_session, tmp)
586
+
587
+ _create_or_update_code_dir(
588
+ model_dir,
589
+ inference_script,
590
+ source_directory,
591
+ dependencies,
592
+ sagemaker_session,
593
+ tmp,
594
+ )
595
+
596
+ tmp_model_path = os.path.join(tmp, "temp-model.tar.gz")
597
+ with tarfile.open(tmp_model_path, mode="w:gz") as t:
598
+ t.add(model_dir, arcname=os.path.sep)
599
+
600
+ _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)
601
+
602
+
603
+ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
604
+ """Placeholder docstring"""
605
+ if repacked_model_uri.lower().startswith("s3://"):
606
+ url = parse.urlparse(repacked_model_uri)
607
+ bucket, key = url.netloc, url.path.lstrip("/")
608
+ new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
609
+
610
+ settings = (
611
+ sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
612
+ )
613
+ encrypt_artifact = settings.encrypt_repacked_artifacts
614
+
615
+ if kms_key:
616
+ extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
617
+ elif encrypt_artifact:
618
+ extra_args = {"ServerSideEncryption": "aws:kms"}
619
+ else:
620
+ extra_args = None
621
+ sagemaker_session.boto_session.resource(
622
+ "s3", region_name=sagemaker_session.boto_region_name
623
+ ).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
624
+ else:
625
+ shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
626
+
627
+
628
+ def _validate_source_directory(source_directory):
629
+ """Validate that source_directory is safe to use.
630
+
631
+ Ensures the source directory path does not access restricted system locations.
632
+
633
+ Args:
634
+ source_directory (str): The source directory path to validate.
635
+
636
+ Raises:
637
+ ValueError: If the path is not allowed.
638
+ """
639
+ if not source_directory or source_directory.lower().startswith("s3://"):
640
+ # S3 paths and None are safe
641
+ return
642
+
643
+ # Resolve symlinks to get the actual path
644
+ abs_source = abspath(realpath(source_directory))
645
+
646
+ # Check if the source path is under any sensitive directory
647
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
648
+ if abs_source != "/" and abs_source.startswith(sensitive_path):
649
+ raise ValueError(
650
+ f"source_directory cannot access sensitive system paths. "
651
+ f"Got: {source_directory} (resolved to {abs_source})"
652
+ )
653
+
654
+
655
+ def _validate_dependency_path(dependency):
656
+ """Validate that a dependency path is safe to use.
657
+
658
+ Ensures the dependency path does not access restricted system locations.
659
+
660
+ Args:
661
+ dependency (str): The dependency path to validate.
662
+
663
+ Raises:
664
+ ValueError: If the path is not allowed.
665
+ """
666
+ if not dependency:
667
+ return
668
+
669
+ # Resolve symlinks to get the actual path
670
+ abs_dependency = abspath(realpath(dependency))
671
+
672
+ # Check if the dependency path is under any sensitive directory
673
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
674
+ if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
675
+ raise ValueError(
676
+ f"dependency path cannot access sensitive system paths. "
677
+ f"Got: {dependency} (resolved to {abs_dependency})"
678
+ )
679
+
680
+
681
+ def _create_or_update_code_dir(
682
+ model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
683
+ ):
684
+ """Placeholder docstring"""
685
+ code_dir = os.path.join(model_dir, "code")
686
+ resolved_code_dir = _get_resolved_path(code_dir)
687
+
688
+ # Validate that code_dir does not resolve to a sensitive system path
689
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
690
+ if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
691
+ raise ValueError(
692
+ f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
693
+ )
694
+
695
+ if source_directory and source_directory.lower().startswith("s3://"):
696
+ local_code_path = os.path.join(tmp, "local_code.tar.gz")
697
+ download_file_from_url(source_directory, local_code_path, sagemaker_session)
698
+
699
+ with tarfile.open(name=local_code_path, mode="r:gz") as t:
700
+ custom_extractall_tarfile(t, code_dir)
701
+
702
+ elif source_directory:
703
+ # Validate source_directory for security
704
+ _validate_source_directory(source_directory)
705
+ if os.path.exists(code_dir):
706
+ shutil.rmtree(code_dir)
707
+ shutil.copytree(source_directory, code_dir)
708
+ else:
709
+ if not os.path.exists(code_dir):
710
+ os.mkdir(code_dir)
711
+ try:
712
+ shutil.copy2(inference_script, code_dir)
713
+ except FileNotFoundError:
714
+ if os.path.exists(os.path.join(code_dir, inference_script)):
715
+ pass
716
+ else:
717
+ raise
718
+
719
+ for dependency in dependencies:
720
+ # Validate dependency path for security
721
+ _validate_dependency_path(dependency)
722
+ lib_dir = os.path.join(code_dir, "lib")
723
+ if os.path.isdir(dependency):
724
+ shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
725
+ else:
726
+ if not os.path.exists(lib_dir):
727
+ os.mkdir(lib_dir)
728
+ shutil.copy2(dependency, lib_dir)
729
+
730
+
731
+ def _extract_model(model_uri, sagemaker_session, tmp):
732
+ """Placeholder docstring"""
733
+ tmp_model_dir = os.path.join(tmp, "model")
734
+ os.mkdir(tmp_model_dir)
735
+ if model_uri.lower().startswith("s3://"):
736
+ local_model_path = os.path.join(tmp, "tar_file")
737
+ download_file_from_url(model_uri, local_model_path, sagemaker_session)
738
+ else:
739
+ local_model_path = model_uri.replace("file://", "")
740
+ with tarfile.open(name=local_model_path, mode="r:gz") as t:
741
+ custom_extractall_tarfile(t, tmp_model_dir)
742
+ return tmp_model_dir
743
+
744
+
745
+ def download_file_from_url(url, dst, sagemaker_session):
746
+ """Placeholder docstring"""
747
+ url = parse.urlparse(url)
748
+ bucket, key = url.netloc, url.path.lstrip("/")
749
+
750
+ download_file(bucket, key, dst, sagemaker_session)
751
+
752
+
753
+ def download_file(bucket_name, path, target, sagemaker_session):
754
+ """Download a Single File from S3 into a local path
755
+
756
+ Args:
757
+ bucket_name (str): S3 bucket name
758
+ path (str): file path within the bucket
759
+ target (str): destination directory for the downloaded file.
760
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
761
+ interact with S3.
762
+ """
763
+ path = path.lstrip("/")
764
+ boto_session = sagemaker_session.boto_session
765
+
766
+ s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
767
+ bucket = s3.Bucket(bucket_name)
768
+ bucket.download_file(path, target)
769
+
770
+
771
+ def sts_regional_endpoint(region):
772
+ """Get the AWS STS endpoint specific for the given region.
773
+
774
+ We need this function because the AWS SDK does not yet honor
775
+ the ``region_name`` parameter when creating an AWS STS client.
776
+
777
+ For the list of regional endpoints, see
778
+ https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
779
+
780
+ Args:
781
+ region (str): AWS region name
782
+
783
+ Returns:
784
+ str: AWS STS regional endpoint
785
+ """
786
+ endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
787
+ if region == "il-central-1" and not endpoint_data:
788
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
789
+ return "https://{}".format(endpoint_data["hostname"])
790
+
791
+
792
+ def retries(
793
+ max_retry_count,
794
+ exception_message_prefix,
795
+ seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS,
796
+ ):
797
+ """Retries until max retry count is reached.
798
+
799
+ Args:
800
+ max_retry_count (int): The retry count.
801
+ exception_message_prefix (str): The message to include in the exception on failure.
802
+ seconds_to_sleep (int): The number of seconds to sleep between executions.
803
+
804
+ """
805
+ for i in range(max_retry_count):
806
+ yield i
807
+ time.sleep(seconds_to_sleep)
808
+
809
+ raise Exception(
810
+ "'{}' has reached the maximum retry count of {}".format(
811
+ exception_message_prefix, max_retry_count
812
+ )
813
+ )
814
+
815
+
816
+ def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
817
+ """Retry with backoff until maximum attempts are reached
818
+
819
+ Args:
820
+ callable_func (Callable): The callable function to retry.
821
+ num_attempts (int): The maximum number of attempts to retry.(Default: 8)
822
+ botocore_client_error_code (str): The specific Botocore ClientError exception error code
823
+ on which to retry on.
824
+ If provided other exceptions will be raised directly w/o retry.
825
+ If not provided, retry on any exception.
826
+ (Default: None)
827
+ """
828
+ if num_attempts < 1:
829
+ raise ValueError(
830
+ "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts)
831
+ )
832
+ for i in range(num_attempts):
833
+ try:
834
+ return callable_func()
835
+ except Exception as ex: # pylint: disable=broad-except
836
+ if not botocore_client_error_code or (
837
+ botocore_client_error_code
838
+ and isinstance(ex, botocore.exceptions.ClientError)
839
+ and ex.response["Error"]["Code"] # pylint: disable=no-member
840
+ == botocore_client_error_code
841
+ ):
842
+ if i == num_attempts - 1:
843
+ raise ex
844
+ else:
845
+ raise ex
846
+ logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
847
+ time.sleep(2**i)
848
+
849
+
850
+ def _botocore_resolver():
851
+ """Get the DNS suffix for the given region.
852
+
853
+ Args:
854
+ region (str): AWS region name
855
+
856
+ Returns:
857
+ str: the DNS suffix
858
+ """
859
+ loader = botocore.loaders.create_loader()
860
+ return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
861
+
862
+
863
+ def aws_partition(region):
864
+ """Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
865
+
866
+ Args:
867
+ region (str): The region name for which to return the corresponding partition.
868
+ Ex: "cn-north-1"
869
+
870
+ Returns:
871
+ str: partition corresponding to the region name passed in. Ex: "aws-cn"
872
+ """
873
+ endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
874
+ if region == "il-central-1" and not endpoint_data:
875
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
876
+ return endpoint_data["partition"]
877
+
878
+
879
+ class DeferredError(object):
880
+ """Stores an exception and raises it at a later time if this object is accessed in any way.
881
+
882
+ Useful to allow soft-dependencies on imports, so that the ImportError can be raised again
883
+ later if code actually relies on the missing library.
884
+
885
+ Example::
886
+
887
+ try:
888
+ import obscurelib
889
+ except ImportError as e:
890
+ logger.warning("Failed to import obscurelib. Obscure features will not work.")
891
+ obscurelib = DeferredError(e)
892
+ """
893
+
894
+ def __init__(self, exception):
895
+ """Placeholder docstring"""
896
+ self.exc = exception
897
+
898
+ def __getattr__(self, name):
899
+ """Called by Python interpreter before using any method or property on the object.
900
+
901
+ So this will short-circuit essentially any access to this object.
902
+
903
+ Args:
904
+ name:
905
+ """
906
+ raise self.exc
907
+
908
+
909
+ def _module_import_error(py_module, feature, extras):
910
+ """Return error message for module import errors, provide installation details.
911
+
912
+ Args:
913
+ py_module (str): Module that failed to be imported
914
+ feature (str): Affected SageMaker feature
915
+ extras (str): Name of the `extras_require` to install the relevant dependencies
916
+
917
+ Returns:
918
+ str: Error message with installation instructions.
919
+ """
920
+ error_msg = (
921
+ "Failed to import {}. {} features will be impaired or broken. "
922
+ "Please run \"pip install 'sagemaker[{}]'\" "
923
+ "to install all required dependencies."
924
+ )
925
+ return error_msg.format(py_module, feature, extras)
926
+
927
+
928
+ class DataConfig(abc.ABC):
929
+ """Abstract base class for accessing data config hosted in AWS resources.
930
+
931
+ Provides a skeleton for customization by overriding of method fetch_data_config.
932
+ """
933
+
934
+ @abc.abstractmethod
935
+ def fetch_data_config(self):
936
+ """Abstract method implementing retrieval of data config from a pre-configured data source.
937
+
938
+ Returns:
939
+ object: The data configuration object.
940
+ """
941
+
942
+
943
+ class S3DataConfig(DataConfig):
944
+ """This class extends the DataConfig class to fetch a data config file hosted on S3"""
945
+
946
+ def __init__(
947
+ self,
948
+ sagemaker_session,
949
+ bucket_name,
950
+ prefix,
951
+ ):
952
+ """Initialize a ``S3DataConfig`` instance.
953
+
954
+ Args:
955
+ sagemaker_session (Session): SageMaker session instance to use for boto configuration.
956
+ bucket_name (str): Required. Bucket name from which data config needs to be fetched.
957
+ prefix (str): Required. The object prefix for the hosted data config.
958
+
959
+ """
960
+ if bucket_name is None or prefix is None:
961
+ raise ValueError(
962
+ "Bucket Name and S3 file Prefix are required arguments and must be provided."
963
+ )
964
+
965
+ super(S3DataConfig, self).__init__()
966
+
967
+ self.bucket_name = bucket_name
968
+ self.prefix = prefix
969
+ self.sagemaker_session = sagemaker_session
970
+
971
+ def fetch_data_config(self):
972
+ """Fetches data configuration from a S3 bucket.
973
+
974
+ Returns:
975
+ object: The JSON object containing data configuration.
976
+ """
977
+
978
+ json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix)
979
+ return json.loads(json_string)
980
+
981
+ def get_data_bucket(self, region_requested=None):
982
+ """Provides the bucket containing the data for specified region.
983
+
984
+ Args:
985
+ region_requested (str): The region for which the data is beig requested.
986
+
987
+ Returns:
988
+ str: Name of the S3 bucket containing datasets in the requested region.
989
+ """
990
+
991
+ config = self.fetch_data_config()
992
+ region = region_requested if region_requested else self.sagemaker_session.boto_region_name
993
+ return config[region] if region in config.keys() else config["default"]
994
+
995
+
996
+ def update_container_with_inference_params(
997
+ framework=None,
998
+ framework_version=None,
999
+ nearest_model_name=None,
1000
+ data_input_configuration=None,
1001
+ container_def=None,
1002
+ container_list=None,
1003
+ ):
1004
+ """Function to check if inference recommender parameters exist and update container.
1005
+
1006
+ Args:
1007
+ framework (str): Machine learning framework of the model package container image
1008
+ (default: None).
1009
+ framework_version (str): Framework version of the Model Package Container Image
1010
+ (default: None).
1011
+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
1012
+ Amazon SageMaker Inference Recommender (default: None).
1013
+ data_input_configuration (str): Input object for the model (default: None).
1014
+ container_def (dict): object to be updated.
1015
+ container_list (list): list to be updated.
1016
+
1017
+ Returns:
1018
+ dict: dict with inference recommender params
1019
+ """
1020
+
1021
+ if container_list is not None:
1022
+ for obj in container_list:
1023
+ construct_container_object(
1024
+ obj, data_input_configuration, framework, framework_version, nearest_model_name
1025
+ )
1026
+
1027
+ if container_def is not None:
1028
+ construct_container_object(
1029
+ container_def,
1030
+ data_input_configuration,
1031
+ framework,
1032
+ framework_version,
1033
+ nearest_model_name,
1034
+ )
1035
+
1036
+ return container_list or container_def
1037
+
1038
+
1039
+ def construct_container_object(
1040
+ obj, data_input_configuration, framework, framework_version, nearest_model_name
1041
+ ):
1042
+ """Function to construct container object.
1043
+
1044
+ Args:
1045
+ framework (str): Machine learning framework of the model package container image
1046
+ (default: None).
1047
+ framework_version (str): Framework version of the Model Package Container Image
1048
+ (default: None).
1049
+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
1050
+ Amazon SageMaker Inference Recommender (default: None).
1051
+ data_input_configuration (str): Input object for the model (default: None).
1052
+ obj (dict): object to be updated.
1053
+
1054
+ Returns:
1055
+ dict: container object
1056
+ """
1057
+
1058
+ if framework is not None:
1059
+ obj.update(
1060
+ {
1061
+ "Framework": framework,
1062
+ }
1063
+ )
1064
+
1065
+ if framework_version is not None:
1066
+ obj.update(
1067
+ {
1068
+ "FrameworkVersion": framework_version,
1069
+ }
1070
+ )
1071
+
1072
+ if nearest_model_name is not None:
1073
+ obj.update(
1074
+ {
1075
+ "NearestModelName": nearest_model_name,
1076
+ }
1077
+ )
1078
+
1079
+ if data_input_configuration is not None:
1080
+ obj.update(
1081
+ {
1082
+ "ModelInput": {
1083
+ "DataInputConfig": data_input_configuration,
1084
+ },
1085
+ }
1086
+ )
1087
+
1088
+ return obj
1089
+
1090
+
1091
+ def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None):
1092
+ """Pop out the unused key-word argument and give a warning.
1093
+
1094
+ Args:
1095
+ arg_name (str): The name of the argument to be checked if it is unused.
1096
+ kwargs (dict): The key-word argument dict.
1097
+ override_val (str): The value used to override the unused argument (default: None).
1098
+ """
1099
+ if arg_name not in kwargs:
1100
+ return
1101
+ warn_msg = "{} supplied in kwargs will be ignored".format(arg_name)
1102
+ if override_val:
1103
+ warn_msg += " and further overridden with {}.".format(override_val)
1104
+ logging.warning(warn_msg)
1105
+ kwargs.pop(arg_name)
1106
+
1107
+
1108
+ def to_string(obj: object):
1109
+ """Convert an object to string
1110
+
1111
+ This helper function handles converting PipelineVariable object to string as well
1112
+
1113
+ Args:
1114
+ obj (object): The object to be converted
1115
+ """
1116
+ return obj.to_string() if is_pipeline_variable(obj) else str(obj)
1117
+
1118
+
1119
+ def _start_waiting(waiting_time: int):
1120
+ """Waiting and print the in progress animation to stdout.
1121
+
1122
+ Args:
1123
+ waiting_time (int): The total waiting time.
1124
+ """
1125
+ interval = float(waiting_time) / WAITING_DOT_NUMBER
1126
+
1127
+ progress = ""
1128
+ for _ in range(WAITING_DOT_NUMBER):
1129
+ progress += "."
1130
+ print(progress, end="\r")
1131
+ time.sleep(interval)
1132
+ print(len(progress) * " ", end="\r")
1133
+
1134
+
1135
+ def get_module(module_name):
1136
+ """Import a module.
1137
+
1138
+ Args:
1139
+ module_name (str): name of the module to import.
1140
+
1141
+ Returns:
1142
+ object: The imported module.
1143
+
1144
+ Raises:
1145
+ Exception: when the module name is not found
1146
+ """
1147
+ try:
1148
+ return import_module(module_name)
1149
+ except ImportError:
1150
+ raise Exception("Cannot import module {}, please try again.".format(module_name))
1151
+
1152
+
1153
+ def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict:
1154
+ """Check user input experiment_config or get it from the current Run object if exists.
1155
+
1156
+ Args:
1157
+ experiment_config (dict): The experiment_config supplied by the user.
1158
+
1159
+ Returns:
1160
+ dict: Return the user supplied experiment_config if it is not None.
1161
+ Otherwise fetch the experiment_config from the current Run object if exists.
1162
+ """
1163
+ from sagemaker.core.experiments._run_context import _RunContext
1164
+
1165
+ run_obj = _RunContext.get_current_run()
1166
+ if experiment_config:
1167
+ if run_obj:
1168
+ logger.warning(
1169
+ "The function is invoked within an Experiment Run context "
1170
+ "but another experiment_config (%s) was supplied, so "
1171
+ "ignoring the experiment_config fetched from the Run object.",
1172
+ experiment_config,
1173
+ )
1174
+ return experiment_config
1175
+
1176
+ return run_obj.experiment_config if run_obj else None
1177
+
1178
+
1179
+ def resolve_value_from_config(
1180
+ direct_input=None,
1181
+ config_path: str = None,
1182
+ default_value=None,
1183
+ sagemaker_session=None,
1184
+ sagemaker_config: dict = None,
1185
+ ):
1186
+ """Decides which value for the caller to use.
1187
+
1188
+ Note: This method incorporates information from the sagemaker config.
1189
+
1190
+ Uses this order of prioritization:
1191
+ 1. direct_input
1192
+ 2. config value
1193
+ 3. default_value
1194
+ 4. None
1195
+
1196
+ Args:
1197
+ direct_input: The value that the caller of this method starts with. Usually this is an
1198
+ input to the caller's class or method.
1199
+ config_path (str): A string denoting the path used to lookup the value in the
1200
+ sagemaker config.
1201
+ default_value: The value used if not present elsewhere.
1202
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1203
+ SageMaker interactions (default: None).
1204
+ sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1205
+ Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1206
+ be checked for the config value if (and only if) sagemaker_session is None. This
1207
+ parameter exists for the rare cases where the user provided no Session but a default
1208
+ Session cannot be initialized before config injection is needed. In that case,
1209
+ the config dictionary may be loaded and passed here before a default Session object
1210
+ is created.
1211
+
1212
+ Returns:
1213
+ The value that should be used by the caller
1214
+ """
1215
+
1216
+ config_value = (
1217
+ get_sagemaker_config_value(
1218
+ sagemaker_session, config_path, sagemaker_config=sagemaker_config
1219
+ )
1220
+ if config_path
1221
+ else None
1222
+ )
1223
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1224
+
1225
+ _log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
1226
+
1227
+ if direct_input is not None:
1228
+ return direct_input
1229
+
1230
+ if config_value is not None:
1231
+ return config_value
1232
+
1233
+ return default_value
1234
+
1235
+
1236
+ def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = None):
1237
+ """Returns the value that corresponds to the provided key from the configuration file.
1238
+
1239
+ Args:
1240
+ key: Key Path of the config file entry.
1241
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1242
+ SageMaker interactions.
1243
+ sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1244
+ Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1245
+ be checked for the config value if (and only if) sagemaker_session is None. This
1246
+ parameter exists for the rare cases where no Session provided but a default Session
1247
+ cannot be initialized before config injection is needed. In that case, the config
1248
+ dictionary may be loaded and passed here before a default Session object is created.
1249
+
1250
+ Returns:
1251
+ object: The corresponding default value in the configuration file.
1252
+ """
1253
+ from sagemaker.core.config.config_manager import SageMakerConfig
1254
+
1255
+ if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
1256
+ config_to_check = sagemaker_session.sagemaker_config
1257
+ else:
1258
+ config_to_check = sagemaker_config
1259
+
1260
+ if not config_to_check:
1261
+ return None
1262
+
1263
+ SageMakerConfig().validate_sagemaker_config(config_to_check)
1264
+ config_value = get_config_value(key, config_to_check)
1265
+ # Copy the value so any modifications to the output will not modify the source config
1266
+ return copy.deepcopy(config_value)
1267
+
1268
+
1269
+ def get_resource_name_from_arn(arn):
1270
+ """Extract the resource name from an ARN string.
1271
+
1272
+ Args:
1273
+ arn (str): An ARN.
1274
+
1275
+ Returns:
1276
+ str: The resource name.
1277
+ """
1278
+ return arn.split(":", 5)[5].split("/", 1)[1]
1279
+
1280
+
1281
+ def list_tags(sagemaker_session, resource_arn, max_results=50):
1282
+ """List the tags given an Amazon Resource Name.
1283
+
1284
+ Args:
1285
+ resource_arn (str): The Amazon Resource Name (ARN) for which to get the tags list.
1286
+ max_results (int): The maximum number of results to include in a single page.
1287
+ This method takes care of that abstraction and returns a full list.
1288
+ """
1289
+ tags_list = []
1290
+
1291
+ try:
1292
+ list_tags_response = sagemaker_session.sagemaker_client.list_tags(
1293
+ ResourceArn=resource_arn, MaxResults=max_results
1294
+ )
1295
+ tags_list = tags_list + list_tags_response["Tags"]
1296
+
1297
+ next_token = list_tags_response.get("nextToken")
1298
+ while next_token is not None:
1299
+ list_tags_response = sagemaker_session.sagemaker_client.list_tags(
1300
+ ResourceArn=resource_arn, MaxResults=max_results, NextToken=next_token
1301
+ )
1302
+ tags_list = tags_list + list_tags_response["Tags"]
1303
+ next_token = list_tags_response.get("nextToken")
1304
+
1305
+ non_aws_tags = []
1306
+ for tag in tags_list:
1307
+ if "aws:" not in tag["Key"]:
1308
+ non_aws_tags.append(tag)
1309
+ return non_aws_tags
1310
+ except ClientError as error:
1311
+ logger.error("Error retrieving tags. resource_arn: %s", resource_arn)
1312
+ raise error
1313
+
1314
+
1315
+ def resolve_class_attribute_from_config(
1316
+ clazz: Optional[type],
1317
+ instance: Optional[object],
1318
+ attribute: str,
1319
+ config_path: str,
1320
+ default_value=None,
1321
+ sagemaker_session=None,
1322
+ ):
1323
+ """Utility method that merges config values to data classes.
1324
+
1325
+ Takes an instance of a class and, if not already set, sets the instance's attribute to a
1326
+ value fetched from the sagemaker_config or the default_value.
1327
+
1328
+ Uses this order of prioritization to determine what the value of the attribute should be:
1329
+ 1. current value of attribute
1330
+ 2. config value
1331
+ 3. default_value
1332
+ 4. does not set it
1333
+
1334
+ Args:
1335
+ clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the
1336
+ instance is None. If None is provided here, no new object will be created
1337
+ if 'instance' doesnt exist. Note: if provided, the constructor should set default
1338
+ values to None; Otherwise, the constructor's non-None default will be left
1339
+ as-is even if a config value was defined.
1340
+ instance (Optional[object]): instance of the Class 'clazz' that has an attribute
1341
+ of 'attribute' to set
1342
+ attribute (str): attribute of the instance to set if not already set
1343
+ config_path (str): a string denoting the path to use to lookup the config value in the
1344
+ sagemaker config
1345
+ default_value: the value to use if not present elsewhere
1346
+ sagemaker_session (sagemaker.core.helper.session.Sessionn): A SageMaker Session object, used for
1347
+ SageMaker interactions (default: None).
1348
+
1349
+ Returns:
1350
+ The updated class instance that should be used by the caller instead of the
1351
+ 'instance' parameter that was passed in.
1352
+ """
1353
+ config_value = get_sagemaker_config_value(sagemaker_session, config_path)
1354
+
1355
+ if config_value is None and default_value is None:
1356
+ # return instance unmodified. Could be None or populated
1357
+ return instance
1358
+
1359
+ if instance is None:
1360
+ if clazz is None or not inspect.isclass(clazz):
1361
+ return instance
1362
+ # construct a new instance if the instance does not exist
1363
+ instance = clazz()
1364
+
1365
+ if not hasattr(instance, attribute):
1366
+ raise TypeError(
1367
+ "Unexpected structure of object.",
1368
+ "Expected attribute {} to be present inside instance {} of class {}".format(
1369
+ attribute, instance, clazz
1370
+ ),
1371
+ )
1372
+
1373
+ current_value = getattr(instance, attribute)
1374
+ if current_value is None:
1375
+ # only set value if object does not already have a value set
1376
+ if config_value is not None:
1377
+ setattr(instance, attribute, config_value)
1378
+ elif default_value is not None:
1379
+ setattr(instance, attribute, default_value)
1380
+
1381
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1382
+
1383
+ _log_sagemaker_config_single_substitution(current_value, config_value, config_path)
1384
+
1385
+ return instance
1386
+
1387
+
1388
+ def resolve_nested_dict_value_from_config(
1389
+ dictionary: dict,
1390
+ nested_keys: List[str],
1391
+ config_path: str,
1392
+ default_value: object = None,
1393
+ sagemaker_session=None,
1394
+ ):
1395
+ """Utility method that sets the value of a key path in a nested dictionary .
1396
+
1397
+ This method takes a dictionary and, if not already set, sets the value for the provided
1398
+ list of nested keys to the value fetched from the sagemaker_config or the default_value.
1399
+
1400
+ Uses this order of prioritization to determine what the value of the attribute should be:
1401
+ (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it
1402
+
1403
+ Args:
1404
+ dictionary: The dict to update.
1405
+ nested_keys: The paths of keys where the value should be checked and set if needed.
1406
+ config_path (str): A string denoting the path used to find the config value in the
1407
+ sagemaker config.
1408
+ default_value: The value to use if not present elsewhere.
1409
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1410
+ SageMaker interactions (default: None).
1411
+
1412
+ Returns:
1413
+ The updated dictionary that should be used by the caller instead of the
1414
+ 'dictionary' parameter that was passed in.
1415
+ """
1416
+ config_value = get_sagemaker_config_value(sagemaker_session, config_path)
1417
+
1418
+ if config_value is None and default_value is None:
1419
+ # if there is nothing to set, return early. And there is no need to traverse through
1420
+ # the dictionary or add nested dicts to it
1421
+ return dictionary
1422
+
1423
+ try:
1424
+ current_nested_value = get_nested_value(dictionary, nested_keys)
1425
+ except ValueError as e:
1426
+ logging.error("Failed to check dictionary for applying sagemaker config: %s", e)
1427
+ return dictionary
1428
+
1429
+ if current_nested_value is None:
1430
+ # only set value if not already set
1431
+ if config_value is not None:
1432
+ dictionary = set_nested_value(dictionary, nested_keys, config_value)
1433
+ elif default_value is not None:
1434
+ dictionary = set_nested_value(dictionary, nested_keys, default_value)
1435
+
1436
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1437
+
1438
+ _log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path)
1439
+
1440
+ return dictionary
1441
+
1442
+
1443
+ def update_list_of_dicts_with_values_from_config(
1444
+ input_list,
1445
+ config_key_path,
1446
+ required_key_paths: List[str] = None,
1447
+ union_key_paths: List[List[str]] = None,
1448
+ sagemaker_session=None,
1449
+ ):
1450
+ """Updates a list of dictionaries with missing values that are present in Config.
1451
+
1452
+ In some cases, config file might introduce new parameters which requires certain other
1453
+ parameters to be provided as part of the input list. Without those parameters, the underlying
1454
+ service will throw an exception. This method provides the capability to specify required key
1455
+ paths.
1456
+
1457
+ In some other cases, config file might introduce new parameters but the service API requires
1458
+ either an existing parameter or the new parameter that was supplied by config but not both
1459
+
1460
+ Args:
1461
+ input_list: The input list that was provided as a method parameter.
1462
+ config_key_path: The Key Path in the Config file that corresponds to the input_list
1463
+ parameter.
1464
+ required_key_paths (List[str]): List of required key paths that should be verified in the
1465
+ merged output. If a required key path is missing, we will not perform the merge for that
1466
+ item.
1467
+ union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify
1468
+ whether exactly zero/one of the parameters exist.
1469
+ For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or
1470
+ neither but not both, then pass [['X1', 'X2']]
1471
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1472
+ SageMaker interactions (default: None).
1473
+
1474
+ Returns:
1475
+ No output. In place merge happens.
1476
+ """
1477
+ if not input_list:
1478
+ return
1479
+ inputs_copy = copy.deepcopy(input_list)
1480
+ inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or []
1481
+ unmodified_inputs_from_config = copy.deepcopy(inputs_from_config)
1482
+
1483
+ for i in range(min(len(input_list), len(inputs_from_config))):
1484
+ dict_from_inputs = input_list[i]
1485
+ dict_from_config = inputs_from_config[i]
1486
+ merge_dicts(dict_from_config, dict_from_inputs)
1487
+ # Check if required key paths are present in merged dict (dict_from_config)
1488
+ required_key_path_check_passed = _validate_required_paths_in_a_dict(
1489
+ dict_from_config, required_key_paths
1490
+ )
1491
+ if not required_key_path_check_passed:
1492
+ # Don't do the merge, config is introducing a new parameter which needs a
1493
+ # corresponding required parameter.
1494
+ continue
1495
+ union_key_path_check_passed = _validate_union_key_paths_in_a_dict(
1496
+ dict_from_config, union_key_paths
1497
+ )
1498
+ if not union_key_path_check_passed:
1499
+ # Don't do the merge, Union parameters are not obeyed.
1500
+ continue
1501
+ input_list[i] = dict_from_config
1502
+
1503
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
1504
+
1505
+ _log_sagemaker_config_merge(
1506
+ source_value=inputs_copy,
1507
+ config_value=unmodified_inputs_from_config,
1508
+ merged_source_and_config_value=input_list,
1509
+ config_key_path=config_key_path,
1510
+ )
1511
+
1512
+
1513
+ def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool:
1514
+ """Placeholder docstring"""
1515
+ if not required_key_paths:
1516
+ return True
1517
+ for required_key_path in required_key_paths:
1518
+ if get_config_value(required_key_path, source_dict) is None:
1519
+ return False
1520
+ return True
1521
+
1522
+
1523
+ def _validate_union_key_paths_in_a_dict(
1524
+ source_dict, union_key_paths: List[List[str]] = None
1525
+ ) -> bool:
1526
+ """Placeholder docstring"""
1527
+ if not union_key_paths:
1528
+ return True
1529
+ for union_key_path in union_key_paths:
1530
+ union_parameter_present = False
1531
+ for key_path in union_key_path:
1532
+ if get_config_value(key_path, source_dict):
1533
+ if union_parameter_present:
1534
+ return False
1535
+ union_parameter_present = True
1536
+ return True
1537
+
1538
+
1539
+ def update_nested_dictionary_with_values_from_config(
1540
+ source_dict, config_key_path, sagemaker_session=None
1541
+ ) -> dict:
1542
+ """Updates a nested dictionary with missing values that are present in Config.
1543
+
1544
+ Args:
1545
+ source_dict: The input nested dictionary that was provided as method parameter.
1546
+ config_key_path: The Key Path in the Config file which corresponds to this
1547
+ source_dict parameter.
1548
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1549
+ SageMaker interactions (default: None).
1550
+
1551
+ Returns:
1552
+ dict: The merged nested dictionary that is updated with missing values that are present
1553
+ in the Config file.
1554
+ """
1555
+ inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {}
1556
+ original_config_dict_value = copy.deepcopy(inferred_config_dict)
1557
+ merge_dicts(inferred_config_dict, source_dict or {})
1558
+
1559
+ if original_config_dict_value == {}:
1560
+ # The config value is empty. That means either
1561
+ # (1) inferred_config_dict equals source_dict, or
1562
+ # (2) if source_dict was None, inferred_config_dict equals {}
1563
+ # We should return whatever source_dict was to be safe. Because if for example,
1564
+ # a VpcConfig is set to {} instead of None, some boto calls will fail due to
1565
+ # ParamValidationError (because a VpcConfig was specified but required parameters for
1566
+ # the VpcConfig were missing.)
1567
+
1568
+ # Don't need to print because no config value was used or defined
1569
+ return source_dict
1570
+
1571
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
1572
+
1573
+ _log_sagemaker_config_merge(
1574
+ source_value=source_dict,
1575
+ config_value=original_config_dict_value,
1576
+ merged_source_and_config_value=inferred_config_dict,
1577
+ config_key_path=config_key_path,
1578
+ )
1579
+
1580
+ return inferred_config_dict
1581
+
1582
+
1583
+ def stringify_object(obj: Any) -> str:
1584
+ """Returns string representation of object, returning only non-None fields."""
1585
+ non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
1586
+ return f"{type(obj).__name__}: {str(non_none_atts)}"
1587
+
1588
+
1589
+ def volume_size_supported(instance_type: str) -> bool:
1590
+ """Returns True if SageMaker allows volume_size to be used for the instance type.
1591
+
1592
+ Raises:
1593
+ ValueError: If the instance type is improperly formatted.
1594
+ """
1595
+
1596
+ try:
1597
+
1598
+ # local mode does not support volume size
1599
+ # instance type given as pipeline parameter does not support volume size
1600
+ # do not change the if statement order below.
1601
+ if is_pipeline_variable(instance_type) or instance_type.startswith("local"):
1602
+ return False
1603
+
1604
+ parts: List[str] = instance_type.split(".")
1605
+
1606
+ if len(parts) == 3 and parts[0] == "ml":
1607
+ parts = parts[1:]
1608
+
1609
+ if len(parts) != 2:
1610
+ raise ValueError(f"Failed to parse instance type '{instance_type}'")
1611
+
1612
+ # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc)
1613
+ # + g5 or g6 or p5 does not support attaching an EBS volume.
1614
+ family = parts[0]
1615
+
1616
+ unsupported_families = ["g5", "g6", "p5", "trn1"]
1617
+
1618
+ return "d" not in family and not any(
1619
+ family.startswith(prefix) for prefix in unsupported_families
1620
+ )
1621
+ except Exception as e:
1622
+ raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")
1623
+
1624
+
1625
+ def instance_supports_kms(instance_type: str) -> bool:
1626
+ """Returns True if SageMaker allows KMS keys to be attached to the instance.
1627
+
1628
+ Raises:
1629
+ ValueError: If the instance type is improperly formatted.
1630
+ """
1631
+ return volume_size_supported(instance_type)
1632
+
1633
+
1634
+ def get_instance_type_family(instance_type: str) -> str:
1635
+ """Return the family of the instance type.
1636
+
1637
+ Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
1638
+ or there is no match, return an empty string.
1639
+ """
1640
+ instance_type_family = ""
1641
+ if isinstance(instance_type, str):
1642
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
1643
+ if match is not None:
1644
+ instance_type_family = match[1]
1645
+ return instance_type_family
1646
+
1647
+
1648
+ def create_paginator_config(max_items: int = None, page_size: int = None) -> Dict[str, int]:
1649
+ """Placeholder docstring"""
1650
+ return {
1651
+ "MaxItems": max_items if max_items else MAX_ITEMS,
1652
+ "PageSize": page_size if page_size else PAGE_SIZE,
1653
+ }
1654
+
1655
+
1656
+ def format_tags(tags: Tags) -> List[TagsDict]:
1657
+ """Process tags to turn them into the expected format for Sagemaker."""
1658
+ if isinstance(tags, dict):
1659
+ return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]
1660
+
1661
+ return tags
1662
+
1663
+
1664
+ def _get_resolved_path(path):
1665
+ """Return the normalized absolute path of a given path.
1666
+
1667
+ abspath - returns the absolute path without resolving symlinks
1668
+ realpath - resolves the symlinks and gets the actual path
1669
+ normpath - normalizes paths (e.g. remove redudant separators)
1670
+ and handles platform-specific differences
1671
+ """
1672
+ return normpath(realpath(abspath(path)))
1673
+
1674
+
1675
+ def _is_bad_path(path, base):
1676
+ """Checks if the joined path (base directory + file path) is rooted under the base directory
1677
+
1678
+ Ensuring that the file does not attempt to access paths
1679
+ outside the expected directory structure.
1680
+
1681
+ Args:
1682
+ path (str): The file path.
1683
+ base (str): The base directory.
1684
+
1685
+ Returns:
1686
+ bool: True if the path is not rooted under the base directory, False otherwise.
1687
+ """
1688
+ # joinpath will ignore base if path is absolute
1689
+ return not _get_resolved_path(joinpath(base, path)).startswith(base)
1690
+
1691
+
1692
+ def _is_bad_link(info, base):
1693
+ """Checks if the link is rooted under the base directory.
1694
+
1695
+ Ensuring that the link does not attempt to access paths outside the expected directory structure
1696
+
1697
+ Args:
1698
+ info (tarfile.TarInfo): The tar file info.
1699
+ base (str): The base directory.
1700
+
1701
+ Returns:
1702
+ bool: True if the link is not rooted under the base directory, False otherwise.
1703
+ """
1704
+ # Links are interpreted relative to the directory containing the link
1705
+ tip = _get_resolved_path(joinpath(base, dirname(info.name)))
1706
+ return _is_bad_path(info.linkname, base=tip)
1707
+
1708
+
1709
+ def _get_safe_members(members):
1710
+ """A generator that yields members that are safe to extract.
1711
+
1712
+ It filters out bad paths and bad links.
1713
+
1714
+ Args:
1715
+ members (list): A list of members to check.
1716
+
1717
+ Yields:
1718
+ tarfile.TarInfo: The tar file info.
1719
+ """
1720
+ base = _get_resolved_path("")
1721
+
1722
+ for file_info in members:
1723
+ if _is_bad_path(file_info.name, base):
1724
+ logger.error("%s is blocked (illegal path)", file_info.name)
1725
+ elif file_info.issym() and _is_bad_link(file_info, base):
1726
+ logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
1727
+ elif file_info.islnk() and _is_bad_link(file_info, base):
1728
+ logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
1729
+ else:
1730
+ yield file_info
1731
+
1732
+
1733
+ def _validate_extracted_paths(extract_path):
1734
+ """Validate that extracted paths remain within the expected directory.
1735
+
1736
+ Performs post-extraction validation to ensure all extracted files and directories
1737
+ are within the intended extraction path.
1738
+
1739
+ Args:
1740
+ extract_path (str): The path where files were extracted.
1741
+
1742
+ Raises:
1743
+ ValueError: If any extracted file is outside the expected extraction path.
1744
+ """
1745
+ base = _get_resolved_path(extract_path)
1746
+
1747
+ for root, dirs, files in os.walk(extract_path):
1748
+ # Check directories
1749
+ for dir_name in dirs:
1750
+ dir_path = os.path.join(root, dir_name)
1751
+ resolved = _get_resolved_path(dir_path)
1752
+ if not resolved.startswith(base):
1753
+ logger.error("Extracted directory escaped extraction path: %s", dir_path)
1754
+ raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1755
+
1756
+ # Check files
1757
+ for file_name in files:
1758
+ file_path = os.path.join(root, file_name)
1759
+ resolved = _get_resolved_path(file_path)
1760
+ if not resolved.startswith(base):
1761
+ logger.error("Extracted file escaped extraction path: %s", file_path)
1762
+ raise ValueError(f"Extracted path outside expected directory: {file_path}")
1763
+
1764
+
1765
+ def custom_extractall_tarfile(tar, extract_path):
1766
+ """Extract a tarfile, optionally using data_filter if available.
1767
+
1768
+ # TODO: The function and it's usages can be deprecated once SageMaker Python SDK
1769
+ is upgraded to use Python 3.12+
1770
+
1771
+ If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
1772
+ Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.
1773
+
1774
+ Args:
1775
+ tar (tarfile.TarFile): The opened tarfile object.
1776
+ extract_path (str): The path to extract the contents of the tarfile.
1777
+
1778
+ Returns:
1779
+ None
1780
+ """
1781
+ if hasattr(tarfile, "data_filter"):
1782
+ tar.extractall(path=extract_path, filter="data")
1783
+ else:
1784
+ tar.extractall(path=extract_path, members=_get_safe_members(tar))
1785
+ # Re-validate extracted paths to catch symlink race conditions
1786
+ _validate_extracted_paths(extract_path)
1787
+
1788
+
1789
+ def can_model_package_source_uri_autopopulate(source_uri: str):
1790
+ """Checks if the source_uri can lead to auto-population of information in the Model registry.
1791
+
1792
+ Args:
1793
+ source_uri (str): The source uri.
1794
+
1795
+ Returns:
1796
+ bool: True if the source_uri can lead to auto-population, False otherwise.
1797
+ """
1798
+ return bool(
1799
+ re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri)
1800
+ )
1801
+
1802
+
1803
+ def flatten_dict(
1804
+ d: Dict[str, Any],
1805
+ max_flatten_depth=None,
1806
+ ) -> Dict[str, Any]:
1807
+ """Flatten a dictionary object.
1808
+
1809
+ d (Dict[str, Any]):
1810
+ The dict that will be flattened.
1811
+ max_flatten_depth (Optional[int]):
1812
+ Maximum depth to merge.
1813
+ """
1814
+
1815
+ def tuple_reducer(k1, k2):
1816
+ if k1 is None:
1817
+ return (k2,)
1818
+ return k1 + (k2,)
1819
+
1820
+ # check max_flatten_depth
1821
+ if max_flatten_depth is not None and max_flatten_depth < 1:
1822
+ raise ValueError("max_flatten_depth should not be less than 1.")
1823
+
1824
+ reducer = tuple_reducer
1825
+
1826
+ flat_dict = {}
1827
+
1828
+ def _flatten(_d, depth, parent=None):
1829
+ key_value_iterable = viewitems(_d)
1830
+ has_item = False
1831
+ for key, value in key_value_iterable:
1832
+ has_item = True
1833
+ flat_key = reducer(parent, key)
1834
+ if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth):
1835
+ has_child = _flatten(value, depth=depth + 1, parent=flat_key)
1836
+ if has_child:
1837
+ continue
1838
+
1839
+ if flat_key in flat_dict:
1840
+ raise ValueError("duplicated key '{}'".format(flat_key))
1841
+ flat_dict[flat_key] = value
1842
+
1843
+ return has_item
1844
+
1845
+ _flatten(d, depth=1)
1846
+ return flat_dict
1847
+
1848
+
1849
+ def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None:
1850
+ """Set a value to a sequence of nested keys."""
1851
+
1852
+ key = keys[0]
1853
+
1854
+ if len(keys) == 1:
1855
+ d[key] = value
1856
+ return
1857
+
1858
+ d = d.setdefault(key, {})
1859
+ nested_set_dict(d, keys[1:], value)
1860
+
1861
+
1862
+ def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
1863
+ """Unflatten dict-like object.
1864
+
1865
+ d (Dict[str, Any]) :
1866
+ The dict that will be unflattened.
1867
+ """
1868
+
1869
+ unflattened_dict = {}
1870
+ for flat_key, value in viewitems(d):
1871
+ key_tuple = flat_key
1872
+ nested_set_dict(unflattened_dict, key_tuple, value)
1873
+
1874
+ return unflattened_dict
1875
+
1876
+
1877
+ def deep_override_dict(
1878
+ dict1: Dict[str, Any], dict2: Dict[str, Any], skip_keys: Optional[List[str]] = None
1879
+ ) -> Dict[str, Any]:
1880
+ """Overrides any overlapping contents of dict1 with the contents of dict2."""
1881
+ if skip_keys is None:
1882
+ skip_keys = []
1883
+
1884
+ flattened_dict1 = flatten_dict(dict1)
1885
+ flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None}
1886
+ flattened_dict2 = flatten_dict(
1887
+ {key: value for key, value in dict2.items() if key not in skip_keys}
1888
+ )
1889
+ flattened_dict1.update(flattened_dict2)
1890
+ return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
1891
+
1892
+
1893
+ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
1894
+ """Resolve Routing Config
1895
+
1896
+ Args:
1897
+ routing_config (Optional[Dict[str, Any]]): The routing config.
1898
+
1899
+ Returns:
1900
+ Optional[Dict[str, Any]]: The resolved routing config.
1901
+
1902
+ Raises:
1903
+ ValueError: If the RoutingStrategy is invalid.
1904
+ """
1905
+
1906
+ if routing_config:
1907
+ routing_strategy = routing_config.get("RoutingStrategy", None)
1908
+ if routing_strategy:
1909
+ if isinstance(routing_strategy, RoutingStrategy):
1910
+ return {"RoutingStrategy": routing_strategy.name}
1911
+ if isinstance(routing_strategy, str) and (
1912
+ routing_strategy.upper() == RoutingStrategy.RANDOM.name
1913
+ or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
1914
+ ):
1915
+ return {"RoutingStrategy": routing_strategy.upper()}
1916
+ raise ValueError(
1917
+ "RoutingStrategy must be either RoutingStrategy.RANDOM "
1918
+ "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
1919
+ )
1920
+ return None
1921
+
1922
+
1923
+ @lru_cache
1924
+ def get_instance_rate_per_hour(
1925
+ instance_type: str,
1926
+ region: str,
1927
+ ) -> Optional[Dict[str, str]]:
1928
+ """Gets instance rate per hour for the given instance type.
1929
+
1930
+ Args:
1931
+ instance_type (str): The instance type.
1932
+ region (str): The region.
1933
+ Returns:
1934
+ Optional[Dict[str, str]]: Instance rate per hour.
1935
+ Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}.
1936
+
1937
+ Raises:
1938
+ Exception: An exception is raised if
1939
+ the IAM role is not authorized to perform pricing:GetProducts.
1940
+ or unexpected event happened.
1941
+ """
1942
+ region_name = "us-east-1"
1943
+ if region.startswith("eu") or region.startswith("af"):
1944
+ region_name = "eu-central-1"
1945
+ elif region.startswith("ap") or region.startswith("cn"):
1946
+ region_name = "ap-south-1"
1947
+
1948
+ pricing_client: boto3.client = boto3.client("pricing", region_name=region_name)
1949
+ res = pricing_client.get_products(
1950
+ ServiceCode="AmazonSageMaker",
1951
+ Filters=[
1952
+ {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
1953
+ {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
1954
+ {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
1955
+ ],
1956
+ )
1957
+
1958
+ price_list = res.get("PriceList", [])
1959
+ if len(price_list) > 0:
1960
+ price_data = price_list[0]
1961
+ if isinstance(price_data, str):
1962
+ price_data = json.loads(price_data)
1963
+
1964
+ instance_rate_per_hour = extract_instance_rate_per_hour(price_data)
1965
+ if instance_rate_per_hour is not None:
1966
+ return instance_rate_per_hour
1967
+ raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.")
1968
+
1969
+
1970
+ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
1971
+ """Extract instance rate per hour for the given Price JSON data.
1972
+
1973
+ Args:
1974
+ price_data (Dict[str, Any]): The Price JSON data.
1975
+ Returns:
1976
+ Optional[Dict[str, str], None]: Instance rate per hour.
1977
+ """
1978
+
1979
+ if price_data is not None:
1980
+ price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values()
1981
+ for dimension in price_dimensions:
1982
+ for price in dimension.get("priceDimensions", {}).values():
1983
+ for currency in price.get("pricePerUnit", {}).keys():
1984
+ value = price.get("pricePerUnit", {}).get(currency)
1985
+ if value is not None:
1986
+ value = str(round(float(value), 3))
1987
+ return {
1988
+ "unit": f"{currency}/Hr",
1989
+ "value": value,
1990
+ "name": "On-demand Instance Rate",
1991
+ }
1992
+ return None
1993
+
1994
+
1995
+ def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
1996
+ """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
1997
+
1998
+ Args:
1999
+ data (dict): The dictionary to be updated.
2000
+
2001
+ Returns:
2002
+ dict: The updated dictionary with keys in PascalCase.
2003
+ """
2004
+ result = {}
2005
+
2006
+ def convert_key(key):
2007
+ """Converts a snake_case key to PascalCase."""
2008
+ return "".join(part.capitalize() for part in key.split("_"))
2009
+
2010
+ def convert_value(value):
2011
+ """Recursively processes the value of a key-value pair."""
2012
+ if isinstance(value, dict):
2013
+ return camel_case_to_pascal_case(value)
2014
+ if isinstance(value, list):
2015
+ return [convert_value(item) for item in value]
2016
+
2017
+ return value
2018
+
2019
+ for key, value in data.items():
2020
+ result[convert_key(key)] = convert_value(value)
2021
+
2022
+ return result
2023
+
2024
+
2025
+ def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool:
2026
+ """Returns True if ``tag`` already exists.
2027
+
2028
+ Args:
2029
+ tag (TagsDict): The tag dictionary.
2030
+ curr_tags (Optional[Tags]): The current tags.
2031
+
2032
+ Returns:
2033
+ bool: True if the tag exists.
2034
+ """
2035
+ if curr_tags is None:
2036
+ return False
2037
+
2038
+ for curr_tag in curr_tags:
2039
+ if tag["Key"] == curr_tag["Key"]:
2040
+ return True
2041
+
2042
+ return False
2043
+
2044
+
2045
+ def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]:
2046
+ """Validates new tags against existing tags.
2047
+
2048
+ Args:
2049
+ new_tags (Optional[Tags]): The new tags.
2050
+ curr_tags (Optional[Tags]): The current tags.
2051
+
2052
+ Returns:
2053
+ Optional[Tags]: The updated tags.
2054
+ """
2055
+ if curr_tags is None:
2056
+ return new_tags
2057
+
2058
+ if curr_tags and isinstance(curr_tags, dict):
2059
+ curr_tags = [curr_tags]
2060
+
2061
+ if isinstance(new_tags, dict):
2062
+ if not tag_exists(new_tags, curr_tags):
2063
+ curr_tags.append(new_tags)
2064
+ elif isinstance(new_tags, list):
2065
+ for new_tag in new_tags:
2066
+ if not tag_exists(new_tag, curr_tags):
2067
+ curr_tags.append(new_tag)
2068
+
2069
+ return curr_tags
2070
+
2071
+
2072
+ def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
2073
+ """Remove a tag with the given key from the list of tags.
2074
+
2075
+ Args:
2076
+ key (str): The key of the tag to remove.
2077
+ tags (Optional[Tags]): The current list of tags.
2078
+
2079
+ Returns:
2080
+ Optional[Tags]: The updated list of tags with the tag removed.
2081
+ """
2082
+ if tags is None:
2083
+ return tags
2084
+ if isinstance(tags, dict):
2085
+ tags = [tags]
2086
+
2087
+ updated_tags = []
2088
+ for tag in tags:
2089
+ if tag["Key"] != key:
2090
+ updated_tags.append(tag)
2091
+
2092
+ if not updated_tags:
2093
+ return None
2094
+ if len(updated_tags) == 1:
2095
+ return updated_tags[0]
2096
+ return updated_tags
2097
+
2098
+
2099
+ def get_domain_for_region(region: str) -> str:
2100
+ """Returns the domain for the given region.
2101
+
2102
+ Args:
2103
+ region (str): AWS region name.
2104
+ """
2105
+ return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
2106
+
2107
+
2108
+ def camel_to_snake(camel_case_string: str) -> str:
2109
+ """Converts camelCase to snake_case_string using a regex.
2110
+
2111
+ This regex cannot handle whitespace ("camelString TwoWords")
2112
+ """
2113
+ return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2114
+
2115
+
2116
+ def walk_and_apply_json(
2117
+ json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
2118
+ ) -> Dict[Any, Any]:
2119
+ """Recursively walks a json object and applies a given function to the keys.
2120
+
2121
+ stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
2122
+ Any children of these keys will not have the application function applied to them.
2123
+ """
2124
+
2125
+ def _walk_and_apply_json(json_obj, new):
2126
+ if isinstance(json_obj, dict) and isinstance(new, dict):
2127
+ for key, value in json_obj.items():
2128
+ new_key = apply(key)
2129
+ if (stop_keys and new_key not in stop_keys) or stop_keys is None:
2130
+ if isinstance(value, dict):
2131
+ new[new_key] = {}
2132
+ _walk_and_apply_json(value, new=new[new_key])
2133
+ elif isinstance(value, list):
2134
+ new[new_key] = []
2135
+ for item in value:
2136
+ _walk_and_apply_json(item, new=new[new_key])
2137
+ else:
2138
+ new[new_key] = value
2139
+ else:
2140
+ new[new_key] = value
2141
+ elif isinstance(json_obj, dict) and isinstance(new, list):
2142
+ new.append(_walk_and_apply_json(json_obj, new={}))
2143
+ elif isinstance(json_obj, list) and isinstance(new, dict):
2144
+ new.update(json_obj)
2145
+ elif isinstance(json_obj, list) and isinstance(new, list):
2146
+ new.append(json_obj)
2147
+ elif isinstance(json_obj, str) and isinstance(new, list):
2148
+ new.append(json_obj)
2149
+ return new
2150
+
2151
+ return _walk_and_apply_json(json_obj, new={})
2152
+
2153
+
2154
+ def _wait_until(callable_fn, poll=5):
2155
+ """Placeholder docstring"""
2156
+ elapsed_time = 0
2157
+ result = None
2158
+ while result is None:
2159
+ try:
2160
+ elapsed_time += poll
2161
+ time.sleep(poll)
2162
+ result = callable_fn()
2163
+ except botocore.exceptions.ClientError as err:
2164
+ # For initial 5 mins we accept/pass AccessDeniedException.
2165
+ # The reason is to await tag propagation to avoid false AccessDenied claims for an
2166
+ # access policy based on resource tags, The caveat here is for true AccessDenied
2167
+ # cases the routine will fail after 5 mins
2168
+ if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
2169
+ logger.warning(
2170
+ "Received AccessDeniedException. This could mean the IAM role does not "
2171
+ "have the resource permissions, in which case please add resource access "
2172
+ "and retry. For cases where the role has tag based resource policy, "
2173
+ "continuing to wait for tag propagation.."
2174
+ )
2175
+ continue
2176
+ raise err
2177
+ return result
2178
+
2179
+
2180
+ def _flush_log_streams(
2181
+ stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
2182
+ ):
2183
+ """Placeholder docstring"""
2184
+ if len(stream_names) < instance_count:
2185
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
2186
+ # may be dynamic until we have a stream for every instance.
2187
+ try:
2188
+ streams = client.describe_log_streams(
2189
+ logGroupName=log_group,
2190
+ logStreamNamePrefix=job_name + "/",
2191
+ orderBy="LogStreamName",
2192
+ limit=min(instance_count, 50),
2193
+ )
2194
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
2195
+
2196
+ while "nextToken" in streams:
2197
+ streams = client.describe_log_streams(
2198
+ logGroupName=log_group,
2199
+ logStreamNamePrefix=job_name + "/",
2200
+ orderBy="LogStreamName",
2201
+ limit=50,
2202
+ )
2203
+
2204
+ stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
2205
+
2206
+ positions.update(
2207
+ [
2208
+ (s, sagemaker.core.logs.Position(timestamp=0, skip=0))
2209
+ for s in stream_names
2210
+ if s not in positions
2211
+ ]
2212
+ )
2213
+ except ClientError as e:
2214
+ # On the very first training job run on an account, there's no log group until
2215
+ # the container starts logging, so ignore any errors thrown about that
2216
+ err = e.response.get("Error", {})
2217
+ if err.get("Code", None) != "ResourceNotFoundException":
2218
+ raise
2219
+
2220
+ if len(stream_names) > 0:
2221
+ if dot:
2222
+ print("")
2223
+ dot = False
2224
+ for idx, event in sagemaker.core.logs.multi_stream_iter(
2225
+ client, log_group, stream_names, positions
2226
+ ):
2227
+ color_wrap(idx, event["message"])
2228
+ ts, count = positions[stream_names[idx]]
2229
+ if event["timestamp"] == ts:
2230
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2231
+ timestamp=ts, skip=count + 1
2232
+ )
2233
+ else:
2234
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2235
+ timestamp=event["timestamp"], skip=1
2236
+ )
2237
+ else:
2238
+ dot = True
2239
+ print(".", end="")
2240
+ sys.stdout.flush()
2241
+
2242
+
2243
+ class LogState(object):
2244
+ """Placeholder docstring"""
2245
+
2246
+ STARTING = 1
2247
+ WAIT_IN_PROGRESS = 2
2248
+ TAILING = 3
2249
+ JOB_COMPLETE = 4
2250
+ COMPLETE = 5
2251
+
2252
+
2253
+ _STATUS_CODE_TABLE = {
2254
+ "COMPLETED": "Completed",
2255
+ "INPROGRESS": "InProgress",
2256
+ "IN_PROGRESS": "InProgress",
2257
+ "FAILED": "Failed",
2258
+ "STOPPED": "Stopped",
2259
+ "STOPPING": "Stopping",
2260
+ "STARTING": "Starting",
2261
+ "PENDING": "Pending",
2262
+ }
2263
+
2264
+
2265
+ def _get_initial_job_state(description, status_key, wait):
2266
+ """Placeholder docstring"""
2267
+ status = description[status_key]
2268
+ job_already_completed = status in ("Completed", "Failed", "Stopped")
2269
+ return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
2270
+
2271
+
2272
+ def _logs_init(boto_session, description, job):
2273
+ """Placeholder docstring"""
2274
+ if job == "Training":
2275
+ if "InstanceGroups" in description["ResourceConfig"]:
2276
+ instance_count = 0
2277
+ for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
2278
+ instance_count += instanceGroup["InstanceCount"]
2279
+ else:
2280
+ instance_count = description["ResourceConfig"]["InstanceCount"]
2281
+ elif job == "Transform":
2282
+ instance_count = description["TransformResources"]["InstanceCount"]
2283
+ elif job == "Processing":
2284
+ instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
2285
+ elif job == "AutoML":
2286
+ instance_count = 0
2287
+
2288
+ stream_names = [] # The list of log streams
2289
+ positions = {} # The current position in each stream, map of stream name -> position
2290
+
2291
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
2292
+ # to be interrupted by a transient exception.
2293
+ config = botocore.config.Config(retries={"max_attempts": 15})
2294
+ client = boto_session.client("logs", config=config)
2295
+ log_group = "/aws/sagemaker/" + job + "Jobs"
2296
+
2297
+ dot = False
2298
+
2299
+ color_wrap = sagemaker.core.logs.ColorWrap()
2300
+
2301
+ return instance_count, stream_names, positions, client, log_group, dot, color_wrap
2302
+
2303
+
2304
+ def _check_job_status(job, desc, status_key_name):
2305
+ """Check to see if the job completed successfully.
2306
+
2307
+ If not, construct and raise a exceptions. (UnexpectedStatusException).
2308
+
2309
+ Args:
2310
+ job (str): The name of the job to check.
2311
+ desc (dict[str, str]): The result of ``describe_training_job()``.
2312
+ status_key_name (str): Status key name to check for.
2313
+
2314
+ Raises:
2315
+ exceptions.CapacityError: If the training job fails with CapacityError.
2316
+ exceptions.UnexpectedStatusException: If the training job fails.
2317
+ """
2318
+ status = desc[status_key_name]
2319
+ # If the status is capital case, then convert it to Camel case
2320
+ status = _STATUS_CODE_TABLE.get(status, status)
2321
+
2322
+ if status == "Stopped":
2323
+ logger.warning(
2324
+ "Job ended with status 'Stopped' rather than 'Completed'. "
2325
+ "This could mean the job timed out or stopped early for some other reason: "
2326
+ "Consider checking whether it completed as you expect."
2327
+ )
2328
+ elif status != "Completed":
2329
+ reason = desc.get("FailureReason", "(No reason provided)")
2330
+ job_type = status_key_name.replace("JobStatus", " job")
2331
+ troubleshooting = (
2332
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
2333
+ "sagemaker-python-sdk-troubleshooting.html"
2334
+ )
2335
+ message = (
2336
+ "Error for {job_type} {job_name}: {status}. Reason: {reason}. "
2337
+ "Check troubleshooting guide for common errors: {troubleshooting}"
2338
+ ).format(
2339
+ job_type=job_type,
2340
+ job_name=job,
2341
+ status=status,
2342
+ reason=reason,
2343
+ troubleshooting=troubleshooting,
2344
+ )
2345
+ if "CapacityError" in str(reason):
2346
+ raise exceptions.CapacityError(
2347
+ message=message,
2348
+ allowed_statuses=["Completed", "Stopped"],
2349
+ actual_status=status,
2350
+ )
2351
+ raise exceptions.UnexpectedStatusException(
2352
+ message=message,
2353
+ allowed_statuses=["Completed", "Stopped"],
2354
+ actual_status=status,
2355
+ )
2356
+
2357
+
2358
+ def _create_resource(create_fn):
2359
+ """Call create function and accepts/pass when resource already exists.
2360
+
2361
+ This is a helper function to use an existing resource if found when creating.
2362
+
2363
+ Args:
2364
+ create_fn: Create resource function.
2365
+
2366
+ Returns:
2367
+ (bool): True if new resource was created, False if resource already exists.
2368
+ """
2369
+ try:
2370
+ create_fn()
2371
+ # create function succeeded, resource does not exist already
2372
+ return True
2373
+ except ClientError as ce:
2374
+ error_code = ce.response["Error"]["Code"]
2375
+ error_message = ce.response["Error"]["Message"]
2376
+ already_exists_exceptions = ["ValidationException", "ResourceInUse"]
2377
+ already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
2378
+ if not (
2379
+ error_code in already_exists_exceptions
2380
+ and any(p in error_message for p in already_exists_msg_patterns)
2381
+ ):
2382
+ raise ce
2383
+ # no new resource created as resource already exists
2384
+ return False
2385
+
2386
+
2387
+ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
2388
+ """Checks whether an S3 URI is valid.
2389
+
2390
+ Args:
2391
+ s3_uri (Optional[str]): The S3 URI.
2392
+
2393
+ Returns:
2394
+ bool: Whether the S3 URI is valid.
2395
+ """
2396
+ if s3_uri is None:
2397
+ return False
2398
+
2399
+ return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None