sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,540 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart predictors."""
14
+ from __future__ import absolute_import
15
+ from typing import List, Optional, Set, Type
16
+ from sagemaker.core.deserializers import BaseDeserializer
17
+ from sagemaker.core.serializers import BaseSerializer
18
+ from sagemaker.core.jumpstart.constants import (
19
+ ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP,
20
+ CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP,
21
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
22
+ DESERIALIZER_TYPE_TO_CLASS_MAP,
23
+ SERIALIZER_TYPE_TO_CLASS_MAP,
24
+ )
25
+ from sagemaker.core.jumpstart.enums import (
26
+ JumpStartScriptScope,
27
+ MIMEType,
28
+ JumpStartModelType,
29
+ )
30
+ from sagemaker.core.jumpstart.utils import (
31
+ get_region_fallback,
32
+ verify_model_region_and_return_specs,
33
+ )
34
+ from sagemaker.core.helper.session_helper import Session
35
+
36
+
37
+ def _retrieve_serializer_from_content_type(
38
+ content_type: MIMEType,
39
+ ) -> BaseDeserializer:
40
+ """Returns serializer object to use for content type."""
41
+
42
+ serializer_type = CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP.get(content_type)
43
+
44
+ if serializer_type is None:
45
+ raise RuntimeError(f"Unrecognized content type: {content_type}")
46
+
47
+ serializer_handle = SERIALIZER_TYPE_TO_CLASS_MAP.get(serializer_type)
48
+
49
+ if serializer_handle is None:
50
+ raise RuntimeError(f"Unrecognized serializer type: {serializer_type}")
51
+
52
+ return serializer_handle.__call__()
53
+
54
+
55
+ def _retrieve_deserializer_from_accept_type(
56
+ accept_type: MIMEType,
57
+ ) -> BaseDeserializer:
58
+ """Returns deserializer object to use for accept type."""
59
+
60
+ deserializer_type = ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP.get(accept_type)
61
+
62
+ if deserializer_type is None:
63
+ raise RuntimeError(f"Unrecognized accept type: {accept_type}")
64
+
65
+ deserializer_handle = DESERIALIZER_TYPE_TO_CLASS_MAP.get(deserializer_type)
66
+
67
+ if deserializer_handle is None:
68
+ raise RuntimeError(f"Unrecognized deserializer type: {deserializer_type}")
69
+
70
+ return deserializer_handle.__call__()
71
+
72
+
73
+ def _retrieve_default_deserializer(
74
+ model_id: str,
75
+ model_version: str,
76
+ hub_arn: Optional[str],
77
+ region: Optional[str],
78
+ tolerate_vulnerable_model: bool = False,
79
+ tolerate_deprecated_model: bool = False,
80
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
81
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
82
+ config_name: Optional[str] = None,
83
+ ) -> BaseDeserializer:
84
+ """Retrieves the default deserializer for the model.
85
+
86
+ Args:
87
+ model_id (str): JumpStart model ID of the JumpStart model for which to
88
+ retrieve the default deserializer.
89
+ model_version (str): Version of the JumpStart model for which to retrieve the
90
+ default deserializer.
91
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
92
+ model details from. (Default: None).
93
+ region (Optional[str]): Region for which to retrieve default deserializer.
94
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
95
+ specifications should be tolerated (exception not raised). If False, raises an
96
+ exception if the script used by this version of the model has dependencies with known
97
+ security vulnerabilities. (Default: False).
98
+ tolerate_deprecated_model (bool): True if deprecated versions of model
99
+ specifications should be tolerated (exception not raised). If False, raises
100
+ an exception if the version of the model is deprecated. (Default: False).
101
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
102
+ object, used for SageMaker interactions. If not
103
+ specified, one is created using the default AWS configuration
104
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
105
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
106
+
107
+ Returns:
108
+ BaseDeserializer: the default deserializer to use for the model.
109
+ """
110
+
111
+ default_accept_type = _retrieve_default_accept_type(
112
+ model_id=model_id,
113
+ model_version=model_version,
114
+ hub_arn=hub_arn,
115
+ region=region,
116
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
117
+ tolerate_deprecated_model=tolerate_deprecated_model,
118
+ sagemaker_session=sagemaker_session,
119
+ model_type=model_type,
120
+ config_name=config_name,
121
+ )
122
+
123
+ return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type))
124
+
125
+
126
+ def _retrieve_default_serializer(
127
+ model_id: str,
128
+ model_version: str,
129
+ hub_arn: Optional[str],
130
+ region: Optional[str],
131
+ tolerate_vulnerable_model: bool = False,
132
+ tolerate_deprecated_model: bool = False,
133
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
134
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
135
+ config_name: Optional[str] = None,
136
+ ) -> BaseSerializer:
137
+ """Retrieves the default serializer for the model.
138
+
139
+ Args:
140
+ model_id (str): JumpStart model ID of the JumpStart model for which to
141
+ retrieve the default serializer.
142
+ model_version (str): Version of the JumpStart model for which to retrieve the
143
+ default serializer.
144
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
145
+ model details from. (Default: None).
146
+ region (Optional[str]): Region for which to retrieve default serializer.
147
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
148
+ specifications should be tolerated (exception not raised). If False, raises an
149
+ exception if the script used by this version of the model has dependencies with known
150
+ security vulnerabilities. (Default: False).
151
+ tolerate_deprecated_model (bool): True if deprecated versions of model
152
+ specifications should be tolerated (exception not raised). If False, raises
153
+ an exception if the version of the model is deprecated. (Default: False).
154
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
155
+ object, used for SageMaker interactions. If not
156
+ specified, one is created using the default AWS configuration
157
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
158
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
159
+ Returns:
160
+ BaseSerializer: the default serializer to use for the model.
161
+ """
162
+
163
+ default_content_type = _retrieve_default_content_type(
164
+ model_id=model_id,
165
+ model_version=model_version,
166
+ hub_arn=hub_arn,
167
+ region=region,
168
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
169
+ tolerate_deprecated_model=tolerate_deprecated_model,
170
+ sagemaker_session=sagemaker_session,
171
+ model_type=model_type,
172
+ config_name=config_name,
173
+ )
174
+
175
+ return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type))
176
+
177
+
178
+ def _retrieve_deserializer_options(
179
+ model_id: str,
180
+ model_version: str,
181
+ hub_arn: Optional[str],
182
+ region: Optional[str],
183
+ tolerate_vulnerable_model: bool = False,
184
+ tolerate_deprecated_model: bool = False,
185
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
186
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
187
+ config_name: Optional[str] = None,
188
+ ) -> List[BaseDeserializer]:
189
+ """Retrieves the supported deserializers for the model.
190
+
191
+ Args:
192
+ model_id (str): JumpStart model ID of the JumpStart model for which to
193
+ retrieve the supported deserializers.
194
+ model_version (str): Version of the JumpStart model for which to retrieve the
195
+ supported deserializers.
196
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
197
+ model details from. (Default: None).
198
+ region (Optional[str]): Region for which to retrieve deserializer options.
199
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
200
+ specifications should be tolerated (exception not raised). If False, raises an
201
+ exception if the script used by this version of the model has dependencies with known
202
+ security vulnerabilities. (Default: False).
203
+ tolerate_deprecated_model (bool): True if deprecated versions of model
204
+ specifications should be tolerated (exception not raised). If False, raises
205
+ an exception if the version of the model is deprecated. (Default: False).
206
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
207
+ object, used for SageMaker interactions. If not
208
+ specified, one is created using the default AWS configuration
209
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
210
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
211
+ Returns:
212
+ List[BaseDeserializer]: the supported deserializers to use for the model.
213
+ """
214
+
215
+ supported_accept_types = _retrieve_supported_accept_types(
216
+ model_id=model_id,
217
+ model_version=model_version,
218
+ hub_arn=hub_arn,
219
+ region=region,
220
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
221
+ tolerate_deprecated_model=tolerate_deprecated_model,
222
+ sagemaker_session=sagemaker_session,
223
+ model_type=model_type,
224
+ config_name=config_name,
225
+ )
226
+
227
+ seen_classes: Set[Type] = set()
228
+
229
+ deserializers_with_duplicates: List[BaseDeserializer] = [
230
+ _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(accept_type))
231
+ for accept_type in supported_accept_types
232
+ ]
233
+
234
+ deserializers: List[BaseDeserializer] = []
235
+
236
+ for deserializer in deserializers_with_duplicates:
237
+ if type(deserializer) not in seen_classes:
238
+ seen_classes.add(type(deserializer))
239
+ deserializers.append(deserializer)
240
+
241
+ return deserializers
242
+
243
+
244
+ def _retrieve_serializer_options(
245
+ model_id: str,
246
+ model_version: str,
247
+ hub_arn: Optional[str],
248
+ region: Optional[str],
249
+ tolerate_vulnerable_model: bool = False,
250
+ tolerate_deprecated_model: bool = False,
251
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
252
+ config_name: Optional[str] = None,
253
+ ) -> List[BaseSerializer]:
254
+ """Retrieves the supported serializers for the model.
255
+
256
+ Args:
257
+ model_id (str): JumpStart model ID of the JumpStart model for which to
258
+ retrieve the supported serializers.
259
+ model_version (str): Version of the JumpStart model for which to retrieve the
260
+ supported serializers.
261
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
262
+ model details from. (Default: None).
263
+ region (Optional[str]): Region for which to retrieve serializer options.
264
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
265
+ specifications should be tolerated (exception not raised). If False, raises an
266
+ exception if the script used by this version of the model has dependencies with known
267
+ security vulnerabilities. (Default: False).
268
+ tolerate_deprecated_model (bool): True if deprecated versions of model
269
+ specifications should be tolerated (exception not raised). If False, raises
270
+ an exception if the version of the model is deprecated. (Default: False).
271
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
272
+ object, used for SageMaker interactions. If not
273
+ specified, one is created using the default AWS configuration
274
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
275
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
276
+ Returns:
277
+ List[BaseSerializer]: the supported serializers to use for the model.
278
+ """
279
+
280
+ supported_content_types = _retrieve_supported_content_types(
281
+ model_id=model_id,
282
+ model_version=model_version,
283
+ hub_arn=hub_arn,
284
+ region=region,
285
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
286
+ tolerate_deprecated_model=tolerate_deprecated_model,
287
+ sagemaker_session=sagemaker_session,
288
+ config_name=config_name,
289
+ )
290
+
291
+ seen_classes: Set[Type] = set()
292
+
293
+ serializers_with_duplicates: List[BaseSerializer] = [
294
+ _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(content_type))
295
+ for content_type in supported_content_types
296
+ ]
297
+
298
+ serializers: List[BaseSerializer] = []
299
+
300
+ for serializer in serializers_with_duplicates:
301
+ if type(serializer) not in seen_classes:
302
+ seen_classes.add(type(serializer))
303
+ serializers.append(serializer)
304
+
305
+ return serializers
306
+
307
+
308
+ def _retrieve_default_content_type(
309
+ model_id: str,
310
+ model_version: str,
311
+ hub_arn: Optional[str],
312
+ region: Optional[str],
313
+ tolerate_vulnerable_model: bool = False,
314
+ tolerate_deprecated_model: bool = False,
315
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
316
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
317
+ config_name: Optional[str] = None,
318
+ ) -> str:
319
+ """Retrieves the default content type for the model.
320
+
321
+ Args:
322
+ model_id (str): JumpStart model ID of the JumpStart model for which to
323
+ retrieve the default content type.
324
+ model_version (str): Version of the JumpStart model for which to retrieve the
325
+ default content type.
326
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
327
+ model details from. (Default: None).
328
+ region (Optional[str]): Region for which to retrieve default content type.
329
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
330
+ specifications should be tolerated (exception not raised). If False, raises an
331
+ exception if the script used by this version of the model has dependencies with known
332
+ security vulnerabilities. (Default: False).
333
+ tolerate_deprecated_model (bool): True if deprecated versions of model
334
+ specifications should be tolerated (exception not raised). If False, raises
335
+ an exception if the version of the model is deprecated. (Default: False).
336
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
337
+ object, used for SageMaker interactions. If not
338
+ specified, one is created using the default AWS configuration
339
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
340
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
341
+ Returns:
342
+ str: the default content type to use for the model.
343
+ """
344
+
345
+ region = region or get_region_fallback(
346
+ sagemaker_session=sagemaker_session,
347
+ )
348
+
349
+ model_specs = verify_model_region_and_return_specs(
350
+ model_id=model_id,
351
+ version=model_version,
352
+ hub_arn=hub_arn,
353
+ scope=JumpStartScriptScope.INFERENCE,
354
+ region=region,
355
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
356
+ tolerate_deprecated_model=tolerate_deprecated_model,
357
+ sagemaker_session=sagemaker_session,
358
+ model_type=model_type,
359
+ config_name=config_name,
360
+ )
361
+
362
+ default_content_type = model_specs.predictor_specs.default_content_type
363
+ return default_content_type
364
+
365
+
366
+ def _retrieve_default_accept_type(
367
+ model_id: str,
368
+ model_version: str,
369
+ hub_arn: Optional[str],
370
+ region: Optional[str],
371
+ tolerate_vulnerable_model: bool = False,
372
+ tolerate_deprecated_model: bool = False,
373
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
374
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
375
+ config_name: Optional[str] = None,
376
+ ) -> str:
377
+ """Retrieves the default accept type for the model.
378
+
379
+ Args:
380
+ model_id (str): JumpStart model ID of the JumpStart model for which to
381
+ retrieve the default accept type.
382
+ model_version (str): Version of the JumpStart model for which to retrieve the
383
+ default accept type.
384
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
385
+ model details from. (Default: None).
386
+ region (Optional[str]): Region for which to retrieve default accept type.
387
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
388
+ specifications should be tolerated (exception not raised). If False, raises an
389
+ exception if the script used by this version of the model has dependencies with known
390
+ security vulnerabilities. (Default: False).
391
+ tolerate_deprecated_model (bool): True if deprecated versions of model
392
+ specifications should be tolerated (exception not raised). If False, raises
393
+ an exception if the version of the model is deprecated. (Default: False).
394
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
395
+ object, used for SageMaker interactions. If not
396
+ specified, one is created using the default AWS configuration
397
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
398
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
399
+ Returns:
400
+ str: the default accept type to use for the model.
401
+ """
402
+
403
+ region = region or get_region_fallback(
404
+ sagemaker_session=sagemaker_session,
405
+ )
406
+
407
+ model_specs = verify_model_region_and_return_specs(
408
+ model_id=model_id,
409
+ version=model_version,
410
+ hub_arn=hub_arn,
411
+ scope=JumpStartScriptScope.INFERENCE,
412
+ region=region,
413
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
414
+ tolerate_deprecated_model=tolerate_deprecated_model,
415
+ sagemaker_session=sagemaker_session,
416
+ model_type=model_type,
417
+ config_name=config_name,
418
+ )
419
+
420
+ default_accept_type = model_specs.predictor_specs.default_accept_type
421
+
422
+ return default_accept_type
423
+
424
+
425
+ def _retrieve_supported_accept_types(
426
+ model_id: str,
427
+ model_version: str,
428
+ hub_arn: Optional[str],
429
+ region: Optional[str],
430
+ tolerate_vulnerable_model: bool = False,
431
+ tolerate_deprecated_model: bool = False,
432
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
433
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
434
+ config_name: Optional[str] = None,
435
+ ) -> List[str]:
436
+ """Retrieves the supported accept types for the model.
437
+
438
+ Args:
439
+ model_id (str): JumpStart model ID of the JumpStart model for which to
440
+ retrieve the supported accept types.
441
+ model_version (str): Version of the JumpStart model for which to retrieve the
442
+ supported accept types.
443
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
444
+ model details from. (Default: None).
445
+ region (Optional[str]): Region for which to retrieve accept type options.
446
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
447
+ specifications should be tolerated (exception not raised). If False, raises an
448
+ exception if the script used by this version of the model has dependencies with known
449
+ security vulnerabilities. (Default: False).
450
+ tolerate_deprecated_model (bool): True if deprecated versions of model
451
+ specifications should be tolerated (exception not raised). If False, raises
452
+ an exception if the version of the model is deprecated. (Default: False).
453
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
454
+ object, used for SageMaker interactions. If not
455
+ specified, one is created using the default AWS configuration
456
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
457
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
458
+ Returns:
459
+ list: the supported accept types to use for the model.
460
+ """
461
+
462
+ region = region or get_region_fallback(
463
+ sagemaker_session=sagemaker_session,
464
+ )
465
+
466
+ model_specs = verify_model_region_and_return_specs(
467
+ model_id=model_id,
468
+ version=model_version,
469
+ hub_arn=hub_arn,
470
+ scope=JumpStartScriptScope.INFERENCE,
471
+ region=region,
472
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
473
+ tolerate_deprecated_model=tolerate_deprecated_model,
474
+ sagemaker_session=sagemaker_session,
475
+ model_type=model_type,
476
+ config_name=config_name,
477
+ )
478
+
479
+ supported_accept_types = model_specs.predictor_specs.supported_accept_types
480
+
481
+ return supported_accept_types
482
+
483
+
484
+ def _retrieve_supported_content_types(
485
+ model_id: str,
486
+ model_version: str,
487
+ hub_arn: Optional[str],
488
+ region: Optional[str],
489
+ tolerate_vulnerable_model: bool = False,
490
+ tolerate_deprecated_model: bool = False,
491
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
492
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
493
+ config_name: Optional[str] = None,
494
+ ) -> List[str]:
495
+ """Retrieves the supported content types for the model.
496
+
497
+ Args:
498
+ model_id (str): JumpStart model ID of the JumpStart model for which to
499
+ retrieve the supported content types.
500
+ model_version (str): Version of the JumpStart model for which to retrieve the
501
+ supported content types.
502
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
503
+ model details from. (Default: None).
504
+ region (Optional[str]): Region for which to retrieve content type options.
505
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
506
+ specifications should be tolerated (exception not raised). If False, raises an
507
+ exception if the script used by this version of the model has dependencies with known
508
+ security vulnerabilities. (Default: False).
509
+ tolerate_deprecated_model (bool): True if deprecated versions of model
510
+ specifications should be tolerated (exception not raised). If False, raises
511
+ an exception if the version of the model is deprecated. (Default: False).
512
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
513
+ object, used for SageMaker interactions. If not
514
+ specified, one is created using the default AWS configuration
515
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
516
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
517
+ Returns:
518
+ list: the supported content types to use for the model.
519
+ """
520
+
521
+ region = region or get_region_fallback(
522
+ sagemaker_session=sagemaker_session,
523
+ )
524
+
525
+ model_specs = verify_model_region_and_return_specs(
526
+ model_id=model_id,
527
+ version=model_version,
528
+ hub_arn=hub_arn,
529
+ scope=JumpStartScriptScope.INFERENCE,
530
+ region=region,
531
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
532
+ tolerate_deprecated_model=tolerate_deprecated_model,
533
+ sagemaker_session=sagemaker_session,
534
+ model_type=model_type,
535
+ config_name=config_name,
536
+ )
537
+
538
+ supported_content_types = model_specs.predictor_specs.supported_content_types
539
+
540
+ return supported_content_types
@@ -0,0 +1,86 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart resource names."""
14
+ from __future__ import absolute_import
15
+ from typing import Optional
16
+ from sagemaker.core.jumpstart.constants import (
17
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18
+ )
19
+ from sagemaker.core.jumpstart.enums import (
20
+ JumpStartScriptScope,
21
+ JumpStartModelType,
22
+ )
23
+ from sagemaker.core.jumpstart.utils import (
24
+ get_region_fallback,
25
+ verify_model_region_and_return_specs,
26
+ )
27
+ from sagemaker.core.helper.session_helper import Session
28
+
29
+
30
+ def _retrieve_resource_name_base(
31
+ model_id: str,
32
+ model_version: str,
33
+ region: Optional[str],
34
+ hub_arn: Optional[str] = None,
35
+ tolerate_vulnerable_model: bool = False,
36
+ tolerate_deprecated_model: bool = False,
37
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
38
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39
+ scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
40
+ config_name: Optional[str] = None,
41
+ ) -> bool:
42
+ """Returns default resource name.
43
+
44
+ Args:
45
+ model_id (str): JumpStart model ID of the JumpStart model for which to
46
+ get default resource name.
47
+ model_version (str): Version of the JumpStart model for which to retrieve the
48
+ default resource name.
49
+ region (Optional[str]): Region for which to retrieve the
50
+ default resource name.
51
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
52
+ model details from. (Default: None).
53
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
54
+ specifications should be tolerated (exception not raised). If False, raises an
55
+ exception if the script used by this version of the model has dependencies with known
56
+ security vulnerabilities. (Default: False).
57
+ tolerate_deprecated_model (bool): True if deprecated versions of model
58
+ specifications should be tolerated (exception not raised). If False, raises
59
+ an exception if the version of the model is deprecated. (Default: False).
60
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
61
+ object, used for SageMaker interactions. If not
62
+ specified, one is created using the default AWS configuration
63
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64
+ config_name (Optional[str]): Name of the JumpStart Model config. (Default: None).
65
+ Returns:
66
+ str: the default resource name.
67
+ """
68
+
69
+ region = region or get_region_fallback(
70
+ sagemaker_session=sagemaker_session,
71
+ )
72
+
73
+ model_specs = verify_model_region_and_return_specs(
74
+ model_id=model_id,
75
+ version=model_version,
76
+ hub_arn=hub_arn,
77
+ scope=scope,
78
+ region=region,
79
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
80
+ tolerate_deprecated_model=tolerate_deprecated_model,
81
+ model_type=model_type,
82
+ sagemaker_session=sagemaker_session,
83
+ config_name=config_name,
84
+ )
85
+
86
+ return model_specs.resource_name_base