sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,640 @@
1
+ import re
2
+ from typing import Optional
3
+ from graphene.utils.str_converters import to_camel_case
4
+
5
+ from sagemaker.core.inference_config import ServerlessInferenceConfig
6
+ from sagemaker.core.training_compiler.config import TrainingCompilerConfig
7
+ from sagemaker.core.common_utils import _botocore_resolver
8
+ from sagemaker.core.workflow import is_pipeline_variable
9
+ from sagemaker.core.image_retriever.image_retriever_utils import (
10
+ _config_for_framework_and_scope,
11
+ _get_final_image_scope,
12
+ _get_image_tag,
13
+ _get_inference_tool,
14
+ _get_latest_versions,
15
+ _processor,
16
+ _registry_from_region,
17
+ _retrieve_latest_pytorch_training_uri,
18
+ _retrieve_pytorch_uri_inputs_are_all_default,
19
+ _validate_arg,
20
+ _validate_for_suppported_frameworks_and_instance_type,
21
+ _validate_instance_deprecation,
22
+ _validate_py_version_and_set_if_needed,
23
+ _validate_version_and_set_if_needed,
24
+ _version_for_config,
25
+ config_for_framework,
26
+ )
27
+ from sagemaker.core.workflow.utilities import override_pipeline_parameter_var
28
+ from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path
29
+ from sagemaker.core.config.config_manager import SageMakerConfig
30
+
31
+ ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
32
+ HUGGING_FACE_FRAMEWORK = "huggingface"
33
+ PYTORCH_FRAMEWORK = "pytorch"
34
+
35
+ CONFIGURABLE_ATTRIBUTES = [
36
+ "version",
37
+ "py_version",
38
+ "instance_type",
39
+ "accelerator_type",
40
+ "image_scope",
41
+ "container_version",
42
+ "distributed",
43
+ "smp",
44
+ "base_framework_version",
45
+ "training_compiler_config",
46
+ "model_id",
47
+ "model_version",
48
+ "sdk_version",
49
+ "inference_tool",
50
+ "serverless_inference_config",
51
+ ]
52
+
53
+
54
+ class ImageRetriever:
55
+ @staticmethod
56
+ def retrieve_hugging_face_uri(
57
+ region: str,
58
+ version: Optional[str] = None,
59
+ py_version: Optional[str] = None,
60
+ instance_type: Optional[str] = None,
61
+ accelerator_type: Optional[str] = None,
62
+ image_scope: Optional[str] = None,
63
+ container_version: str = None,
64
+ distributed: bool = False,
65
+ base_framework_version: Optional[str] = None,
66
+ training_compiler_config: TrainingCompilerConfig = None,
67
+ sdk_version: Optional[str] = None,
68
+ inference_tool: Optional[str] = None,
69
+ serverless_inference_config: ServerlessInferenceConfig = None,
70
+ ):
71
+ """Retrieves the ECR URI for the Docker image of HuggingFace models.
72
+
73
+ Args:
74
+ region (str): The AWS region.
75
+ version (Optional[str]): The framework or algorithm version. This is required if there is
76
+ more than one supported version for the given framework or algorithm.
77
+ py_version (Optional[str]): The Python version. This is required if there is
78
+ more than one supported Python version for the given framework version.
79
+ instance_type (Optional[str]): The SageMaker instance type. For supported types, see
80
+ https://aws.amazon.com/sagemaker/pricing. This is required if
81
+ there are different images for different processor types.
82
+ accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see
83
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
84
+ image_scope (Optional[str]): The image type, i.e. what it is used for.
85
+ Valid values: "training", "inference", "inference_graviton", "eia".
86
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
87
+ container_version (Optional[str]): the version of docker image.
88
+ Ideally the value of parameter should be created inside the framework.
89
+ For custom use, see the list of supported container versions:
90
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
91
+ (default: None).
92
+ distributed (bool): If the image is for running distributed training.
93
+ base_framework_version (Optional[str]): The base version number of PyTorch or Tensorflow.
94
+ (default: None).
95
+ training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
96
+ A configuration class for the SageMaker Training Compiler
97
+ (default: None).
98
+ sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval.
99
+ (default: None).
100
+ inference_tool (Optional[str]): the tool that will be used to aid in the inference.
101
+ Valid values: "neuron, neuronx, None"
102
+ (default: None).
103
+ serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig):
104
+ Specifies configuration related to serverless endpoint. Instance type is
105
+ not provided in serverless inference. So this is used to determine processor type..
106
+
107
+ Returns:
108
+ str: The ECR URI for the corresponding SageMaker Docker image.
109
+ """
110
+ args = dict(locals())
111
+ for name, val in args.items():
112
+ if name in CONFIGURABLE_ATTRIBUTES and not val:
113
+ default_value = SageMakerConfig.resolve_value_from_config(
114
+ config_path=_simple_path(
115
+ SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
116
+ )
117
+ )
118
+ if default_value is not None:
119
+ locals()[name] = default_value
120
+
121
+ if training_compiler_config:
122
+ final_image_scope = image_scope
123
+ config = _config_for_framework_and_scope(
124
+ HUGGING_FACE_FRAMEWORK + "-training-compiler", final_image_scope, accelerator_type
125
+ )
126
+ else:
127
+ _framework = HUGGING_FACE_FRAMEWORK
128
+ inference_tool = _get_inference_tool(inference_tool, instance_type)
129
+ if inference_tool in ["neuron", "neuronx"]:
130
+ _framework = f"{HUGGING_FACE_FRAMEWORK}-{inference_tool}"
131
+ final_image_scope = _get_final_image_scope(
132
+ HUGGING_FACE_FRAMEWORK, instance_type, image_scope
133
+ )
134
+ _validate_for_suppported_frameworks_and_instance_type(
135
+ HUGGING_FACE_FRAMEWORK, instance_type
136
+ )
137
+ config = _config_for_framework_and_scope(
138
+ _framework, final_image_scope, accelerator_type
139
+ )
140
+
141
+ if base_framework_version is not None:
142
+ processor = _processor(instance_type, ["cpu", "gpu"])
143
+ is_native_huggingface_gpu = processor == "gpu" and not training_compiler_config
144
+ container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
145
+
146
+ original_version = version
147
+ version = _validate_version_and_set_if_needed(
148
+ version, config, HUGGING_FACE_FRAMEWORK, image_scope
149
+ )
150
+ version_config = config["versions"][_version_for_config(version, config)]
151
+
152
+ if version_config.get("version_aliases"):
153
+ full_base_framework_version = version_config["version_aliases"].get(
154
+ base_framework_version, base_framework_version
155
+ )
156
+ _validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
157
+ version_config = version_config.get(full_base_framework_version)
158
+
159
+ py_version = _validate_py_version_and_set_if_needed(
160
+ py_version, version_config, HUGGING_FACE_FRAMEWORK
161
+ )
162
+ version_config = version_config.get(py_version) or version_config
163
+ registry = _registry_from_region(region, version_config["registries"])
164
+ endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
165
+ if region == "il-central-1" and not endpoint_data:
166
+ endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
167
+ hostname = endpoint_data["hostname"]
168
+
169
+ repo = version_config["repository"]
170
+
171
+ processor = _processor(
172
+ instance_type,
173
+ config.get("processors") or version_config.get("processors"),
174
+ serverless_inference_config,
175
+ )
176
+
177
+ # if container version is available in .json file, utilize that
178
+ if version_config.get("container_version"):
179
+ container_version = version_config["container_version"][processor]
180
+
181
+ # Append sdk version in case of trainium instances
182
+ if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]:
183
+ if not sdk_version:
184
+ sdk_version = _get_latest_versions(version_config["sdk_versions"])
185
+ container_version = sdk_version + "-" + container_version
186
+
187
+ pt_or_tf_version = (
188
+ re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
189
+ )
190
+ _version = original_version
191
+
192
+ if repo in [
193
+ "huggingface-pytorch-trcomp-training",
194
+ "huggingface-tensorflow-trcomp-training",
195
+ ]:
196
+ _version = version
197
+ if repo in [
198
+ "huggingface-pytorch-inference-neuron",
199
+ "huggingface-pytorch-inference-neuronx",
200
+ ]:
201
+ if not sdk_version:
202
+ sdk_version = _get_latest_versions(version_config["sdk_versions"])
203
+ container_version = sdk_version + "-" + container_version
204
+ if config.get("version_aliases").get(original_version):
205
+ _version = config.get("version_aliases")[original_version]
206
+ if (
207
+ config.get("versions", {})
208
+ .get(_version, {})
209
+ .get("version_aliases", {})
210
+ .get(base_framework_version, {})
211
+ ):
212
+ _base_framework_version = config.get("versions")[_version]["version_aliases"][
213
+ base_framework_version
214
+ ]
215
+ pt_or_tf_version = (
216
+ re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
217
+ )
218
+
219
+ tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
220
+
221
+ if repo == f"{HUGGING_FACE_FRAMEWORK}-inference-graviton":
222
+ container_version = f"{container_version}-sagemaker"
223
+ _validate_instance_deprecation(HUGGING_FACE_FRAMEWORK, instance_type, version)
224
+
225
+ tag = _get_image_tag(
226
+ container_version,
227
+ distributed,
228
+ final_image_scope,
229
+ HUGGING_FACE_FRAMEWORK,
230
+ inference_tool,
231
+ instance_type,
232
+ processor,
233
+ py_version,
234
+ tag_prefix,
235
+ version,
236
+ )
237
+
238
+ if tag:
239
+ repo += ":{}".format(tag)
240
+
241
+ return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
242
+
243
+ @staticmethod
244
+ def retrieve_jumpstart_uri():
245
+ """Retrieves the container image URI for JumpStart models.
246
+
247
+ TODO: We are changing the Jumpstart utilities and moving from S3 based metadata to
248
+ API based metadata. Implement this function after utility changes are finished.
249
+ """
250
+
251
+ @staticmethod
252
+ def retrieve_pytorch_uri(
253
+ region: str,
254
+ version: Optional[str] = None,
255
+ py_version: Optional[str] = None,
256
+ instance_type: Optional[str] = None,
257
+ accelerator_type: Optional[str] = None,
258
+ image_scope: Optional[str] = None,
259
+ container_version: str = None,
260
+ distributed: bool = False,
261
+ smp: bool = False,
262
+ training_compiler_config: TrainingCompilerConfig = None,
263
+ sdk_version: Optional[str] = None,
264
+ inference_tool: Optional[str] = None,
265
+ serverless_inference_config: ServerlessInferenceConfig = None,
266
+ ):
267
+ """Retrieves the ECR URI for the Docker image of PyTorch models.
268
+
269
+ Args:
270
+ region (str): The AWS region.
271
+ version (Optional[str]): The framework or algorithm version. This is required if there is
272
+ more than one supported version for the given framework or algorithm.
273
+ py_version (Optional[str]): The Python version. This is required if there is
274
+ more than one supported Python version for the given framework version.
275
+ instance_type (Optional[str]): The SageMaker instance type. For supported types, see
276
+ https://aws.amazon.com/sagemaker/pricing. This is required if
277
+ there are different images for different processor types.
278
+ accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see
279
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
280
+ image_scope (Optional[str]): The image type, i.e. what it is used for.
281
+ Valid values: "training", "inference", "inference_graviton", "eia".
282
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
283
+ container_version (Optional[str]): the version of docker image.
284
+ Ideally the value of parameter should be created inside the framework.
285
+ For custom use, see the list of supported container versions:
286
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
287
+ (default: None).
288
+ distributed (bool): If the image is for running distributed training.
289
+ smp (bool): If using the SMP library for distributed training.
290
+ training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
291
+ A configuration class for the SageMaker Training Compiler
292
+ (default: None).
293
+ sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval.
294
+ (default: None).
295
+ inference_tool (Optional[str]): the tool that will be used to aid in the inference.
296
+ Valid values: "neuron, neuronx, None"
297
+ (default: None).
298
+ serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig):
299
+ Specifies configuration related to serverless endpoint. Instance type is
300
+ not provided in serverless inference. So this is used to determine processor type.
301
+
302
+ Returns:
303
+ str: The ECR URI for the corresponding SageMaker Docker image.
304
+ """
305
+ if _retrieve_pytorch_uri_inputs_are_all_default(
306
+ version,
307
+ py_version,
308
+ instance_type,
309
+ accelerator_type,
310
+ image_scope,
311
+ container_version,
312
+ distributed,
313
+ smp,
314
+ training_compiler_config,
315
+ sdk_version,
316
+ inference_tool,
317
+ serverless_inference_config,
318
+ ):
319
+ # if all inputs are not set, return the latest pytorch training uri
320
+ return _retrieve_latest_pytorch_training_uri(region)
321
+
322
+ framework = PYTORCH_FRAMEWORK
323
+ if training_compiler_config:
324
+ final_image_scope = image_scope
325
+ config = _config_for_framework_and_scope(
326
+ PYTORCH_FRAMEWORK + "-training-compiler", final_image_scope, accelerator_type
327
+ )
328
+ else:
329
+ _framework = PYTORCH_FRAMEWORK
330
+ inference_tool = _get_inference_tool(inference_tool, instance_type)
331
+ if inference_tool in ["neuron", "neuronx"]:
332
+ _framework = f"{PYTORCH_FRAMEWORK}-{inference_tool}"
333
+ final_image_scope = _get_final_image_scope(
334
+ PYTORCH_FRAMEWORK, instance_type, image_scope
335
+ )
336
+ _validate_for_suppported_frameworks_and_instance_type(PYTORCH_FRAMEWORK, instance_type)
337
+ config = _config_for_framework_and_scope(
338
+ _framework, final_image_scope, accelerator_type
339
+ )
340
+
341
+ version = _validate_version_and_set_if_needed(
342
+ version, config, PYTORCH_FRAMEWORK, image_scope
343
+ )
344
+ version_config = config["versions"][_version_for_config(version, config)]
345
+
346
+ if distributed and smp:
347
+ framework = "pytorch-smp"
348
+ supported_smp_pt_versions_cu124 = ("2.5",)
349
+ supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4")
350
+ if any(pt_version in version for pt_version in supported_smp_pt_versions_cu124):
351
+ container_version = "cu124"
352
+ elif "p5" in instance_type or any(
353
+ pt_version in version for pt_version in supported_smp_pt_versions_cu121
354
+ ):
355
+ container_version = "cu121"
356
+ else:
357
+ container_version = "cu118"
358
+
359
+ py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
360
+ version_config = version_config.get(py_version) or version_config
361
+ registry = _registry_from_region(region, version_config["registries"])
362
+ endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
363
+ if region == "il-central-1" and not endpoint_data:
364
+ endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
365
+ hostname = endpoint_data["hostname"]
366
+
367
+ repo = version_config["repository"]
368
+
369
+ processor = _processor(
370
+ instance_type,
371
+ config.get("processors") or version_config.get("processors"),
372
+ serverless_inference_config,
373
+ )
374
+
375
+ # if container version is available in .json file, utilize that
376
+ if version_config.get("container_version"):
377
+ container_version = version_config["container_version"][processor]
378
+
379
+ # Append sdk version in case of trainium instances
380
+ if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]:
381
+ if not sdk_version:
382
+ sdk_version = _get_latest_versions(version_config["sdk_versions"])
383
+ container_version = sdk_version + "-" + container_version
384
+
385
+ tag_prefix = version_config.get("tag_prefix", version)
386
+
387
+ if repo == f"{framework}-inference-graviton":
388
+ container_version = f"{container_version}-sagemaker"
389
+ _validate_instance_deprecation(framework, instance_type, version)
390
+
391
+ tag = _get_image_tag(
392
+ container_version,
393
+ distributed,
394
+ final_image_scope,
395
+ framework,
396
+ inference_tool,
397
+ instance_type,
398
+ processor,
399
+ py_version,
400
+ tag_prefix,
401
+ version,
402
+ )
403
+
404
+ if tag:
405
+ repo += ":{}".format(tag)
406
+
407
+ return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
408
+
409
+ @staticmethod
410
+ @override_pipeline_parameter_var
411
+ def retrieve(
412
+ framework: str,
413
+ region: str,
414
+ version: Optional[str] = None,
415
+ py_version: Optional[str] = None,
416
+ instance_type: Optional[str] = None,
417
+ accelerator_type: Optional[str] = None,
418
+ image_scope: Optional[str] = None,
419
+ container_version: str = None,
420
+ distributed: bool = False,
421
+ smp: bool = False,
422
+ base_framework_version: Optional[str] = None,
423
+ training_compiler_config: TrainingCompilerConfig = None,
424
+ model_id: Optional[str] = None,
425
+ model_version: Optional[str] = None,
426
+ sdk_version: Optional[str] = None,
427
+ inference_tool: Optional[str] = None,
428
+ serverless_inference_config: ServerlessInferenceConfig = None,
429
+ ):
430
+ """Retrieves the ECR URI for the Docker image matching the given arguments.
431
+
432
+ Args:
433
+ framework (str): The name of the framework or algorithm.
434
+ region (str): The AWS region.
435
+ version (Optional[str]): The framework or algorithm version. This is required if there is
436
+ more than one supported version for the given framework or algorithm.
437
+ py_version (Optional[str]): The Python version. This is required if there is
438
+ more than one supported Python version for the given framework version.
439
+ instance_type (Optional[str]): The SageMaker instance type. For supported types, see
440
+ https://aws.amazon.com/sagemaker/pricing. This is required if
441
+ there are different images for different processor types.
442
+ accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see
443
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
444
+ image_scope (Optional[str]): The image type, i.e. what it is used for.
445
+ Valid values: "training", "inference", "inference_graviton", "eia".
446
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
447
+ container_version (Optional[str]): the version of docker image.
448
+ Ideally the value of parameter should be created inside the framework.
449
+ For custom use, see the list of supported container versions:
450
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
451
+ (default: None).
452
+ distributed (bool): If the image is for running distributed training.
453
+ smp (bool): If using the SMP library for distributed training.
454
+ base_framework_version (Optional[str]): The base version number of PyTorch or Tensorflow.
455
+ (default: None).
456
+ training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
457
+ A configuration class for the SageMaker Training Compiler
458
+ (default: None).
459
+ model_id (Optional[str]): The JumpStart model ID for which to retrieve the image URI
460
+ (default: None).
461
+ model_version (Optional[str]): The version of the JumpStart model for which to retrieve the
462
+ image URI (default: None).
463
+ hub_arn (Optional[str]): The arn of the SageMaker Hub for which to retrieve
464
+ model details from. (Default: None).
465
+ tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
466
+ should be tolerated without an exception raised. If ``False``, raises an exception if
467
+ the script used by this version of the model has dependencies with known security
468
+ vulnerabilities. (Default: False).
469
+ tolerate_deprecated_model (bool): True if deprecated versions of model specifications
470
+ should be tolerated without an exception raised. If False, raises an exception
471
+ if the version of the model is deprecated. (Default: False).
472
+ sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval.
473
+ (default: None).
474
+ inference_tool (Optional[str]): the tool that will be used to aid in the inference.
475
+ Valid values: "neuron, neuronx, None"
476
+ (default: None).
477
+ serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig):
478
+ Specifies configuration related to serverless endpoint. Instance type is
479
+ not provided in serverless inference. So this is used to determine processor type.
480
+ sagemaker_session (sagemaker.core.helper.Session): A SageMaker Session
481
+ object, used for SageMaker interactions. If not
482
+ specified, one is created using the default AWS configuration
483
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
484
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
485
+ model_type (JumpStartModelType): The type of the model, can be open weights model
486
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
487
+
488
+ Returns:
489
+ str: The ECR URI for the corresponding SageMaker Docker image.
490
+
491
+ Raises:
492
+ NotImplementedError: If the scope is not supported.
493
+ ValueError: If the combination of arguments specified is not supported or
494
+ any PipelineVariable object is passed in.
495
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
496
+ known security vulnerabilities.
497
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
498
+ """
499
+ args = dict(locals())
500
+ for name, val in args.items():
501
+ if name in CONFIGURABLE_ATTRIBUTES and not val:
502
+ default_value = SageMakerConfig.resolve_value_from_config(
503
+ config_path=_simple_path(
504
+ SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
505
+ )
506
+ )
507
+ if default_value is not None:
508
+ locals()[name] = default_value
509
+
510
+ for name, val in args.items():
511
+ if is_pipeline_variable(val):
512
+ raise ValueError(
513
+ "When retrieving the image_uri, the argument %s should not be a pipeline variable "
514
+ "(%s) since pipeline variables are only interpreted in the pipeline execution time."
515
+ % (name, type(val))
516
+ )
517
+ if model_id:
518
+ return ImageRetriever.retrieve_jumpstart_uri()
519
+
520
+ if framework == HUGGING_FACE_FRAMEWORK:
521
+ return ImageRetriever.retrieve_hugging_face_uri(
522
+ region=region,
523
+ version=version,
524
+ py_version=py_version,
525
+ instance_type=instance_type,
526
+ accelerator_type=accelerator_type,
527
+ image_scope=image_scope,
528
+ container_version=container_version,
529
+ distributed=distributed,
530
+ base_framework_version=base_framework_version,
531
+ training_compiler_config=training_compiler_config,
532
+ sdk_version=sdk_version,
533
+ inference_tool=inference_tool,
534
+ serverless_inference_config=serverless_inference_config,
535
+ )
536
+
537
+ if framework == PYTORCH_FRAMEWORK:
538
+ return ImageRetriever.retrieve_pytorch_uri(
539
+ region=region,
540
+ version=version,
541
+ py_version=py_version,
542
+ instance_type=instance_type,
543
+ accelerator_type=accelerator_type,
544
+ image_scope=image_scope,
545
+ container_version=container_version,
546
+ distributed=distributed,
547
+ smp=smp,
548
+ training_compiler_config=training_compiler_config,
549
+ sdk_version=sdk_version,
550
+ inference_tool=inference_tool,
551
+ serverless_inference_config=serverless_inference_config,
552
+ )
553
+
554
+ final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
555
+ _validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
556
+ config = _config_for_framework_and_scope(framework, final_image_scope, accelerator_type)
557
+
558
+ version = _validate_version_and_set_if_needed(version, config, framework, image_scope)
559
+ version_config = config["versions"][_version_for_config(version, config)]
560
+
561
+ py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
562
+ version_config = version_config.get(py_version) or version_config
563
+ registry = _registry_from_region(region, version_config["registries"])
564
+ endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
565
+ if region == "il-central-1" and not endpoint_data:
566
+ endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
567
+ hostname = endpoint_data["hostname"]
568
+
569
+ repo = version_config["repository"]
570
+
571
+ processor = _processor(
572
+ instance_type,
573
+ config.get("processors") or version_config.get("processors"),
574
+ serverless_inference_config,
575
+ )
576
+
577
+ # if container version is available in .json file, utilize that
578
+ if version_config.get("container_version"):
579
+ container_version = version_config["container_version"][processor]
580
+
581
+ # Append sdk version in case of trainium instances
582
+ if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]:
583
+ if not sdk_version:
584
+ sdk_version = _get_latest_versions(version_config["sdk_versions"])
585
+ container_version = sdk_version + "-" + container_version
586
+
587
+ tag_prefix = version_config.get("tag_prefix", version)
588
+
589
+ if repo == f"{framework}-inference-graviton":
590
+ container_version = f"{container_version}-sagemaker"
591
+ _validate_instance_deprecation(framework, instance_type, version)
592
+
593
+ tag = _get_image_tag(
594
+ container_version,
595
+ distributed,
596
+ final_image_scope,
597
+ framework,
598
+ inference_tool,
599
+ instance_type,
600
+ processor,
601
+ py_version,
602
+ tag_prefix,
603
+ version,
604
+ )
605
+
606
+ if tag:
607
+ repo += ":{}".format(tag)
608
+
609
+ return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
610
+
611
+ @staticmethod
612
+ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str:
613
+ """Retrieves the image URI for base python image.
614
+
615
+ Args:
616
+ region (str): The AWS region to use for image URI.
617
+ py_version (str): The python version to use for the image.
618
+ Default to 310
619
+
620
+ Returns:
621
+ str: The image URI string.
622
+ """
623
+
624
+ framework = "sagemaker-base-python"
625
+ version = "1.0"
626
+ endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
627
+ if region == "il-central-1" and not endpoint_data:
628
+ endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
629
+ hostname = endpoint_data["hostname"]
630
+ config = config_for_framework(framework)
631
+ version_config = config["versions"][_version_for_config(version, config)]
632
+
633
+ registry = _registry_from_region(region, version_config["registries"])
634
+
635
+ repo = version_config["repository"] + "-" + py_version
636
+ repo_and_tag = repo + ":" + version
637
+
638
+ return ECR_URI_TEMPLATE.format(
639
+ registry=registry, hostname=hostname, repository=repo_and_tag
640
+ )