sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,38 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Utilities related to logging."""
14
+ from __future__ import absolute_import
15
+
16
+ import logging
17
+ import time
18
+
19
+
20
+ class _UTCFormatter(logging.Formatter):
21
+ """Class that overrides the default local time provider in log formatter."""
22
+
23
+ converter = time.gmtime
24
+
25
+
26
+ def get_logger():
27
+ """Return a logger with the name 'sagemaker'"""
28
+ sagemaker_logger = logging.getLogger("sagemaker.remote_function")
29
+ if len(sagemaker_logger.handlers) == 0:
30
+ sagemaker_logger.setLevel(logging.INFO)
31
+ handler = logging.StreamHandler()
32
+ formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s")
33
+ handler.setFormatter(formatter)
34
+ sagemaker_logger.addHandler(handler)
35
+ # don't stream logs with the root logger handler
36
+ sagemaker_logger.propagate = 0
37
+
38
+ return sagemaker_logger
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Sagemaker modules container_drivers directory."""
14
+ from __future__ import absolute_import
@@ -0,0 +1,605 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """An entry point for runtime environment. This must be kept independent of SageMaker PySDK"""
14
+ from __future__ import absolute_import
15
+
16
+ import argparse
17
+ import getpass
18
+ import json
19
+ import multiprocessing
20
+ import os
21
+ import pathlib
22
+ import shutil
23
+ import subprocess
24
+ import sys
25
+ from typing import Any, Dict
26
+
27
+ if __package__ is None or __package__ == "":
28
+ from runtime_environment_manager import (
29
+ RuntimeEnvironmentManager,
30
+ _DependencySettings,
31
+ get_logger,
32
+ )
33
+ else:
34
+ from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
35
+ RuntimeEnvironmentManager,
36
+ _DependencySettings,
37
+ get_logger,
38
+ )
39
+
40
+ SUCCESS_EXIT_CODE = 0
41
+ DEFAULT_FAILURE_CODE = 1
42
+
43
+ REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
44
+ BASE_CHANNEL_PATH = "/opt/ml/input/data"
45
+ FAILURE_REASON_PATH = "/opt/ml/output/failure"
46
+ JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"]
47
+ PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
48
+ JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
49
+ SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies"
50
+
51
+ SM_MODEL_DIR = "/opt/ml/model"
52
+
53
+ SM_INPUT_DIR = "/opt/ml/input"
54
+ SM_INPUT_DATA_DIR = "/opt/ml/input/data"
55
+ SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"
56
+
57
+ SM_OUTPUT_DIR = "/opt/ml/output"
58
+ SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
59
+ SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
60
+
61
+ SM_MASTER_ADDR = "algo-1"
62
+ SM_MASTER_PORT = 7777
63
+
64
+ RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
65
+ ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env"
66
+
67
+ SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
68
+ HIDDEN_VALUE = "******"
69
+
70
+ SM_EFA_NCCL_INSTANCES = [
71
+ "ml.g4dn.8xlarge",
72
+ "ml.g4dn.12xlarge",
73
+ "ml.g5.48xlarge",
74
+ "ml.p3dn.24xlarge",
75
+ "ml.p4d.24xlarge",
76
+ "ml.p4de.24xlarge",
77
+ "ml.p5.48xlarge",
78
+ "ml.trn1.32xlarge",
79
+ ]
80
+
81
+ SM_EFA_RDMA_INSTANCES = [
82
+ "ml.p4d.24xlarge",
83
+ "ml.p4de.24xlarge",
84
+ "ml.trn1.32xlarge",
85
+ ]
86
+
87
+ logger = get_logger()
88
+
89
+
90
+ def _bootstrap_runtime_env_for_remote_function(
91
+ client_python_version: str,
92
+ conda_env: str = None,
93
+ dependency_settings: _DependencySettings = None,
94
+ ):
95
+ """Bootstrap runtime environment for remote function invocation.
96
+
97
+ Args:
98
+ client_python_version (str): Python version at the client side.
99
+ conda_env (str): conda environment to be activated. Default is None.
100
+ dependency_settings (dict): Settings for installing dependencies.
101
+ """
102
+
103
+ workspace_unpack_dir = _unpack_user_workspace()
104
+ if not workspace_unpack_dir:
105
+ logger.info("No workspace to unpack and setup.")
106
+ return
107
+
108
+ _handle_pre_exec_scripts(workspace_unpack_dir)
109
+
110
+ _install_dependencies(
111
+ workspace_unpack_dir,
112
+ conda_env,
113
+ client_python_version,
114
+ REMOTE_FUNCTION_WORKSPACE,
115
+ dependency_settings,
116
+ )
117
+
118
+
119
+ def _bootstrap_runtime_env_for_pipeline_step(
120
+ client_python_version: str,
121
+ func_step_workspace: str,
122
+ conda_env: str = None,
123
+ dependency_settings: _DependencySettings = None,
124
+ ):
125
+ """Bootstrap runtime environment for pipeline step invocation.
126
+
127
+ Args:
128
+ client_python_version (str): Python version at the client side.
129
+ func_step_workspace (str): s3 folder where workspace for FunctionStep is stored
130
+ conda_env (str): conda environment to be activated. Default is None.
131
+ dependency_settings (dict): Name of the dependency file. Default is None.
132
+ """
133
+
134
+ workspace_dir = _unpack_user_workspace(func_step_workspace)
135
+ if not workspace_dir:
136
+ os.mkdir(JOB_REMOTE_FUNCTION_WORKSPACE)
137
+ workspace_dir = pathlib.Path(os.getcwd(), JOB_REMOTE_FUNCTION_WORKSPACE).absolute()
138
+
139
+ pre_exec_script_and_dependencies_dir = os.path.join(
140
+ BASE_CHANNEL_PATH, SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME
141
+ )
142
+
143
+ if not os.path.exists(pre_exec_script_and_dependencies_dir):
144
+ logger.info("No dependencies to bootstrap")
145
+ return
146
+ for file in os.listdir(pre_exec_script_and_dependencies_dir):
147
+ src_path = os.path.join(pre_exec_script_and_dependencies_dir, file)
148
+ dest_path = os.path.join(workspace_dir, file)
149
+ shutil.copy(src_path, dest_path)
150
+
151
+ _handle_pre_exec_scripts(workspace_dir)
152
+
153
+ _install_dependencies(
154
+ workspace_dir,
155
+ conda_env,
156
+ client_python_version,
157
+ SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME,
158
+ dependency_settings,
159
+ )
160
+
161
+
162
+ def _handle_pre_exec_scripts(script_file_dir: str):
163
+ """Run the pre execution scripts.
164
+
165
+ Args:
166
+ script_file_dir (str): Directory in the container where pre-execution scripts exists.
167
+ """
168
+
169
+ path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME)
170
+ if os.path.isfile(path_to_pre_exec_script):
171
+ RuntimeEnvironmentManager().run_pre_exec_script(
172
+ pre_exec_script_path=path_to_pre_exec_script
173
+ )
174
+
175
+
176
+ def _install_dependencies(
177
+ dependency_file_dir: str,
178
+ conda_env: str,
179
+ client_python_version: str,
180
+ channel_name: str,
181
+ dependency_settings: _DependencySettings = None,
182
+ ):
183
+ """Install dependencies in the job container
184
+
185
+ Args:
186
+ dependency_file_dir (str): Directory in the container where dependency file exists.
187
+ conda_env (str): conda environment to be activated.
188
+ client_python_version (str): Python version at the client side.
189
+ channel_name (str): Channel where dependency file was uploaded.
190
+ dependency_settings (dict): Settings for installing dependencies.
191
+ """
192
+
193
+ if dependency_settings is not None and dependency_settings.dependency_file is None:
194
+ # an empty dict is passed when no dependencies are specified
195
+ logger.info("No dependencies to install.")
196
+ elif dependency_settings is not None:
197
+ dependencies_file = os.path.join(dependency_file_dir, dependency_settings.dependency_file)
198
+ RuntimeEnvironmentManager().bootstrap(
199
+ local_dependencies_file=dependencies_file,
200
+ conda_env=conda_env,
201
+ client_python_version=client_python_version,
202
+ )
203
+ else:
204
+ # no dependency file name is passed when an legacy version of the SDK is used
205
+ # we look for a file with .txt, .yml or .yaml extension in the workspace directory
206
+ dependencies_file = None
207
+ for file in os.listdir(dependency_file_dir):
208
+ if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"):
209
+ dependencies_file = os.path.join(dependency_file_dir, file)
210
+ break
211
+
212
+ if dependencies_file:
213
+ RuntimeEnvironmentManager().bootstrap(
214
+ local_dependencies_file=dependencies_file,
215
+ conda_env=conda_env,
216
+ client_python_version=client_python_version,
217
+ )
218
+ else:
219
+ logger.info(
220
+ "Did not find any dependency file in the directory at '%s'."
221
+ " Assuming no additional dependencies to install.",
222
+ os.path.join(BASE_CHANNEL_PATH, channel_name),
223
+ )
224
+
225
+
226
+ def _unpack_user_workspace(func_step_workspace: str = None):
227
+ """Unzip the user workspace"""
228
+
229
+ workspace_archive_dir_path = (
230
+ os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE)
231
+ if not func_step_workspace
232
+ else os.path.join(BASE_CHANNEL_PATH, func_step_workspace)
233
+ )
234
+ if not os.path.exists(workspace_archive_dir_path):
235
+ logger.info(
236
+ "Directory '%s' does not exist.",
237
+ workspace_archive_dir_path,
238
+ )
239
+ return None
240
+
241
+ workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip")
242
+ if not os.path.isfile(workspace_archive_path):
243
+ logger.info(
244
+ "Workspace archive '%s' does not exist.",
245
+ workspace_archive_dir_path,
246
+ )
247
+ return None
248
+
249
+ workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute()
250
+ shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir)
251
+ logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir)
252
+ workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
253
+ return workspace_unpack_dir
254
+
255
+
256
+ def _write_failure_reason_file(failure_msg):
257
+ """Create a file 'failure' with failure reason written if bootstrap runtime env failed.
258
+
259
+ See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
260
+ Args:
261
+ failure_msg: The content of file to be written.
262
+ """
263
+ if not os.path.exists(FAILURE_REASON_PATH):
264
+ with open(FAILURE_REASON_PATH, "w") as f:
265
+ f.write("RuntimeEnvironmentError: " + failure_msg)
266
+
267
+
268
+ def _parse_args(sys_args):
269
+ """Parses CLI arguments."""
270
+ parser = argparse.ArgumentParser()
271
+ parser.add_argument("--job_conda_env", type=str)
272
+ parser.add_argument("--client_python_version", type=str)
273
+ parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None)
274
+ parser.add_argument("--pipeline_execution_id", type=str)
275
+ parser.add_argument("--dependency_settings", type=str)
276
+ parser.add_argument("--func_step_s3_dir", type=str)
277
+ parser.add_argument("--distribution", type=str, default=None)
278
+ parser.add_argument("--user_nproc_per_node", type=str, default=None)
279
+ args, _ = parser.parse_known_args(sys_args)
280
+ return args
281
+
282
+
283
+ def log_key_value(key: str, value: str):
284
+ """Log a key-value pair, masking sensitive values if necessary."""
285
+ if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS):
286
+ logger.info("%s=%s", key, HIDDEN_VALUE)
287
+ elif isinstance(value, dict):
288
+ masked_value = mask_sensitive_info(value)
289
+ logger.info("%s=%s", key, json.dumps(masked_value))
290
+ else:
291
+ try:
292
+ decoded_value = json.loads(value)
293
+ if isinstance(decoded_value, dict):
294
+ masked_value = mask_sensitive_info(decoded_value)
295
+ logger.info("%s=%s", key, json.dumps(masked_value))
296
+ else:
297
+ logger.info("%s=%s", key, decoded_value)
298
+ except (json.JSONDecodeError, TypeError):
299
+ logger.info("%s=%s", key, value)
300
+
301
+
302
+ def log_env_variables(env_vars_dict: Dict[str, Any]):
303
+ """Log Environment Variables from the environment and an env_vars_dict."""
304
+ for key, value in os.environ.items():
305
+ log_key_value(key, value)
306
+
307
+ for key, value in env_vars_dict.items():
308
+ log_key_value(key, value)
309
+
310
+
311
+ def mask_sensitive_info(data):
312
+ """Recursively mask sensitive information in a dictionary."""
313
+ if isinstance(data, dict):
314
+ for k, v in data.items():
315
+ if isinstance(v, dict):
316
+ data[k] = mask_sensitive_info(v)
317
+ elif isinstance(v, str) and any(
318
+ keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS
319
+ ):
320
+ data[k] = HIDDEN_VALUE
321
+ return data
322
+
323
+
324
+ def num_cpus() -> int:
325
+ """Return the number of CPUs available in the current container.
326
+
327
+ Returns:
328
+ int: Number of CPUs available in the current container.
329
+ """
330
+ return multiprocessing.cpu_count()
331
+
332
+
333
+ def num_gpus() -> int:
334
+ """Return the number of GPUs available in the current container.
335
+
336
+ Returns:
337
+ int: Number of GPUs available in the current container.
338
+ """
339
+ try:
340
+ cmd = ["nvidia-smi", "--list-gpus"]
341
+ output = subprocess.check_output(cmd).decode("utf-8")
342
+ return sum(1 for line in output.splitlines() if line.startswith("GPU "))
343
+ except (OSError, subprocess.CalledProcessError):
344
+ logger.info("No GPUs detected (normal if no gpus installed)")
345
+ return 0
346
+
347
+
348
+ def num_neurons() -> int:
349
+ """Return the number of neuron cores available in the current container.
350
+
351
+ Returns:
352
+ int: Number of Neuron Cores available in the current container.
353
+ """
354
+ try:
355
+ cmd = ["neuron-ls", "-j"]
356
+ output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8")
357
+ j = json.loads(output)
358
+ neuron_cores = 0
359
+ for item in j:
360
+ neuron_cores += item.get("nc_count", 0)
361
+ logger.info("Found %s neurons on this instance", neuron_cores)
362
+ return neuron_cores
363
+ except OSError:
364
+ logger.info("No Neurons detected (normal if no neurons installed)")
365
+ return 0
366
+ except subprocess.CalledProcessError as e:
367
+ if e.output is not None:
368
+ try:
369
+ msg = e.output.decode("utf-8").partition("error=")[2]
370
+ logger.info(
371
+ "No Neurons detected (normal if no neurons installed). \
372
+ If neuron installed then %s",
373
+ msg,
374
+ )
375
+ except AttributeError:
376
+ logger.info("No Neurons detected (normal if no neurons installed)")
377
+ else:
378
+ logger.info("No Neurons detected (normal if no neurons installed)")
379
+
380
+ return 0
381
+
382
+
383
+ def safe_serialize(data):
384
+ """Serialize the data without wrapping strings in quotes.
385
+
386
+ This function handles the following cases:
387
+ 1. If `data` is a string, it returns the string as-is without wrapping in quotes.
388
+ 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
389
+ the JSON-encoded string using `json.dumps()`.
390
+ 3. If `data` cannot be serialized (e.g., a custom object), it returns the string
391
+ representation of the data using `str(data)`.
392
+
393
+ Args:
394
+ data (Any): The data to serialize.
395
+
396
+ Returns:
397
+ str: The serialized JSON-compatible string or the string representation of the input.
398
+ """
399
+ if isinstance(data, str):
400
+ return data
401
+ try:
402
+ return json.dumps(data)
403
+ except TypeError:
404
+ return str(data)
405
+
406
+
407
+ def set_env(
408
+ resource_config: Dict[str, Any],
409
+ distribution: str = None,
410
+ user_nproc_per_node: bool = None,
411
+ output_file: str = ENV_OUTPUT_FILE,
412
+ ):
413
+ """Set environment variables for the training job container.
414
+
415
+ Args:
416
+ resource_config (Dict[str, Any]): Resource configuration for the training job.
417
+ output_file (str): Output file to write the environment variables.
418
+ """
419
+ # Constants
420
+ env_vars = {
421
+ "SM_MODEL_DIR": SM_MODEL_DIR,
422
+ "SM_INPUT_DIR": SM_INPUT_DIR,
423
+ "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR,
424
+ "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR,
425
+ "SM_OUTPUT_DIR": SM_OUTPUT_DIR,
426
+ "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE,
427
+ "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR,
428
+ "SM_MASTER_ADDR": SM_MASTER_ADDR,
429
+ "SM_MASTER_PORT": SM_MASTER_PORT,
430
+ }
431
+
432
+ # Host Variables
433
+ current_host = resource_config["current_host"]
434
+ current_instance_type = resource_config["current_instance_type"]
435
+ hosts = resource_config["hosts"]
436
+ sorted_hosts = sorted(hosts)
437
+
438
+ env_vars["SM_CURRENT_HOST"] = current_host
439
+ env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type
440
+ env_vars["SM_HOSTS"] = sorted_hosts
441
+ env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"]
442
+ env_vars["SM_HOST_COUNT"] = len(sorted_hosts)
443
+ env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host)
444
+
445
+ env_vars["SM_NUM_CPUS"] = num_cpus()
446
+ env_vars["SM_NUM_GPUS"] = num_gpus()
447
+ env_vars["SM_NUM_NEURONS"] = num_neurons()
448
+
449
+ # Misc.
450
+ env_vars["SM_RESOURCE_CONFIG"] = resource_config
451
+
452
+ if user_nproc_per_node is not None and int(user_nproc_per_node) > 0:
453
+ env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node)
454
+ else:
455
+ if int(env_vars["SM_NUM_GPUS"]) > 0:
456
+ env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"])
457
+ elif int(env_vars["SM_NUM_NEURONS"]) > 0:
458
+ env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"])
459
+ else:
460
+ env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"])
461
+
462
+ # All Training Environment Variables
463
+ env_vars["SM_TRAINING_ENV"] = {
464
+ "current_host": env_vars["SM_CURRENT_HOST"],
465
+ "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"],
466
+ "hosts": env_vars["SM_HOSTS"],
467
+ "host_count": env_vars["SM_HOST_COUNT"],
468
+ "nproc_per_node": env_vars["SM_NPROC_PER_NODE"],
469
+ "master_addr": env_vars["SM_MASTER_ADDR"],
470
+ "master_port": env_vars["SM_MASTER_PORT"],
471
+ "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"],
472
+ "input_data_dir": env_vars["SM_INPUT_DATA_DIR"],
473
+ "input_dir": env_vars["SM_INPUT_DIR"],
474
+ "job_name": os.environ["TRAINING_JOB_NAME"],
475
+ "model_dir": env_vars["SM_MODEL_DIR"],
476
+ "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"],
477
+ "num_cpus": env_vars["SM_NUM_CPUS"],
478
+ "num_gpus": env_vars["SM_NUM_GPUS"],
479
+ "num_neurons": env_vars["SM_NUM_NEURONS"],
480
+ "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"],
481
+ "resource_config": env_vars["SM_RESOURCE_CONFIG"],
482
+ }
483
+
484
+ if distribution and distribution == "torchrun":
485
+ logger.info("Distribution: torchrun")
486
+
487
+ instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"]
488
+ network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0")
489
+
490
+ if instance_type in SM_EFA_NCCL_INSTANCES:
491
+ # Enable EFA use
492
+ env_vars["FI_PROVIDER"] = "efa"
493
+ if instance_type in SM_EFA_RDMA_INSTANCES:
494
+ # Use EFA's RDMA functionality for one-sided and two-sided transfer
495
+ env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1"
496
+ env_vars["RDMAV_FORK_SAFE"] = "1"
497
+ env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name)
498
+ env_vars["NCCL_PROTO"] = "simple"
499
+ elif distribution and distribution == "mpirun":
500
+ logger.info("Distribution: mpirun")
501
+
502
+ env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"]
503
+ env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"])
504
+
505
+ host_list = [
506
+ "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts
507
+ ]
508
+ env_vars["SM_HOSTS_LIST"] = ",".join(host_list)
509
+
510
+ instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"]
511
+
512
+ if instance_type in SM_EFA_NCCL_INSTANCES:
513
+ env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa"
514
+ env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple"
515
+ else:
516
+ env_vars["SM_FI_PROVIDER"] = ""
517
+ env_vars["SM_NCCL_PROTO"] = ""
518
+
519
+ if instance_type in SM_EFA_RDMA_INSTANCES:
520
+ env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1"
521
+ else:
522
+ env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = ""
523
+
524
+ with open(output_file, "w") as f:
525
+ for key, value in env_vars.items():
526
+ f.write(f"export {key}='{safe_serialize(value)}'\n")
527
+
528
+ logger.info("Environment Variables:")
529
+ log_env_variables(env_vars_dict=env_vars)
530
+
531
+
532
+ def main(sys_args=None):
533
+ """Entry point for bootstrap script"""
534
+
535
+ exit_code = DEFAULT_FAILURE_CODE
536
+
537
+ try:
538
+ args = _parse_args(sys_args)
539
+
540
+ logger.info("Arguments:")
541
+ for arg in vars(args):
542
+ logger.info("%s=%s", arg, getattr(args, arg))
543
+
544
+ client_python_version = args.client_python_version
545
+ client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
546
+ job_conda_env = args.job_conda_env
547
+ pipeline_execution_id = args.pipeline_execution_id
548
+ dependency_settings = _DependencySettings.from_string(args.dependency_settings)
549
+ func_step_workspace = args.func_step_s3_dir
550
+ distribution = args.distribution
551
+ user_nproc_per_node = args.user_nproc_per_node
552
+
553
+ conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
554
+
555
+ RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
556
+
557
+ user = getpass.getuser()
558
+ if user != "root":
559
+ log_message = (
560
+ "The job is running on non-root user: %s. Adding write permissions to the "
561
+ "following job output directories: %s."
562
+ )
563
+ logger.info(log_message, user, JOB_OUTPUT_DIRS)
564
+ RuntimeEnvironmentManager().change_dir_permission(
565
+ dirs=JOB_OUTPUT_DIRS, new_permission="777"
566
+ )
567
+
568
+ if pipeline_execution_id:
569
+ _bootstrap_runtime_env_for_pipeline_step(
570
+ client_python_version, func_step_workspace, conda_env, dependency_settings
571
+ )
572
+ else:
573
+ _bootstrap_runtime_env_for_remote_function(
574
+ client_python_version, conda_env, dependency_settings
575
+ )
576
+
577
+ RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
578
+ client_sagemaker_pysdk_version
579
+ )
580
+
581
+ if os.path.exists(RESOURCE_CONFIG):
582
+ try:
583
+ logger.info("Found %s", RESOURCE_CONFIG)
584
+ with open(RESOURCE_CONFIG, "r") as f:
585
+ resource_config = json.load(f)
586
+ set_env(
587
+ resource_config=resource_config,
588
+ distribution=distribution,
589
+ user_nproc_per_node=user_nproc_per_node,
590
+ )
591
+ except (json.JSONDecodeError, FileNotFoundError) as e:
592
+ # Optionally, you might want to log this error
593
+ logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e))
594
+
595
+ exit_code = SUCCESS_EXIT_CODE
596
+ except Exception as e: # pylint: disable=broad-except
597
+ logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
598
+
599
+ _write_failure_reason_file(str(e))
600
+ finally:
601
+ sys.exit(exit_code)
602
+
603
+
604
+ if __name__ == "__main__":
605
+ main(sys.argv[1:])