sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,833 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores JumpStart factory utilities."""
14
+
15
+ from __future__ import absolute_import
16
+ import json
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+ from sagemaker.core.shapes import ModelAccessConfig
19
+ from sagemaker.core import (
20
+ environment_variables,
21
+ image_uris,
22
+ instance_types,
23
+ model_uris,
24
+ script_uris,
25
+ )
26
+ from sagemaker.serve.async_inference.async_inference_config import AsyncInferenceConfig
27
+ from sagemaker.core.deserializers.base import BaseDeserializer
28
+ from sagemaker.core.serializers.base import BaseSerializer
29
+ from sagemaker.core.explainer.explainer_config import ExplainerConfig
30
+ from sagemaker.core.jumpstart.artifacts import (
31
+ _model_supports_inference_script_uri,
32
+ _retrieve_model_init_kwargs,
33
+ _retrieve_model_deploy_kwargs,
34
+ _retrieve_model_package_arn,
35
+ )
36
+ from sagemaker.core.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
37
+ from sagemaker.core.jumpstart.constants import (
38
+ INFERENCE_ENTRY_POINT_SCRIPT_NAME,
39
+ JUMPSTART_DEFAULT_REGION_NAME,
40
+ JUMPSTART_LOGGER,
41
+ )
42
+
43
+ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
44
+ from sagemaker.core.jumpstart.hub.utils import (
45
+ construct_hub_model_arn_from_inputs,
46
+ construct_hub_model_reference_arn_from_inputs,
47
+ )
48
+
49
+ from sagemaker.core.jumpstart.enums import (
50
+ JumpStartScriptScope,
51
+ JumpStartModelType,
52
+ HubContentCapability,
53
+ )
54
+ from sagemaker.core.jumpstart.types import (
55
+ HubContentType,
56
+ JumpStartEstimatorDeployKwargs,
57
+ JumpStartEstimatorFitKwargs,
58
+ JumpStartEstimatorInitKwargs,
59
+ JumpStartModelDeployKwargs,
60
+ JumpStartModelInitKwargs,
61
+ JumpStartModelSpecs,
62
+ )
63
+ from sagemaker.core.jumpstart.utils import (
64
+ add_hub_content_arn_tags,
65
+ add_jumpstart_model_info_tags,
66
+ add_bedrock_store_tags,
67
+ get_default_jumpstart_session_with_user_agent_suffix,
68
+ get_top_ranked_config_name,
69
+ update_dict_if_key_not_present,
70
+ resolve_model_sagemaker_config_field,
71
+ verify_model_region_and_return_specs,
72
+ get_draft_model_content_bucket,
73
+ )
74
+
75
+ from sagemaker.core.model_monitor.data_capture_config import DataCaptureConfig
76
+
77
+ from sagemaker.serve.serverless.serverless_inference_config import ServerlessInferenceConfig
78
+ from sagemaker.core.helper.session_helper import Session
79
+ from sagemaker.core.common_utils import (
80
+ camel_case_to_pascal_case,
81
+ name_from_base,
82
+ format_tags,
83
+ Tags,
84
+ )
85
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
86
+ from sagemaker.serve.compute_resource_requirements.resource_requirements import ResourceRequirements
87
+ from sagemaker.core import resource_requirements
88
+ from sagemaker.core.enums import EndpointType
89
+
90
+
91
+ KwargsType = Union[
92
+ JumpStartModelDeployKwargs,
93
+ JumpStartModelInitKwargs,
94
+ JumpStartEstimatorFitKwargs,
95
+ JumpStartEstimatorInitKwargs,
96
+ JumpStartEstimatorDeployKwargs,
97
+ ]
98
+
99
+
100
+ def get_model_info_default_kwargs(
101
+ kwargs: KwargsType,
102
+ include_config_name: bool = True,
103
+ include_model_version: bool = True,
104
+ include_tolerate_flags: bool = True,
105
+ ) -> dict:
106
+ """Returns a dictionary of model info kwargs to use with JumpStart APIs."""
107
+
108
+ kwargs_dict = {
109
+ "model_id": kwargs.model_id,
110
+ "hub_arn": kwargs.hub_arn,
111
+ "region": kwargs.region,
112
+ "sagemaker_session": kwargs.sagemaker_session,
113
+ "model_type": kwargs.model_type,
114
+ }
115
+ if include_config_name:
116
+ kwargs_dict.update({"config_name": kwargs.config_name})
117
+
118
+ if include_model_version:
119
+ kwargs_dict.update({"model_version": kwargs.model_version})
120
+
121
+ if include_tolerate_flags:
122
+ kwargs_dict.update(
123
+ {
124
+ "tolerate_deprecated_model": kwargs.tolerate_deprecated_model,
125
+ "tolerate_vulnerable_model": kwargs.tolerate_vulnerable_model,
126
+ }
127
+ )
128
+
129
+ return kwargs_dict
130
+
131
+
132
+ def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]:
133
+ """Sets a temporary sagemaker session if one is not set, and returns original session.
134
+
135
+ We need to create a default JS session (without custom user agent)
136
+ in order to retrieve config name info.
137
+ """
138
+
139
+ orig_session = kwargs.sagemaker_session
140
+ if kwargs.sagemaker_session is None:
141
+ kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
142
+ return kwargs, orig_session
143
+
144
+
145
+ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
146
+ """Sets region kwargs based on default or override, returns full kwargs."""
147
+
148
+ kwargs.region = (
149
+ kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME
150
+ )
151
+
152
+ return kwargs
153
+
154
+
155
+ def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
156
+ kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs],
157
+ orig_session: Optional[Session],
158
+ ) -> JumpStartModelInitKwargs:
159
+ """Sets session in kwargs based on default or override, returns full kwargs."""
160
+
161
+ kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix(
162
+ model_id=kwargs.model_id,
163
+ model_version=kwargs.model_version,
164
+ config_name=kwargs.config_name,
165
+ is_hub_content=kwargs.hub_arn is not None,
166
+ )
167
+
168
+ return kwargs
169
+
170
+
171
+ def _add_role_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
172
+ """Sets role based on default or override, returns full kwargs."""
173
+
174
+ kwargs.role = resolve_model_sagemaker_config_field(
175
+ field_name="role",
176
+ field_val=kwargs.role,
177
+ sagemaker_session=kwargs.sagemaker_session,
178
+ default_value=kwargs.role,
179
+ )
180
+
181
+ return kwargs
182
+
183
+
184
+ def _add_model_version_to_kwargs(
185
+ kwargs: JumpStartModelInitKwargs,
186
+ ) -> JumpStartModelInitKwargs:
187
+ """Sets model version based on default or override, returns full kwargs."""
188
+
189
+ kwargs.model_version = kwargs.model_version or "*"
190
+
191
+ if kwargs.hub_arn:
192
+ hub_content_version = kwargs.specs.version
193
+ kwargs.model_version = hub_content_version
194
+
195
+ return kwargs
196
+
197
+
198
+ def _add_vulnerable_and_deprecated_status_to_kwargs(
199
+ kwargs: JumpStartModelInitKwargs,
200
+ ) -> JumpStartModelInitKwargs:
201
+ """Sets deprecated and vulnerability check status, returns full kwargs."""
202
+
203
+ kwargs.tolerate_deprecated_model = kwargs.tolerate_deprecated_model or False
204
+ kwargs.tolerate_vulnerable_model = kwargs.tolerate_vulnerable_model or False
205
+
206
+ return kwargs
207
+
208
+
209
+ def _add_instance_type_to_kwargs(
210
+ kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False
211
+ ) -> JumpStartModelInitKwargs:
212
+ """Sets instance type based on default or override, returns full kwargs."""
213
+
214
+ orig_instance_type = kwargs.instance_type
215
+ kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default(
216
+ **get_model_info_default_kwargs(kwargs),
217
+ scope=JumpStartScriptScope.INFERENCE,
218
+ training_instance_type=kwargs.training_instance_type,
219
+ )
220
+
221
+ if not disable_instance_type_logging and orig_instance_type is None:
222
+ JUMPSTART_LOGGER.info(
223
+ "No instance type selected for inference hosting endpoint. Defaulting to %s.",
224
+ kwargs.instance_type,
225
+ )
226
+
227
+ specs = kwargs.specs
228
+
229
+ if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
230
+ return kwargs
231
+
232
+ resolved_config = (
233
+ specs.inference_configs.configs[kwargs.config_name].resolved_config
234
+ if specs.inference_configs
235
+ else None
236
+ )
237
+ if resolved_config is None:
238
+ return kwargs
239
+ supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
240
+ if kwargs.instance_type not in supported_instance_types:
241
+ JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
242
+
243
+ return kwargs
244
+
245
+
246
+ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
247
+ """Sets image uri based on default or override, returns full kwargs.
248
+ Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages
249
+ """
250
+
251
+ if kwargs.model_type == JumpStartModelType.PROPRIETARY:
252
+ kwargs.image_uri = None
253
+ return kwargs
254
+
255
+ kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
256
+ **get_model_info_default_kwargs(kwargs),
257
+ framework=None,
258
+ image_scope=JumpStartScriptScope.INFERENCE,
259
+ instance_type=kwargs.instance_type,
260
+ )
261
+
262
+ return kwargs
263
+
264
+
265
+ def _add_model_reference_arn_to_kwargs(
266
+ kwargs: JumpStartModelInitKwargs,
267
+ ) -> JumpStartModelInitKwargs:
268
+ """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
269
+
270
+ hub_content_type = kwargs.specs.hub_content_type
271
+ kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None
272
+
273
+ if hub_content_type == HubContentType.MODEL_REFERENCE:
274
+ kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
275
+ hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
276
+ )
277
+ else:
278
+ kwargs.model_reference_arn = None
279
+ return kwargs
280
+
281
+
282
+ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
283
+ """Sets model data based on default or override, returns full kwargs."""
284
+
285
+ if kwargs.model_type == JumpStartModelType.PROPRIETARY:
286
+ kwargs.model_data = None
287
+ return kwargs
288
+
289
+ model_info_kwargs = get_model_info_default_kwargs(kwargs)
290
+ model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
291
+ **model_info_kwargs,
292
+ model_scope=JumpStartScriptScope.INFERENCE,
293
+ instance_type=kwargs.instance_type,
294
+ )
295
+
296
+ if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
297
+ old_model_data_str = model_data
298
+ model_data = {
299
+ "S3DataSource": {
300
+ "S3Uri": model_data,
301
+ "S3DataType": "S3Prefix",
302
+ "CompressionType": "None",
303
+ }
304
+ }
305
+ if kwargs.model_data:
306
+ JUMPSTART_LOGGER.info(
307
+ "S3 prefix model_data detected for JumpStartModel: '%s'. "
308
+ "Converting to S3DataSource dictionary: '%s'.",
309
+ old_model_data_str,
310
+ json.dumps(model_data),
311
+ )
312
+
313
+ kwargs.model_data = model_data
314
+
315
+ return kwargs
316
+
317
+
318
+ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
319
+ """Sets source dir based on default or override, returns full kwargs."""
320
+
321
+ if kwargs.model_type == JumpStartModelType.PROPRIETARY:
322
+ kwargs.source_dir = None
323
+ return kwargs
324
+
325
+ source_dir = kwargs.source_dir
326
+
327
+ if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
328
+ source_dir = source_dir or script_uris.retrieve(
329
+ **get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE
330
+ )
331
+
332
+ kwargs.source_dir = source_dir
333
+
334
+ return kwargs
335
+
336
+
337
+ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
338
+ """Sets entry point based on default or override, returns full kwargs."""
339
+
340
+ if kwargs.model_type == JumpStartModelType.PROPRIETARY:
341
+ kwargs.entry_point = None
342
+ return kwargs
343
+
344
+ entry_point = kwargs.entry_point
345
+
346
+ if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
347
+
348
+ entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
349
+
350
+ kwargs.entry_point = entry_point
351
+
352
+ return kwargs
353
+
354
+
355
+ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
356
+ """Sets env based on default or override, returns full kwargs."""
357
+
358
+ if kwargs.model_type == JumpStartModelType.PROPRIETARY:
359
+ kwargs.env = None
360
+ return kwargs
361
+
362
+ env = kwargs.env
363
+
364
+ if env is None:
365
+ env = {}
366
+
367
+ extra_env_vars = environment_variables.retrieve_default(
368
+ **get_model_info_default_kwargs(kwargs),
369
+ include_aws_sdk_env_vars=False,
370
+ script=JumpStartScriptScope.INFERENCE,
371
+ instance_type=kwargs.instance_type,
372
+ )
373
+
374
+ for key, value in extra_env_vars.items():
375
+ update_dict_if_key_not_present(
376
+ env,
377
+ key,
378
+ value,
379
+ )
380
+
381
+ if env == {}:
382
+ env = None
383
+
384
+ kwargs.env = env
385
+
386
+ return kwargs
387
+
388
+
389
+ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
390
+ """Sets model package arn based on default or override, returns full kwargs."""
391
+
392
+ model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
393
+ **get_model_info_default_kwargs(kwargs),
394
+ instance_type=kwargs.instance_type,
395
+ scope=JumpStartScriptScope.INFERENCE,
396
+ )
397
+
398
+ kwargs.model_package_arn = model_package_arn
399
+ return kwargs
400
+
401
+
402
+ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
403
+ """Sets extra kwargs based on default or override, returns full kwargs."""
404
+
405
+ model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs))
406
+
407
+ for key, value in model_kwargs_to_add.items():
408
+ if getattr(kwargs, key) is None:
409
+ resolved_value = resolve_model_sagemaker_config_field(
410
+ field_name=key,
411
+ field_val=value,
412
+ sagemaker_session=kwargs.sagemaker_session,
413
+ )
414
+ setattr(kwargs, key, resolved_value)
415
+
416
+ return kwargs
417
+
418
+
419
+ def _add_endpoint_name_to_kwargs(
420
+ kwargs: Optional[JumpStartModelDeployKwargs],
421
+ ) -> JumpStartModelDeployKwargs:
422
+ """Sets resource name based on default or override, returns full kwargs."""
423
+
424
+ default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
425
+
426
+ kwargs.endpoint_name = kwargs.endpoint_name or (
427
+ name_from_base(default_endpoint_name) if default_endpoint_name is not None else None
428
+ )
429
+
430
+ return kwargs
431
+
432
+
433
+ def _add_model_name_to_kwargs(
434
+ kwargs: Optional[JumpStartModelInitKwargs],
435
+ ) -> JumpStartModelInitKwargs:
436
+ """Sets resource name based on default or override, returns full kwargs."""
437
+
438
+ default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
439
+
440
+ kwargs.name = kwargs.name or (
441
+ name_from_base(default_model_name) if default_model_name is not None else None
442
+ )
443
+
444
+ return kwargs
445
+
446
+
447
+ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
448
+ """Sets tags based on default or override, returns full kwargs."""
449
+
450
+ full_model_version = kwargs.specs.version
451
+
452
+ if kwargs.sagemaker_session.settings.include_jumpstart_tags:
453
+ kwargs.tags = add_jumpstart_model_info_tags(
454
+ kwargs.tags,
455
+ kwargs.model_id,
456
+ full_model_version,
457
+ kwargs.model_type,
458
+ config_name=kwargs.config_name,
459
+ scope=JumpStartScriptScope.INFERENCE,
460
+ )
461
+
462
+ if kwargs.hub_arn:
463
+ if kwargs.model_reference_arn:
464
+ hub_content_arn = construct_hub_model_reference_arn_from_inputs(
465
+ kwargs.hub_arn, kwargs.model_id, kwargs.model_version
466
+ )
467
+ else:
468
+ hub_content_arn = construct_hub_model_arn_from_inputs(
469
+ kwargs.hub_arn, kwargs.model_id, kwargs.model_version
470
+ )
471
+ kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
472
+
473
+ if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
474
+ if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
475
+ kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
476
+
477
+ return kwargs
478
+
479
+
480
+ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]:
481
+ """Sets extra kwargs based on default or override, returns full kwargs."""
482
+
483
+ deploy_kwargs_to_add = _retrieve_model_deploy_kwargs(
484
+ **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type
485
+ )
486
+
487
+ for key, value in deploy_kwargs_to_add.items():
488
+ if getattr(kwargs, key) is None:
489
+ setattr(kwargs, key, value)
490
+
491
+ return kwargs
492
+
493
+
494
+ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
495
+ """Sets the resource requirements based on the default or an override. Returns full kwargs."""
496
+
497
+ kwargs.resources = kwargs.resources or resource_requirements.retrieve_default(
498
+ **get_model_info_default_kwargs(kwargs),
499
+ scope=JumpStartScriptScope.INFERENCE,
500
+ instance_type=kwargs.instance_type,
501
+ )
502
+
503
+ return kwargs
504
+
505
+
506
+ def _select_inference_config_from_training_config(
507
+ specs: JumpStartModelSpecs, training_config_name: str
508
+ ) -> Optional[str]:
509
+ """Selects the inference config from the training config.
510
+ Args:
511
+ specs (JumpStartModelSpecs): The specs for the model.
512
+ training_config_name (str): The name of the training config.
513
+ Returns:
514
+ str: The name of the inference config.
515
+ """
516
+ if specs.training_configs:
517
+ resolved_training_config = specs.training_configs.configs.get(training_config_name)
518
+ if resolved_training_config:
519
+ return resolved_training_config.default_inference_config
520
+
521
+ return None
522
+
523
+
524
+ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
525
+ """Sets default config name to the kwargs. Returns full kwargs.
526
+ Raises:
527
+ ValueError: If the instance_type is not supported with the current config.
528
+ """
529
+
530
+ kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
531
+ **get_model_info_default_kwargs(kwargs, include_config_name=False),
532
+ scope=JumpStartScriptScope.INFERENCE,
533
+ )
534
+
535
+ if kwargs.config_name is None:
536
+ return kwargs
537
+
538
+ return kwargs
539
+
540
+
541
+ def _add_additional_model_data_sources_to_kwargs(
542
+ kwargs: JumpStartModelInitKwargs,
543
+ ) -> JumpStartModelInitKwargs:
544
+ """Sets default additional model data sources to init kwargs"""
545
+
546
+ specs = kwargs.specs
547
+ # Append speculative decoding data source from metadata
548
+ speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
549
+ for data_source in speculative_decoding_data_sources:
550
+ data_source.s3_data_source.set_bucket(
551
+ get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
552
+ )
553
+ api_shape_additional_model_data_sources = (
554
+ [
555
+ camel_case_to_pascal_case(data_source.to_json())
556
+ for data_source in speculative_decoding_data_sources
557
+ ]
558
+ if specs.get_speculative_decoding_s3_data_sources()
559
+ else None
560
+ )
561
+
562
+ kwargs.additional_model_data_sources = (
563
+ kwargs.additional_model_data_sources or api_shape_additional_model_data_sources
564
+ )
565
+
566
+ return kwargs
567
+
568
+
569
+ def _add_config_name_to_deploy_kwargs(
570
+ kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
571
+ ) -> JumpStartModelInitKwargs:
572
+ """Sets default config name to the kwargs. Returns full kwargs.
573
+ If a training_config_name is passed, then choose the inference config
574
+ based on the supported inference configs in that training config.
575
+ Raises:
576
+ ValueError: If the instance_type is not supported with the current config.
577
+ """
578
+
579
+ if training_config_name:
580
+
581
+ specs = kwargs.specs
582
+ default_config_name = _select_inference_config_from_training_config(
583
+ specs=specs, training_config_name=training_config_name
584
+ )
585
+
586
+ else:
587
+ default_config_name = kwargs.config_name or get_top_ranked_config_name(
588
+ **get_model_info_default_kwargs(kwargs, include_config_name=False),
589
+ scope=JumpStartScriptScope.INFERENCE,
590
+ )
591
+
592
+ kwargs.config_name = kwargs.config_name or default_config_name
593
+
594
+ return kwargs
595
+
596
+
597
+ def get_deploy_kwargs(
598
+ model_id: str,
599
+ model_version: Optional[str] = None,
600
+ hub_arn: Optional[str] = None,
601
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
602
+ region: Optional[str] = None,
603
+ initial_instance_count: Optional[int] = None,
604
+ instance_type: Optional[str] = None,
605
+ serializer: Optional[BaseSerializer] = None,
606
+ deserializer: Optional[BaseDeserializer] = None,
607
+ accelerator_type: Optional[str] = None,
608
+ endpoint_name: Optional[str] = None,
609
+ inference_component_name: Optional[str] = None,
610
+ tags: Optional[Tags] = None,
611
+ kms_key: Optional[str] = None,
612
+ wait: Optional[bool] = None,
613
+ data_capture_config: Optional[DataCaptureConfig] = None,
614
+ async_inference_config: Optional[AsyncInferenceConfig] = None,
615
+ serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
616
+ volume_size: Optional[int] = None,
617
+ model_data_download_timeout: Optional[int] = None,
618
+ container_startup_health_check_timeout: Optional[int] = None,
619
+ inference_recommendation_id: Optional[str] = None,
620
+ explainer_config: Optional[ExplainerConfig] = None,
621
+ tolerate_vulnerable_model: Optional[bool] = None,
622
+ tolerate_deprecated_model: Optional[bool] = None,
623
+ sagemaker_session: Optional[Session] = None,
624
+ accept_eula: Optional[bool] = None,
625
+ model_reference_arn: Optional[str] = None,
626
+ endpoint_logging: Optional[bool] = None,
627
+ resources: Optional[ResourceRequirements] = None,
628
+ managed_instance_scaling: Optional[str] = None,
629
+ endpoint_type: Optional[EndpointType] = None,
630
+ training_config_name: Optional[str] = None,
631
+ config_name: Optional[str] = None,
632
+ routing_config: Optional[Dict[str, Any]] = None,
633
+ model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
634
+ inference_ami_version: Optional[str] = None,
635
+ ) -> JumpStartModelDeployKwargs:
636
+ """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
637
+
638
+ deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs(
639
+ model_id=model_id,
640
+ model_version=model_version,
641
+ hub_arn=hub_arn,
642
+ model_type=model_type,
643
+ region=region,
644
+ initial_instance_count=initial_instance_count,
645
+ instance_type=instance_type,
646
+ serializer=serializer,
647
+ deserializer=deserializer,
648
+ accelerator_type=accelerator_type,
649
+ endpoint_name=endpoint_name,
650
+ inference_component_name=inference_component_name,
651
+ tags=format_tags(tags),
652
+ kms_key=kms_key,
653
+ wait=wait,
654
+ data_capture_config=data_capture_config,
655
+ async_inference_config=async_inference_config,
656
+ serverless_inference_config=serverless_inference_config,
657
+ volume_size=volume_size,
658
+ model_data_download_timeout=model_data_download_timeout,
659
+ container_startup_health_check_timeout=container_startup_health_check_timeout,
660
+ inference_recommendation_id=inference_recommendation_id,
661
+ explainer_config=explainer_config,
662
+ tolerate_deprecated_model=tolerate_deprecated_model,
663
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
664
+ sagemaker_session=sagemaker_session,
665
+ accept_eula=accept_eula,
666
+ model_reference_arn=model_reference_arn,
667
+ endpoint_logging=endpoint_logging,
668
+ resources=resources,
669
+ config_name=config_name,
670
+ routing_config=routing_config,
671
+ model_access_configs=model_access_configs,
672
+ inference_ami_version=inference_ami_version,
673
+ )
674
+ deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
675
+ deploy_kwargs.specs = verify_model_region_and_return_specs(
676
+ **get_model_info_default_kwargs(
677
+ deploy_kwargs, include_model_version=False, include_tolerate_flags=False
678
+ ),
679
+ version=deploy_kwargs.model_version or "*",
680
+ scope=JumpStartScriptScope.INFERENCE,
681
+ # We set these flags to True to retrieve the json specs.
682
+ # Exceptions will be thrown later if these are not tolerated.
683
+ tolerate_deprecated_model=True,
684
+ tolerate_vulnerable_model=True,
685
+ )
686
+
687
+ deploy_kwargs = _add_config_name_to_deploy_kwargs(
688
+ kwargs=deploy_kwargs, training_config_name=training_config_name
689
+ )
690
+
691
+ deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
692
+
693
+ deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
694
+ kwargs=deploy_kwargs, orig_session=orig_session
695
+ )
696
+
697
+ deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
698
+
699
+ deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
700
+
701
+ deploy_kwargs.initial_instance_count = initial_instance_count or 1
702
+
703
+ deploy_kwargs = _add_deploy_extra_kwargs(kwargs=deploy_kwargs)
704
+
705
+ deploy_kwargs = _add_tags_to_kwargs(kwargs=deploy_kwargs)
706
+
707
+ if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
708
+ deploy_kwargs = _add_resources_to_kwargs(kwargs=deploy_kwargs)
709
+ deploy_kwargs.endpoint_type = endpoint_type
710
+ deploy_kwargs.managed_instance_scaling = managed_instance_scaling
711
+
712
+ return deploy_kwargs
713
+
714
+
715
+ def get_init_kwargs(
716
+ model_id: str,
717
+ model_from_estimator: bool = False,
718
+ model_version: Optional[str] = None,
719
+ hub_arn: Optional[str] = None,
720
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
721
+ tolerate_vulnerable_model: Optional[bool] = None,
722
+ tolerate_deprecated_model: Optional[bool] = None,
723
+ instance_type: Optional[str] = None,
724
+ region: Optional[str] = None,
725
+ image_uri: Optional[Union[str, PipelineVariable]] = None,
726
+ model_data: Optional[Union[str, PipelineVariable, dict]] = None,
727
+ role: Optional[str] = None,
728
+ env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
729
+ name: Optional[str] = None,
730
+ vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
731
+ sagemaker_session: Optional[Session] = None,
732
+ enable_network_isolation: Union[bool, PipelineVariable] = None,
733
+ model_kms_key: Optional[str] = None,
734
+ image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
735
+ source_dir: Optional[str] = None,
736
+ code_location: Optional[str] = None,
737
+ entry_point: Optional[str] = None,
738
+ container_log_level: Optional[Union[int, PipelineVariable]] = None,
739
+ dependencies: Optional[List[str]] = None,
740
+ git_config: Optional[Dict[str, str]] = None,
741
+ model_package_arn: Optional[str] = None,
742
+ training_instance_type: Optional[str] = None,
743
+ disable_instance_type_logging: bool = False,
744
+ resources: Optional[ResourceRequirements] = None,
745
+ config_name: Optional[str] = None,
746
+ additional_model_data_sources: Optional[Dict[str, Any]] = None,
747
+ ) -> JumpStartModelInitKwargs:
748
+ """Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
749
+
750
+ model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs(
751
+ model_id=model_id,
752
+ model_version=model_version,
753
+ hub_arn=hub_arn,
754
+ model_type=model_type,
755
+ instance_type=instance_type,
756
+ region=region,
757
+ image_uri=image_uri,
758
+ model_data=model_data,
759
+ source_dir=source_dir,
760
+ entry_point=entry_point,
761
+ env=env,
762
+ role=role,
763
+ name=name,
764
+ vpc_config=vpc_config,
765
+ sagemaker_session=sagemaker_session,
766
+ enable_network_isolation=enable_network_isolation,
767
+ model_kms_key=model_kms_key,
768
+ image_config=image_config,
769
+ code_location=code_location,
770
+ container_log_level=container_log_level,
771
+ dependencies=dependencies,
772
+ git_config=git_config,
773
+ tolerate_deprecated_model=tolerate_deprecated_model,
774
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
775
+ model_package_arn=model_package_arn,
776
+ training_instance_type=training_instance_type,
777
+ resources=resources,
778
+ config_name=config_name,
779
+ additional_model_data_sources=additional_model_data_sources,
780
+ )
781
+ model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
782
+ kwargs=model_init_kwargs
783
+ )
784
+ model_init_kwargs.specs = verify_model_region_and_return_specs(
785
+ **get_model_info_default_kwargs(
786
+ model_init_kwargs, include_model_version=False, include_tolerate_flags=False
787
+ ),
788
+ version=model_init_kwargs.model_version or "*",
789
+ scope=JumpStartScriptScope.INFERENCE,
790
+ # We set these flags to True to retrieve the json specs.
791
+ # Exceptions will be thrown later if these are not tolerated.
792
+ tolerate_deprecated_model=True,
793
+ tolerate_vulnerable_model=True,
794
+ )
795
+
796
+ model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
797
+ model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
798
+ model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
799
+
800
+ model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
801
+ kwargs=model_init_kwargs, orig_session=orig_session
802
+ )
803
+ model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
804
+
805
+ model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
806
+
807
+ model_init_kwargs = _add_instance_type_to_kwargs(
808
+ kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
809
+ )
810
+
811
+ model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)
812
+
813
+ if hub_arn:
814
+ model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs)
815
+ else:
816
+ model_init_kwargs.model_reference_arn = None
817
+ model_init_kwargs.hub_content_type = None
818
+
819
+ # we use the model artifact from the training job output
820
+ if not model_from_estimator:
821
+ model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs)
822
+ model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs)
823
+ model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs)
824
+ model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs)
825
+ model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs)
826
+ model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs)
827
+ model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
828
+
829
+ model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
830
+
831
+ model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
832
+
833
+ return model_init_kwargs