sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,202 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart model packages."""
14
+ from __future__ import absolute_import
15
+ from typing import Optional
16
+ from sagemaker.core.jumpstart.constants import (
17
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18
+ )
19
+ from sagemaker.core.jumpstart.utils import (
20
+ get_region_fallback,
21
+ verify_model_region_and_return_specs,
22
+ )
23
+ from sagemaker.core.jumpstart.enums import (
24
+ JumpStartScriptScope,
25
+ JumpStartModelType,
26
+ )
27
+ from sagemaker.core.helper.session_helper import Session
28
+
29
+
30
+ def _retrieve_model_package_arn(
31
+ model_id: str,
32
+ model_version: str,
33
+ instance_type: Optional[str],
34
+ region: Optional[str],
35
+ hub_arn: Optional[str] = None,
36
+ scope: Optional[str] = None,
37
+ tolerate_vulnerable_model: bool = False,
38
+ tolerate_deprecated_model: bool = False,
39
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
41
+ config_name: Optional[str] = None,
42
+ ) -> Optional[str]:
43
+ """Retrieves associated model pacakge arn for the model.
44
+
45
+ Args:
46
+ model_id (str): JumpStart model ID of the JumpStart model for which to
47
+ retrieve the model package arn.
48
+ model_version (str): Version of the JumpStart model for which to retrieve the
49
+ model package arn.
50
+ instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
51
+ specific for the instance type.
52
+ region (Optional[str]): Region for which to retrieve the model package arn.
53
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
54
+ model details from. (Default: None).
55
+ scope (Optional[str]): Scope for which to retrieve the model package arn.
56
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
57
+ specifications should be tolerated (exception not raised). If False, raises an
58
+ exception if the script used by this version of the model has dependencies with known
59
+ security vulnerabilities. (Default: False).
60
+ tolerate_deprecated_model (bool): True if deprecated versions of model
61
+ specifications should be tolerated (exception not raised). If False, raises
62
+ an exception if the version of the model is deprecated. (Default: False).
63
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
64
+ object, used for SageMaker interactions. If not
65
+ specified, one is created using the default AWS configuration
66
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
67
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
68
+
69
+ Returns:
70
+ str: the model package arn to use for the model or None.
71
+ """
72
+
73
+ region = region or get_region_fallback(
74
+ sagemaker_session=sagemaker_session,
75
+ )
76
+
77
+ model_specs = verify_model_region_and_return_specs(
78
+ model_id=model_id,
79
+ version=model_version,
80
+ hub_arn=hub_arn,
81
+ scope=scope,
82
+ region=region,
83
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
84
+ tolerate_deprecated_model=tolerate_deprecated_model,
85
+ sagemaker_session=sagemaker_session,
86
+ model_type=model_type,
87
+ config_name=config_name,
88
+ )
89
+
90
+ if scope == JumpStartScriptScope.INFERENCE:
91
+
92
+ instance_specific_arn: Optional[str] = (
93
+ model_specs.hosting_instance_type_variants.get_model_package_arn(
94
+ region=region, instance_type=instance_type
95
+ )
96
+ if getattr(model_specs, "hosting_instance_type_variants", None) is not None
97
+ else None
98
+ )
99
+
100
+ if instance_specific_arn is not None:
101
+ return instance_specific_arn
102
+
103
+ if (
104
+ model_specs.hosting_model_package_arns is None
105
+ or model_specs.hosting_model_package_arns == {}
106
+ ):
107
+ return None
108
+
109
+ regional_arn = model_specs.hosting_model_package_arns.get(region)
110
+
111
+ if regional_arn is None:
112
+ raise ValueError(
113
+ f"Model package arn for '{model_id}' not supported in {region}. "
114
+ "Please try one of the following regions: "
115
+ f"{', '.join(model_specs.hosting_model_package_arns.keys())}."
116
+ )
117
+
118
+ return regional_arn
119
+
120
+ raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
121
+
122
+
123
+ def _retrieve_model_package_model_artifact_s3_uri(
124
+ model_id: str,
125
+ model_version: str,
126
+ region: Optional[str],
127
+ hub_arn: Optional[str] = None,
128
+ scope: Optional[str] = None,
129
+ tolerate_vulnerable_model: bool = False,
130
+ tolerate_deprecated_model: bool = False,
131
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
132
+ config_name: Optional[str] = None,
133
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
134
+ ) -> Optional[str]:
135
+ """Retrieves s3 artifact uri associated with model package.
136
+
137
+ Args:
138
+ model_id (str): JumpStart model ID of the JumpStart model for which to
139
+ retrieve the model package artifact.
140
+ model_version (str): Version of the JumpStart model for which to retrieve the
141
+ model package artifact.
142
+ region (Optional[str]): Region for which to retrieve the model package artifact.
143
+ (Default: None).
144
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
145
+ model details from. (Default: None).
146
+ scope (Optional[str]): Scope for which to retrieve the model package artifact.
147
+ (Default: None).
148
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
149
+ specifications should be tolerated (exception not raised). If False, raises an
150
+ exception if the script used by this version of the model has dependencies with known
151
+ security vulnerabilities. (Default: False).
152
+ tolerate_deprecated_model (bool): True if deprecated versions of model
153
+ specifications should be tolerated (exception not raised). If False, raises
154
+ an exception if the version of the model is deprecated. (Default: False).
155
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
156
+ object, used for SageMaker interactions. If not
157
+ specified, one is created using the default AWS configuration
158
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
159
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
160
+ model_type (JumpStartModelType): The type of the model, can be open weights model
161
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
162
+ Returns:
163
+ str: the model package artifact uri to use for the model or None.
164
+
165
+ Raises:
166
+ NotImplementedError: If an unsupported script is used.
167
+ """
168
+
169
+ if scope == JumpStartScriptScope.TRAINING:
170
+
171
+ region = region or get_region_fallback(
172
+ sagemaker_session=sagemaker_session,
173
+ )
174
+
175
+ model_specs = verify_model_region_and_return_specs(
176
+ model_id=model_id,
177
+ version=model_version,
178
+ hub_arn=hub_arn,
179
+ scope=scope,
180
+ region=region,
181
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
182
+ tolerate_deprecated_model=tolerate_deprecated_model,
183
+ sagemaker_session=sagemaker_session,
184
+ config_name=config_name,
185
+ model_type=model_type,
186
+ )
187
+
188
+ if model_specs.training_model_package_artifact_uris is None:
189
+ return None
190
+
191
+ model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)
192
+
193
+ if model_s3_uri is None:
194
+ raise ValueError(
195
+ f"Model package artifact s3 uri for '{model_id}' not supported in {region}. "
196
+ "Please try one of the following regions: "
197
+ f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}."
198
+ )
199
+
200
+ return model_s3_uri
201
+
202
+ raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")
@@ -0,0 +1,252 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart model uris."""
14
+ from __future__ import absolute_import
15
+ import os
16
+ from typing import Optional
17
+
18
+ from sagemaker.core.jumpstart.constants import (
19
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
20
+ ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
21
+ )
22
+ from sagemaker.core.jumpstart.enums import (
23
+ JumpStartModelType,
24
+ JumpStartScriptScope,
25
+ )
26
+ from sagemaker.core.jumpstart.utils import (
27
+ get_jumpstart_content_bucket,
28
+ get_jumpstart_gated_content_bucket,
29
+ get_region_fallback,
30
+ verify_model_region_and_return_specs,
31
+ )
32
+ from sagemaker.core.s3.utils import is_s3_url
33
+ from sagemaker.core.helper.session_helper import Session
34
+ from sagemaker.core.jumpstart.types import JumpStartModelSpecs
35
+
36
+
37
+ def _retrieve_hosting_prepacked_artifact_key(
38
+ model_specs: JumpStartModelSpecs, instance_type: str
39
+ ) -> str:
40
+ """Returns instance specific hosting prepacked artifact key or default one as fallback."""
41
+ instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
42
+ model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
43
+ instance_type=instance_type
44
+ )
45
+ if instance_type
46
+ and getattr(model_specs, "hosting_instance_type_variants", None) is not None
47
+ else None
48
+ )
49
+
50
+ default_prepacked_hosting_artifact_key: Optional[str] = getattr(
51
+ model_specs, "hosting_prepacked_artifact_key"
52
+ )
53
+
54
+ return (
55
+ instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
56
+ )
57
+
58
+
59
+ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
60
+ """Returns instance specific hosting artifact key or default one as fallback."""
61
+ instance_specific_hosting_artifact_key: Optional[str] = (
62
+ model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
63
+ instance_type=instance_type
64
+ )
65
+ if instance_type
66
+ and getattr(model_specs, "hosting_instance_type_variants", None) is not None
67
+ else None
68
+ )
69
+
70
+ default_hosting_artifact_key: str = model_specs.hosting_artifact_key
71
+
72
+ return instance_specific_hosting_artifact_key or default_hosting_artifact_key
73
+
74
+
75
+ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
76
+ """Returns instance specific training artifact key or default one as fallback."""
77
+ instance_specific_training_artifact_key: Optional[str] = (
78
+ model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key(
79
+ instance_type=instance_type
80
+ )
81
+ if instance_type
82
+ and getattr(model_specs, "training_instance_type_variants", None) is not None
83
+ else None
84
+ )
85
+
86
+ default_training_artifact_key: str = model_specs.training_artifact_key
87
+
88
+ return instance_specific_training_artifact_key or default_training_artifact_key
89
+
90
+
91
+ def _retrieve_model_uri(
92
+ model_id: str,
93
+ model_version: str,
94
+ hub_arn: Optional[str] = None,
95
+ model_scope: Optional[str] = None,
96
+ instance_type: Optional[str] = None,
97
+ region: Optional[str] = None,
98
+ tolerate_vulnerable_model: bool = False,
99
+ tolerate_deprecated_model: bool = False,
100
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
101
+ config_name: Optional[str] = None,
102
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
103
+ ):
104
+ """Retrieves the model artifact S3 URI for the model matching the given arguments.
105
+
106
+ Optionally uses a bucket override specified by environment variable.
107
+
108
+ Args:
109
+ model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
110
+ the model artifact S3 URI.
111
+ model_version (str): Version of the JumpStart model for which to retrieve the model
112
+ artifact S3 URI.
113
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
114
+ model details from. (Default: None).
115
+ model_scope (str): The model type, i.e. what it is used for.
116
+ Valid values: "training" and "inference".
117
+ instance_type (str): The ML compute instance type for the specified scope. (Default: None).
118
+ region (str): Region for which to retrieve model S3 URI. (Default: None).
119
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
120
+ specifications should be tolerated (exception not raised). If False, raises an
121
+ exception if the script used by this version of the model has dependencies with known
122
+ security vulnerabilities. (Default: False).
123
+ tolerate_deprecated_model (bool): True if deprecated versions of model
124
+ specifications should be tolerated (exception not raised). If False, raises
125
+ an exception if the version of the model is deprecated. (Default: False).
126
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
127
+ object, used for SageMaker interactions. If not
128
+ specified, one is created using the default AWS configuration
129
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
130
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
131
+ model_type (JumpStartModelType): The type of the model, can be open weights model
132
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
133
+
134
+ Returns:
135
+ str: the model artifact S3 URI for the corresponding model.
136
+
137
+ Raises:
138
+ ValueError: If the combination of arguments specified is not supported.
139
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
140
+ known security vulnerabilities.
141
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
142
+ """
143
+ region = region or get_region_fallback(
144
+ sagemaker_session=sagemaker_session,
145
+ )
146
+
147
+ model_specs = verify_model_region_and_return_specs(
148
+ model_id=model_id,
149
+ version=model_version,
150
+ hub_arn=hub_arn,
151
+ scope=model_scope,
152
+ region=region,
153
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
154
+ tolerate_deprecated_model=tolerate_deprecated_model,
155
+ sagemaker_session=sagemaker_session,
156
+ config_name=config_name,
157
+ model_type=model_type,
158
+ )
159
+
160
+ model_artifact_key: str
161
+
162
+ if model_scope == JumpStartScriptScope.INFERENCE:
163
+
164
+ is_prepacked = not model_specs.use_inference_script_uri()
165
+
166
+ if hub_arn:
167
+ model_artifact_uri = model_specs.hosting_artifact_uri
168
+ return model_artifact_uri
169
+ model_artifact_key = (
170
+ _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
171
+ if is_prepacked
172
+ else _retrieve_hosting_artifact_key(model_specs, instance_type)
173
+ )
174
+
175
+ elif model_scope == JumpStartScriptScope.TRAINING:
176
+
177
+ model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
178
+
179
+ default_jumpstart_bucket: str = (
180
+ get_jumpstart_gated_content_bucket(region)
181
+ if model_specs.gated_bucket
182
+ else get_jumpstart_content_bucket(region)
183
+ )
184
+
185
+ bucket = (
186
+ os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
187
+ or default_jumpstart_bucket
188
+ )
189
+ if not is_s3_url(model_artifact_key):
190
+ model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
191
+
192
+ return model_s3_uri
193
+
194
+
195
+ def _model_supports_training_model_uri(
196
+ model_id: str,
197
+ model_version: str,
198
+ region: Optional[str],
199
+ hub_arn: Optional[str] = None,
200
+ tolerate_vulnerable_model: bool = False,
201
+ tolerate_deprecated_model: bool = False,
202
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
203
+ config_name: Optional[str] = None,
204
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
205
+ ) -> bool:
206
+ """Returns True if the model supports training with model uri field.
207
+
208
+ Args:
209
+ model_id (str): JumpStart model ID of the JumpStart model for which to
210
+ retrieve the support status for model uri with training.
211
+ model_version (str): Version of the JumpStart model for which to retrieve the
212
+ support status for model uri with training.
213
+ region (Optional[str]): Region for which to retrieve the
214
+ support status for model uri with training.
215
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
216
+ model details from. (Default: None).
217
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
218
+ specifications should be tolerated (exception not raised). If False, raises an
219
+ exception if the script used by this version of the model has dependencies with known
220
+ security vulnerabilities. (Default: False).
221
+ tolerate_deprecated_model (bool): True if deprecated versions of model
222
+ specifications should be tolerated (exception not raised). If False, raises
223
+ an exception if the version of the model is deprecated. (Default: False).
224
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
225
+ object, used for SageMaker interactions. If not
226
+ specified, one is created using the default AWS configuration
227
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
228
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
229
+ model_type (JumpStartModelType): The type of the model, can be open weights model
230
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
231
+ Returns:
232
+ bool: the support status for model uri with training.
233
+ """
234
+
235
+ region = region or get_region_fallback(
236
+ sagemaker_session=sagemaker_session,
237
+ )
238
+
239
+ model_specs = verify_model_region_and_return_specs(
240
+ model_id=model_id,
241
+ version=model_version,
242
+ hub_arn=hub_arn,
243
+ scope=JumpStartScriptScope.TRAINING,
244
+ region=region,
245
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
246
+ tolerate_deprecated_model=tolerate_deprecated_model,
247
+ sagemaker_session=sagemaker_session,
248
+ config_name=config_name,
249
+ model_type=model_type,
250
+ )
251
+
252
+ return model_specs.use_training_model_artifact()
@@ -0,0 +1,96 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions to obtain JumpStart model payloads."""
14
+ from __future__ import absolute_import
15
+ from copy import deepcopy
16
+ from typing import Dict, Optional
17
+ from sagemaker.core.jumpstart.constants import (
18
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19
+ )
20
+ from sagemaker.core.jumpstart.enums import (
21
+ JumpStartScriptScope,
22
+ JumpStartModelType,
23
+ )
24
+ from sagemaker.core.jumpstart.types import JumpStartSerializablePayload
25
+ from sagemaker.core.jumpstart.utils import (
26
+ get_region_fallback,
27
+ verify_model_region_and_return_specs,
28
+ )
29
+ from sagemaker.core.helper.session_helper import Session
30
+
31
+
32
+ def _retrieve_example_payloads(
33
+ model_id: str,
34
+ model_version: str,
35
+ region: Optional[str],
36
+ hub_arn: Optional[str] = None,
37
+ tolerate_vulnerable_model: bool = False,
38
+ tolerate_deprecated_model: bool = False,
39
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
41
+ config_name: Optional[str] = None,
42
+ ) -> Optional[Dict[str, JumpStartSerializablePayload]]:
43
+ """Returns example payloads.
44
+
45
+ Args:
46
+ model_id (str): JumpStart model ID of the JumpStart model for which to
47
+ get example payloads.
48
+ model_version (str): Version of the JumpStart model for which to retrieve the
49
+ example payloads.
50
+ region (Optional[str]): Region for which to retrieve the
51
+ example payloads.
52
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53
+ model details from. (Default: None).
54
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
55
+ specifications should be tolerated (exception not raised). If False, raises an
56
+ exception if the script used by this version of the model has dependencies with known
57
+ security vulnerabilities. (Default: False).
58
+ tolerate_deprecated_model (bool): True if deprecated versions of model
59
+ specifications should be tolerated (exception not raised). If False, raises
60
+ an exception if the version of the model is deprecated. (Default: False).
61
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
62
+ object, used for SageMaker interactions. If not
63
+ specified, one is created using the default AWS configuration
64
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
65
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
66
+ Returns:
67
+ Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases
68
+ to the serializable payload object.
69
+ """
70
+
71
+ region = region or get_region_fallback(
72
+ sagemaker_session=sagemaker_session,
73
+ )
74
+
75
+ model_specs = verify_model_region_and_return_specs(
76
+ model_id=model_id,
77
+ version=model_version,
78
+ hub_arn=hub_arn,
79
+ scope=JumpStartScriptScope.INFERENCE,
80
+ region=region,
81
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
82
+ tolerate_deprecated_model=tolerate_deprecated_model,
83
+ sagemaker_session=sagemaker_session,
84
+ model_type=model_type,
85
+ config_name=config_name,
86
+ )
87
+
88
+ default_payloads = model_specs.default_payloads
89
+
90
+ if default_payloads:
91
+ for payload in default_payloads.values():
92
+ payload.accept = getattr(
93
+ payload, "accept", model_specs.predictor_specs.default_accept_type
94
+ )
95
+
96
+ return deepcopy(default_payloads) if default_payloads else None