sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,70 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ # pylint: skip-file
14
+ """This module contains utilities related to SageMaker JumpStart Hub."""
15
+ from __future__ import absolute_import
16
+
17
+ import re
18
+ from typing import Any, Dict, List, Optional
19
+
20
+
21
+ def camel_to_snake(camel_case_string: str) -> str:
22
+ """Converts camelCase to snake_case_string using a regex.
23
+
24
+ This regex cannot handle whitespace ("camelString TwoWords")
25
+ """
26
+ return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
27
+
28
+
29
+ def snake_to_upper_camel(snake_case_string: str) -> str:
30
+ """Converts snake_case_string to UpperCamelCaseString."""
31
+ upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_"))
32
+ return upper_camel_case_string
33
+
34
+
35
+ def walk_and_apply_json(
36
+ json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
37
+ ) -> Dict[Any, Any]:
38
+ """Recursively walks a json object and applies a given function to the keys.
39
+
40
+ stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
41
+ Any children of these keys will not have the application function applied to them.
42
+ """
43
+
44
+ def _walk_and_apply_json(json_obj, new):
45
+ if isinstance(json_obj, dict) and isinstance(new, dict):
46
+ for key, value in json_obj.items():
47
+ new_key = apply(key)
48
+ if (stop_keys and new_key not in stop_keys) or stop_keys is None:
49
+ if isinstance(value, dict):
50
+ new[new_key] = {}
51
+ _walk_and_apply_json(value, new=new[new_key])
52
+ elif isinstance(value, list):
53
+ new[new_key] = []
54
+ for item in value:
55
+ _walk_and_apply_json(item, new=new[new_key])
56
+ else:
57
+ new[new_key] = value
58
+ else:
59
+ new[new_key] = value
60
+ elif isinstance(json_obj, dict) and isinstance(new, list):
61
+ new.append(_walk_and_apply_json(json_obj, new={}))
62
+ elif isinstance(json_obj, list) and isinstance(new, dict):
63
+ new.update(json_obj)
64
+ elif isinstance(json_obj, list) and isinstance(new, list):
65
+ new.append(json_obj)
66
+ elif isinstance(json_obj, str) and isinstance(new, list):
67
+ new.append(json_obj)
68
+ return new
69
+
70
+ return _walk_and_apply_json(json_obj, new={})
@@ -0,0 +1,288 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ # pylint: skip-file
14
+ """This module stores Hub converter utilities for JumpStart."""
15
+ from __future__ import absolute_import
16
+
17
+ from typing import Any, Dict, List
18
+ from sagemaker.core.jumpstart.enums import ModelSpecKwargType, NamingConventionType
19
+ from sagemaker.core.s3 import parse_s3_url
20
+ from sagemaker.core.jumpstart.types import (
21
+ JumpStartModelSpecs,
22
+ HubContentType,
23
+ JumpStartDataHolderType,
24
+ )
25
+ from sagemaker.core.jumpstart.hub.interfaces import (
26
+ DescribeHubContentResponse,
27
+ HubModelDocument,
28
+ )
29
+ from sagemaker.core.jumpstart.hub.parser_utils import (
30
+ camel_to_snake,
31
+ snake_to_upper_camel,
32
+ walk_and_apply_json,
33
+ )
34
+
35
+
36
+ def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]:
37
+ """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys"""
38
+ for key, value in dictionary.items():
39
+ if issubclass(type(value), JumpStartDataHolderType):
40
+ dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel)
41
+ elif isinstance(value, list):
42
+ new_value = []
43
+ for value_in_list in value:
44
+ new_value_in_list = value_in_list
45
+ if issubclass(type(value_in_list), JumpStartDataHolderType):
46
+ new_value_in_list = walk_and_apply_json(
47
+ value_in_list.to_json(), snake_to_upper_camel
48
+ )
49
+ new_value.append(new_value_in_list)
50
+ dictionary[key] = new_value
51
+ elif isinstance(value, dict):
52
+ for key_in_dict, value_in_dict in value.items():
53
+ if issubclass(type(value_in_dict), JumpStartDataHolderType):
54
+ value[key_in_dict] = walk_and_apply_json(
55
+ value_in_dict.to_json(), snake_to_upper_camel
56
+ )
57
+ return dictionary
58
+
59
+
60
+ def get_model_spec_arg_keys(
61
+ arg_type: ModelSpecKwargType,
62
+ naming_convention: NamingConventionType = NamingConventionType.DEFAULT,
63
+ ) -> List[str]:
64
+ """Returns a list of arg keys for a specific model spec arg type.
65
+
66
+ Args:
67
+ arg_type (ModelSpecKwargType): Type of the model spec's kwarg.
68
+ naming_convention (NamingConventionType): Type of naming convention to return.
69
+
70
+ Raises:
71
+ ValueError: If the naming convention is not valid.
72
+ """
73
+ arg_keys: List[str] = []
74
+ if arg_type == ModelSpecKwargType.DEPLOY:
75
+ arg_keys = [
76
+ "ModelDataDownloadTimeout",
77
+ "ContainerStartupHealthCheckTimeout",
78
+ "InferenceAmiVersion",
79
+ ]
80
+ elif arg_type == ModelSpecKwargType.ESTIMATOR:
81
+ arg_keys = [
82
+ "EncryptInterContainerTraffic",
83
+ "MaxRuntimeInSeconds",
84
+ "DisableOutputCompression",
85
+ "ModelDir",
86
+ ]
87
+ elif arg_type == ModelSpecKwargType.MODEL:
88
+ arg_keys = []
89
+ elif arg_type == ModelSpecKwargType.FIT:
90
+ arg_keys = []
91
+
92
+ if naming_convention == NamingConventionType.SNAKE_CASE:
93
+ arg_keys = [camel_to_snake(key) for key in arg_keys]
94
+ elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
95
+ return arg_keys
96
+ else:
97
+ raise ValueError("Please provide a valid naming convention.")
98
+ return arg_keys
99
+
100
+
101
+ def get_model_spec_kwargs_from_hub_model_document(
102
+ arg_type: ModelSpecKwargType,
103
+ hub_content_document: Dict[str, Any],
104
+ naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE,
105
+ ) -> Dict[str, Any]:
106
+ """Returns a map of arg type to arg keys for a given hub content document.
107
+
108
+ Args:
109
+ arg_type (ModelSpecKwargType): Type of the model spec's kwarg.
110
+ hub_content_document: A dictionary representation of hub content document.
111
+ naming_convention (NamingConventionType): Type of naming convention to return.
112
+
113
+ """
114
+ kwargs = dict()
115
+ keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention)
116
+ for k in keys:
117
+ kwarg_value = hub_content_document.get(k)
118
+ if kwarg_value is not None:
119
+ kwargs[k] = kwarg_value
120
+ return kwargs
121
+
122
+
123
+ def make_model_specs_from_describe_hub_content_response(
124
+ response: DescribeHubContentResponse,
125
+ ) -> JumpStartModelSpecs:
126
+ """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse
127
+
128
+ Args:
129
+ response (Dict[str, any]): parsed DescribeHubContentResponse returned
130
+ from SageMaker:DescribeHubContent
131
+ """
132
+ if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}:
133
+ raise AttributeError(
134
+ "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE."
135
+ )
136
+ region = response.get_hub_region()
137
+ specs = {}
138
+ model_id = response.hub_content_name
139
+ specs["model_id"] = model_id
140
+ specs["version"] = response.hub_content_version
141
+ hub_model_document: HubModelDocument = response.hub_content_document
142
+ specs["url"] = hub_model_document.url
143
+ specs["min_sdk_version"] = hub_model_document.min_sdk_version
144
+ specs["model_types"] = hub_model_document.model_types
145
+ specs["capabilities"] = hub_model_document.capabilities
146
+ specs["training_supported"] = bool(hub_model_document.training_supported)
147
+ specs["incremental_training_supported"] = bool(
148
+ hub_model_document.incremental_training_supported
149
+ )
150
+ specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri
151
+ specs["inference_configs"] = hub_model_document.inference_configs
152
+ specs["inference_config_components"] = hub_model_document.inference_config_components
153
+ specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
154
+
155
+ if hub_model_document.hosting_artifact_uri:
156
+ _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
157
+ hub_model_document.hosting_artifact_uri
158
+ )
159
+ specs["hosting_artifact_key"] = hosting_artifact_key
160
+ specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
161
+
162
+ if hub_model_document.hosting_script_uri:
163
+ _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
164
+ hub_model_document.hosting_script_uri
165
+ )
166
+ specs["hosting_script_key"] = hosting_script_key
167
+
168
+ specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
169
+ specs["inference_vulnerable"] = False
170
+ specs["inference_dependencies"] = hub_model_document.inference_dependencies
171
+ specs["inference_vulnerabilities"] = []
172
+ specs["training_vulnerable"] = False
173
+ specs["training_vulnerabilities"] = []
174
+ specs["deprecated"] = False
175
+ specs["deprecated_message"] = None
176
+ specs["deprecate_warn_message"] = None
177
+ specs["usage_info_message"] = None
178
+ specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type
179
+ specs["supported_inference_instance_types"] = (
180
+ hub_model_document.supported_inference_instance_types
181
+ )
182
+ specs["dynamic_container_deployment_supported"] = (
183
+ hub_model_document.dynamic_container_deployment_supported
184
+ )
185
+ specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements
186
+
187
+ specs["hosting_prepacked_artifact_key"] = None
188
+ if hub_model_document.hosting_prepacked_artifact_uri is not None:
189
+ (
190
+ hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable
191
+ hosting_prepacked_artifact_key,
192
+ ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri)
193
+ specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key
194
+
195
+ hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json()
196
+
197
+ specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
198
+ ModelSpecKwargType.FIT, hub_content_document_dict
199
+ )
200
+ specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
201
+ ModelSpecKwargType.MODEL, hub_content_document_dict
202
+ )
203
+ specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
204
+ ModelSpecKwargType.DEPLOY, hub_content_document_dict
205
+ )
206
+ specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
207
+ ModelSpecKwargType.ESTIMATOR, hub_content_document_dict
208
+ )
209
+
210
+ specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications
211
+ default_payloads: Dict[str, Any] = {}
212
+ if hub_model_document.default_payloads is not None:
213
+ for alias, payload in hub_model_document.default_payloads.items():
214
+ default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
215
+ specs["default_payloads"] = default_payloads
216
+ specs["gated_bucket"] = hub_model_document.gated_bucket
217
+ specs["inference_volume_size"] = hub_model_document.inference_volume_size
218
+ specs["inference_enable_network_isolation"] = (
219
+ hub_model_document.inference_enable_network_isolation
220
+ )
221
+ specs["resource_name_base"] = hub_model_document.resource_name_base
222
+
223
+ specs["hosting_eula_key"] = None
224
+ if hub_model_document.hosting_eula_uri is not None:
225
+ hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable
226
+ hub_model_document.hosting_eula_uri
227
+ )
228
+ specs["hosting_eula_key"] = hosting_eula_key
229
+
230
+ if hub_model_document.hosting_model_package_arn:
231
+ specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
232
+
233
+ specs["model_subscription_link"] = hub_model_document.model_subscription_link
234
+
235
+ specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
236
+
237
+ specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants
238
+
239
+ if specs["training_supported"]:
240
+ specs["training_ecr_uri"] = hub_model_document.training_ecr_uri
241
+ (
242
+ training_artifact_bucket, # pylint: disable=unused-variable
243
+ training_artifact_key,
244
+ ) = parse_s3_url(hub_model_document.training_artifact_uri)
245
+ specs["training_artifact_key"] = training_artifact_key
246
+ (
247
+ training_script_bucket, # pylint: disable=unused-variable
248
+ training_script_key,
249
+ ) = parse_s3_url(hub_model_document.training_script_uri)
250
+ specs["training_script_key"] = training_script_key
251
+
252
+ specs["training_configs"] = hub_model_document.training_configs
253
+ specs["training_config_components"] = hub_model_document.training_config_components
254
+ specs["training_config_rankings"] = hub_model_document.training_config_rankings
255
+
256
+ specs["training_dependencies"] = hub_model_document.training_dependencies
257
+ specs["default_training_instance_type"] = hub_model_document.default_training_instance_type
258
+ specs["supported_training_instance_types"] = (
259
+ hub_model_document.supported_training_instance_types
260
+ )
261
+ specs["metrics"] = hub_model_document.training_metrics
262
+ specs["training_prepacked_script_key"] = None
263
+ if hub_model_document.training_prepacked_script_uri is not None:
264
+ (
265
+ training_prepacked_script_bucket, # pylint: disable=unused-variable
266
+ training_prepacked_script_key,
267
+ ) = parse_s3_url(hub_model_document.training_prepacked_script_uri)
268
+ specs["training_prepacked_script_key"] = training_prepacked_script_key
269
+
270
+ specs["hyperparameters"] = hub_model_document.hyperparameters
271
+ specs["training_volume_size"] = hub_model_document.training_volume_size
272
+ specs["training_enable_network_isolation"] = (
273
+ hub_model_document.training_enable_network_isolation
274
+ )
275
+ if hub_model_document.training_model_package_artifact_uri:
276
+ specs["training_model_package_artifact_uris"] = {
277
+ region: hub_model_document.training_model_package_artifact_uri
278
+ }
279
+ specs["training_instance_type_variants"] = (
280
+ hub_model_document.training_instance_type_variants
281
+ )
282
+ if hub_model_document.default_training_dataset_uri:
283
+ _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
284
+ hub_model_document.default_training_dataset_uri
285
+ )
286
+ specs["default_training_dataset_key"] = default_training_dataset_key
287
+ specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
288
+ return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
@@ -0,0 +1,35 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores types related to SageMaker JumpStart Hub."""
14
+ from __future__ import absolute_import
15
+ from typing import Dict
16
+ from dataclasses import dataclass
17
+
18
+
19
+ @dataclass
20
+ class S3ObjectLocation:
21
+ """Helper class for S3 object references."""
22
+
23
+ bucket: str
24
+ key: str
25
+
26
+ def format_for_s3_copy(self) -> Dict[str, str]:
27
+ """Returns a dict formatted for S3 copy calls"""
28
+ return {
29
+ "Bucket": self.bucket,
30
+ "Key": self.key,
31
+ }
32
+
33
+ def get_uri(self) -> str:
34
+ """Returns the s3 URI"""
35
+ return f"s3://{self.bucket}/{self.key}"
@@ -0,0 +1,260 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ # pylint: skip-file
14
+ """This module contains utilities related to SageMaker JumpStart Hub."""
15
+ from __future__ import absolute_import
16
+ import re
17
+ from typing import Optional, List, Any
18
+ from sagemaker.core.helper.session_helper import Session
19
+ from sagemaker.core.common_utils import aws_partition
20
+ from sagemaker.core.jumpstart.types import HubContentType, HubArnExtractedInfo
21
+ from sagemaker.core.jumpstart import constants
22
+ from packaging.specifiers import SpecifierSet, InvalidSpecifier
23
+ from packaging import version
24
+
25
+ PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
26
+
27
+
28
+ def _convert_str_to_optional(string: str) -> Optional[str]:
29
+ if string == "None":
30
+ string = None
31
+ return string
32
+
33
+
34
+ def get_info_from_hub_resource_arn(
35
+ arn: str,
36
+ ) -> HubArnExtractedInfo:
37
+ """Extracts descriptive information from a Hub or HubContent Arn."""
38
+
39
+ match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
40
+ if match:
41
+ partition = match.group(1)
42
+ hub_region = match.group(2)
43
+ account_id = match.group(3)
44
+ hub_name = match.group(4)
45
+ hub_content_type = match.group(5)
46
+ hub_content_name = match.group(6)
47
+ hub_content_version = _convert_str_to_optional(match.group(7))
48
+
49
+ return HubArnExtractedInfo(
50
+ partition=partition,
51
+ region=hub_region,
52
+ account_id=account_id,
53
+ hub_name=hub_name,
54
+ hub_content_type=hub_content_type,
55
+ hub_content_name=hub_content_name,
56
+ hub_content_version=hub_content_version,
57
+ )
58
+
59
+ match = re.match(constants.HUB_ARN_REGEX, arn)
60
+ if match:
61
+ partition = match.group(1)
62
+ hub_region = match.group(2)
63
+ account_id = match.group(3)
64
+ hub_name = match.group(4)
65
+ return HubArnExtractedInfo(
66
+ partition=partition,
67
+ region=hub_region,
68
+ account_id=account_id,
69
+ hub_name=hub_name,
70
+ )
71
+
72
+
73
+ def construct_hub_arn_from_name(
74
+ hub_name: str,
75
+ region: Optional[str] = None,
76
+ session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
77
+ account_id: Optional[str] = None,
78
+ ) -> str:
79
+ """Constructs a Hub arn from the Hub name using default Session values."""
80
+ if session is None:
81
+ # session is overridden to none by some callers
82
+ session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
83
+
84
+ account_id = account_id or session.account_id()
85
+ region = region or session.boto_region_name
86
+ partition = aws_partition(region)
87
+
88
+ return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}"
89
+
90
+
91
+ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str:
92
+ """Constructs a HubContent model arn from the Hub name, model name, and model version."""
93
+
94
+ info = get_info_from_hub_resource_arn(hub_arn)
95
+ arn = (
96
+ f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/"
97
+ f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}"
98
+ )
99
+
100
+ return arn
101
+
102
+
103
+ def construct_hub_model_reference_arn_from_inputs(
104
+ hub_arn: str, model_name: str, version: str
105
+ ) -> str:
106
+ """Constructs a HubContent model arn from the Hub name, model name, and model version."""
107
+
108
+ info = get_info_from_hub_resource_arn(hub_arn)
109
+ arn = (
110
+ f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/"
111
+ f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}"
112
+ )
113
+
114
+ return arn
115
+
116
+
117
+ def generate_hub_arn_for_init_kwargs(
118
+ hub_name: str, region: Optional[str] = None, session: Optional[Session] = None
119
+ ):
120
+ """Generates the Hub Arn for JumpStart class args from a HubName or Arn.
121
+
122
+ Args:
123
+ hub_name (str): HubName or HubArn from JumpStart class args
124
+ region (str): Region from JumpStart class args
125
+ session (Session): Custom SageMaker Session from JumpStart class args
126
+ """
127
+
128
+ hub_arn = None
129
+ if hub_name:
130
+ if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
131
+ return None
132
+ match = re.match(constants.HUB_ARN_REGEX, hub_name)
133
+ if match:
134
+ hub_arn = hub_name
135
+ else:
136
+ hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
137
+ return hub_arn
138
+
139
+
140
+ def is_gated_bucket(bucket_name: str) -> bool:
141
+ """Returns true if the bucket name is the JumpStart gated bucket."""
142
+ return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET
143
+
144
+
145
+ def get_hub_model_version(
146
+ hub_name: str,
147
+ hub_model_name: str,
148
+ hub_model_type: str,
149
+ hub_model_version: Optional[str] = None,
150
+ sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
151
+ ) -> str:
152
+ """Returns available Jumpstart hub model version.
153
+
154
+ It will attempt both a semantic HubContent version search and Marketplace version search.
155
+ If the Marketplace version is also semantic, this function will default to HubContent version.
156
+
157
+ Raises:
158
+ ClientError: If the specified model is not found in the hub.
159
+ KeyError: If the specified model version is not found.
160
+ """
161
+ if sagemaker_session is None:
162
+ # sagemaker_session is overridden to none by some callers
163
+ sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
164
+
165
+ try:
166
+ hub_content_summaries = _list_hub_content_versions_helper(
167
+ hub_name=hub_name,
168
+ hub_content_name=hub_model_name,
169
+ hub_content_type=hub_model_type,
170
+ sagemaker_session=sagemaker_session,
171
+ )
172
+ except Exception as ex:
173
+ raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
174
+
175
+ try:
176
+ return _get_hub_model_version_for_open_weight_version(
177
+ hub_content_summaries, hub_model_version
178
+ )
179
+ except KeyError:
180
+ marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
181
+ hub_content_summaries, hub_model_version
182
+ )
183
+ if marketplace_hub_content_version:
184
+ return marketplace_hub_content_version
185
+ raise
186
+
187
+
188
+ def _list_hub_content_versions_helper(
189
+ hub_name, hub_content_name, hub_content_type, sagemaker_session
190
+ ):
191
+ all_hub_content_summaries = []
192
+ list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
193
+ hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
194
+ )
195
+ all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
196
+ while "NextToken" in list_hub_content_versions_response:
197
+ list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
198
+ hub_name=hub_name,
199
+ hub_content_name=hub_content_name,
200
+ hub_content_type=hub_content_type,
201
+ next_token=list_hub_content_versions_response["NextToken"],
202
+ )
203
+ all_hub_content_summaries.extend(
204
+ list_hub_content_versions_response.get("HubContentSummaries")
205
+ )
206
+ return all_hub_content_summaries
207
+
208
+
209
+ def _get_hub_model_version_for_open_weight_version(
210
+ hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
211
+ ) -> str:
212
+ available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
213
+
214
+ if hub_model_version == "*" or hub_model_version is None:
215
+ return str(max(version.parse(v) for v in available_model_versions))
216
+
217
+ try:
218
+ spec = SpecifierSet(f"=={hub_model_version}")
219
+ except InvalidSpecifier:
220
+ raise KeyError(f"Bad semantic version: {hub_model_version}")
221
+ available_versions_filtered = list(spec.filter(available_model_versions))
222
+ if not available_versions_filtered:
223
+ raise KeyError("Model version not available in the Hub")
224
+ hub_model_version = str(max(available_versions_filtered))
225
+
226
+ return hub_model_version
227
+
228
+
229
+ def _get_hub_model_version_for_marketplace_version(
230
+ hub_content_summaries: List[Any], marketplace_version: str
231
+ ) -> Optional[str]:
232
+ """Returns the HubContent version associated with the Marketplace version.
233
+
234
+ This function will check within the HubContentSearchKeywords for the proprietary version.
235
+ """
236
+ for model in hub_content_summaries:
237
+ model_search_keywords = model.get("HubContentSearchKeywords", [])
238
+ if _hub_search_keywords_contains_marketplace_version(
239
+ model_search_keywords, marketplace_version
240
+ ):
241
+ return model.get("HubContentVersion")
242
+
243
+ return None
244
+
245
+
246
+ def _hub_search_keywords_contains_marketplace_version(
247
+ model_search_keywords: List[str], marketplace_version: str
248
+ ) -> bool:
249
+ proprietary_version_keyword = next(
250
+ filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
251
+ )
252
+
253
+ if not proprietary_version_keyword:
254
+ return False
255
+
256
+ proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
257
+ if proprietary_version == marketplace_version:
258
+ return True
259
+
260
+ return False