sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2281 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Placeholder docstring"""
14
+ from __future__ import absolute_import
15
+
16
+ import sys
17
+ import contextlib
18
+ import copy
19
+ import errno
20
+ import inspect
21
+ import logging
22
+ import os
23
+ import random
24
+ import re
25
+ import shutil
26
+ import tarfile
27
+ import tempfile
28
+ import time
29
+ from functools import lru_cache
30
+ from typing import Union, Any, List, Optional, Dict
31
+ import json
32
+ import abc
33
+ import uuid
34
+ from datetime import datetime
35
+ from os.path import abspath, realpath, dirname, normpath, join as joinpath
36
+
37
+ from importlib import import_module
38
+
39
+ import boto3
40
+ import botocore
41
+ from botocore.utils import merge_dicts
42
+ from botocore import exceptions
43
+ from botocore.exceptions import ClientError
44
+ from six.moves.urllib import parse
45
+ from six import viewitems
46
+
47
+ import sagemaker
48
+
49
+ from sagemaker.core.enums import RoutingStrategy
50
+ from sagemaker.core.session_settings import SessionSettings
51
+ from sagemaker.core.workflow import is_pipeline_variable, is_pipeline_parameter_string
52
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
53
+ from enum import Enum
54
+
55
+ ALTERNATE_DOMAINS = {
56
+ "cn-north-1": "amazonaws.com.cn",
57
+ "cn-northwest-1": "amazonaws.com.cn",
58
+ "us-iso-east-1": "c2s.ic.gov",
59
+ "us-isob-east-1": "sc2s.sgov.gov",
60
+ "us-isof-south-1": "csp.hci.ic.gov",
61
+ "us-isof-east-1": "csp.hci.ic.gov",
62
+ }
63
+
64
+ ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
65
+ MODEL_PACKAGE_ARN_PATTERN = (
66
+ r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
67
+ )
68
+ MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)"
69
+ MAX_BUCKET_PATHS_COUNT = 5
70
+ S3_PREFIX = "s3://"
71
+ HTTP_PREFIX = "http://"
72
+ HTTPS_PREFIX = "https://"
73
+ DEFAULT_SLEEP_TIME_SECONDS = 10
74
+ WAITING_DOT_NUMBER = 10
75
+ MAX_ITEMS = 100
76
+ PAGE_SIZE = 10
77
+
78
+ logger = logging.getLogger(__name__)
79
+
80
+ TagsDict = Dict[str, Union[str, PipelineVariable]]
81
+ Tags = Union[List[TagsDict], TagsDict]
82
+
83
+
84
+ class ModelApprovalStatusEnum(str, Enum):
85
+ """Model package approval status enumerator"""
86
+
87
+ APPROVED = "Approved"
88
+ REJECTED = "Rejected"
89
+ PENDING_MANUAL_APPROVAL = "PendingManualApproval"
90
+
91
+
92
+ # Use the base name of the image as the job name if the user doesn't give us one
93
+ def name_from_image(image, max_length=63):
94
+ """Create a training job name based on the image name and a timestamp.
95
+
96
+ Args:
97
+ image (str): Image name.
98
+
99
+ Returns:
100
+ str: Training job name using the algorithm from the image name and a
101
+ timestamp.
102
+ max_length (int): Maximum length for the resulting string (default: 63).
103
+ """
104
+ return name_from_base(base_name_from_image(image), max_length=max_length)
105
+
106
+
107
+ def name_from_base(base, max_length=63, short=False):
108
+ """Append a timestamp to the provided string.
109
+
110
+ This function assures that the total length of the resulting string is
111
+ not longer than the specified max length, trimming the input parameter if
112
+ necessary.
113
+
114
+ Args:
115
+ base (str): String used as prefix to generate the unique name.
116
+ max_length (int): Maximum length for the resulting string (default: 63).
117
+ short (bool): Whether or not to use a truncated timestamp (default: False).
118
+
119
+ Returns:
120
+ str: Input parameter with appended timestamp.
121
+ """
122
+ timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp()
123
+ trimmed_base = base[: max_length - len(timestamp) - 1]
124
+ return "{}-{}".format(trimmed_base, timestamp)
125
+
126
+
127
+ def unique_name_from_base_uuid4(base, max_length=63):
128
+ """Append a UUID to the provided string.
129
+
130
+ This function is used to generate a name using UUID instead of timestamps
131
+ for uniqueness.
132
+
133
+ Args:
134
+ base (str): String used as prefix to generate the unique name.
135
+ max_length (int): Maximum length for the resulting string (default: 63).
136
+
137
+ Returns:
138
+ str: Input parameter with appended timestamp.
139
+ """
140
+ random.seed(int(uuid.uuid4())) # using uuid to randomize
141
+ unique = str(uuid.uuid4())
142
+ trimmed_base = base[: max_length - len(unique) - 1]
143
+ return "{}-{}".format(trimmed_base, unique)
144
+
145
+
146
+ def unique_name_from_base(base, max_length=63):
147
+ """Placeholder Docstring"""
148
+ random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used.
149
+ unique = "%04x" % random.randrange(16**4) # 4-digit hex
150
+ ts = str(int(time.time()))
151
+ available_length = max_length - 2 - len(ts) - len(unique)
152
+ trimmed = base[:available_length]
153
+ return "{}-{}-{}".format(trimmed, ts, unique)
154
+
155
+
156
+ def base_name_from_image(image, default_base_name=None):
157
+ """Extract the base name of the image to use as the 'algorithm name' for the job.
158
+
159
+ Args:
160
+ image (str): Image name.
161
+ default_base_name (str): The default base name
162
+
163
+ Returns:
164
+ str: Algorithm name, as extracted from the image name.
165
+ """
166
+ if is_pipeline_variable(image):
167
+ if is_pipeline_parameter_string(image) and image.default_value:
168
+ image_str = image.default_value
169
+ else:
170
+ return default_base_name if default_base_name else "base_name"
171
+ else:
172
+ image_str = image
173
+
174
+ m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
175
+ base_name = m.group(2) if m else image_str
176
+ return base_name
177
+
178
+
179
+ def base_from_name(name):
180
+ """Extract the base name of the resource name (for use with future resource name generation).
181
+
182
+ This function looks for timestamps that match the ones produced by
183
+ :func:`~sagemaker.utils.name_from_base`.
184
+
185
+ Args:
186
+ name (str): The resource name.
187
+
188
+ Returns:
189
+ str: The base name, as extracted from the resource name.
190
+ """
191
+ m = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", name)
192
+ return m.group(1) if m else name
193
+
194
+
195
+ def sagemaker_timestamp():
196
+ """Return a timestamp with millisecond precision."""
197
+ moment = time.time()
198
+ moment_ms = repr(moment).split(".")[1][:3]
199
+ return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment))
200
+
201
+
202
+ def sagemaker_short_timestamp():
203
+ """Return a timestamp that is relatively short in length"""
204
+ return time.strftime("%y%m%d-%H%M")
205
+
206
+
207
+ def build_dict(key, value):
208
+ """Return a dict of key and value pair if value is not None, otherwise return an empty dict.
209
+
210
+ Args:
211
+ key (str): input key
212
+ value (str): input value
213
+
214
+ Returns:
215
+ dict: dict of key and value or an empty dict.
216
+ """
217
+ if value:
218
+ return {key: value}
219
+ return {}
220
+
221
+
222
+ def get_config_value(key_path, config):
223
+ """Placeholder Docstring"""
224
+ if config is None:
225
+ return None
226
+
227
+ current_section = config
228
+ for key in key_path.split("."):
229
+ if key in current_section:
230
+ current_section = current_section[key]
231
+ else:
232
+ return None
233
+
234
+ return current_section
235
+
236
+
237
+ def get_nested_value(dictionary: dict, nested_keys: List[str]):
238
+ """Returns a nested value from the given dictionary, and None if none present.
239
+
240
+ Raises
241
+ ValueError if the dictionary structure does not match the nested_keys
242
+ """
243
+ if (
244
+ dictionary is not None
245
+ and isinstance(dictionary, dict)
246
+ and nested_keys is not None
247
+ and len(nested_keys) > 0
248
+ ):
249
+
250
+ current_section = dictionary
251
+
252
+ for key in nested_keys[:-1]:
253
+ current_section = current_section.get(key, None)
254
+ if current_section is None:
255
+ # means the full path of nested_keys doesnt exist in the dictionary
256
+ # or the value was set to None
257
+ return None
258
+ if not isinstance(current_section, dict):
259
+ raise ValueError(
260
+ "Unexpected structure of dictionary.",
261
+ "Expected value of type dict at key '{}' but got '{}' for dict '{}'".format(
262
+ key, current_section, dictionary
263
+ ),
264
+ )
265
+ return current_section.get(nested_keys[-1], None)
266
+
267
+ return None
268
+
269
+
270
+ def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object):
271
+ """Sets a nested value in a dictionary.
272
+
273
+ This sets a nested value inside the given dictionary and returns the new dictionary. Note: if
274
+ provided an unintended list of nested keys, this can overwrite an unexpected part of the dict.
275
+ Recommended to use after a check with get_nested_value first
276
+ """
277
+
278
+ if dictionary is None:
279
+ dictionary = {}
280
+
281
+ if (
282
+ dictionary is not None
283
+ and isinstance(dictionary, dict)
284
+ and nested_keys is not None
285
+ and len(nested_keys) > 0
286
+ ):
287
+ current_section = dictionary
288
+ for key in nested_keys[:-1]:
289
+ if (
290
+ key not in current_section
291
+ or current_section[key] is None
292
+ or not isinstance(current_section[key], dict)
293
+ ):
294
+ current_section[key] = {}
295
+ current_section = current_section[key]
296
+
297
+ current_section[nested_keys[-1]] = value_to_set
298
+ return dictionary
299
+
300
+
301
+ def get_short_version(framework_version):
302
+ """Return short version in the format of x.x
303
+
304
+ Args:
305
+ framework_version: The version string to be shortened.
306
+
307
+ Returns:
308
+ str: The short version string
309
+ """
310
+ return ".".join(framework_version.split(".")[:2])
311
+
312
+
313
+ def secondary_training_status_changed(current_job_description, prev_job_description):
314
+ """Returns true if training job's secondary status message has changed.
315
+
316
+ Args:
317
+ current_job_description: Current job description, returned from DescribeTrainingJob call.
318
+ prev_job_description: Previous job description, returned from DescribeTrainingJob call.
319
+
320
+ Returns:
321
+ boolean: Whether the secondary status message of a training job changed
322
+ or not.
323
+ """
324
+ current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions")
325
+ if (
326
+ current_secondary_status_transitions is None
327
+ or len(current_secondary_status_transitions) == 0
328
+ ):
329
+ return False
330
+
331
+ prev_job_secondary_status_transitions = (
332
+ prev_job_description.get("SecondaryStatusTransitions")
333
+ if prev_job_description is not None
334
+ else None
335
+ )
336
+
337
+ last_message = (
338
+ prev_job_secondary_status_transitions[-1]["StatusMessage"]
339
+ if prev_job_secondary_status_transitions is not None
340
+ and len(prev_job_secondary_status_transitions) > 0
341
+ else ""
342
+ )
343
+
344
+ message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"]
345
+
346
+ return message != last_message
347
+
348
+
349
+ def secondary_training_status_message(job_description, prev_description):
350
+ """Returns a string contains last modified time and the secondary training job status message.
351
+
352
+ Args:
353
+ job_description: Returned response from DescribeTrainingJob call
354
+ prev_description: Previous job description from DescribeTrainingJob call
355
+
356
+ Returns:
357
+ str: Job status string to be printed.
358
+ """
359
+
360
+ if (
361
+ job_description is None
362
+ or job_description.get("SecondaryStatusTransitions") is None
363
+ or len(job_description.get("SecondaryStatusTransitions")) == 0
364
+ ):
365
+ return ""
366
+
367
+ prev_description_secondary_transitions = (
368
+ prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None
369
+ )
370
+ prev_transitions_num = (
371
+ len(prev_description["SecondaryStatusTransitions"])
372
+ if prev_description_secondary_transitions is not None
373
+ else 0
374
+ )
375
+ current_transitions = job_description["SecondaryStatusTransitions"]
376
+
377
+ if len(current_transitions) == prev_transitions_num:
378
+ # Secondary status is not changed but the message changed.
379
+ transitions_to_print = current_transitions[-1:]
380
+ else:
381
+ # Secondary status is changed we need to print all the entries.
382
+ transitions_to_print = current_transitions[
383
+ prev_transitions_num - len(current_transitions) :
384
+ ]
385
+
386
+ status_strs = []
387
+ for transition in transitions_to_print:
388
+ message = transition["StatusMessage"]
389
+ time_str = datetime.utcfromtimestamp(
390
+ time.mktime(job_description["LastModifiedTime"].timetuple())
391
+ ).strftime("%Y-%m-%d %H:%M:%S")
392
+ status_strs.append("{} {} - {}".format(time_str, transition["Status"], message))
393
+
394
+ return "\n".join(status_strs)
395
+
396
+
397
+ def download_folder(bucket_name, prefix, target, sagemaker_session):
398
+ """Download a folder from S3 to a local path
399
+
400
+ Args:
401
+ bucket_name (str): S3 bucket name
402
+ prefix (str): S3 prefix within the bucket that will be downloaded. Can
403
+ be a single file.
404
+ target (str): destination path where the downloaded items will be placed
405
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
406
+ interact with S3.
407
+ """
408
+ s3 = sagemaker_session.s3_resource
409
+
410
+ prefix = prefix.lstrip("/")
411
+
412
+ # Try to download the prefix as an object first, in case it is a file and not a 'directory'.
413
+ # Do this first, in case the object has broader permissions than the bucket.
414
+ if not prefix.endswith("/"):
415
+ try:
416
+ file_destination = os.path.join(target, os.path.basename(prefix))
417
+ s3.Object(bucket_name, prefix).download_file(file_destination)
418
+ return
419
+ except botocore.exceptions.ClientError as e:
420
+ err_info = e.response["Error"]
421
+ if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
422
+ # S3 also throws this error if the object is a folder,
423
+ # so assume that is the case here, and then raise for an actual 404 later.
424
+ pass
425
+ else:
426
+ raise
427
+
428
+ _download_files_under_prefix(bucket_name, prefix, target, s3)
429
+
430
+
431
+ def _download_files_under_prefix(bucket_name, prefix, target, s3):
432
+ """Download all S3 files which match the given prefix
433
+
434
+ Args:
435
+ bucket_name (str): S3 bucket name
436
+ prefix (str): S3 prefix within the bucket that will be downloaded
437
+ target (str): destination path where the downloaded items will be placed
438
+ s3 (boto3.resources.base.ServiceResource): S3 resource
439
+ """
440
+ bucket = s3.Bucket(bucket_name)
441
+ for obj_sum in bucket.objects.filter(Prefix=prefix):
442
+ # if obj_sum is a folder object skip it.
443
+ if obj_sum.key.endswith("/"):
444
+ continue
445
+ obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
446
+ s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
447
+ file_path = os.path.join(target, s3_relative_path)
448
+
449
+ try:
450
+ os.makedirs(os.path.dirname(file_path))
451
+ except OSError as exc:
452
+ # EEXIST means the folder already exists, this is safe to skip
453
+ # anything else will be raised.
454
+ if exc.errno != errno.EEXIST:
455
+ raise
456
+ obj.download_file(file_path)
457
+
458
+
459
+ def create_tar_file(source_files, target=None):
460
+ """Create a tar file containing all the source_files
461
+
462
+ Args:
463
+ source_files: (List[str]): List of file paths that will be contained in the tar file
464
+ target:
465
+
466
+ Returns:
467
+ (str): path to created tar file
468
+ """
469
+ if target:
470
+ filename = target
471
+ else:
472
+ _, filename = tempfile.mkstemp()
473
+
474
+ with tarfile.open(filename, mode="w:gz", dereference=True) as t:
475
+ for sf in source_files:
476
+ # Add all files from the directory into the root of the directory structure of the tar
477
+ t.add(sf, arcname=os.path.basename(sf))
478
+ return filename
479
+
480
+
481
+ @contextlib.contextmanager
482
+ def _tmpdir(suffix="", prefix="tmp", directory=None):
483
+ """Create a temporary directory with a context manager.
484
+
485
+ The file is deleted when the context exits, even when there's an exception.
486
+ The prefix, suffix, and dir arguments are the same as for mkstemp().
487
+
488
+ Args:
489
+ suffix (str): If suffix is specified, the file name will end with that
490
+ suffix, otherwise there will be no suffix.
491
+ prefix (str): If prefix is specified, the file name will begin with that
492
+ prefix; otherwise, a default prefix is used.
493
+ directory (str): If a directory is specified, the file will be downloaded
494
+ in this directory; otherwise, a default directory is used.
495
+
496
+ Returns:
497
+ str: path to the directory
498
+ """
499
+ if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
500
+ raise ValueError(
501
+ "Inputted directory for storing newly generated temporary "
502
+ f"directory does not exist: '{directory}'"
503
+ )
504
+ tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
505
+ try:
506
+ yield tmp
507
+ finally:
508
+ shutil.rmtree(tmp)
509
+
510
+
511
+ def repack_model(
512
+ inference_script,
513
+ source_directory,
514
+ dependencies,
515
+ model_uri,
516
+ repacked_model_uri,
517
+ sagemaker_session,
518
+ kms_key=None,
519
+ ):
520
+ """Unpack model tarball and creates a new model tarball with the provided code script.
521
+
522
+ This function does the following: - uncompresses model tarball from S3 or
523
+ local system into a temp folder - replaces the inference code from the model
524
+ with the new code provided - compresses the new model tarball and saves it
525
+ in S3 or local file system
526
+
527
+ Args:
528
+ inference_script (str): path or basename of the inference script that
529
+ will be packed into the model
530
+ source_directory (str): path including all the files that will be packed
531
+ into the model
532
+ dependencies (list[str]): A list of paths to directories (absolute or
533
+ relative) with any additional libraries that will be exported to the
534
+ container (default: []). The library folders will be copied to
535
+ SageMaker in the same folder where the entrypoint is copied.
536
+ Example
537
+
538
+ The following call >>> Estimator(entry_point='train.py',
539
+ dependencies=['my/libs/common', 'virtual-env']) results in the
540
+ following inside the container:
541
+
542
+ >>> $ ls
543
+
544
+ >>> opt/ml/code
545
+ >>> |------ train.py
546
+ >>> |------ common
547
+ >>> |------ virtual-env
548
+ model_uri (str): S3 or file system location of the original model tar
549
+ repacked_model_uri (str): path or file system location where the new
550
+ model will be saved
551
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
552
+ interact with S3.
553
+ kms_key (str): KMS key ARN for encrypting the repacked model file
554
+
555
+ Returns:
556
+ str: path to the new packed model
557
+ """
558
+ dependencies = dependencies or []
559
+
560
+ local_download_dir = (
561
+ None
562
+ if sagemaker_session.settings is None
563
+ or sagemaker_session.settings.local_download_dir is None
564
+ else sagemaker_session.settings.local_download_dir
565
+ )
566
+ with _tmpdir(directory=local_download_dir) as tmp:
567
+ model_dir = _extract_model(model_uri, sagemaker_session, tmp)
568
+
569
+ _create_or_update_code_dir(
570
+ model_dir,
571
+ inference_script,
572
+ source_directory,
573
+ dependencies,
574
+ sagemaker_session,
575
+ tmp,
576
+ )
577
+
578
+ tmp_model_path = os.path.join(tmp, "temp-model.tar.gz")
579
+ with tarfile.open(tmp_model_path, mode="w:gz") as t:
580
+ t.add(model_dir, arcname=os.path.sep)
581
+
582
+ _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)
583
+
584
+
585
+ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
586
+ """Placeholder docstring"""
587
+ if repacked_model_uri.lower().startswith("s3://"):
588
+ url = parse.urlparse(repacked_model_uri)
589
+ bucket, key = url.netloc, url.path.lstrip("/")
590
+ new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
591
+
592
+ settings = (
593
+ sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
594
+ )
595
+ encrypt_artifact = settings.encrypt_repacked_artifacts
596
+
597
+ if kms_key:
598
+ extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
599
+ elif encrypt_artifact:
600
+ extra_args = {"ServerSideEncryption": "aws:kms"}
601
+ else:
602
+ extra_args = None
603
+ sagemaker_session.boto_session.resource(
604
+ "s3", region_name=sagemaker_session.boto_region_name
605
+ ).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
606
+ else:
607
+ shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
608
+
609
+
610
+ def _create_or_update_code_dir(
611
+ model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
612
+ ):
613
+ """Placeholder docstring"""
614
+ code_dir = os.path.join(model_dir, "code")
615
+ if source_directory and source_directory.lower().startswith("s3://"):
616
+ local_code_path = os.path.join(tmp, "local_code.tar.gz")
617
+ download_file_from_url(source_directory, local_code_path, sagemaker_session)
618
+
619
+ with tarfile.open(name=local_code_path, mode="r:gz") as t:
620
+ custom_extractall_tarfile(t, code_dir)
621
+
622
+ elif source_directory:
623
+ if os.path.exists(code_dir):
624
+ shutil.rmtree(code_dir)
625
+ shutil.copytree(source_directory, code_dir)
626
+ else:
627
+ if not os.path.exists(code_dir):
628
+ os.mkdir(code_dir)
629
+ try:
630
+ shutil.copy2(inference_script, code_dir)
631
+ except FileNotFoundError:
632
+ if os.path.exists(os.path.join(code_dir, inference_script)):
633
+ pass
634
+ else:
635
+ raise
636
+
637
+ for dependency in dependencies:
638
+ lib_dir = os.path.join(code_dir, "lib")
639
+ if os.path.isdir(dependency):
640
+ shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
641
+ else:
642
+ if not os.path.exists(lib_dir):
643
+ os.mkdir(lib_dir)
644
+ shutil.copy2(dependency, lib_dir)
645
+
646
+
647
+ def _extract_model(model_uri, sagemaker_session, tmp):
648
+ """Placeholder docstring"""
649
+ tmp_model_dir = os.path.join(tmp, "model")
650
+ os.mkdir(tmp_model_dir)
651
+ if model_uri.lower().startswith("s3://"):
652
+ local_model_path = os.path.join(tmp, "tar_file")
653
+ download_file_from_url(model_uri, local_model_path, sagemaker_session)
654
+ else:
655
+ local_model_path = model_uri.replace("file://", "")
656
+ with tarfile.open(name=local_model_path, mode="r:gz") as t:
657
+ custom_extractall_tarfile(t, tmp_model_dir)
658
+ return tmp_model_dir
659
+
660
+
661
+ def download_file_from_url(url, dst, sagemaker_session):
662
+ """Placeholder docstring"""
663
+ url = parse.urlparse(url)
664
+ bucket, key = url.netloc, url.path.lstrip("/")
665
+
666
+ download_file(bucket, key, dst, sagemaker_session)
667
+
668
+
669
+ def download_file(bucket_name, path, target, sagemaker_session):
670
+ """Download a Single File from S3 into a local path
671
+
672
+ Args:
673
+ bucket_name (str): S3 bucket name
674
+ path (str): file path within the bucket
675
+ target (str): destination directory for the downloaded file.
676
+ sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
677
+ interact with S3.
678
+ """
679
+ path = path.lstrip("/")
680
+ boto_session = sagemaker_session.boto_session
681
+
682
+ s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
683
+ bucket = s3.Bucket(bucket_name)
684
+ bucket.download_file(path, target)
685
+
686
+
687
+ def sts_regional_endpoint(region):
688
+ """Get the AWS STS endpoint specific for the given region.
689
+
690
+ We need this function because the AWS SDK does not yet honor
691
+ the ``region_name`` parameter when creating an AWS STS client.
692
+
693
+ For the list of regional endpoints, see
694
+ https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
695
+
696
+ Args:
697
+ region (str): AWS region name
698
+
699
+ Returns:
700
+ str: AWS STS regional endpoint
701
+ """
702
+ endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
703
+ if region == "il-central-1" and not endpoint_data:
704
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
705
+ return "https://{}".format(endpoint_data["hostname"])
706
+
707
+
708
+ def retries(
709
+ max_retry_count,
710
+ exception_message_prefix,
711
+ seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS,
712
+ ):
713
+ """Retries until max retry count is reached.
714
+
715
+ Args:
716
+ max_retry_count (int): The retry count.
717
+ exception_message_prefix (str): The message to include in the exception on failure.
718
+ seconds_to_sleep (int): The number of seconds to sleep between executions.
719
+
720
+ """
721
+ for i in range(max_retry_count):
722
+ yield i
723
+ time.sleep(seconds_to_sleep)
724
+
725
+ raise Exception(
726
+ "'{}' has reached the maximum retry count of {}".format(
727
+ exception_message_prefix, max_retry_count
728
+ )
729
+ )
730
+
731
+
732
+ def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
733
+ """Retry with backoff until maximum attempts are reached
734
+
735
+ Args:
736
+ callable_func (Callable): The callable function to retry.
737
+ num_attempts (int): The maximum number of attempts to retry.(Default: 8)
738
+ botocore_client_error_code (str): The specific Botocore ClientError exception error code
739
+ on which to retry on.
740
+ If provided other exceptions will be raised directly w/o retry.
741
+ If not provided, retry on any exception.
742
+ (Default: None)
743
+ """
744
+ if num_attempts < 1:
745
+ raise ValueError(
746
+ "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts)
747
+ )
748
+ for i in range(num_attempts):
749
+ try:
750
+ return callable_func()
751
+ except Exception as ex: # pylint: disable=broad-except
752
+ if not botocore_client_error_code or (
753
+ botocore_client_error_code
754
+ and isinstance(ex, botocore.exceptions.ClientError)
755
+ and ex.response["Error"]["Code"] # pylint: disable=no-member
756
+ == botocore_client_error_code
757
+ ):
758
+ if i == num_attempts - 1:
759
+ raise ex
760
+ else:
761
+ raise ex
762
+ logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
763
+ time.sleep(2**i)
764
+
765
+
766
+ def _botocore_resolver():
767
+ """Get the DNS suffix for the given region.
768
+
769
+ Args:
770
+ region (str): AWS region name
771
+
772
+ Returns:
773
+ str: the DNS suffix
774
+ """
775
+ loader = botocore.loaders.create_loader()
776
+ return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
777
+
778
+
779
+ def aws_partition(region):
780
+ """Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
781
+
782
+ Args:
783
+ region (str): The region name for which to return the corresponding partition.
784
+ Ex: "cn-north-1"
785
+
786
+ Returns:
787
+ str: partition corresponding to the region name passed in. Ex: "aws-cn"
788
+ """
789
+ endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
790
+ if region == "il-central-1" and not endpoint_data:
791
+ endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
792
+ return endpoint_data["partition"]
793
+
794
+
795
+ class DeferredError(object):
796
+ """Stores an exception and raises it at a later time if this object is accessed in any way.
797
+
798
+ Useful to allow soft-dependencies on imports, so that the ImportError can be raised again
799
+ later if code actually relies on the missing library.
800
+
801
+ Example::
802
+
803
+ try:
804
+ import obscurelib
805
+ except ImportError as e:
806
+ logger.warning("Failed to import obscurelib. Obscure features will not work.")
807
+ obscurelib = DeferredError(e)
808
+ """
809
+
810
+ def __init__(self, exception):
811
+ """Placeholder docstring"""
812
+ self.exc = exception
813
+
814
+ def __getattr__(self, name):
815
+ """Called by Python interpreter before using any method or property on the object.
816
+
817
+ So this will short-circuit essentially any access to this object.
818
+
819
+ Args:
820
+ name:
821
+ """
822
+ raise self.exc
823
+
824
+
825
+ def _module_import_error(py_module, feature, extras):
826
+ """Return error message for module import errors, provide installation details.
827
+
828
+ Args:
829
+ py_module (str): Module that failed to be imported
830
+ feature (str): Affected SageMaker feature
831
+ extras (str): Name of the `extras_require` to install the relevant dependencies
832
+
833
+ Returns:
834
+ str: Error message with installation instructions.
835
+ """
836
+ error_msg = (
837
+ "Failed to import {}. {} features will be impaired or broken. "
838
+ "Please run \"pip install 'sagemaker[{}]'\" "
839
+ "to install all required dependencies."
840
+ )
841
+ return error_msg.format(py_module, feature, extras)
842
+
843
+
844
+ class DataConfig(abc.ABC):
845
+ """Abstract base class for accessing data config hosted in AWS resources.
846
+
847
+ Provides a skeleton for customization by overriding of method fetch_data_config.
848
+ """
849
+
850
+ @abc.abstractmethod
851
+ def fetch_data_config(self):
852
+ """Abstract method implementing retrieval of data config from a pre-configured data source.
853
+
854
+ Returns:
855
+ object: The data configuration object.
856
+ """
857
+
858
+
859
+ class S3DataConfig(DataConfig):
860
+ """This class extends the DataConfig class to fetch a data config file hosted on S3"""
861
+
862
+ def __init__(
863
+ self,
864
+ sagemaker_session,
865
+ bucket_name,
866
+ prefix,
867
+ ):
868
+ """Initialize a ``S3DataConfig`` instance.
869
+
870
+ Args:
871
+ sagemaker_session (Session): SageMaker session instance to use for boto configuration.
872
+ bucket_name (str): Required. Bucket name from which data config needs to be fetched.
873
+ prefix (str): Required. The object prefix for the hosted data config.
874
+
875
+ """
876
+ if bucket_name is None or prefix is None:
877
+ raise ValueError(
878
+ "Bucket Name and S3 file Prefix are required arguments and must be provided."
879
+ )
880
+
881
+ super(S3DataConfig, self).__init__()
882
+
883
+ self.bucket_name = bucket_name
884
+ self.prefix = prefix
885
+ self.sagemaker_session = sagemaker_session
886
+
887
+ def fetch_data_config(self):
888
+ """Fetches data configuration from a S3 bucket.
889
+
890
+ Returns:
891
+ object: The JSON object containing data configuration.
892
+ """
893
+
894
+ json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix)
895
+ return json.loads(json_string)
896
+
897
+ def get_data_bucket(self, region_requested=None):
898
+ """Provides the bucket containing the data for specified region.
899
+
900
+ Args:
901
+ region_requested (str): The region for which the data is beig requested.
902
+
903
+ Returns:
904
+ str: Name of the S3 bucket containing datasets in the requested region.
905
+ """
906
+
907
+ config = self.fetch_data_config()
908
+ region = region_requested if region_requested else self.sagemaker_session.boto_region_name
909
+ return config[region] if region in config.keys() else config["default"]
910
+
911
+
912
+ def update_container_with_inference_params(
913
+ framework=None,
914
+ framework_version=None,
915
+ nearest_model_name=None,
916
+ data_input_configuration=None,
917
+ container_def=None,
918
+ container_list=None,
919
+ ):
920
+ """Function to check if inference recommender parameters exist and update container.
921
+
922
+ Args:
923
+ framework (str): Machine learning framework of the model package container image
924
+ (default: None).
925
+ framework_version (str): Framework version of the Model Package Container Image
926
+ (default: None).
927
+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
928
+ Amazon SageMaker Inference Recommender (default: None).
929
+ data_input_configuration (str): Input object for the model (default: None).
930
+ container_def (dict): object to be updated.
931
+ container_list (list): list to be updated.
932
+
933
+ Returns:
934
+ dict: dict with inference recommender params
935
+ """
936
+
937
+ if container_list is not None:
938
+ for obj in container_list:
939
+ construct_container_object(
940
+ obj, data_input_configuration, framework, framework_version, nearest_model_name
941
+ )
942
+
943
+ if container_def is not None:
944
+ construct_container_object(
945
+ container_def,
946
+ data_input_configuration,
947
+ framework,
948
+ framework_version,
949
+ nearest_model_name,
950
+ )
951
+
952
+ return container_list or container_def
953
+
954
+
955
+ def construct_container_object(
956
+ obj, data_input_configuration, framework, framework_version, nearest_model_name
957
+ ):
958
+ """Function to construct container object.
959
+
960
+ Args:
961
+ framework (str): Machine learning framework of the model package container image
962
+ (default: None).
963
+ framework_version (str): Framework version of the Model Package Container Image
964
+ (default: None).
965
+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
966
+ Amazon SageMaker Inference Recommender (default: None).
967
+ data_input_configuration (str): Input object for the model (default: None).
968
+ obj (dict): object to be updated.
969
+
970
+ Returns:
971
+ dict: container object
972
+ """
973
+
974
+ if framework is not None:
975
+ obj.update(
976
+ {
977
+ "Framework": framework,
978
+ }
979
+ )
980
+
981
+ if framework_version is not None:
982
+ obj.update(
983
+ {
984
+ "FrameworkVersion": framework_version,
985
+ }
986
+ )
987
+
988
+ if nearest_model_name is not None:
989
+ obj.update(
990
+ {
991
+ "NearestModelName": nearest_model_name,
992
+ }
993
+ )
994
+
995
+ if data_input_configuration is not None:
996
+ obj.update(
997
+ {
998
+ "ModelInput": {
999
+ "DataInputConfig": data_input_configuration,
1000
+ },
1001
+ }
1002
+ )
1003
+
1004
+ return obj
1005
+
1006
+
1007
+ def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None):
1008
+ """Pop out the unused key-word argument and give a warning.
1009
+
1010
+ Args:
1011
+ arg_name (str): The name of the argument to be checked if it is unused.
1012
+ kwargs (dict): The key-word argument dict.
1013
+ override_val (str): The value used to override the unused argument (default: None).
1014
+ """
1015
+ if arg_name not in kwargs:
1016
+ return
1017
+ warn_msg = "{} supplied in kwargs will be ignored".format(arg_name)
1018
+ if override_val:
1019
+ warn_msg += " and further overridden with {}.".format(override_val)
1020
+ logging.warning(warn_msg)
1021
+ kwargs.pop(arg_name)
1022
+
1023
+
1024
+ def to_string(obj: object):
1025
+ """Convert an object to string
1026
+
1027
+ This helper function handles converting PipelineVariable object to string as well
1028
+
1029
+ Args:
1030
+ obj (object): The object to be converted
1031
+ """
1032
+ return obj.to_string() if is_pipeline_variable(obj) else str(obj)
1033
+
1034
+
1035
+ def _start_waiting(waiting_time: int):
1036
+ """Waiting and print the in progress animation to stdout.
1037
+
1038
+ Args:
1039
+ waiting_time (int): The total waiting time.
1040
+ """
1041
+ interval = float(waiting_time) / WAITING_DOT_NUMBER
1042
+
1043
+ progress = ""
1044
+ for _ in range(WAITING_DOT_NUMBER):
1045
+ progress += "."
1046
+ print(progress, end="\r")
1047
+ time.sleep(interval)
1048
+ print(len(progress) * " ", end="\r")
1049
+
1050
+
1051
+ def get_module(module_name):
1052
+ """Import a module.
1053
+
1054
+ Args:
1055
+ module_name (str): name of the module to import.
1056
+
1057
+ Returns:
1058
+ object: The imported module.
1059
+
1060
+ Raises:
1061
+ Exception: when the module name is not found
1062
+ """
1063
+ try:
1064
+ return import_module(module_name)
1065
+ except ImportError:
1066
+ raise Exception("Cannot import module {}, please try again.".format(module_name))
1067
+
1068
+
1069
+ def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict:
1070
+ """Check user input experiment_config or get it from the current Run object if exists.
1071
+
1072
+ Args:
1073
+ experiment_config (dict): The experiment_config supplied by the user.
1074
+
1075
+ Returns:
1076
+ dict: Return the user supplied experiment_config if it is not None.
1077
+ Otherwise fetch the experiment_config from the current Run object if exists.
1078
+ """
1079
+ from sagemaker.core.experiments._run_context import _RunContext
1080
+
1081
+ run_obj = _RunContext.get_current_run()
1082
+ if experiment_config:
1083
+ if run_obj:
1084
+ logger.warning(
1085
+ "The function is invoked within an Experiment Run context "
1086
+ "but another experiment_config (%s) was supplied, so "
1087
+ "ignoring the experiment_config fetched from the Run object.",
1088
+ experiment_config,
1089
+ )
1090
+ return experiment_config
1091
+
1092
+ return run_obj.experiment_config if run_obj else None
1093
+
1094
+
1095
+ def resolve_value_from_config(
1096
+ direct_input=None,
1097
+ config_path: str = None,
1098
+ default_value=None,
1099
+ sagemaker_session=None,
1100
+ sagemaker_config: dict = None,
1101
+ ):
1102
+ """Decides which value for the caller to use.
1103
+
1104
+ Note: This method incorporates information from the sagemaker config.
1105
+
1106
+ Uses this order of prioritization:
1107
+ 1. direct_input
1108
+ 2. config value
1109
+ 3. default_value
1110
+ 4. None
1111
+
1112
+ Args:
1113
+ direct_input: The value that the caller of this method starts with. Usually this is an
1114
+ input to the caller's class or method.
1115
+ config_path (str): A string denoting the path used to lookup the value in the
1116
+ sagemaker config.
1117
+ default_value: The value used if not present elsewhere.
1118
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1119
+ SageMaker interactions (default: None).
1120
+ sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1121
+ Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1122
+ be checked for the config value if (and only if) sagemaker_session is None. This
1123
+ parameter exists for the rare cases where the user provided no Session but a default
1124
+ Session cannot be initialized before config injection is needed. In that case,
1125
+ the config dictionary may be loaded and passed here before a default Session object
1126
+ is created.
1127
+
1128
+ Returns:
1129
+ The value that should be used by the caller
1130
+ """
1131
+
1132
+ config_value = (
1133
+ get_sagemaker_config_value(
1134
+ sagemaker_session, config_path, sagemaker_config=sagemaker_config
1135
+ )
1136
+ if config_path
1137
+ else None
1138
+ )
1139
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1140
+
1141
+ _log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
1142
+
1143
+ if direct_input is not None:
1144
+ return direct_input
1145
+
1146
+ if config_value is not None:
1147
+ return config_value
1148
+
1149
+ return default_value
1150
+
1151
+
1152
+ def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = None):
1153
+ """Returns the value that corresponds to the provided key from the configuration file.
1154
+
1155
+ Args:
1156
+ key: Key Path of the config file entry.
1157
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1158
+ SageMaker interactions.
1159
+ sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1160
+ Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1161
+ be checked for the config value if (and only if) sagemaker_session is None. This
1162
+ parameter exists for the rare cases where no Session provided but a default Session
1163
+ cannot be initialized before config injection is needed. In that case, the config
1164
+ dictionary may be loaded and passed here before a default Session object is created.
1165
+
1166
+ Returns:
1167
+ object: The corresponding default value in the configuration file.
1168
+ """
1169
+ from sagemaker.core.config.config_manager import SageMakerConfig
1170
+
1171
+ if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
1172
+ config_to_check = sagemaker_session.sagemaker_config
1173
+ else:
1174
+ config_to_check = sagemaker_config
1175
+
1176
+ if not config_to_check:
1177
+ return None
1178
+
1179
+ SageMakerConfig().validate_sagemaker_config(config_to_check)
1180
+ config_value = get_config_value(key, config_to_check)
1181
+ # Copy the value so any modifications to the output will not modify the source config
1182
+ return copy.deepcopy(config_value)
1183
+
1184
+
1185
+ def get_resource_name_from_arn(arn):
1186
+ """Extract the resource name from an ARN string.
1187
+
1188
+ Args:
1189
+ arn (str): An ARN.
1190
+
1191
+ Returns:
1192
+ str: The resource name.
1193
+ """
1194
+ return arn.split(":", 5)[5].split("/", 1)[1]
1195
+
1196
+
1197
+ def list_tags(sagemaker_session, resource_arn, max_results=50):
1198
+ """List the tags given an Amazon Resource Name.
1199
+
1200
+ Args:
1201
+ resource_arn (str): The Amazon Resource Name (ARN) for which to get the tags list.
1202
+ max_results (int): The maximum number of results to include in a single page.
1203
+ This method takes care of that abstraction and returns a full list.
1204
+ """
1205
+ tags_list = []
1206
+
1207
+ try:
1208
+ list_tags_response = sagemaker_session.sagemaker_client.list_tags(
1209
+ ResourceArn=resource_arn, MaxResults=max_results
1210
+ )
1211
+ tags_list = tags_list + list_tags_response["Tags"]
1212
+
1213
+ next_token = list_tags_response.get("nextToken")
1214
+ while next_token is not None:
1215
+ list_tags_response = sagemaker_session.sagemaker_client.list_tags(
1216
+ ResourceArn=resource_arn, MaxResults=max_results, NextToken=next_token
1217
+ )
1218
+ tags_list = tags_list + list_tags_response["Tags"]
1219
+ next_token = list_tags_response.get("nextToken")
1220
+
1221
+ non_aws_tags = []
1222
+ for tag in tags_list:
1223
+ if "aws:" not in tag["Key"]:
1224
+ non_aws_tags.append(tag)
1225
+ return non_aws_tags
1226
+ except ClientError as error:
1227
+ logger.error("Error retrieving tags. resource_arn: %s", resource_arn)
1228
+ raise error
1229
+
1230
+
1231
+ def resolve_class_attribute_from_config(
1232
+ clazz: Optional[type],
1233
+ instance: Optional[object],
1234
+ attribute: str,
1235
+ config_path: str,
1236
+ default_value=None,
1237
+ sagemaker_session=None,
1238
+ ):
1239
+ """Utility method that merges config values to data classes.
1240
+
1241
+ Takes an instance of a class and, if not already set, sets the instance's attribute to a
1242
+ value fetched from the sagemaker_config or the default_value.
1243
+
1244
+ Uses this order of prioritization to determine what the value of the attribute should be:
1245
+ 1. current value of attribute
1246
+ 2. config value
1247
+ 3. default_value
1248
+ 4. does not set it
1249
+
1250
+ Args:
1251
+ clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the
1252
+ instance is None. If None is provided here, no new object will be created
1253
+ if 'instance' doesnt exist. Note: if provided, the constructor should set default
1254
+ values to None; Otherwise, the constructor's non-None default will be left
1255
+ as-is even if a config value was defined.
1256
+ instance (Optional[object]): instance of the Class 'clazz' that has an attribute
1257
+ of 'attribute' to set
1258
+ attribute (str): attribute of the instance to set if not already set
1259
+ config_path (str): a string denoting the path to use to lookup the config value in the
1260
+ sagemaker config
1261
+ default_value: the value to use if not present elsewhere
1262
+ sagemaker_session (sagemaker.core.helper.session.Sessionn): A SageMaker Session object, used for
1263
+ SageMaker interactions (default: None).
1264
+
1265
+ Returns:
1266
+ The updated class instance that should be used by the caller instead of the
1267
+ 'instance' parameter that was passed in.
1268
+ """
1269
+ config_value = get_sagemaker_config_value(sagemaker_session, config_path)
1270
+
1271
+ if config_value is None and default_value is None:
1272
+ # return instance unmodified. Could be None or populated
1273
+ return instance
1274
+
1275
+ if instance is None:
1276
+ if clazz is None or not inspect.isclass(clazz):
1277
+ return instance
1278
+ # construct a new instance if the instance does not exist
1279
+ instance = clazz()
1280
+
1281
+ if not hasattr(instance, attribute):
1282
+ raise TypeError(
1283
+ "Unexpected structure of object.",
1284
+ "Expected attribute {} to be present inside instance {} of class {}".format(
1285
+ attribute, instance, clazz
1286
+ ),
1287
+ )
1288
+
1289
+ current_value = getattr(instance, attribute)
1290
+ if current_value is None:
1291
+ # only set value if object does not already have a value set
1292
+ if config_value is not None:
1293
+ setattr(instance, attribute, config_value)
1294
+ elif default_value is not None:
1295
+ setattr(instance, attribute, default_value)
1296
+
1297
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1298
+
1299
+ _log_sagemaker_config_single_substitution(current_value, config_value, config_path)
1300
+
1301
+ return instance
1302
+
1303
+
1304
+ def resolve_nested_dict_value_from_config(
1305
+ dictionary: dict,
1306
+ nested_keys: List[str],
1307
+ config_path: str,
1308
+ default_value: object = None,
1309
+ sagemaker_session=None,
1310
+ ):
1311
+ """Utility method that sets the value of a key path in a nested dictionary .
1312
+
1313
+ This method takes a dictionary and, if not already set, sets the value for the provided
1314
+ list of nested keys to the value fetched from the sagemaker_config or the default_value.
1315
+
1316
+ Uses this order of prioritization to determine what the value of the attribute should be:
1317
+ (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it
1318
+
1319
+ Args:
1320
+ dictionary: The dict to update.
1321
+ nested_keys: The paths of keys where the value should be checked and set if needed.
1322
+ config_path (str): A string denoting the path used to find the config value in the
1323
+ sagemaker config.
1324
+ default_value: The value to use if not present elsewhere.
1325
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1326
+ SageMaker interactions (default: None).
1327
+
1328
+ Returns:
1329
+ The updated dictionary that should be used by the caller instead of the
1330
+ 'dictionary' parameter that was passed in.
1331
+ """
1332
+ config_value = get_sagemaker_config_value(sagemaker_session, config_path)
1333
+
1334
+ if config_value is None and default_value is None:
1335
+ # if there is nothing to set, return early. And there is no need to traverse through
1336
+ # the dictionary or add nested dicts to it
1337
+ return dictionary
1338
+
1339
+ try:
1340
+ current_nested_value = get_nested_value(dictionary, nested_keys)
1341
+ except ValueError as e:
1342
+ logging.error("Failed to check dictionary for applying sagemaker config: %s", e)
1343
+ return dictionary
1344
+
1345
+ if current_nested_value is None:
1346
+ # only set value if not already set
1347
+ if config_value is not None:
1348
+ dictionary = set_nested_value(dictionary, nested_keys, config_value)
1349
+ elif default_value is not None:
1350
+ dictionary = set_nested_value(dictionary, nested_keys, default_value)
1351
+
1352
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
1353
+
1354
+ _log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path)
1355
+
1356
+ return dictionary
1357
+
1358
+
1359
+ def update_list_of_dicts_with_values_from_config(
1360
+ input_list,
1361
+ config_key_path,
1362
+ required_key_paths: List[str] = None,
1363
+ union_key_paths: List[List[str]] = None,
1364
+ sagemaker_session=None,
1365
+ ):
1366
+ """Updates a list of dictionaries with missing values that are present in Config.
1367
+
1368
+ In some cases, config file might introduce new parameters which requires certain other
1369
+ parameters to be provided as part of the input list. Without those parameters, the underlying
1370
+ service will throw an exception. This method provides the capability to specify required key
1371
+ paths.
1372
+
1373
+ In some other cases, config file might introduce new parameters but the service API requires
1374
+ either an existing parameter or the new parameter that was supplied by config but not both
1375
+
1376
+ Args:
1377
+ input_list: The input list that was provided as a method parameter.
1378
+ config_key_path: The Key Path in the Config file that corresponds to the input_list
1379
+ parameter.
1380
+ required_key_paths (List[str]): List of required key paths that should be verified in the
1381
+ merged output. If a required key path is missing, we will not perform the merge for that
1382
+ item.
1383
+ union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify
1384
+ whether exactly zero/one of the parameters exist.
1385
+ For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or
1386
+ neither but not both, then pass [['X1', 'X2']]
1387
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1388
+ SageMaker interactions (default: None).
1389
+
1390
+ Returns:
1391
+ No output. In place merge happens.
1392
+ """
1393
+ if not input_list:
1394
+ return
1395
+ inputs_copy = copy.deepcopy(input_list)
1396
+ inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or []
1397
+ unmodified_inputs_from_config = copy.deepcopy(inputs_from_config)
1398
+
1399
+ for i in range(min(len(input_list), len(inputs_from_config))):
1400
+ dict_from_inputs = input_list[i]
1401
+ dict_from_config = inputs_from_config[i]
1402
+ merge_dicts(dict_from_config, dict_from_inputs)
1403
+ # Check if required key paths are present in merged dict (dict_from_config)
1404
+ required_key_path_check_passed = _validate_required_paths_in_a_dict(
1405
+ dict_from_config, required_key_paths
1406
+ )
1407
+ if not required_key_path_check_passed:
1408
+ # Don't do the merge, config is introducing a new parameter which needs a
1409
+ # corresponding required parameter.
1410
+ continue
1411
+ union_key_path_check_passed = _validate_union_key_paths_in_a_dict(
1412
+ dict_from_config, union_key_paths
1413
+ )
1414
+ if not union_key_path_check_passed:
1415
+ # Don't do the merge, Union parameters are not obeyed.
1416
+ continue
1417
+ input_list[i] = dict_from_config
1418
+
1419
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
1420
+
1421
+ _log_sagemaker_config_merge(
1422
+ source_value=inputs_copy,
1423
+ config_value=unmodified_inputs_from_config,
1424
+ merged_source_and_config_value=input_list,
1425
+ config_key_path=config_key_path,
1426
+ )
1427
+
1428
+
1429
+ def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool:
1430
+ """Placeholder docstring"""
1431
+ if not required_key_paths:
1432
+ return True
1433
+ for required_key_path in required_key_paths:
1434
+ if get_config_value(required_key_path, source_dict) is None:
1435
+ return False
1436
+ return True
1437
+
1438
+
1439
+ def _validate_union_key_paths_in_a_dict(
1440
+ source_dict, union_key_paths: List[List[str]] = None
1441
+ ) -> bool:
1442
+ """Placeholder docstring"""
1443
+ if not union_key_paths:
1444
+ return True
1445
+ for union_key_path in union_key_paths:
1446
+ union_parameter_present = False
1447
+ for key_path in union_key_path:
1448
+ if get_config_value(key_path, source_dict):
1449
+ if union_parameter_present:
1450
+ return False
1451
+ union_parameter_present = True
1452
+ return True
1453
+
1454
+
1455
+ def update_nested_dictionary_with_values_from_config(
1456
+ source_dict, config_key_path, sagemaker_session=None
1457
+ ) -> dict:
1458
+ """Updates a nested dictionary with missing values that are present in Config.
1459
+
1460
+ Args:
1461
+ source_dict: The input nested dictionary that was provided as method parameter.
1462
+ config_key_path: The Key Path in the Config file which corresponds to this
1463
+ source_dict parameter.
1464
+ sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
1465
+ SageMaker interactions (default: None).
1466
+
1467
+ Returns:
1468
+ dict: The merged nested dictionary that is updated with missing values that are present
1469
+ in the Config file.
1470
+ """
1471
+ inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {}
1472
+ original_config_dict_value = copy.deepcopy(inferred_config_dict)
1473
+ merge_dicts(inferred_config_dict, source_dict or {})
1474
+
1475
+ if original_config_dict_value == {}:
1476
+ # The config value is empty. That means either
1477
+ # (1) inferred_config_dict equals source_dict, or
1478
+ # (2) if source_dict was None, inferred_config_dict equals {}
1479
+ # We should return whatever source_dict was to be safe. Because if for example,
1480
+ # a VpcConfig is set to {} instead of None, some boto calls will fail due to
1481
+ # ParamValidationError (because a VpcConfig was specified but required parameters for
1482
+ # the VpcConfig were missing.)
1483
+
1484
+ # Don't need to print because no config value was used or defined
1485
+ return source_dict
1486
+
1487
+ from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
1488
+
1489
+ _log_sagemaker_config_merge(
1490
+ source_value=source_dict,
1491
+ config_value=original_config_dict_value,
1492
+ merged_source_and_config_value=inferred_config_dict,
1493
+ config_key_path=config_key_path,
1494
+ )
1495
+
1496
+ return inferred_config_dict
1497
+
1498
+
1499
+ def stringify_object(obj: Any) -> str:
1500
+ """Returns string representation of object, returning only non-None fields."""
1501
+ non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
1502
+ return f"{type(obj).__name__}: {str(non_none_atts)}"
1503
+
1504
+
1505
+ def volume_size_supported(instance_type: str) -> bool:
1506
+ """Returns True if SageMaker allows volume_size to be used for the instance type.
1507
+
1508
+ Raises:
1509
+ ValueError: If the instance type is improperly formatted.
1510
+ """
1511
+
1512
+ try:
1513
+
1514
+ # local mode does not support volume size
1515
+ # instance type given as pipeline parameter does not support volume size
1516
+ # do not change the if statement order below.
1517
+ if is_pipeline_variable(instance_type) or instance_type.startswith("local"):
1518
+ return False
1519
+
1520
+ parts: List[str] = instance_type.split(".")
1521
+
1522
+ if len(parts) == 3 and parts[0] == "ml":
1523
+ parts = parts[1:]
1524
+
1525
+ if len(parts) != 2:
1526
+ raise ValueError(f"Failed to parse instance type '{instance_type}'")
1527
+
1528
+ # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc)
1529
+ # + g5 or g6 or p5 does not support attaching an EBS volume.
1530
+ family = parts[0]
1531
+
1532
+ unsupported_families = ["g5", "g6", "p5", "trn1"]
1533
+
1534
+ return "d" not in family and not any(
1535
+ family.startswith(prefix) for prefix in unsupported_families
1536
+ )
1537
+ except Exception as e:
1538
+ raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")
1539
+
1540
+
1541
+ def instance_supports_kms(instance_type: str) -> bool:
1542
+ """Returns True if SageMaker allows KMS keys to be attached to the instance.
1543
+
1544
+ Raises:
1545
+ ValueError: If the instance type is improperly formatted.
1546
+ """
1547
+ return volume_size_supported(instance_type)
1548
+
1549
+
1550
+ def get_instance_type_family(instance_type: str) -> str:
1551
+ """Return the family of the instance type.
1552
+
1553
+ Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
1554
+ or there is no match, return an empty string.
1555
+ """
1556
+ instance_type_family = ""
1557
+ if isinstance(instance_type, str):
1558
+ match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1559
+ if match is not None:
1560
+ instance_type_family = match[1]
1561
+ return instance_type_family
1562
+
1563
+
1564
+ def create_paginator_config(max_items: int = None, page_size: int = None) -> Dict[str, int]:
1565
+ """Placeholder docstring"""
1566
+ return {
1567
+ "MaxItems": max_items if max_items else MAX_ITEMS,
1568
+ "PageSize": page_size if page_size else PAGE_SIZE,
1569
+ }
1570
+
1571
+
1572
+ def format_tags(tags: Tags) -> List[TagsDict]:
1573
+ """Process tags to turn them into the expected format for Sagemaker."""
1574
+ if isinstance(tags, dict):
1575
+ return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]
1576
+
1577
+ return tags
1578
+
1579
+
1580
+ def _get_resolved_path(path):
1581
+ """Return the normalized absolute path of a given path.
1582
+
1583
+ abspath - returns the absolute path without resolving symlinks
1584
+ realpath - resolves the symlinks and gets the actual path
1585
+ normpath - normalizes paths (e.g. remove redudant separators)
1586
+ and handles platform-specific differences
1587
+ """
1588
+ return normpath(realpath(abspath(path)))
1589
+
1590
+
1591
+ def _is_bad_path(path, base):
1592
+ """Checks if the joined path (base directory + file path) is rooted under the base directory
1593
+
1594
+ Ensuring that the file does not attempt to access paths
1595
+ outside the expected directory structure.
1596
+
1597
+ Args:
1598
+ path (str): The file path.
1599
+ base (str): The base directory.
1600
+
1601
+ Returns:
1602
+ bool: True if the path is not rooted under the base directory, False otherwise.
1603
+ """
1604
+ # joinpath will ignore base if path is absolute
1605
+ return not _get_resolved_path(joinpath(base, path)).startswith(base)
1606
+
1607
+
1608
+ def _is_bad_link(info, base):
1609
+ """Checks if the link is rooted under the base directory.
1610
+
1611
+ Ensuring that the link does not attempt to access paths outside the expected directory structure
1612
+
1613
+ Args:
1614
+ info (tarfile.TarInfo): The tar file info.
1615
+ base (str): The base directory.
1616
+
1617
+ Returns:
1618
+ bool: True if the link is not rooted under the base directory, False otherwise.
1619
+ """
1620
+ # Links are interpreted relative to the directory containing the link
1621
+ tip = _get_resolved_path(joinpath(base, dirname(info.name)))
1622
+ return _is_bad_path(info.linkname, base=tip)
1623
+
1624
+
1625
+ def _get_safe_members(members):
1626
+ """A generator that yields members that are safe to extract.
1627
+
1628
+ It filters out bad paths and bad links.
1629
+
1630
+ Args:
1631
+ members (list): A list of members to check.
1632
+
1633
+ Yields:
1634
+ tarfile.TarInfo: The tar file info.
1635
+ """
1636
+ base = _get_resolved_path("")
1637
+
1638
+ for file_info in members:
1639
+ if _is_bad_path(file_info.name, base):
1640
+ logger.error("%s is blocked (illegal path)", file_info.name)
1641
+ elif file_info.issym() and _is_bad_link(file_info, base):
1642
+ logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
1643
+ elif file_info.islnk() and _is_bad_link(file_info, base):
1644
+ logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
1645
+ else:
1646
+ yield file_info
1647
+
1648
+
1649
+ def custom_extractall_tarfile(tar, extract_path):
1650
+ """Extract a tarfile, optionally using data_filter if available.
1651
+
1652
+ # TODO: The function and it's usages can be deprecated once SageMaker Python SDK
1653
+ is upgraded to use Python 3.12+
1654
+
1655
+ If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
1656
+ Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.
1657
+
1658
+ Args:
1659
+ tar (tarfile.TarFile): The opened tarfile object.
1660
+ extract_path (str): The path to extract the contents of the tarfile.
1661
+
1662
+ Returns:
1663
+ None
1664
+ """
1665
+ if hasattr(tarfile, "data_filter"):
1666
+ tar.extractall(path=extract_path, filter="data")
1667
+ else:
1668
+ tar.extractall(path=extract_path, members=_get_safe_members(tar))
1669
+
1670
+
1671
+ def can_model_package_source_uri_autopopulate(source_uri: str):
1672
+ """Checks if the source_uri can lead to auto-population of information in the Model registry.
1673
+
1674
+ Args:
1675
+ source_uri (str): The source uri.
1676
+
1677
+ Returns:
1678
+ bool: True if the source_uri can lead to auto-population, False otherwise.
1679
+ """
1680
+ return bool(
1681
+ re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri)
1682
+ )
1683
+
1684
+
1685
+ def flatten_dict(
1686
+ d: Dict[str, Any],
1687
+ max_flatten_depth=None,
1688
+ ) -> Dict[str, Any]:
1689
+ """Flatten a dictionary object.
1690
+
1691
+ d (Dict[str, Any]):
1692
+ The dict that will be flattened.
1693
+ max_flatten_depth (Optional[int]):
1694
+ Maximum depth to merge.
1695
+ """
1696
+
1697
+ def tuple_reducer(k1, k2):
1698
+ if k1 is None:
1699
+ return (k2,)
1700
+ return k1 + (k2,)
1701
+
1702
+ # check max_flatten_depth
1703
+ if max_flatten_depth is not None and max_flatten_depth < 1:
1704
+ raise ValueError("max_flatten_depth should not be less than 1.")
1705
+
1706
+ reducer = tuple_reducer
1707
+
1708
+ flat_dict = {}
1709
+
1710
+ def _flatten(_d, depth, parent=None):
1711
+ key_value_iterable = viewitems(_d)
1712
+ has_item = False
1713
+ for key, value in key_value_iterable:
1714
+ has_item = True
1715
+ flat_key = reducer(parent, key)
1716
+ if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth):
1717
+ has_child = _flatten(value, depth=depth + 1, parent=flat_key)
1718
+ if has_child:
1719
+ continue
1720
+
1721
+ if flat_key in flat_dict:
1722
+ raise ValueError("duplicated key '{}'".format(flat_key))
1723
+ flat_dict[flat_key] = value
1724
+
1725
+ return has_item
1726
+
1727
+ _flatten(d, depth=1)
1728
+ return flat_dict
1729
+
1730
+
1731
+ def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None:
1732
+ """Set a value to a sequence of nested keys."""
1733
+
1734
+ key = keys[0]
1735
+
1736
+ if len(keys) == 1:
1737
+ d[key] = value
1738
+ return
1739
+
1740
+ d = d.setdefault(key, {})
1741
+ nested_set_dict(d, keys[1:], value)
1742
+
1743
+
1744
+ def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
1745
+ """Unflatten dict-like object.
1746
+
1747
+ d (Dict[str, Any]) :
1748
+ The dict that will be unflattened.
1749
+ """
1750
+
1751
+ unflattened_dict = {}
1752
+ for flat_key, value in viewitems(d):
1753
+ key_tuple = flat_key
1754
+ nested_set_dict(unflattened_dict, key_tuple, value)
1755
+
1756
+ return unflattened_dict
1757
+
1758
+
1759
+ def deep_override_dict(
1760
+ dict1: Dict[str, Any], dict2: Dict[str, Any], skip_keys: Optional[List[str]] = None
1761
+ ) -> Dict[str, Any]:
1762
+ """Overrides any overlapping contents of dict1 with the contents of dict2."""
1763
+ if skip_keys is None:
1764
+ skip_keys = []
1765
+
1766
+ flattened_dict1 = flatten_dict(dict1)
1767
+ flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None}
1768
+ flattened_dict2 = flatten_dict(
1769
+ {key: value for key, value in dict2.items() if key not in skip_keys}
1770
+ )
1771
+ flattened_dict1.update(flattened_dict2)
1772
+ return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
1773
+
1774
+
1775
+ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
1776
+ """Resolve Routing Config
1777
+
1778
+ Args:
1779
+ routing_config (Optional[Dict[str, Any]]): The routing config.
1780
+
1781
+ Returns:
1782
+ Optional[Dict[str, Any]]: The resolved routing config.
1783
+
1784
+ Raises:
1785
+ ValueError: If the RoutingStrategy is invalid.
1786
+ """
1787
+
1788
+ if routing_config:
1789
+ routing_strategy = routing_config.get("RoutingStrategy", None)
1790
+ if routing_strategy:
1791
+ if isinstance(routing_strategy, RoutingStrategy):
1792
+ return {"RoutingStrategy": routing_strategy.name}
1793
+ if isinstance(routing_strategy, str) and (
1794
+ routing_strategy.upper() == RoutingStrategy.RANDOM.name
1795
+ or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
1796
+ ):
1797
+ return {"RoutingStrategy": routing_strategy.upper()}
1798
+ raise ValueError(
1799
+ "RoutingStrategy must be either RoutingStrategy.RANDOM "
1800
+ "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
1801
+ )
1802
+ return None
1803
+
1804
+
1805
+ @lru_cache
1806
+ def get_instance_rate_per_hour(
1807
+ instance_type: str,
1808
+ region: str,
1809
+ ) -> Optional[Dict[str, str]]:
1810
+ """Gets instance rate per hour for the given instance type.
1811
+
1812
+ Args:
1813
+ instance_type (str): The instance type.
1814
+ region (str): The region.
1815
+ Returns:
1816
+ Optional[Dict[str, str]]: Instance rate per hour.
1817
+ Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}.
1818
+
1819
+ Raises:
1820
+ Exception: An exception is raised if
1821
+ the IAM role is not authorized to perform pricing:GetProducts.
1822
+ or unexpected event happened.
1823
+ """
1824
+ region_name = "us-east-1"
1825
+ if region.startswith("eu") or region.startswith("af"):
1826
+ region_name = "eu-central-1"
1827
+ elif region.startswith("ap") or region.startswith("cn"):
1828
+ region_name = "ap-south-1"
1829
+
1830
+ pricing_client: boto3.client = boto3.client("pricing", region_name=region_name)
1831
+ res = pricing_client.get_products(
1832
+ ServiceCode="AmazonSageMaker",
1833
+ Filters=[
1834
+ {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
1835
+ {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
1836
+ {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
1837
+ ],
1838
+ )
1839
+
1840
+ price_list = res.get("PriceList", [])
1841
+ if len(price_list) > 0:
1842
+ price_data = price_list[0]
1843
+ if isinstance(price_data, str):
1844
+ price_data = json.loads(price_data)
1845
+
1846
+ instance_rate_per_hour = extract_instance_rate_per_hour(price_data)
1847
+ if instance_rate_per_hour is not None:
1848
+ return instance_rate_per_hour
1849
+ raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.")
1850
+
1851
+
1852
+ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
1853
+ """Extract instance rate per hour for the given Price JSON data.
1854
+
1855
+ Args:
1856
+ price_data (Dict[str, Any]): The Price JSON data.
1857
+ Returns:
1858
+ Optional[Dict[str, str], None]: Instance rate per hour.
1859
+ """
1860
+
1861
+ if price_data is not None:
1862
+ price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values()
1863
+ for dimension in price_dimensions:
1864
+ for price in dimension.get("priceDimensions", {}).values():
1865
+ for currency in price.get("pricePerUnit", {}).keys():
1866
+ value = price.get("pricePerUnit", {}).get(currency)
1867
+ if value is not None:
1868
+ value = str(round(float(value), 3))
1869
+ return {
1870
+ "unit": f"{currency}/Hr",
1871
+ "value": value,
1872
+ "name": "On-demand Instance Rate",
1873
+ }
1874
+ return None
1875
+
1876
+
1877
+ def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
1878
+ """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
1879
+
1880
+ Args:
1881
+ data (dict): The dictionary to be updated.
1882
+
1883
+ Returns:
1884
+ dict: The updated dictionary with keys in PascalCase.
1885
+ """
1886
+ result = {}
1887
+
1888
+ def convert_key(key):
1889
+ """Converts a snake_case key to PascalCase."""
1890
+ return "".join(part.capitalize() for part in key.split("_"))
1891
+
1892
+ def convert_value(value):
1893
+ """Recursively processes the value of a key-value pair."""
1894
+ if isinstance(value, dict):
1895
+ return camel_case_to_pascal_case(value)
1896
+ if isinstance(value, list):
1897
+ return [convert_value(item) for item in value]
1898
+
1899
+ return value
1900
+
1901
+ for key, value in data.items():
1902
+ result[convert_key(key)] = convert_value(value)
1903
+
1904
+ return result
1905
+
1906
+
1907
+ def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool:
1908
+ """Returns True if ``tag`` already exists.
1909
+
1910
+ Args:
1911
+ tag (TagsDict): The tag dictionary.
1912
+ curr_tags (Optional[Tags]): The current tags.
1913
+
1914
+ Returns:
1915
+ bool: True if the tag exists.
1916
+ """
1917
+ if curr_tags is None:
1918
+ return False
1919
+
1920
+ for curr_tag in curr_tags:
1921
+ if tag["Key"] == curr_tag["Key"]:
1922
+ return True
1923
+
1924
+ return False
1925
+
1926
+
1927
+ def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]:
1928
+ """Validates new tags against existing tags.
1929
+
1930
+ Args:
1931
+ new_tags (Optional[Tags]): The new tags.
1932
+ curr_tags (Optional[Tags]): The current tags.
1933
+
1934
+ Returns:
1935
+ Optional[Tags]: The updated tags.
1936
+ """
1937
+ if curr_tags is None:
1938
+ return new_tags
1939
+
1940
+ if curr_tags and isinstance(curr_tags, dict):
1941
+ curr_tags = [curr_tags]
1942
+
1943
+ if isinstance(new_tags, dict):
1944
+ if not tag_exists(new_tags, curr_tags):
1945
+ curr_tags.append(new_tags)
1946
+ elif isinstance(new_tags, list):
1947
+ for new_tag in new_tags:
1948
+ if not tag_exists(new_tag, curr_tags):
1949
+ curr_tags.append(new_tag)
1950
+
1951
+ return curr_tags
1952
+
1953
+
1954
+ def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
1955
+ """Remove a tag with the given key from the list of tags.
1956
+
1957
+ Args:
1958
+ key (str): The key of the tag to remove.
1959
+ tags (Optional[Tags]): The current list of tags.
1960
+
1961
+ Returns:
1962
+ Optional[Tags]: The updated list of tags with the tag removed.
1963
+ """
1964
+ if tags is None:
1965
+ return tags
1966
+ if isinstance(tags, dict):
1967
+ tags = [tags]
1968
+
1969
+ updated_tags = []
1970
+ for tag in tags:
1971
+ if tag["Key"] != key:
1972
+ updated_tags.append(tag)
1973
+
1974
+ if not updated_tags:
1975
+ return None
1976
+ if len(updated_tags) == 1:
1977
+ return updated_tags[0]
1978
+ return updated_tags
1979
+
1980
+
1981
+ def get_domain_for_region(region: str) -> str:
1982
+ """Returns the domain for the given region.
1983
+
1984
+ Args:
1985
+ region (str): AWS region name.
1986
+ """
1987
+ return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
1988
+
1989
+
1990
+ def camel_to_snake(camel_case_string: str) -> str:
1991
+ """Converts camelCase to snake_case_string using a regex.
1992
+
1993
+ This regex cannot handle whitespace ("camelString TwoWords")
1994
+ """
1995
+ return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
1996
+
1997
+
1998
+ def walk_and_apply_json(
1999
+ json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
2000
+ ) -> Dict[Any, Any]:
2001
+ """Recursively walks a json object and applies a given function to the keys.
2002
+
2003
+ stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
2004
+ Any children of these keys will not have the application function applied to them.
2005
+ """
2006
+
2007
+ def _walk_and_apply_json(json_obj, new):
2008
+ if isinstance(json_obj, dict) and isinstance(new, dict):
2009
+ for key, value in json_obj.items():
2010
+ new_key = apply(key)
2011
+ if (stop_keys and new_key not in stop_keys) or stop_keys is None:
2012
+ if isinstance(value, dict):
2013
+ new[new_key] = {}
2014
+ _walk_and_apply_json(value, new=new[new_key])
2015
+ elif isinstance(value, list):
2016
+ new[new_key] = []
2017
+ for item in value:
2018
+ _walk_and_apply_json(item, new=new[new_key])
2019
+ else:
2020
+ new[new_key] = value
2021
+ else:
2022
+ new[new_key] = value
2023
+ elif isinstance(json_obj, dict) and isinstance(new, list):
2024
+ new.append(_walk_and_apply_json(json_obj, new={}))
2025
+ elif isinstance(json_obj, list) and isinstance(new, dict):
2026
+ new.update(json_obj)
2027
+ elif isinstance(json_obj, list) and isinstance(new, list):
2028
+ new.append(json_obj)
2029
+ elif isinstance(json_obj, str) and isinstance(new, list):
2030
+ new.append(json_obj)
2031
+ return new
2032
+
2033
+ return _walk_and_apply_json(json_obj, new={})
2034
+
2035
+
2036
+ def _wait_until(callable_fn, poll=5):
2037
+ """Placeholder docstring"""
2038
+ elapsed_time = 0
2039
+ result = None
2040
+ while result is None:
2041
+ try:
2042
+ elapsed_time += poll
2043
+ time.sleep(poll)
2044
+ result = callable_fn()
2045
+ except botocore.exceptions.ClientError as err:
2046
+ # For initial 5 mins we accept/pass AccessDeniedException.
2047
+ # The reason is to await tag propagation to avoid false AccessDenied claims for an
2048
+ # access policy based on resource tags, The caveat here is for true AccessDenied
2049
+ # cases the routine will fail after 5 mins
2050
+ if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
2051
+ logger.warning(
2052
+ "Received AccessDeniedException. This could mean the IAM role does not "
2053
+ "have the resource permissions, in which case please add resource access "
2054
+ "and retry. For cases where the role has tag based resource policy, "
2055
+ "continuing to wait for tag propagation.."
2056
+ )
2057
+ continue
2058
+ raise err
2059
+ return result
2060
+
2061
+
2062
+ def _flush_log_streams(
2063
+ stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
2064
+ ):
2065
+ """Placeholder docstring"""
2066
+ if len(stream_names) < instance_count:
2067
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
2068
+ # may be dynamic until we have a stream for every instance.
2069
+ try:
2070
+ streams = client.describe_log_streams(
2071
+ logGroupName=log_group,
2072
+ logStreamNamePrefix=job_name + "/",
2073
+ orderBy="LogStreamName",
2074
+ limit=min(instance_count, 50),
2075
+ )
2076
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
2077
+
2078
+ while "nextToken" in streams:
2079
+ streams = client.describe_log_streams(
2080
+ logGroupName=log_group,
2081
+ logStreamNamePrefix=job_name + "/",
2082
+ orderBy="LogStreamName",
2083
+ limit=50,
2084
+ )
2085
+
2086
+ stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
2087
+
2088
+ positions.update(
2089
+ [
2090
+ (s, sagemaker.core.logs.Position(timestamp=0, skip=0))
2091
+ for s in stream_names
2092
+ if s not in positions
2093
+ ]
2094
+ )
2095
+ except ClientError as e:
2096
+ # On the very first training job run on an account, there's no log group until
2097
+ # the container starts logging, so ignore any errors thrown about that
2098
+ err = e.response.get("Error", {})
2099
+ if err.get("Code", None) != "ResourceNotFoundException":
2100
+ raise
2101
+
2102
+ if len(stream_names) > 0:
2103
+ if dot:
2104
+ print("")
2105
+ dot = False
2106
+ for idx, event in sagemaker.core.logs.multi_stream_iter(
2107
+ client, log_group, stream_names, positions
2108
+ ):
2109
+ color_wrap(idx, event["message"])
2110
+ ts, count = positions[stream_names[idx]]
2111
+ if event["timestamp"] == ts:
2112
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2113
+ timestamp=ts, skip=count + 1
2114
+ )
2115
+ else:
2116
+ positions[stream_names[idx]] = sagemaker.core.logs.Position(
2117
+ timestamp=event["timestamp"], skip=1
2118
+ )
2119
+ else:
2120
+ dot = True
2121
+ print(".", end="")
2122
+ sys.stdout.flush()
2123
+
2124
+
2125
+ class LogState(object):
2126
+ """Placeholder docstring"""
2127
+
2128
+ STARTING = 1
2129
+ WAIT_IN_PROGRESS = 2
2130
+ TAILING = 3
2131
+ JOB_COMPLETE = 4
2132
+ COMPLETE = 5
2133
+
2134
+
2135
+ _STATUS_CODE_TABLE = {
2136
+ "COMPLETED": "Completed",
2137
+ "INPROGRESS": "InProgress",
2138
+ "IN_PROGRESS": "InProgress",
2139
+ "FAILED": "Failed",
2140
+ "STOPPED": "Stopped",
2141
+ "STOPPING": "Stopping",
2142
+ "STARTING": "Starting",
2143
+ "PENDING": "Pending",
2144
+ }
2145
+
2146
+
2147
+ def _get_initial_job_state(description, status_key, wait):
2148
+ """Placeholder docstring"""
2149
+ status = description[status_key]
2150
+ job_already_completed = status in ("Completed", "Failed", "Stopped")
2151
+ return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
2152
+
2153
+
2154
+ def _logs_init(boto_session, description, job):
2155
+ """Placeholder docstring"""
2156
+ if job == "Training":
2157
+ if "InstanceGroups" in description["ResourceConfig"]:
2158
+ instance_count = 0
2159
+ for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
2160
+ instance_count += instanceGroup["InstanceCount"]
2161
+ else:
2162
+ instance_count = description["ResourceConfig"]["InstanceCount"]
2163
+ elif job == "Transform":
2164
+ instance_count = description["TransformResources"]["InstanceCount"]
2165
+ elif job == "Processing":
2166
+ instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
2167
+ elif job == "AutoML":
2168
+ instance_count = 0
2169
+
2170
+ stream_names = [] # The list of log streams
2171
+ positions = {} # The current position in each stream, map of stream name -> position
2172
+
2173
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
2174
+ # to be interrupted by a transient exception.
2175
+ config = botocore.config.Config(retries={"max_attempts": 15})
2176
+ client = boto_session.client("logs", config=config)
2177
+ log_group = "/aws/sagemaker/" + job + "Jobs"
2178
+
2179
+ dot = False
2180
+
2181
+ color_wrap = sagemaker.core.logs.ColorWrap()
2182
+
2183
+ return instance_count, stream_names, positions, client, log_group, dot, color_wrap
2184
+
2185
+
2186
+ def _check_job_status(job, desc, status_key_name):
2187
+ """Check to see if the job completed successfully.
2188
+
2189
+ If not, construct and raise a exceptions. (UnexpectedStatusException).
2190
+
2191
+ Args:
2192
+ job (str): The name of the job to check.
2193
+ desc (dict[str, str]): The result of ``describe_training_job()``.
2194
+ status_key_name (str): Status key name to check for.
2195
+
2196
+ Raises:
2197
+ exceptions.CapacityError: If the training job fails with CapacityError.
2198
+ exceptions.UnexpectedStatusException: If the training job fails.
2199
+ """
2200
+ status = desc[status_key_name]
2201
+ # If the status is capital case, then convert it to Camel case
2202
+ status = _STATUS_CODE_TABLE.get(status, status)
2203
+
2204
+ if status == "Stopped":
2205
+ logger.warning(
2206
+ "Job ended with status 'Stopped' rather than 'Completed'. "
2207
+ "This could mean the job timed out or stopped early for some other reason: "
2208
+ "Consider checking whether it completed as you expect."
2209
+ )
2210
+ elif status != "Completed":
2211
+ reason = desc.get("FailureReason", "(No reason provided)")
2212
+ job_type = status_key_name.replace("JobStatus", " job")
2213
+ troubleshooting = (
2214
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
2215
+ "sagemaker-python-sdk-troubleshooting.html"
2216
+ )
2217
+ message = (
2218
+ "Error for {job_type} {job_name}: {status}. Reason: {reason}. "
2219
+ "Check troubleshooting guide for common errors: {troubleshooting}"
2220
+ ).format(
2221
+ job_type=job_type,
2222
+ job_name=job,
2223
+ status=status,
2224
+ reason=reason,
2225
+ troubleshooting=troubleshooting,
2226
+ )
2227
+ if "CapacityError" in str(reason):
2228
+ raise exceptions.CapacityError(
2229
+ message=message,
2230
+ allowed_statuses=["Completed", "Stopped"],
2231
+ actual_status=status,
2232
+ )
2233
+ raise exceptions.UnexpectedStatusException(
2234
+ message=message,
2235
+ allowed_statuses=["Completed", "Stopped"],
2236
+ actual_status=status,
2237
+ )
2238
+
2239
+
2240
+ def _create_resource(create_fn):
2241
+ """Call create function and accepts/pass when resource already exists.
2242
+
2243
+ This is a helper function to use an existing resource if found when creating.
2244
+
2245
+ Args:
2246
+ create_fn: Create resource function.
2247
+
2248
+ Returns:
2249
+ (bool): True if new resource was created, False if resource already exists.
2250
+ """
2251
+ try:
2252
+ create_fn()
2253
+ # create function succeeded, resource does not exist already
2254
+ return True
2255
+ except ClientError as ce:
2256
+ error_code = ce.response["Error"]["Code"]
2257
+ error_message = ce.response["Error"]["Message"]
2258
+ already_exists_exceptions = ["ValidationException", "ResourceInUse"]
2259
+ already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
2260
+ if not (
2261
+ error_code in already_exists_exceptions
2262
+ and any(p in error_message for p in already_exists_msg_patterns)
2263
+ ):
2264
+ raise ce
2265
+ # no new resource created as resource already exists
2266
+ return False
2267
+
2268
+
2269
+ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
2270
+ """Checks whether an S3 URI is valid.
2271
+
2272
+ Args:
2273
+ s3_uri (Optional[str]): The S3 URI.
2274
+
2275
+ Returns:
2276
+ bool: Whether the S3 URI is valid.
2277
+ """
2278
+ if s3_uri is None:
2279
+ return False
2280
+
2281
+ return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None