sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,305 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module is used to define the environment variables for the training job container."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import Dict, Any
17
+ import multiprocessing
18
+ import subprocess
19
+ import json
20
+ import os
21
+ import sys
22
+ from pathlib import Path
23
+ import logging
24
+
25
+ sys.path.insert(0, str(Path(__file__).parent.parent))
26
+
27
+ from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
28
+ safe_serialize,
29
+ safe_deserialize,
30
+ read_distributed_json,
31
+ read_source_code_json,
32
+ )
33
+
34
+ # Initialize logger
35
+ SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
36
+ logger = logging.getLogger(__name__)
37
+ console_handler = logging.StreamHandler(sys.stdout)
38
+ logger.addHandler(console_handler)
39
+ logger.setLevel(int(SM_LOG_LEVEL))
40
+
41
+ SM_MODEL_DIR = "/opt/ml/model"
42
+
43
+ SM_INPUT_DIR = "/opt/ml/input"
44
+ SM_INPUT_DATA_DIR = "/opt/ml/input/data"
45
+ SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"
46
+
47
+ SM_OUTPUT_DIR = "/opt/ml/output"
48
+ SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
49
+ SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
50
+ SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code"
51
+ SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers"
52
+
53
+ SM_MASTER_ADDR = "algo-1"
54
+ SM_MASTER_PORT = 7777
55
+
56
+ RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
57
+ INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json"
58
+ HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json"
59
+
60
+ ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env"
61
+
62
+ SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
63
+ HIDDEN_VALUE = "******"
64
+
65
+
66
+ def num_cpus() -> int:
67
+ """Return the number of CPUs available in the current container.
68
+
69
+ Returns:
70
+ int: Number of CPUs available in the current container.
71
+ """
72
+ return multiprocessing.cpu_count()
73
+
74
+
75
+ def num_gpus() -> int:
76
+ """Return the number of GPUs available in the current container.
77
+
78
+ Returns:
79
+ int: Number of GPUs available in the current container.
80
+ """
81
+ try:
82
+ cmd = ["nvidia-smi", "--list-gpus"]
83
+ output = subprocess.check_output(cmd).decode("utf-8")
84
+ return sum(1 for line in output.splitlines() if line.startswith("GPU "))
85
+ except (OSError, subprocess.CalledProcessError):
86
+ logger.info("No GPUs detected (normal if no gpus installed)")
87
+ return 0
88
+
89
+
90
+ def num_neurons() -> int:
91
+ """Return the number of neuron cores available in the current container.
92
+
93
+ Returns:
94
+ int: Number of Neuron Cores available in the current container.
95
+ """
96
+ try:
97
+ cmd = ["neuron-ls", "-j"]
98
+ output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8")
99
+ j = json.loads(output)
100
+ neuron_cores = 0
101
+ for item in j:
102
+ neuron_cores += item.get("nc_count", 0)
103
+ logger.info("Found %s neurons on this instance", neuron_cores)
104
+ return neuron_cores
105
+ except OSError:
106
+ logger.info("No Neurons detected (normal if no neurons installed)")
107
+ return 0
108
+ except subprocess.CalledProcessError as e:
109
+ if e.output is not None:
110
+ try:
111
+ msg = e.output.decode("utf-8").partition("error=")[2]
112
+ logger.info(
113
+ "No Neurons detected (normal if no neurons installed). \
114
+ If neuron installed then %s",
115
+ msg,
116
+ )
117
+ except AttributeError:
118
+ logger.info("No Neurons detected (normal if no neurons installed)")
119
+ else:
120
+ logger.info("No Neurons detected (normal if no neurons installed)")
121
+
122
+ return 0
123
+
124
+
125
+ def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]:
126
+ """Deserialize hyperparameters from string to their original types.
127
+
128
+ Args:
129
+ hyperparameters (Dict[str, str]): Hyperparameters as strings.
130
+
131
+ Returns:
132
+ Dict[str, Any]: Hyperparameters as their original types.
133
+ """
134
+ deserialized_hyperparameters = {}
135
+ for key, value in hyperparameters.items():
136
+ deserialized_hyperparameters[key] = safe_deserialize(value)
137
+ return deserialized_hyperparameters
138
+
139
+
140
+ def set_env(
141
+ resource_config: Dict[str, Any],
142
+ input_data_config: Dict[str, Any],
143
+ hyperparameters_config: Dict[str, Any],
144
+ output_file: str = ENV_OUTPUT_FILE,
145
+ ):
146
+ """Set environment variables for the training job container.
147
+
148
+ Args:
149
+ resource_config (Dict[str, Any]): Resource configuration for the training job.
150
+ input_data_config (Dict[str, Any]): Input data configuration for the training job.
151
+ hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job.
152
+ output_file (str): Output file to write the environment variables.
153
+ """
154
+ # Constants
155
+ env_vars = {
156
+ "SM_MODEL_DIR": SM_MODEL_DIR,
157
+ "SM_INPUT_DIR": SM_INPUT_DIR,
158
+ "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR,
159
+ "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR,
160
+ "SM_OUTPUT_DIR": SM_OUTPUT_DIR,
161
+ "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE,
162
+ "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR,
163
+ "SM_LOG_LEVEL": SM_LOG_LEVEL,
164
+ "SM_MASTER_ADDR": SM_MASTER_ADDR,
165
+ "SM_MASTER_PORT": SM_MASTER_PORT,
166
+ }
167
+
168
+ # SourceCode and DistributedConfig Environment Variables
169
+ source_code = read_source_code_json()
170
+ if source_code:
171
+ env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH
172
+ env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "")
173
+
174
+ distributed = read_distributed_json()
175
+ if distributed:
176
+ env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH
177
+ env_vars["SM_DISTRIBUTED_CONFIG"] = distributed
178
+
179
+ # Data Channels
180
+ channels = list(input_data_config.keys())
181
+ for channel in channels:
182
+ env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}"
183
+ env_vars["SM_CHANNELS"] = channels
184
+
185
+ # Hyperparameters
186
+ hps = deserialize_hyperparameters(hyperparameters_config)
187
+ for key, value in hps.items():
188
+ key_upper = key.replace("-", "_").upper()
189
+ env_vars[f"SM_HP_{key_upper}"] = value
190
+ env_vars["SM_HPS"] = hps
191
+
192
+ # Host Variables
193
+ current_host = resource_config["current_host"]
194
+ current_instance_type = resource_config["current_instance_type"]
195
+ hosts = resource_config["hosts"]
196
+ sorted_hosts = sorted(hosts)
197
+
198
+ env_vars["SM_CURRENT_HOST"] = current_host
199
+ env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type
200
+ env_vars["SM_HOSTS"] = sorted_hosts
201
+ env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"]
202
+ env_vars["SM_HOST_COUNT"] = len(sorted_hosts)
203
+ env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host)
204
+
205
+ env_vars["SM_NUM_CPUS"] = num_cpus()
206
+ env_vars["SM_NUM_GPUS"] = num_gpus()
207
+ env_vars["SM_NUM_NEURONS"] = num_neurons()
208
+
209
+ # Misc.
210
+ env_vars["SM_RESOURCE_CONFIG"] = resource_config
211
+ env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config
212
+
213
+ # All Training Environment Variables
214
+ env_vars["SM_TRAINING_ENV"] = {
215
+ "channel_input_dirs": {
216
+ channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels
217
+ },
218
+ "current_host": env_vars["SM_CURRENT_HOST"],
219
+ "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"],
220
+ "hosts": env_vars["SM_HOSTS"],
221
+ "master_addr": env_vars["SM_MASTER_ADDR"],
222
+ "master_port": env_vars["SM_MASTER_PORT"],
223
+ "hyperparameters": env_vars["SM_HPS"],
224
+ "input_data_config": input_data_config,
225
+ "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"],
226
+ "input_data_dir": env_vars["SM_INPUT_DATA_DIR"],
227
+ "input_dir": env_vars["SM_INPUT_DIR"],
228
+ "job_name": os.environ["TRAINING_JOB_NAME"],
229
+ "log_level": env_vars["SM_LOG_LEVEL"],
230
+ "model_dir": env_vars["SM_MODEL_DIR"],
231
+ "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"],
232
+ "num_cpus": env_vars["SM_NUM_CPUS"],
233
+ "num_gpus": env_vars["SM_NUM_GPUS"],
234
+ "num_neurons": env_vars["SM_NUM_NEURONS"],
235
+ "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"],
236
+ "resource_config": env_vars["SM_RESOURCE_CONFIG"],
237
+ }
238
+ with open(output_file, "w") as f:
239
+ for key, value in env_vars.items():
240
+ f.write(f"export {key}='{safe_serialize(value)}'\n")
241
+
242
+ logger.info("Environment Variables:")
243
+ log_env_variables(env_vars_dict=env_vars)
244
+
245
+
246
+ def mask_sensitive_info(data):
247
+ """Recursively mask sensitive information in a dictionary."""
248
+ if isinstance(data, dict):
249
+ for k, v in data.items():
250
+ if isinstance(v, dict):
251
+ data[k] = mask_sensitive_info(v)
252
+ elif isinstance(v, str) and any(
253
+ keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS
254
+ ):
255
+ data[k] = HIDDEN_VALUE
256
+ return data
257
+
258
+
259
+ def log_key_value(key: str, value: str):
260
+ """Log a key-value pair, masking sensitive values if necessary."""
261
+ if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS):
262
+ logger.info("%s=%s", key, HIDDEN_VALUE)
263
+ elif isinstance(value, dict):
264
+ masked_value = mask_sensitive_info(value)
265
+ logger.info("%s=%s", key, json.dumps(masked_value))
266
+ else:
267
+ try:
268
+ decoded_value = json.loads(value)
269
+ if isinstance(decoded_value, dict):
270
+ masked_value = mask_sensitive_info(decoded_value)
271
+ logger.info("%s=%s", key, json.dumps(masked_value))
272
+ else:
273
+ logger.info("%s=%s", key, decoded_value)
274
+ except (json.JSONDecodeError, TypeError):
275
+ logger.info("%s=%s", key, value)
276
+
277
+
278
+ def log_env_variables(env_vars_dict: Dict[str, Any]):
279
+ """Log Environment Variables from the environment and an env_vars_dict."""
280
+ for key, value in os.environ.items():
281
+ log_key_value(key, value)
282
+
283
+ for key, value in env_vars_dict.items():
284
+ log_key_value(key, value)
285
+
286
+
287
+ def main():
288
+ """Main function to set the environment variables for the training job container."""
289
+ with open(RESOURCE_CONFIG, "r") as f:
290
+ resource_config = json.load(f)
291
+ with open(INPUT_DATA_CONFIG, "r") as f:
292
+ input_data_config = json.load(f)
293
+ with open(HYPERPARAMETERS_CONFIG, "r") as f:
294
+ hyperparameters_config = json.load(f)
295
+
296
+ set_env(
297
+ resource_config=resource_config,
298
+ input_data_config=input_data_config,
299
+ hyperparameters_config=hyperparameters_config,
300
+ output_file=ENV_OUTPUT_FILE,
301
+ )
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
File without changes
@@ -0,0 +1,330 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Utility functions for SageMaker training recipes."""
14
+ from __future__ import absolute_import
15
+
16
+ import math
17
+ import os
18
+ import json
19
+ import shutil
20
+ import tempfile
21
+ from urllib.request import urlretrieve
22
+ from typing import Dict, Any, Optional, Tuple
23
+
24
+ import omegaconf
25
+ from omegaconf import OmegaConf, dictconfig
26
+
27
+ from sagemaker.core.image_uris import retrieve
28
+
29
+ from sagemaker.core.modules import logger
30
+ from sagemaker.core.modules.utils import _run_clone_command_silent
31
+ from sagemaker.core.modules.configs import Compute, SourceCode
32
+ from sagemaker.core.modules.distributed import Torchrun, SMP
33
+
34
+
35
+ def _try_resolve_recipe(recipe, key=None):
36
+ """Try to resolve recipe and return resolved recipe."""
37
+ if key is not None:
38
+ recipe = dictconfig.DictConfig({key: recipe})
39
+ try:
40
+ OmegaConf.resolve(recipe)
41
+ except omegaconf.errors.OmegaConfBaseException:
42
+ return None
43
+ if key is None:
44
+ return recipe
45
+ return recipe[key]
46
+
47
+
48
+ def _determine_device_type(instance_type: str) -> str:
49
+ """Determine device type (gpu, cpu, trainium) based on instance type."""
50
+ instance_family = instance_type.split(".")[1]
51
+ if instance_family.startswith(("p", "g")):
52
+ return "gpu"
53
+ if instance_family.startswith("trn"):
54
+ return "trainium"
55
+ return "cpu"
56
+
57
+
58
+ def _load_recipes_cfg() -> str:
59
+ """Load training recipes configuration json."""
60
+ training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json")
61
+ with open(training_recipes_cfg_filename) as training_recipes_cfg_file:
62
+ training_recipes_cfg = json.load(training_recipes_cfg_file)
63
+ return training_recipes_cfg
64
+
65
+
66
+ def _load_base_recipe(
67
+ training_recipe: str,
68
+ recipe_overrides: Optional[Dict[str, Any]] = None,
69
+ training_recipes_cfg: Optional[Dict[str, Any]] = None,
70
+ ) -> Dict[str, Any]:
71
+ """Load recipe and apply overrides."""
72
+ if recipe_overrides is None:
73
+ recipe_overrides = dict()
74
+
75
+ temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name
76
+
77
+ if training_recipe.endswith(".yaml"):
78
+ if os.path.isfile(training_recipe):
79
+ shutil.copy(training_recipe, temp_local_recipe)
80
+ else:
81
+ try:
82
+ urlretrieve(training_recipe, temp_local_recipe)
83
+ except Exception as e:
84
+ raise ValueError(
85
+ f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
86
+ )
87
+ else:
88
+ recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
89
+
90
+ launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get(
91
+ "launcher_repo"
92
+ )
93
+ _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name)
94
+
95
+ recipe = os.path.join(
96
+ recipe_launcher_dir.name,
97
+ "recipes_collection",
98
+ "recipes",
99
+ training_recipe + ".yaml",
100
+ )
101
+ if os.path.isfile(recipe):
102
+ shutil.copy(recipe, temp_local_recipe)
103
+ else:
104
+ raise ValueError(f"Recipe {training_recipe} not found.")
105
+
106
+ recipe = OmegaConf.load(temp_local_recipe)
107
+ os.unlink(temp_local_recipe)
108
+ recipe = OmegaConf.merge(recipe, recipe_overrides)
109
+ return recipe
110
+
111
+
112
+ def _register_custom_resolvers():
113
+ """Register custom resolvers for OmegaConf."""
114
+ if not OmegaConf.has_resolver("multiply"):
115
+ OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
116
+ if not OmegaConf.has_resolver("divide_ceil"):
117
+ OmegaConf.register_new_resolver(
118
+ "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True
119
+ )
120
+ if not OmegaConf.has_resolver("divide_floor"):
121
+ OmegaConf.register_new_resolver(
122
+ "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True
123
+ )
124
+ if not OmegaConf.has_resolver("add"):
125
+ OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
126
+
127
+
128
+ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
129
+ """Get the model base name and script for the training recipe."""
130
+
131
+ model_type_to_script = {
132
+ "llama_v3": ("llama", "llama_pretrain.py"),
133
+ "mistral": ("mistral", "mistral_pretrain.py"),
134
+ "mixtral": ("mixtral", "mixtral_pretrain.py"),
135
+ "deepseek": ("deepseek", "deepseek_pretrain.py"),
136
+ }
137
+
138
+ for key in model_type_to_script:
139
+ if model_type.startswith(key):
140
+ model_type = key
141
+ break
142
+
143
+ if model_type not in model_type_to_script:
144
+ raise ValueError(f"Model type {model_type} not supported")
145
+
146
+ return model_type_to_script[model_type][0], model_type_to_script[model_type][1]
147
+
148
+
149
+ def _configure_gpu_args(
150
+ training_recipes_cfg: Dict[str, Any],
151
+ region_name: str,
152
+ recipe: OmegaConf,
153
+ recipe_train_dir: tempfile.TemporaryDirectory,
154
+ ) -> Dict[str, Any]:
155
+ """Configure arguments specific to GPU."""
156
+ source_code = SourceCode()
157
+ args = dict()
158
+
159
+ adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get(
160
+ "adapter_repo"
161
+ )
162
+ _run_clone_command_silent(adapter_repo, recipe_train_dir.name)
163
+
164
+ if "model" not in recipe:
165
+ raise ValueError("Supplied recipe does not contain required field model.")
166
+ if "model_type" not in recipe["model"]:
167
+ raise ValueError("Supplied recipe does not contain required field model_type.")
168
+ model_type = recipe["model"]["model_type"]
169
+
170
+ model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
171
+
172
+ source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name)
173
+ source_code.entry_script = script
174
+
175
+ gpu_image_cfg = training_recipes_cfg.get("gpu_image")
176
+ if isinstance(gpu_image_cfg, str):
177
+ training_image = gpu_image_cfg
178
+ else:
179
+ training_image = retrieve(
180
+ gpu_image_cfg.get("framework"),
181
+ region=region_name,
182
+ version=gpu_image_cfg.get("version"),
183
+ image_scope="training",
184
+ **gpu_image_cfg.get("additional_args"),
185
+ )
186
+
187
+ # Setting dummy parameters for now
188
+ torch_distributed = Torchrun(smp=SMP(random_seed="123456"))
189
+ args.update(
190
+ {
191
+ "source_code": source_code,
192
+ "training_image": training_image,
193
+ "distributed": torch_distributed,
194
+ }
195
+ )
196
+ return args
197
+
198
+
199
+ def _configure_trainium_args(
200
+ training_recipes_cfg: Dict[str, Any],
201
+ region_name: str,
202
+ recipe_train_dir: tempfile.TemporaryDirectory,
203
+ ) -> Dict[str, Any]:
204
+ """Configure arguments specific to Trainium."""
205
+ source_code = SourceCode()
206
+ args = dict()
207
+
208
+ _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
209
+
210
+ source_code.source_dir = os.path.join(recipe_train_dir.name, "examples")
211
+ source_code.entry_script = "training_orchestrator.py"
212
+ neuron_image_cfg = training_recipes_cfg.get("neuron_image")
213
+ if isinstance(neuron_image_cfg, str):
214
+ training_image = neuron_image_cfg
215
+ else:
216
+ training_image = retrieve(
217
+ neuron_image_cfg.get("framework"),
218
+ region=region_name,
219
+ version=neuron_image_cfg.get("version"),
220
+ image_scope="training",
221
+ **neuron_image_cfg.get("additional_args"),
222
+ )
223
+
224
+ args.update(
225
+ {
226
+ "source_code": source_code,
227
+ "training_image": training_image,
228
+ "distributed": Torchrun(),
229
+ }
230
+ )
231
+ return args
232
+
233
+
234
+ def _get_args_from_recipe(
235
+ training_recipe: str,
236
+ compute: Compute,
237
+ region_name: str,
238
+ recipe_overrides: Optional[Dict[str, Any]],
239
+ requirements: Optional[str],
240
+ ) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
241
+ """Get arguments for ModelTrainer from a training recipe.
242
+
243
+ Returns a dictionary of arguments to be used with ModelTrainer like:
244
+ ```python
245
+ {
246
+ "source_code": SourceCode,
247
+ "training_image": str,
248
+ "distributed": DistributedConfig,
249
+ "compute": Compute,
250
+ "hyperparameters": Dict[str, Any],
251
+ }
252
+ ```
253
+
254
+ Args:
255
+ training_recipe (str):
256
+ Name of the training recipe or path to the recipe file.
257
+ compute (Compute):
258
+ Compute configuration for training.
259
+ region_name (str):
260
+ Name of the AWS region.
261
+ recipe_overrides (Optional[Dict[str, Any]]):
262
+ Overrides for the training recipe.
263
+ requirements (Optional[str]):
264
+ Path to the requirements file.
265
+ """
266
+ if compute.instance_type is None:
267
+ raise ValueError("Must set `instance_type` in compute when using training recipes.")
268
+
269
+ training_recipes_cfg = _load_recipes_cfg()
270
+ recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
271
+
272
+ if "trainer" not in recipe:
273
+ raise ValueError("Supplied recipe does not contain required field trainer.")
274
+
275
+ # Set instance_count
276
+ if compute.instance_count and "num_nodes" in recipe["trainer"]:
277
+ logger.warning(
278
+ f"Using Compute to set instance_count:\n{compute}."
279
+ "\nIgnoring trainer -> num_nodes in recipe."
280
+ )
281
+ if compute.instance_count is None:
282
+ if "num_nodes" not in recipe["trainer"]:
283
+ raise ValueError(
284
+ "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
285
+ )
286
+ compute.instance_count = recipe["trainer"]["num_nodes"]
287
+
288
+ if requirements and not os.path.isfile(requirements):
289
+ raise ValueError(f"Recipe requirements file {requirements} not found.")
290
+
291
+ # Get Training Image, SourceCode, and distributed args
292
+ device_type = _determine_device_type(compute.instance_type)
293
+ recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
294
+ if device_type == "gpu":
295
+ args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir)
296
+ elif device_type == "trainium":
297
+ args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir)
298
+ else:
299
+ raise ValueError(f"Devices of type {device_type} are not supported with training recipes.")
300
+
301
+ _register_custom_resolvers()
302
+
303
+ # Resolve Final Recipe
304
+ final_recipe = _try_resolve_recipe(recipe)
305
+ if final_recipe is None:
306
+ final_recipe = _try_resolve_recipe(recipe, "recipes")
307
+ if final_recipe is None:
308
+ final_recipe = _try_resolve_recipe(recipe, "training")
309
+ if final_recipe is None:
310
+ raise RuntimeError("Could not resolve provided recipe.")
311
+
312
+ # Save Final Recipe to source_dir
313
+ OmegaConf.save(
314
+ config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml")
315
+ )
316
+
317
+ # If recipe_requirements is provided, copy it to source_dir
318
+ if requirements:
319
+ shutil.copy(requirements, args["source_code"].source_dir)
320
+ args["source_code"].requirements = os.path.basename(requirements)
321
+
322
+ # Update args with compute and hyperparameters
323
+ args.update(
324
+ {
325
+ "compute": compute,
326
+ "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"},
327
+ }
328
+ )
329
+
330
+ return args, recipe_train_dir
@@ -0,0 +1,19 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Types module."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import Union
17
+ from sagemaker.core.modules.configs import S3DataSource, FileSystemDataSource
18
+
19
+ DataSourceType = Union[str, S3DataSource, FileSystemDataSource]