sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,302 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module provides mpi related utility functions for the container drivers."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ import sys
18
+ import subprocess
19
+ import time
20
+
21
+ from pathlib import Path
22
+ from typing import List
23
+
24
+ import paramiko
25
+
26
+ sys.path.insert(0, str(Path(__file__).parent.parent))
27
+
28
+ from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
29
+ SM_EFA_NCCL_INSTANCES,
30
+ SM_EFA_RDMA_INSTANCES,
31
+ get_python_executable,
32
+ logger,
33
+ )
34
+
35
+ FINISHED_STATUS_FILE = "/tmp/done.algo-1"
36
+ READY_FILE = "/tmp/ready.%s"
37
+ DEFAULT_SSH_PORT = 22
38
+
39
+
40
+ def _write_file_to_host(host: str, status_file: str) -> bool:
41
+ """Write the a file to the provided host."""
42
+ try:
43
+ logger.info(f"Writing {status_file} to {host}")
44
+ subprocess.run(
45
+ ["ssh", host, "touch", f"{status_file}"],
46
+ capture_output=True,
47
+ text=True,
48
+ check=True,
49
+ )
50
+ logger.info("Finished writing status file")
51
+ return True
52
+ except subprocess.CalledProcessError:
53
+ logger.info(f"Cannot connect to {host}")
54
+ return False
55
+
56
+
57
+ def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE):
58
+ """Write the status file to all worker nodes."""
59
+ for worker in worker_hosts:
60
+ retry = 0
61
+ while not _write_file_to_host(worker, status_file):
62
+ time.sleep(5)
63
+ retry += 1
64
+ if retry > 5:
65
+ raise TimeoutError(f"Timed out waiting for {worker} to be reachable.")
66
+ logger.info(f"Retrying to write status file to {worker}")
67
+
68
+
69
+ def _wait_for_status_file(status_file: str):
70
+ """Wait for the status file to be created."""
71
+ logger.info(f"Waiting for status file {status_file}")
72
+ while not os.path.exists(status_file):
73
+ time.sleep(30)
74
+ logger.info(f"Found status file {status_file}")
75
+
76
+
77
+ def start_sshd_daemon():
78
+ """Start the SSH daemon on the current node."""
79
+ sshd_executable = "/usr/sbin/sshd"
80
+
81
+ if not os.path.exists(sshd_executable):
82
+ raise RuntimeError("SSH daemon not found.")
83
+
84
+ # Start the sshd in daemon mode (-D)
85
+ subprocess.Popen([sshd_executable, "-D"])
86
+ logger.info("Started SSH daemon.")
87
+
88
+
89
+ class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
90
+ """Class to handle host key policy for SageMaker distributed training SSH connections.
91
+
92
+ Example:
93
+ >>> client = paramiko.SSHClient()
94
+ >>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
95
+ >>> # Will succeed for SageMaker algorithm containers
96
+ >>> client.connect('algo-1234.internal')
97
+ >>> # Will raise SSHException for other unknown hosts
98
+ >>> client.connect('unknown-host') # raises SSHException
99
+ """
100
+
101
+ def missing_host_key(self, client, hostname, key):
102
+ """Accept host keys for algo-* hostnames, reject others.
103
+
104
+ Args:
105
+ client: The SSHClient instance
106
+ hostname: The hostname attempting to connect
107
+ key: The host key
108
+
109
+ Raises:
110
+ paramiko.SSHException: If hostname doesn't match algo-* pattern
111
+ """
112
+ if hostname.startswith("algo-"):
113
+ client.get_host_keys().add(hostname, key.get_name(), key)
114
+ return
115
+ raise paramiko.SSHException(f"Unknown host key for {hostname}")
116
+
117
+
118
+ def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
119
+ """Check if the connection to the provided host and port is possible."""
120
+ try:
121
+ logger.debug("Testing connection to host %s", host)
122
+ with paramiko.SSHClient() as client:
123
+ client.load_system_host_keys()
124
+ client.set_missing_host_key_policy(CustomHostKeyPolicy())
125
+ client.connect(host, port=port)
126
+ logger.info("Can connect to host %s", host)
127
+ return True
128
+ except Exception as e: # pylint: disable=W0703
129
+ logger.info("Cannot connect to host %s", host)
130
+ logger.debug(f"Connection failed with exception: {e}")
131
+ return False
132
+
133
+
134
+ def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300):
135
+ """Master node waits until it can connect to all worker nodes."""
136
+ start_time = time.time()
137
+ if not worker_hosts:
138
+ logger.info("No worker nodes to connect to.")
139
+ return
140
+
141
+ while True:
142
+ logger.info("Master is attempting to connect to all workers...")
143
+ all_workers_connected = all(
144
+ _can_connect(worker, port) and os.path.exists(READY_FILE % worker)
145
+ for worker in worker_hosts
146
+ )
147
+
148
+ if all_workers_connected:
149
+ logger.info("Master can connect to all worker nodes.")
150
+ break
151
+ if time.time() - start_time > timeout:
152
+ raise TimeoutError("Timed out waiting for workers to be reachable.")
153
+
154
+ time.sleep(5) # Wait for 5 seconds before trying again
155
+
156
+
157
+ def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300):
158
+ """Worker nodes wait until they can connect to the master node."""
159
+ start_time = time.time()
160
+ while True:
161
+ logger.info(f"Worker is attempting to connect to the master node {master_host}...")
162
+ if _can_connect(master_host, port):
163
+ logger.info(f"Worker can connect to master node {master_host}.")
164
+ break
165
+ if time.time() - start_time > timeout:
166
+ raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.")
167
+
168
+ time.sleep(5) # Wait for 5 seconds before trying again
169
+
170
+
171
+ def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE):
172
+ """Bootstrap the worker nodes."""
173
+ logger.info("Bootstrapping worker node...")
174
+ _wait_for_master(master_host)
175
+ _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"])
176
+ _wait_for_status_file(status_file)
177
+
178
+
179
+ def bootstrap_master_node(worker_hosts: List[str]):
180
+ """Bootstrap the master node."""
181
+ logger.info("Bootstrapping master node...")
182
+ _wait_for_workers(worker_hosts)
183
+
184
+
185
+ def validate_smddprun() -> bool:
186
+ """Whether smddprun is installed.
187
+
188
+ Returns:
189
+ bool: True if installed
190
+ """
191
+ try:
192
+ output = subprocess.run(
193
+ ["which", "smddprun"],
194
+ capture_output=True,
195
+ text=True,
196
+ check=True,
197
+ )
198
+ return output.stdout != ""
199
+ except subprocess.CalledProcessError:
200
+ return False
201
+
202
+
203
+ def validate_smddpmprun() -> bool:
204
+ """Whether smddpmprun is installed.
205
+
206
+ Returns:
207
+ bool: True if both are installed
208
+ """
209
+ try:
210
+ output = subprocess.run(
211
+ ["which", "smddpmprun"],
212
+ capture_output=True,
213
+ text=True,
214
+ check=True,
215
+ )
216
+ return output.stdout != ""
217
+ except subprocess.CalledProcessError:
218
+ return False
219
+
220
+
221
+ def write_env_vars_to_file():
222
+ """Write environment variables to /etc/environment file."""
223
+ with open("/etc/environment", "a", encoding="utf-8") as f:
224
+ for name in os.environ:
225
+ f.write(f"{name}={os.environ.get(name)}\n")
226
+
227
+
228
+ def get_mpirun_command(
229
+ host_count: int,
230
+ host_list: List[str],
231
+ num_processes: int,
232
+ additional_options: List[str],
233
+ entry_script_path: str,
234
+ ):
235
+ """Fetch mpi command"""
236
+ network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0")
237
+
238
+ mpirun_command = [
239
+ "mpirun",
240
+ "--host",
241
+ ",".join(host_list),
242
+ "-np",
243
+ str(num_processes),
244
+ "--allow-run-as-root",
245
+ "--tag-output",
246
+ "-mca",
247
+ "btl_tcp_if_include",
248
+ network_interface_name,
249
+ "-mca",
250
+ "oob_tcp_if_include",
251
+ network_interface_name,
252
+ "-mca",
253
+ "plm_rsh_no_tree_spawn",
254
+ "1",
255
+ "-mca",
256
+ "pml",
257
+ "ob1",
258
+ "-mca",
259
+ "btl",
260
+ "^openib",
261
+ "-mca",
262
+ "orte_abort_on_non_zero_status",
263
+ "1",
264
+ "-mca",
265
+ "btl_vader_single_copy_mechanism",
266
+ "none",
267
+ "-mca",
268
+ "plm_rsh_num_concurrent",
269
+ str(host_count),
270
+ "-x",
271
+ "NCCL_SOCKET_IFNAME=%s" % network_interface_name,
272
+ "-x",
273
+ "LD_LIBRARY_PATH",
274
+ "-x",
275
+ "PATH",
276
+ ]
277
+
278
+ if additional_options:
279
+ mpirun_command.extend(additional_options)
280
+
281
+ instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"]
282
+ # EFA settings
283
+ if instance_type in SM_EFA_NCCL_INSTANCES:
284
+ mpirun_command.extend(["-x", "FI_PROVIDER=efa"])
285
+ # Use simple protocol to handle the out-of-order data delivery from EFA
286
+ mpirun_command.extend(["-x", "NCCL_PROTO=simple"])
287
+
288
+ if instance_type in SM_EFA_RDMA_INSTANCES:
289
+ # Use EFA's RDMA functionality for one-sided and two-sided transfer
290
+ mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"])
291
+
292
+ for credential in [
293
+ "AWS_ACCESS_KEY_ID",
294
+ "AWS_SECRET_ACCESS_KEY",
295
+ "AWS_SESSION_TOKEN",
296
+ ]:
297
+ if credential in os.environ:
298
+ mpirun_command.extend(["-x", credential])
299
+
300
+ mpirun_command.extend([get_python_executable()])
301
+ mpirun_command.extend(["-m", "mpi4py", entry_script_path])
302
+ return mpirun_command
@@ -0,0 +1,129 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module is the entry point for the Torchrun driver script."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+
20
+ from pathlib import Path
21
+ from typing import List, Tuple
22
+
23
+ sys.path.insert(0, str(Path(__file__).parent.parent))
24
+
25
+ from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
26
+ logger,
27
+ hyperparameters_to_cli_args,
28
+ get_process_count,
29
+ get_python_executable,
30
+ execute_commands,
31
+ write_failure_file,
32
+ SM_EFA_NCCL_INSTANCES,
33
+ SM_EFA_RDMA_INSTANCES,
34
+ )
35
+
36
+
37
+ def pytorch_version() -> Tuple[int, int]:
38
+ """Get the PyTorch version as a tuple of integers."""
39
+ import torch
40
+
41
+ return tuple(map(int, torch.__version__.split(".")[:2]))
42
+
43
+
44
+ def get_base_pytorch_command() -> List[str]:
45
+ """Get the base Torch Distributed launcher to execute"""
46
+ if pytorch_version() >= (1, 9):
47
+ return ["torchrun"]
48
+ return [f"{get_python_executable()}", "-m", "torch.distributed.launch"]
49
+
50
+
51
+ def setup_env():
52
+ """Setup the environment variables for PyTorch distributed training"""
53
+ instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"]
54
+ network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0")
55
+ if instance_type in SM_EFA_NCCL_INSTANCES:
56
+ # Enable EFA use
57
+ os.environ["FI_PROVIDER"] = "efa"
58
+ if instance_type in SM_EFA_RDMA_INSTANCES:
59
+ # Use EFA's RDMA functionality for one-sided and two-sided transfer
60
+ os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
61
+ os.environ["RDMAV_FORK_SAFE"] = "1"
62
+ os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name)
63
+ os.environ["NCCL_PROTO"] = "simple"
64
+
65
+
66
+ def create_commands():
67
+ """Create the Torch Distributed command to execute"""
68
+ entry_script = os.environ["SM_ENTRY_SCRIPT"]
69
+ distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
70
+ hyperparameters = json.loads(os.environ["SM_HPS"])
71
+
72
+ process_count = int(distributed_config["process_count_per_node"] or 0)
73
+ process_count = get_process_count(process_count)
74
+ host_count = int(os.environ["SM_HOST_COUNT"])
75
+
76
+ torch_cmd = []
77
+ if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1":
78
+ torch_cmd.append("neuron_parallel_compile")
79
+
80
+ torch_cmd.extend(get_base_pytorch_command())
81
+ torch_cmd.extend(
82
+ [
83
+ f"--nnodes={host_count}",
84
+ f"--nproc_per_node={process_count}",
85
+ ]
86
+ )
87
+
88
+ # If more than one node is used, add node rank information
89
+ if int(host_count) > 1:
90
+ torch_cmd.extend(
91
+ [
92
+ f"--master_addr={os.environ['SM_MASTER_ADDR']}",
93
+ f"--master_port={os.environ['SM_MASTER_PORT']}",
94
+ f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}",
95
+ ]
96
+ )
97
+
98
+ torch_cmd.extend([entry_script])
99
+
100
+ args = hyperparameters_to_cli_args(hyperparameters)
101
+ torch_cmd += args
102
+
103
+ return torch_cmd
104
+
105
+
106
+ def main():
107
+ """Main function to execute the PyTorch distributed training script.
108
+
109
+ This function sets some environment variables and executes the PyTorch
110
+ distributed training script.
111
+
112
+ Execution Lifecycle:
113
+ 1. Setup Environment Variables for PyTorch Distributed Training
114
+ 2. Create Torch Distributed Command
115
+ 3. Execute Torch Distributed Command with user script provided in `entry_script`
116
+ 4. Exit
117
+
118
+ """
119
+ setup_env()
120
+ torch_cmd = create_commands()
121
+ logger.info(f"Executing command: {' '.join(torch_cmd)}")
122
+ exit_code, traceback = execute_commands(torch_cmd)
123
+ if exit_code != 0:
124
+ write_failure_file(traceback)
125
+ sys.exit(exit_code)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Sagemaker modules container drivers - scripts directory."""
14
+ from __future__ import absolute_import