sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1176 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Utility methods used by framework classes."""
14
+ from __future__ import absolute_import
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import re
20
+ import shutil
21
+ import tempfile
22
+ import time
23
+ from collections import namedtuple
24
+ from typing import Dict, List, Optional, Union
25
+
26
+ from packaging import version
27
+
28
+ import sagemaker.core.common_utils as sagemaker_utils
29
+ from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
30
+ from sagemaker.core.instance_group import InstanceGroup
31
+ from sagemaker.core.s3 import s3_path_join
32
+ from sagemaker.core.session_settings import SessionSettings
33
+ from sagemaker.core.workflow import is_pipeline_variable
34
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ _TAR_SOURCE_FILENAME = "source.tar.gz"
39
+
40
+ UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
41
+ """sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
42
+
43
+ This is for the source code used for the entry point with an ``Estimator``. It can be
44
+ instantiated with positional or keyword arguments.
45
+ """
46
+
47
+ PYTHON_2_DEPRECATION_WARNING = (
48
+ "{latest_supported_version} is the latest version of {framework} that supports "
49
+ "Python 2. Newer versions of {framework} will only be available for Python 3."
50
+ "Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
51
+ )
52
+ PARAMETER_SERVER_MULTI_GPU_WARNING = (
53
+ "If you have selected a multi-GPU training instance type "
54
+ "and also enabled parameter server for distributed training, "
55
+ "distributed training with the default parameter server configuration will not "
56
+ "fully leverage all GPU cores; the parameter server will be configured to run "
57
+ "only one worker per host regardless of the number of GPUs."
58
+ )
59
+
60
+ DEBUGGER_UNSUPPORTED_REGIONS = (
61
+ "us-iso-east-1",
62
+ "us-isob-east-1",
63
+ "ap-southeast-3",
64
+ "ap-southeast-4",
65
+ "eu-south-2",
66
+ "me-central-1",
67
+ "ap-south-2",
68
+ "eu-central-2",
69
+ "us-gov-east-1",
70
+ )
71
+ PROFILER_UNSUPPORTED_REGIONS = (
72
+ "us-iso-east-1",
73
+ "us-isob-east-1",
74
+ "ap-southeast-3",
75
+ "ap-southeast-4",
76
+ "eu-south-2",
77
+ "me-central-1",
78
+ "ap-south-2",
79
+ "eu-central-2",
80
+ "us-gov-east-1",
81
+ )
82
+
83
+ SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
84
+ SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
85
+ "ml.p3.16xlarge",
86
+ "ml.p3dn.24xlarge",
87
+ "ml.p4d.24xlarge",
88
+ "ml.p4de.24xlarge",
89
+ "local_gpu",
90
+ )
91
+ SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
92
+ # tf 2.12 should not be supported: smdataparallel excludes support for tf>=2.12.
93
+ "tensorflow": [
94
+ "2.3",
95
+ "2.3.1",
96
+ "2.3.2",
97
+ "2.4",
98
+ "2.4.1",
99
+ "2.4.3",
100
+ "2.5",
101
+ "2.5.0",
102
+ "2.5.1",
103
+ "2.6",
104
+ "2.6.0",
105
+ "2.6.2",
106
+ "2.6.3",
107
+ "2.7",
108
+ "2.7.1",
109
+ "2.8",
110
+ "2.8.0",
111
+ "2.9",
112
+ "2.9.1",
113
+ "2.9.2",
114
+ "2.10",
115
+ "2.10.1",
116
+ "2.11",
117
+ "2.11.0",
118
+ ],
119
+ "pytorch": [
120
+ "1.6",
121
+ "1.6.0",
122
+ "1.7",
123
+ "1.7.1",
124
+ "1.8",
125
+ "1.8.0",
126
+ "1.8.1",
127
+ "1.9",
128
+ "1.9.0",
129
+ "1.9.1",
130
+ "1.10",
131
+ "1.10.0",
132
+ "1.10.2",
133
+ "1.11",
134
+ "1.11.0",
135
+ "1.12",
136
+ "1.12.0",
137
+ "1.12.1",
138
+ "1.13.1",
139
+ "2.0.0",
140
+ "2.0.1",
141
+ "2.1.0",
142
+ "2.1.2",
143
+ "2.2.0",
144
+ ],
145
+ }
146
+
147
+ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
148
+ "1.13.1",
149
+ "2.0.0",
150
+ "2.0.1",
151
+ "2.1.0",
152
+ "2.1.2",
153
+ "2.2.0",
154
+ "2.3.0",
155
+ "2.3.1",
156
+ "2.4.1",
157
+ "2.5.1",
158
+ ]
159
+
160
+ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
161
+ TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
162
+ "1.11",
163
+ "1.11.0",
164
+ "1.12",
165
+ "1.12.0",
166
+ "1.12.1",
167
+ "1.13.1",
168
+ "2.0.0",
169
+ ]
170
+
171
+ SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
172
+
173
+
174
+ GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = [
175
+ "m6g",
176
+ "m6gd",
177
+ "c6g",
178
+ "c6gd",
179
+ "c6gn",
180
+ "c7g",
181
+ "r6g",
182
+ "r6gd",
183
+ ]
184
+
185
+
186
+ GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch", "xgboost", "sklearn"])
187
+
188
+
189
+ def validate_source_dir(script, directory):
190
+ """Validate that the source directory exists and it contains the user script.
191
+
192
+ Args:
193
+ script (str): Script filename.
194
+ directory (str): Directory containing the source file.
195
+ Raises:
196
+ ValueError: If ``directory`` does not exist, is not a directory, or does
197
+ not contain ``script``.
198
+ """
199
+ if directory:
200
+ if not os.path.isfile(os.path.join(directory, script)):
201
+ raise ValueError(
202
+ 'No file named "{}" was found in directory "{}".'.format(script, directory)
203
+ )
204
+
205
+ return True
206
+
207
+
208
+ def validate_source_code_input_against_pipeline_variables(
209
+ entry_point: Optional[Union[str, PipelineVariable]] = None,
210
+ source_dir: Optional[Union[str, PipelineVariable]] = None,
211
+ git_config: Optional[Dict[str, str]] = None,
212
+ enable_network_isolation: Union[bool, PipelineVariable] = False,
213
+ ):
214
+ """Validate source code input against pipeline variables.
215
+
216
+ Args:
217
+ entry_point (str or PipelineVariable): The path to the local Python source file that
218
+ should be executed as the entry point to training (default: None).
219
+ source_dir (str or PipelineVariable): The Path to a directory with any other
220
+ training source code dependencies aside from the entry point file (default: None).
221
+ git_config (Dict[str, str]): Git configurations used for cloning files (default: None).
222
+ enable_network_isolation (bool or PipelineVariable): Specifies whether container will run
223
+ in network isolation mode (default: False).
224
+ """
225
+ if is_pipeline_variable(enable_network_isolation) or enable_network_isolation is True:
226
+ if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
227
+ raise TypeError(
228
+ "entry_point, source_dir should not be pipeline variables "
229
+ "when enable_network_isolation is a pipeline variable or it is set to True."
230
+ )
231
+ if git_config:
232
+ if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
233
+ raise TypeError(
234
+ "entry_point, source_dir should not be pipeline variables when git_config is given."
235
+ )
236
+ if is_pipeline_variable(entry_point):
237
+ if not source_dir:
238
+ raise TypeError(
239
+ "The entry_point should not be a pipeline variable when source_dir is missing."
240
+ )
241
+ if not is_pipeline_variable(source_dir) and not source_dir.lower().startswith("s3://"):
242
+ raise TypeError(
243
+ "The entry_point should not be a pipeline variable when source_dir is a local path."
244
+ )
245
+ logger.warning(
246
+ "The entry_point is a pipeline variable: %s. During pipeline execution, "
247
+ "the interpreted value of entry_point has to be a local path in the container "
248
+ "pointing to a Python source file which is located at the root of source_dir.",
249
+ type(entry_point),
250
+ )
251
+ if is_pipeline_variable(source_dir):
252
+ logger.warning(
253
+ "The source_dir is a pipeline variable: %s. During pipeline execution, "
254
+ "the interpreted value of source_dir has to be an S3 URI and "
255
+ "must point to a file with name ``sourcedir.tar.gz``",
256
+ type(source_dir),
257
+ )
258
+
259
+
260
+ def parse_mp_parameters(params):
261
+ """Parse the model parallelism parameters provided by the user.
262
+
263
+ Args:
264
+ params: a string representing path to an existing config, or
265
+ a config dict.
266
+
267
+ Returns:
268
+ parsed: a dict of parsed config.
269
+
270
+ Raises:
271
+ ValueError: if params is not a string or a dict, or
272
+ the config file cannot be parsed as json.
273
+ """
274
+ parsed = None
275
+ if isinstance(params, dict):
276
+ parsed = params
277
+ elif os.path.exists(params):
278
+ try:
279
+ with open(params, "r") as fp:
280
+ parsed = json.load(fp)
281
+ except json.decoder.JSONDecodeError:
282
+ pass
283
+ else:
284
+ raise ValueError(
285
+ f"Expected a string path to an existing modelparallel config, or a dictionary. "
286
+ f"Received: {params}."
287
+ )
288
+
289
+ if parsed is None:
290
+ raise ValueError(f"Cannot parse {params} as a json file.")
291
+
292
+ return parsed
293
+
294
+
295
+ def get_mp_parameters(distribution):
296
+ """Get the model parallelism parameters provided by the user.
297
+
298
+ Args:
299
+ distribution: distribution dictionary defined by the user.
300
+
301
+ Returns:
302
+ params: dictionary containing model parallelism parameters
303
+ used for training.
304
+ """
305
+ try:
306
+ mp_dict = distribution["smdistributed"]["modelparallel"]
307
+ except KeyError:
308
+ mp_dict = {}
309
+ if mp_dict.get("enabled", False) is True:
310
+ params = mp_dict.get("parameters", {})
311
+ params = parse_mp_parameters(params)
312
+ validate_mp_config(params)
313
+ return params
314
+ return None
315
+
316
+
317
+ def validate_mp_config(config):
318
+ """Validate the configuration dictionary for model parallelism.
319
+
320
+ Args:
321
+ config (dict): Dictionary holding configuration keys and values.
322
+
323
+ Raises:
324
+ ValueError: If any of the keys have incorrect values.
325
+ """
326
+
327
+ def validate_positive(key):
328
+ try:
329
+ if not isinstance(config[key], int) or config[key] < 1:
330
+ raise ValueError(f"The number of {key} must be a positive integer.")
331
+ except KeyError:
332
+ pass
333
+
334
+ def validate_in(key, vals):
335
+ try:
336
+ if config[key] not in vals:
337
+ raise ValueError(f"{key} must be a value in: {vals}.")
338
+ except KeyError:
339
+ pass
340
+
341
+ def validate_bool(keys):
342
+ validate_in(keys, [True, False])
343
+
344
+ validate_in("pipeline", ["simple", "interleaved", "_only_forward"])
345
+ validate_in("placement_strategy", ["spread", "cluster"])
346
+ validate_in("optimize", ["speed", "memory"])
347
+
348
+ for key in ["microbatches", "partitions", "active_microbatches"]:
349
+ validate_positive(key)
350
+
351
+ for key in [
352
+ "auto_partition",
353
+ "contiguous",
354
+ "load_partition",
355
+ "horovod",
356
+ "ddp",
357
+ "deterministic_server",
358
+ ]:
359
+ validate_bool(key)
360
+
361
+ if "partition_file" in config and not isinstance(config.get("partition_file"), str):
362
+ raise ValueError("'partition_file' must be a str.")
363
+
364
+ if config.get("auto_partition") is False and "default_partition" not in config:
365
+ raise ValueError("default_partition must be supplied if auto_partition is set to False!")
366
+
367
+ if "default_partition" in config and config["default_partition"] >= config["partitions"]:
368
+ raise ValueError("default_partition must be less than the number of partitions!")
369
+
370
+ if "memory_weight" in config and (
371
+ config["memory_weight"] > 1.0 or config["memory_weight"] < 0.0
372
+ ):
373
+ raise ValueError("memory_weight must be between 0.0 and 1.0!")
374
+
375
+ if "ddp_port" in config and "ddp" not in config:
376
+ raise ValueError("`ddp_port` needs `ddp` to be set as well")
377
+
378
+ if "ddp_dist_backend" in config and "ddp" not in config:
379
+ raise ValueError("`ddp_dist_backend` needs `ddp` to be set as well")
380
+
381
+ if "ddp_port" in config:
382
+ if not isinstance(config["ddp_port"], int) or config["ddp_port"] < 0:
383
+ value = config["ddp_port"]
384
+ raise ValueError(f"Invalid port number {value}.")
385
+
386
+ if config.get("horovod", False) and config.get("ddp", False):
387
+ raise ValueError("'ddp' and 'horovod' cannot be simultaneously enabled.")
388
+
389
+
390
+ def tar_and_upload_dir(
391
+ session,
392
+ bucket,
393
+ s3_key_prefix,
394
+ script,
395
+ directory=None,
396
+ dependencies=None,
397
+ kms_key=None,
398
+ s3_resource=None,
399
+ settings: Optional[SessionSettings] = None,
400
+ ) -> UploadedCode:
401
+ """Package source files and upload a compress tar file to S3.
402
+
403
+ The S3 location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
404
+ If directory is an S3 URI, an UploadedCode object will be returned, but
405
+ nothing will be uploaded to S3 (this allow reuse of code already in S3).
406
+ If directory is None, the script will be added to the archive at
407
+ ``./<basename of script>``. If directory is not None, the (recursive) contents
408
+ of the directory will be added to the archive. directory is treated as the base
409
+ path of the archive, and the script name is assumed to be a filename or relative path
410
+ inside the directory.
411
+
412
+ Args:
413
+ session (boto3.Session): Boto session used to access S3.
414
+ bucket (str): S3 bucket to which the compressed file is uploaded.
415
+ s3_key_prefix (str): Prefix for the S3 key.
416
+ script (str): Script filename or path.
417
+ directory (str): Optional. Directory containing the source file. If it
418
+ starts with "s3://", no action is taken.
419
+ dependencies (List[str]): Optional. A list of paths to directories
420
+ (absolute or relative) containing additional libraries that will be
421
+ copied into /opt/ml/lib
422
+ kms_key (str): Optional. KMS key ID used to upload objects to the bucket
423
+ (default: None).
424
+ s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
425
+ for S3 connections, can be used to customize the configuration,
426
+ e.g. set the endpoint URL (default: None).
427
+ settings (sagemaker.session_settings.SessionSettings): Optional. The settings
428
+ of the SageMaker ``Session``, can be used to override the default encryption
429
+ behavior (default: None).
430
+ Returns:
431
+ sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and
432
+ script name.
433
+ """
434
+ if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")):
435
+ return UploadedCode(s3_prefix=directory, script_name=script)
436
+
437
+ script_name = script if directory else os.path.basename(script)
438
+ dependencies = dependencies or []
439
+ key = "%s/sourcedir.tar.gz" % s3_key_prefix
440
+ if (
441
+ settings is not None
442
+ and settings.local_download_dir is not None
443
+ and not (
444
+ os.path.exists(settings.local_download_dir)
445
+ and os.path.isdir(settings.local_download_dir)
446
+ )
447
+ ):
448
+ raise ValueError(
449
+ "Inputted directory for storing newly generated temporary directory does "
450
+ f"not exist: '{settings.local_download_dir}'"
451
+ )
452
+ local_download_dir = None if settings is None else settings.local_download_dir
453
+ tmp = tempfile.mkdtemp(dir=local_download_dir)
454
+ encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts
455
+
456
+ try:
457
+ source_files = _list_files_to_compress(script, directory) + dependencies
458
+ tar_file = sagemaker_utils.create_tar_file(
459
+ source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
460
+ )
461
+
462
+ if kms_key:
463
+ extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
464
+ elif encrypt_artifact:
465
+ # encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
466
+ # see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
467
+ extra_args = {"ServerSideEncryption": "aws:kms"}
468
+ else:
469
+ extra_args = None
470
+
471
+ if s3_resource is None:
472
+ s3_resource = session.resource("s3", region_name=session.region_name)
473
+ else:
474
+ logger.debug("Using provided s3_resource")
475
+
476
+ s3_resource.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
477
+ finally:
478
+ shutil.rmtree(tmp)
479
+
480
+ return UploadedCode(s3_prefix="s3://%s/%s" % (bucket, key), script_name=script_name)
481
+
482
+
483
+ def _list_files_to_compress(script, directory):
484
+ """Placeholder docstring."""
485
+ if directory is None:
486
+ return [script]
487
+
488
+ basedir = directory if directory else os.path.dirname(script)
489
+ return [os.path.join(basedir, name) for name in os.listdir(basedir)]
490
+
491
+
492
+ def framework_name_from_image(image_uri):
493
+ # noinspection LongLine
494
+ """Extract the framework and Python version from the image name.
495
+
496
+ Args:
497
+ image_uri (str): Image URI, which should be one of the following forms:
498
+ legacy:
499
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<container_version>'
500
+ legacy:
501
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<fw_version>-<device>-<py_ver>'
502
+ current:
503
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
504
+ current:
505
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
506
+ current:
507
+ '<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
508
+ current:
509
+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-xgboost:<fw_version>-<container_version>'
510
+
511
+ Returns:
512
+ tuple: A tuple containing:
513
+
514
+ - str: The framework name
515
+ - str: The Python version
516
+ - str: The image tag
517
+ - str: If the TensorFlow image is script mode
518
+ """
519
+ sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
520
+ sagemaker_match = sagemaker_pattern.match(image_uri)
521
+ if sagemaker_match is None:
522
+ return None, None, None, None
523
+
524
+ # extract framework, python version and image tag
525
+ # We must support both the legacy and current image name format.
526
+ name_pattern = re.compile(
527
+ r"""^(?:sagemaker(?:-rl)?-)?
528
+ (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
529
+ |huggingface-tensorflow|huggingface-pytorch
530
+ |huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
531
+ (scriptmode|training)?
532
+ :(.*)-(.*?)-(py2|py3\d*)(?:.*)$""",
533
+ re.VERBOSE,
534
+ )
535
+ name_match = name_pattern.match(sagemaker_match.group(9))
536
+ if name_match is not None:
537
+ fw, scriptmode, ver, device, py = (
538
+ name_match.group(1),
539
+ name_match.group(2),
540
+ name_match.group(3),
541
+ name_match.group(4),
542
+ name_match.group(5),
543
+ )
544
+ return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
545
+
546
+ legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
547
+ legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
548
+ if legacy_match is not None:
549
+ return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
550
+
551
+ # sagemaker-xgboost images are tagged with two aliases, e.g.:
552
+ # 1. Long tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1-cpu-py3"
553
+ # 2. Short tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1"
554
+ # Note 1: Both tags point to the same image
555
+ # Note 2: Both tags have full GPU capabilities, despite "cpu" delineation in the long tag
556
+ short_xgboost_tag_pattern = re.compile(r"^sagemaker-(xgboost):(.*)$")
557
+ short_xgboost_tag_match = short_xgboost_tag_pattern.match(sagemaker_match.group(9))
558
+ if short_xgboost_tag_match is not None:
559
+ return (short_xgboost_tag_match.group(1), "py3", short_xgboost_tag_match.group(2), None)
560
+ return None, None, None, None
561
+
562
+
563
+ def framework_version_from_tag(image_tag):
564
+ """Extract the framework version from the image tag.
565
+
566
+ Args:
567
+ image_tag (str): Image tag, which should take the form
568
+ '<framework_version>-<device>-<py_version>'
569
+ '<xgboost_version>-<container_version>'
570
+
571
+ Returns:
572
+ str: The framework version.
573
+ """
574
+ tag_pattern = re.compile(r"^(.*)-(cpu|gpu)-(py2|py3\d*)$")
575
+ tag_match = tag_pattern.match(image_tag)
576
+ if tag_match is None:
577
+ short_xgboost_tag_pattern = re.compile(r"^(\d\.\d+\-\d)$")
578
+ tag_match = short_xgboost_tag_pattern.match(image_tag)
579
+ return None if tag_match is None else tag_match.group(1)
580
+
581
+
582
+ def model_code_key_prefix(code_location_key_prefix, model_name, image):
583
+ """Returns the s3 key prefix for uploading code during model deployment.
584
+
585
+ The location returned is a potential concatenation of 2 parts
586
+ 1. code_location_key_prefix if it exists
587
+ 2. model_name or a name derived from the image
588
+ Args:
589
+ code_location_key_prefix (str): the s3 key prefix from code_location
590
+ model_name (str): the name of the model
591
+ image (str): the image from which a default name can be extracted
592
+
593
+ Returns:
594
+ str: the key prefix to be used in uploading code
595
+ """
596
+ name_from_image = f"/model_code/{int(time.time())}"
597
+ if not is_pipeline_variable(image):
598
+ name_from_image = sagemaker_utils.name_from_image(image)
599
+ return s3_path_join(code_location_key_prefix, model_name or name_from_image)
600
+
601
+
602
+ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution):
603
+ """Warn the user about training when it doesn't leverage all the GPU cores.
604
+
605
+ Warn the user that training will not fully leverage all the GPU
606
+ cores if parameter server is enabled and a multi-GPU instance is selected.
607
+ Distributed training with the default parameter server setup doesn't
608
+ support multi-GPU instances.
609
+
610
+ Args:
611
+ training_instance_type (str): A string representing the type of training instance selected.
612
+ distribution (dict): A dictionary with information to enable distributed training.
613
+ (Defaults to None if distributed training is not enabled.) For example:
614
+
615
+ .. code:: python
616
+
617
+ {
618
+ "parameter_server": {
619
+ "enabled": True
620
+ }
621
+ }
622
+ """
623
+ if training_instance_type == "local" or distribution is None:
624
+ return
625
+ if is_pipeline_variable(training_instance_type):
626
+ # The training_instance_type is not available in compile time.
627
+ # Rather, it's given in Pipeline execution time
628
+ return
629
+
630
+ is_multi_gpu_instance = (
631
+ training_instance_type == "local_gpu"
632
+ or training_instance_type.split(".")[1].startswith("p")
633
+ ) and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
634
+
635
+ ps_enabled = "parameter_server" in distribution and distribution["parameter_server"].get(
636
+ "enabled", False
637
+ )
638
+
639
+ if is_multi_gpu_instance and ps_enabled:
640
+ logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
641
+
642
+
643
+ def profiler_config_deprecation_warning(
644
+ profiler_config, image_uri, framework_name, framework_version
645
+ ):
646
+ """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0."""
647
+ if profiler_config is None or profiler_config.framework_profile_params is None:
648
+ return
649
+
650
+ if framework_name not in ("pytorch", "tensorflow"):
651
+ return
652
+
653
+ if framework_version is None:
654
+ framework_name, _, image_tag, _ = framework_name_from_image(image_uri)
655
+
656
+ if image_tag is not None:
657
+ framework_version = framework_version_from_tag(image_tag)
658
+
659
+ if framework_version is not None:
660
+ framework_profile_thresh = (
661
+ version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12")
662
+ )
663
+ framework_profile = version.parse(framework_version)
664
+ if framework_profile >= framework_profile_thresh:
665
+ deprecation_warn_base(
666
+ f"Framework profiling is deprecated from\
667
+ {framework_name} version {framework_version}.\
668
+ No framework metrics will be collected"
669
+ )
670
+
671
+
672
+ def validate_smdistributed(
673
+ instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
674
+ ):
675
+ """Check if smdistributed strategy is correctly invoked by the user.
676
+
677
+ Currently, two strategies are supported: `dataparallel` or `modelparallel`.
678
+ Validate if the user requested strategy is supported.
679
+
680
+ Currently, only one strategy can be specified at a time. Validate if the user has requested
681
+ more than one strategy simultaneously.
682
+
683
+ Validate if the smdistributed dict arg is syntactically correct.
684
+
685
+ Additionally, perform strategy-specific validations.
686
+
687
+ Args:
688
+ instance_type (str): A string representing the type of training instance selected.
689
+ framework_name (str): A string representing the name of framework selected.
690
+ framework_version (str): A string representing the framework version selected.
691
+ py_version (str): A string representing the python version selected.
692
+ Ex: `py38, py39, py310, py311`
693
+ distribution (dict): A dictionary with information to enable distributed training.
694
+ (Defaults to None if distributed training is not enabled.) For example:
695
+
696
+ .. code:: python
697
+
698
+ {
699
+ "smdistributed": {
700
+ "dataparallel": {
701
+ "enabled": True
702
+ }
703
+ }
704
+ }
705
+ image_uri (str): A string representing a Docker image URI.
706
+
707
+ Raises:
708
+ ValueError: if distribution dictionary isn't correctly formatted or
709
+ multiple strategies are requested simultaneously or
710
+ an unsupported strategy is requested or
711
+ strategy-specific inputs are incorrect/unsupported
712
+ """
713
+ if "smdistributed" not in distribution:
714
+ # Distribution strategy other than smdistributed is selected
715
+ return
716
+ if is_pipeline_variable(instance_type) or is_pipeline_variable(image_uri):
717
+ # The instance_type is not available in compile time.
718
+ # Rather, it's given in Pipeline execution time
719
+ return
720
+
721
+ # distribution contains smdistributed
722
+ smdistributed = distribution["smdistributed"]
723
+ if not isinstance(smdistributed, dict):
724
+ raise ValueError("smdistributed strategy requires a dictionary")
725
+
726
+ if len(smdistributed) > 1:
727
+ # more than 1 smdistributed strategy requested by the user
728
+ err_msg = (
729
+ "Cannot use more than 1 smdistributed strategy. \n"
730
+ "Choose one of the following supported strategies:"
731
+ f"{SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
732
+ )
733
+ raise ValueError(err_msg)
734
+
735
+ # validate if smdistributed strategy is supported
736
+ # currently this for loop essentially checks for only 1 key
737
+ for strategy in smdistributed:
738
+ if strategy not in SMDISTRIBUTED_SUPPORTED_STRATEGIES:
739
+ err_msg = (
740
+ f"Invalid smdistributed strategy provided: {strategy} \n"
741
+ f"Supported strategies: {SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
742
+ )
743
+ raise ValueError(err_msg)
744
+
745
+ # smdataparallel-specific input validation
746
+ if "dataparallel" in smdistributed:
747
+ _validate_smdataparallel_args(
748
+ instance_type, framework_name, framework_version, py_version, distribution, image_uri
749
+ )
750
+
751
+
752
+ def _validate_smdataparallel_args(
753
+ instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
754
+ ):
755
+ """Check if request is using unsupported arguments.
756
+
757
+ Validate if user specifies a supported instance type, framework version, and python
758
+ version.
759
+
760
+ Args:
761
+ instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
762
+ framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
763
+ framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
764
+ py_version (str): A string representing the python version selected.
765
+ Ex: `py38, py39, py310, py311`
766
+ distribution (dict): A dictionary with information to enable distributed training.
767
+ (Defaults to None if distributed training is not enabled.) Ex:
768
+
769
+ .. code:: python
770
+
771
+ {
772
+ "smdistributed": {
773
+ "dataparallel": {
774
+ "enabled": True
775
+ }
776
+ }
777
+ }
778
+ image_uri (str): A string representing a Docker image URI.
779
+
780
+ Raises:
781
+ ValueError: if
782
+ `py_version` is not python3 or
783
+ `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
784
+ """
785
+ smdataparallel_enabled = (
786
+ distribution.get("smdistributed").get("dataparallel").get("enabled", False)
787
+ )
788
+
789
+ if not smdataparallel_enabled:
790
+ return
791
+
792
+ err_msg = ""
793
+
794
+ if not instance_type:
795
+ err_msg += "Please specify an instance_type for smdataparallel.\n"
796
+
797
+ if not image_uri:
798
+ # ignore framework_version & py_version if image_uri is set
799
+ # in case image_uri is not set, then both are mandatory
800
+ supported = SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS[framework_name]
801
+ if framework_version not in supported:
802
+ err_msg += (
803
+ f"Provided framework_version {framework_version} is not supported by"
804
+ " smdataparallel.\n"
805
+ f"Please specify one of the supported framework versions: {supported} \n"
806
+ )
807
+
808
+ if "py3" not in py_version:
809
+ err_msg += (
810
+ f"Provided py_version {py_version} is not supported by smdataparallel.\n"
811
+ "Please specify py_version>=py3"
812
+ )
813
+
814
+ if err_msg:
815
+ raise ValueError(err_msg)
816
+
817
+
818
+ def validate_distribution(
819
+ distribution: Dict,
820
+ instance_groups: List[InstanceGroup],
821
+ framework_name: str,
822
+ framework_version: str,
823
+ py_version: str,
824
+ image_uri: str,
825
+ kwargs: Dict,
826
+ ) -> Dict:
827
+ """Check if distribution strategy is correctly invoked by the user.
828
+
829
+ Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
830
+ Validate if the user requested strategy is supported.
831
+
832
+ Args:
833
+ distribution (dict): A dictionary with information to enable distributed training.
834
+ (Defaults to None if distributed training is not enabled.) For example:
835
+
836
+ .. code:: python
837
+
838
+ {
839
+ "smdistributed": {
840
+ "dataparallel": {
841
+ "enabled": True
842
+ }
843
+ }
844
+ }
845
+ instance_groups ([InstanceGroup]): A list contains instance groups used for training.
846
+ framework_name (str): A string representing the name of framework selected.
847
+ framework_version (str): A string representing the framework version selected.
848
+ py_version (str): A string representing the python version selected.
849
+ Ex: `py38, py39, py310, py311`
850
+ image_uri (str): A string representing a Docker image URI.
851
+ kwargs(dict): Additional kwargs passed to this function
852
+
853
+ Returns:
854
+ distribution(dict): updated dictionary with validated information
855
+ to enable distributed training.
856
+
857
+ Raises:
858
+ ValueError: if distribution dictionary isn't correctly formatted or
859
+ multiple strategies are requested simultaneously or
860
+ an unsupported strategy is requested or
861
+ strategy-specific inputs are incorrect/unsupported or
862
+ heterogeneous cluster set up is incorrect
863
+ """
864
+ validated_distribution = dict(distribution)
865
+
866
+ train_instance_groups = validated_distribution.get("instance_groups", [])
867
+ if instance_groups is None:
868
+ if len(train_instance_groups) >= 1:
869
+ # if estimator's instance_groups is not defined but
870
+ # train_instance_groups are specified in distribution
871
+ raise ValueError("Instance groups not specified in the estimator !")
872
+ else:
873
+ if len(train_instance_groups) > len(instance_groups):
874
+ # if train_instance_groups in distribution are more than estimator's instance_groups
875
+ raise ValueError("Train instance groups oversubscribed !")
876
+ if len(instance_groups) == 1 and len(train_instance_groups) == 0:
877
+ # if just one instance_group but it is not specified in distribution, we set it for user
878
+ train_instance_groups = instance_groups
879
+ elif len(instance_groups) > 1 and len(train_instance_groups) != 1:
880
+ # currently we just support one train instance group
881
+ raise ValueError("Distribution should only contain one instance group name !")
882
+
883
+ if len(train_instance_groups) != 0:
884
+ # in this case, we are handling a heterogeneous cluster training job
885
+ instance_group_names = []
886
+ for train_instance_group in train_instance_groups:
887
+ # in future version we will support multiple train_instance_groups, so use loop here
888
+ if train_instance_group not in instance_groups:
889
+ # check if train instance groups belongs to what user defined in estimator set up
890
+ raise ValueError(
891
+ f"Invalid training instance group {train_instance_group.instance_group_name} !"
892
+ )
893
+ instance_type = train_instance_group.instance_type
894
+ validate_distribution_for_instance_type(
895
+ instance_type=instance_type,
896
+ distribution=validated_distribution,
897
+ )
898
+ validate_smdistributed(
899
+ instance_type=instance_type,
900
+ framework_name=framework_name,
901
+ framework_version=framework_version,
902
+ py_version=py_version,
903
+ distribution=validated_distribution,
904
+ image_uri=image_uri,
905
+ )
906
+ if framework_name and framework_name == "pytorch":
907
+ # We need to validate only for PyTorch framework
908
+ validate_torch_distributed_distribution(
909
+ instance_type=instance_type,
910
+ distribution=validated_distribution,
911
+ framework_version=framework_version,
912
+ py_version=py_version,
913
+ image_uri=image_uri,
914
+ entry_point=kwargs["entry_point"],
915
+ )
916
+ warn_if_parameter_server_with_multi_gpu(
917
+ training_instance_type=instance_type, distribution=validated_distribution
918
+ )
919
+ # get instance group names
920
+ instance_group_names.append(train_instance_group.instance_group_name)
921
+ validated_distribution["instance_groups"] = instance_group_names
922
+ else:
923
+ # in this case, we are handling a normal training job (without heterogeneous cluster)
924
+ instance_type = renamed_kwargs(
925
+ "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
926
+ )
927
+ validate_distribution_for_instance_type(
928
+ instance_type=instance_type,
929
+ distribution=validated_distribution,
930
+ )
931
+ validate_smdistributed(
932
+ instance_type=instance_type,
933
+ framework_name=framework_name,
934
+ framework_version=framework_version,
935
+ py_version=py_version,
936
+ distribution=validated_distribution,
937
+ image_uri=image_uri,
938
+ )
939
+ if framework_name and framework_name == "pytorch":
940
+ # We need to validate only for PyTorch framework
941
+ validate_torch_distributed_distribution(
942
+ instance_type=instance_type,
943
+ distribution=validated_distribution,
944
+ framework_version=framework_version,
945
+ py_version=py_version,
946
+ image_uri=image_uri,
947
+ entry_point=kwargs["entry_point"],
948
+ )
949
+ warn_if_parameter_server_with_multi_gpu(
950
+ training_instance_type=instance_type, distribution=validated_distribution
951
+ )
952
+ return validated_distribution
953
+
954
+
955
+ def validate_distribution_for_instance_type(instance_type, distribution):
956
+ """Check if the provided distribution strategy is supported for the instance_type.
957
+
958
+ Args:
959
+ instance_type (str): A string representing the type of training instance selected.
960
+ distribution (dict): A dictionary with information to enable distributed training.
961
+ """
962
+ err_msg = ""
963
+ if isinstance(instance_type, str):
964
+ match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
965
+ if match and match[1].startswith("trn"):
966
+ keys = list(distribution.keys())
967
+ if len(keys) == 0:
968
+ return
969
+ if len(keys) == 1:
970
+ distribution_strategy = keys[0]
971
+ if distribution_strategy != "torch_distributed":
972
+ err_msg += (
973
+ f"Provided distribution strategy {distribution_strategy} is not supported"
974
+ " for Trainium instances.\n"
975
+ "Please specify one of the following supported distribution strategies:"
976
+ f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
977
+ )
978
+ elif len(keys) > 1:
979
+ err_msg += (
980
+ "Multiple distribution strategies are not supported for Trainium instances.\n"
981
+ "Please specify one of the following supported distribution strategies:"
982
+ f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
983
+ )
984
+
985
+ if err_msg:
986
+ raise ValueError(err_msg)
987
+
988
+
989
+ def validate_torch_distributed_distribution(
990
+ instance_type,
991
+ distribution,
992
+ framework_version,
993
+ py_version,
994
+ image_uri,
995
+ entry_point,
996
+ ):
997
+ """Check if torch_distributed distribution strategy is correctly invoked by the user.
998
+
999
+ Args:
1000
+ instance_type (str): A string representing the type of training instance selected.
1001
+ distribution (dict): A dictionary with information to enable distributed training.
1002
+ (Defaults to None if distributed training is not enabled.) For example:
1003
+
1004
+ .. code:: python
1005
+
1006
+ {
1007
+ "torch_distributed": {
1008
+ "enabled": True
1009
+ }
1010
+ }
1011
+ framework_version (str): A string representing the framework version selected.
1012
+ py_version (str): A string representing the python version selected.
1013
+ Ex: `py38, py39, py310, py311`
1014
+ image_uri (str): A string representing a Docker image URI.
1015
+ entry_point (str or PipelineVariable): The absolute or relative path to the local Python
1016
+ source file that should be executed as the entry point to
1017
+ training.
1018
+
1019
+ Raises:
1020
+ ValueError: if
1021
+ `py_version` is not python3 or
1022
+ `framework_version` is not compatible with instance types
1023
+ """
1024
+ torch_distributed_enabled = False
1025
+ if "torch_distributed" in distribution:
1026
+ torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
1027
+ if not torch_distributed_enabled:
1028
+ # Distribution strategy other than torch_distributed is selected
1029
+ return
1030
+
1031
+ err_msg = ""
1032
+
1033
+ if not image_uri:
1034
+ # ignore framework_version and py_version if image_uri is set
1035
+ # in case image_uri is not set, then both are mandatory
1036
+ if "py3" not in py_version:
1037
+ err_msg += (
1038
+ f"Provided py_version {py_version} is not supported by torch_distributed.\n"
1039
+ "Please specify py_version>=py3\n"
1040
+ )
1041
+
1042
+ # Check instance and framework_version compatibility
1043
+ if _is_gpu_instance(instance_type):
1044
+ if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS:
1045
+ err_msg += (
1046
+ f"Provided framework_version {framework_version} is not supported by"
1047
+ f" torch_distributed for instance {instance_type}.\n"
1048
+ "Please specify one of the supported framework versions:"
1049
+ f"{TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS} \n"
1050
+ )
1051
+ elif _is_trainium_instance(instance_type):
1052
+ if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS:
1053
+ err_msg += (
1054
+ f"Provided framework_version {framework_version} is not supported by"
1055
+ f" torch_distributed for instance {instance_type}.\n"
1056
+ "Please specify one of the supported framework versions:"
1057
+ f"{TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS} \n"
1058
+ )
1059
+ else:
1060
+ err_msg += (
1061
+ "Currently torch_distributed is supported only for GPU and Trainium instances.\n"
1062
+ )
1063
+
1064
+ # Check entry point type
1065
+ if not entry_point.endswith(".py"):
1066
+ err_msg += (
1067
+ "Unsupported entry point type for the distribution torch_distributed.\n"
1068
+ "Only python programs (*.py) are supported."
1069
+ )
1070
+
1071
+ if err_msg:
1072
+ raise ValueError(err_msg)
1073
+
1074
+
1075
+ def _is_gpu_instance(instance_type):
1076
+ """Returns bool indicating whether instance_type supports GPU.
1077
+
1078
+ Args:
1079
+ instance_type (str): Name of the instance_type to check against.
1080
+
1081
+ Returns:
1082
+ bool: Whether or not the instance_type supports GPU
1083
+ """
1084
+ if isinstance(instance_type, str):
1085
+ match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1086
+ if match:
1087
+ if match[1].startswith("p") or match[1].startswith("g"):
1088
+ return True
1089
+ if instance_type == "local_gpu":
1090
+ return True
1091
+ return False
1092
+
1093
+
1094
+ def _is_trainium_instance(instance_type):
1095
+ """Returns bool indicating whether instance_type is a Trainium instance.
1096
+
1097
+ Args:
1098
+ instance_type (str): Name of the instance_type to check against.
1099
+
1100
+ Returns:
1101
+ bool: Whether or not the instance_type is a Trainium instance
1102
+ """
1103
+ if isinstance(instance_type, str):
1104
+ match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1105
+ if match and match[1].startswith("trn"):
1106
+ return True
1107
+ return False
1108
+
1109
+
1110
+ def python_deprecation_warning(framework, latest_supported_version):
1111
+ """Placeholder docstring."""
1112
+ return PYTHON_2_DEPRECATION_WARNING.format(
1113
+ framework=framework, latest_supported_version=latest_supported_version
1114
+ )
1115
+
1116
+
1117
+ def _region_supports_debugger(region_name):
1118
+ """Returns boolean indicating whether the region supports Amazon SageMaker Debugger.
1119
+
1120
+ Args:
1121
+ region_name (str): Name of the region to check against.
1122
+
1123
+ Returns:
1124
+ bool: Whether or not the region supports Amazon SageMaker Debugger.
1125
+ """
1126
+ return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
1127
+
1128
+
1129
+ def _region_supports_profiler(region_name):
1130
+ """Returns bool indicating whether region supports Amazon SageMaker Debugger profiling feature.
1131
+
1132
+ Args:
1133
+ region_name (str): Name of the region to check against.
1134
+
1135
+ Returns:
1136
+ bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1137
+ """
1138
+ return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
1139
+
1140
+
1141
+ def _instance_type_supports_profiler(instance_type):
1142
+ """Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
1143
+
1144
+ Args:
1145
+ instance_type (str): Name of the instance_type to check against.
1146
+
1147
+ Returns:
1148
+ bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1149
+ """
1150
+ if isinstance(instance_type, str):
1151
+ match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1152
+ if match and match[1].startswith("trn"):
1153
+ return True
1154
+ return False
1155
+
1156
+
1157
+ def validate_version_or_image_args(framework_version, py_version, image_uri):
1158
+ """Checks if version or image arguments are specified.
1159
+
1160
+ Validates framework and model arguments to enforce version or image specification.
1161
+
1162
+ Args:
1163
+ framework_version (str): The version of the framework.
1164
+ py_version (str): A string representing the python version selected.
1165
+ Ex: `py38, py39, py310, py311`
1166
+ image_uri (str): The URI of the image.
1167
+
1168
+ Raises:
1169
+ ValueError: if `image_uri` is None and either `framework_version` or `py_version` is
1170
+ None.
1171
+ """
1172
+ if (framework_version is None or py_version is None) and image_uri is None:
1173
+ raise ValueError(
1174
+ "framework_version or py_version was None, yet image_uri was also None. "
1175
+ "Either specify both framework_version and py_version, or specify image_uri."
1176
+ )