sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,51 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Defines enum values."""
14
+
15
+ from __future__ import absolute_import
16
+
17
+ import logging
18
+ from enum import Enum
19
+
20
+
21
+ LOGGER = logging.getLogger("sagemaker")
22
+
23
+
24
+ class EndpointType(Enum):
25
+ """Types of endpoint"""
26
+
27
+ MODEL_BASED = "ModelBased" # Amazon SageMaker Model Based Endpoint
28
+ INFERENCE_COMPONENT_BASED = (
29
+ "InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
30
+ )
31
+
32
+
33
+ class RoutingStrategy(Enum):
34
+ """Strategy for routing https traffics."""
35
+
36
+ RANDOM = "RANDOM"
37
+ """The endpoint routes each request to a randomly chosen instance.
38
+ """
39
+ LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
40
+ """The endpoint routes requests to the specific instances that have
41
+ more capacity to process them.
42
+ """
43
+
44
+
45
+ class Tag(str, Enum):
46
+ """Enum class for tag keys to apply to models."""
47
+
48
+ OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
49
+ SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
50
+ FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
51
+ FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"
@@ -0,0 +1,101 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Accessors to retrieve environment variables for hosting containers."""
14
+
15
+ from __future__ import absolute_import
16
+
17
+ import logging
18
+ from typing import Dict, Optional
19
+
20
+ from sagemaker.core.jumpstart import utils as jumpstart_utils
21
+ from sagemaker.core.jumpstart import artifacts
22
+ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23
+ from sagemaker.core.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
24
+ from sagemaker.core.helper.session_helper import Session
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def retrieve_default(
30
+ region: Optional[str] = None,
31
+ model_id: Optional[str] = None,
32
+ model_version: Optional[str] = None,
33
+ hub_arn: Optional[str] = None,
34
+ tolerate_vulnerable_model: bool = False,
35
+ tolerate_deprecated_model: bool = False,
36
+ include_aws_sdk_env_vars: bool = True,
37
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
38
+ instance_type: Optional[str] = None,
39
+ script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
40
+ config_name: Optional[str] = None,
41
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
42
+ ) -> Dict[str, str]:
43
+ """Retrieves the default container environment variables for the model matching the arguments.
44
+
45
+ Args:
46
+ region (str): Optional. The AWS Region for which to retrieve the default environment
47
+ variables. (Default: None).
48
+ model_id (str): Optional. The model ID of the model for which to
49
+ retrieve the default environment variables. (Default: None).
50
+ model_version (str): Optional. The version of the model for which to retrieve the
51
+ default environment variables. (Default: None).
52
+ hub_arn (str): The arn of the SageMaker Hub for which to
53
+ retrieve model details from. (Default: None).
54
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
55
+ specifications should be tolerated (exception not raised). If False, raises an
56
+ exception if the script used by this version of the model has dependencies with known
57
+ security vulnerabilities. (Default: False).
58
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
59
+ (exception not raised). False if these models should raise an exception.
60
+ (Default: False).
61
+ include_aws_sdk_env_vars (bool): True if environment variables for low-level AWS API call
62
+ should be included. The `Model` class of the SageMaker Python SDK inserts environment
63
+ variables that would be required when making the low-level AWS API call.
64
+ (Default: True).
65
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
66
+ object, used for SageMaker interactions. If not
67
+ specified, one is created using the default AWS configuration
68
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
69
+ instance_type (str): An instance type to optionally supply in order to get environment
70
+ variables specific for the instance type.
71
+ script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
72
+ variables.
73
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
74
+ model_type (JumpStartModelType): The type of the model, can be open weights model
75
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
76
+ Returns:
77
+ dict: The variables to use for the model.
78
+
79
+ Raises:
80
+ ValueError: If the combination of arguments specified is not supported.
81
+ """
82
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
83
+ raise ValueError(
84
+ "Must specify JumpStart `model_id` and `model_version` "
85
+ "when retrieving environment variables."
86
+ )
87
+
88
+ return artifacts._retrieve_default_environment_variables(
89
+ model_id=model_id,
90
+ model_version=model_version,
91
+ hub_arn=hub_arn,
92
+ region=region,
93
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
94
+ tolerate_deprecated_model=tolerate_deprecated_model,
95
+ include_aws_sdk_env_vars=include_aws_sdk_env_vars,
96
+ sagemaker_session=sagemaker_session,
97
+ instance_type=instance_type,
98
+ script=script,
99
+ config_name=config_name,
100
+ model_type=model_type,
101
+ )
@@ -0,0 +1,108 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Custom exception classes for Sagemaker SDK"""
14
+ from __future__ import absolute_import
15
+
16
+
17
+ class UnexpectedStatusException(ValueError):
18
+ """Raised when resource status is not expected and thus not allowed for further execution"""
19
+
20
+ def __init__(self, message, allowed_statuses, actual_status):
21
+ self.allowed_statuses = allowed_statuses
22
+ self.actual_status = actual_status
23
+ super(UnexpectedStatusException, self).__init__(message)
24
+
25
+
26
+ class CapacityError(UnexpectedStatusException):
27
+ """Raised when resource status is not expected and fails with a reason of CapacityError"""
28
+
29
+
30
+ class AsyncInferenceError(Exception):
31
+ """The base exception class for Async Inference exceptions."""
32
+
33
+ fmt = "An unspecified error occurred"
34
+
35
+ def __init__(self, **kwargs):
36
+ msg = self.fmt.format(**kwargs)
37
+ Exception.__init__(self, msg)
38
+ self.kwargs = kwargs
39
+
40
+
41
+ class ObjectNotExistedError(AsyncInferenceError):
42
+ """Raised when Amazon S3 object not exist in the given path"""
43
+
44
+ fmt = "Object not exist at {output_path}. {message}"
45
+
46
+ def __init__(self, message, output_path):
47
+ super().__init__(message=message, output_path=output_path)
48
+
49
+
50
+ class PollingTimeoutError(AsyncInferenceError):
51
+ """Raised when wait longer than expected and no result object in Amazon S3 bucket yet"""
52
+
53
+ fmt = "No result at {output_path} after polling for {seconds} seconds. {message}"
54
+
55
+ def __init__(self, message, output_path, seconds):
56
+ super().__init__(message=message, output_path=output_path, seconds=seconds)
57
+
58
+
59
+ class UnexpectedClientError(AsyncInferenceError):
60
+ """Raised when ClientError's error code is not expected"""
61
+
62
+ fmt = "Encountered unexpected client error: {message}"
63
+
64
+ def __init__(self, message):
65
+ super().__init__(message=message)
66
+
67
+
68
+ class AutoMLStepInvalidModeError(Exception):
69
+ """Raised when the automl mode passed into AutoMLStep in invalid"""
70
+
71
+ fmt = (
72
+ "Mode in AutoMLJobConfig must be defined for AutoMLStep. "
73
+ "AutoMLStep currently only supports ENSEMBLING mode"
74
+ )
75
+
76
+ def __init__(self, **kwargs):
77
+ msg = self.fmt.format(**kwargs)
78
+ Exception.__init__(self, msg)
79
+ self.kwargs = kwargs
80
+
81
+
82
+ class AsyncInferenceModelError(AsyncInferenceError):
83
+ """Raised when model returns errors for failed requests"""
84
+
85
+ fmt = "Model returned error: {message} "
86
+
87
+ def __init__(self, message):
88
+ super().__init__(message=message)
89
+
90
+
91
+ class ModelStreamError(Exception):
92
+ """Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError"""
93
+
94
+ def __init__(self, message="An error occurred", code=None):
95
+ self.message = message
96
+ self.code = code
97
+ if code is not None:
98
+ super().__init__(f"{message} (Code: {code})")
99
+ else:
100
+ super().__init__(message)
101
+
102
+
103
+ class InternalStreamFailure(Exception):
104
+ """Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure"""
105
+
106
+ def __init__(self, message="An error occurred"):
107
+ self.message = message
108
+ super().__init__(self.message)
@@ -0,0 +1,53 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """SageMaker Experiments module for tracking experiments, trials, and runs."""
14
+ from __future__ import absolute_import
15
+
16
+ # Lazy imports to avoid circular dependencies during package initialization
17
+ # Users should import directly from the specific modules:
18
+ # from sagemaker.core.experiments.experiment import Experiment
19
+ # from sagemaker.core.experiments.run import Run
20
+ # etc.
21
+
22
+ __all__ = [
23
+ "Experiment",
24
+ "Run",
25
+ "_RunContext",
26
+ "_Trial",
27
+ "_TrialComponent",
28
+ ]
29
+
30
+
31
+ def __getattr__(name):
32
+ """Lazy import to avoid circular dependencies."""
33
+ if name == "Experiment":
34
+ from sagemaker.core.experiments.experiment import Experiment
35
+
36
+ return Experiment
37
+ elif name == "Run":
38
+ from sagemaker.core.experiments.run import Run
39
+
40
+ return Run
41
+ elif name == "_RunContext":
42
+ from sagemaker.core.experiments._run_context import _RunContext
43
+
44
+ return _RunContext
45
+ elif name == "_Trial":
46
+ from sagemaker.core.experiments.trial import _Trial
47
+
48
+ return _Trial
49
+ elif name == "_TrialComponent":
50
+ from sagemaker.core.experiments.trial_component import _TrialComponent
51
+
52
+ return _TrialComponent
53
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,251 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains API objects for SageMaker experiments."""
14
+ from __future__ import absolute_import
15
+
16
+ import enum
17
+ import numbers
18
+
19
+ from sagemaker.core.apiutils import _base_types
20
+
21
+
22
+ class TrialComponentMetricSummary(_base_types.ApiObject):
23
+ """Summary model of a trial component.
24
+
25
+ Attributes:
26
+ metric_name (str): The name of the metric.
27
+ source_arn (str): The ARN of the source.
28
+ time_stamp (datetime): Metric last updated value.
29
+ max (float): The max value of the metric.
30
+ min (float): The min value of the metric.
31
+ last (float): The last value of the metric.
32
+ count (float): The number of samples used to generate the metric.
33
+ avg (float): The average value of the metric.
34
+ std_dev (float): The standard deviation of the metric.
35
+ """
36
+
37
+ metric_name = None
38
+ source_arn = None
39
+ time_stamp = None
40
+ max = None
41
+ min = None
42
+ last = None
43
+ count = None
44
+ avg = None
45
+ std_dev = None
46
+
47
+ def __init__(self, metric_name=None, source_arn=None, **kwargs):
48
+ super(TrialComponentMetricSummary, self).__init__(
49
+ metric_name=metric_name, source_arn=source_arn, **kwargs
50
+ )
51
+
52
+
53
+ class TrialComponentParameters(_base_types.ApiObject):
54
+ """A dictionary of TrialComponentParameterValues"""
55
+
56
+ @classmethod
57
+ def from_boto(cls, boto_dict, **kwargs):
58
+ """Converts a boto dict to a dictionary of TrialComponentParameterValues
59
+
60
+ Args:
61
+ boto_dict (dict): boto response dictionary.
62
+ **kwargs: Arbitrary keyword arguments.
63
+
64
+ Returns:
65
+ dict: Dictionary of parameter values.
66
+ """
67
+ return_map = {}
68
+ for key, value in boto_dict.items():
69
+ return_map[key] = value.get("NumberValue", value.get("StringValue", None))
70
+ return return_map
71
+
72
+ @classmethod
73
+ def to_boto(cls, parameters):
74
+ """Converts TrialComponentParameters to dict.
75
+
76
+ Args:
77
+ parameters (TrialComponentParameters): Dictionary to convert.
78
+
79
+ Returns:
80
+ dict: Dictionary of trial component parameters in boto format.
81
+ """
82
+ boto_map = {}
83
+ for key, value in parameters.items():
84
+ if isinstance(value, numbers.Number):
85
+ boto_map[key] = {"NumberValue": value}
86
+ else:
87
+ boto_map[key] = {"StringValue": str(value)}
88
+ return boto_map
89
+
90
+
91
+ class TrialComponentArtifact(_base_types.ApiObject):
92
+ """Trial component artifact.
93
+
94
+ Attributes:
95
+ value (str): The artifact value.
96
+ media_type (str): The media type.
97
+ """
98
+
99
+ value = None
100
+ media_type = None
101
+
102
+ def __init__(self, value=None, media_type=None, **kwargs):
103
+ super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs)
104
+
105
+
106
+ class _TrialComponentStatusType(enum.Enum):
107
+ """The type of trial component status"""
108
+
109
+ InProgress = "InProgress"
110
+ Completed = "Completed"
111
+ Failed = "Failed"
112
+
113
+
114
+ class TrialComponentStatus(_base_types.ApiObject):
115
+ """Status of the trial component.
116
+
117
+ Attributes:
118
+ primary_status (str): The status of a trial component.
119
+ message (str): Status message.
120
+ """
121
+
122
+ primary_status = None
123
+ message = None
124
+
125
+ def __init__(self, primary_status=None, message=None, **kwargs):
126
+ super(TrialComponentStatus, self).__init__(
127
+ primary_status=primary_status, message=message, **kwargs
128
+ )
129
+
130
+
131
+ class TrialComponentSummary(_base_types.ApiObject):
132
+ """Summary model of a trial component.
133
+
134
+ Attributes:
135
+ trial_component_name (str): Name of trial component.
136
+ trial_component_arn (str): ARN of the trial component.
137
+ display_name (str): Friendly display name in UI.
138
+ source_arn (str): ARN of the trial component source.
139
+ status (str): Status.
140
+ start_time (datetime): Start time.
141
+ end_time (datetime): End time.
142
+ creation_time (datetime): Creation time.
143
+ created_by (str): Created by.
144
+ last_modified_time (datetime): Date last modified.
145
+ last_modified_by (datetime): User last modified.
146
+ """
147
+
148
+ _custom_boto_types = {
149
+ "status": (TrialComponentStatus, False),
150
+ }
151
+ trial_component_name = None
152
+ trial_component_arn = None
153
+ display_name = None
154
+ source_arn = None
155
+ status = None
156
+ start_time = None
157
+ end_time = None
158
+ creation_time = None
159
+ created_by = None
160
+ last_modified_time = None
161
+ last_modified_by = None
162
+
163
+
164
+ class TrialComponentSource(_base_types.ApiObject):
165
+ """Trial Component Source
166
+
167
+ Attributes:
168
+ source_arn (str): The ARN of the source.
169
+ """
170
+
171
+ source_arn = None
172
+
173
+ def __init__(self, source_arn=None, **kwargs):
174
+ super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs)
175
+
176
+
177
+ class Parent(_base_types.ApiObject):
178
+ """The trial/experiment/run that a trial component is associated with.
179
+
180
+ Attributes:
181
+ trial_name (str): Name of the trial.
182
+ experiment_name (str): Name of the experiment.
183
+ run_name (str): Name of the run.
184
+ """
185
+
186
+ trial_name = None
187
+ experiment_name = None
188
+ run_name = None
189
+
190
+
191
+ class TrialComponentSearchResult(_base_types.ApiObject):
192
+ """Summary model of an Trial Component search result.
193
+
194
+ Attributes:
195
+ trial_component_arn (str): ARN of the trial component.
196
+ trial_component_name (str): Name of the trial component.
197
+ display_name (str): Display name of the trial component for UI display.
198
+ source (dict): The source of the trial component.
199
+ status (dict): The status of the trial component.
200
+ start_time (datetime): Start time.
201
+ end_time (datetime): End time.
202
+ creation_time (datetime): Creation time.
203
+ created_by (str): Created by.
204
+ last_modified_time (datetime): Date last modified.
205
+ last_modified_by (datetime): User last modified.
206
+ parameters (dict): The hyperparameters of the component.
207
+ input_artifacts (dict): The input artifacts of the component.
208
+ output_artifacts (dict): The output artifacts of the component.
209
+ metrics (list): The metrics for the component.
210
+ source_detail (dict): The source of the trial component.
211
+ tags (list): The list of tags that are associated with the trial component.
212
+ parents (list[Parent]): The parent of trial component.
213
+ """
214
+
215
+ _custom_boto_types = {
216
+ "parents": (Parent, True), # parents is a collection (list) of Parent objects
217
+ }
218
+ trial_component_arn = None
219
+ trial_component_name = None
220
+ display_name = None
221
+ source = None
222
+ status = None
223
+ start_time = None
224
+ end_time = None
225
+ creation_time = None
226
+ created_by = None
227
+ last_modified_time = None
228
+ last_modified_by = None
229
+ parameters = None
230
+ input_artifacts = None
231
+ output_artifacts = None
232
+ metrics = None
233
+ source_detail = None
234
+ tags = None
235
+ parents = None
236
+
237
+
238
+ class TrialSummary(_base_types.ApiObject):
239
+ """Summary model of a trial.
240
+
241
+ Attributes:
242
+ trial_arn (str): The ARN of the trial.
243
+ trial_name (str): The name of the trial.
244
+ creation_time (datetime): When the trial was created.
245
+ last_modified_time (datetime): When the trial was last modified.
246
+ """
247
+
248
+ trial_arn = None
249
+ trial_name = None
250
+ creation_time = None
251
+ last_modified_time = None
@@ -0,0 +1,124 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains the _RunEnvironment class."""
14
+ from __future__ import absolute_import
15
+
16
+ import enum
17
+ import json
18
+ import logging
19
+ import os
20
+
21
+ from sagemaker.core.helper.session_helper import Session
22
+ from sagemaker.core.experiments import trial_component
23
+ from sagemaker.core.common_utils import retry_with_backoff
24
+
25
+ TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
26
+ PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
27
+ TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN"
28
+ MAX_RETRY_ATTEMPTS = 7
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class _EnvironmentType(enum.Enum):
34
+ """SageMaker jobs which data can be pulled from the environment."""
35
+
36
+ SageMakerTrainingJob = 1
37
+ SageMakerProcessingJob = 2
38
+ SageMakerTransformJob = 3
39
+
40
+
41
+ class _RunEnvironment(object):
42
+ """Retrieves job specific data from the environment."""
43
+
44
+ def __init__(self, environment_type: _EnvironmentType, source_arn: str):
45
+ """Init for _RunEnvironment.
46
+
47
+ Args:
48
+ environment_type (_EnvironmentType): The environment type.
49
+ source_arn (str): The ARN of the current job.
50
+ """
51
+ self.environment_type = environment_type
52
+ self.source_arn = source_arn
53
+
54
+ @classmethod
55
+ def load(
56
+ cls,
57
+ training_job_arn_env: str = TRAINING_JOB_ARN_ENV,
58
+ processing_job_config_path: str = PROCESSING_JOB_CONFIG_PATH,
59
+ transform_job_arn_env: str = TRANSFORM_JOB_ARN_ENV,
60
+ ):
61
+ """Loads source arn of current job from environment.
62
+
63
+ Args:
64
+ training_job_arn_env (str): The environment key for training job ARN
65
+ (default: `TRAINING_JOB_ARN`).
66
+ processing_job_config_path (str): The processing job config path
67
+ (default: `/opt/ml/config/processingjobconfig.json`).
68
+ transform_job_arn_env (str): The environment key for transform job ARN
69
+ (default: `TRANSFORM_JOB_ARN_ENV`).
70
+
71
+ Returns:
72
+ _RunEnvironment: Job data loaded from the environment. None if config does not exist.
73
+ """
74
+ if training_job_arn_env in os.environ:
75
+ environment_type = _EnvironmentType.SageMakerTrainingJob
76
+ source_arn = os.environ.get(training_job_arn_env)
77
+ return _RunEnvironment(environment_type, source_arn)
78
+ if os.path.exists(processing_job_config_path):
79
+ environment_type = _EnvironmentType.SageMakerProcessingJob
80
+ source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
81
+ return _RunEnvironment(environment_type, source_arn)
82
+ if transform_job_arn_env in os.environ:
83
+ environment_type = _EnvironmentType.SageMakerTransformJob
84
+ # TODO: need to update to get source_arn from config file once Transform side ready
85
+ source_arn = os.environ.get(transform_job_arn_env)
86
+ return _RunEnvironment(environment_type, source_arn)
87
+
88
+ return None
89
+
90
+ def get_trial_component(self, sagemaker_session: Session):
91
+ """Retrieves the trial component from the job in the environment.
92
+
93
+ Args:
94
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
95
+ manages interactions with Amazon SageMaker APIs and any other
96
+ AWS services needed. If not specified, one is created using the
97
+ default AWS configuration chain.
98
+
99
+ Returns:
100
+ _TrialComponent: The trial component created from the job. None if not found.
101
+ """
102
+
103
+ def _get_trial_component():
104
+ summaries = list(
105
+ trial_component._TrialComponent.list(
106
+ source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session
107
+ )
108
+ )
109
+ if summaries:
110
+ summary = summaries[0]
111
+ return trial_component._TrialComponent.load(
112
+ trial_component_name=summary.trial_component_name,
113
+ sagemaker_session=sagemaker_session,
114
+ )
115
+ return None
116
+
117
+ job_tc = None
118
+ try:
119
+ job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS)
120
+ except Exception as ex: # pylint: disable=broad-except
121
+ logger.error(
122
+ "Failed to get trail component in the current environment due to %s", str(ex)
123
+ )
124
+ return job_tc