sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,29 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """HuggingFace framework support for SageMaker."""
14
+ from __future__ import absolute_import
15
+
16
+ from sagemaker.core.huggingface.estimator import HuggingFace # noqa: F401
17
+ from sagemaker.core.huggingface.llm_utils import get_huggingface_llm_image_uri # noqa: F401
18
+ from sagemaker.core.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
19
+ from sagemaker.core.huggingface.processing import HuggingFaceProcessor # noqa: F401
20
+ from sagemaker.core.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401
21
+
22
+ __all__ = [
23
+ "HuggingFace",
24
+ "HuggingFaceModel",
25
+ "HuggingFacePredictor",
26
+ "HuggingFaceProcessor",
27
+ "TrainingCompilerConfig",
28
+ "get_huggingface_llm_image_uri",
29
+ ]
@@ -0,0 +1,150 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ from typing import Optional
18
+ import importlib.util
19
+
20
+ import urllib.request
21
+ from urllib.error import HTTPError, URLError
22
+ import json
23
+ from json import JSONDecodeError
24
+ import logging
25
+ from sagemaker.core import image_uris
26
+ from sagemaker.core.helper.session_helper import Session
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def get_huggingface_llm_image_uri(
32
+ backend: str,
33
+ session: Optional[Session] = None,
34
+ region: Optional[str] = None,
35
+ version: Optional[str] = None,
36
+ ) -> str:
37
+ """Retrieves the image URI for inference.
38
+
39
+ Args:
40
+ backend (str): The backend to use. Valid values include "huggingface" and "lmi".
41
+ session (Session): The SageMaker Session to use. (Default: None).
42
+ region (str): The AWS region to use for image URI. (default: None).
43
+ version (str): The framework version for which to retrieve an
44
+ image URI. If no version is set, defaults to latest version. (default: None).
45
+
46
+ Returns:
47
+ str: The image URI string.
48
+ """
49
+
50
+ if region is None:
51
+ if session is None:
52
+ region = Session().boto_session.region_name
53
+ else:
54
+ region = session.boto_session.region_name
55
+ if backend == "huggingface":
56
+ return image_uris.retrieve(
57
+ "huggingface-llm",
58
+ region=region,
59
+ version=version,
60
+ image_scope="inference",
61
+ )
62
+ if backend == "huggingface-neuronx":
63
+ return image_uris.retrieve(
64
+ "huggingface-llm-neuronx",
65
+ region=region,
66
+ version=version,
67
+ image_scope="inference",
68
+ inference_tool="neuronx",
69
+ )
70
+ if backend == "huggingface-tei":
71
+ return image_uris.retrieve(
72
+ "huggingface-tei",
73
+ region=region,
74
+ version=version,
75
+ image_scope="inference",
76
+ )
77
+ if backend == "huggingface-tei-cpu":
78
+ return image_uris.retrieve(
79
+ "huggingface-tei-cpu",
80
+ region=region,
81
+ version=version,
82
+ image_scope="inference",
83
+ )
84
+ if backend == "lmi":
85
+ version = version or "0.24.0"
86
+ return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
87
+ raise ValueError("Unsupported backend: %s" % backend)
88
+
89
+
90
+ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
91
+ """Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
92
+
93
+ Args:
94
+ model_id (str): The HuggingFace Model ID
95
+ hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
96
+
97
+ Returns:
98
+ dict: The model metadata retrieved with the HuggingFace API
99
+ """
100
+ if not model_id:
101
+ raise ValueError("Model ID is empty. Please provide a valid Model ID.")
102
+ hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
103
+ hf_model_metadata_json = None
104
+ try:
105
+ if hf_hub_token:
106
+ hf_model_metadata_url = urllib.request.Request(
107
+ hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
108
+ )
109
+ with urllib.request.urlopen(hf_model_metadata_url) as response:
110
+ hf_model_metadata_json = json.load(response)
111
+ except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
112
+ if "HTTP Error 401: Unauthorized" in str(e):
113
+ raise ValueError(
114
+ "Trying to access a gated/private HuggingFace model without valid credentials. "
115
+ "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
116
+ )
117
+ logger.warning(
118
+ "Exception encountered while trying to retrieve HuggingFace model metadata %s. "
119
+ "Details: %s",
120
+ hf_model_metadata_url,
121
+ e,
122
+ )
123
+ if not hf_model_metadata_json:
124
+ raise ValueError(
125
+ "Did not find model metadata for the following HuggingFace Model ID %s" % model_id
126
+ )
127
+ return hf_model_metadata_json
128
+
129
+
130
+ def download_huggingface_model_metadata(
131
+ model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
132
+ ) -> None:
133
+ """Downloads the HuggingFace Model snapshot via HuggingFace API.
134
+
135
+ Args:
136
+ model_id (str): The HuggingFace Model ID
137
+ model_local_path (str): The local path to save the HuggingFace Model snapshot.
138
+ hf_hub_token (str): The HuggingFace Hub Token
139
+
140
+ Raises:
141
+ ImportError: If huggingface_hub is not installed.
142
+ """
143
+ if not importlib.util.find_spec("huggingface_hub"):
144
+ raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed")
145
+
146
+ from huggingface_hub import snapshot_download
147
+
148
+ os.makedirs(model_local_path, exist_ok=True)
149
+ logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
150
+ snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)
@@ -0,0 +1,139 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains code related to HuggingFace Processors which are used for Processing jobs.
14
+
15
+ These jobs let customers perform data pre-processing, post-processing, feature engineering,
16
+ data validation, and model evaluation and interpretation on SageMaker.
17
+ """
18
+ from __future__ import absolute_import
19
+
20
+ from typing import Union, Optional, List, Dict
21
+
22
+ from sagemaker.core.helper.session_helper import Session
23
+ from sagemaker.core.network import NetworkConfig
24
+ from sagemaker.core.processing import FrameworkProcessor
25
+ from sagemaker.core.huggingface.estimator import HuggingFace
26
+
27
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
28
+ from sagemaker.core.common_utils import format_tags, Tags
29
+
30
+
31
+ class HuggingFaceProcessor(FrameworkProcessor):
32
+ """Handles Amazon SageMaker processing tasks for jobs using HuggingFace containers."""
33
+
34
+ estimator_cls = HuggingFace
35
+
36
+ def __init__(
37
+ self,
38
+ role: Optional[Union[str, PipelineVariable]] = None,
39
+ instance_count: Union[int, PipelineVariable] = None,
40
+ instance_type: Union[str, PipelineVariable] = None,
41
+ transformers_version: Optional[str] = None,
42
+ tensorflow_version: Optional[str] = None,
43
+ pytorch_version: Optional[str] = None,
44
+ py_version: str = "py36",
45
+ image_uri: Optional[Union[str, PipelineVariable]] = None,
46
+ command: Optional[List[str]] = None,
47
+ volume_size_in_gb: Union[int, PipelineVariable] = 30,
48
+ volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
49
+ output_kms_key: Optional[Union[str, PipelineVariable]] = None,
50
+ code_location: Optional[str] = None,
51
+ max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
52
+ base_job_name: Optional[str] = None,
53
+ sagemaker_session: Optional[Session] = None,
54
+ env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
55
+ tags: Optional[Tags] = None,
56
+ network_config: Optional[NetworkConfig] = None,
57
+ ):
58
+ """This processor executes a Python script in a HuggingFace execution environment.
59
+
60
+ Unless ``image_uri`` is specified, the environment is an Amazon-built Docker container
61
+ that executes functions defined in the supplied ``code`` Python script.
62
+
63
+ The arguments have the same meaning as in ``FrameworkProcessor``, with the following
64
+ exceptions.
65
+
66
+ Args:
67
+ transformers_version (str): Transformers version you want to use for
68
+ executing your model training code. Defaults to ``None``. Required unless
69
+ ``image_uri`` is provided. The current supported version is ``4.4.2``.
70
+ tensorflow_version (str): TensorFlow version you want to use for
71
+ executing your model training code. Defaults to ``None``. Required unless
72
+ ``pytorch_version`` is provided. The current supported version is ``2.4.1``.
73
+ pytorch_version (str): PyTorch version you want to use for
74
+ executing your model training code. Defaults to ``None``. Required unless
75
+ ``tensorflow_version`` is provided. The current supported version is ``1.6.0``.
76
+ py_version (str): Python version you want to use for executing your model training
77
+ code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
78
+ using PyTorch, the current supported version is ``py36``. If using TensorFlow,
79
+ the current supported version is ``py37``.
80
+
81
+ .. tip::
82
+
83
+ You can find additional parameters for initializing this class at
84
+ :class:`~sagemaker.processing.FrameworkProcessor`.
85
+ """
86
+ self.pytorch_version = pytorch_version
87
+ self.tensorflow_version = tensorflow_version
88
+ super().__init__(
89
+ self.estimator_cls,
90
+ transformers_version,
91
+ role,
92
+ instance_count,
93
+ instance_type,
94
+ py_version,
95
+ image_uri,
96
+ command,
97
+ volume_size_in_gb,
98
+ volume_kms_key,
99
+ output_kms_key,
100
+ code_location,
101
+ max_runtime_in_seconds,
102
+ base_job_name,
103
+ sagemaker_session,
104
+ env,
105
+ format_tags(tags),
106
+ network_config,
107
+ )
108
+
109
+ def _create_estimator(
110
+ self,
111
+ entry_point="",
112
+ source_dir=None,
113
+ dependencies=None,
114
+ git_config=None,
115
+ ):
116
+ """Override default estimator factory function for HuggingFace's different parameters
117
+
118
+ HuggingFace estimators have 3 framework version parameters instead of one: The version for
119
+ Transformers, PyTorch, and TensorFlow.
120
+ """
121
+ return self.estimator_cls(
122
+ transformers_version=self.framework_version,
123
+ tensorflow_version=self.tensorflow_version,
124
+ pytorch_version=self.pytorch_version,
125
+ py_version=self.py_version,
126
+ entry_point=entry_point,
127
+ source_dir=source_dir,
128
+ dependencies=dependencies,
129
+ git_config=git_config,
130
+ code_location=self.code_location,
131
+ enable_network_isolation=False,
132
+ image_uri=self.image_uri,
133
+ role=self.role,
134
+ instance_count=self.instance_count,
135
+ instance_type=self.instance_type,
136
+ sagemaker_session=self.sagemaker_session,
137
+ debugger_hook_config=False,
138
+ disable_profiler=True,
139
+ )
@@ -0,0 +1,167 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Configuration for the SageMaker Training Compiler."""
14
+ from __future__ import absolute_import
15
+ import logging
16
+ from typing import Union
17
+ from packaging.specifiers import SpecifierSet
18
+ from packaging.version import Version
19
+
20
+ from sagemaker.core.training_compiler.config import TrainingCompilerConfig as BaseConfig
21
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class TrainingCompilerConfig(BaseConfig):
27
+ """The SageMaker Training Compiler configuration class."""
28
+
29
+ SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
30
+ SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
31
+ "ml.g4dn.8xlarge",
32
+ "ml.g4dn.12xlarge",
33
+ "ml.g5.48xlarge",
34
+ "ml.p3dn.24xlarge",
35
+ "ml.p4d.24xlarge",
36
+ ]
37
+
38
+ def __init__(
39
+ self,
40
+ enabled: Union[bool, PipelineVariable] = True,
41
+ debug: Union[bool, PipelineVariable] = False,
42
+ ):
43
+ """This class initializes a ``TrainingCompilerConfig`` instance.
44
+
45
+ `Amazon SageMaker Training Compiler
46
+ <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
47
+ is a feature of SageMaker Training
48
+ and speeds up training jobs by optimizing model execution graphs.
49
+
50
+ You can compile Hugging Face models
51
+ by passing the object of this configuration class to the ``compiler_config``
52
+ parameter of the :class:`~sagemaker.huggingface.HuggingFace`
53
+ estimator.
54
+
55
+ Args:
56
+ enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker
57
+ Training Compiler. The default is ``True``.
58
+ debug (bool or PipelineVariable): Optional. Whether to dump detailed logs
59
+ for debugging. This comes with a potential performance slowdown.
60
+ The default is ``False``.
61
+
62
+ **Example**: The following code shows the basic usage of the
63
+ :class:`sagemaker.huggingface.TrainingCompilerConfig()` class
64
+ to run a HuggingFace training job with the compiler.
65
+
66
+ .. code-block:: python
67
+
68
+ from sagemaker.core.huggingface import HuggingFace, TrainingCompilerConfig
69
+
70
+ huggingface_estimator=HuggingFace(
71
+ ...
72
+ compiler_config=TrainingCompilerConfig()
73
+ )
74
+
75
+ .. seealso::
76
+
77
+ For more information about how to enable SageMaker Training Compiler
78
+ for various training settings such as using TensorFlow-based models,
79
+ PyTorch-based models, and distributed training,
80
+ see `Enable SageMaker Training Compiler
81
+ <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
82
+ in the `Amazon SageMaker Training Compiler developer guide
83
+ <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
84
+
85
+ """
86
+
87
+ super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
88
+
89
+ @classmethod
90
+ def validate(cls, estimator):
91
+ """Checks if SageMaker Training Compiler is configured correctly.
92
+
93
+ Args:
94
+ estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
95
+ If SageMaker Training Compiler is enabled, it will validate whether
96
+ the estimator is configured to be compatible with Training Compiler.
97
+
98
+ Raises:
99
+ ValueError: Raised if the requested configuration is not compatible
100
+ with SageMaker Training Compiler.
101
+ """
102
+
103
+ super(TrainingCompilerConfig, cls).validate(estimator)
104
+
105
+ if estimator.pytorch_version:
106
+ if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or (
107
+ Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
108
+ ):
109
+ error_helper_string = (
110
+ "SageMaker Training Compiler is only supported "
111
+ "with HuggingFace PyTorch 1.9-1.11. "
112
+ "Received pytorch_version={} which is unsupported."
113
+ )
114
+ raise ValueError(error_helper_string.format(estimator.pytorch_version))
115
+
116
+ if estimator.image_uri:
117
+ error_helper_string = (
118
+ "Overriding the image URI is currently not supported "
119
+ "for SageMaker Training Compiler."
120
+ "Specify the following parameters to run the Hugging Face training job "
121
+ "with SageMaker Training Compiler enabled: "
122
+ "transformer_version, tensorflow_version or pytorch_version, and compiler_config."
123
+ )
124
+ raise ValueError(error_helper_string)
125
+
126
+ if estimator.distribution:
127
+ pt_xla_present = "pytorchxla" in estimator.distribution
128
+ pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
129
+ if pt_xla_enabled:
130
+ if estimator.tensorflow_version:
131
+ error_helper_string = (
132
+ "Distribution mechanism 'pytorchxla' is currently only supported for "
133
+ "PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
134
+ "tensorflow_version={} which is unsupported."
135
+ )
136
+ raise ValueError(error_helper_string.format(estimator.tensorflow_version))
137
+ if estimator.pytorch_version:
138
+ if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
139
+ error_helper_string = (
140
+ "Distribution mechanism 'pytorchxla' is currently only supported for "
141
+ "PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
142
+ " Received pytorch_version={} which is unsupported."
143
+ )
144
+ raise ValueError(error_helper_string.format(estimator.pytorch_version))
145
+ if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
146
+ logger.warning(
147
+ "Consider using instances with EFA support when "
148
+ "training with PyTorch >= 1.11 and SageMaker Training Compiler "
149
+ "enabled. SageMaker Training Compiler leverages EFA to provide better "
150
+ "performance for distributed training."
151
+ )
152
+ if not pt_xla_present:
153
+ if estimator.pytorch_version:
154
+ if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
155
+ error_helper_string = (
156
+ "'pytorchxla' is the only distribution mechanism currently supported "
157
+ "for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
158
+ " Received distribution={} which is unsupported."
159
+ )
160
+ raise ValueError(error_helper_string.format(estimator.distribution))
161
+ elif estimator.instance_count and estimator.instance_count > 1:
162
+ if estimator.pytorch_version:
163
+ if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
164
+ logger.warning(
165
+ "Consider setting 'distribution' to 'pytorchxla' for distributed "
166
+ "training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
167
+ )
@@ -0,0 +1,172 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Accessors to retrieve hyperparameters for training jobs."""
14
+
15
+ from __future__ import absolute_import
16
+
17
+ import logging
18
+ from typing import Dict, Optional
19
+
20
+ from sagemaker.core.jumpstart import utils as jumpstart_utils
21
+ from sagemaker.core.jumpstart import artifacts
22
+ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23
+ from sagemaker.core.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
24
+ from sagemaker.core.jumpstart.validators import validate_hyperparameters
25
+ from sagemaker.core.helper.session_helper import Session
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def retrieve_default(
31
+ region: Optional[str] = None,
32
+ model_id: Optional[str] = None,
33
+ model_version: Optional[str] = None,
34
+ hub_arn: Optional[str] = None,
35
+ instance_type: Optional[str] = None,
36
+ include_container_hyperparameters: bool = False,
37
+ tolerate_vulnerable_model: bool = False,
38
+ tolerate_deprecated_model: bool = False,
39
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40
+ config_name: Optional[str] = None,
41
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
42
+ ) -> Dict[str, str]:
43
+ """Retrieves the default training hyperparameters for the model matching the given arguments.
44
+
45
+ Args:
46
+ region (str): The AWS Region for which to retrieve the default hyperparameters.
47
+ Defaults to ``None``.
48
+ model_id (str): The model ID of the model for which to
49
+ retrieve the default hyperparameters. (Default: None).
50
+ model_version (str): The version of the model for which to retrieve the
51
+ default hyperparameters. (Default: None).
52
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53
+ model details from. (default: None).
54
+ instance_type (str): An instance type to optionally supply in order to get hyperparameters
55
+ specific for the instance type.
56
+ include_container_hyperparameters (bool): ``True`` if the container hyperparameters
57
+ should be returned. Container hyperparameters are not used to tune
58
+ the specific algorithm. They are used by SageMaker Training jobs to set up
59
+ the training container environment. For example, there is a container hyperparameter
60
+ that indicates the entrypoint script to use. These hyperparameters may be required
61
+ when creating a training job with boto3, however the ``Estimator`` classes
62
+ add required container hyperparameters to the job. (Default: False).
63
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
64
+ specifications should be tolerated (exception not raised). If False, raises an
65
+ exception if the script used by this version of the model has dependencies with known
66
+ security vulnerabilities. (Default: False).
67
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
68
+ (exception not raised). False if these models should raise an exception.
69
+ (Default: False).
70
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
71
+ object, used for SageMaker interactions. If not
72
+ specified, one is created using the default AWS configuration
73
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
74
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
75
+ model_type (JumpStartModelType): The type of the model, can be open weights model
76
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
77
+ Returns:
78
+ dict: The hyperparameters to use for the model.
79
+
80
+ Raises:
81
+ ValueError: If the combination of arguments specified is not supported.
82
+ """
83
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
84
+ raise ValueError(
85
+ "Must specify JumpStart `model_id` and `model_version` when retrieving hyperparameters."
86
+ )
87
+
88
+ return artifacts._retrieve_default_hyperparameters(
89
+ model_id=model_id,
90
+ model_version=model_version,
91
+ hub_arn=hub_arn,
92
+ instance_type=instance_type,
93
+ region=region,
94
+ include_container_hyperparameters=include_container_hyperparameters,
95
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
96
+ tolerate_deprecated_model=tolerate_deprecated_model,
97
+ sagemaker_session=sagemaker_session,
98
+ config_name=config_name,
99
+ model_type=model_type,
100
+ )
101
+
102
+
103
+ def validate(
104
+ region: Optional[str] = None,
105
+ model_id: Optional[str] = None,
106
+ hub_arn: Optional[str] = None,
107
+ model_version: Optional[str] = None,
108
+ hyperparameters: Optional[dict] = None,
109
+ validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
110
+ tolerate_vulnerable_model: bool = False,
111
+ tolerate_deprecated_model: bool = False,
112
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
113
+ ) -> None:
114
+ """Validates hyperparameters for models.
115
+
116
+ Args:
117
+ region (str): The AWS Region for which to validate hyperparameters. (Default: None).
118
+ model_id (str): The model ID of the model for which to validate hyperparameters.
119
+ (Default: None).
120
+ model_version (str): The version of the model for which to validate hyperparameters.
121
+ (Default: None).
122
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
123
+ model details from. (default: None).
124
+ hyperparameters (dict): Hyperparameters to validate.
125
+ (Default: None).
126
+ validation_mode (HyperparameterValidationMode): Method of validation to use with
127
+ hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
128
+ to this function will be validated, the missing hyperparameters will be ignored.
129
+ If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
130
+ If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
131
+ (Default: None).
132
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
133
+ specifications should be tolerated (exception not raised). If False, raises an
134
+ exception if the script used by this version of the model has dependencies with known
135
+ security vulnerabilities. (Default: False).
136
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
137
+ (exception not raised). False if these models should raise an exception.
138
+ (Default: False).
139
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
140
+ object, used for SageMaker interactions. If not
141
+ specified, one is created using the default AWS configuration
142
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
143
+
144
+ Raises:
145
+ JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
146
+ according to its specs in the model metadata.
147
+ ValueError: If the combination of arguments specified is not supported.
148
+
149
+ """
150
+
151
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
152
+ raise ValueError(
153
+ "Must specify JumpStart `model_id` and `model_version` when validating hyperparameters."
154
+ )
155
+
156
+ if model_id is None or model_version is None:
157
+ raise RuntimeError("Model ID and version must both be non-None")
158
+
159
+ if hyperparameters is None:
160
+ raise ValueError("Must specify hyperparameters.")
161
+
162
+ return validate_hyperparameters(
163
+ model_id=model_id,
164
+ model_version=model_version,
165
+ hub_arn=hub_arn,
166
+ hyperparameters=hyperparameters,
167
+ validation_mode=validation_mode,
168
+ region=region,
169
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
170
+ tolerate_deprecated_model=tolerate_deprecated_model,
171
+ sagemaker_session=sagemaker_session,
172
+ )
@@ -0,0 +1,3 @@
1
+ from __future__ import absolute_import
2
+
3
+ from sagemaker.core.image_retriever.image_retriever import ImageRetriever # noqa: F401