sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1731 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains utils for JumpStart."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import Optional
17
+ from sagemaker.core.helper.session_helper import Session
18
+ from sagemaker.core.jumpstart.models import HubContentDocument
19
+
20
+ from copy import copy
21
+ import logging
22
+ import os
23
+ from functools import lru_cache, wraps
24
+ from typing import Any, Dict, List, Set, Optional, Tuple, Union
25
+ from urllib.parse import urlparse
26
+ import boto3
27
+ from botocore.exceptions import ClientError
28
+ from packaging.version import Version, InvalidVersion
29
+ import botocore
30
+ from sagemaker.core.shapes import ModelAccessConfig
31
+ import sagemaker
32
+ from sagemaker.core.config.config_schema import (
33
+ MODEL_ENABLE_NETWORK_ISOLATION_PATH,
34
+ MODEL_EXECUTION_ROLE_ARN_PATH,
35
+ TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
36
+ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
37
+ TRAINING_JOB_ROLE_ARN_PATH,
38
+ )
39
+
40
+ from sagemaker.core.jumpstart import constants, enums
41
+ from sagemaker.core.jumpstart import accessors
42
+ from sagemaker.core.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel
43
+ from sagemaker.core.s3 import parse_s3_url
44
+ from sagemaker.core.jumpstart.exceptions import (
45
+ DeprecatedJumpStartModelError,
46
+ VulnerableJumpStartModelError,
47
+ get_old_model_version_msg,
48
+ )
49
+ from sagemaker.core.jumpstart.types import (
50
+ JumpStartBenchmarkStat,
51
+ JumpStartMetadataConfig,
52
+ JumpStartModelHeader,
53
+ JumpStartModelSpecs,
54
+ JumpStartVersionedModelId,
55
+ DeploymentConfigMetadata,
56
+ )
57
+ from sagemaker.core.helper.session_helper import Session
58
+ from sagemaker.core.config.config import load_sagemaker_config
59
+ from sagemaker.core.common_utils import (
60
+ resolve_value_from_config,
61
+ TagsDict,
62
+ get_instance_rate_per_hour,
63
+ get_domain_for_region,
64
+ camel_case_to_pascal_case,
65
+ )
66
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
67
+
68
+
69
+ def is_pipeline_variable(var: object) -> bool:
70
+ """Check if the variable is a pipeline variable"""
71
+ return isinstance(var, PipelineVariable)
72
+
73
+
74
+ from sagemaker.core.utils.user_agent import get_user_agent_extra_suffix
75
+
76
+
77
+ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Session] = None) -> str:
78
+ """Get the EULA URL from the HubContentDocument.
79
+
80
+ Args:
81
+ document (HubContentDocument): The HubContentDocument object.
82
+ sagemaker_session (Optional[Session]): SageMaker session. Defaults to None.
83
+ Returns:
84
+ str: The EULA URL.
85
+ """
86
+ if not document.HostingEulaUri:
87
+ return ""
88
+ if sagemaker_session is None:
89
+ sagemaker_session = Session()
90
+
91
+ path_parts = document.HostingEulaUri.replace("s3://", "").split("/")
92
+
93
+ bucket = path_parts[0]
94
+ key = "/".join(path_parts[1:])
95
+ region = sagemaker_session.boto_region_name
96
+
97
+ botocore_session = sagemaker_session.boto_session._session
98
+ endpoint_resolver = botocore_session.get_component("endpoint_resolver")
99
+ partition = endpoint_resolver.get_partition_for_region(region)
100
+ dns_suffix = endpoint_resolver.get_partition_dns_suffix(partition)
101
+
102
+ return f"https://{bucket}.s3.{region}.{dns_suffix}/{key}"
103
+
104
+
105
+ def get_jumpstart_launched_regions_message() -> str:
106
+ """Returns formatted string indicating where JumpStart is launched."""
107
+ if len(constants.JUMPSTART_REGION_NAME_SET) == 0:
108
+ return "JumpStart is not available in any region."
109
+ if len(constants.JUMPSTART_REGION_NAME_SET) == 1:
110
+ region = list(constants.JUMPSTART_REGION_NAME_SET)[0]
111
+ return f"JumpStart is available in {region} region."
112
+
113
+ sorted_regions = sorted(list(constants.JUMPSTART_REGION_NAME_SET))
114
+ if len(constants.JUMPSTART_REGION_NAME_SET) == 2:
115
+ return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions."
116
+
117
+ formatted_launched_regions_list = []
118
+ for i, region in enumerate(sorted_regions):
119
+ region_prefix = "" if i < len(sorted_regions) - 1 else "and "
120
+ formatted_launched_regions_list.append(region_prefix + region)
121
+ formatted_launched_regions_str = ", ".join(formatted_launched_regions_list)
122
+ return f"JumpStart is available in {formatted_launched_regions_str} regions."
123
+
124
+
125
+ def get_jumpstart_gated_content_bucket(
126
+ region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
127
+ ) -> str:
128
+ """Returns regionalized private content bucket name for JumpStart.
129
+
130
+ Raises:
131
+ ValueError: If JumpStart is not launched in ``region`` or private content
132
+ unavailable in that region.
133
+ """
134
+
135
+ old_gated_content_bucket: Optional[str] = (
136
+ accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
137
+ )
138
+
139
+ info_logs: List[str] = []
140
+
141
+ gated_bucket_to_return: Optional[str] = None
142
+ if (
143
+ constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE in os.environ
144
+ and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE]) > 0
145
+ ):
146
+ gated_bucket_to_return = os.environ[
147
+ constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE
148
+ ]
149
+ info_logs.append(f"Using JumpStart gated bucket override: '{gated_bucket_to_return}'")
150
+ else:
151
+ try:
152
+ gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
153
+ region
154
+ ].gated_content_bucket
155
+ if gated_bucket_to_return is None:
156
+ raise ValueError(
157
+ f"No private content bucket for JumpStart exists in {region} region."
158
+ )
159
+ except KeyError:
160
+ formatted_launched_regions_str = get_jumpstart_launched_regions_message()
161
+ raise ValueError(
162
+ f"Unable to get private content bucket for JumpStart in {region} region. "
163
+ f"{formatted_launched_regions_str}"
164
+ )
165
+
166
+ accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)
167
+
168
+ if gated_bucket_to_return != old_gated_content_bucket:
169
+ if old_gated_content_bucket is not None:
170
+ accessors.JumpStartModelsAccessor.reset_cache()
171
+ for info_log in info_logs:
172
+ constants.JUMPSTART_LOGGER.info(info_log)
173
+
174
+ return gated_bucket_to_return
175
+
176
+
177
+ def get_jumpstart_content_bucket(
178
+ region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
179
+ ) -> str:
180
+ """Returns the regionalized content bucket name for JumpStart.
181
+
182
+ Raises:
183
+ ValueError: If JumpStart is not launched in ``region``.
184
+ """
185
+
186
+ old_content_bucket: Optional[str] = (
187
+ accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket()
188
+ )
189
+
190
+ info_logs: List[str] = []
191
+
192
+ bucket_to_return: Optional[str] = None
193
+ if (
194
+ constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
195
+ and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
196
+ ):
197
+ bucket_to_return = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
198
+ info_logs.append(f"Using JumpStart bucket override: '{bucket_to_return}'")
199
+ else:
200
+ try:
201
+ bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
202
+ region
203
+ ].content_bucket
204
+ except KeyError:
205
+ formatted_launched_regions_str = get_jumpstart_launched_regions_message()
206
+ raise ValueError(
207
+ f"Unable to get content bucket for Neo in {region} region. "
208
+ f"{formatted_launched_regions_str}"
209
+ )
210
+
211
+ accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return)
212
+
213
+ if bucket_to_return != old_content_bucket:
214
+ if old_content_bucket is not None:
215
+ accessors.JumpStartModelsAccessor.reset_cache()
216
+ for info_log in info_logs:
217
+ constants.JUMPSTART_LOGGER.info(info_log)
218
+ return bucket_to_return
219
+
220
+
221
+ def get_neo_content_bucket(
222
+ region: str = constants.NEO_DEFAULT_REGION_NAME,
223
+ ) -> str:
224
+ """Returns the regionalized S3 bucket name for Neo service.
225
+
226
+ Raises:
227
+ ValueError: If Neo is not launched in ``region``.
228
+ """
229
+
230
+ bucket_to_return: Optional[str] = None
231
+ if (
232
+ constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ
233
+ and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0
234
+ ):
235
+ bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]
236
+ info_log = f"Using Neo bucket override: '{bucket_to_return}'"
237
+ constants.JUMPSTART_LOGGER.info(info_log)
238
+ else:
239
+ try:
240
+ bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
241
+ region
242
+ ].neo_content_bucket
243
+ except KeyError:
244
+ raise ValueError(f"Unable to get content bucket for Neo in {region} region.")
245
+
246
+ return bucket_to_return
247
+
248
+
249
+ def get_formatted_manifest(
250
+ manifest: List[Dict],
251
+ ) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]:
252
+ """Returns formatted manifest dictionary from raw manifest.
253
+
254
+ Keys are JumpStartVersionedModelId objects, values are
255
+ ``JumpStartModelHeader`` objects.
256
+ """
257
+ manifest_dict = {}
258
+ for header in manifest:
259
+ header_obj = JumpStartModelHeader(header)
260
+ manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
261
+ header_obj
262
+ )
263
+ return manifest_dict
264
+
265
+
266
+ def get_sagemaker_version() -> str:
267
+ """Returns sagemaker library version.
268
+
269
+ If the sagemaker library version has not been set, this function
270
+ calls ``parse_sagemaker_version`` to retrieve the version and set
271
+ the constant.
272
+ """
273
+ if accessors.SageMakerSettings.get_sagemaker_version() == "":
274
+ accessors.SageMakerSettings.set_sagemaker_version(parse_sagemaker_version())
275
+ return accessors.SageMakerSettings.get_sagemaker_version()
276
+
277
+
278
+ def parse_sagemaker_version() -> str:
279
+ """Returns sagemaker library version. This should only be called once.
280
+
281
+ Function reads ``__version__`` variable in ``sagemaker`` module.
282
+ In order to maintain compatibility with the ``packaging.version``
283
+ library, versions with fewer than 2, or more than 3, periods are rejected.
284
+ All versions that cannot be parsed with ``packaging.version`` are also
285
+ rejected.
286
+
287
+ Raises:
288
+ RuntimeError: If the SageMaker version is not readable. An exception is also raised if
289
+ the version cannot be parsed by ``packaging.version``.
290
+ """
291
+ # Handle case where __version__ might not be available in development
292
+ try:
293
+ version = sagemaker.__version__
294
+ except AttributeError:
295
+ # Fallback for development environments - return a valid version format
296
+ return "3.0.0"
297
+ parsed_version = None
298
+
299
+ num_periods = version.count(".")
300
+ if num_periods == 2:
301
+ parsed_version = version
302
+ elif num_periods == 3:
303
+ trailing_period_index = version.rfind(".")
304
+ parsed_version = version[:trailing_period_index]
305
+ else:
306
+ raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}")
307
+
308
+ Version(parsed_version)
309
+
310
+ return parsed_version
311
+
312
+
313
+ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> bool:
314
+ """Determines if `model_id` and `version` input are for JumpStart.
315
+
316
+ This method returns True if both arguments are not None, false if both arguments
317
+ are None, and raises an exception if one argument is None but the other isn't.
318
+
319
+ Args:
320
+ model_id (str): Optional. Model ID of the JumpStart model.
321
+ version (str): Optional. Version of the JumpStart model.
322
+
323
+ Raises:
324
+ ValueError: If only one of the two arguments is None.
325
+ """
326
+ if model_id is not None or version is not None:
327
+ if model_id is None or version is None:
328
+ raise ValueError(
329
+ "Must specify JumpStart `model_id` and `model_version` when getting specs for "
330
+ "JumpStart models."
331
+ )
332
+ return True
333
+ return False
334
+
335
+
336
+ def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
337
+ """Returns True if URI corresponds to a JumpStart-hosted model.
338
+
339
+ Args:
340
+ uri (Optional[str]): uri for inference/training job.
341
+ """
342
+ # Handle case where uri is not a string (e.g., SourceCode object)
343
+ if uri is None or not isinstance(uri, str):
344
+ return False
345
+
346
+ bucket = None
347
+ if urlparse(uri).scheme == "s3":
348
+ bucket, _ = parse_s3_url(uri)
349
+
350
+ return bucket in constants.JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET
351
+
352
+
353
+ def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool:
354
+ """Returns True if ``tag_key`` is in the ``tag_array``.
355
+
356
+ Args:
357
+ tag_key (str): the tag key to check if it's already in the ``tag_array``.
358
+ tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``.
359
+ """
360
+ for tag in tag_array:
361
+ if tag_key == tag["Key"]:
362
+ return True
363
+ return False
364
+
365
+
366
+ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
367
+ """Return the value of a tag whose key matches the given ``tag_key``.
368
+
369
+ Args:
370
+ tag_key (str): AWS tag for which to search.
371
+ tag_array (List[Dict[str, str]]): List of AWS tags, each formatted as dicts.
372
+
373
+ Raises:
374
+ KeyError: If the number of matches for the ``tag_key`` is not equal to 1.
375
+ """
376
+ tag_values = [tag["Value"] for tag in tag_array if tag_key == tag["Key"]]
377
+ if len(tag_values) != 1:
378
+ raise KeyError(
379
+ f"Cannot get value of tag for tag key '{tag_key}' -- found {len(tag_values)} "
380
+ f"number of matches in the tag list."
381
+ )
382
+
383
+ return tag_values[0]
384
+
385
+
386
+ def add_single_jumpstart_tag(
387
+ tag_value: str,
388
+ tag_key: enums.JumpStartTag,
389
+ curr_tags: Optional[List[Dict[str, str]]],
390
+ is_uri=False,
391
+ ) -> Optional[List]:
392
+ """Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
393
+
394
+ Args:
395
+ uri (str): URI which may correspond to a JumpStart model.
396
+ tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
397
+ corresponds to a JumpStart model.
398
+ curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
399
+ is_uri (boolean): Set to True to indicate a s3 uri is to be tagged. Set to False to indicate
400
+ tags for JumpStart model id / version are being added. (Default: False).
401
+ """
402
+ if not is_uri or is_jumpstart_model_uri(tag_value):
403
+ if curr_tags is None:
404
+ curr_tags = []
405
+ if not tag_key_in_array(tag_key, curr_tags):
406
+ skip_adding_tag = (
407
+ (
408
+ tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
409
+ or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
410
+ or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
411
+ or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
412
+ or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
413
+ )
414
+ if is_uri
415
+ else False
416
+ )
417
+ if not skip_adding_tag:
418
+ curr_tags.append(
419
+ {
420
+ "Key": tag_key,
421
+ "Value": tag_value,
422
+ }
423
+ )
424
+ return curr_tags
425
+
426
+
427
+ def get_jumpstart_base_name_if_jumpstart_model(
428
+ *uris: Optional[str],
429
+ ) -> Optional[str]:
430
+ """Return default JumpStart base name if a URI belongs to JumpStart.
431
+
432
+ If no URIs belong to JumpStart, return None.
433
+
434
+ Args:
435
+ *uris (Optional[str]): URI to test for association with JumpStart.
436
+ """
437
+ for uri in uris:
438
+ if is_jumpstart_model_uri(uri):
439
+ return constants.JUMPSTART_RESOURCE_BASE_NAME
440
+ return None
441
+
442
+
443
+ def add_jumpstart_model_info_tags(
444
+ tags: Optional[List[TagsDict]],
445
+ model_id: str,
446
+ model_version: str,
447
+ model_type: Optional[enums.JumpStartModelType] = None,
448
+ config_name: Optional[str] = None,
449
+ scope: enums.JumpStartScriptScope = None,
450
+ ) -> List[TagsDict]:
451
+ """Add custom model ID and version tags to JumpStart related resources."""
452
+ if model_id is None or model_version is None:
453
+ return tags
454
+ tags = add_single_jumpstart_tag(
455
+ model_id,
456
+ enums.JumpStartTag.MODEL_ID,
457
+ tags,
458
+ is_uri=False,
459
+ )
460
+ # Only add version tag if it's not a wildcard (*)
461
+ # Wildcard is used for version resolution but is not a valid AWS tag value
462
+ if model_version != "*":
463
+ tags = add_single_jumpstart_tag(
464
+ model_version,
465
+ enums.JumpStartTag.MODEL_VERSION,
466
+ tags,
467
+ is_uri=False,
468
+ )
469
+ if model_type == enums.JumpStartModelType.PROPRIETARY:
470
+ tags = add_single_jumpstart_tag(
471
+ enums.JumpStartModelType.PROPRIETARY.value,
472
+ enums.JumpStartTag.MODEL_TYPE,
473
+ tags,
474
+ is_uri=False,
475
+ )
476
+ if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
477
+ tags = add_single_jumpstart_tag(
478
+ config_name,
479
+ enums.JumpStartTag.INFERENCE_CONFIG_NAME,
480
+ tags,
481
+ is_uri=False,
482
+ )
483
+ if config_name and scope == enums.JumpStartScriptScope.TRAINING:
484
+ tags = add_single_jumpstart_tag(
485
+ config_name,
486
+ enums.JumpStartTag.TRAINING_CONFIG_NAME,
487
+ tags,
488
+ is_uri=False,
489
+ )
490
+ return tags
491
+
492
+
493
+ def add_hub_content_arn_tags(
494
+ tags: Optional[List[TagsDict]],
495
+ hub_content_arn: str,
496
+ ) -> Optional[List[TagsDict]]:
497
+ """Adds custom Hub arn tag to JumpStart related resources."""
498
+
499
+ tags = add_single_jumpstart_tag(
500
+ hub_content_arn,
501
+ enums.JumpStartTag.HUB_CONTENT_ARN,
502
+ tags,
503
+ is_uri=False,
504
+ )
505
+ return tags
506
+
507
+
508
+ def add_bedrock_store_tags(
509
+ tags: Optional[List[TagsDict]],
510
+ compatibility: str,
511
+ ) -> Optional[List[TagsDict]]:
512
+ """Adds custom Hub arn tag to JumpStart related resources."""
513
+
514
+ tags = add_single_jumpstart_tag(
515
+ compatibility,
516
+ enums.JumpStartTag.BEDROCK,
517
+ tags,
518
+ is_uri=False,
519
+ )
520
+ return tags
521
+
522
+
523
+ def add_jumpstart_uri_tags(
524
+ tags: Optional[List[TagsDict]] = None,
525
+ inference_model_uri: Optional[Union[str, dict]] = None,
526
+ inference_script_uri: Optional[str] = None,
527
+ training_model_uri: Optional[str] = None,
528
+ training_script_uri: Optional[str] = None,
529
+ ) -> Optional[List[TagsDict]]:
530
+ """Add custom uri tags to JumpStart models, return the updated tags.
531
+
532
+ No-op if this is not a JumpStart model related resource.
533
+
534
+ Args:
535
+ tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
536
+ or training job. (Default: None).
537
+ inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact.
538
+ (Default: None).
539
+ inference_script_uri (Optional[str]): S3 URI for inference script tarball.
540
+ (Default: None).
541
+ training_model_uri (Optional[str]): S3 URI for training model artifact.
542
+ (Default: None).
543
+ training_script_uri (Optional[str]): S3 URI for training script tarball.
544
+ (Default: None).
545
+ """
546
+ warn_msg = (
547
+ "The URI (%s) is a pipeline variable which is only interpreted at execution time. "
548
+ "As a result, the JumpStart resources will not be tagged."
549
+ )
550
+
551
+ if isinstance(inference_model_uri, dict):
552
+ inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None)
553
+
554
+ if inference_model_uri:
555
+ if is_pipeline_variable(inference_model_uri):
556
+ logging.warning(warn_msg, "inference_model_uri")
557
+ else:
558
+ tags = add_single_jumpstart_tag(
559
+ inference_model_uri,
560
+ enums.JumpStartTag.INFERENCE_MODEL_URI,
561
+ tags,
562
+ is_uri=True,
563
+ )
564
+
565
+ if inference_script_uri:
566
+ if is_pipeline_variable(inference_script_uri):
567
+ logging.warning(warn_msg, "inference_script_uri")
568
+ else:
569
+ tags = add_single_jumpstart_tag(
570
+ inference_script_uri,
571
+ enums.JumpStartTag.INFERENCE_SCRIPT_URI,
572
+ tags,
573
+ is_uri=True,
574
+ )
575
+
576
+ if training_model_uri:
577
+ if is_pipeline_variable(training_model_uri):
578
+ logging.warning(warn_msg, "training_model_uri")
579
+ else:
580
+ tags = add_single_jumpstart_tag(
581
+ training_model_uri,
582
+ enums.JumpStartTag.TRAINING_MODEL_URI,
583
+ tags,
584
+ is_uri=True,
585
+ )
586
+
587
+ if training_script_uri:
588
+ if is_pipeline_variable(training_script_uri):
589
+ logging.warning(warn_msg, "training_script_uri")
590
+ else:
591
+ tags = add_single_jumpstart_tag(
592
+ training_script_uri,
593
+ enums.JumpStartTag.TRAINING_SCRIPT_URI,
594
+ tags,
595
+ is_uri=True,
596
+ )
597
+
598
+ return tags
599
+
600
+
601
+ def update_inference_tags_with_jumpstart_training_tags(
602
+ inference_tags: Optional[List[Dict[str, str]]], training_tags: Optional[List[Dict[str, str]]]
603
+ ) -> Optional[List[Dict[str, str]]]:
604
+ """Updates the tags for the ``sagemaker.model.Model.deploy`` command with any JumpStart tags.
605
+
606
+ Args:
607
+ inference_tags (Optional[List[Dict[str, str]]]): Custom tags to appy to inference job.
608
+ training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
609
+ """
610
+ if training_tags:
611
+ for tag_key in enums.JumpStartTag:
612
+ if tag_key_in_array(tag_key, training_tags):
613
+ tag_value = get_tag_value(tag_key, training_tags)
614
+ if inference_tags is None:
615
+ inference_tags = []
616
+ if not tag_key_in_array(tag_key, inference_tags):
617
+ inference_tags.append({"Key": tag_key, "Value": tag_value})
618
+
619
+ return inference_tags
620
+
621
+
622
+ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
623
+ """Returns EULA message to display if one is available, else empty string."""
624
+ if model_specs.hosting_eula_key is None:
625
+ return ""
626
+ return get_formatted_eula_message_template(
627
+ model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
628
+ )
629
+
630
+
631
+ def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
632
+ """Returns a formatted EULA message."""
633
+ return (
634
+ f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
635
+ f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
636
+ f"{get_domain_for_region(region)}"
637
+ f"/{hosting_eula_key} for terms of use."
638
+ )
639
+
640
+
641
+ def emit_logs_based_on_model_specs(
642
+ model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
643
+ ) -> None:
644
+ """Emits logs based on model specs and region."""
645
+
646
+ if model_specs.hosting_eula_key:
647
+ constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
648
+
649
+ full_version: str = model_specs.version
650
+
651
+ models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
652
+ region=region, s3_client=s3_client
653
+ )
654
+ max_version_for_model_id: Optional[str] = None
655
+ for header in models_manifest_list:
656
+ if header.model_id == model_specs.model_id:
657
+ if max_version_for_model_id is None or Version(header.version) > Version(
658
+ max_version_for_model_id
659
+ ):
660
+ max_version_for_model_id = header.version
661
+
662
+ if full_version != max_version_for_model_id:
663
+ constants.JUMPSTART_LOGGER.info(
664
+ get_old_model_version_msg(model_specs.model_id, full_version, max_version_for_model_id)
665
+ )
666
+
667
+ if model_specs.deprecated:
668
+ deprecated_message = model_specs.deprecated_message or (
669
+ "Using deprecated JumpStart model "
670
+ f"'{model_specs.model_id}' and version '{model_specs.version}'."
671
+ )
672
+
673
+ constants.JUMPSTART_LOGGER.warning(deprecated_message)
674
+
675
+ if model_specs.deprecate_warn_message:
676
+ constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message)
677
+
678
+ if model_specs.usage_info_message:
679
+ constants.JUMPSTART_LOGGER.info(model_specs.usage_info_message)
680
+
681
+ if model_specs.inference_vulnerable or model_specs.training_vulnerable:
682
+ constants.JUMPSTART_LOGGER.warning(
683
+ "Using vulnerable JumpStart model '%s' and version '%s'.",
684
+ model_specs.model_id,
685
+ model_specs.version,
686
+ )
687
+
688
+
689
+ def verify_model_region_and_return_specs(
690
+ model_id: Optional[str],
691
+ version: Optional[str],
692
+ scope: Optional[str],
693
+ region: Optional[str] = None,
694
+ hub_arn: Optional[str] = None,
695
+ tolerate_vulnerable_model: bool = False,
696
+ tolerate_deprecated_model: bool = False,
697
+ sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
698
+ model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
699
+ config_name: Optional[str] = None,
700
+ ) -> JumpStartModelSpecs:
701
+ """Verifies that an acceptable model_id, version, scope, and region combination is provided.
702
+
703
+ Args:
704
+ model_id (Optional[str]): model ID of the JumpStart model to verify and
705
+ obtains specs.
706
+ version (Optional[str]): version of the JumpStart model to verify and
707
+ obtains specs.
708
+ scope (Optional[str]): scope of the JumpStart model to verify.
709
+ region (Optional[str]): region of the JumpStart model to verify and
710
+ obtains specs.
711
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
712
+ model details from. (default: None).
713
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
714
+ specifications should be tolerated (exception not raised). If False, raises an
715
+ exception if the script used by this version of the model has dependencies with known
716
+ security vulnerabilities. (Default: False).
717
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
718
+ (exception not raised). False if these models should raise an exception.
719
+ (Default: False).
720
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
721
+ object, used for SageMaker interactions. If not
722
+ specified, one is created using the default AWS configuration
723
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
724
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
725
+
726
+ Raises:
727
+ NotImplementedError: If the scope is not supported.
728
+ ValueError: If the combination of arguments specified is not supported.
729
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
730
+ known security vulnerabilities.
731
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
732
+ """
733
+
734
+ region = region or get_region_fallback(
735
+ sagemaker_session=sagemaker_session,
736
+ )
737
+
738
+ if scope is None:
739
+ raise ValueError(
740
+ "Must specify `model_scope` argument to retrieve model "
741
+ "artifact uri for JumpStart models."
742
+ )
743
+
744
+ if scope not in constants.SUPPORTED_JUMPSTART_SCOPES:
745
+ raise NotImplementedError(
746
+ "JumpStart models only support scopes: "
747
+ f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}."
748
+ )
749
+
750
+ model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore
751
+ region=region,
752
+ model_id=model_id,
753
+ hub_arn=hub_arn,
754
+ version=version,
755
+ s3_client=sagemaker_session.s3_client,
756
+ model_type=model_type,
757
+ sagemaker_session=sagemaker_session,
758
+ )
759
+
760
+ if (
761
+ scope == constants.JumpStartScriptScope.TRAINING.value
762
+ and not model_specs.training_supported
763
+ ):
764
+ raise ValueError(
765
+ f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
766
+ )
767
+
768
+ if model_specs.deprecated:
769
+ if not tolerate_deprecated_model:
770
+ raise DeprecatedJumpStartModelError(
771
+ model_id=model_id, version=version, message=model_specs.deprecated_message
772
+ )
773
+
774
+ if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
775
+ if not tolerate_vulnerable_model:
776
+ raise VulnerableJumpStartModelError(
777
+ model_id=model_id,
778
+ version=version,
779
+ vulnerabilities=model_specs.inference_vulnerabilities,
780
+ scope=constants.JumpStartScriptScope.INFERENCE,
781
+ )
782
+
783
+ if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable:
784
+ if not tolerate_vulnerable_model:
785
+ raise VulnerableJumpStartModelError(
786
+ model_id=model_id,
787
+ version=version,
788
+ vulnerabilities=model_specs.training_vulnerabilities,
789
+ scope=constants.JumpStartScriptScope.TRAINING,
790
+ )
791
+
792
+ if model_specs and config_name:
793
+ model_specs.set_config(config_name, scope)
794
+
795
+ return model_specs
796
+
797
+
798
+ def update_dict_if_key_not_present(
799
+ dict_to_update: Optional[dict], key_to_add: Any, value_to_add: Any
800
+ ) -> Optional[dict]:
801
+ """If a key is not present in the dict, add the new (key, value) pair, and return dict.
802
+
803
+ If dict is empty, return None.
804
+ """
805
+ if dict_to_update is None:
806
+ dict_to_update = {}
807
+ if key_to_add not in dict_to_update:
808
+ dict_to_update[key_to_add] = value_to_add
809
+ if dict_to_update == {}:
810
+ dict_to_update = None
811
+
812
+ return dict_to_update
813
+
814
+
815
+ def resolve_model_sagemaker_config_field(
816
+ field_name: str,
817
+ field_val: Optional[Any],
818
+ sagemaker_session: Session,
819
+ default_value: Optional[str] = None,
820
+ ) -> Any:
821
+ """Given a field name, checks if there is a sagemaker config value to set.
822
+
823
+ For the role field, which is customer-supplied, we allow ``field_val`` to take precedence
824
+ over sagemaker config values. For all other fields, sagemaker config values take precedence
825
+ over the JumpStart default fields.
826
+ """
827
+ # In case, sagemaker_session is None, get sagemaker_config from load_sagemaker_config()
828
+ # to resolve value from config for the respective field_name parameter
829
+ _sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
830
+
831
+ # We allow customers to define a role which takes precedence
832
+ # over the one defined in sagemaker config
833
+ if field_name == "role":
834
+ return resolve_value_from_config(
835
+ direct_input=field_val,
836
+ config_path=MODEL_EXECUTION_ROLE_ARN_PATH,
837
+ default_value=default_value or sagemaker_session.get_caller_identity_arn(),
838
+ sagemaker_session=sagemaker_session,
839
+ sagemaker_config=_sagemaker_config,
840
+ )
841
+
842
+ # JumpStart Models have certain default field values. We want
843
+ # sagemaker config values to take priority over the model-specific defaults.
844
+ if field_name == "enable_network_isolation":
845
+ resolved_val = resolve_value_from_config(
846
+ direct_input=None,
847
+ config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
848
+ sagemaker_session=sagemaker_session,
849
+ default_value=default_value,
850
+ sagemaker_config=_sagemaker_config,
851
+ )
852
+ return resolved_val if resolved_val is not None else field_val
853
+
854
+ # field is not covered by sagemaker config so return as is
855
+ return field_val
856
+
857
+
858
+ def resolve_estimator_sagemaker_config_field(
859
+ field_name: str,
860
+ field_val: Optional[Any],
861
+ sagemaker_session: Session,
862
+ default_value: Optional[str] = None,
863
+ ) -> Any:
864
+ """Given a field name, checks if there is a sagemaker config value to set.
865
+
866
+ For the role field, which is customer-supplied, we allow ``field_val`` to take precedence
867
+ over sagemaker config values. For all other fields, sagemaker config values take precedence
868
+ over the JumpStart default fields.
869
+ """
870
+
871
+ # Workaround for config injection if sagemaker_session is None, since in
872
+ # that case sagemaker_session will not be initialized until
873
+ # `_init_sagemaker_session_if_does_not_exist` is called later
874
+ _sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
875
+
876
+ # We allow customers to define a role which takes precedence
877
+ # over the one defined in sagemaker config
878
+ if field_name == "role":
879
+ return resolve_value_from_config(
880
+ direct_input=field_val,
881
+ config_path=TRAINING_JOB_ROLE_ARN_PATH,
882
+ default_value=default_value or sagemaker_session.get_caller_identity_arn(),
883
+ sagemaker_session=sagemaker_session,
884
+ sagemaker_config=_sagemaker_config,
885
+ )
886
+
887
+ # JumpStart Estimators have certain default field values. We want
888
+ # sagemaker config values to take priority over the model-specific defaults.
889
+ if field_name == "enable_network_isolation":
890
+
891
+ resolved_val = resolve_value_from_config(
892
+ direct_input=None,
893
+ config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
894
+ sagemaker_session=sagemaker_session,
895
+ default_value=default_value,
896
+ sagemaker_config=_sagemaker_config,
897
+ )
898
+ return resolved_val if resolved_val is not None else field_val
899
+
900
+ if field_name == "encrypt_inter_container_traffic":
901
+
902
+ resolved_val = resolve_value_from_config(
903
+ direct_input=None,
904
+ config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
905
+ sagemaker_session=sagemaker_session,
906
+ default_value=default_value,
907
+ sagemaker_config=_sagemaker_config,
908
+ )
909
+ return resolved_val if resolved_val is not None else field_val
910
+
911
+ # field is not covered by sagemaker config so return as is
912
+ return field_val
913
+
914
+
915
+ def validate_model_id_and_get_type(
916
+ model_id: Optional[str],
917
+ region: Optional[str] = None,
918
+ model_version: Optional[str] = None,
919
+ script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
920
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
921
+ hub_arn: Optional[str] = None,
922
+ ) -> Optional[enums.JumpStartModelType]:
923
+ """Returns model type if the model ID is supported for the given script.
924
+
925
+ Raises:
926
+ ValueError: If the script is not supported by JumpStart.
927
+ """
928
+
929
+ if model_id in {None, ""}:
930
+ return None
931
+ if not isinstance(model_id, str):
932
+ return None
933
+ if hub_arn:
934
+ model_types = _validate_hub_service_model_id_and_get_type(
935
+ model_id=model_id,
936
+ hub_arn=hub_arn,
937
+ region=region,
938
+ model_version=model_version,
939
+ sagemaker_session=sagemaker_session,
940
+ )
941
+ return (
942
+ model_types[0] if model_types else None
943
+ ) # Currently this function only supports one model type
944
+
945
+ s3_client = sagemaker_session.s3_client if sagemaker_session else None
946
+ region = region or constants.JUMPSTART_DEFAULT_REGION_NAME
947
+ model_version = model_version or "*"
948
+ models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
949
+ region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS
950
+ )
951
+ open_weight_model_id_set = {model.model_id for model in models_manifest_list}
952
+
953
+ if model_id in open_weight_model_id_set:
954
+ return enums.JumpStartModelType.OPEN_WEIGHTS
955
+
956
+ proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
957
+ region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY
958
+ )
959
+
960
+ proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list}
961
+ if model_id in proprietary_model_id_set:
962
+ if script == enums.JumpStartScriptScope.INFERENCE:
963
+ return enums.JumpStartModelType.PROPRIETARY
964
+ raise ValueError(f"Unsupported script for Proprietary models: {script}")
965
+ return None
966
+
967
+
968
+ def _validate_hub_service_model_id_and_get_type(
969
+ model_id: Optional[str],
970
+ hub_arn: str,
971
+ region: Optional[str] = None,
972
+ model_version: Optional[str] = None,
973
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
974
+ ) -> List[enums.JumpStartModelType]:
975
+ """Returns a list of JumpStartModelType based off the HubContent.
976
+
977
+ Only returns valid JumpStartModelType. Returns an empty array if none are found.
978
+ """
979
+ hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
980
+ region=region,
981
+ model_id=model_id,
982
+ version=model_version,
983
+ hub_arn=hub_arn,
984
+ sagemaker_session=sagemaker_session,
985
+ )
986
+
987
+ hub_content_model_types = []
988
+ model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", [])
989
+ model_types = model_types_field if model_types_field else []
990
+ for model_type in model_types:
991
+ try:
992
+ hub_content_model_types.append(enums.JumpStartModelType[model_type])
993
+ except ValueError:
994
+ continue
995
+
996
+ return hub_content_model_types
997
+
998
+
999
+ def _extract_value_from_list_of_tags(
1000
+ tag_keys: List[str],
1001
+ list_tags_result: List[str],
1002
+ resource_name: str,
1003
+ resource_arn: str,
1004
+ ):
1005
+ """Extracts value from list of tags with check of duplicate tags.
1006
+
1007
+ Returns None if no value is found.
1008
+ """
1009
+ resolved_value = None
1010
+ for tag_key in tag_keys:
1011
+ try:
1012
+ value_from_tag = get_tag_value(tag_key, list_tags_result)
1013
+ except KeyError:
1014
+ continue
1015
+ if value_from_tag is not None:
1016
+ if resolved_value is not None and value_from_tag != resolved_value:
1017
+ constants.JUMPSTART_LOGGER.warning(
1018
+ "Found multiple %s tags on the following resource: %s",
1019
+ resource_name,
1020
+ resource_arn,
1021
+ )
1022
+ resolved_value = None
1023
+ break
1024
+ resolved_value = value_from_tag
1025
+ return resolved_value
1026
+
1027
+
1028
+ def get_jumpstart_model_info_from_resource_arn(
1029
+ resource_arn: str,
1030
+ sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1031
+ ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
1032
+ """Returns the JumpStart model ID, version and config name if in resource tags.
1033
+
1034
+ Returns 'None' if model ID or version or config name cannot be inferred from tags.
1035
+ """
1036
+
1037
+ list_tags_result = sagemaker_session.list_tags(resource_arn)
1038
+
1039
+ model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
1040
+ model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
1041
+ inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
1042
+ training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]
1043
+
1044
+ model_id: Optional[str] = _extract_value_from_list_of_tags(
1045
+ tag_keys=model_id_keys,
1046
+ list_tags_result=list_tags_result,
1047
+ resource_name="model ID",
1048
+ resource_arn=resource_arn,
1049
+ )
1050
+
1051
+ model_version: Optional[str] = _extract_value_from_list_of_tags(
1052
+ tag_keys=model_version_keys,
1053
+ list_tags_result=list_tags_result,
1054
+ resource_name="model version",
1055
+ resource_arn=resource_arn,
1056
+ )
1057
+
1058
+ inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
1059
+ tag_keys=inference_config_name_keys,
1060
+ list_tags_result=list_tags_result,
1061
+ resource_name="inference config name",
1062
+ resource_arn=resource_arn,
1063
+ )
1064
+
1065
+ training_config_name: Optional[str] = _extract_value_from_list_of_tags(
1066
+ tag_keys=training_config_name_keys,
1067
+ list_tags_result=list_tags_result,
1068
+ resource_name="training config name",
1069
+ resource_arn=resource_arn,
1070
+ )
1071
+
1072
+ return model_id, model_version, inference_config_name, training_config_name
1073
+
1074
+
1075
+ def get_region_fallback(
1076
+ s3_bucket_name: Optional[str] = None,
1077
+ s3_client: Optional[boto3.client] = None,
1078
+ sagemaker_session: Optional[Session] = None,
1079
+ ) -> str:
1080
+ """Returns region to use for JumpStart functionality implicitly via session objects."""
1081
+ regions_in_s3_bucket_name: Set[str] = {
1082
+ region
1083
+ for region in constants.JUMPSTART_REGION_NAME_SET
1084
+ if s3_bucket_name is not None
1085
+ if region in s3_bucket_name
1086
+ }
1087
+ regions_in_s3_client_endpoint_url: Set[str] = {
1088
+ region
1089
+ for region in constants.JUMPSTART_REGION_NAME_SET
1090
+ if s3_client is not None
1091
+ if region in s3_client._endpoint.host
1092
+ }
1093
+
1094
+ regions_in_sagemaker_session: Set[str] = {
1095
+ region
1096
+ for region in constants.JUMPSTART_REGION_NAME_SET
1097
+ if sagemaker_session
1098
+ if region == sagemaker_session.boto_region_name
1099
+ }
1100
+
1101
+ combined_regions = regions_in_s3_client_endpoint_url.union(
1102
+ regions_in_s3_bucket_name, regions_in_sagemaker_session
1103
+ )
1104
+
1105
+ if len(combined_regions) > 1:
1106
+ raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.")
1107
+
1108
+ if len(combined_regions) == 0:
1109
+ return constants.JUMPSTART_DEFAULT_REGION_NAME
1110
+
1111
+ return list(combined_regions)[0]
1112
+
1113
+
1114
+ def get_config_names(
1115
+ region: str,
1116
+ model_id: str,
1117
+ model_version: str,
1118
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1119
+ scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1120
+ model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1121
+ ) -> List[str]:
1122
+ """Returns a list of config names for the given model ID and region.
1123
+
1124
+ Raises:
1125
+ ValueError: If the script scope is not supported by JumpStart.
1126
+ """
1127
+ model_specs = verify_model_region_and_return_specs(
1128
+ region=region,
1129
+ model_id=model_id,
1130
+ version=model_version,
1131
+ sagemaker_session=sagemaker_session,
1132
+ scope=scope,
1133
+ model_type=model_type,
1134
+ )
1135
+
1136
+ if scope == enums.JumpStartScriptScope.INFERENCE:
1137
+ metadata_configs = model_specs.inference_configs
1138
+ elif scope == enums.JumpStartScriptScope.TRAINING:
1139
+ metadata_configs = model_specs.training_configs
1140
+ else:
1141
+ raise ValueError(f"Unknown script scope: {scope}.")
1142
+
1143
+ return list(metadata_configs.configs.keys()) if metadata_configs else []
1144
+
1145
+
1146
+ def get_benchmark_stats(
1147
+ region: str,
1148
+ model_id: str,
1149
+ model_version: str,
1150
+ config_names: Optional[List[str]] = None,
1151
+ hub_arn: Optional[str] = None,
1152
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1153
+ scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1154
+ model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1155
+ ) -> Dict[str, List[JumpStartBenchmarkStat]]:
1156
+ """Returns benchmark stats for the given model ID and region.
1157
+
1158
+ Raises:
1159
+ ValueError: If the script scope is not supported by JumpStart.
1160
+ """
1161
+ model_specs = verify_model_region_and_return_specs(
1162
+ region=region,
1163
+ model_id=model_id,
1164
+ version=model_version,
1165
+ hub_arn=hub_arn,
1166
+ sagemaker_session=sagemaker_session,
1167
+ scope=scope,
1168
+ model_type=model_type,
1169
+ )
1170
+
1171
+ if scope == enums.JumpStartScriptScope.INFERENCE:
1172
+ metadata_configs = model_specs.inference_configs
1173
+ elif scope == enums.JumpStartScriptScope.TRAINING:
1174
+ metadata_configs = model_specs.training_configs
1175
+ else:
1176
+ raise ValueError(f"Unknown script scope: {scope}.")
1177
+
1178
+ if not config_names:
1179
+ config_names = metadata_configs.configs.keys() if metadata_configs else []
1180
+
1181
+ benchmark_stats = {}
1182
+ for config_name in config_names:
1183
+ if config_name not in metadata_configs.configs:
1184
+ raise ValueError(f"Unknown config name: {config_name}")
1185
+ benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics
1186
+
1187
+ return benchmark_stats
1188
+
1189
+
1190
+ def get_jumpstart_configs(
1191
+ region: str,
1192
+ model_id: str,
1193
+ model_version: str,
1194
+ config_names: Optional[List[str]] = None,
1195
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1196
+ scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1197
+ model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1198
+ hub_arn: Optional[str] = None,
1199
+ ) -> Dict[str, JumpStartMetadataConfig]:
1200
+ """Returns metadata configs for the given model ID and region.
1201
+
1202
+ Raises:
1203
+ ValueError: If the script scope is not supported by JumpStart.
1204
+ """
1205
+ model_specs = verify_model_region_and_return_specs(
1206
+ region=region,
1207
+ model_id=model_id,
1208
+ version=model_version,
1209
+ sagemaker_session=sagemaker_session,
1210
+ scope=scope,
1211
+ model_type=model_type,
1212
+ hub_arn=hub_arn,
1213
+ )
1214
+
1215
+ if scope == enums.JumpStartScriptScope.INFERENCE:
1216
+ metadata_configs = model_specs.inference_configs
1217
+ elif scope == enums.JumpStartScriptScope.TRAINING:
1218
+ metadata_configs = model_specs.training_configs
1219
+ else:
1220
+ raise ValueError(f"Unknown script scope: {scope}.")
1221
+
1222
+ if not config_names:
1223
+ config_names = (
1224
+ metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
1225
+ )
1226
+
1227
+ if hub_arn:
1228
+ return (
1229
+ {
1230
+ config_name: metadata_configs.configs[
1231
+ camel_to_snake(snake_to_upper_camel(config_name))
1232
+ ]
1233
+ for config_name in config_names
1234
+ }
1235
+ if metadata_configs
1236
+ else {}
1237
+ )
1238
+ return (
1239
+ {config_name: metadata_configs.configs[config_name] for config_name in config_names}
1240
+ if metadata_configs
1241
+ else {}
1242
+ )
1243
+
1244
+
1245
+ def get_jumpstart_user_agent_extra_suffix(
1246
+ model_id: Optional[str],
1247
+ model_version: Optional[str],
1248
+ config_name: Optional[str],
1249
+ is_hub_content: Optional[bool],
1250
+ ) -> str:
1251
+ """Returns the model-specific user agent string to be added to requests."""
1252
+ sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
1253
+ jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
1254
+ config_specific_suffix = f"md/js_config#{config_name}"
1255
+ hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"
1256
+
1257
+ if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
1258
+ headers = sagemaker_python_sdk_headers
1259
+ elif is_hub_content is True:
1260
+ if model_id is None and model_version is None:
1261
+ headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}"
1262
+ else:
1263
+ headers = (
1264
+ f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}"
1265
+ )
1266
+ else:
1267
+ headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
1268
+
1269
+ if config_name:
1270
+ headers = f"{headers} {config_specific_suffix}"
1271
+
1272
+ return headers
1273
+
1274
+
1275
+ def get_top_ranked_config_name(
1276
+ region: str,
1277
+ model_id: str,
1278
+ model_version: str,
1279
+ sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1280
+ scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1281
+ model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1282
+ tolerate_deprecated_model: bool = False,
1283
+ tolerate_vulnerable_model: bool = False,
1284
+ hub_arn: Optional[str] = None,
1285
+ ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT,
1286
+ ) -> Optional[str]:
1287
+ """Returns the top ranked config name for the given model ID and region.
1288
+
1289
+ Raises:
1290
+ ValueError: If the script scope is not supported by JumpStart.
1291
+ """
1292
+ model_specs = verify_model_region_and_return_specs(
1293
+ model_id=model_id,
1294
+ version=model_version,
1295
+ scope=scope,
1296
+ region=region,
1297
+ hub_arn=hub_arn,
1298
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
1299
+ tolerate_deprecated_model=tolerate_deprecated_model,
1300
+ sagemaker_session=sagemaker_session,
1301
+ model_type=model_type,
1302
+ )
1303
+
1304
+ if scope == enums.JumpStartScriptScope.INFERENCE:
1305
+ return (
1306
+ model_specs.inference_configs.get_top_config_from_ranking(
1307
+ ranking_name=ranking_name
1308
+ ).config_name
1309
+ if model_specs.inference_configs
1310
+ else None
1311
+ )
1312
+ if scope == enums.JumpStartScriptScope.TRAINING:
1313
+ return (
1314
+ model_specs.training_configs.get_top_config_from_ranking(
1315
+ ranking_name=ranking_name
1316
+ ).config_name
1317
+ if model_specs.training_configs
1318
+ else None
1319
+ )
1320
+ raise ValueError(f"Unsupported script scope: {scope}.")
1321
+
1322
+
1323
+ def get_default_jumpstart_session_with_user_agent_suffix(
1324
+ model_id: Optional[str] = None,
1325
+ model_version: Optional[str] = None,
1326
+ config_name: Optional[str] = None,
1327
+ is_hub_content: Optional[bool] = False,
1328
+ ) -> Session:
1329
+ """Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
1330
+ botocore_session = botocore.session.get_session()
1331
+ botocore_config = botocore.config.Config(
1332
+ user_agent_extra=get_jumpstart_user_agent_extra_suffix(
1333
+ model_id=model_id,
1334
+ model_version=model_version,
1335
+ config_name=config_name,
1336
+ is_hub_content=is_hub_content,
1337
+ ),
1338
+ )
1339
+ botocore_session.set_default_client_config(botocore_config)
1340
+ # shallow copy to not affect default session constant
1341
+ session = copy(constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION)
1342
+ session.boto_session = boto3.Session(
1343
+ region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, botocore_session=botocore_session
1344
+ )
1345
+ session.sagemaker_client = boto3.client(
1346
+ "sagemaker", region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, config=botocore_config
1347
+ )
1348
+ session.sagemaker_runtime_client = boto3.client(
1349
+ "sagemaker-runtime",
1350
+ region_name=constants.JUMPSTART_DEFAULT_REGION_NAME,
1351
+ config=botocore_config,
1352
+ )
1353
+ return session
1354
+
1355
+
1356
+ def add_instance_rate_stats_to_benchmark_metrics(
1357
+ region: str,
1358
+ benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]],
1359
+ ) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
1360
+ """Adds instance types metric stats to the given benchmark_metrics dict.
1361
+
1362
+ Args:
1363
+ region (str): AWS region.
1364
+ benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]):
1365
+ Returns:
1366
+ Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
1367
+ Contains Error and metrics.
1368
+ """
1369
+ if not benchmark_metrics:
1370
+ return None
1371
+
1372
+ err_message = None
1373
+ final_benchmark_metrics = {}
1374
+ for instance_type, benchmark_metric_stats in benchmark_metrics.items():
1375
+ instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}"
1376
+
1377
+ if not has_instance_rate_stat(benchmark_metric_stats) and not err_message:
1378
+ try:
1379
+ instance_type_rate = get_instance_rate_per_hour(
1380
+ instance_type=instance_type, region=region
1381
+ )
1382
+
1383
+ if not benchmark_metric_stats:
1384
+ benchmark_metric_stats = []
1385
+ benchmark_metric_stats.append(
1386
+ JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate})
1387
+ )
1388
+
1389
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
1390
+ except ClientError as e:
1391
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
1392
+ err_message = e.response["Error"]
1393
+ except Exception: # pylint: disable=W0703
1394
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
1395
+ else:
1396
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
1397
+
1398
+ return err_message, final_benchmark_metrics
1399
+
1400
+
1401
+ def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool:
1402
+ """Determines whether a benchmark metric stats contains instance rate metric stat.
1403
+
1404
+ Args:
1405
+ benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]):
1406
+ List of benchmark metric stats.
1407
+ Returns:
1408
+ bool: Whether the benchmark metric stats contains instance rate metric stat.
1409
+ """
1410
+ if benchmark_metric_stats is None:
1411
+ return True
1412
+ for benchmark_metric_stat in benchmark_metric_stats:
1413
+ if benchmark_metric_stat.name.lower() == "instance rate":
1414
+ return True
1415
+ return False
1416
+
1417
+
1418
+ def get_metrics_from_deployment_configs(
1419
+ deployment_configs: Optional[List[DeploymentConfigMetadata]],
1420
+ ) -> Dict[str, List[str]]:
1421
+ """Extracts benchmark metrics from deployment configs metadata.
1422
+
1423
+ Args:
1424
+ deployment_configs (Optional[List[DeploymentConfigMetadata]]):
1425
+ List of deployment configs metadata.
1426
+ Returns:
1427
+ Dict[str, List[str]]: Deployment configs bench metrics dict.
1428
+ """
1429
+ if not deployment_configs:
1430
+ return {}
1431
+
1432
+ data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []}
1433
+ instance_rate_data = {}
1434
+ for index, deployment_config in enumerate(deployment_configs):
1435
+ benchmark_metrics = deployment_config.benchmark_metrics
1436
+ if not deployment_config.deployment_args or not benchmark_metrics:
1437
+ continue
1438
+
1439
+ for current_instance_type, current_instance_type_metrics in benchmark_metrics.items():
1440
+ instance_type_rate, concurrent_users = _normalize_benchmark_metrics(
1441
+ current_instance_type_metrics
1442
+ )
1443
+
1444
+ for concurrent_user, metrics in concurrent_users.items():
1445
+ instance_type_to_display = (
1446
+ f"{current_instance_type} (Default)"
1447
+ if index == 0
1448
+ and concurrent_user
1449
+ and int(concurrent_user) == 1
1450
+ and current_instance_type
1451
+ == deployment_config.deployment_args.default_instance_type
1452
+ else current_instance_type
1453
+ )
1454
+
1455
+ data["Config Name"].append(deployment_config.deployment_config_name)
1456
+ data["Instance Type"].append(instance_type_to_display)
1457
+ data["Concurrent Users"].append(concurrent_user)
1458
+
1459
+ if instance_type_rate:
1460
+ instance_rate_column_name = (
1461
+ f"{instance_type_rate.name} ({instance_type_rate.unit})"
1462
+ )
1463
+ instance_rate_data[instance_rate_column_name] = instance_rate_data.get(
1464
+ instance_rate_column_name, []
1465
+ )
1466
+ instance_rate_data[instance_rate_column_name].append(instance_type_rate.value)
1467
+
1468
+ for metric in metrics:
1469
+ column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit)
1470
+ data[column_name] = data.get(column_name, [])
1471
+ data[column_name].append(metric.value)
1472
+
1473
+ data = {**data, **instance_rate_data}
1474
+ return data
1475
+
1476
+
1477
+ def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str:
1478
+ """Normalizes benchmark metric column name.
1479
+
1480
+ Args:
1481
+ name (str): Name of the metric.
1482
+ unit (str): Unit of the metric.
1483
+ Returns:
1484
+ str: Normalized metric column name.
1485
+ """
1486
+ if "latency" in name.lower():
1487
+ name = f"Latency, TTFT (P50 in {unit.lower()})"
1488
+ elif "throughput" in name.lower():
1489
+ name = f"Throughput (P50 in {unit.lower()}/user)"
1490
+ return name
1491
+
1492
+
1493
+ def _normalize_benchmark_metrics(
1494
+ benchmark_metric_stats: List[JumpStartBenchmarkStat],
1495
+ ) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
1496
+ """Normalizes benchmark metrics dict.
1497
+
1498
+ Args:
1499
+ benchmark_metric_stats (List[JumpStartBenchmarkStat]):
1500
+ List of benchmark metrics stats.
1501
+ Returns:
1502
+ Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
1503
+ Normalized benchmark metrics dict.
1504
+ """
1505
+ instance_type_rate = None
1506
+ concurrent_users = {}
1507
+ for current_instance_type_metric in benchmark_metric_stats:
1508
+ if "instance rate" in current_instance_type_metric.name.lower():
1509
+ instance_type_rate = current_instance_type_metric
1510
+ elif current_instance_type_metric.concurrency not in concurrent_users:
1511
+ concurrent_users[current_instance_type_metric.concurrency] = [
1512
+ current_instance_type_metric
1513
+ ]
1514
+ else:
1515
+ concurrent_users[current_instance_type_metric.concurrency].append(
1516
+ current_instance_type_metric
1517
+ )
1518
+
1519
+ return instance_type_rate, concurrent_users
1520
+
1521
+
1522
+ def deployment_config_response_data(
1523
+ deployment_configs: Optional[List[DeploymentConfigMetadata]],
1524
+ ) -> List[Dict[str, Any]]:
1525
+ """Deployment config api response data.
1526
+
1527
+ Args:
1528
+ deployment_configs (Optional[List[DeploymentConfigMetadata]]):
1529
+ List of deployment configs metadata.
1530
+ Returns:
1531
+ List[Dict[str, Any]]: List of deployment config api response data.
1532
+ """
1533
+ configs = []
1534
+ if not deployment_configs:
1535
+ return configs
1536
+
1537
+ for deployment_config in deployment_configs:
1538
+ deployment_config_json = deployment_config.to_json()
1539
+ benchmark_metrics = deployment_config_json.get("BenchmarkMetrics")
1540
+ if benchmark_metrics and deployment_config.deployment_args:
1541
+ deployment_config_json["BenchmarkMetrics"] = {
1542
+ deployment_config.deployment_args.instance_type: benchmark_metrics.get(
1543
+ deployment_config.deployment_args.instance_type
1544
+ )
1545
+ }
1546
+
1547
+ configs.append(deployment_config_json)
1548
+ return configs
1549
+
1550
+
1551
+ def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False):
1552
+ """LRU cache for deployment configs."""
1553
+
1554
+ def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool:
1555
+ """Determines whether metadata config contains instance rate metric stat.
1556
+
1557
+ Args:
1558
+ config (DeploymentConfigMetadata): Metadata config metadata.
1559
+ Returns:
1560
+ bool: Whether the metadata config contains instance rate metric stat.
1561
+ """
1562
+ if config.benchmark_metrics is None:
1563
+ return True
1564
+ for benchmark_metric_stats in config.benchmark_metrics.values():
1565
+ if not has_instance_rate_stat(benchmark_metric_stats):
1566
+ return False
1567
+ return True
1568
+
1569
+ def wrapper_cache(f):
1570
+ f = lru_cache(maxsize=maxsize, typed=typed)(f)
1571
+
1572
+ @wraps(f)
1573
+ def wrapped_f(*args, **kwargs):
1574
+ res = f(*args, **kwargs)
1575
+
1576
+ # Clear cache on first call if
1577
+ # - The output does not contain Instant rate metrics
1578
+ # as this is caused by missing policy.
1579
+ if f.cache_info().hits == 0 and f.cache_info().misses == 1:
1580
+ if isinstance(res, list):
1581
+ for item in res:
1582
+ if isinstance(
1583
+ item, DeploymentConfigMetadata
1584
+ ) and not has_instance_rate_metric(item):
1585
+ f.cache_clear()
1586
+ break
1587
+ elif isinstance(res, dict):
1588
+ keys = list(res.keys())
1589
+ if len(keys) == 0 or "Instance Rate" not in keys[-1]:
1590
+ f.cache_clear()
1591
+ elif len(res[keys[1]]) > len(res[keys[-1]]):
1592
+ del res[keys[-1]]
1593
+ f.cache_clear()
1594
+ return res
1595
+
1596
+ wrapped_f.cache_info = f.cache_info
1597
+ wrapped_f.cache_clear = f.cache_clear
1598
+ return wrapped_f
1599
+
1600
+ if _func is None:
1601
+ return wrapper_cache
1602
+ return wrapper_cache(_func)
1603
+
1604
+
1605
+ def _add_model_access_configs_to_model_data_sources(
1606
+ model_data_sources: List[Dict[str, any]],
1607
+ model_access_configs: Dict[str, ModelAccessConfig],
1608
+ model_id: str,
1609
+ region: str,
1610
+ ) -> List[Dict[str, any]]:
1611
+ """Iterate over the accept EULA configs to ensure all channels are matched
1612
+
1613
+ Args:
1614
+ model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
1615
+ model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
1616
+ model_id (DeploymentConfigMetadata): Jumpstart model id.
1617
+ region (str): Region where the user is operating in.
1618
+ Returns:
1619
+ List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
1620
+ Raise:
1621
+ ValueError if at least one channel that requires EULA acceptance as not passed.
1622
+ """
1623
+ if not model_data_sources:
1624
+ return model_data_sources
1625
+
1626
+ acked_model_data_sources = []
1627
+ for model_data_source in model_data_sources:
1628
+ hosting_eula_key = model_data_source.get("HostingEulaKey")
1629
+ mutable_model_data_source = model_data_source.copy()
1630
+ if hosting_eula_key:
1631
+ if (
1632
+ not model_access_configs
1633
+ or not model_access_configs.get(model_id)
1634
+ or not model_access_configs.get(model_id).accept_eula
1635
+ ):
1636
+ eula_message_template = (
1637
+ "{model_source}{base_eula_message}{model_access_configs_message}"
1638
+ )
1639
+ model_access_config_entry = (
1640
+ '"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
1641
+ )
1642
+ raise ValueError(
1643
+ eula_message_template.format(
1644
+ model_source="Additional " if model_data_source.get("ChannelName") else "",
1645
+ base_eula_message=get_formatted_eula_message_template(
1646
+ model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
1647
+ ),
1648
+ model_access_configs_message=(
1649
+ "Please add a ModelAccessConfig entry:"
1650
+ f" {model_access_config_entry} "
1651
+ "to model_access_configs to accept the EULA."
1652
+ ),
1653
+ )
1654
+ )
1655
+ mutable_model_data_source.pop(
1656
+ "HostingEulaKey"
1657
+ ) # pop when model access config is applied
1658
+ mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
1659
+ camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
1660
+ )
1661
+ acked_model_data_sources.append(mutable_model_data_source)
1662
+ else:
1663
+ if "HostingEulaKey" in mutable_model_data_source:
1664
+ mutable_model_data_source.pop(
1665
+ "HostingEulaKey"
1666
+ ) # pop when model access config is not applicable
1667
+ acked_model_data_sources.append(mutable_model_data_source)
1668
+ return acked_model_data_sources
1669
+
1670
+
1671
+ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
1672
+ """Returns the correct content bucket for a 1p draft model."""
1673
+ neo_bucket = get_neo_content_bucket(region=region)
1674
+ if not provider:
1675
+ return neo_bucket
1676
+ provider_name = provider.get("name", "")
1677
+ if provider_name == "JumpStart":
1678
+ classification = provider.get("classification", "ungated")
1679
+ if classification == "gated":
1680
+ return get_jumpstart_gated_content_bucket(region=region)
1681
+ return get_jumpstart_content_bucket(region=region)
1682
+ return neo_bucket
1683
+
1684
+
1685
+ def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
1686
+ init_kwargs: dict, accept_eula: Optional[bool]
1687
+ ):
1688
+ """Remove env vars if access configs are used
1689
+
1690
+ Args:
1691
+ init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
1692
+ accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1693
+ """
1694
+ if accept_eula is not None and init_kwargs["environment"]:
1695
+ del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
1696
+
1697
+
1698
+ def get_hub_access_config(hub_content_arn: Optional[str]):
1699
+ """Get hub access config
1700
+
1701
+ Args:
1702
+ hub_content_arn (Optional[bool]): Arn of the model reference hub content
1703
+ """
1704
+ if hub_content_arn is not None:
1705
+ hub_access_config = {"HubContentArn": hub_content_arn}
1706
+ else:
1707
+ hub_access_config = None
1708
+
1709
+ return hub_access_config
1710
+
1711
+
1712
+ def get_model_access_config(accept_eula: Optional[bool]):
1713
+ """Get access configs
1714
+
1715
+ Args:
1716
+ accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1717
+ """
1718
+ if accept_eula is not None:
1719
+ model_access_config = {"AcceptEula": accept_eula}
1720
+ else:
1721
+ model_access_config = None
1722
+
1723
+ return model_access_config
1724
+
1725
+
1726
+ def get_latest_version(versions: List[str]) -> Optional[str]:
1727
+ """Returns the latest version using sem-ver when possible."""
1728
+ try:
1729
+ return None if not versions else max(versions, key=Version)
1730
+ except InvalidVersion:
1731
+ return max(versions)