sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1310 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """SageMaker remote function client."""
14
+ from __future__ import absolute_import
15
+
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ from collections import deque
18
+ import time
19
+ import threading
20
+ from typing import Callable, Dict, List, Optional, Tuple, Any, Union
21
+ import functools
22
+ import itertools
23
+ import inspect
24
+
25
+ from botocore.exceptions import ClientError
26
+ from sagemaker.core.exceptions import UnexpectedStatusException
27
+ from sagemaker.core.experiments._run_context import _RunContext
28
+
29
+ import sagemaker.core.remote_function.core.serialization as serialization
30
+ from sagemaker.core.remote_function.errors import (
31
+ RemoteFunctionError,
32
+ ServiceError,
33
+ DeserializationError,
34
+ )
35
+ from sagemaker.core.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER
36
+ from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
37
+ RuntimeEnvironmentError,
38
+ )
39
+
40
+ from sagemaker.core.helper.session_helper import Session
41
+ from sagemaker.core.s3 import s3_path_join
42
+ from sagemaker.core.remote_function.job import _JobSettings, _Job, _RunInfo
43
+ from sagemaker.core.remote_function import logging_config
44
+ from sagemaker.core.common_utils import name_from_base, base_from_name
45
+ from sagemaker.core.remote_function.spark_config import SparkConfig
46
+ from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter
47
+ from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
48
+ from sagemaker.core.telemetry.constants import Feature
49
+
50
+ _API_CALL_LIMIT = {
51
+ "SubmittingIntervalInSecs": 1,
52
+ "MinBatchPollingIntervalInSecs": 10,
53
+ "PollingIntervalInSecs": 0.5,
54
+ }
55
+
56
+ # Possible future states.
57
+ _PENDING = "PENDING"
58
+ _RUNNING = "RUNNING"
59
+ # The future was cancelled by the user...
60
+ _CANCELLED = "CANCELLED"
61
+ _FINISHED = "FINISHED"
62
+
63
+ logger = logging_config.get_logger()
64
+
65
+
66
+ @_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote")
67
+ def remote(
68
+ _func=None,
69
+ *,
70
+ dependencies: str = None,
71
+ pre_execution_commands: List[str] = None,
72
+ pre_execution_script: str = None,
73
+ environment_variables: Dict[str, str] = None,
74
+ image_uri: str = None,
75
+ include_local_workdir: bool = None,
76
+ custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
77
+ instance_count: int = 1,
78
+ instance_type: str = None,
79
+ job_conda_env: str = None,
80
+ job_name_prefix: str = None,
81
+ keep_alive_period_in_seconds: int = 0,
82
+ max_retry_attempts: int = 1,
83
+ max_runtime_in_seconds: int = 24 * 60 * 60,
84
+ role: str = None,
85
+ s3_kms_key: str = None,
86
+ s3_root_uri: str = None,
87
+ sagemaker_session: Session = None,
88
+ security_group_ids: List[str] = None,
89
+ subnets: List[str] = None,
90
+ tags: List[Tuple[str, str]] = None,
91
+ volume_kms_key: str = None,
92
+ volume_size: int = 30,
93
+ encrypt_inter_container_traffic: bool = None,
94
+ spark_config: SparkConfig = None,
95
+ use_spot_instances=False,
96
+ max_wait_time_in_seconds=None,
97
+ disable_output_compression: bool = False,
98
+ use_torchrun: bool = False,
99
+ use_mpirun: bool = False,
100
+ nproc_per_node: Optional[int] = None,
101
+ ):
102
+ """Decorator for running the annotated function as a SageMaker training job.
103
+
104
+ This decorator wraps the annotated code and runs it as a new SageMaker job synchronously
105
+ with the provided runtime settings.
106
+
107
+ If a parameter value is not set, the decorator first looks up the value from the SageMaker
108
+ configuration file. If no value is specified in the configuration file or no configuration file
109
+ is found, the decorator selects the default as specified below. For more information, see
110
+ `Configuring and using defaults with the SageMaker Python SDK <https://sagemaker.readthedocs.io/
111
+ en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk>`_.
112
+
113
+ Args:
114
+ _func (Optional): A Python function to run as a SageMaker training job.
115
+
116
+ dependencies (str): Either the path to a dependencies file or the reserved keyword
117
+ ``auto_capture``. Defaults to ``None``.
118
+ If ``dependencies`` is provided, the value must be one of the following:
119
+
120
+ * A path to a conda environment.yml file. The following conditions apply.
121
+
122
+ * If job_conda_env is set, then the conda environment is updated by installing
123
+ dependencies from the yaml file and the function is invoked within that
124
+ conda environment. For this to succeed, the specified conda environment must
125
+ already exist in the image.
126
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then the
127
+ conda environment is updated by installing dependencies from the yaml file and the
128
+ function is invoked within that conda environment. For this to succeed, the
129
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
130
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
131
+ * If none of the previous conditions are met, a new conda environment named
132
+ ``sagemaker-runtime-env`` is created and the function annotated with the remote
133
+ decorator is invoked in that conda environment.
134
+
135
+ * A path to a requirements.txt file. The following conditions apply.
136
+
137
+ * If ``job_conda_env`` is set in the remote decorator, dependencies are installed
138
+ within that conda environment and the function annotated with the remote decorator
139
+ is invoked in the same conda environment. For this to succeed, the specified
140
+ conda environment must already exist in the image.
141
+ * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
142
+ dependencies are installed within that conda environment and the function annotated
143
+ with the remote decorator is invoked in the same. For this to succeed, the conda
144
+ environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
145
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
146
+ * If none of the above conditions are met, conda is not used. Dependencies are
147
+ installed at the system level, without any virtual environment, and the function
148
+ annotated with the remote decorator is invoked using the Python runtime available
149
+ in the system path.
150
+
151
+ * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
152
+ generate an env_snapshot.yml corresponding to the current active conda environment’s
153
+ snapshot. You do not need to provide a dependencies file. The following conditions
154
+ apply:
155
+
156
+ * You must run the remote function within an active conda environment.
157
+ * When installing the dependencies on the training job, the same conditions as when
158
+ dependencies is set to a path to a conda environment file apply. These conditions are
159
+ as follows:
160
+
161
+ * If job_conda_env is set, then the conda environment is updated by installing
162
+ dependencies from the yaml file and the function is invoked within that
163
+ conda environment. For this to succeed, the specified conda environment must
164
+ already exist in the image.
165
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then
166
+ the conda environment is updated by installing dependencies from the yaml file
167
+ and the function is invoked within that conda environment. For this to
168
+ succeed, the conda environment name must already be set in
169
+ ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist
170
+ in the image.
171
+ * If none of the previous conditions are met, a new conda environment with name
172
+ ``sagemaker-runtime-env`` is created and the function annotated with the
173
+ remote decorator is invoked in that conda environment.
174
+
175
+ * ``None``. SageMaker will assume that there are no dependencies to install while
176
+ executing the remote annotated function in the training job.
177
+
178
+ pre_execution_commands (List[str]): List of commands to be executed prior to executing
179
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
180
+ can be specified at the same time. Defaults to None.
181
+
182
+ pre_execution_script (str): Path to script file to be executed prior to executing
183
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
184
+ can be specified at the same time. Defaults to None.
185
+
186
+ environment_variables (Dict): The environment variables used inside the decorator function.
187
+ Defaults to ``None``.
188
+
189
+ image_uri (str): The universal resource identifier (URI) location of a Docker image on
190
+ Amazon Elastic Container Registry (ECR). Defaults to the following based on where the SDK
191
+ is running:
192
+
193
+ * For users who specify ``spark_config`` and want to run the function in a Spark
194
+ application, the ``image_uri`` should be ``None``. A SageMaker Spark image will
195
+ be used for training, otherwise a ``ValueError`` is thrown.
196
+ * For users on SageMaker Studio notebooks, the image used as the kernel image for the
197
+ notebook is used.
198
+ * For other users, it is resolved to base python image with the same python version
199
+ as the environment running the local code.
200
+
201
+ If no compatible image is found, a ValueError is thrown.
202
+
203
+ include_local_workdir (bool): A flag to indicate that the remote function should include
204
+ local directories. Set to ``True`` if the remote function code imports local modules and
205
+ methods that are not available via PyPI or conda. Only python files are included.
206
+ Default value is ``False``.
207
+
208
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
209
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
210
+ that specifies the local directories and files to be included in the remote function.
211
+ If a callable is passed in, the function should follow the protocol of ``ignore`` argument
212
+ of ``shutil.copytree``. Defaults to ``None``, which means only python
213
+ files are accepted and uploaded to S3.
214
+
215
+ instance_count (int): The number of instances to use. Defaults to 1.
216
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
217
+ mpirun utilities
218
+
219
+ instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
220
+ the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
221
+
222
+ job_conda_env (str): The name of the conda environment to activate during job's runtime.
223
+ Defaults to ``None``.
224
+
225
+ job_name_prefix (str): The prefix used used to create the underlying SageMaker job.
226
+
227
+ keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse provisioned
228
+ infrastructure after the completion of a training job, also known as SageMaker managed
229
+ warm pools. The use of warmpools reduces the latency time spent to provision new
230
+ resources. The default value for ``keep_alive_period_in_seconds`` is 0.
231
+ NOTE: Additional charges associated with warm pools may apply. Using this parameter also
232
+ activates a new persistent cache feature, which will further reduce job start up
233
+ latency than over using SageMaker managed warm pools alone by caching the package source
234
+ downloaded in the previous runs.
235
+
236
+ max_retry_attempts (int): The max number of times the job is retried on
237
+ ``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
238
+
239
+ max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
240
+ this specified amount of time, SageMaker terminates the job regardless of its current
241
+ status. Defaults to 1 day or (86400 seconds).
242
+
243
+ role (str): The IAM role (either name or full ARN) used to run your SageMaker training
244
+ job. Defaults to:
245
+
246
+ * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or
247
+ SageMaker Studio Notebooks.
248
+ * if not above, a ValueError is be thrown.
249
+
250
+ s3_kms_key (str): The key used to encrypt the input and output data. Default to ``None``.
251
+
252
+ s3_root_uri (str): The root S3 folder to which the code archives and data are
253
+ uploaded to. Defaults to ``s3://<sagemaker-default-bucket>``.
254
+
255
+ sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which
256
+ SageMaker service calls are delegated to (default: None). If not provided, one is created
257
+ using a default configuration chain.
258
+
259
+ security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and the
260
+ training job is created without VPC config.
261
+
262
+ subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is created
263
+ without VPC config.
264
+
265
+ tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None`` and
266
+ the training job is created without tags.
267
+
268
+ volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an
269
+ Amazon Elastic Block Storage (EBS) volume attached to the training instance. Defaults to
270
+ ``None``.
271
+
272
+ volume_size (int): The size in GB of the storage volume for storing input and output data
273
+ during training. Defaults to ``30``.
274
+
275
+ encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
276
+ training containers is encrypted for the training job. Defaults to ``False``.
277
+
278
+ spark_config (SparkConfig): Configurations to the Spark application that runs on
279
+ Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
280
+ will be used for training. Note that ``image_uri`` can not be specified at the
281
+ same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
282
+
283
+ use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
284
+ training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
285
+ Defaults to ``False``.
286
+
287
+ max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
288
+ After this amount of time Amazon SageMaker will stop waiting for managed spot training
289
+ job to complete. Defaults to ``None``.
290
+
291
+ disable_output_compression (bool): Optional. When set to true, Model is uploaded to
292
+ Amazon S3 without compression after training finishes.
293
+
294
+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
295
+ Defaults to ``False``.
296
+
297
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
298
+ Defaults to ``False``.
299
+
300
+ nproc_per_node (int): Optional. Specifies the number of processes per node for
301
+ distributed training. Defaults to ``None``.
302
+ This is defined automatically configured on the instance type.
303
+ """
304
+
305
+ def _remote(func):
306
+
307
+ if job_conda_env:
308
+ RemoteExecutor._validate_env_name(job_conda_env)
309
+
310
+ job_settings = _JobSettings(
311
+ dependencies=dependencies,
312
+ pre_execution_commands=pre_execution_commands,
313
+ pre_execution_script=pre_execution_script,
314
+ environment_variables=environment_variables,
315
+ image_uri=image_uri,
316
+ include_local_workdir=include_local_workdir,
317
+ custom_file_filter=custom_file_filter,
318
+ instance_count=instance_count,
319
+ instance_type=instance_type,
320
+ job_conda_env=job_conda_env,
321
+ job_name_prefix=job_name_prefix,
322
+ keep_alive_period_in_seconds=keep_alive_period_in_seconds,
323
+ max_retry_attempts=max_retry_attempts,
324
+ max_runtime_in_seconds=max_runtime_in_seconds,
325
+ role=role,
326
+ s3_kms_key=s3_kms_key,
327
+ s3_root_uri=s3_root_uri,
328
+ sagemaker_session=sagemaker_session,
329
+ security_group_ids=security_group_ids,
330
+ subnets=subnets,
331
+ tags=tags,
332
+ volume_kms_key=volume_kms_key,
333
+ volume_size=volume_size,
334
+ encrypt_inter_container_traffic=encrypt_inter_container_traffic,
335
+ spark_config=spark_config,
336
+ use_spot_instances=use_spot_instances,
337
+ max_wait_time_in_seconds=max_wait_time_in_seconds,
338
+ disable_output_compression=disable_output_compression,
339
+ use_torchrun=use_torchrun,
340
+ use_mpirun=use_mpirun,
341
+ nproc_per_node=nproc_per_node,
342
+ )
343
+
344
+ @functools.wraps(func)
345
+ def wrapper(*args, **kwargs):
346
+
347
+ if instance_count > 1 and not (
348
+ (spark_config is not None and not use_torchrun and not use_mpirun)
349
+ or (spark_config is None and use_torchrun and not use_mpirun)
350
+ or (spark_config is None and not use_torchrun and use_mpirun)
351
+ ):
352
+ raise ValueError(
353
+ "Remote function do not support training on multi instances "
354
+ + "without spark_config or use_torchrun or use_mpirun. "
355
+ + "Please provide instance_count = 1"
356
+ )
357
+
358
+ RemoteExecutor._validate_submit_args(func, *args, **kwargs)
359
+
360
+ job = _Job.start(job_settings, func, args, kwargs)
361
+
362
+ try:
363
+ job.wait()
364
+ except UnexpectedStatusException as usex:
365
+ if usex.actual_status == "Failed":
366
+ try:
367
+ exception = serialization.deserialize_exception_from_s3(
368
+ sagemaker_session=job_settings.sagemaker_session,
369
+ s3_uri=s3_path_join(
370
+ job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
371
+ ),
372
+
373
+ )
374
+ except ServiceError as serr:
375
+ chained_e = serr.__cause__
376
+ if (
377
+ isinstance(chained_e, ClientError)
378
+ and chained_e.response["Error"]["Code"] # pylint: disable=no-member
379
+ == "404"
380
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
381
+ == "Not Found"
382
+ ):
383
+ describe_result = job.describe()
384
+ if (
385
+ "FailureReason" in describe_result
386
+ and describe_result["FailureReason"]
387
+ and "RuntimeEnvironmentError: " in describe_result["FailureReason"]
388
+ ):
389
+ failure_msg = describe_result["FailureReason"].replace(
390
+ "RuntimeEnvironmentError: ", ""
391
+ )
392
+ raise RuntimeEnvironmentError(failure_msg)
393
+ raise RemoteFunctionError(
394
+ "Failed to execute remote function. "
395
+ + "Check corresponding job for details."
396
+ )
397
+ raise serr
398
+
399
+ raise exception
400
+
401
+ raise TimeoutError(
402
+ "Job for remote function timed out before reaching a termination status."
403
+ )
404
+
405
+ if job.describe()["TrainingJobStatus"] == "Completed":
406
+ return serialization.deserialize_obj_from_s3(
407
+ sagemaker_session=job_settings.sagemaker_session,
408
+ s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
409
+
410
+ )
411
+
412
+ if job.describe()["TrainingJobStatus"] == "Stopped":
413
+ raise RemoteFunctionError("Job for remote function has been aborted.")
414
+
415
+ return None
416
+
417
+ wrapper.job_settings = job_settings
418
+ wrapper.wrapped_func = func
419
+ return wrapper
420
+
421
+ if _func is None:
422
+ return _remote
423
+ return _remote(_func)
424
+
425
+
426
+ class _SubmitRequest:
427
+ """Class that holds parameters and data for creating a new job."""
428
+
429
+ def __init__(
430
+ self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
431
+ ):
432
+ self.future = future
433
+ self.job_settings = job_settings
434
+ self.func = func
435
+ self.args = func_args
436
+ self.kwargs = func_kwargs
437
+ self.run_info = run_info
438
+
439
+
440
+ def _submit_worker(executor):
441
+ """Background worker that submits job requests."""
442
+
443
+ def has_work_to_do():
444
+ return (
445
+ len(executor._pending_request_queue) > 0
446
+ and len(executor._running_jobs) < executor.max_parallel_jobs
447
+ )
448
+
449
+ try:
450
+ while True:
451
+ with executor._state_condition:
452
+ executor._state_condition.wait_for(has_work_to_do)
453
+ request = executor._pending_request_queue[0]
454
+
455
+ if request is None:
456
+ with executor._state_condition:
457
+ # remove the anchor from the pending queue
458
+ executor._pending_request_queue.popleft()
459
+ return
460
+
461
+ time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"])
462
+ # submit a new job
463
+ job = request.future._start_and_notify(
464
+ request.job_settings, request.func, request.args, request.kwargs, request.run_info
465
+ )
466
+
467
+ with executor._state_condition:
468
+ if job:
469
+ executor._running_jobs[job.job_name] = job
470
+ # remove the request from the pending queue
471
+ executor._pending_request_queue.popleft()
472
+ except Exception: # pylint: disable=broad-except
473
+ logger.exception("Error occurred while submitting CreateTrainingJob requests.")
474
+
475
+
476
+ def _polling_worker(executor):
477
+ """Background worker that polls the status of the running jobs."""
478
+ try:
479
+ while True:
480
+ with executor._state_condition:
481
+ if (
482
+ executor._shutdown
483
+ and len(executor._running_jobs) + len(executor._pending_request_queue) == 0
484
+ ):
485
+ return
486
+
487
+ time.sleep(
488
+ max(
489
+ _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"]
490
+ - len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"],
491
+ 0,
492
+ )
493
+ )
494
+
495
+ # check if running jobs are terminated
496
+ for job_name in list(executor._running_jobs.keys()):
497
+ try:
498
+ time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"])
499
+ if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [
500
+ "Completed",
501
+ "Failed",
502
+ "Stopped",
503
+ ]:
504
+ with executor._state_condition:
505
+ del executor._running_jobs[job_name]
506
+ executor._state_condition.notify_all()
507
+ except Exception as e: # pylint: disable=broad-except
508
+ if (
509
+ not isinstance(e, ClientError)
510
+ or e.response["Error"]["Code"] # pylint: disable=no-member
511
+ != "LimitExceededException"
512
+ ):
513
+ # Couldn't check the job status, move on
514
+ logger.exception(
515
+ "Error occurred while checking the status of job %s", job_name
516
+ )
517
+ with executor._state_condition:
518
+ del executor._running_jobs[job_name]
519
+ executor._state_condition.notify_all()
520
+ except Exception: # pylint: disable=broad-except
521
+ logger.exception("Error occurred while monitoring the job statuses.")
522
+
523
+
524
+ class RemoteExecutor(object):
525
+ """Run Python functions asynchronously as SageMaker jobs"""
526
+
527
+ def __init__(
528
+ self,
529
+ *,
530
+ dependencies: str = None,
531
+ pre_execution_commands: List[str] = None,
532
+ pre_execution_script: str = None,
533
+ environment_variables: Dict[str, str] = None,
534
+ image_uri: str = None,
535
+ include_local_workdir: bool = None,
536
+ custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
537
+ instance_count: int = 1,
538
+ instance_type: str = None,
539
+ job_conda_env: str = None,
540
+ job_name_prefix: str = None,
541
+ keep_alive_period_in_seconds: int = 0,
542
+ max_parallel_jobs: int = 1,
543
+ max_retry_attempts: int = 1,
544
+ max_runtime_in_seconds: int = 24 * 60 * 60,
545
+ role: str = None,
546
+ s3_kms_key: str = None,
547
+ s3_root_uri: str = None,
548
+ sagemaker_session: Session = None,
549
+ security_group_ids: List[str] = None,
550
+ subnets: List[str] = None,
551
+ tags: List[Tuple[str, str]] = None,
552
+ volume_kms_key: str = None,
553
+ volume_size: int = 30,
554
+ encrypt_inter_container_traffic: bool = None,
555
+ spark_config: SparkConfig = None,
556
+ use_spot_instances=False,
557
+ max_wait_time_in_seconds=None,
558
+ disable_output_compression: bool = False,
559
+ use_torchrun: bool = False,
560
+ use_mpirun: bool = False,
561
+ nproc_per_node: Optional[int] = None,
562
+ ):
563
+ """Constructor for RemoteExecutor
564
+
565
+ If a parameter value is not set, the constructor first looks up the value from the
566
+ SageMaker configuration file. If no value is specified in the configuration file or
567
+ no configuration file is found, the constructor selects the default as specified below.
568
+ For more information, see `Configuring and using defaults with the SageMaker Python SDK
569
+ <https://sagemaker.readthedocs.io/en/stable/overview.html
570
+ #configuring-and-using-defaults-with-the-sagemaker-python-sdk>`_.
571
+
572
+ Args:
573
+ _func (Optional): A Python function to run as a SageMaker training job.
574
+
575
+ dependencies (str): Either the path to a dependencies file or the reserved keyword
576
+ ``auto_capture``. Defaults to ``None``.
577
+ If ``dependencies`` is provided, the value must be one of the following:
578
+
579
+ * A path to a conda environment.yml file. The following conditions apply.
580
+
581
+ * If job_conda_env is set, then the conda environment is updated by installing
582
+ dependencies from the yaml file and the function is invoked within that
583
+ conda environment. For this to succeed, the specified conda environment must
584
+ already exist in the image.
585
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image, then
586
+ the conda environment is updated by installing dependencies from the yaml file and
587
+ the function is invoked within that conda environment. For this to succeed, the
588
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
589
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
590
+ * If none of the previous conditions are met, a new conda environment named
591
+ ``sagemaker-runtime-env`` is created and the function annotated with the remote
592
+ decorator is invoked in that conda environment.
593
+
594
+ * A path to a requirements.txt file. The following conditions apply.
595
+
596
+ * If ``job_conda_env`` is set in the remote decorator, dependencies are installed
597
+ within that conda environment and the function annotated with the remote decorator
598
+ is invoked in the same conda environment. For this to succeed, the specified
599
+ conda environment must already exist in the image.
600
+ * If an environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
601
+ dependencies are installed within that conda environment and the function annotated
602
+ with the remote decorator is invoked in the same. For this to succeed, the
603
+ conda environment name must already be set in ``SAGEMAKER_JOB_CONDA_ENV``, and
604
+ ``SAGEMAKER_JOB_CONDA_ENV`` must already exist in the image.
605
+ * If none of the above conditions are met, conda is not used. Dependencies are
606
+ installed at the system level, without any virtual environment, and the function
607
+ annotated with the remote decorator is invoked using the Python runtime available
608
+ in the system path.
609
+
610
+ * The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
611
+ generate an env_snapshot.yml corresponding to the current active conda environment’s
612
+ snapshot. You do not need to provide a dependencies file. The following conditions
613
+ apply:
614
+
615
+ * You must run the remote function within an active conda environment.
616
+ * When installing the dependencies on the training job, the same conditions as when
617
+ dependencies is set to a path to a conda environment file apply. These conditions
618
+ are as follows:
619
+
620
+ * If job_conda_env is set, then the conda environment is updated by installing
621
+ dependencies from the yaml file and the function is invoked within that
622
+ conda environment. For this to succeed, the specified conda environment must
623
+ already exist in the image.
624
+ * If the environment variable ``SAGEMAKER_JOB_CONDA_ENV`` is set in the image,
625
+ then the conda environment is updated by installing dependencies from the yaml
626
+ file and the function is invoked within that conda environment. For this to
627
+ succeed, the conda environment name must already be set in
628
+ ``SAGEMAKER_JOB_CONDA_ENV``, and ``SAGEMAKER_JOB_CONDA_ENV`` must already exist
629
+ in the image.
630
+ * If none of the previous conditions are met, a new conda environment with name
631
+ ``sagemaker-runtime-env`` is created and the function annotated with the
632
+ remote decorator is invoked in that conda environment.
633
+
634
+ * ``None``. SageMaker will assume that there are no dependencies to install while
635
+ executing the remote annotated function in the training job.
636
+
637
+ pre_execution_commands (List[str]): List of commands to be executed prior to executing
638
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
639
+ can be specified at the same time. Defaults to None.
640
+
641
+ pre_execution_script (str): Path to script file to be executed prior to executing
642
+ remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script``
643
+ can be specified at the same time. Defaults to None.
644
+
645
+ environment_variables (Dict): The environment variables used inside the decorator
646
+ function. Defaults to ``None``.
647
+
648
+ image_uri (str): The universal resource identifier (URI) location of a Docker image on
649
+ Amazon Elastic Container Registry (ECR). Defaults to the following based on where the
650
+ SDK is running:
651
+
652
+ * For users who specify ``spark_config`` and want to run the function in a Spark
653
+ application, the ``image_uri`` should be ``None``. A SageMaker Spark image will
654
+ be used for training, otherwise a ``ValueError`` is thrown.
655
+ * For users on SageMaker Studio notebooks, the image used as the kernel image for
656
+ the notebook is used.
657
+ * For other users, it is resolved to base python image with the same python
658
+ version as the environment running the local code.
659
+
660
+ If no compatible image is found, a ValueError is thrown.
661
+
662
+ include_local_workdir (bool): A flag to indicate that the remote function should include
663
+ local directories. Set to ``True`` if the remote function code imports local modules
664
+ and methods that are not available via PyPI or conda. Default value is ``False``.
665
+
666
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
667
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
668
+ that specifies the local directories and files to be included in the remote function.
669
+ If a callable is passed in, that function is passed to the ``ignore`` argument of
670
+ ``shutil.copytree``. Defaults to ``None``, which means only python
671
+ files are accepted and uploaded to S3.
672
+
673
+ instance_count (int): The number of instances to use. Defaults to 1.
674
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
675
+ mpirun utilities
676
+
677
+ instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
678
+ the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
679
+
680
+ job_conda_env (str): The name of the conda environment to activate during job's runtime.
681
+ Defaults to ``None``.
682
+
683
+ job_name_prefix (str): The prefix used used to create the underlying SageMaker job.
684
+
685
+ keep_alive_period_in_seconds (int): The duration in seconds to retain and reuse
686
+ provisioned infrastructure after the completion of a training job, also known as
687
+ SageMaker managed warm pools. The use of warmpools reduces the latency time spent to
688
+ provision new resources. The default value for ``keep_alive_period_in_seconds`` is 0.
689
+ NOTE: Additional charges associated with warm pools may apply. Using this parameter
690
+ also activates a new pesistent cache feature, which will further reduce job start
691
+ up latency than over using SageMaker managed warm pools alone by caching the package
692
+ source downloaded in the previous runs.
693
+
694
+ max_parallel_jobs (int): Maximum number of jobs that run in parallel. Defaults to 1.
695
+
696
+ max_retry_attempts (int): The max number of times the job is retried on
697
+ ``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
698
+
699
+ max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
700
+ this specified amount of time, SageMaker terminates the job regardless of its current
701
+ status. Defaults to 1 day or (86400 seconds).
702
+
703
+ role (str): The IAM role (either name or full ARN) used to run your SageMaker training
704
+ job. Defaults to:
705
+
706
+ * the SageMaker default IAM role if the SDK is running in SageMaker Notebooks or
707
+ SageMaker Studio Notebooks.
708
+ * if not above, a ValueError is be thrown.
709
+
710
+ s3_kms_key (str): The key used to encrypt the input and output data.
711
+ Default to ``None``.
712
+
713
+ s3_root_uri (str): The root S3 folder to which the code archives and data are
714
+ uploaded to. Defaults to ``s3://<sagemaker-default-bucket>``.
715
+
716
+ sagemaker_session (sagemaker.core.helper.session.Session): The underlying SageMaker session to which
717
+ SageMaker service calls are delegated to (default: None). If not provided, one is
718
+ created using a default configuration chain.
719
+
720
+ security_group_ids (List[str): A list of security group IDs. Defaults to ``None`` and
721
+ the training job is created without VPC config.
722
+
723
+ subnets (List[str): A list of subnet IDs. Defaults to ``None`` and the job is
724
+ created without VPC config.
725
+
726
+ tags (List[Tuple[str, str]): A list of tags attached to the job. Defaults to ``None``
727
+ and the training job is created without tags.
728
+
729
+ volume_kms_key (str): An Amazon Key Management Service (KMS) key used to encrypt an
730
+ Amazon Elastic Block Storage (EBS) volume attached to the training instance.
731
+ Defaults to ``None``.
732
+
733
+ volume_size (int): The size in GB of the storage volume for storing input and output
734
+ data during training. Defaults to ``30``.
735
+
736
+ encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
737
+ training containers is encrypted for the training job. Defaults to ``False``.
738
+
739
+ spark_config (SparkConfig): Configurations to the Spark application that runs on
740
+ Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
741
+ will be used for training. Note that ``image_uri`` can not be specified at the
742
+ same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
743
+
744
+ use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
745
+ training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
746
+ Defaults to ``False``.
747
+
748
+ max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
749
+ After this amount of time Amazon SageMaker will stop waiting for managed spot training
750
+ job to complete. Defaults to ``None``.
751
+
752
+ disable_output_compression (bool): Optional. When set to true, Model is uploaded to
753
+ Amazon S3 without compression after training finishes.
754
+
755
+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
756
+ Defaults to ``False``.
757
+
758
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
759
+ Defaults to ``False``.
760
+
761
+ nproc_per_node (int): Optional. Specifies the number of processes per node for
762
+ distributed training. Defaults to ``None``.
763
+ This is defined automatically configured on the instance type.
764
+ """
765
+ self.max_parallel_jobs = max_parallel_jobs
766
+
767
+ if self.max_parallel_jobs <= 0:
768
+ raise ValueError("max_parallel_jobs must be greater than 0.")
769
+
770
+ if instance_count > 1 and not (
771
+ (spark_config is not None and not use_torchrun and not use_mpirun)
772
+ or (spark_config is None and use_torchrun and not use_mpirun)
773
+ or (spark_config is None and not use_torchrun and use_mpirun)
774
+ ):
775
+ raise ValueError(
776
+ "Remote function do not support training on multi instances "
777
+ + "without spark_config or use_torchrun or use_mpirun. "
778
+ + "Please provide instance_count = 1"
779
+ )
780
+
781
+ if job_conda_env:
782
+ self._validate_env_name(job_conda_env)
783
+
784
+ self.job_settings = _JobSettings(
785
+ dependencies=dependencies,
786
+ pre_execution_commands=pre_execution_commands,
787
+ pre_execution_script=pre_execution_script,
788
+ environment_variables=environment_variables,
789
+ image_uri=image_uri,
790
+ include_local_workdir=include_local_workdir,
791
+ custom_file_filter=custom_file_filter,
792
+ instance_count=instance_count,
793
+ instance_type=instance_type,
794
+ job_conda_env=job_conda_env,
795
+ job_name_prefix=job_name_prefix,
796
+ keep_alive_period_in_seconds=keep_alive_period_in_seconds,
797
+ max_retry_attempts=max_retry_attempts,
798
+ max_runtime_in_seconds=max_runtime_in_seconds,
799
+ role=role,
800
+ s3_kms_key=s3_kms_key,
801
+ s3_root_uri=s3_root_uri,
802
+ sagemaker_session=sagemaker_session,
803
+ security_group_ids=security_group_ids,
804
+ subnets=subnets,
805
+ tags=tags,
806
+ volume_kms_key=volume_kms_key,
807
+ volume_size=volume_size,
808
+ encrypt_inter_container_traffic=encrypt_inter_container_traffic,
809
+ spark_config=spark_config,
810
+ use_spot_instances=use_spot_instances,
811
+ max_wait_time_in_seconds=max_wait_time_in_seconds,
812
+ disable_output_compression=disable_output_compression,
813
+ use_torchrun=use_torchrun,
814
+ use_mpirun=use_mpirun,
815
+ nproc_per_node=nproc_per_node,
816
+ )
817
+
818
+ self._state_condition = threading.Condition()
819
+ self._pending_request_queue = deque()
820
+ # For thread safety, see
821
+ # https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm
822
+ self._running_jobs = dict()
823
+ self._shutdown = False
824
+
825
+ self._workers: ThreadPoolExecutor = None
826
+
827
+ def submit(self, func, *args, **kwargs):
828
+ """Execute the input function as a SageMaker job asynchronously.
829
+
830
+ Args:
831
+ func: Python function to run as a SageMaker job.
832
+ *args: Positional arguments to the input function.
833
+ **kwargs: keyword arguments to the input function
834
+ """
835
+ if self._shutdown:
836
+ raise RuntimeError("Cannot schedule new remote function executions after shutdown")
837
+
838
+ self._validate_submit_args(func, *args, **kwargs)
839
+
840
+ with self._state_condition:
841
+ future = Future()
842
+
843
+ run_info = None
844
+ if _RunContext.get_current_run() is not None:
845
+ run = _RunContext.get_current_run()
846
+ run_info = _RunInfo(run.experiment_name, run.run_name)
847
+
848
+ self._pending_request_queue.append(
849
+ _SubmitRequest(future, self.job_settings, func, args, kwargs, run_info)
850
+ )
851
+
852
+ if self._workers is None:
853
+ self._workers = ThreadPoolExecutor(2)
854
+ self._workers.submit(_submit_worker, self)
855
+ self._workers.submit(_polling_worker, self)
856
+
857
+ self._state_condition.notify_all()
858
+
859
+ return future
860
+
861
+ def map(self, func, *iterables):
862
+ """Return an iterator that applies function to every item of iterable, yielding the results.
863
+
864
+ If additional iterables arguments are passed, function must take that many arguments and
865
+ is applied to the items from all iterables in parallel. With multiple iterables, the
866
+ iterator stops when the shortest iterable is exhausted.
867
+
868
+ Args:
869
+ func: Python function to run as a SageMaker job.
870
+ iterables: Arguments of the input python function.
871
+ """
872
+
873
+ futures = map(self.submit, itertools.repeat(func), *iterables)
874
+ return [future.result() for future in futures]
875
+
876
+ def shutdown(self):
877
+ """Prevent more function executions to be submitted to this executor."""
878
+ with self._state_condition:
879
+ self._shutdown = True
880
+
881
+ # give a signal to the submitting worker so that it doesn't block on empty queue forever
882
+ self._pending_request_queue.append(None)
883
+
884
+ self._state_condition.notify_all()
885
+
886
+ if self._workers is not None:
887
+ self._workers.shutdown(wait=True)
888
+
889
+ def __enter__(self):
890
+ """Create an executor instance and return it"""
891
+ return self
892
+
893
+ def __exit__(self, exc_type, exc_val, exc_tb):
894
+ """Make sure the executor instance is shutdown."""
895
+ self.shutdown()
896
+ return False
897
+
898
+ @staticmethod
899
+ def _validate_submit_args(func, *args, **kwargs):
900
+ """Validates input args passed to submit method."""
901
+
902
+ full_arg_spec = inspect.getfullargspec(func)
903
+
904
+ # args related validations
905
+
906
+ is_accepting_variable_positional_args = full_arg_spec.varargs is not None
907
+ num_default_positional_args = len(full_arg_spec.defaults) if full_arg_spec.defaults else 0
908
+ minimum_num_expected_positional_args = len(full_arg_spec.args) - num_default_positional_args
909
+
910
+ if not is_accepting_variable_positional_args and len(args) > len(full_arg_spec.args):
911
+ raise TypeError(
912
+ f"{func.__name__}() takes {len(full_arg_spec.args)} positional "
913
+ + f"{'arguments' if len(full_arg_spec.args) > 1 else 'argument'} but {len(args)} "
914
+ + f"{'were' if len(args) > 1 else 'was'} given."
915
+ )
916
+
917
+ if len(args) < minimum_num_expected_positional_args:
918
+ missing_positional_args = full_arg_spec.args[
919
+ len(args) : minimum_num_expected_positional_args
920
+ ]
921
+ missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args))
922
+ if missing_args:
923
+ missing_args_str = (
924
+ ", ".join(map(lambda x: f"'{x}'", missing_args[:-1]))
925
+ + f", and '{missing_args[-1]}'"
926
+ if len(missing_args) > 1
927
+ else f"'{missing_args[0]}'"
928
+ )
929
+ raise TypeError(
930
+ f"{func.__name__}() missing {len(missing_args)} required positional "
931
+ + f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}"
932
+ )
933
+
934
+ # kwargs related validations
935
+
936
+ for k in kwargs:
937
+ if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k):
938
+ raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'")
939
+ if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args:
940
+ raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'")
941
+
942
+ missing_kwargs = [
943
+ k
944
+ for k in full_arg_spec.kwonlyargs
945
+ if k not in full_arg_spec.kwonlydefaults and k not in kwargs
946
+ ]
947
+ if missing_kwargs:
948
+ missing_kwargs_string = (
949
+ ", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1]))
950
+ + f", and '{missing_kwargs[-1]}'"
951
+ if len(missing_kwargs) > 1
952
+ else f"'{missing_kwargs[0]}'"
953
+ )
954
+
955
+ raise TypeError(
956
+ f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only "
957
+ + f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: "
958
+ + f"{missing_kwargs_string}"
959
+ )
960
+
961
+ @staticmethod
962
+ def _validate_env_name(env_name: str) -> None:
963
+ """Validate conda environment name to prevent command injection.
964
+
965
+ Args:
966
+ env_name (str): The environment name to validate
967
+
968
+ Raises:
969
+ ValueError: If the environment name contains invalid characters
970
+ """
971
+
972
+ # Allow only alphanumeric, underscore, and hyphen
973
+ import re
974
+ if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
975
+ raise ValueError(
976
+ f"Invalid environment name '{env_name}'. "
977
+ "Only alphanumeric characters, underscores, and hyphens are allowed."
978
+ )
979
+
980
+
981
+ class Future(object):
982
+ """Class representing a reference to a SageMaker job result.
983
+
984
+ Reference to the SageMaker job created as a result of the remote function run. The job may
985
+ or may not have finished running.
986
+ """
987
+
988
+ def __init__(self):
989
+ self._condition = threading.Condition()
990
+ self._state = _PENDING
991
+ self._job = None
992
+ self._exception = None
993
+ self._return = None
994
+
995
+ @staticmethod
996
+ def from_describe_response(describe_training_job_response, sagemaker_session):
997
+ """Construct a Future from a describe_training_job_response object."""
998
+ future = Future()
999
+ job_exception = None
1000
+ client_exception = None
1001
+ job_return = None
1002
+ job = _Job.from_describe_response(describe_training_job_response, sagemaker_session)
1003
+ if describe_training_job_response["TrainingJobStatus"] in ["Stopping", "Stopped"]:
1004
+ state = _CANCELLED
1005
+ elif describe_training_job_response["TrainingJobStatus"] == "Completed":
1006
+ state = _FINISHED
1007
+ try:
1008
+ job_return = serialization.deserialize_obj_from_s3(
1009
+ sagemaker_session=sagemaker_session,
1010
+ s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
1011
+
1012
+ )
1013
+ except DeserializationError as e:
1014
+ client_exception = e
1015
+ except ServiceError as e:
1016
+ client_exception = e
1017
+ elif describe_training_job_response["TrainingJobStatus"] == "Failed":
1018
+ state = _FINISHED
1019
+ try:
1020
+ job_exception = serialization.deserialize_exception_from_s3(
1021
+ sagemaker_session=sagemaker_session,
1022
+ s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
1023
+
1024
+ )
1025
+ except ServiceError as serr:
1026
+ chained_e = serr.__cause__
1027
+ if (
1028
+ isinstance(chained_e, ClientError)
1029
+ and chained_e.response["Error"]["Code"] == "404" # pylint: disable=no-member
1030
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
1031
+ == "Not Found"
1032
+ ):
1033
+ if (
1034
+ "FailureReason" in describe_training_job_response
1035
+ and describe_training_job_response["FailureReason"]
1036
+ and "RuntimeEnvironmentError: "
1037
+ in describe_training_job_response["FailureReason"]
1038
+ ):
1039
+ failure_msg = describe_training_job_response["FailureReason"].replace(
1040
+ "RuntimeEnvironmentError: ", ""
1041
+ )
1042
+ job_exception = RuntimeEnvironmentError(failure_msg)
1043
+ else:
1044
+ job_exception = RemoteFunctionError(
1045
+ "Failed to execute remote function. "
1046
+ + "Check corresponding job for details."
1047
+ )
1048
+ else:
1049
+ job_exception = serr
1050
+ except DeserializationError as e:
1051
+ client_exception = e
1052
+ else:
1053
+ state = _RUNNING
1054
+
1055
+ future._job = job
1056
+ future._state = state
1057
+ future._exception = job_exception or client_exception
1058
+ future._return = job_return
1059
+ return future
1060
+
1061
+ def _start_and_notify(
1062
+ self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None
1063
+ ):
1064
+ """Start and record the newly created job in the future object.
1065
+
1066
+ The job is recorded if one is successfully started. Otherwise, the exception is
1067
+ recorded. The state update is broadcast to other waiting threads.
1068
+ """
1069
+ with self._condition:
1070
+ if self._state in [_PENDING]:
1071
+
1072
+ try:
1073
+ self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info)
1074
+ except (Exception,) as e: # pylint: disable=broad-except
1075
+ self._exception = e
1076
+ self._state = _FINISHED
1077
+ self._condition.notify_all()
1078
+ return None
1079
+
1080
+ self._state = _RUNNING
1081
+ self._condition.notify_all()
1082
+ return self._job
1083
+ return None
1084
+
1085
+ def result(self, timeout: float = None) -> Any:
1086
+ """Returns the SageMaker job result.
1087
+
1088
+ This method waits for the SageMaker job created from the remote function execution to
1089
+ complete for up to the timeout value (if specified). If timeout is ``None``,
1090
+ this method will wait until the SageMaker job completes.
1091
+
1092
+ Args:
1093
+ timeout (float): Timeout in seconds to wait until the job is completed. ``None`` by
1094
+ default.
1095
+
1096
+ Returns:
1097
+ The Python object returned by the remote function.
1098
+ """
1099
+ try:
1100
+ self.wait(timeout)
1101
+ except UnexpectedStatusException:
1102
+ pass
1103
+
1104
+ with self._condition:
1105
+ if self._state == _PENDING:
1106
+ raise RuntimeError()
1107
+
1108
+ if self._state == _RUNNING:
1109
+ if self._job.describe()["TrainingJobStatus"] == "Completed":
1110
+ self._return = serialization.deserialize_obj_from_s3(
1111
+ sagemaker_session=self._job.sagemaker_session,
1112
+ s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1113
+
1114
+ )
1115
+ self._state = _FINISHED
1116
+ return self._return
1117
+ if self._job.describe()["TrainingJobStatus"] == "Failed":
1118
+ try:
1119
+ self._exception = serialization.deserialize_exception_from_s3(
1120
+ sagemaker_session=self._job.sagemaker_session,
1121
+ s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1122
+
1123
+ )
1124
+ except ServiceError as serr:
1125
+ chained_e = serr.__cause__
1126
+ if (
1127
+ isinstance(chained_e, ClientError)
1128
+ and chained_e.response["Error"]["Code"] # pylint: disable=no-member
1129
+ == "404"
1130
+ and chained_e.response["Error"]["Message"] # pylint: disable=no-member
1131
+ == "Not Found"
1132
+ ):
1133
+ if (
1134
+ "FailureReason" in self._job.describe()
1135
+ and self._job.describe()["FailureReason"]
1136
+ and "RuntimeEnvironmentError: "
1137
+ in self._job.describe()["FailureReason"]
1138
+ ):
1139
+ failure_msg = self._job.describe()["FailureReason"].replace(
1140
+ "RuntimeEnvironmentError: ", ""
1141
+ )
1142
+ self._exception = RuntimeEnvironmentError(failure_msg)
1143
+ else:
1144
+ self._exception = RemoteFunctionError(
1145
+ "Failed to execute remote function. "
1146
+ + "Check corresponding job for details."
1147
+ )
1148
+ else:
1149
+ self._exception = serr
1150
+ self._state = _FINISHED
1151
+ elif self._job.describe()["TrainingJobStatus"] == "Stopped":
1152
+ self._state = _CANCELLED
1153
+ raise RemoteFunctionError("Job for remote function has been aborted.")
1154
+ else:
1155
+ raise TimeoutError(
1156
+ "Job for remote function timed out before reaching a termination status."
1157
+ )
1158
+
1159
+ if self._state == _FINISHED:
1160
+ if self._exception:
1161
+ raise self._exception
1162
+ return self._return
1163
+
1164
+ return None
1165
+
1166
+ def wait(
1167
+ self,
1168
+ timeout: int = None,
1169
+ ) -> None:
1170
+ """Wait for the underlying SageMaker job to complete.
1171
+
1172
+ This method waits for the SageMaker job created as a result of the remote function run
1173
+ to complete for up to the timeout value (if specified). If timeout is ``None``, this method
1174
+ will block until the job is completed.
1175
+
1176
+ Args:
1177
+ timeout (int): Timeout in seconds to wait until the job is completed before it is
1178
+ stopped. Defaults to ``None``.
1179
+
1180
+ Returns:
1181
+ None
1182
+ """
1183
+
1184
+ with self._condition:
1185
+ if self._state == _PENDING:
1186
+ self._condition.wait(timeout=timeout)
1187
+
1188
+ if self._state == _RUNNING:
1189
+ self._job.wait(timeout=timeout)
1190
+
1191
+ def cancel(self) -> bool:
1192
+ """Cancel the function execution.
1193
+
1194
+ This method prevents the SageMaker job being created or stops the underlying SageMaker job
1195
+ early if it is already in progress.
1196
+
1197
+ Returns:
1198
+ ``True`` if the underlying SageMaker job created as a result of the remote function
1199
+ run is cancelled.
1200
+ """
1201
+ with self._condition:
1202
+ if self._state == _FINISHED:
1203
+ return False
1204
+ if self._state == _CANCELLED:
1205
+ return True
1206
+
1207
+ if self._job:
1208
+ self._job.stop()
1209
+ self._state = _CANCELLED
1210
+ return True
1211
+
1212
+ def running(self) -> bool:
1213
+ """Check if the underlying SageMaker job is running.
1214
+
1215
+ Returns:
1216
+ ``True`` if the underlying SageMaker job is still running. ``False``, otherwise.
1217
+ """
1218
+ with self._condition:
1219
+ return self._state == _RUNNING
1220
+
1221
+ def cancelled(self) -> bool:
1222
+ """Check if the underlying SageMaker job was cancelled.
1223
+
1224
+ Returns:
1225
+ ``True`` if the underlying SageMaker job was cancelled. ``False``, otherwise.
1226
+ """
1227
+ with self._condition:
1228
+ return self._state == _CANCELLED
1229
+
1230
+ def done(self) -> bool:
1231
+ """Check if the underlying SageMaker job is finished.
1232
+
1233
+ Returns:
1234
+ ``True`` if the underlying SageMaker job finished running. ``False``, otherwise.
1235
+ """
1236
+ with self._condition:
1237
+ if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [
1238
+ "Completed",
1239
+ "Failed",
1240
+ ]:
1241
+ self._state = _FINISHED
1242
+ return True
1243
+
1244
+ if self._state == _FINISHED:
1245
+ return True
1246
+
1247
+ return False
1248
+
1249
+
1250
+ def get_future(job_name, sagemaker_session=None) -> Future:
1251
+ """Get a future object with information about a job with the given job_name.
1252
+
1253
+ Args:
1254
+ job_name (str): name of the underlying SageMaker job created as a result of the remote
1255
+ function run.
1256
+
1257
+ sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions
1258
+ with Amazon SageMaker APIs and any other AWS services needed.
1259
+
1260
+ Returns:
1261
+ A `sagemaker.remote_function.client.Future` instance.
1262
+ """
1263
+ if not sagemaker_session:
1264
+ sagemaker_session = Session()
1265
+ describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job(
1266
+ TrainingJobName=job_name
1267
+ )
1268
+ return Future.from_describe_response(describe_training_job_response, sagemaker_session)
1269
+
1270
+
1271
+ def list_futures(job_name_prefix, sagemaker_session=None):
1272
+ """Generates Future objects with information about jobs with given job_name_prefix.
1273
+
1274
+ Args:
1275
+ job_name_prefix (str): A prefix used to identify the SageMaker jobs associated with remote
1276
+ function run.
1277
+ sagemaker_session (sagemaker.core.helper.session.Session): A session object that manages interactions
1278
+ with Amazon SageMaker APIs and any other AWS services needed.
1279
+
1280
+ Yields:
1281
+ A `sagemaker.remote_function.client.Future` instance.
1282
+ """
1283
+ if not sagemaker_session:
1284
+ sagemaker_session = Session()
1285
+ job_name = name_from_base(job_name_prefix)
1286
+ # perform the following transformation because we might have trimmed the job_name_prefix while
1287
+ # creating the job.
1288
+ transformed_job_name_prefix = base_from_name(job_name)
1289
+ next_token = None
1290
+ list_training_job_kwargs = {"NameContains": transformed_job_name_prefix}
1291
+ while True:
1292
+ if next_token:
1293
+ list_training_job_kwargs["NextToken"] = next_token
1294
+ list_training_job_response = sagemaker_session.sagemaker_client.list_training_jobs(
1295
+ **list_training_job_kwargs
1296
+ )
1297
+ training_job_names = [
1298
+ job["TrainingJobName"] for job in list_training_job_response["TrainingJobSummaries"]
1299
+ ]
1300
+ for training_job_name in training_job_names:
1301
+ describe_training_job_response = (
1302
+ sagemaker_session.sagemaker_client.describe_training_job(
1303
+ TrainingJobName=training_job_name
1304
+ )
1305
+ )
1306
+ yield Future.from_describe_response(describe_training_job_response, sagemaker_session)
1307
+ if "NextToken" in list_training_job_response:
1308
+ next_token = list_training_job_response["NextToken"]
1309
+ else:
1310
+ break