sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1285 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """SageMaker remote function client."""
14
+ from __future__ import absolute_import
15
+
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ from collections import deque
18
+ import time
19
+ import threading
20
+ from typing import Callable, Dict, List, Optional, Tuple, Any, Union
21
+ import functools
22
+ import itertools
23
+ import inspect
24
+
25
+ from botocore.exceptions import ClientError
26
+ from sagemaker.core.exceptions import UnexpectedStatusException
27
+ from sagemaker.core.experiments._run_context import _RunContext
28
+
29
+ import sagemaker.core.remote_function.core.serialization as serialization
30
+ from sagemaker.core.remote_function.errors import (
31
+ RemoteFunctionError,
32
+ ServiceError,
33
+ DeserializationError,
34
+ )
35
+ from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER
36
+ from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
37
+ RuntimeEnvironmentError,
38
+ )
39
+
40
+ from sagemaker.core.helper.session_helper import Session
41
+ from sagemaker.core.s3 import s3_path_join
42
+ from sagemaker.core.remote_function.job import _JobSettings, _Job, _RunInfo
43
+ from sagemaker.core.remote_function import logging_config
44
+ from sagemaker.core.common_utils import name_from_base, base_from_name
45
+ from sagemaker.core.remote_function.spark_config import SparkConfig
46
+ from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter
47
+ from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
48
+ from sagemaker.core.telemetry.constants import Feature
49
+
50
+ _API_CALL_LIMIT = {
51
+ "SubmittingIntervalInSecs": 1,
52
+ "MinBatchPollingIntervalInSecs": 10,
53
+ "PollingIntervalInSecs": 0.5,
54
+ }
55
+
56
+ # Possible future states.
57
+ _PENDING = "PENDING"
58
+ _RUNNING = "RUNNING"
59
+ # The future was cancelled by the user...
60
+ _CANCELLED = "CANCELLED"
61
+ _FINISHED = "FINISHED"
62
+
63
+ logger = logging_config.get_logger()
64
+
65
+
66
+ @_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote")
67
+ def remote(
68
+ _func=None,
69
+ *,
70
+ dependencies: str = None,
71
+ pre_execution_commands: List[str] = None,
72
+ pre_execution_script: str = None,
73
+ environment_variables: Dict[str, str] = None,
74
+ image_uri: str = None,
75
+ include_local_workdir: bool = None,
76
+ custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
77
+ instance_count: int = 1,
78
+ instance_type: str = None,
79
+ job_conda_env: str = None,
80
+ job_name_prefix: str = None,
81
+ keep_alive_period_in_seconds: int = 0,
82
+ max_retry_attempts: int = 1,
83
+ max_runtime_in_seconds: int = 24 * 60 * 60,
84
+ role: str = None,
85
+ s3_kms_key: str = None,
86
+ s3_root_uri: str = None,
87
+ sagemaker_session: Session = None,
88
+ security_group_ids: List[str] = None,
89
+ subnets: List[str] = None,
90
+ tags: List[Tuple[str, str]] = None,
91
+ volume_kms_key: str = None,
92
+ volume_size: int = 30,
93
+ encrypt_inter_container_traffic: bool = None,
94
+ spark_config: SparkConfig = None,
95
+ use_spot_instances=False,
96
+ max_wait_time_in_seconds=None,
97
+ disable_output_compression: bool = False,
98
+ use_torchrun: bool = False,
99
+ use_mpirun: bool = False,
100
+ nproc_per_node: Optional[int] = None,
101
+ ):
102
+ """Decorator for running the annotated function as a SageMaker training job.
103
+
104
+ This decorator wraps the annotated code and runs it as a new SageMaker job synchronously
105
+ with the provided runtime settings.
106
+
107
+ If a parameter value is not set, the decorator first looks up the value from the SageMaker
108
+ configuration file. If no value is specified in the configuration file or no configuration file
109
+ is found, the decorator selects the default as specified below. For more information, see
110
+ `Configuring and using defaults with the SageMaker Python SDK <https://sagemaker.readthedocs.io/
111
+ en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk>`_.
112
+
113
+ Args:
114
+ _func (Optional): A Python function to run as a SageMaker training job.
115
+
116
+ dependencies (str): Either the path to a dependencies file or the reserved keyword
117
+ ``auto_capture``. Defaults to ``None``.
118
+ If ``dependencies`` is provided, the value must be one of the following:
119
+
120
+ * A path to a conda environment.yml file. The following conditions apply.
121
+
122
+ * If job_conda_env is set, then the conda environment is updated by installing
123
+ dependencies from the yaml file and the function is invoked within that
124
+ conda environment. For this to succeed, the specified conda environment must
125
+ already exist in the image.
126
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then the
127
+ conda environment is updated by installing dependencies from the yaml file and the
128
+ function is invoked within that conda environment. For this to succeed, the
129
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
130
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
131
+ * If none of the previous conditions are met, a new conda environment named
132
+ ``sagemaker-runtime-env`` is created and the function annotated with the remote
133
+ decorator is invoked in that conda environment.
134
+
135
+ * A path to a requirements.txt file. The following conditions apply.
136
+
137
+ * If ``job_conda_env`` is set in the remote decorator, dependencies are installed
138
+ within that conda environment and the function annotated with the remote decorator
139
+ is invoked in the same conda environment. For this to succeed, the specified
140
+ conda environment must already exist in the image.
141
+ * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
142
+ dependencies are installed within that conda environment and the function annotated
143
+ with the remote decorator is invoked in the same. For this to succeed, the conda
144
+ environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
145
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
146
+ * If none of the above conditions are met, conda is not used. Dependencies are
147
+ installed at the system level, without any virtual environment, and the function
148
+ annotated with the remote decorator is invoked using the Python runtime available
149
+ in the system path.
150
+
151
+ * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
152
+ generate an env_snapshot.yml corresponding to the current active conda environment’s
153
+ snapshot. You do not need to provide a dependencies file. The following conditions
154
+ apply:
155
+
156
+ * You must run the remote function within an active conda environment.
157
+ * When installing the dependencies on the training job, the same conditions as when
158
+ dependencies is set to a path to a conda environment file apply. These conditions are
159
+ as follows:
160
+
161
+ * If job_conda_env is set, then the conda environment is updated by installing
162
+ dependencies from the yaml file and the function is invoked within that
163
+ conda environment. For this to succeed, the specified conda environment must
164
+ already exist in the image.
165
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then
166
+ the conda environment is updated by installing dependencies from the yaml file
167
+ and the function is invoked within that conda environment. For this to
168
+ succeed, the conda environment name must already be set in
169
+ ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist
170
+ in the image.
171
+ * If none of the previous conditions are met, a new conda environment with name
172
+ ``sagemaker-runtime-env`` is created and the function annotated with the
173
+ remote decorator is invoked in that conda environment.
174
+
175
+ * ``None``. SageMaker will assume that there are no dependencies to install while
176
+ executing the remote annotated function in the training job.
177
+
178
+ pre_execution_commands (List[str]): List of commands to be executed prior to executing
179
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
180
+ can be specified at the same time. Defaults to None.
181
+
182
+ pre_execution_script (str): Path to script file to be executed prior to executing
183
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
184
+ can be specified at the same time. Defaults to None.
185
+
186
+ environment_variables (Dict): The environment variables used inside the decorator function.
187
+ Defaults to ``None``.
188
+
189
+ image_uri (str): The universal resource identifier (URI) location of a Docker image on
190
+ Amazon Elastic Container Registry (ECR). Defaults to the following based on where the SDK
191
+ is running:
192
+
193
+ * For users who specify ``spark_config`` and want to run the function in a Spark
194
+ application, the ``image_uri`` should be ``None``. A SageMaker Spark image will
195
+ be used for training, otherwise a ``ValueError`` is thrown.
196
+ * For users on SageMaker Studio notebooks, the image used as the kernel image for the
197
+ notebook is used.
198
+ * For other users, it is resolved to base python image with the same python version
199
+ as the environment running the local code.
200
+
201
+ If no compatible image is found, a ValueError is thrown.
202
+
203
+ include_local_workdir (bool): A flag to indicate that the remote function should include
204
+ local directories. Set to ``True`` if the remote function code imports local modules and
205
+ methods that are not available via PyPI or conda. Only python files are included.
206
+ Default value is ``False``.
207
+
208
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
209
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
210
+ that specifies the local directories and files to be included in the remote function.
211
+ If a callable is passed in, the function should follow the protocol of ``ignore`` argument
212
+ of ``shutil.copytree``. Defaults to ``None``, which means only python
213
+ files are accepted and uploaded to S3.
214
+
215
+ instance_count (int): The number of instances to use. Defaults to 1.
216
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
217
+ mpirun utilities
218
+
219
+ instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
220
+ the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
221
+
222
+ job_conda_env (str): The name of the conda environment to activate during job's runtime.
223
+ Defaults to ``None``.
224
+
225
+ job_name_prefix (str): The prefix used used to create the underlying SageMaker job.
226
+
227
+ keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse provisioned
228
+ infrastructure after the completion of a training job, also known as SageMaker managed
229
+ warm pools. The use of warmpools reduces the latency time spent to provision new
230
+ resources. The default value for ``keep_alive_period_in_seconds`` is 0.
231
+ NOTE: Additional charges associated with warm pools may apply. Using this parameter also
232
+ activates a new persistent cache feature, which will further reduce job start up
233
+ latency than over using SageMaker managed warm pools alone by caching the package source
234
+ downloaded in the previous runs.
235
+
236
+ max_retry_attempts (int): The max number of times the job is retried on
237
+ ``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
238
+
239
+ max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
240
+ this specified amount of time, SageMaker terminates the job regardless of its current
241
+ status. Defaults to 1 day or (86400 seconds).
242
+
243
+ role (str): The IAM role (either name or full ARN) used to run your SageMaker training
244
+ job. Defaults to:
245
+
246
+ * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or
247
+ SageMaker Studio Notebooks.
248
+ * if not above, a ValueError is be thrown.
249
+
250
+ s3_kms_key (str): The key used to encrypt the input and output data. Default to ``None``.
251
+
252
+ s3_root_uri (str): The root S3 folder to which the code archives and data are
253
+ uploaded to. Defaults to ``s3://<sagemaker-default-bucket>``.
254
+
255
+ sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which
256
+ SageMaker service calls are delegated to (default: None). If not provided, one is created
257
+ using a default configuration chain.
258
+
259
+ security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and the
260
+ training job is created without VPC config.
261
+
262
+ subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is created
263
+ without VPC config.
264
+
265
+ tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` and
266
+ the training job is created without tags.
267
+
268
+ volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an
269
+ Amazon Elastic Block Storage (EBS) volume attached to the training instance. Defaults to
270
+ ``None``.
271
+
272
+ volume_size (int): The size in GB of the storage volume for storing input and output data
273
+ during training. Defaults to ``30``.
274
+
275
+ encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
276
+ training containers is encrypted for the training job. Defaults to ``False``.
277
+
278
+ spark_config (SparkConfig): Configurations to the Spark application that runs on
279
+ Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
280
+ will be used for training. Note that ``image_uri`` can not be specified at the
281
+ same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
282
+
283
+ use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
284
+ training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
285
+ Defaults to ``False``.
286
+
287
+ max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
288
+ After this amount of time Amazon SageMaker will stop waiting for managed spot training
289
+ job to complete. Defaults to ``None``.
290
+
291
+ disable_output_compression (bool): Optional. When set to true, Model is uploaded to
292
+ Amazon S3 without compression after training finishes.
293
+
294
+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
295
+ Defaults to ``False``.
296
+
297
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
298
+ Defaults to ``False``.
299
+
300
+ nproc_per_node (int): Optional. Specifies the number of processes per node for
301
+ distributed training. Defaults to ``None``.
302
+ This is defined automatically configured on the instance type.
303
+ """
304
+
305
+ def _remote(func):
306
+
307
+ job_settings = _JobSettings(
308
+ dependencies=dependencies,
309
+ pre_execution_commands=pre_execution_commands,
310
+ pre_execution_script=pre_execution_script,
311
+ environment_variables=environment_variables,
312
+ image_uri=image_uri,
313
+ include_local_workdir=include_local_workdir,
314
+ custom_file_filter=custom_file_filter,
315
+ instance_count=instance_count,
316
+ instance_type=instance_type,
317
+ job_conda_env=job_conda_env,
318
+ job_name_prefix=job_name_prefix,
319
+ keep_alive_period_in_seconds=keep_alive_period_in_seconds,
320
+ max_retry_attempts=max_retry_attempts,
321
+ max_runtime_in_seconds=max_runtime_in_seconds,
322
+ role=role,
323
+ s3_kms_key=s3_kms_key,
324
+ s3_root_uri=s3_root_uri,
325
+ sagemaker_session=sagemaker_session,
326
+ security_group_ids=security_group_ids,
327
+ subnets=subnets,
328
+ tags=tags,
329
+ volume_kms_key=volume_kms_key,
330
+ volume_size=volume_size,
331
+ encrypt_inter_container_traffic=encrypt_inter_container_traffic,
332
+ spark_config=spark_config,
333
+ use_spot_instances=use_spot_instances,
334
+ max_wait_time_in_seconds=max_wait_time_in_seconds,
335
+ disable_output_compression=disable_output_compression,
336
+ use_torchrun=use_torchrun,
337
+ use_mpirun=use_mpirun,
338
+ nproc_per_node=nproc_per_node,
339
+ )
340
+
341
+ @functools.wraps(func)
342
+ def wrapper(*args, **kwargs):
343
+
344
+ if instance_count > 1 and not (
345
+ (spark_config is not None and not use_torchrun and not use_mpirun)
346
+ or (spark_config is None and use_torchrun and not use_mpirun)
347
+ or (spark_config is None and not use_torchrun and use_mpirun)
348
+ ):
349
+ raise ValueError(
350
+ "Remote function do not support training on multi instances "
351
+ + "without spark_config or use_torchrun or use_mpirun. "
352
+ + "Please provide instance_count = 1"
353
+ )
354
+
355
+ RemoteExecutor._validate_submit_args(func, *args, **kwargs)
356
+
357
+ job = _Job.start(job_settings, func, args, kwargs)
358
+
359
+ try:
360
+ job.wait()
361
+ except UnexpectedStatusException as usex:
362
+ if usex.actual_status == "Failed":
363
+ try:
364
+ exception = serialization.deserialize_exception_from_s3(
365
+ sagemaker_session=job_settings.sagemaker_session,
366
+ s3_uri=s3_path_join(
367
+ job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
368
+ ),
369
+ hmac_key=job.hmac_key,
370
+ )
371
+ except ServiceError as serr:
372
+ chained_e = serr.__cause__
373
+ if (
374
+ isinstance(chained_e, ClientError)
375
+ and chained_e.response["Error"]["Code"] # pylint: disable=no-member
376
+ == "404"
377
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
378
+ == "Not Found"
379
+ ):
380
+ describe_result = job.describe()
381
+ if (
382
+ "FailureReason" in describe_result
383
+ and describe_result["FailureReason"]
384
+ and "RuntimeEnvironmentError: " in describe_result["FailureReason"]
385
+ ):
386
+ failure_msg = describe_result["FailureReason"].replace(
387
+ "RuntimeEnvironmentError: ", ""
388
+ )
389
+ raise RuntimeEnvironmentError(failure_msg)
390
+ raise RemoteFunctionError(
391
+ "Failed to execute remote function. "
392
+ + "Check corresponding job for details."
393
+ )
394
+ raise serr
395
+
396
+ raise exception
397
+
398
+ raise TimeoutError(
399
+ "Job for remote function timed out before reaching a termination status."
400
+ )
401
+
402
+ if job.describe()["TrainingJobStatus"] == "Completed":
403
+ return serialization.deserialize_obj_from_s3(
404
+ sagemaker_session=job_settings.sagemaker_session,
405
+ s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
406
+ hmac_key=job.hmac_key,
407
+ )
408
+
409
+ if job.describe()["TrainingJobStatus"] == "Stopped":
410
+ raise RemoteFunctionError("Job for remote function has been aborted.")
411
+
412
+ return None
413
+
414
+ wrapper.job_settings = job_settings
415
+ wrapper.wrapped_func = func
416
+ return wrapper
417
+
418
+ if _func is None:
419
+ return _remote
420
+ return _remote(_func)
421
+
422
+
423
+ class _SubmitRequest:
424
+ """Class that holds parameters and data for creating a new job."""
425
+
426
+ def __init__(
427
+ self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
428
+ ):
429
+ self.future = future
430
+ self.job_settings = job_settings
431
+ self.func = func
432
+ self.args = func_args
433
+ self.kwargs = func_kwargs
434
+ self.run_info = run_info
435
+
436
+
437
+ def _submit_worker(executor):
438
+ """Background worker that submits job requests."""
439
+
440
+ def has_work_to_do():
441
+ return (
442
+ len(executor._pending_request_queue) > 0
443
+ and len(executor._running_jobs) < executor.max_parallel_jobs
444
+ )
445
+
446
+ try:
447
+ while True:
448
+ with executor._state_condition:
449
+ executor._state_condition.wait_for(has_work_to_do)
450
+ request = executor._pending_request_queue[0]
451
+
452
+ if request is None:
453
+ with executor._state_condition:
454
+ # remove the anchor from the pending queue
455
+ executor._pending_request_queue.popleft()
456
+ return
457
+
458
+ time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"])
459
+ # submit a new job
460
+ job = request.future._start_and_notify(
461
+ request.job_settings, request.func, request.args, request.kwargs, request.run_info
462
+ )
463
+
464
+ with executor._state_condition:
465
+ if job:
466
+ executor._running_jobs[job.job_name] = job
467
+ # remove the request from the pending queue
468
+ executor._pending_request_queue.popleft()
469
+ except Exception: # pylint: disable=broad-except
470
+ logger.exception("Error occurred while submitting CreateTrainingJob requests.")
471
+
472
+
473
+ def _polling_worker(executor):
474
+ """Background worker that polls the status of the running jobs."""
475
+ try:
476
+ while True:
477
+ with executor._state_condition:
478
+ if (
479
+ executor._shutdown
480
+ and len(executor._running_jobs) + len(executor._pending_request_queue) == 0
481
+ ):
482
+ return
483
+
484
+ time.sleep(
485
+ max(
486
+ _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"]
487
+ - len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"],
488
+ 0,
489
+ )
490
+ )
491
+
492
+ # check if running jobs are terminated
493
+ for job_name in list(executor._running_jobs.keys()):
494
+ try:
495
+ time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"])
496
+ if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [
497
+ "Completed",
498
+ "Failed",
499
+ "Stopped",
500
+ ]:
501
+ with executor._state_condition:
502
+ del executor._running_jobs[job_name]
503
+ executor._state_condition.notify_all()
504
+ except Exception as e: # pylint: disable=broad-except
505
+ if (
506
+ not isinstance(e, ClientError)
507
+ or e.response["Error"]["Code"] # pylint: disable=no-member
508
+ != "LimitExceededException"
509
+ ):
510
+ # Couldn't check the job status, move on
511
+ logger.exception(
512
+ "Error occurred while checking the status of job %s", job_name
513
+ )
514
+ with executor._state_condition:
515
+ del executor._running_jobs[job_name]
516
+ executor._state_condition.notify_all()
517
+ except Exception: # pylint: disable=broad-except
518
+ logger.exception("Error occurred while monitoring the job statuses.")
519
+
520
+
521
+ class RemoteExecutor(object):
522
+ """Run Python functions asynchronously as SageMaker jobs"""
523
+
524
+ def __init__(
525
+ self,
526
+ *,
527
+ dependencies: str = None,
528
+ pre_execution_commands: List[str] = None,
529
+ pre_execution_script: str = None,
530
+ environment_variables: Dict[str, str] = None,
531
+ image_uri: str = None,
532
+ include_local_workdir: bool = None,
533
+ custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
534
+ instance_count: int = 1,
535
+ instance_type: str = None,
536
+ job_conda_env: str = None,
537
+ job_name_prefix: str = None,
538
+ keep_alive_period_in_seconds: int = 0,
539
+ max_parallel_jobs: int = 1,
540
+ max_retry_attempts: int = 1,
541
+ max_runtime_in_seconds: int = 24 * 60 * 60,
542
+ role: str = None,
543
+ s3_kms_key: str = None,
544
+ s3_root_uri: str = None,
545
+ sagemaker_session: Session = None,
546
+ security_group_ids: List[str] = None,
547
+ subnets: List[str] = None,
548
+ tags: List[Tuple[str, str]] = None,
549
+ volume_kms_key: str = None,
550
+ volume_size: int = 30,
551
+ encrypt_inter_container_traffic: bool = None,
552
+ spark_config: SparkConfig = None,
553
+ use_spot_instances=False,
554
+ max_wait_time_in_seconds=None,
555
+ disable_output_compression: bool = False,
556
+ use_torchrun: bool = False,
557
+ use_mpirun: bool = False,
558
+ nproc_per_node: Optional[int] = None,
559
+ ):
560
+ """Constructor for RemoteExecutor
561
+
562
+ If a parameter value is not set, the constructor first looks up the value from the
563
+ SageMaker configuration file. If no value is specified in the configuration file or
564
+ no configuration file is found, the constructor selects the default as specified below.
565
+ For more information, see `Configuring and using defaults with the SageMaker Python SDK
566
+ <https://sagemaker.readthedocs.io/en/stable/overview.html
567
+ #configuring-and-using-defaults-with-the-sagemaker-python-sdk>`_.
568
+
569
+ Args:
570
+ _func (Optional): A Python function to run as a SageMaker training job.
571
+
572
+ dependencies (str): Either the path to a dependencies file or the reserved keyword
573
+ ``auto_capture``. Defaults to ``None``.
574
+ If ``dependencies`` is provided, the value must be one of the following:
575
+
576
+ * A path to a conda environment.yml file. The following conditions apply.
577
+
578
+ * If job_conda_env is set, then the conda environment is updated by installing
579
+ dependencies from the yaml file and the function is invoked within that
580
+ conda environment. For this to succeed, the specified conda environment must
581
+ already exist in the image.
582
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then
583
+ the conda environment is updated by installing dependencies from the yaml file and
584
+ the function is invoked within that conda environment. For this to succeed, the
585
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
586
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
587
+ * If none of the previous conditions are met, a new conda environment named
588
+ ``sagemaker-runtime-env`` is created and the function annotated with the remote
589
+ decorator is invoked in that conda environment.
590
+
591
+ * A path to a requirements.txt file. The following conditions apply.
592
+
593
+ * If ``job_conda_env`` is set in the remote decorator, dependencies are installed
594
+ within that conda environment and the function annotated with the remote decorator
595
+ is invoked in the same conda environment. For this to succeed, the specified
596
+ conda environment must already exist in the image.
597
+ * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
598
+ dependencies are installed within that conda environment and the function annotated
599
+ with the remote decorator is invoked in the same. For this to succeed, the
600
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
601
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
602
+ * If none of the above conditions are met, conda is not used. Dependencies are
603
+ installed at the system level, without any virtual environment, and the function
604
+ annotated with the remote decorator is invoked using the Python runtime available
605
+ in the system path.
606
+
607
+ * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
608
+ generate an env_snapshot.yml corresponding to the current active conda environment’s
609
+ snapshot. You do not need to provide a dependencies file. The following conditions
610
+ apply:
611
+
612
+ * You must run the remote function within an active conda environment.
613
+ * When installing the dependencies on the training job, the same conditions as when
614
+ dependencies is set to a path to a conda environment file apply. These conditions
615
+ are as follows:
616
+
617
+ * If job_conda_env is set, then the conda environment is updated by installing
618
+ dependencies from the yaml file and the function is invoked within that
619
+ conda environment. For this to succeed, the specified conda environment must
620
+ already exist in the image.
621
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
622
+ then the conda environment is updated by installing dependencies from the yaml
623
+ file and the function is invoked within that conda environment. For this to
624
+ succeed, the conda environment name must already be set in
625
+ ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist
626
+ in the image.
627
+ * If none of the previous conditions are met, a new conda environment with name
628
+ ``sagemaker-runtime-env`` is created and the function annotated with the
629
+ remote decorator is invoked in that conda environment.
630
+
631
+ * ``None``. SageMaker will assume that there are no dependencies to install while
632
+ executing the remote annotated function in the training job.
633
+
634
+ pre_execution_commands (List[str]): List of commands to be executed prior to executing
635
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
636
+ can be specified at the same time. Defaults to None.
637
+
638
+ pre_execution_script (str): Path to script file to be executed prior to executing
639
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
640
+ can be specified at the same time. Defaults to None.
641
+
642
+ environment_variables (Dict): The environment variables used inside the decorator
643
+ function. Defaults to ``None``.
644
+
645
+ image_uri (str): The universal resource identifier (URI) location of a Docker image on
646
+ Amazon Elastic Container Registry (ECR). Defaults to the following based on where the
647
+ SDK is running:
648
+
649
+ * For users who specify ``spark_config`` and want to run the function in a Spark
650
+ application, the ``image_uri`` should be ``None``. A SageMaker Spark image will
651
+ be used for training, otherwise a ``ValueError`` is thrown.
652
+ * For users on SageMaker Studio notebooks, the image used as the kernel image for
653
+ the notebook is used.
654
+ * For other users, it is resolved to base python image with the same python
655
+ version as the environment running the local code.
656
+
657
+ If no compatible image is found, a ValueError is thrown.
658
+
659
+ include_local_workdir (bool): A flag to indicate that the remote function should include
660
+ local directories. Set to ``True`` if the remote function code imports local modules
661
+ and methods that are not available via PyPI or conda. Default value is ``False``.
662
+
663
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
664
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
665
+ that specifies the local directories and files to be included in the remote function.
666
+ If a callable is passed in, that function is passed to the ``ignore`` argument of
667
+ ``shutil.copytree``. Defaults to ``None``, which means only python
668
+ files are accepted and uploaded to S3.
669
+
670
+ instance_count (int): The number of instances to use. Defaults to 1.
671
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
672
+ mpirun utilities
673
+
674
+ instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
675
+ the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
676
+
677
+ job_conda_env (str): The name of the conda environment to activate during job's runtime.
678
+ Defaults to ``None``.
679
+
680
+ job_name_prefix (str): The prefix used used to create the underlying SageMaker job.
681
+
682
+ keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse
683
+ provisioned infrastructure after the completion of a training job, also known as
684
+ SageMaker managed warm pools. The use of warmpools reduces the latency time spent to
685
+ provision new resources. The default value for ``keep_alive_period_in_seconds`` is 0.
686
+ NOTE: Additional charges associated with warm pools may apply. Using this parameter
687
+ also activates a new pesistent cache feature, which will further reduce job start
688
+ up latency than over using SageMaker managed warm pools alone by caching the package
689
+ source downloaded in the previous runs.
690
+
691
+ max_parallel_jobs (int): Maximum number of jobs that run in parallel. Defaults to 1.
692
+
693
+ max_retry_attempts (int): The max number of times the job is retried on
694
+ ``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
695
+
696
+ max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
697
+ this specified amount of time, SageMaker terminates the job regardless of its current
698
+ status. Defaults to 1 day or (86400 seconds).
699
+
700
+ role (str): The IAM role (either name or full ARN) used to run your SageMaker training
701
+ job. Defaults to:
702
+
703
+ * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or
704
+ SageMaker Studio Notebooks.
705
+ * if not above, a ValueError is be thrown.
706
+
707
+ s3_kms_key (str): The key used to encrypt the input and output data.
708
+ Default to ``None``.
709
+
710
+ s3_root_uri (str): The root S3 folder to which the code archives and data are
711
+ uploaded to. Defaults to ``s3://<sagemaker-default-bucket>``.
712
+
713
+ sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which
714
+ SageMaker service calls are delegated to (default: None). If not provided, one is
715
+ created using a default configuration chain.
716
+
717
+ security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and
718
+ the training job is created without VPC config.
719
+
720
+ subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is
721
+ created without VPC config.
722
+
723
+ tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None``
724
+ and the training job is created without tags.
725
+
726
+ volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an
727
+ Amazon Elastic Block Storage (EBS) volume attached to the training instance.
728
+ Defaults to ``None``.
729
+
730
+ volume_size (int): The size in GB of the storage volume for storing input and output
731
+ data during training. Defaults to ``30``.
732
+
733
+ encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
734
+ training containers is encrypted for the training job. Defaults to ``False``.
735
+
736
+ spark_config (SparkConfig): Configurations to the Spark application that runs on
737
+ Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
738
+ will be used for training. Note that ``image_uri`` can not be specified at the
739
+ same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
740
+
741
+ use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
742
+ training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
743
+ Defaults to ``False``.
744
+
745
+ max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
746
+ After this amount of time Amazon SageMaker will stop waiting for managed spot training
747
+ job to complete. Defaults to ``None``.
748
+
749
+ disable_output_compression (bool): Optional. When set to true, Model is uploaded to
750
+ Amazon S3 without compression after training finishes.
751
+
752
+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
753
+ Defaults to ``False``.
754
+
755
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
756
+ Defaults to ``False``.
757
+
758
+ nproc_per_node (int): Optional. Specifies the number of processes per node for
759
+ distributed training. Defaults to ``None``.
760
+ This is defined automatically configured on the instance type.
761
+ """
762
+ self.max_parallel_jobs = max_parallel_jobs
763
+
764
+ if self.max_parallel_jobs <= 0:
765
+ raise ValueError("max_parallel_jobs must be greater than 0.")
766
+
767
+ if instance_count > 1 and not (
768
+ (spark_config is not None and not use_torchrun and not use_mpirun)
769
+ or (spark_config is None and use_torchrun and not use_mpirun)
770
+ or (spark_config is None and not use_torchrun and use_mpirun)
771
+ ):
772
+ raise ValueError(
773
+ "Remote function do not support training on multi instances "
774
+ + "without spark_config or use_torchrun or use_mpirun. "
775
+ + "Please provide instance_count = 1"
776
+ )
777
+
778
+ self.job_settings = _JobSettings(
779
+ dependencies=dependencies,
780
+ pre_execution_commands=pre_execution_commands,
781
+ pre_execution_script=pre_execution_script,
782
+ environment_variables=environment_variables,
783
+ image_uri=image_uri,
784
+ include_local_workdir=include_local_workdir,
785
+ custom_file_filter=custom_file_filter,
786
+ instance_count=instance_count,
787
+ instance_type=instance_type,
788
+ job_conda_env=job_conda_env,
789
+ job_name_prefix=job_name_prefix,
790
+ keep_alive_period_in_seconds=keep_alive_period_in_seconds,
791
+ max_retry_attempts=max_retry_attempts,
792
+ max_runtime_in_seconds=max_runtime_in_seconds,
793
+ role=role,
794
+ s3_kms_key=s3_kms_key,
795
+ s3_root_uri=s3_root_uri,
796
+ sagemaker_session=sagemaker_session,
797
+ security_group_ids=security_group_ids,
798
+ subnets=subnets,
799
+ tags=tags,
800
+ volume_kms_key=volume_kms_key,
801
+ volume_size=volume_size,
802
+ encrypt_inter_container_traffic=encrypt_inter_container_traffic,
803
+ spark_config=spark_config,
804
+ use_spot_instances=use_spot_instances,
805
+ max_wait_time_in_seconds=max_wait_time_in_seconds,
806
+ disable_output_compression=disable_output_compression,
807
+ use_torchrun=use_torchrun,
808
+ use_mpirun=use_mpirun,
809
+ nproc_per_node=nproc_per_node,
810
+ )
811
+
812
+ self._state_condition = threading.Condition()
813
+ self._pending_request_queue = deque()
814
+ # For thread safety, see
815
+ # https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm
816
+ self._running_jobs = dict()
817
+ self._shutdown = False
818
+
819
+ self._workers: ThreadPoolExecutor = None
820
+
821
+ def submit(self, func, *args, **kwargs):
822
+ """Execute the input function as a SageMaker job asynchronously.
823
+
824
+ Args:
825
+ func: Python function to run as a SageMaker job.
826
+ *args: Positional arguments to the input function.
827
+ **kwargs: keyword arguments to the input function
828
+ """
829
+ if self._shutdown:
830
+ raise RuntimeError("Cannot schedule new remote function executions after shutdown")
831
+
832
+ self._validate_submit_args(func, *args, **kwargs)
833
+
834
+ with self._state_condition:
835
+ future = Future()
836
+
837
+ run_info = None
838
+ if _RunContext.get_current_run() is not None:
839
+ run = _RunContext.get_current_run()
840
+ run_info = _RunInfo(run.experiment_name, run.run_name)
841
+
842
+ self._pending_request_queue.append(
843
+ _SubmitRequest(future, self.job_settings, func, args, kwargs, run_info)
844
+ )
845
+
846
+ if self._workers is None:
847
+ self._workers = ThreadPoolExecutor(2)
848
+ self._workers.submit(_submit_worker, self)
849
+ self._workers.submit(_polling_worker, self)
850
+
851
+ self._state_condition.notify_all()
852
+
853
+ return future
854
+
855
+ def map(self, func, *iterables):
856
+ """Return an iterator that applies function to every item of iterable, yielding the results.
857
+
858
+ If additional iterables arguments are passed, function must take that many arguments and
859
+ is applied to the items from all iterables in parallel. With multiple iterables, the
860
+ iterator stops when the shortest iterable is exhausted.
861
+
862
+ Args:
863
+ func: Python function to run as a SageMaker job.
864
+ iterables: Arguments of the input python function.
865
+ """
866
+
867
+ futures = map(self.submit, itertools.repeat(func), *iterables)
868
+ return [future.result() for future in futures]
869
+
870
+ def shutdown(self):
871
+ """Prevent more function executions to be submitted to this executor."""
872
+ with self._state_condition:
873
+ self._shutdown = True
874
+
875
+ # give a signal to the submitting worker so that it doesn't block on empty queue forever
876
+ self._pending_request_queue.append(None)
877
+
878
+ self._state_condition.notify_all()
879
+
880
+ if self._workers is not None:
881
+ self._workers.shutdown(wait=True)
882
+
883
+ def __enter__(self):
884
+ """Create an executor instance and return it"""
885
+ return self
886
+
887
+ def __exit__(self, exc_type, exc_val, exc_tb):
888
+ """Make sure the executor instance is shutdown."""
889
+ self.shutdown()
890
+ return False
891
+
892
+ @staticmethod
893
+ def _validate_submit_args(func, *args, **kwargs):
894
+ """Validates input args passed to submit method."""
895
+
896
+ full_arg_spec = inspect.getfullargspec(func)
897
+
898
+ # args related validations
899
+
900
+ is_accepting_variable_positional_args = full_arg_spec.varargs is not None
901
+ num_default_positional_args = len(full_arg_spec.defaults) if full_arg_spec.defaults else 0
902
+ minimum_num_expected_positional_args = len(full_arg_spec.args) - num_default_positional_args
903
+
904
+ if not is_accepting_variable_positional_args and len(args) > len(full_arg_spec.args):
905
+ raise TypeError(
906
+ f"{func.__name__}() takes {len(full_arg_spec.args)} positional "
907
+ + f"{'arguments' if len(full_arg_spec.args) > 1 else 'argument'} but {len(args)} "
908
+ + f"{'were' if len(args) > 1 else 'was'} given."
909
+ )
910
+
911
+ if len(args) < minimum_num_expected_positional_args:
912
+ missing_positional_args = full_arg_spec.args[
913
+ len(args) : minimum_num_expected_positional_args
914
+ ]
915
+ missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args))
916
+ if missing_args:
917
+ missing_args_str = (
918
+ ", ".join(map(lambda x: f"'{x}'", missing_args[:-1]))
919
+ + f", and '{missing_args[-1]}'"
920
+ if len(missing_args) > 1
921
+ else f"'{missing_args[0]}'"
922
+ )
923
+ raise TypeError(
924
+ f"{func.__name__}() missing {len(missing_args)} required positional "
925
+ + f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}"
926
+ )
927
+
928
+ # kwargs related validations
929
+
930
+ for k in kwargs:
931
+ if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k):
932
+ raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'")
933
+ if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args:
934
+ raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'")
935
+
936
+ missing_kwargs = [
937
+ k
938
+ for k in full_arg_spec.kwonlyargs
939
+ if k not in full_arg_spec.kwonlydefaults and k not in kwargs
940
+ ]
941
+ if missing_kwargs:
942
+ missing_kwargs_string = (
943
+ ", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1]))
944
+ + f", and '{missing_kwargs[-1]}'"
945
+ if len(missing_kwargs) > 1
946
+ else f"'{missing_kwargs[0]}'"
947
+ )
948
+
949
+ raise TypeError(
950
+ f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only "
951
+ + f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: "
952
+ + f"{missing_kwargs_string}"
953
+ )
954
+
955
+
956
+ class Future(object):
957
+ """Class representing a reference to a SageMaker job result.
958
+
959
+ Reference to the SageMaker job created as a result of the remote function run. The job may
960
+ or may not have finished running.
961
+ """
962
+
963
+ def __init__(self):
964
+ self._condition = threading.Condition()
965
+ self._state = _PENDING
966
+ self._job = None
967
+ self._exception = None
968
+ self._return = None
969
+
970
+ @staticmethod
971
+ def from_describe_response(describe_training_job_response, sagemaker_session):
972
+ """Construct a Future from a describe_training_job_response object."""
973
+ future = Future()
974
+ job_exception = None
975
+ client_exception = None
976
+ job_return = None
977
+ job = _Job.from_describe_response(describe_training_job_response, sagemaker_session)
978
+ if describe_training_job_response["TrainingJobStatus"] in ["Stopping", "Stopped"]:
979
+ state = _CANCELLED
980
+ elif describe_training_job_response["TrainingJobStatus"] == "Completed":
981
+ state = _FINISHED
982
+ try:
983
+ job_return = serialization.deserialize_obj_from_s3(
984
+ sagemaker_session=sagemaker_session,
985
+ s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
986
+ hmac_key=job.hmac_key,
987
+ )
988
+ except DeserializationError as e:
989
+ client_exception = e
990
+ except ServiceError as e:
991
+ client_exception = e
992
+ elif describe_training_job_response["TrainingJobStatus"] == "Failed":
993
+ state = _FINISHED
994
+ try:
995
+ job_exception = serialization.deserialize_exception_from_s3(
996
+ sagemaker_session=sagemaker_session,
997
+ s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
998
+ hmac_key=job.hmac_key,
999
+ )
1000
+ except ServiceError as serr:
1001
+ chained_e = serr.__cause__
1002
+ if (
1003
+ isinstance(chained_e, ClientError)
1004
+ and chained_e.response["Error"]["Code"] == "404" # pylint: disable=no-member
1005
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
1006
+ == "Not Found"
1007
+ ):
1008
+ if (
1009
+ "FailureReason" in describe_training_job_response
1010
+ and describe_training_job_response["FailureReason"]
1011
+ and "RuntimeEnvironmentError: "
1012
+ in describe_training_job_response["FailureReason"]
1013
+ ):
1014
+ failure_msg = describe_training_job_response["FailureReason"].replace(
1015
+ "RuntimeEnvironmentError: ", ""
1016
+ )
1017
+ job_exception = RuntimeEnvironmentError(failure_msg)
1018
+ else:
1019
+ job_exception = RemoteFunctionError(
1020
+ "Failed to execute remote function. "
1021
+ + "Check corresponding job for details."
1022
+ )
1023
+ else:
1024
+ job_exception = serr
1025
+ except DeserializationError as e:
1026
+ client_exception = e
1027
+ else:
1028
+ state = _RUNNING
1029
+
1030
+ future._job = job
1031
+ future._state = state
1032
+ future._exception = job_exception or client_exception
1033
+ future._return = job_return
1034
+ return future
1035
+
1036
+ def _start_and_notify(
1037
+ self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
1038
+ ):
1039
+ """Start and record the newly created job in the future object.
1040
+
1041
+ The job is recorded if one is successfully started. Otherwise, the exception is
1042
+ recorded. The state update is broadcast to other waiting threads.
1043
+ """
1044
+ with self._condition:
1045
+ if self._state in [_PENDING]:
1046
+
1047
+ try:
1048
+ self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info)
1049
+ except (Exception,) as e: # pylint: disable=broad-except
1050
+ self._exception = e
1051
+ self._state = _FINISHED
1052
+ self._condition.notify_all()
1053
+ return None
1054
+
1055
+ self._state = _RUNNING
1056
+ self._condition.notify_all()
1057
+ return self._job
1058
+ return None
1059
+
1060
+ def result(self, timeout: float = None) -> Any:
1061
+ """Returns the SageMaker job result.
1062
+
1063
+ This method waits for the SageMaker job created from the remote function execution to
1064
+ complete for up to the timeout value (if specified). If timeout is ``None``,
1065
+ this method will wait until the SageMaker job completes.
1066
+
1067
+ Args:
1068
+ timeout (float): Timeout in seconds to wait until the job is completed. ``None`` by
1069
+ default.
1070
+
1071
+ Returns:
1072
+ The Python object returned by the remote function.
1073
+ """
1074
+ try:
1075
+ self.wait(timeout)
1076
+ except UnexpectedStatusException:
1077
+ pass
1078
+
1079
+ with self._condition:
1080
+ if self._state == _PENDING:
1081
+ raise RuntimeError()
1082
+
1083
+ if self._state == _RUNNING:
1084
+ if self._job.describe()["TrainingJobStatus"] == "Completed":
1085
+ self._return = serialization.deserialize_obj_from_s3(
1086
+ sagemaker_session=self._job.sagemaker_session,
1087
+ s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1088
+ hmac_key=self._job.hmac_key,
1089
+ )
1090
+ self._state = _FINISHED
1091
+ return self._return
1092
+ if self._job.describe()["TrainingJobStatus"] == "Failed":
1093
+ try:
1094
+ self._exception = serialization.deserialize_exception_from_s3(
1095
+ sagemaker_session=self._job.sagemaker_session,
1096
+ s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1097
+ hmac_key=self._job.hmac_key,
1098
+ )
1099
+ except ServiceError as serr:
1100
+ chained_e = serr.__cause__
1101
+ if (
1102
+ isinstance(chained_e, ClientError)
1103
+ and chained_e.response["Error"]["Code"] # pylint: disable=no-member
1104
+ == "404"
1105
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
1106
+ == "Not Found"
1107
+ ):
1108
+ if (
1109
+ "FailureReason" in self._job.describe()
1110
+ and self._job.describe()["FailureReason"]
1111
+ and "RuntimeEnvironmentError: "
1112
+ in self._job.describe()["FailureReason"]
1113
+ ):
1114
+ failure_msg = self._job.describe()["FailureReason"].replace(
1115
+ "RuntimeEnvironmentError: ", ""
1116
+ )
1117
+ self._exception = RuntimeEnvironmentError(failure_msg)
1118
+ else:
1119
+ self._exception = RemoteFunctionError(
1120
+ "Failed to execute remote function. "
1121
+ + "Check corresponding job for details."
1122
+ )
1123
+ else:
1124
+ self._exception = serr
1125
+ self._state = _FINISHED
1126
+ elif self._job.describe()["TrainingJobStatus"] == "Stopped":
1127
+ self._state = _CANCELLED
1128
+ raise RemoteFunctionError("Job for remote function has been aborted.")
1129
+ else:
1130
+ raise TimeoutError(
1131
+ "Job for remote function timed out before reaching a termination status."
1132
+ )
1133
+
1134
+ if self._state == _FINISHED:
1135
+ if self._exception:
1136
+ raise self._exception
1137
+ return self._return
1138
+
1139
+ return None
1140
+
1141
+ def wait(
1142
+ self,
1143
+ timeout: int = None,
1144
+ ) -> None:
1145
+ """Wait for the underlying SageMaker job to complete.
1146
+
1147
+ This method waits for the SageMaker job created as a result of the remote function run
1148
+ to complete for up to the timeout value (if specified). If timeout is ``None``, this method
1149
+ will block until the job is completed.
1150
+
1151
+ Args:
1152
+ timeout (int): Timeout in seconds to wait until the job is completed before it is
1153
+ stopped. Defaults to ``None``.
1154
+
1155
+ Returns:
1156
+ None
1157
+ """
1158
+
1159
+ with self._condition:
1160
+ if self._state == _PENDING:
1161
+ self._condition.wait(timeout=timeout)
1162
+
1163
+ if self._state == _RUNNING:
1164
+ self._job.wait(timeout=timeout)
1165
+
1166
+ def cancel(self) -> bool:
1167
+ """Cancel the function execution.
1168
+
1169
+ This method prevents the SageMaker job being created or stops the underlying SageMaker job
1170
+ early if it is already in progress.
1171
+
1172
+ Returns:
1173
+ ``True`` if the underlying SageMaker job created as a result of the remote function
1174
+ run is cancelled.
1175
+ """
1176
+ with self._condition:
1177
+ if self._state == _FINISHED:
1178
+ return False
1179
+ if self._state == _CANCELLED:
1180
+ return True
1181
+
1182
+ if self._job:
1183
+ self._job.stop()
1184
+ self._state = _CANCELLED
1185
+ return True
1186
+
1187
+ def running(self) -> bool:
1188
+ """Check if the underlying SageMaker job is running.
1189
+
1190
+ Returns:
1191
+ ``True`` if the underlying SageMaker job is still running. ``False``, otherwise.
1192
+ """
1193
+ with self._condition:
1194
+ return self._state == _RUNNING
1195
+
1196
+ def cancelled(self) -> bool:
1197
+ """Check if the underlying SageMaker job was cancelled.
1198
+
1199
+ Returns:
1200
+ ``True`` if the underlying SageMaker job was cancelled. ``False``, otherwise.
1201
+ """
1202
+ with self._condition:
1203
+ return self._state == _CANCELLED
1204
+
1205
+ def done(self) -> bool:
1206
+ """Check if the underlying SageMaker job is finished.
1207
+
1208
+ Returns:
1209
+ ``True`` if the underlying SageMaker job finished running. ``False``, otherwise.
1210
+ """
1211
+ with self._condition:
1212
+ if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [
1213
+ "Completed",
1214
+ "Failed",
1215
+ ]:
1216
+ self._state = _FINISHED
1217
+ return True
1218
+
1219
+ if self._state == _FINISHED:
1220
+ return True
1221
+
1222
+ return False
1223
+
1224
+
1225
+ def get_future(job_name, sagemaker_session=None) -> Future:
1226
+ """Get a future object with information about a job with the given job_name.
1227
+
1228
+ Args:
1229
+ job_name (str): name of the underlying SageMaker job created as a result of the remote
1230
+ function run.
1231
+
1232
+ sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions
1233
+ with Amazon SageMaker APIs and any other AWS services needed.
1234
+
1235
+ Returns:
1236
+ A `sagemaker.remote_function.client.Future` instance.
1237
+ """
1238
+ if not sagemaker_session:
1239
+ sagemaker_session = Session()
1240
+ describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job(
1241
+ TrainingJobName=job_name
1242
+ )
1243
+ return Future.from_describe_response(describe_training_job_response, sagemaker_session)
1244
+
1245
+
1246
+ def list_futures(job_name_prefix, sagemaker_session=None):
1247
+ """Generates Future objects with information about jobs with given job_name_prefix.
1248
+
1249
+ Args:
1250
+ job_name_prefix (str): A prefix used to identify the SageMaker jobs associated with remote
1251
+ function run.
1252
+ sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions
1253
+ with Amazon SageMaker APIs and any other AWS services needed.
1254
+
1255
+ Yields:
1256
+ A `sagemaker.remote_function.client.Future` instance.
1257
+ """
1258
+ if not sagemaker_session:
1259
+ sagemaker_session = Session()
1260
+ job_name = name_from_base(job_name_prefix)
1261
+ # perform the following transformation because we might have trimmed the job_name_prefix while
1262
+ # creating the job.
1263
+ transformed_job_name_prefix = base_from_name(job_name)
1264
+ next_token = None
1265
+ list_training_job_kwargs = {"NameContains": transformed_job_name_prefix}
1266
+ while True:
1267
+ if next_token:
1268
+ list_training_job_kwargs["NextToken"] = next_token
1269
+ list_training_job_response = sagemaker_session.sagemaker_client.list_training_jobs(
1270
+ **list_training_job_kwargs
1271
+ )
1272
+ training_job_names = [
1273
+ job["TrainingJobName"] for job in list_training_job_response["TrainingJobSummaries"]
1274
+ ]
1275
+ for training_job_name in training_job_names:
1276
+ describe_training_job_response = (
1277
+ sagemaker_session.sagemaker_client.describe_training_job(
1278
+ TrainingJobName=training_job_name
1279
+ )
1280
+ )
1281
+ yield Future.from_describe_response(describe_training_job_response, sagemaker_session)
1282
+ if "NextToken" in list_training_job_response:
1283
+ next_token = list_training_job_response["NextToken"]
1284
+ else:
1285
+ break