sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,663 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module defines the JumpStartModelsCache class."""
14
+ from __future__ import absolute_import
15
+ import datetime
16
+ from difflib import get_close_matches
17
+ import os
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+ import json
20
+ import boto3
21
+ import botocore
22
+ from packaging.version import Version
23
+ from packaging.specifiers import SpecifierSet, InvalidSpecifier
24
+ from sagemaker.core.jumpstart.constants import (
25
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
26
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
27
+ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
28
+ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
29
+ JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
30
+ JUMPSTART_LOGGER,
31
+ MODEL_ID_LIST_WEB_URL,
32
+ MODEL_TYPE_TO_MANIFEST_MAP,
33
+ MODEL_TYPE_TO_SPECS_MAP,
34
+ )
35
+ from sagemaker.core.jumpstart.exceptions import (
36
+ get_wildcard_model_version_msg,
37
+ get_wildcard_proprietary_model_version_msg,
38
+ )
39
+ from sagemaker.core.jumpstart.parameters import (
40
+ JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
41
+ JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
42
+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
43
+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
44
+ )
45
+ from sagemaker.core.jumpstart.types import (
46
+ JumpStartCachedContentKey,
47
+ JumpStartCachedContentValue,
48
+ JumpStartModelHeader,
49
+ JumpStartModelSpecs,
50
+ JumpStartS3FileType,
51
+ JumpStartVersionedModelId,
52
+ HubContentType,
53
+ )
54
+ from sagemaker.core.jumpstart.hub import utils as hub_utils
55
+ from sagemaker.core.jumpstart.hub.interfaces import DescribeHubContentResponse
56
+ from sagemaker.core.jumpstart.hub.parsers import (
57
+ make_model_specs_from_describe_hub_content_response,
58
+ )
59
+ from sagemaker.core.jumpstart.enums import JumpStartModelType
60
+ from sagemaker.core.jumpstart import utils
61
+ from sagemaker.core.utilities.cache import LRUCache
62
+ from sagemaker.core.helper.session_helper import Session
63
+
64
+
65
+ class JumpStartModelsCache:
66
+ """Class that implements a cache for JumpStart models manifests and specs.
67
+
68
+ The manifest and specs associated with JumpStart models provide the information necessary
69
+ for launching JumpStart models from the SageMaker SDK.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ region: Optional[str] = None,
75
+ max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
76
+ s3_cache_expiration_horizon: datetime.timedelta = (
77
+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON
78
+ ),
79
+ max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
80
+ semantic_version_cache_expiration_horizon: datetime.timedelta = (
81
+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON
82
+ ),
83
+ manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
84
+ proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
85
+ s3_bucket_name: Optional[str] = None,
86
+ s3_client_config: Optional[botocore.config.Config] = None,
87
+ s3_client: Optional[boto3.client] = None,
88
+ sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
89
+ ) -> None:
90
+ """Initialize a ``JumpStartModelsCache`` instance.
91
+
92
+ Args:
93
+ region (str): AWS region to associate with cache. Default: region associated
94
+ with boto3 session.
95
+ max_s3_cache_items (int): Maximum number of items to store in s3 cache.
96
+ Default: 20.
97
+ s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold
98
+ items in s3 cache before invalidation. Default: 6 hours.
99
+ max_semantic_version_cache_items (int): Maximum number of items to store in
100
+ semantic version cache. Default: 20.
101
+ semantic_version_cache_expiration_horizon (datetime.timedelta):
102
+ Maximum time to hold items in semantic version cache before invalidation.
103
+ Default: 6 hours.
104
+ manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest.
105
+ s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
106
+ Default: JumpStart-hosted content bucket for region.
107
+ s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
108
+ Default: None (no config).
109
+ s3_client (Optional[boto3.client]): s3 client to use. Default: None.
110
+ sagemaker_session: sagemaker session object to use.
111
+ Default: session object from default region us-west-2.
112
+ """
113
+
114
+ self._region = region or utils.get_region_fallback(
115
+ s3_bucket_name=s3_bucket_name, s3_client=s3_client
116
+ )
117
+
118
+ self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
119
+ max_cache_items=max_s3_cache_items,
120
+ expiration_horizon=s3_cache_expiration_horizon,
121
+ retrieval_function=self._retrieval_function,
122
+ )
123
+ self._open_weight_model_id_manifest_key_cache = LRUCache[
124
+ JumpStartVersionedModelId, JumpStartVersionedModelId
125
+ ](
126
+ max_cache_items=max_semantic_version_cache_items,
127
+ expiration_horizon=semantic_version_cache_expiration_horizon,
128
+ retrieval_function=self._get_open_weight_manifest_key_from_model_id,
129
+ )
130
+ self._proprietary_model_id_manifest_key_cache = LRUCache[
131
+ JumpStartVersionedModelId, JumpStartVersionedModelId
132
+ ](
133
+ max_cache_items=max_semantic_version_cache_items,
134
+ expiration_horizon=semantic_version_cache_expiration_horizon,
135
+ retrieval_function=self._get_proprietary_manifest_key_from_model_id,
136
+ )
137
+ self._manifest_file_s3_key = manifest_file_s3_key
138
+ self._proprietary_manifest_s3_key = proprietary_manifest_s3_key
139
+ self._manifest_file_s3_map = {
140
+ JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key,
141
+ JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key,
142
+ }
143
+ self.s3_bucket_name = (
144
+ utils.get_jumpstart_content_bucket(self._region)
145
+ if s3_bucket_name is None
146
+ else s3_bucket_name
147
+ )
148
+ self._s3_client = s3_client or (
149
+ boto3.client("s3", region_name=self._region, config=s3_client_config)
150
+ if s3_client_config
151
+ else boto3.client("s3", region_name=self._region)
152
+ )
153
+ # Fallback in case a caller overrides sagemaker_session to None
154
+ self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
155
+
156
+ def set_region(self, region: str) -> None:
157
+ """Set region for cache. Clears cache after new region is set."""
158
+ if region != self._region:
159
+ self._region = region
160
+ self.clear()
161
+
162
+ def get_region(self) -> str:
163
+ """Return region for cache."""
164
+ return self._region
165
+
166
+ def set_manifest_file_s3_key(
167
+ self,
168
+ key: str,
169
+ file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
170
+ ) -> None:
171
+ """Set manifest file s3 key, clear cache after new key is set.
172
+
173
+ Raises:
174
+ ValueError: if the file type is not recognized
175
+ """
176
+ file_mapping = {
177
+ JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key,
178
+ JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key,
179
+ }
180
+ property_name = file_mapping.get(file_type)
181
+ if not property_name:
182
+ raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
183
+ if key != property_name:
184
+ setattr(self, property_name, key)
185
+ self.clear()
186
+
187
+ def get_manifest_file_s3_key(
188
+ self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST
189
+ ) -> str:
190
+ """Return manifest file s3 key for cache."""
191
+ if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
192
+ return self._manifest_file_s3_key
193
+ if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST:
194
+ return self._proprietary_manifest_s3_key
195
+ raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
196
+
197
+ def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
198
+ """Set s3 bucket used for cache."""
199
+ if s3_bucket_name != self.s3_bucket_name:
200
+ self.s3_bucket_name = s3_bucket_name
201
+ self.clear()
202
+
203
+ def get_bucket(self) -> str:
204
+ """Return bucket used for cache."""
205
+ return self.s3_bucket_name
206
+
207
+ def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str:
208
+ """Return error message for bad model type."""
209
+ if manifest_only:
210
+ return (
211
+ f"Bad value when getting manifest '{file_type}': "
212
+ f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} "
213
+ f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}."
214
+ )
215
+ return (
216
+ f"Bad value when getting manifest '{file_type}': "
217
+ f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'."
218
+ )
219
+
220
+ def _model_id_retrieval_function(
221
+ self,
222
+ key: JumpStartVersionedModelId,
223
+ value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
224
+ model_type: JumpStartModelType,
225
+ ) -> JumpStartVersionedModelId:
226
+ """Return model ID and version in manifest that matches semantic version/id.
227
+
228
+ Uses ``packaging.version`` to perform version comparison. The highest model version
229
+ matching the semantic version is used, which is compatible with the SageMaker
230
+ version.
231
+
232
+ Args:
233
+ key (JumpStartVersionedModelId): Key for which to fetch versioned model ID.
234
+ value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
235
+ old cached model ID/version.
236
+ model_type (JumpStartModelType): JumpStart model type to indicate whether it is
237
+ open weights model or proprietary (Marketplace) model.
238
+
239
+ Raises:
240
+ KeyError: If the semantic version is not found in the manifest, or is found but
241
+ the SageMaker version needs to be upgraded in order for the model to be used.
242
+ """
243
+
244
+ model_id, version = key.model_id, key.version
245
+ sm_version = utils.get_sagemaker_version()
246
+ manifest = self._content_cache.get(
247
+ JumpStartCachedContentKey(
248
+ MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
249
+ )
250
+ )[0].formatted_content
251
+
252
+ versions_compatible_with_sagemaker = [
253
+ header.version
254
+ for header in manifest.values() # type: ignore
255
+ if header.model_id == model_id and Version(header.min_version) <= Version(sm_version)
256
+ ]
257
+
258
+ sm_compatible_model_version = self._select_version(
259
+ model_id, version, versions_compatible_with_sagemaker, model_type
260
+ )
261
+
262
+ if sm_compatible_model_version is not None:
263
+ return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
264
+
265
+ versions_incompatible_with_sagemaker = [
266
+ header.version
267
+ for header in manifest.values() # type: ignore
268
+ if header.model_id == model_id
269
+ ]
270
+ sm_incompatible_model_version = self._select_version(
271
+ model_id, version, versions_incompatible_with_sagemaker, model_type
272
+ )
273
+
274
+ if sm_incompatible_model_version is not None:
275
+ model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
276
+ sm_version_to_use_list = [
277
+ header.min_version
278
+ for header in manifest.values() # type: ignore
279
+ if header.model_id == model_id
280
+ and header.version == model_version_to_use_incompatible_with_sagemaker
281
+ ]
282
+ if len(sm_version_to_use_list) != 1:
283
+ # ``manifest`` dict should already enforce this
284
+ raise RuntimeError("Found more than one incompatible SageMaker version to use.")
285
+ sm_version_to_use = sm_version_to_use_list[0]
286
+
287
+ error_msg = (
288
+ f"Unable to find model manifest for '{model_id}' with version '{version}' "
289
+ f"compatible with your SageMaker version ('{sm_version}'). "
290
+ f"Consider upgrading your SageMaker library to at least version "
291
+ f"'{sm_version_to_use}' so you can use version "
292
+ f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
293
+ )
294
+ raise KeyError(error_msg)
295
+
296
+ error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
297
+ error_msg += "Specify a different model ID or try a different AWS Region. "
298
+ error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "
299
+
300
+ other_model_id_version = None
301
+ if model_type == JumpStartModelType.OPEN_WEIGHTS:
302
+ other_model_id_version = self._select_version(
303
+ model_id, "*", versions_incompatible_with_sagemaker, model_type
304
+ ) # all versions here are incompatible with sagemaker
305
+ elif model_type == JumpStartModelType.PROPRIETARY:
306
+ all_possible_model_id_version = [
307
+ header.version
308
+ for header in manifest.values() # type: ignore
309
+ if header.model_id == model_id
310
+ ]
311
+ other_model_id_version = (
312
+ None if not all_possible_model_id_version else all_possible_model_id_version[0]
313
+ )
314
+
315
+ if other_model_id_version is not None:
316
+ error_msg += (
317
+ f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'."
318
+ )
319
+ else:
320
+ possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
321
+ closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
322
+ error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
323
+
324
+ raise KeyError(error_msg)
325
+
326
+ def _get_open_weight_manifest_key_from_model_id(
327
+ self,
328
+ key: JumpStartVersionedModelId,
329
+ value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
330
+ ) -> JumpStartVersionedModelId:
331
+ """For open weights models, retrieve model manifest key for open weight model.
332
+
333
+ Filters models list by supported versions.
334
+ """
335
+ return self._model_id_retrieval_function(
336
+ key, value, model_type=JumpStartModelType.OPEN_WEIGHTS
337
+ )
338
+
339
+ def _get_proprietary_manifest_key_from_model_id(
340
+ self,
341
+ key: JumpStartVersionedModelId,
342
+ value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
343
+ ) -> JumpStartVersionedModelId:
344
+ """For proprietary models, retrieve model manifest key for proprietary model.
345
+
346
+ Filters models list by supported versions.
347
+ """
348
+ return self._model_id_retrieval_function(
349
+ key, value, model_type=JumpStartModelType.PROPRIETARY
350
+ )
351
+
352
+ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
353
+ """Returns json file from s3, along with its etag."""
354
+ response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
355
+ return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]
356
+
357
+ def _is_local_metadata_mode(self) -> bool:
358
+ """Returns True if the cache should use local metadata mode, based off env variables."""
359
+ return (
360
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
361
+ and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
362
+ and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
363
+ and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])
364
+ )
365
+
366
+ def _get_json_file(
367
+ self, key: str, filetype: JumpStartS3FileType
368
+ ) -> Tuple[Union[dict, list], Optional[str]]:
369
+ """Returns json file either from s3 or local file system.
370
+
371
+ Returns etag along with json object for s3, or just the json
372
+ object and None when reading from the local file system.
373
+ """
374
+ if self._is_local_metadata_mode():
375
+ file_content, etag = self._get_json_file_from_local_override(key, filetype), None
376
+ else:
377
+ file_content, etag = self._get_json_file_and_etag_from_s3(key)
378
+ return file_content, etag
379
+
380
+ def _get_json_md5_hash(self, key: str):
381
+ """Retrieves md5 object hash for s3 objects, using `s3.head_object`.
382
+
383
+ Raises:
384
+ ValueError: if the cache should use local metadata mode.
385
+ """
386
+ if self._is_local_metadata_mode():
387
+ raise ValueError("Cannot get md5 hash of local file.")
388
+ return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
389
+
390
+ def _get_json_file_from_local_override(
391
+ self, key: str, filetype: JumpStartS3FileType
392
+ ) -> Union[dict, list]:
393
+ """Reads json file from local filesystem and returns data."""
394
+ if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
395
+ metadata_local_root = os.environ[
396
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
397
+ ]
398
+ elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
399
+ metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
400
+ else:
401
+ raise ValueError(f"Unsupported file type for local override: {filetype}")
402
+ file_path = os.path.join(metadata_local_root, key)
403
+ with open(file_path, "r") as f:
404
+ data = json.load(f)
405
+ return data
406
+
407
+ def _retrieval_function(
408
+ self,
409
+ key: JumpStartCachedContentKey,
410
+ value: Optional[JumpStartCachedContentValue],
411
+ ) -> JumpStartCachedContentValue:
412
+ """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``.
413
+
414
+ If a manifest file is being fetched, we only download the object if the md5 hash in
415
+ ``head_object`` does not match the current md5 hash for the stored value. This prevents
416
+ unnecessarily downloading the full manifest when it hasn't changed.
417
+
418
+ Args:
419
+ key (JumpStartCachedContentKey): key for which to fetch s3 content.
420
+ value (Optional[JumpStartVersionedModelId]): Current value of old cached
421
+ s3 content. This is used for the manifest file, so that it is only
422
+ downloaded when its content changes.
423
+ """
424
+
425
+ data_type, id_info = key.data_type, key.id_info
426
+
427
+ if data_type in {
428
+ JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
429
+ JumpStartS3FileType.PROPRIETARY_MANIFEST,
430
+ }:
431
+ if value is not None and not self._is_local_metadata_mode():
432
+ etag = self._get_json_md5_hash(id_info)
433
+ if etag == value.md5_hash:
434
+ return value
435
+ formatted_body, etag = self._get_json_file(id_info, data_type)
436
+ return JumpStartCachedContentValue(
437
+ formatted_content=utils.get_formatted_manifest(formatted_body),
438
+ md5_hash=etag,
439
+ )
440
+ if data_type in {
441
+ JumpStartS3FileType.OPEN_WEIGHT_SPECS,
442
+ JumpStartS3FileType.PROPRIETARY_SPECS,
443
+ }:
444
+ formatted_body, _ = self._get_json_file(id_info, data_type)
445
+ model_specs = JumpStartModelSpecs(formatted_body)
446
+ utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
447
+ return JumpStartCachedContentValue(formatted_content=model_specs)
448
+
449
+ if data_type == HubContentType.NOTEBOOK:
450
+ hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn(
451
+ id_info
452
+ )
453
+ response: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
454
+ hub_name=hub_name,
455
+ hub_content_name=notebook_name,
456
+ hub_content_version=notebook_version,
457
+ hub_content_type=data_type,
458
+ )
459
+ hub_notebook_description = DescribeHubContentResponse(response)
460
+ return JumpStartCachedContentValue(formatted_content=hub_notebook_description)
461
+
462
+ if data_type in {
463
+ HubContentType.MODEL,
464
+ HubContentType.MODEL_REFERENCE,
465
+ }:
466
+
467
+ hub_resource_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info)
468
+ hub_arn = hub_utils.construct_hub_arn_from_name(
469
+ hub_name=hub_resource_arn_extracted_info.hub_name,
470
+ region=hub_resource_arn_extracted_info.region,
471
+ account_id=hub_resource_arn_extracted_info.account_id,
472
+ )
473
+
474
+ model_version: str = hub_utils.get_hub_model_version(
475
+ hub_model_name=hub_resource_arn_extracted_info.hub_content_name,
476
+ hub_model_type=data_type.value,
477
+ hub_name=hub_arn,
478
+ sagemaker_session=self._sagemaker_session,
479
+ hub_model_version=hub_resource_arn_extracted_info.hub_content_version,
480
+ )
481
+
482
+ hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
483
+ hub_name=hub_arn,
484
+ hub_content_name=hub_resource_arn_extracted_info.hub_content_name,
485
+ hub_content_version=model_version,
486
+ hub_content_type=data_type.value,
487
+ )
488
+
489
+ model_specs = make_model_specs_from_describe_hub_content_response(
490
+ DescribeHubContentResponse(hub_model_description),
491
+ )
492
+
493
+ return JumpStartCachedContentValue(formatted_content=model_specs)
494
+
495
+ raise ValueError(self._file_type_error_msg(data_type))
496
+
497
+ def get_manifest(
498
+ self,
499
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
500
+ ) -> List[JumpStartModelHeader]:
501
+ """Return entire JumpStart models manifest."""
502
+ manifest_dict = self._content_cache.get(
503
+ JumpStartCachedContentKey(
504
+ MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
505
+ )
506
+ )[0].formatted_content
507
+ manifest = list(manifest_dict.values()) # type: ignore
508
+ return manifest
509
+
510
+ def get_header(
511
+ self,
512
+ model_id: str,
513
+ semantic_version_str: str,
514
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
515
+ ) -> JumpStartModelHeader:
516
+ """Return header for a given JumpStart model ID and semantic version.
517
+
518
+ Args:
519
+ model_id (str): model ID for which to get a header.
520
+ semantic_version_str (str): The semantic version for which to get a
521
+ header.
522
+ """
523
+
524
+ return self._get_header_impl(
525
+ model_id, semantic_version_str=semantic_version_str, model_type=model_type
526
+ )
527
+
528
+ def _select_version(
529
+ self,
530
+ model_id: str,
531
+ version_str: str,
532
+ available_versions: List[str],
533
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
534
+ ) -> Optional[str]:
535
+ """Perform semantic version search on available versions.
536
+
537
+ Args:
538
+ version_str (str): the semantic version for which to filter
539
+ available versions.
540
+ available_versions (List[Version]): list of available versions.
541
+ """
542
+
543
+ if version_str == "*":
544
+ return utils.get_latest_version(available_versions)
545
+
546
+ if model_type == JumpStartModelType.PROPRIETARY:
547
+ if "*" in version_str:
548
+ raise KeyError(
549
+ get_wildcard_proprietary_model_version_msg(
550
+ model_id, version_str, available_versions
551
+ )
552
+ )
553
+ return version_str if version_str in available_versions else None
554
+
555
+ try:
556
+ spec = SpecifierSet(f"=={version_str}")
557
+ except InvalidSpecifier:
558
+ raise KeyError(f"Bad semantic version: {version_str}")
559
+ available_versions_filtered = list(spec.filter(available_versions))
560
+ return str(max(available_versions_filtered)) if available_versions_filtered != [] else None
561
+
562
+ def _get_header_impl(
563
+ self,
564
+ model_id: str,
565
+ semantic_version_str: str,
566
+ attempt: int = 0,
567
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
568
+ ) -> JumpStartModelHeader:
569
+ """Lower-level function to return header.
570
+
571
+ Allows a single retry if the cache is old.
572
+
573
+ Args:
574
+ model_id (str): model ID for which to get a header.
575
+ semantic_version_str (str): The semantic version for which to get a
576
+ header.
577
+ attempt (int): attempt number at retrieving a header.
578
+ """
579
+ if model_type == JumpStartModelType.OPEN_WEIGHTS:
580
+ versioned_model_id = self._open_weight_model_id_manifest_key_cache.get(
581
+ JumpStartVersionedModelId(model_id, semantic_version_str)
582
+ )[0]
583
+ elif model_type == JumpStartModelType.PROPRIETARY:
584
+ versioned_model_id = self._proprietary_model_id_manifest_key_cache.get(
585
+ JumpStartVersionedModelId(model_id, semantic_version_str)
586
+ )[0]
587
+
588
+ manifest = self._content_cache.get(
589
+ JumpStartCachedContentKey(
590
+ MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
591
+ )
592
+ )[0].formatted_content
593
+
594
+ try:
595
+ header = manifest[versioned_model_id] # type: ignore
596
+ return header
597
+ except KeyError:
598
+ if attempt > 0:
599
+ raise
600
+ self.clear()
601
+ return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type)
602
+
603
+ def get_specs(
604
+ self,
605
+ model_id: str,
606
+ version_str: str,
607
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
608
+ ) -> JumpStartModelSpecs:
609
+ """Return specs for a given JumpStart model ID and semantic version.
610
+
611
+ Args:
612
+ model_id (str): model ID for which to get specs.
613
+ semantic_version_str (str): The semantic version for which to get
614
+ specs.
615
+ model_type (JumpStartModelType): The type of the model of interest.
616
+ """
617
+ header = self.get_header(model_id, version_str, model_type)
618
+ spec_key = header.spec_key
619
+ specs, cache_hit = self._content_cache.get(
620
+ JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
621
+ )
622
+
623
+ if not cache_hit and "*" in version_str:
624
+ JUMPSTART_LOGGER.warning(
625
+ get_wildcard_model_version_msg(header.model_id, version_str, header.version)
626
+ )
627
+ return specs.formatted_content
628
+
629
+ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
630
+ """Return JumpStart-compatible specs for a given Hub model
631
+
632
+ Args:
633
+ hub_model_arn (str): Arn for the Hub model to get specs for
634
+ """
635
+
636
+ details, _ = self._content_cache.get(
637
+ JumpStartCachedContentKey(
638
+ HubContentType.MODEL,
639
+ hub_model_arn,
640
+ )
641
+ )
642
+ return details.formatted_content
643
+
644
+ def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs:
645
+ """Return JumpStart-compatible specs for a given Hub model reference
646
+
647
+ Args:
648
+ hub_model_arn (str): Arn for the Hub model to get specs for
649
+ """
650
+
651
+ details, _ = self._content_cache.get(
652
+ JumpStartCachedContentKey(
653
+ HubContentType.MODEL_REFERENCE,
654
+ hub_model_reference_arn,
655
+ )
656
+ )
657
+ return details.formatted_content
658
+
659
+ def clear(self) -> None:
660
+ """Clears the model ID/version and s3 cache."""
661
+ self._content_cache.clear()
662
+ self._open_weight_model_id_manifest_key_cache.clear()
663
+ self._proprietary_model_id_manifest_key_cache.clear()