sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1380 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module is the entry to run spark processing script.
14
+
15
+ This module contains code related to Spark Processors, which are used
16
+ for Processing jobs. These jobs let customers perform data pre-processing,
17
+ post-processing, feature engineering, data validation, and model evaluation
18
+ on SageMaker using Spark and PySpark.
19
+ """
20
+ from __future__ import absolute_import
21
+
22
+ import json
23
+ import logging
24
+ import os.path
25
+ import shutil
26
+ import subprocess
27
+ import tempfile
28
+ import time
29
+ import urllib.request
30
+ from enum import Enum
31
+ from io import BytesIO
32
+ from urllib.parse import urlparse
33
+ from copy import copy
34
+
35
+ from typing import Union, List, Dict, Optional
36
+
37
+ from sagemaker.core import image_uris
38
+ from sagemaker.core import s3
39
+ from sagemaker.core.local.image import _ecr_login_if_needed, _pull_image
40
+ from sagemaker.core.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
41
+ from sagemaker.core.s3 import S3Uploader
42
+ from sagemaker.core.helper.session_helper import Session
43
+ from sagemaker.core.network import NetworkConfig
44
+ from sagemaker.core.spark import defaults
45
+ from sagemaker.core.common_utils import format_tags, Tags
46
+
47
+ from sagemaker.core.workflow import is_pipeline_variable
48
+ from sagemaker.core.workflow.pipeline_context import runnable_by_pipeline
49
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
50
+ from sagemaker.core.workflow.functions import Join
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ class _SparkProcessorBase(ScriptProcessor):
56
+ """Handles Amazon SageMaker processing tasks for jobs using Spark.
57
+
58
+ Base class for either PySpark or SparkJars.
59
+ """
60
+
61
+ _default_command = "smspark-submit"
62
+ _conf_container_base_path = "/opt/ml/processing/input/"
63
+ _conf_container_input_name = "conf"
64
+ _conf_file_name = "configuration.json"
65
+
66
+ _submit_jars_input_channel_name = "jars"
67
+ _submit_files_input_channel_name = "files"
68
+ _submit_py_files_input_channel_name = "py-files"
69
+ _submit_deps_error_message = (
70
+ "Please specify a list of one or more S3 URIs, "
71
+ "local file paths, and/or local directory paths"
72
+ )
73
+
74
+ # history server vars
75
+ _history_server_port = "15050"
76
+ _history_server_url_suffix = f"/proxy/{_history_server_port}"
77
+ _spark_event_log_default_local_path = "/opt/ml/processing/spark-events/"
78
+
79
+ def __init__(
80
+ self,
81
+ role=None,
82
+ instance_type=None,
83
+ instance_count=None,
84
+ framework_version=None,
85
+ py_version=None,
86
+ container_version=None,
87
+ image_uri=None,
88
+ volume_size_in_gb=30,
89
+ volume_kms_key=None,
90
+ output_kms_key=None,
91
+ configuration_location: Optional[str] = None,
92
+ dependency_location: Optional[str] = None,
93
+ max_runtime_in_seconds=None,
94
+ base_job_name=None,
95
+ sagemaker_session=None,
96
+ env=None,
97
+ tags=None,
98
+ network_config=None,
99
+ ):
100
+ """Initialize a ``_SparkProcessorBase`` instance.
101
+
102
+ The _SparkProcessorBase handles Amazon SageMaker processing tasks for
103
+ jobs using SageMaker Spark.
104
+
105
+ Args:
106
+ framework_version (str): The version of SageMaker PySpark.
107
+ py_version (str): The version of python.
108
+ container_version (str): The version of spark container.
109
+ role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
110
+ uses this role to access AWS resources, such as
111
+ data stored in Amazon S3 (default: None).
112
+ If not specified, the value from the defaults configuration file
113
+ will be used.
114
+ instance_type (str): Type of EC2 instance to use for
115
+ processing, for example, 'ml.c4.xlarge'.
116
+ instance_count (int): The number of instances to run
117
+ the Processing job with. Defaults to 1.
118
+ volume_size_in_gb (int): Size in GB of the EBS volume to
119
+ use for storing data during processing (default: 30).
120
+ volume_kms_key (str): A KMS key for the processing
121
+ volume.
122
+ output_kms_key (str): The KMS key id for all ProcessingOutputs.
123
+ configuration_location (str): The S3 prefix URI where the user-provided EMR
124
+ application configuration will be uploaded (default: None). If not specified,
125
+ the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
126
+ dependency_location (str): The S3 prefix URI where Spark dependencies will be
127
+ uploaded (default: None). If not specified, the default ``dependency location``
128
+ is 's3://{sagemaker-default-bucket}'.
129
+ max_runtime_in_seconds (int): Timeout in seconds.
130
+ After this amount of time Amazon SageMaker terminates the job
131
+ regardless of its current status.
132
+ base_job_name (str): Prefix for processing name. If not specified,
133
+ the processor generates a default job name, based on the
134
+ training image name and current timestamp.
135
+ sagemaker_session (sagemaker.session.Session): Session object which
136
+ manages interactions with Amazon
137
+ SageMaker APIs and any other AWS services needed. If not specified,
138
+ the processor creates one using the default AWS configuration chain.
139
+ env (dict): Environment variables to be passed to the processing job.
140
+ tags (Optional[Tags]): List of tags to be passed to the processing job.
141
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
142
+ object that configures network isolation, encryption of
143
+ inter-container traffic, security group IDs, and subnets.
144
+ """
145
+ self.configuration_location = configuration_location
146
+ self.dependency_location = dependency_location
147
+ self.history_server = None
148
+ self._spark_event_logs_s3_uri = None
149
+
150
+ session = sagemaker_session or Session()
151
+ region = session.boto_region_name
152
+
153
+ self.image_uri = self._retrieve_image_uri(
154
+ image_uri, framework_version, py_version, container_version, region, instance_type
155
+ )
156
+
157
+ env = env or {}
158
+ command = [_SparkProcessorBase._default_command]
159
+
160
+ super(_SparkProcessorBase, self).__init__(
161
+ role=role,
162
+ image_uri=self.image_uri,
163
+ instance_count=instance_count,
164
+ instance_type=instance_type,
165
+ command=command,
166
+ volume_size_in_gb=volume_size_in_gb,
167
+ volume_kms_key=volume_kms_key,
168
+ output_kms_key=output_kms_key,
169
+ max_runtime_in_seconds=max_runtime_in_seconds,
170
+ base_job_name=base_job_name,
171
+ sagemaker_session=session,
172
+ env=env,
173
+ tags=format_tags(tags),
174
+ network_config=network_config,
175
+ )
176
+
177
+ def get_run_args(
178
+ self,
179
+ code,
180
+ inputs=None,
181
+ outputs=None,
182
+ arguments=None,
183
+ ):
184
+ """Returns a RunArgs object.
185
+
186
+ For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
187
+ :class:`~sagemaker.spark.processing.SparkJar`) that have special
188
+ run() arguments, this object contains the normalized arguments for passing to
189
+ :class:`~sagemaker.workflow.steps.ProcessingStep`.
190
+
191
+ Args:
192
+ code (str): This can be an S3 URI or a local path to a file with the framework
193
+ script to run.
194
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
195
+ the processing job. These must be provided as
196
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
197
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
198
+ the processing job. These can be specified as either path strings or
199
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
200
+ arguments (list[str]): A list of string arguments to be passed to a
201
+ processing job (default: None).
202
+ """
203
+ return super().get_run_args(
204
+ code=code,
205
+ inputs=inputs,
206
+ outputs=outputs,
207
+ arguments=arguments,
208
+ )
209
+
210
+ @runnable_by_pipeline
211
+ def run(
212
+ self,
213
+ submit_app,
214
+ inputs=None,
215
+ outputs=None,
216
+ arguments=None,
217
+ wait=True,
218
+ logs=True,
219
+ job_name=None,
220
+ experiment_config=None,
221
+ kms_key=None,
222
+ ):
223
+ """Runs a processing job.
224
+
225
+ Args:
226
+ submit_app (str): .py or .jar file to submit to Spark as the primary application
227
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
228
+ the processing job. These must be provided as
229
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
230
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
231
+ the processing job. These can be specified as either path strings or
232
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
233
+ arguments (list[str]): A list of string arguments to be passed to a
234
+ processing job (default: None).
235
+ wait (bool): Whether the call should wait until the job completes (default: True).
236
+ logs (bool): Whether to show the logs produced by the job.
237
+ Only meaningful when wait is True (default: True).
238
+ job_name (str): Processing job name. If not specified, the processor generates
239
+ a default job name, based on the base job name and current timestamp.
240
+ experiment_config (dict[str, str]): Experiment management configuration.
241
+ Optionally, the dict can contain three keys:
242
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
243
+ The behavior of setting these keys is as follows:
244
+ * If `ExperimentName` is supplied but `TrialName` is not a Trial will be
245
+ automatically created and the job's Trial Component associated with the Trial.
246
+ * If `TrialName` is supplied and the Trial already exists the job's Trial Component
247
+ will be associated with the Trial.
248
+ * If both `ExperimentName` and `TrialName` are not supplied the trial component
249
+ will be unassociated.
250
+ * `TrialComponentDisplayName` is used for display in Studio.
251
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
252
+ user code file (default: None).
253
+ """
254
+ self._current_job_name = self._generate_current_job_name(job_name=job_name)
255
+
256
+ if is_pipeline_variable(submit_app):
257
+ raise ValueError(
258
+ "submit_app argument has to be a valid S3 URI or local file path "
259
+ + "rather than a pipeline variable"
260
+ )
261
+
262
+ return super().run(
263
+ submit_app,
264
+ inputs,
265
+ outputs,
266
+ arguments,
267
+ wait,
268
+ logs,
269
+ job_name,
270
+ experiment_config,
271
+ kms_key,
272
+ )
273
+
274
+ def _extend_processing_args(self, inputs, outputs, **kwargs):
275
+ """Extends processing job args such as inputs."""
276
+
277
+ # make a shallow copy of user outputs
278
+ outputs = outputs or []
279
+ extended_outputs = copy(outputs)
280
+
281
+ if kwargs.get("spark_event_logs_s3_uri"):
282
+ spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
283
+ SparkConfigUtils.validate_s3_uri(spark_event_logs_s3_uri)
284
+
285
+ self._spark_event_logs_s3_uri = spark_event_logs_s3_uri
286
+ self.command.extend(
287
+ [
288
+ "--local-spark-event-logs-dir",
289
+ _SparkProcessorBase._spark_event_log_default_local_path,
290
+ ]
291
+ )
292
+
293
+ output = ProcessingOutput(
294
+ source=_SparkProcessorBase._spark_event_log_default_local_path,
295
+ destination=spark_event_logs_s3_uri,
296
+ s3_upload_mode="Continuous",
297
+ )
298
+
299
+ extended_outputs.append(output)
300
+
301
+ # make a shallow copy of user inputs
302
+ inputs = inputs or []
303
+ extended_inputs = copy(inputs)
304
+
305
+ if kwargs.get("configuration"):
306
+ configuration = kwargs.get("configuration")
307
+ SparkConfigUtils.validate_configuration(configuration)
308
+ extended_inputs.append(self._stage_configuration(configuration))
309
+
310
+ return (
311
+ extended_inputs if extended_inputs else None,
312
+ extended_outputs if extended_outputs else None,
313
+ )
314
+
315
+ def start_history_server(self, spark_event_logs_s3_uri=None):
316
+ """Starts a Spark history server.
317
+
318
+ Args:
319
+ spark_event_logs_s3_uri (str): S3 URI where Spark events are stored.
320
+ """
321
+ if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image_uri):
322
+ logger.info("Pulling spark history server image...")
323
+ _pull_image(self.image_uri)
324
+ history_server_env_variables = self._prepare_history_server_env_variables(
325
+ spark_event_logs_s3_uri
326
+ )
327
+ self.history_server = _HistoryServer(
328
+ history_server_env_variables, self.image_uri, self._get_network_config()
329
+ )
330
+ self.history_server.run()
331
+ self._check_history_server()
332
+
333
+ def terminate_history_server(self):
334
+ """Terminates the Spark history server."""
335
+ if self.history_server:
336
+ logger.info("History server is running, terminating history server")
337
+ self.history_server.down()
338
+ self.history_server = None
339
+
340
+ def _retrieve_image_uri(
341
+ self, image_uri, framework_version, py_version, container_version, region, instance_type
342
+ ):
343
+ """Builds an image URI."""
344
+ if not image_uri:
345
+ if (py_version is None) != (container_version is None):
346
+ raise ValueError(
347
+ "Both or neither of py_version and container_version should be set"
348
+ )
349
+
350
+ if container_version:
351
+ container_version = f"v{container_version}"
352
+
353
+ return image_uris.retrieve(
354
+ defaults.SPARK_NAME,
355
+ region,
356
+ version=framework_version,
357
+ instance_type=instance_type,
358
+ py_version=py_version,
359
+ container_version=container_version,
360
+ )
361
+
362
+ return image_uri
363
+
364
+ def _stage_configuration(self, configuration):
365
+ """Serializes and uploads the user-provided EMR application configuration to S3.
366
+
367
+ This method prepares an input channel.
368
+
369
+ Args:
370
+ configuration (Dict): the configuration dict for the EMR application configuration.
371
+ """
372
+ from sagemaker.core.workflow.utilities import _pipeline_config
373
+
374
+ if self.configuration_location:
375
+ if self.configuration_location.endswith("/"):
376
+ s3_prefix_uri = self.configuration_location[:-1]
377
+ else:
378
+ s3_prefix_uri = self.configuration_location
379
+ else:
380
+ s3_prefix_uri = s3.s3_path_join(
381
+ "s3://",
382
+ self.sagemaker_session.default_bucket(),
383
+ self.sagemaker_session.default_bucket_prefix,
384
+ )
385
+
386
+ serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8"))
387
+
388
+ if _pipeline_config and _pipeline_config.config_hash:
389
+ s3_uri = (
390
+ f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/{_pipeline_config.step_name}/"
391
+ f"input/{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
392
+ f"{self._conf_file_name}"
393
+ )
394
+ else:
395
+ s3_uri = (
396
+ f"{s3_prefix_uri}/{self._current_job_name}/"
397
+ f"input/{self._conf_container_input_name}/"
398
+ f"{self._conf_file_name}"
399
+ )
400
+
401
+ S3Uploader.upload_string_as_file_body(
402
+ body=serialized_configuration,
403
+ desired_s3_uri=s3_uri,
404
+ sagemaker_session=self.sagemaker_session,
405
+ )
406
+
407
+ conf_input = ProcessingInput(
408
+ source=s3_uri,
409
+ destination=f"{self._conf_container_base_path}{self._conf_container_input_name}",
410
+ input_name=_SparkProcessorBase._conf_container_input_name,
411
+ )
412
+ return conf_input
413
+
414
+ def _stage_submit_deps(self, submit_deps, input_channel_name):
415
+ """Prepares a list of paths to jars, py-files, or files dependencies.
416
+
417
+ This prepared list of paths is provided as `spark-submit` options.
418
+ The submit_deps list may include a combination of S3 URIs and local paths.
419
+ Any S3 URIs are appended to the `spark-submit` option value without modification.
420
+ Any local file paths are copied to a temp directory, uploaded to ``dependency location``,
421
+ and included as a ProcessingInput channel to provide as local files to the SageMaker
422
+ Spark container.
423
+
424
+ :param submit_deps (list[str]): List of one or more dependency paths to include.
425
+ :param input_channel_name (str): The `spark-submit` option name associated with
426
+ the input channel.
427
+ :return (Optional[ProcessingInput], str): Tuple of (left) optional ProcessingInput
428
+ for the input channel, and (right) comma-delimited value for
429
+ `spark-submit` option.
430
+ """
431
+ if not submit_deps:
432
+ raise ValueError(
433
+ f"submit_deps value may not be empty. {self._submit_deps_error_message}"
434
+ )
435
+ if not input_channel_name:
436
+ raise ValueError("input_channel_name value may not be empty.")
437
+
438
+ use_input_channel = False
439
+ spark_opt_s3_uris = []
440
+ spark_opt_s3_uris_has_pipeline_var = False
441
+
442
+ with tempfile.TemporaryDirectory() as tmpdir:
443
+ for dep_path in submit_deps:
444
+ if is_pipeline_variable(dep_path):
445
+ spark_opt_s3_uris.append(dep_path)
446
+ spark_opt_s3_uris_has_pipeline_var = True
447
+ continue
448
+ dep_url = urlparse(dep_path)
449
+ # S3 URIs are included as-is in the spark-submit argument
450
+ if dep_url.scheme in ["s3", "s3a"]:
451
+ spark_opt_s3_uris.append(dep_path)
452
+ # Local files are copied to temp directory to be uploaded to S3
453
+ elif not dep_url.scheme or dep_url.scheme == "file":
454
+ if not os.path.isfile(dep_path):
455
+ raise ValueError(
456
+ f"submit_deps path {dep_path} is not a valid local file. "
457
+ f"{self._submit_deps_error_message}"
458
+ )
459
+ logger.info(
460
+ "Copying dependency from local path %s to tmpdir %s", dep_path, tmpdir
461
+ )
462
+ shutil.copy(dep_path, tmpdir)
463
+ else:
464
+ raise ValueError(
465
+ f"submit_deps path {dep_path} references unsupported filesystem "
466
+ f"scheme: {dep_url.scheme} {self._submit_deps_error_message}"
467
+ )
468
+
469
+ # If any local files were found and copied, upload the temp directory to S3
470
+ if os.listdir(tmpdir):
471
+ from sagemaker.core.workflow.utilities import _pipeline_config
472
+
473
+ if self.dependency_location:
474
+ if self.dependency_location.endswith("/"):
475
+ s3_prefix_uri = self.dependency_location[:-1]
476
+ else:
477
+ s3_prefix_uri = self.dependency_location
478
+ else:
479
+ s3_prefix_uri = s3.s3_path_join(
480
+ "s3://",
481
+ self.sagemaker_session.default_bucket(),
482
+ self.sagemaker_session.default_bucket_prefix,
483
+ )
484
+
485
+ if _pipeline_config and _pipeline_config.code_hash:
486
+ input_channel_s3_uri = (
487
+ f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/"
488
+ f"code/{_pipeline_config.code_hash}/{input_channel_name}"
489
+ )
490
+ else:
491
+ input_channel_s3_uri = (
492
+ f"{s3_prefix_uri}/{self._current_job_name}/input/{input_channel_name}"
493
+ )
494
+ logger.info(
495
+ "Uploading dependencies from tmpdir %s to S3 %s", tmpdir, input_channel_s3_uri
496
+ )
497
+ S3Uploader.upload(
498
+ local_path=tmpdir,
499
+ desired_s3_uri=input_channel_s3_uri,
500
+ sagemaker_session=self.sagemaker_session,
501
+ )
502
+ use_input_channel = True
503
+
504
+ # If any local files were uploaded, construct a ProcessingInput to provide
505
+ # them to the Spark container and form the spark-submit option from a
506
+ # combination of S3 URIs and container's local input path
507
+ if use_input_channel:
508
+ input_channel = ProcessingInput(
509
+ source=input_channel_s3_uri,
510
+ destination=f"{self._conf_container_base_path}{input_channel_name}",
511
+ input_name=input_channel_name,
512
+ )
513
+ spark_opt = (
514
+ Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
515
+ if spark_opt_s3_uris_has_pipeline_var
516
+ else ",".join(spark_opt_s3_uris + [input_channel.destination])
517
+ )
518
+ # If no local files were uploaded, form the spark-submit option from a list of S3 URIs
519
+ else:
520
+ input_channel = None
521
+ spark_opt = (
522
+ Join(on=",", values=spark_opt_s3_uris)
523
+ if spark_opt_s3_uris_has_pipeline_var
524
+ else ",".join(spark_opt_s3_uris)
525
+ )
526
+
527
+ return input_channel, spark_opt
528
+
529
+ def _get_network_config(self):
530
+ """Runs container with different network config based on different env."""
531
+ if self._is_notebook_instance():
532
+ return "--network host"
533
+
534
+ return f"-p 80:80 -p {self._history_server_port}:{self._history_server_port}"
535
+
536
+ def _prepare_history_server_env_variables(self, spark_event_logs_s3_uri):
537
+ """Gets all parameters required to run history server."""
538
+ # prepare env varibles
539
+ history_server_env_variables = {}
540
+
541
+ if spark_event_logs_s3_uri:
542
+ history_server_env_variables[_HistoryServer.arg_event_logs_s3_uri] = (
543
+ spark_event_logs_s3_uri
544
+ )
545
+ # this variable will be previously set by run() method
546
+ elif self._spark_event_logs_s3_uri is not None:
547
+ history_server_env_variables[_HistoryServer.arg_event_logs_s3_uri] = (
548
+ self._spark_event_logs_s3_uri
549
+ )
550
+ else:
551
+ raise ValueError(
552
+ "SPARK_EVENT_LOGS_S3_URI not present. You can specify spark_event_logs_s3_uri "
553
+ "either in run() or start_history_server()"
554
+ )
555
+
556
+ history_server_env_variables.update(self._config_aws_credentials())
557
+ region = self.sagemaker_session.boto_region_name
558
+ history_server_env_variables["AWS_REGION"] = region
559
+
560
+ if self._is_notebook_instance():
561
+ history_server_env_variables[_HistoryServer.arg_remote_domain_name] = (
562
+ self._get_notebook_instance_domain()
563
+ )
564
+
565
+ return history_server_env_variables
566
+
567
+ def _is_notebook_instance(self):
568
+ """Determine whether it is a notebook instance."""
569
+ return os.path.isfile("/opt/ml/metadata/resource-metadata.json")
570
+
571
+ def _get_notebook_instance_domain(self):
572
+ """Get the instance's domain."""
573
+ region = self.sagemaker_session.boto_region_name
574
+ with open("/opt/ml/metadata/resource-metadata.json") as file:
575
+ data = json.load(file)
576
+ notebook_name = data["ResourceName"]
577
+
578
+ return f"https://{notebook_name}.notebook.{region}.sagemaker.aws"
579
+
580
+ def _check_history_server(self, ping_timeout=40):
581
+ """Print message indicating the status of history server.
582
+
583
+ Pings port 15050 to check whether the history server is up.
584
+ Times out after `ping_timeout`.
585
+
586
+ Args:
587
+ ping_timeout (int): Timeout in seconds (defaults to 40).
588
+ """
589
+ # ping port 15050 to check history server is up
590
+ timeout = time.time() + ping_timeout
591
+
592
+ while True:
593
+ if self._is_history_server_started():
594
+ if self._is_notebook_instance():
595
+ logger.info(
596
+ "History server is up on %s%s",
597
+ self._get_notebook_instance_domain(),
598
+ self._history_server_url_suffix,
599
+ )
600
+ else:
601
+ logger.info(
602
+ "History server is up on http://0.0.0.0%s", self._history_server_url_suffix
603
+ )
604
+ break
605
+ if time.time() > timeout:
606
+ logger.error(
607
+ "History server failed to start. Please run 'docker logs history_server' "
608
+ "to see logs"
609
+ )
610
+ break
611
+
612
+ time.sleep(1)
613
+
614
+ def _is_history_server_started(self):
615
+ """Check if history server is started."""
616
+ try:
617
+ response = urllib.request.urlopen(f"http://localhost:{self._history_server_port}")
618
+ return response.status == 200
619
+ except Exception: # pylint: disable=W0703
620
+ return False
621
+
622
+ # TODO (guoqioa@): method only checks urlparse scheme, need to perform deep s3 validation
623
+ def _validate_s3_uri(self, spark_output_s3_path):
624
+ """Validate whether the URI uses an S3 scheme.
625
+
626
+ In the future, this validation will perform deeper S3 validation.
627
+
628
+ Args:
629
+ spark_output_s3_path (str): The URI of the Spark output S3 Path.
630
+ """
631
+ if is_pipeline_variable(spark_output_s3_path):
632
+ return
633
+
634
+ if urlparse(spark_output_s3_path).scheme != "s3":
635
+ raise ValueError(
636
+ f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
637
+ "s3://bucket-name/folder-name"
638
+ )
639
+
640
+ def _config_aws_credentials(self):
641
+ """Configure AWS credentials."""
642
+ try:
643
+ creds = self.sagemaker_session.boto_session.get_credentials()
644
+ access_key = creds.access_key
645
+ secret_key = creds.secret_key
646
+ token = creds.token
647
+
648
+ return {
649
+ "AWS_ACCESS_KEY_ID": str(access_key),
650
+ "AWS_SECRET_ACCESS_KEY": str(secret_key),
651
+ "AWS_SESSION_TOKEN": str(token),
652
+ }
653
+ except Exception as e: # pylint: disable=W0703
654
+ logger.info("Could not get AWS credentials: %s", e)
655
+ return {}
656
+
657
+ def _handle_script_dependencies(self, inputs, submit_files, file_type):
658
+ """Handle script dependencies
659
+
660
+ The method extends inputs and command based on input files and file_type
661
+ """
662
+
663
+ if not submit_files:
664
+ return inputs
665
+
666
+ input_channel_name_dict = {
667
+ FileType.PYTHON: self._submit_py_files_input_channel_name,
668
+ FileType.JAR: self._submit_jars_input_channel_name,
669
+ FileType.FILE: self._submit_files_input_channel_name,
670
+ }
671
+
672
+ files_input, files_opt = self._stage_submit_deps(
673
+ submit_files, input_channel_name_dict[file_type]
674
+ )
675
+
676
+ inputs = inputs or []
677
+
678
+ if files_input:
679
+ inputs.append(files_input)
680
+
681
+ if files_opt:
682
+ self.command.extend([f"--{input_channel_name_dict[file_type]}", files_opt])
683
+
684
+ return inputs
685
+
686
+
687
+ class PySparkProcessor(_SparkProcessorBase):
688
+ """Handles Amazon SageMaker processing tasks for jobs using PySpark."""
689
+
690
+ def __init__(
691
+ self,
692
+ role: str = None,
693
+ instance_type: Union[str, PipelineVariable] = None,
694
+ instance_count: Union[int, PipelineVariable] = None,
695
+ framework_version: Optional[str] = None,
696
+ py_version: Optional[str] = None,
697
+ container_version: Optional[str] = None,
698
+ image_uri: Optional[Union[str, PipelineVariable]] = None,
699
+ volume_size_in_gb: Union[int, PipelineVariable] = 30,
700
+ volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
701
+ output_kms_key: Optional[Union[str, PipelineVariable]] = None,
702
+ configuration_location: Optional[str] = None,
703
+ dependency_location: Optional[str] = None,
704
+ max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
705
+ base_job_name: Optional[str] = None,
706
+ sagemaker_session: Optional[Session] = None,
707
+ env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
708
+ tags: Optional[Tags] = None,
709
+ network_config: Optional[NetworkConfig] = None,
710
+ ):
711
+ """Initialize an ``PySparkProcessor`` instance.
712
+
713
+ The PySparkProcessor handles Amazon SageMaker processing tasks for jobs
714
+ using SageMaker PySpark.
715
+
716
+ Args:
717
+ framework_version (str): The version of SageMaker PySpark.
718
+ py_version (str): The version of python.
719
+ container_version (str): The version of spark container.
720
+ role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
721
+ uses this role to access AWS resources, such as
722
+ data stored in Amazon S3 (default: None).
723
+ If not specified, the value from the defaults configuration file
724
+ will be used.
725
+ instance_type (str or PipelineVariable): Type of EC2 instance to use for
726
+ processing, for example, 'ml.c4.xlarge'.
727
+ instance_count (int or PipelineVariable): The number of instances to run
728
+ the Processing job with. Defaults to 1.
729
+ volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume to
730
+ use for storing data during processing (default: 30).
731
+ volume_kms_key (str or PipelineVariable): A KMS key for the processing
732
+ volume.
733
+ output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
734
+ configuration_location (str): The S3 prefix URI where the user-provided EMR
735
+ application configuration will be uploaded (default: None). If not specified,
736
+ the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
737
+ dependency_location (str): The S3 prefix URI where Spark dependencies will be
738
+ uploaded (default: None). If not specified, the default ``dependency location``
739
+ is 's3://{sagemaker-default-bucket}'.
740
+ max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
741
+ After this amount of time Amazon SageMaker terminates the job
742
+ regardless of its current status.
743
+ base_job_name (str): Prefix for processing name. If not specified,
744
+ the processor generates a default job name, based on the
745
+ training image name and current timestamp.
746
+ sagemaker_session (sagemaker.session.Session): Session object which
747
+ manages interactions with Amazon SageMaker APIs and any other
748
+ AWS services needed. If not specified, the processor creates one
749
+ using the default AWS configuration chain.
750
+ env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to
751
+ be passed to the processing job.
752
+ tags (Optional[Tags]): List of tags to
753
+ be passed to the processing job.
754
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
755
+ object that configures network isolation, encryption of
756
+ inter-container traffic, security group IDs, and subnets.
757
+ """
758
+
759
+ super(PySparkProcessor, self).__init__(
760
+ role=role,
761
+ instance_count=instance_count,
762
+ instance_type=instance_type,
763
+ framework_version=framework_version,
764
+ py_version=py_version,
765
+ container_version=container_version,
766
+ image_uri=image_uri,
767
+ volume_size_in_gb=volume_size_in_gb,
768
+ volume_kms_key=volume_kms_key,
769
+ output_kms_key=output_kms_key,
770
+ configuration_location=configuration_location,
771
+ dependency_location=dependency_location,
772
+ max_runtime_in_seconds=max_runtime_in_seconds,
773
+ base_job_name=base_job_name,
774
+ sagemaker_session=sagemaker_session,
775
+ env=env,
776
+ tags=format_tags(tags),
777
+ network_config=network_config,
778
+ )
779
+
780
+ def get_run_args(
781
+ self,
782
+ submit_app,
783
+ submit_py_files=None,
784
+ submit_jars=None,
785
+ submit_files=None,
786
+ inputs=None,
787
+ outputs=None,
788
+ arguments=None,
789
+ job_name=None,
790
+ configuration=None,
791
+ spark_event_logs_s3_uri=None,
792
+ ):
793
+ """Returns a RunArgs object.
794
+
795
+ This object contains the normalized inputs, outputs and arguments
796
+ needed when using a ``PySparkProcessor`` in a
797
+ :class:`~sagemaker.workflow.steps.ProcessingStep`.
798
+
799
+ Args:
800
+ submit_app (str): Path (local or S3) to Python file to submit to Spark
801
+ as the primary application. This is translated to the `code`
802
+ property on the returned `RunArgs` object.
803
+ submit_py_files (list[str]): List of paths (local or S3) to provide for
804
+ `spark-submit --py-files` option
805
+ submit_jars (list[str]): List of paths (local or S3) to provide for
806
+ `spark-submit --jars` option
807
+ submit_files (list[str]): List of paths (local or S3) to provide for
808
+ `spark-submit --files` option
809
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
810
+ the processing job. These must be provided as
811
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
812
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
813
+ the processing job. These can be specified as either path strings or
814
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
815
+ arguments (list[str]): A list of string arguments to be passed to a
816
+ processing job (default: None).
817
+ job_name (str): Processing job name. If not specified, the processor generates
818
+ a default job name, based on the base job name and current timestamp.
819
+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
820
+ List or dictionary of EMR-style classifications.
821
+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
822
+ spark_event_logs_s3_uri (str): S3 path where spark application events will
823
+ be published to.
824
+ """
825
+ self._current_job_name = self._generate_current_job_name(job_name=job_name)
826
+
827
+ if not submit_app:
828
+ raise ValueError("submit_app is required")
829
+
830
+ extended_inputs, extended_outputs = self._extend_processing_args(
831
+ inputs=inputs,
832
+ outputs=outputs,
833
+ submit_py_files=submit_py_files,
834
+ submit_jars=submit_jars,
835
+ submit_files=submit_files,
836
+ configuration=configuration,
837
+ spark_event_logs_s3_uri=spark_event_logs_s3_uri,
838
+ )
839
+
840
+ return super().get_run_args(
841
+ code=submit_app,
842
+ inputs=extended_inputs,
843
+ outputs=extended_outputs,
844
+ arguments=arguments,
845
+ )
846
+
847
+ @runnable_by_pipeline
848
+ def run(
849
+ self,
850
+ submit_app: str,
851
+ submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
852
+ submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
853
+ submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
854
+ inputs: Optional[List[ProcessingInput]] = None,
855
+ outputs: Optional[List[ProcessingOutput]] = None,
856
+ arguments: Optional[List[Union[str, PipelineVariable]]] = None,
857
+ wait: bool = True,
858
+ logs: bool = True,
859
+ job_name: Optional[str] = None,
860
+ experiment_config: Optional[Dict[str, str]] = None,
861
+ configuration: Optional[Union[List[Dict], Dict]] = None,
862
+ spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
863
+ kms_key: Optional[str] = None,
864
+ ):
865
+ """Runs a processing job.
866
+
867
+ Args:
868
+ submit_app (str): Path (local or S3) to Python file to submit to Spark
869
+ as the primary application
870
+ submit_py_files (list[str] or list[PipelineVariable]): List of paths (local or S3)
871
+ to provide for `spark-submit --py-files` option
872
+ submit_jars (list[str] or list[PipelineVariable]): List of paths (local or S3)
873
+ to provide for `spark-submit --jars` option
874
+ submit_files (list[str] or list[PipelineVariable]): List of paths (local or S3)
875
+ to provide for `spark-submit --files` option
876
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
877
+ the processing job. These must be provided as
878
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
879
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
880
+ the processing job. These can be specified as either path strings or
881
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
882
+ arguments (list[str] or list[PipelineVariable]): A list of string arguments to
883
+ be passed to a processing job (default: None).
884
+ wait (bool): Whether the call should wait until the job completes (default: True).
885
+ logs (bool): Whether to show the logs produced by the job.
886
+ Only meaningful when wait is True (default: True).
887
+ job_name (str): Processing job name. If not specified, the processor generates
888
+ a default job name, based on the base job name and current timestamp.
889
+ experiment_config (dict[str, str]): Experiment management configuration.
890
+ Optionally, the dict can contain three keys:
891
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
892
+ The behavior of setting these keys is as follows:
893
+ * If `ExperimentName` is supplied but `TrialName` is not a Trial will be
894
+ automatically created and the job's Trial Component associated with the Trial.
895
+ * If `TrialName` is supplied and the Trial already exists the job's Trial Component
896
+ will be associated with the Trial.
897
+ * If both `ExperimentName` and `TrialName` are not supplied the trial component
898
+ will be unassociated.
899
+ * `TrialComponentDisplayName` is used for display in Studio.
900
+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
901
+ List or dictionary of EMR-style classifications.
902
+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
903
+ spark_event_logs_s3_uri (str or PipelineVariable): S3 path where spark application
904
+ events will be published to.
905
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
906
+ user code file (default: None).
907
+ """
908
+ self._current_job_name = self._generate_current_job_name(job_name=job_name)
909
+
910
+ if not submit_app:
911
+ raise ValueError("submit_app is required")
912
+
913
+ extended_inputs, extended_outputs = self._extend_processing_args(
914
+ inputs=inputs,
915
+ outputs=outputs,
916
+ submit_py_files=submit_py_files,
917
+ submit_jars=submit_jars,
918
+ submit_files=submit_files,
919
+ configuration=configuration,
920
+ spark_event_logs_s3_uri=spark_event_logs_s3_uri,
921
+ )
922
+
923
+ return super().run(
924
+ submit_app=submit_app,
925
+ inputs=extended_inputs,
926
+ outputs=extended_outputs,
927
+ arguments=arguments,
928
+ wait=wait,
929
+ logs=logs,
930
+ job_name=self._current_job_name,
931
+ experiment_config=experiment_config,
932
+ kms_key=kms_key,
933
+ )
934
+
935
+ def _extend_processing_args(self, inputs, outputs, **kwargs):
936
+ """Extends inputs and outputs.
937
+
938
+ Args:
939
+ inputs: Processing inputs.
940
+ outputs: Processing outputs.
941
+ kwargs: Additional keyword arguments passed to `super()`.
942
+ """
943
+
944
+ if inputs is None:
945
+ inputs = []
946
+
947
+ # make a shallow copy of user inputs
948
+ extended_inputs = copy(inputs)
949
+
950
+ self.command = [_SparkProcessorBase._default_command]
951
+ extended_inputs = self._handle_script_dependencies(
952
+ extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON
953
+ )
954
+ extended_inputs = self._handle_script_dependencies(
955
+ extended_inputs, kwargs.get("submit_jars"), FileType.JAR
956
+ )
957
+ extended_inputs = self._handle_script_dependencies(
958
+ extended_inputs, kwargs.get("submit_files"), FileType.FILE
959
+ )
960
+
961
+ return super()._extend_processing_args(extended_inputs, outputs, **kwargs)
962
+
963
+
964
+ class SparkJarProcessor(_SparkProcessorBase):
965
+ """Handles Amazon SageMaker processing tasks for jobs using Spark with Java or Scala Jars."""
966
+
967
+ def __init__(
968
+ self,
969
+ role: str = None,
970
+ instance_type: Union[str, PipelineVariable] = None,
971
+ instance_count: Union[int, PipelineVariable] = None,
972
+ framework_version: Optional[str] = None,
973
+ py_version: Optional[str] = None,
974
+ container_version: Optional[str] = None,
975
+ image_uri: Optional[Union[str, PipelineVariable]] = None,
976
+ volume_size_in_gb: Union[int, PipelineVariable] = 30,
977
+ volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
978
+ output_kms_key: Optional[Union[str, PipelineVariable]] = None,
979
+ configuration_location: Optional[str] = None,
980
+ dependency_location: Optional[str] = None,
981
+ max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
982
+ base_job_name: Optional[str] = None,
983
+ sagemaker_session: Optional[Session] = None,
984
+ env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
985
+ tags: Optional[Tags] = None,
986
+ network_config: Optional[NetworkConfig] = None,
987
+ ):
988
+ """Initialize a ``SparkJarProcessor`` instance.
989
+
990
+ The SparkProcessor handles Amazon SageMaker processing tasks for jobs
991
+ using SageMaker Spark.
992
+
993
+ Args:
994
+ framework_version (str): The version of SageMaker PySpark.
995
+ py_version (str): The version of python.
996
+ container_version (str): The version of spark container.
997
+ role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
998
+ uses this role to access AWS resources, such as
999
+ data stored in Amazon S3 (default: None).
1000
+ If not specified, the value from the defaults configuration file
1001
+ will be used.
1002
+ instance_type (str or PipelineVariable): Type of EC2 instance to use for
1003
+ processing, for example, 'ml.c4.xlarge'.
1004
+ instance_count (int or PipelineVariable): The number of instances to run
1005
+ the Processing job with. Defaults to 1.
1006
+ volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume to
1007
+ use for storing data during processing (default: 30).
1008
+ volume_kms_key (str or PipelineVariable): A KMS key for the processing
1009
+ volume.
1010
+ output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
1011
+ configuration_location (str): The S3 prefix URI where the user-provided EMR
1012
+ application configuration will be uploaded (default: None). If not specified,
1013
+ the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
1014
+ dependency_location (str): The S3 prefix URI where Spark dependencies will be
1015
+ uploaded (default: None). If not specified, the default ``dependency location``
1016
+ is 's3://{sagemaker-default-bucket}'.
1017
+ max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
1018
+ After this amount of time Amazon SageMaker terminates the job
1019
+ regardless of its current status.
1020
+ base_job_name (str): Prefix for processing name. If not specified,
1021
+ the processor generates a default job name, based on the
1022
+ training image name and current timestamp.
1023
+ sagemaker_session (sagemaker.session.Session): Session object which
1024
+ manages interactions with Amazon SageMaker APIs and any other
1025
+ AWS services needed. If not specified, the processor creates one
1026
+ using the default AWS configuration chain.
1027
+ env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to
1028
+ be passed to the processing job.
1029
+ tags (Optional[Tags]): Tags to be passed to the processing job.
1030
+ network_config (sagemaker.network.NetworkConfig): A NetworkConfig
1031
+ object that configures network isolation, encryption of
1032
+ inter-container traffic, security group IDs, and subnets.
1033
+ """
1034
+
1035
+ super(SparkJarProcessor, self).__init__(
1036
+ role=role,
1037
+ instance_count=instance_count,
1038
+ instance_type=instance_type,
1039
+ framework_version=framework_version,
1040
+ py_version=py_version,
1041
+ container_version=container_version,
1042
+ image_uri=image_uri,
1043
+ volume_size_in_gb=volume_size_in_gb,
1044
+ volume_kms_key=volume_kms_key,
1045
+ output_kms_key=output_kms_key,
1046
+ configuration_location=configuration_location,
1047
+ dependency_location=dependency_location,
1048
+ max_runtime_in_seconds=max_runtime_in_seconds,
1049
+ base_job_name=base_job_name,
1050
+ sagemaker_session=sagemaker_session,
1051
+ env=env,
1052
+ tags=format_tags(tags),
1053
+ network_config=network_config,
1054
+ )
1055
+
1056
+ def get_run_args(
1057
+ self,
1058
+ submit_app,
1059
+ submit_class=None,
1060
+ submit_jars=None,
1061
+ submit_files=None,
1062
+ inputs=None,
1063
+ outputs=None,
1064
+ arguments=None,
1065
+ job_name=None,
1066
+ configuration=None,
1067
+ spark_event_logs_s3_uri=None,
1068
+ ):
1069
+ """Returns a RunArgs object.
1070
+
1071
+ This object contains the normalized inputs, outputs and arguments
1072
+ needed when using a ``SparkJarProcessor`` in a
1073
+ :class:`~sagemaker.workflow.steps.ProcessingStep`.
1074
+
1075
+ Args:
1076
+ submit_app (str): Path (local or S3) to Python file to submit to Spark
1077
+ as the primary application. This is translated to the `code`
1078
+ property on the returned `RunArgs` object
1079
+ submit_class (str): Java class reference to submit to Spark as the primary
1080
+ application
1081
+ submit_jars (list[str]): List of paths (local or S3) to provide for
1082
+ `spark-submit --jars` option
1083
+ submit_files (list[str]): List of paths (local or S3) to provide for
1084
+ `spark-submit --files` option
1085
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1086
+ the processing job. These must be provided as
1087
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1088
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1089
+ the processing job. These can be specified as either path strings or
1090
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1091
+ arguments (list[str]): A list of string arguments to be passed to a
1092
+ processing job (default: None).
1093
+ job_name (str): Processing job name. If not specified, the processor generates
1094
+ a default job name, based on the base job name and current timestamp.
1095
+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
1096
+ List or dictionary of EMR-style classifications.
1097
+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1098
+ spark_event_logs_s3_uri (str): S3 path where spark application events will
1099
+ be published to.
1100
+ """
1101
+ self._current_job_name = self._generate_current_job_name(job_name=job_name)
1102
+
1103
+ if not submit_app:
1104
+ raise ValueError("submit_app is required")
1105
+
1106
+ extended_inputs, extended_outputs = self._extend_processing_args(
1107
+ inputs=inputs,
1108
+ outputs=outputs,
1109
+ submit_class=submit_class,
1110
+ submit_jars=submit_jars,
1111
+ submit_files=submit_files,
1112
+ configuration=configuration,
1113
+ spark_event_logs_s3_uri=spark_event_logs_s3_uri,
1114
+ )
1115
+
1116
+ return super().get_run_args(
1117
+ code=submit_app,
1118
+ inputs=extended_inputs,
1119
+ outputs=extended_outputs,
1120
+ arguments=arguments,
1121
+ )
1122
+
1123
+ @runnable_by_pipeline
1124
+ def run(
1125
+ self,
1126
+ submit_app: str,
1127
+ submit_class: Union[str, PipelineVariable],
1128
+ submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
1129
+ submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
1130
+ inputs: Optional[List[ProcessingInput]] = None,
1131
+ outputs: Optional[List[ProcessingOutput]] = None,
1132
+ arguments: Optional[List[Union[str, PipelineVariable]]] = None,
1133
+ wait: bool = True,
1134
+ logs: bool = True,
1135
+ job_name: Optional[str] = None,
1136
+ experiment_config: Optional[Dict[str, str]] = None,
1137
+ configuration: Optional[Union[List[Dict], Dict]] = None,
1138
+ spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
1139
+ kms_key: Optional[str] = None,
1140
+ ):
1141
+ """Runs a processing job.
1142
+
1143
+ Args:
1144
+ submit_app (str): Path (local or S3) to Jar file to submit to Spark as
1145
+ the primary application
1146
+ submit_class (str or PipelineVariable): Java class reference to submit to Spark
1147
+ as the primary application
1148
+ submit_jars (list[str] or list[PipelineVariable]): List of paths (local or S3)
1149
+ to provide for `spark-submit --jars` option
1150
+ submit_files (list[str] or list[PipelineVariable]): List of paths (local or S3)
1151
+ to provide for `spark-submit --files` option
1152
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1153
+ the processing job. These must be provided as
1154
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1155
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1156
+ the processing job. These can be specified as either path strings or
1157
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1158
+ arguments (list[str] or list[PipelineVariable]): A list of string arguments to
1159
+ be passed to a processing job (default: None).
1160
+ wait (bool): Whether the call should wait until the job completes (default: True).
1161
+ logs (bool): Whether to show the logs produced by the job.
1162
+ Only meaningful when wait is True (default: True).
1163
+ job_name (str): Processing job name. If not specified, the processor generates
1164
+ a default job name, based on the base job name and current timestamp.
1165
+ experiment_config (dict[str, str]): Experiment management configuration.
1166
+ Optionally, the dict can contain three keys:
1167
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1168
+ The behavior of setting these keys is as follows:
1169
+ * If `ExperimentName` is supplied but `TrialName` is not a Trial will be
1170
+ automatically created and the job's Trial Component associated with the Trial.
1171
+ * If `TrialName` is supplied and the Trial already exists the job's Trial Component
1172
+ will be associated with the Trial.
1173
+ * If both `ExperimentName` and `TrialName` are not supplied the trial component
1174
+ will be unassociated.
1175
+ * `TrialComponentDisplayName` is used for display in Studio.
1176
+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
1177
+ List or dictionary of EMR-style classifications.
1178
+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1179
+ spark_event_logs_s3_uri (str or PipelineVariable): S3 path where spark application
1180
+ events will be published to.
1181
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
1182
+ user code file (default: None).
1183
+ """
1184
+ self._current_job_name = self._generate_current_job_name(job_name=job_name)
1185
+
1186
+ if not submit_app:
1187
+ raise ValueError("submit_app is required")
1188
+
1189
+ extended_inputs, extended_outputs = self._extend_processing_args(
1190
+ inputs=inputs,
1191
+ outputs=outputs,
1192
+ submit_class=submit_class,
1193
+ submit_jars=submit_jars,
1194
+ submit_files=submit_files,
1195
+ configuration=configuration,
1196
+ spark_event_logs_s3_uri=spark_event_logs_s3_uri,
1197
+ )
1198
+
1199
+ return super().run(
1200
+ submit_app=submit_app,
1201
+ inputs=extended_inputs,
1202
+ outputs=extended_outputs,
1203
+ arguments=arguments,
1204
+ wait=wait,
1205
+ logs=logs,
1206
+ job_name=self._current_job_name,
1207
+ experiment_config=experiment_config,
1208
+ kms_key=kms_key,
1209
+ )
1210
+
1211
+ def _extend_processing_args(self, inputs, outputs, **kwargs):
1212
+ self.command = [_SparkProcessorBase._default_command]
1213
+ if kwargs.get("submit_class"):
1214
+ self.command.extend(["--class", kwargs.get("submit_class")])
1215
+ else:
1216
+ raise ValueError("submit_class is required")
1217
+
1218
+ if inputs is None:
1219
+ inputs = []
1220
+
1221
+ # make a shallow copy of user inputs
1222
+ extended_inputs = copy(inputs)
1223
+
1224
+ extended_inputs = self._handle_script_dependencies(
1225
+ extended_inputs, kwargs.get("submit_jars"), FileType.JAR
1226
+ )
1227
+ extended_inputs = self._handle_script_dependencies(
1228
+ extended_inputs, kwargs.get("submit_files"), FileType.FILE
1229
+ )
1230
+
1231
+ return super()._extend_processing_args(extended_inputs, outputs, **kwargs)
1232
+
1233
+
1234
+ class _HistoryServer:
1235
+ """History server class that is responsible for starting history server."""
1236
+
1237
+ _container_name = "history_server"
1238
+ _entry_point = "smspark-history-server"
1239
+ arg_event_logs_s3_uri = "event_logs_s3_uri"
1240
+ arg_remote_domain_name = "remote_domain_name"
1241
+
1242
+ _history_server_args_format_map = {
1243
+ arg_event_logs_s3_uri: "--event-logs-s3-uri {} ",
1244
+ arg_remote_domain_name: "--remote-domain-name {} ",
1245
+ }
1246
+
1247
+ def __init__(self, cli_args, image_uri, network_config):
1248
+ self.cli_args = cli_args
1249
+ self.image_uri = image_uri
1250
+ self.network_config = network_config
1251
+ self.run_history_server_command = self._get_run_history_server_cmd()
1252
+
1253
+ def run(self):
1254
+ """Runs the history server."""
1255
+ self.down()
1256
+ logger.info("Starting history server...")
1257
+ subprocess.Popen(
1258
+ self.run_history_server_command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
1259
+ )
1260
+
1261
+ def down(self):
1262
+ """Stops and removes the container."""
1263
+ subprocess.call(["docker", "stop", self._container_name])
1264
+ subprocess.call(["docker", "rm", self._container_name])
1265
+ logger.info("History server terminated")
1266
+
1267
+ # This method belongs to _HistoryServer because _CONTAINER_NAME(app name) belongs
1268
+ # to _HistoryServer. In the future, dynamically creating new app name, available
1269
+ # port should also belong to _HistoryServer rather than PySparkProcessor
1270
+ def _get_run_history_server_cmd(self):
1271
+ """Gets the history server command."""
1272
+ env_options = ""
1273
+ ser_cli_args = ""
1274
+ for key, value in self.cli_args.items():
1275
+ if key in self._history_server_args_format_map:
1276
+ ser_cli_args += self._history_server_args_format_map[key].format(value)
1277
+ else:
1278
+ env_options += f"--env {key}={value} "
1279
+
1280
+ cmd = (
1281
+ f"docker run {env_options.strip()} --name {self._container_name} "
1282
+ f"{self.network_config} --entrypoint {self._entry_point} {self.image_uri} "
1283
+ f"{ser_cli_args.strip()}"
1284
+ )
1285
+
1286
+ return cmd
1287
+
1288
+
1289
+ class FileType(Enum):
1290
+ """Enum of file type"""
1291
+
1292
+ JAR = 1
1293
+ PYTHON = 2
1294
+ FILE = 3
1295
+
1296
+
1297
+ class SparkConfigUtils:
1298
+ """Util class for spark configurations"""
1299
+
1300
+ _valid_configuration_keys = ["Classification", "Properties", "Configurations"]
1301
+ _valid_configuration_classifications = [
1302
+ "core-site",
1303
+ "hadoop-env",
1304
+ "hadoop-log4j",
1305
+ "hive-env",
1306
+ "hive-log4j",
1307
+ "hive-exec-log4j",
1308
+ "hive-site",
1309
+ "spark-defaults",
1310
+ "spark-env",
1311
+ "spark-log4j",
1312
+ "spark-hive-site",
1313
+ "spark-metrics",
1314
+ "yarn-env",
1315
+ "yarn-site",
1316
+ "export",
1317
+ ]
1318
+
1319
+ @staticmethod
1320
+ def validate_configuration(configuration: Dict):
1321
+ """Validates the user-provided Hadoop/Spark/Hive configuration.
1322
+
1323
+ This ensures that the list or dictionary the user provides will serialize to
1324
+ JSON matching the schema of EMR's application configuration
1325
+
1326
+ Args:
1327
+ configuration (Dict): A dict that contains the configuration overrides to
1328
+ the default values. For more information, please visit:
1329
+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1330
+ """
1331
+ emr_configure_apps_url = (
1332
+ "https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html"
1333
+ )
1334
+ if isinstance(configuration, dict):
1335
+ keys = configuration.keys()
1336
+ if "Classification" not in keys or "Properties" not in keys:
1337
+ raise ValueError(
1338
+ f"Missing one or more required keys in configuration dictionary "
1339
+ f"{configuration} Please see {emr_configure_apps_url} for more information"
1340
+ )
1341
+
1342
+ for key in keys:
1343
+ if key not in SparkConfigUtils._valid_configuration_keys:
1344
+ raise ValueError(
1345
+ f"Invalid key: {key}. "
1346
+ f"Must be one of {SparkConfigUtils._valid_configuration_keys}. "
1347
+ f"Please see {emr_configure_apps_url} for more information."
1348
+ )
1349
+ if key == "Classification":
1350
+ if (
1351
+ configuration[key]
1352
+ not in SparkConfigUtils._valid_configuration_classifications
1353
+ ):
1354
+ raise ValueError(
1355
+ f"Invalid classification: {key}. Must be one of "
1356
+ f"{SparkConfigUtils._valid_configuration_classifications}"
1357
+ )
1358
+
1359
+ if isinstance(configuration, list):
1360
+ for item in configuration:
1361
+ SparkConfigUtils.validate_configuration(item)
1362
+
1363
+ # TODO (guoqioa@): method only checks urlparse scheme, need to perform deep s3 validation
1364
+ @staticmethod
1365
+ def validate_s3_uri(spark_output_s3_path):
1366
+ """Validate whether the URI uses an S3 scheme.
1367
+
1368
+ In the future, this validation will perform deeper S3 validation.
1369
+
1370
+ Args:
1371
+ spark_output_s3_path (str): The URI of the Spark output S3 Path.
1372
+ """
1373
+ if is_pipeline_variable(spark_output_s3_path):
1374
+ return
1375
+
1376
+ if urlparse(spark_output_s3_path).scheme != "s3":
1377
+ raise ValueError(
1378
+ f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
1379
+ "s3://bucket-name/folder-name"
1380
+ )