sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,3044 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ # pylint: skip-file
14
+ """This module stores types related to SageMaker JumpStart."""
15
+ from __future__ import absolute_import
16
+ from __future__ import annotations
17
+ import re
18
+ from copy import deepcopy
19
+ from enum import Enum
20
+ from typing import Any, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING
21
+ from sagemaker.core.shapes import ModelAccessConfig as CoreModelAccessConfig
22
+ from sagemaker.core.shapes.model_card_shapes import ModelCardContent as ModelCard
23
+ from sagemaker.core.shapes.model_card_shapes import ModelCardContent as ModelPackageModelCard
24
+ from sagemaker.core.common_utils import (
25
+ S3_PREFIX,
26
+ get_instance_type_family,
27
+ format_tags,
28
+ Tags,
29
+ deep_override_dict,
30
+ camel_to_snake,
31
+ walk_and_apply_json,
32
+ )
33
+ from sagemaker.core.model_metrics import ModelMetrics
34
+ from sagemaker.core.metadata_properties import MetadataProperties
35
+ from sagemaker.core.drift_check_baselines import DriftCheckBaselines
36
+ from sagemaker.core.jumpstart.enums import (
37
+ JumpStartModelType,
38
+ JumpStartScriptScope,
39
+ JumpStartConfigRankingName,
40
+ )
41
+
42
+ from sagemaker.core.helper.session_helper import Session
43
+ from sagemaker.core.helper.pipeline_variable import PipelineVariable
44
+ from sagemaker.core.enums import EndpointType
45
+
46
+ from sagemaker.core.model_life_cycle import ModelLifeCycle
47
+
48
+ if TYPE_CHECKING:
49
+ from sagemaker.core.resource_requirements import ResourceRequirements
50
+
51
+
52
+ class JumpStartDataHolderType:
53
+ """Base class for many JumpStart types.
54
+
55
+ Allows objects to be added to dicts and sets,
56
+ and improves string representation. This class overrides the ``__eq__``
57
+ and ``__hash__`` methods so that different objects with the same attributes/types
58
+ can be compared.
59
+ """
60
+
61
+ __slots__: List[str] = []
62
+
63
+ _non_serializable_slots: List[str] = []
64
+
65
+ def __eq__(self, other: Any) -> bool:
66
+ """Returns True if ``other`` is of the same type and has all attributes equal.
67
+
68
+ Args:
69
+ other (Any): Other object to which to compare this object.
70
+ """
71
+
72
+ if not isinstance(other, type(self)):
73
+ return False
74
+ if getattr(other, "__slots__", None) is None:
75
+ return False
76
+ if self.__slots__ != other.__slots__:
77
+ return False
78
+ for attribute in self.__slots__:
79
+ if (hasattr(self, attribute) and not hasattr(other, attribute)) or (
80
+ hasattr(other, attribute) and not hasattr(self, attribute)
81
+ ):
82
+ return False
83
+ if hasattr(self, attribute) and hasattr(other, attribute):
84
+ if getattr(self, attribute) != getattr(other, attribute):
85
+ return False
86
+ return True
87
+
88
+ def __hash__(self) -> int:
89
+ """Makes hash of object.
90
+
91
+ Maps object to unique tuple, which then gets hashed.
92
+ """
93
+
94
+ return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__]))
95
+
96
+ def __str__(self) -> str:
97
+ """Returns string representation of object. Example:
98
+
99
+ "JumpStartLaunchedRegionInfo:
100
+ {'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
101
+ """
102
+
103
+ att_dict = {
104
+ att: getattr(self, att)
105
+ for att in self.__slots__
106
+ if hasattr(self, att) and att not in self._non_serializable_slots
107
+ }
108
+ return f"{type(self).__name__}: {str(att_dict)}"
109
+
110
+ def __repr__(self) -> str:
111
+ """Returns ``__repr__`` string of object. Example:
112
+
113
+ "JumpStartLaunchedRegionInfo at 0x7f664529efa0:
114
+ {'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
115
+ """
116
+
117
+ att_dict = {
118
+ att: getattr(self, att)
119
+ for att in self.__slots__
120
+ if hasattr(self, att) and att not in self._non_serializable_slots
121
+ }
122
+ return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
123
+
124
+
125
+ class JumpStartS3FileType(str, Enum):
126
+ """Type of files published in JumpStart S3 distribution buckets."""
127
+
128
+ OPEN_WEIGHT_MANIFEST = "manifest"
129
+ OPEN_WEIGHT_SPECS = "specs"
130
+ PROPRIETARY_MANIFEST = "proprietary_manifest"
131
+ PROPRIETARY_SPECS = "proprietary_specs"
132
+
133
+
134
+ class HubType(str, Enum):
135
+ """Enum for Hub objects."""
136
+
137
+ HUB = "Hub"
138
+
139
+
140
+ class HubContentType(str, Enum):
141
+ """Enum for Hub content objects."""
142
+
143
+ MODEL = "Model"
144
+ NOTEBOOK = "Notebook"
145
+ MODEL_REFERENCE = "ModelReference"
146
+
147
+
148
+ JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType]
149
+
150
+
151
+ class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
152
+ """Data class for launched region info."""
153
+
154
+ __slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"]
155
+
156
+ def __init__(
157
+ self,
158
+ content_bucket: str,
159
+ region_name: str,
160
+ gated_content_bucket: Optional[str] = None,
161
+ neo_content_bucket: Optional[str] = None,
162
+ ):
163
+ """Instantiates JumpStartLaunchedRegionInfo object.
164
+
165
+ Args:
166
+ content_bucket (str): Name of JumpStart s3 content bucket associated with region.
167
+ region_name (str): Name of JumpStart launched region.
168
+ gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
169
+ optionally associated with region.
170
+ neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket
171
+ optionally associated with region.
172
+ """
173
+ self.content_bucket = content_bucket
174
+ self.gated_content_bucket = gated_content_bucket
175
+ self.region_name = region_name
176
+ self.neo_content_bucket = neo_content_bucket
177
+
178
+
179
+ class JumpStartModelHeader(JumpStartDataHolderType):
180
+ """Data class JumpStart model header."""
181
+
182
+ __slots__ = ["model_id", "version", "min_version", "spec_key", "search_keywords"]
183
+
184
+ def __init__(self, header: Dict[str, str]):
185
+ """Initializes a JumpStartModelHeader object from its json representation.
186
+
187
+ Args:
188
+ header (Dict[str, str]): Dictionary representation of header.
189
+ """
190
+ self.from_json(header)
191
+
192
+ def to_json(self) -> Dict[str, str]:
193
+ """Returns json representation of JumpStartModelHeader object."""
194
+ json_obj = {
195
+ att: getattr(self, att)
196
+ for att in self.__slots__
197
+ if getattr(self, att, None) is not None
198
+ }
199
+ return json_obj
200
+
201
+ def from_json(self, json_obj: Dict[str, str]) -> None:
202
+ """Sets fields in object based on json of header.
203
+
204
+ Args:
205
+ json_obj (Dict[str, str]): Dictionary representation of header.
206
+ """
207
+ self.model_id: str = json_obj["model_id"]
208
+ self.version: str = json_obj["version"]
209
+ self.min_version: str = json_obj["min_version"]
210
+ self.spec_key: str = json_obj["spec_key"]
211
+ self.search_keywords: Optional[List[str]] = json_obj.get("search_keywords")
212
+
213
+
214
+ class JumpStartECRSpecs(JumpStartDataHolderType):
215
+ """Data class for JumpStart ECR specs."""
216
+
217
+ __slots__ = [
218
+ "framework",
219
+ "framework_version",
220
+ "py_version",
221
+ "huggingface_transformers_version",
222
+ "_is_hub_content",
223
+ ]
224
+
225
+ _non_serializable_slots = ["_is_hub_content"]
226
+
227
+ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
228
+ """Initializes a JumpStartECRSpecs object from its json representation.
229
+
230
+ Args:
231
+ spec (Dict[str, Any]): Dictionary representation of spec.
232
+ """
233
+ self._is_hub_content = is_hub_content
234
+ self.from_json(spec)
235
+
236
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
237
+ """Sets fields in object based on json.
238
+
239
+ Args:
240
+ json_obj (Dict[str, Any]): Dictionary representation of spec.
241
+ """
242
+
243
+ if not json_obj:
244
+ return
245
+
246
+ if self._is_hub_content:
247
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
248
+
249
+ self.framework = json_obj.get("framework")
250
+ self.framework_version = json_obj.get("framework_version")
251
+ self.py_version = json_obj.get("py_version")
252
+ huggingface_transformers_version = json_obj.get("huggingface_transformers_version")
253
+ if huggingface_transformers_version is not None:
254
+ self.huggingface_transformers_version = huggingface_transformers_version
255
+
256
+ def to_json(self) -> Dict[str, Any]:
257
+ """Returns json representation of JumpStartECRSpecs object."""
258
+ json_obj = {
259
+ att: getattr(self, att)
260
+ for att in self.__slots__
261
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
262
+ }
263
+ return json_obj
264
+
265
+
266
+ class JumpStartHyperparameter(JumpStartDataHolderType):
267
+ """Data class for JumpStart hyperparameter definition in the training container."""
268
+
269
+ __slots__ = [
270
+ "name",
271
+ "type",
272
+ "options",
273
+ "default",
274
+ "scope",
275
+ "min",
276
+ "max",
277
+ "exclusive_min",
278
+ "exclusive_max",
279
+ "_is_hub_content",
280
+ ]
281
+
282
+ _non_serializable_slots = ["_is_hub_content"]
283
+
284
+ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
285
+ """Initializes a JumpStartHyperparameter object from its json representation.
286
+
287
+ Args:
288
+ spec (Dict[str, Any]): Dictionary representation of hyperparameter.
289
+ """
290
+ self._is_hub_content = is_hub_content
291
+ self.from_json(spec)
292
+
293
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
294
+ """Sets fields in object based on json.
295
+
296
+ Args:
297
+ json_obj (Dict[str, Any]): Dictionary representation of hyperparameter.
298
+ """
299
+
300
+ if self._is_hub_content:
301
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
302
+ self.name = json_obj["name"]
303
+ self.type = json_obj["type"]
304
+ self.default = json_obj["default"]
305
+ self.scope = json_obj["scope"]
306
+
307
+ options = json_obj.get("options")
308
+ if options is not None:
309
+ self.options = options
310
+
311
+ min_val = json_obj.get("min")
312
+ if min_val is not None:
313
+ self.min = min_val
314
+
315
+ max_val = json_obj.get("max")
316
+ if max_val is not None:
317
+ self.max = max_val
318
+
319
+ # HubContentDocument model schema does not allow exclusive min/max.
320
+ if self._is_hub_content:
321
+ return
322
+
323
+ exclusive_min_val = json_obj.get("exclusive_min")
324
+ exclusive_max_val = json_obj.get("exclusive_max")
325
+ if exclusive_min_val is not None:
326
+ self.exclusive_min = exclusive_min_val
327
+ if exclusive_max_val is not None:
328
+ self.exclusive_max = exclusive_max_val
329
+
330
+ def to_json(self) -> Dict[str, Any]:
331
+ """Returns json representation of JumpStartHyperparameter object."""
332
+ json_obj = {
333
+ att: getattr(self, att)
334
+ for att in self.__slots__
335
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
336
+ }
337
+ return json_obj
338
+
339
+
340
+ class JumpStartEnvironmentVariable(JumpStartDataHolderType):
341
+ """Data class for JumpStart environment variable definitions in the hosting container."""
342
+
343
+ __slots__ = [
344
+ "name",
345
+ "type",
346
+ "default",
347
+ "scope",
348
+ "required_for_model_class",
349
+ "_is_hub_content",
350
+ ]
351
+
352
+ _non_serializable_slots = ["_is_hub_content"]
353
+
354
+ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
355
+ """Initializes a JumpStartEnvironmentVariable object from its json representation.
356
+
357
+ Args:
358
+ spec (Dict[str, Any]): Dictionary representation of environment variable.
359
+ """
360
+ self._is_hub_content = is_hub_content
361
+ self.from_json(spec)
362
+
363
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
364
+ """Sets fields in object based on json.
365
+
366
+ Args:
367
+ json_obj (Dict[str, Any]): Dictionary representation of environment variable.
368
+ """
369
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
370
+ self.name = json_obj["name"]
371
+ self.type = json_obj["type"]
372
+ self.default = json_obj["default"]
373
+ self.scope = json_obj["scope"]
374
+ self.required_for_model_class: bool = json_obj.get("required_for_model_class", False)
375
+
376
+ def to_json(self) -> Dict[str, Any]:
377
+ """Returns json representation of JumpStartEnvironmentVariable object."""
378
+ json_obj = {
379
+ att: getattr(self, att)
380
+ for att in self.__slots__
381
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
382
+ }
383
+ return json_obj
384
+
385
+
386
+ class JumpStartPredictorSpecs(JumpStartDataHolderType):
387
+ """Data class for JumpStart Predictor specs."""
388
+
389
+ __slots__ = [
390
+ "default_content_type",
391
+ "supported_content_types",
392
+ "default_accept_type",
393
+ "supported_accept_types",
394
+ "_is_hub_content",
395
+ ]
396
+
397
+ _non_serializable_slots = ["_is_hub_content"]
398
+
399
+ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
400
+ """Initializes a JumpStartPredictorSpecs object from its json representation.
401
+
402
+ Args:
403
+ spec (Dict[str, Any]): Dictionary representation of predictor specs.
404
+ """
405
+ self._is_hub_content = is_hub_content
406
+ self.from_json(spec)
407
+
408
+ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
409
+ """Sets fields in object based on json.
410
+
411
+ Args:
412
+ json_obj (Dict[str, Any]): Dictionary representation of predictor specs.
413
+ """
414
+
415
+ if json_obj is None:
416
+ return
417
+
418
+ if self._is_hub_content:
419
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
420
+ self.default_content_type = json_obj["default_content_type"]
421
+ self.supported_content_types = json_obj["supported_content_types"]
422
+ self.default_accept_type = json_obj["default_accept_type"]
423
+ self.supported_accept_types = json_obj["supported_accept_types"]
424
+
425
+ def to_json(self) -> Dict[str, Any]:
426
+ """Returns json representation of JumpStartPredictorSpecs object."""
427
+ json_obj = {
428
+ att: getattr(self, att)
429
+ for att in self.__slots__
430
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
431
+ }
432
+ return json_obj
433
+
434
+
435
+ class JumpStartSerializablePayload(JumpStartDataHolderType):
436
+ """Data class for JumpStart serialized payload specs."""
437
+
438
+ __slots__ = [
439
+ "raw_payload",
440
+ "content_type",
441
+ "accept",
442
+ "body",
443
+ "prompt_key",
444
+ "_is_hub_content",
445
+ ]
446
+
447
+ _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"]
448
+
449
+ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
450
+ """Initializes a JumpStartSerializablePayload object from its json representation.
451
+
452
+ Args:
453
+ spec (Dict[str, Any]): Dictionary representation of payload specs.
454
+ """
455
+ self._is_hub_content = is_hub_content
456
+ self.from_json(spec)
457
+
458
+ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
459
+ """Sets fields in object based on json.
460
+
461
+ Args:
462
+ json_obj (Dict[str, Any]): Dictionary representation of serializable
463
+ payload specs.
464
+
465
+ Raises:
466
+ KeyError: If the dictionary is missing keys.
467
+ """
468
+
469
+ if json_obj is None:
470
+ return
471
+
472
+ if self._is_hub_content:
473
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
474
+ self.raw_payload = json_obj
475
+ self.content_type = json_obj["content_type"]
476
+ self.body = json_obj.get("body")
477
+ accept = json_obj.get("accept")
478
+ self.prompt_key = json_obj.get("prompt_key")
479
+ if accept:
480
+ self.accept = accept
481
+
482
+ def to_json(self) -> Dict[str, Any]:
483
+ """Returns json representation of JumpStartSerializablePayload object."""
484
+ return deepcopy(self.raw_payload)
485
+
486
+
487
+ class JumpStartInstanceTypeVariants(JumpStartDataHolderType):
488
+ """Data class for JumpStart instance type variants."""
489
+
490
+ __slots__ = [
491
+ "regional_aliases",
492
+ "aliases",
493
+ "variants",
494
+ "_is_hub_content",
495
+ ]
496
+
497
+ _non_serializable_slots = ["_is_hub_content"]
498
+
499
+ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
500
+ """Initializes a JumpStartInstanceTypeVariants object from its json representation.
501
+
502
+ Args:
503
+ spec (Dict[str, Any]): Dictionary representation of instance type variants.
504
+ """
505
+
506
+ self._is_hub_content = is_hub_content
507
+
508
+ if self._is_hub_content:
509
+ self.from_describe_hub_content_response(spec)
510
+ else:
511
+ self.from_json(spec)
512
+
513
+ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
514
+ """Sets fields in object based on json.
515
+
516
+ Args:
517
+ json_obj (Dict[str, Any]): Dictionary representation of instance type variants.
518
+ """
519
+
520
+ if json_obj is None:
521
+ return
522
+
523
+ self.aliases = None
524
+ self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases")
525
+ self.variants: Optional[dict] = json_obj.get("variants")
526
+
527
+ def to_json(self) -> Dict[str, Any]:
528
+ """Returns json representation of JumpStartInstance object."""
529
+ json_obj = {
530
+ att: getattr(self, att)
531
+ for att in self.__slots__
532
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
533
+ }
534
+ return json_obj
535
+
536
+ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None:
537
+ """Sets fields in object based on DescribeHubContent response.
538
+
539
+ Args:
540
+ response (Dict[str, Any]): Dictionary representation of instance type variants.
541
+ """
542
+
543
+ if response is None:
544
+ return
545
+
546
+ response = walk_and_apply_json(response, camel_to_snake)
547
+ self.aliases: Optional[dict] = response.get("aliases")
548
+ self.regional_aliases = None
549
+ self.variants: Optional[dict] = response.get("variants")
550
+
551
+ def regionalize( # pylint: disable=inconsistent-return-statements
552
+ self, region: str
553
+ ) -> Optional[Dict[str, Any]]:
554
+ """Returns regionalized instance type variants."""
555
+
556
+ if self.regional_aliases is None or self.aliases is not None:
557
+ return
558
+ aliases = self.regional_aliases.get(region, {})
559
+ variants = {}
560
+ for instance_name, properties in self.variants.items():
561
+ if properties.get("regional_properties") is not None:
562
+ variants.update({instance_name: properties.get("regional_properties")})
563
+ if properties.get("properties") is not None:
564
+ variants.update({instance_name: properties.get("properties")})
565
+ return {"Aliases": aliases, "Variants": variants}
566
+
567
+ def get_instance_specific_metric_definitions(
568
+ self, instance_type: str
569
+ ) -> List[JumpStartHyperparameter]:
570
+ """Returns instance specific metric definitions.
571
+
572
+ Returns empty list if a model, instance type tuple does not have specific
573
+ metric definitions.
574
+ """
575
+
576
+ if self.variants is None:
577
+ return []
578
+
579
+ instance_specific_metric_definitions: List[Dict[str, Union[str, Any]]] = (
580
+ self.variants.get(instance_type, {}).get("properties", {}).get("metrics", [])
581
+ )
582
+
583
+ instance_type_family = get_instance_type_family(instance_type)
584
+
585
+ instance_family_metric_definitions: List[Dict[str, Union[str, Any]]] = (
586
+ self.variants.get(instance_type_family, {}).get("properties", {}).get("metrics", [])
587
+ if instance_type_family not in {"", None}
588
+ else []
589
+ )
590
+
591
+ instance_specific_metric_names = {
592
+ metric_definition["Name"] for metric_definition in instance_specific_metric_definitions
593
+ }
594
+
595
+ metric_definitions_to_return = deepcopy(instance_specific_metric_definitions)
596
+
597
+ for instance_family_metric_definition in instance_family_metric_definitions:
598
+ if instance_family_metric_definition["Name"] not in instance_specific_metric_names:
599
+ metric_definitions_to_return.append(instance_family_metric_definition)
600
+
601
+ return metric_definitions_to_return
602
+
603
+ def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
604
+ """Returns instance specific model artifact key.
605
+
606
+ Returns None if a model, instance type tuple does not have specific
607
+ artifact key.
608
+ """
609
+
610
+ return self._get_instance_specific_property(
611
+ instance_type=instance_type, property_name="prepacked_artifact_key"
612
+ )
613
+
614
+ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
615
+ """Returns instance specific model artifact key.
616
+
617
+ Returns None if a model, instance type tuple does not have specific
618
+ artifact key.
619
+ """
620
+
621
+ return self._get_instance_specific_property(
622
+ instance_type=instance_type, property_name="artifact_key"
623
+ )
624
+
625
+ def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]:
626
+ """Returns instance specific training artifact key.
627
+
628
+ Returns None if a model, instance type tuple does not have specific
629
+ training artifact key.
630
+ """
631
+
632
+ return self._get_instance_specific_property(
633
+ instance_type=instance_type, property_name="training_artifact_uri"
634
+ ) or self._get_instance_specific_property(
635
+ instance_type=instance_type, property_name="training_artifact_key"
636
+ )
637
+
638
+ def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
639
+ """Returns instance specific resource requirements.
640
+
641
+ If a value exists for both the instance family and instance type, the instance type value
642
+ is chosen.
643
+ """
644
+
645
+ instance_specific_resource_requirements: dict = (
646
+ self.variants.get(instance_type, {})
647
+ .get("properties", {})
648
+ .get("resource_requirements", {})
649
+ )
650
+
651
+ instance_type_family = get_instance_type_family(instance_type)
652
+
653
+ instance_family_resource_requirements: dict = (
654
+ self.variants.get(instance_type_family, {})
655
+ .get("properties", {})
656
+ .get("resource_requirements", {})
657
+ )
658
+
659
+ return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
660
+
661
+ def _get_instance_specific_property(
662
+ self, instance_type: str, property_name: str
663
+ ) -> Optional[str]:
664
+ """Returns instance specific property.
665
+
666
+ If a value exists for both the instance family and instance type,
667
+ the instance type value is chosen.
668
+
669
+ Returns None if a (model, instance type, property name) tuple does not have
670
+ specific prepacked artifact key.
671
+ """
672
+
673
+ if self.variants is None:
674
+ return None
675
+
676
+ instance_specific_property: Optional[str] = (
677
+ self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
678
+ )
679
+
680
+ if instance_specific_property:
681
+ return instance_specific_property
682
+
683
+ instance_type_family = get_instance_type_family(instance_type)
684
+
685
+ instance_family_property: Optional[str] = (
686
+ self.variants.get(instance_type_family, {})
687
+ .get("properties", {})
688
+ .get(property_name, None)
689
+ if instance_type_family not in {"", None}
690
+ else None
691
+ )
692
+
693
+ return instance_family_property
694
+
695
+ def get_instance_specific_hyperparameters(
696
+ self, instance_type: str
697
+ ) -> List[JumpStartHyperparameter]:
698
+ """Returns instance specific hyperparameters.
699
+
700
+ Returns empty list if a model, instance type tuple does not have specific
701
+ hyperparameters.
702
+ """
703
+
704
+ if self.variants is None:
705
+ return []
706
+
707
+ instance_specific_hyperparameters: List[JumpStartHyperparameter] = [
708
+ JumpStartHyperparameter(json)
709
+ for json in self.variants.get(instance_type, {})
710
+ .get("properties", {})
711
+ .get("hyperparameters", [])
712
+ ]
713
+
714
+ instance_type_family = get_instance_type_family(instance_type)
715
+
716
+ instance_family_hyperparameters: List[JumpStartHyperparameter] = [
717
+ JumpStartHyperparameter(json)
718
+ for json in (
719
+ self.variants.get(instance_type_family, {})
720
+ .get("properties", {})
721
+ .get("hyperparameters", [])
722
+ if instance_type_family not in {"", None}
723
+ else []
724
+ )
725
+ ]
726
+
727
+ instance_specific_hyperparameter_names = {
728
+ hyperparameter.name for hyperparameter in instance_specific_hyperparameters
729
+ }
730
+
731
+ hyperparams_to_return = deepcopy(instance_specific_hyperparameters)
732
+
733
+ for hyperparameter in instance_family_hyperparameters:
734
+ if hyperparameter.name not in instance_specific_hyperparameter_names:
735
+ hyperparams_to_return.append(hyperparameter)
736
+
737
+ return hyperparams_to_return
738
+
739
+ def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
740
+ """Returns instance specific environment variables.
741
+
742
+ Returns empty dict if a model, instance type tuple does not have specific
743
+ environment variables.
744
+ """
745
+
746
+ if self.variants is None:
747
+ return {}
748
+
749
+ instance_specific_environment_variables: Dict[str, str] = (
750
+ self.variants.get(instance_type, {})
751
+ .get("properties", {})
752
+ .get("environment_variables", {})
753
+ )
754
+
755
+ instance_type_family = get_instance_type_family(instance_type)
756
+
757
+ instance_family_environment_variables: dict = (
758
+ self.variants.get(instance_type_family, {})
759
+ .get("properties", {})
760
+ .get("environment_variables", {})
761
+ if instance_type_family not in {"", None}
762
+ else {}
763
+ )
764
+
765
+ instance_family_environment_variables.update(instance_specific_environment_variables)
766
+
767
+ return instance_family_environment_variables
768
+
769
+ def get_instance_specific_gated_model_key_env_var_value(
770
+ self, instance_type: str
771
+ ) -> Optional[str]:
772
+ """Returns instance specific gated model env var s3 key.
773
+
774
+ Returns None if a model, instance type tuple does not have instance
775
+ specific property.
776
+ """
777
+
778
+ gated_model_key_env_var_value = (
779
+ "gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value"
780
+ )
781
+
782
+ return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value)
783
+
784
+ def get_instance_specific_default_inference_instance_type(
785
+ self, instance_type: str
786
+ ) -> Optional[str]:
787
+ """Returns instance specific default inference instance type.
788
+
789
+ Returns None if a model, instance type tuple does not have instance
790
+ specific inference instance types.
791
+ """
792
+
793
+ return self._get_instance_specific_property(
794
+ instance_type, "default_inference_instance_type"
795
+ )
796
+
797
+ def get_instance_specific_supported_inference_instance_types(
798
+ self, instance_type: str
799
+ ) -> List[str]:
800
+ """Returns instance specific supported inference instance types.
801
+
802
+ Returns empty list if a model, instance type tuple does not have instance
803
+ specific inference instance types.
804
+ """
805
+
806
+ if self.variants is None:
807
+ return []
808
+
809
+ instance_specific_inference_instance_types: List[str] = (
810
+ self.variants.get(instance_type, {})
811
+ .get("properties", {})
812
+ .get("supported_inference_instance_types", [])
813
+ )
814
+
815
+ instance_type_family = get_instance_type_family(instance_type)
816
+
817
+ instance_family_inference_instance_types: List[str] = (
818
+ self.variants.get(instance_type_family, {})
819
+ .get("properties", {})
820
+ .get("supported_inference_instance_types", [])
821
+ if instance_type_family not in {"", None}
822
+ else []
823
+ )
824
+
825
+ return sorted(
826
+ list(
827
+ set(
828
+ instance_specific_inference_instance_types
829
+ + instance_family_inference_instance_types
830
+ )
831
+ )
832
+ )
833
+
834
+ def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]:
835
+ """Returns image uri from instance type and region.
836
+
837
+ Returns None if no instance type is available or found.
838
+ None is also returned if the metadata is improperly formatted.
839
+ """
840
+ return self._get_regional_property(
841
+ instance_type=instance_type, region=region, property_name="image_uri"
842
+ )
843
+
844
+ def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]:
845
+ """Returns model package arn from instance type and region.
846
+
847
+ Returns None if no instance type is available or found.
848
+ None is also returned if the metadata is improperly formatted.
849
+ """
850
+ return self._get_regional_property(
851
+ instance_type=instance_type, region=region, property_name="model_package_arn"
852
+ )
853
+
854
+ def _get_regional_property(
855
+ self, instance_type: str, region: Optional[str], property_name: str
856
+ ) -> Optional[str]:
857
+ """Returns regional property from instance type and region.
858
+
859
+ Returns None if no instance type is available or found.
860
+ None is also returned if the metadata is improperly formatted.
861
+ """
862
+ # pylint: disable=too-many-return-statements
863
+ # if self.variants is None or (self.aliases is None and self.regional_aliases is None):
864
+ # return None
865
+
866
+ if self.variants is None:
867
+ return None
868
+
869
+ if region is None and self.regional_aliases is not None:
870
+ return None
871
+
872
+ regional_property_alias: Optional[str] = None
873
+ regional_property_value: Optional[str] = None
874
+
875
+ if self.regional_aliases:
876
+ regional_property_alias = (
877
+ self.variants.get(instance_type, {})
878
+ .get("regional_properties", {})
879
+ .get(property_name)
880
+ )
881
+ else:
882
+ regional_property_value = (
883
+ self.variants.get(instance_type, {}).get("properties", {}).get(property_name)
884
+ )
885
+
886
+ if regional_property_alias is None and regional_property_value is None:
887
+ instance_type_family = get_instance_type_family(instance_type)
888
+ if instance_type_family in {"", None}:
889
+ return None
890
+ if self.regional_aliases:
891
+ regional_property_alias = (
892
+ self.variants.get(instance_type_family, {})
893
+ .get("regional_properties", {})
894
+ .get(property_name)
895
+ )
896
+ else:
897
+ # if reading from HubContent, aliases are already regionalized
898
+ regional_property_value = (
899
+ self.variants.get(instance_type_family, {})
900
+ .get("properties", {})
901
+ .get(property_name)
902
+ )
903
+
904
+ if (regional_property_alias is None or len(regional_property_alias) == 0) and (
905
+ regional_property_value is None or len(regional_property_value) == 0
906
+ ):
907
+ return None
908
+
909
+ if regional_property_alias and not regional_property_alias.startswith("$"):
910
+ # No leading '$' indicates bad metadata.
911
+ # There are tests to ensure this never happens.
912
+ # However, to allow for fallback options in the unlikely event
913
+ # of a regression, we do not raise an exception here.
914
+ # We return None, indicating the field does not exist.
915
+ return None
916
+
917
+ if self.regional_aliases and region not in self.regional_aliases:
918
+ return None
919
+
920
+ if self.regional_aliases:
921
+ alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None)
922
+ return alias_value
923
+ return regional_property_value
924
+
925
+
926
+ class JumpStartAdditionalDataSources(JumpStartDataHolderType):
927
+ """Data class of additional data sources."""
928
+
929
+ __slots__ = ["speculative_decoding", "scripts"]
930
+
931
+ def __init__(self, spec: Dict[str, Any]):
932
+ """Initializes a AdditionalDataSources object.
933
+
934
+ Args:
935
+ spec (Dict[str, Any]): Dictionary representation of data source.
936
+ """
937
+ self.from_json(spec)
938
+
939
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
940
+ """Sets fields in object based on json.
941
+
942
+ Args:
943
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
944
+ """
945
+ self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = (
946
+ [
947
+ JumpStartModelDataSource(data_source)
948
+ for data_source in json_obj["speculative_decoding"]
949
+ ]
950
+ if json_obj.get("speculative_decoding")
951
+ else None
952
+ )
953
+ self.scripts: Optional[List[JumpStartModelDataSource]] = (
954
+ [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]]
955
+ if json_obj.get("scripts")
956
+ else None
957
+ )
958
+
959
+ def to_json(self) -> Dict[str, Any]:
960
+ """Returns json representation of AdditionalDataSources object."""
961
+ json_obj = {}
962
+ for att in self.__slots__:
963
+ if hasattr(self, att):
964
+ cur_val = getattr(self, att)
965
+ if isinstance(cur_val, list):
966
+ json_obj[att] = []
967
+ for obj in cur_val:
968
+ if issubclass(type(obj), JumpStartDataHolderType):
969
+ json_obj[att].append(obj.to_json())
970
+ else:
971
+ json_obj[att].append(obj)
972
+ else:
973
+ json_obj[att] = cur_val
974
+ return json_obj
975
+
976
+
977
+ class ModelAccessConfig(JumpStartDataHolderType):
978
+ """Data class of model access config that mirrors CreateModel API."""
979
+
980
+ __slots__ = ["accept_eula"]
981
+
982
+ def __init__(self, spec: Dict[str, Any]):
983
+ """Initializes a ModelAccessConfig object.
984
+
985
+ Args:
986
+ spec (Dict[str, Any]): Dictionary representation of data source.
987
+ """
988
+ self.from_json(spec)
989
+
990
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
991
+ """Sets fields in object based on json.
992
+
993
+ Args:
994
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
995
+ """
996
+ self.accept_eula: bool = json_obj["accept_eula"]
997
+
998
+ def to_json(self) -> Dict[str, Any]:
999
+ """Returns json representation of ModelAccessConfig object."""
1000
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1001
+ return json_obj
1002
+
1003
+
1004
+ class HubAccessConfig(JumpStartDataHolderType):
1005
+ """Data class of model access config that mirrors CreateModel API."""
1006
+
1007
+ __slots__ = ["hub_content_arn"]
1008
+
1009
+ def __init__(self, spec: Dict[str, Any]):
1010
+ """Initializes a HubAccessConfig object.
1011
+
1012
+ Args:
1013
+ spec (Dict[str, Any]): Dictionary representation of data source.
1014
+ """
1015
+ self.from_json(spec)
1016
+
1017
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1018
+ """Sets fields in object based on json.
1019
+
1020
+ Args:
1021
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
1022
+ """
1023
+ self.hub_content_arn: bool = json_obj["accept_eula"]
1024
+
1025
+ def to_json(self) -> Dict[str, Any]:
1026
+ """Returns json representation of ModelAccessConfig object."""
1027
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1028
+ return json_obj
1029
+
1030
+
1031
+ class S3DataSource(JumpStartDataHolderType):
1032
+ """Data class of S3 data source that mirrors CreateModel API."""
1033
+
1034
+ __slots__ = [
1035
+ "compression_type",
1036
+ "s3_data_type",
1037
+ "s3_uri",
1038
+ "model_access_config",
1039
+ "hub_access_config",
1040
+ ]
1041
+
1042
+ def __init__(self, spec: Dict[str, Any]):
1043
+ """Initializes a S3DataSource object.
1044
+
1045
+ Args:
1046
+ spec (Dict[str, Any]): Dictionary representation of data source.
1047
+ """
1048
+ self.from_json(spec)
1049
+
1050
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1051
+ """Sets fields in object based on json.
1052
+
1053
+ Args:
1054
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
1055
+ """
1056
+ self.compression_type: str = json_obj["compression_type"]
1057
+ self.s3_data_type: str = json_obj["s3_data_type"]
1058
+ self.s3_uri: str = json_obj["s3_uri"]
1059
+ self.model_access_config: ModelAccessConfig = (
1060
+ ModelAccessConfig(json_obj["model_access_config"])
1061
+ if json_obj.get("model_access_config")
1062
+ else None
1063
+ )
1064
+ self.hub_access_config: HubAccessConfig = (
1065
+ HubAccessConfig(json_obj["hub_access_config"])
1066
+ if json_obj.get("hub_access_config")
1067
+ else None
1068
+ )
1069
+
1070
+ def to_json(self) -> Dict[str, Any]:
1071
+ """Returns json representation of S3DataSource object."""
1072
+ json_obj = {}
1073
+ for att in self.__slots__:
1074
+ if hasattr(self, att):
1075
+ cur_val = getattr(self, att)
1076
+ if issubclass(type(cur_val), JumpStartDataHolderType):
1077
+ json_obj[att] = cur_val.to_json()
1078
+ elif cur_val:
1079
+ json_obj[att] = cur_val
1080
+ return json_obj
1081
+
1082
+ def set_bucket(self, bucket: str) -> None:
1083
+ """Sets bucket name from S3 URI."""
1084
+
1085
+ if self.s3_uri.startswith(S3_PREFIX):
1086
+ s3_path = self.s3_uri[len(S3_PREFIX) :]
1087
+ old_bucket = s3_path.split("/")[0]
1088
+ key = s3_path[len(old_bucket) :]
1089
+ self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201
1090
+ return
1091
+
1092
+ if not bucket.endswith("/"):
1093
+ bucket += "/"
1094
+
1095
+ self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201
1096
+
1097
+
1098
+ class AdditionalModelDataSource(JumpStartDataHolderType):
1099
+ """Data class of additional model data source mirrors CreateModel API."""
1100
+
1101
+ SERIALIZATION_EXCLUSION_SET = {"provider"}
1102
+
1103
+ __slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]
1104
+
1105
+ def __init__(self, spec: Dict[str, Any]):
1106
+ """Initializes a AdditionalModelDataSource object.
1107
+
1108
+ Args:
1109
+ spec (Dict[str, Any]): Dictionary representation of data source.
1110
+ """
1111
+ self.from_json(spec)
1112
+
1113
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1114
+ """Sets fields in object based on json.
1115
+
1116
+ Args:
1117
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
1118
+ """
1119
+ self.channel_name: str = json_obj["channel_name"]
1120
+ self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
1121
+ self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
1122
+ self.provider: Dict = json_obj.get("provider", {})
1123
+
1124
+ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
1125
+ """Returns json representation of AdditionalModelDataSource object."""
1126
+ json_obj = {}
1127
+ for att in self.__slots__:
1128
+ if hasattr(self, att):
1129
+ if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys:
1130
+ cur_val = getattr(self, att)
1131
+ if issubclass(type(cur_val), JumpStartDataHolderType):
1132
+ json_obj[att] = cur_val.to_json()
1133
+ else:
1134
+ json_obj[att] = cur_val
1135
+ return json_obj
1136
+
1137
+
1138
+ class JumpStartModelDataSource(AdditionalModelDataSource):
1139
+ """Data class JumpStart additional model data source."""
1140
+
1141
+ SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
1142
+ {"artifact_version"}
1143
+ )
1144
+
1145
+ __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
1146
+
1147
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1148
+ """Sets fields in object based on json.
1149
+
1150
+ Args:
1151
+ json_obj (Dict[str, Any]): Dictionary representation of data source.
1152
+ """
1153
+ super().from_json(json_obj)
1154
+ self.artifact_version: str = json_obj["artifact_version"]
1155
+
1156
+
1157
+ class JumpStartBenchmarkStat(JumpStartDataHolderType):
1158
+ """Data class JumpStart benchmark stat."""
1159
+
1160
+ __slots__ = ["name", "value", "unit", "concurrency"]
1161
+
1162
+ def __init__(self, spec: Dict[str, Any]):
1163
+ """Initializes a JumpStartBenchmarkStat object.
1164
+
1165
+ Args:
1166
+ spec (Dict[str, Any]): Dictionary representation of benchmark stat.
1167
+ """
1168
+ self.from_json(spec)
1169
+
1170
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1171
+ """Sets fields in object based on json.
1172
+
1173
+ Args:
1174
+ json_obj (Dict[str, Any]): Dictionary representation of benchmark stats.
1175
+ """
1176
+ self.name: str = json_obj["name"]
1177
+ self.value: str = json_obj["value"]
1178
+ self.unit: Union[int, str] = json_obj["unit"]
1179
+ self.concurrency: Union[int, str] = json_obj["concurrency"]
1180
+
1181
+ def to_json(self) -> Dict[str, Any]:
1182
+ """Returns json representation of JumpStartBenchmarkStat object."""
1183
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1184
+ return json_obj
1185
+
1186
+
1187
+ class JumpStartConfigRanking(JumpStartDataHolderType):
1188
+ """Data class JumpStart config ranking."""
1189
+
1190
+ __slots__ = ["description", "rankings"]
1191
+
1192
+ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
1193
+ """Initializes a JumpStartConfigRanking object.
1194
+
1195
+ Args:
1196
+ spec (Dict[str, Any]): Dictionary representation of training config ranking.
1197
+ """
1198
+ if is_hub_content:
1199
+ spec = walk_and_apply_json(spec, camel_to_snake)
1200
+ self.from_json(spec)
1201
+
1202
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1203
+ """Sets fields in object based on json.
1204
+
1205
+ Args:
1206
+ json_obj (Dict[str, Any]): Dictionary representation of config ranking.
1207
+ """
1208
+ self.description: str = json_obj["description"]
1209
+ self.rankings: List[str] = json_obj["rankings"]
1210
+
1211
+ def to_json(self) -> Dict[str, Any]:
1212
+ """Returns json representation of JumpStartConfigRanking object."""
1213
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1214
+ return json_obj
1215
+
1216
+
1217
+ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
1218
+ """Data class JumpStart metadata base fields that can be overridden."""
1219
+
1220
+ __slots__ = [
1221
+ "model_id",
1222
+ "url",
1223
+ "version",
1224
+ "min_sdk_version",
1225
+ "model_types",
1226
+ "capabilities",
1227
+ "incremental_training_supported",
1228
+ "hosting_ecr_specs",
1229
+ "hosting_ecr_uri",
1230
+ "hosting_artifact_uri",
1231
+ "hosting_artifact_key",
1232
+ "hosting_script_key",
1233
+ "training_supported",
1234
+ "training_ecr_specs",
1235
+ "training_ecr_uri",
1236
+ "training_artifact_key",
1237
+ "training_script_key",
1238
+ "hyperparameters",
1239
+ "inference_environment_variables",
1240
+ "inference_vulnerable",
1241
+ "inference_dependencies",
1242
+ "inference_vulnerabilities",
1243
+ "training_vulnerable",
1244
+ "training_dependencies",
1245
+ "training_vulnerabilities",
1246
+ "deprecated",
1247
+ "usage_info_message",
1248
+ "deprecated_message",
1249
+ "deprecate_warn_message",
1250
+ "default_inference_instance_type",
1251
+ "supported_inference_instance_types",
1252
+ "dynamic_container_deployment_supported",
1253
+ "hosting_resource_requirements",
1254
+ "default_training_instance_type",
1255
+ "supported_training_instance_types",
1256
+ "metrics",
1257
+ "training_prepacked_script_key",
1258
+ "training_prepacked_script_version",
1259
+ "hosting_prepacked_artifact_key",
1260
+ "hosting_prepacked_artifact_version",
1261
+ "model_kwargs",
1262
+ "deploy_kwargs",
1263
+ "estimator_kwargs",
1264
+ "fit_kwargs",
1265
+ "predictor_specs",
1266
+ "inference_volume_size",
1267
+ "training_volume_size",
1268
+ "inference_enable_network_isolation",
1269
+ "training_enable_network_isolation",
1270
+ "resource_name_base",
1271
+ "hosting_eula_key",
1272
+ "hosting_model_package_arns",
1273
+ "training_model_package_artifact_uris",
1274
+ "hosting_use_script_uri",
1275
+ "hosting_instance_type_variants",
1276
+ "training_instance_type_variants",
1277
+ "default_payloads",
1278
+ "gated_bucket",
1279
+ "model_subscription_link",
1280
+ "hosting_additional_data_sources",
1281
+ "hosting_neuron_model_id",
1282
+ "hosting_neuron_model_version",
1283
+ "hub_content_type",
1284
+ "_is_hub_content",
1285
+ "default_training_dataset_key",
1286
+ "default_training_dataset_uri",
1287
+ ]
1288
+
1289
+ _non_serializable_slots = ["_is_hub_content"]
1290
+
1291
+ def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False):
1292
+ """Initializes a JumpStartMetadataFields object.
1293
+
1294
+ Args:
1295
+ fields (Dict[str, Any]): Dictionary representation of metadata fields.
1296
+ """
1297
+ self._is_hub_content = is_hub_content
1298
+ self.from_json(fields)
1299
+
1300
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1301
+ """Sets fields in object based on json of header.
1302
+
1303
+ Args:
1304
+ json_obj (Dict[str, Any]): Dictionary representation of spec.
1305
+ """
1306
+ if self._is_hub_content:
1307
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1308
+ self.model_id: str = json_obj.get("model_id")
1309
+ self.url: str = json_obj.get("url")
1310
+ self.version: str = json_obj.get("version")
1311
+ self.min_sdk_version: str = json_obj.get("min_sdk_version")
1312
+ self.incremental_training_supported: bool = bool(
1313
+ json_obj.get("incremental_training_supported", False)
1314
+ )
1315
+ if self._is_hub_content:
1316
+ self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
1317
+ self.model_types: Optional[List[str]] = json_obj.get("model_types")
1318
+ self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
1319
+ self._non_serializable_slots.append("hosting_ecr_specs")
1320
+ else:
1321
+ self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
1322
+ JumpStartECRSpecs(
1323
+ json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content
1324
+ )
1325
+ if "hosting_ecr_specs" in json_obj
1326
+ else None
1327
+ )
1328
+ self._non_serializable_slots.append("hosting_ecr_uri")
1329
+ self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key")
1330
+ self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri")
1331
+ self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key")
1332
+ self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False))
1333
+ self.inference_environment_variables = [
1334
+ JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content)
1335
+ for env_variable in json_obj.get("inference_environment_variables", [])
1336
+ ]
1337
+ self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False))
1338
+ self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", [])
1339
+ self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", [])
1340
+ self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False))
1341
+ self.training_dependencies: List[str] = json_obj.get("training_dependencies", [])
1342
+ self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", [])
1343
+ self.deprecated: bool = bool(json_obj.get("deprecated", False))
1344
+ self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
1345
+ self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
1346
+ self.usage_info_message: Optional[str] = json_obj.get("usage_info_message")
1347
+ self.default_inference_instance_type: Optional[str] = json_obj.get(
1348
+ "default_inference_instance_type"
1349
+ )
1350
+ self.default_training_instance_type: Optional[str] = json_obj.get(
1351
+ "default_training_instance_type"
1352
+ )
1353
+ self.supported_inference_instance_types: Optional[List[str]] = json_obj.get(
1354
+ "supported_inference_instance_types"
1355
+ )
1356
+ self.supported_training_instance_types: Optional[List[str]] = json_obj.get(
1357
+ "supported_training_instance_types"
1358
+ )
1359
+ self.dynamic_container_deployment_supported: Optional[bool] = bool(
1360
+ json_obj.get("dynamic_container_deployment_supported")
1361
+ )
1362
+ self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get(
1363
+ "hosting_resource_requirements", None
1364
+ )
1365
+ self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
1366
+ self.training_prepacked_script_key: Optional[str] = json_obj.get(
1367
+ "training_prepacked_script_key", None
1368
+ )
1369
+ self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
1370
+ "hosting_prepacked_artifact_key", None
1371
+ )
1372
+ # New fields required for Hub model.
1373
+ if self._is_hub_content:
1374
+ self.training_prepacked_script_version: Optional[str] = json_obj.get(
1375
+ "training_prepacked_script_version"
1376
+ )
1377
+ self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get(
1378
+ "hosting_prepacked_artifact_version"
1379
+ )
1380
+ self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {}))
1381
+ self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
1382
+ self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
1383
+ JumpStartPredictorSpecs(
1384
+ json_obj.get("predictor_specs"),
1385
+ is_hub_content=self._is_hub_content,
1386
+ )
1387
+ if json_obj.get("predictor_specs")
1388
+ else None
1389
+ )
1390
+ self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
1391
+ {
1392
+ alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content)
1393
+ for alias, payload in json_obj["default_payloads"].items()
1394
+ }
1395
+ if json_obj.get("default_payloads")
1396
+ else None
1397
+ )
1398
+ self.gated_bucket = json_obj.get("gated_bucket", False)
1399
+ self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
1400
+ self.inference_enable_network_isolation: bool = json_obj.get(
1401
+ "inference_enable_network_isolation", False
1402
+ )
1403
+ self.resource_name_base: bool = json_obj.get("resource_name_base")
1404
+
1405
+ self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
1406
+
1407
+ model_package_arns = json_obj.get("hosting_model_package_arns")
1408
+ self.hosting_model_package_arns: Optional[Dict] = (
1409
+ model_package_arns if model_package_arns is not None else {}
1410
+ )
1411
+
1412
+ self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
1413
+
1414
+ self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
1415
+ JumpStartInstanceTypeVariants(
1416
+ json_obj["hosting_instance_type_variants"], self._is_hub_content
1417
+ )
1418
+ if json_obj.get("hosting_instance_type_variants")
1419
+ else None
1420
+ )
1421
+ self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = (
1422
+ JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"])
1423
+ if json_obj.get("hosting_additional_data_sources")
1424
+ else None
1425
+ )
1426
+ self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id")
1427
+ self.hosting_neuron_model_version: Optional[str] = json_obj.get(
1428
+ "hosting_neuron_model_version"
1429
+ )
1430
+
1431
+ if self.training_supported:
1432
+ if self._is_hub_content:
1433
+ self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri")
1434
+ self._non_serializable_slots.append("training_ecr_specs")
1435
+ else:
1436
+ self.training_ecr_specs: Optional[JumpStartECRSpecs] = (
1437
+ JumpStartECRSpecs(json_obj["training_ecr_specs"])
1438
+ if "training_ecr_specs" in json_obj
1439
+ else None
1440
+ )
1441
+ self._non_serializable_slots.append("training_ecr_uri")
1442
+ self.training_artifact_key: str = json_obj["training_artifact_key"]
1443
+ self.training_script_key: str = json_obj["training_script_key"]
1444
+ hyperparameters: Any = json_obj.get("hyperparameters")
1445
+ self.hyperparameters: List[JumpStartHyperparameter] = []
1446
+ if hyperparameters is not None:
1447
+ self.hyperparameters.extend(
1448
+ [
1449
+ JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content)
1450
+ for hyperparameter in hyperparameters
1451
+ ]
1452
+ )
1453
+ self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {}))
1454
+ self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {}))
1455
+ self.training_volume_size: Optional[int] = json_obj.get("training_volume_size")
1456
+ self.training_enable_network_isolation: bool = json_obj.get(
1457
+ "training_enable_network_isolation", False
1458
+ )
1459
+ self.training_model_package_artifact_uris: Optional[Dict] = json_obj.get(
1460
+ "training_model_package_artifact_uris"
1461
+ )
1462
+ self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
1463
+ JumpStartInstanceTypeVariants(
1464
+ json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content
1465
+ )
1466
+ if json_obj.get("training_instance_type_variants")
1467
+ else None
1468
+ )
1469
+ self.model_subscription_link = json_obj.get("model_subscription_link")
1470
+ self.default_training_dataset_key: Optional[str] = json_obj.get(
1471
+ "default_training_dataset_key"
1472
+ )
1473
+ self.default_training_dataset_uri: Optional[str] = json_obj.get(
1474
+ "default_training_dataset_uri"
1475
+ )
1476
+
1477
+ def to_json(self) -> Dict[str, Any]:
1478
+ """Returns json representation of JumpStartMetadataBaseFields object."""
1479
+ json_obj = {}
1480
+ for att in self.__slots__:
1481
+ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []):
1482
+ cur_val = getattr(self, att)
1483
+ if issubclass(type(cur_val), JumpStartDataHolderType):
1484
+ json_obj[att] = cur_val.to_json()
1485
+ elif isinstance(cur_val, list):
1486
+ json_obj[att] = []
1487
+ for obj in cur_val:
1488
+ if issubclass(type(obj), JumpStartDataHolderType):
1489
+ json_obj[att].append(obj.to_json())
1490
+ else:
1491
+ json_obj[att].append(obj)
1492
+ elif isinstance(cur_val, dict):
1493
+ json_obj[att] = {}
1494
+ for key, val in cur_val.items():
1495
+ if issubclass(type(val), JumpStartDataHolderType):
1496
+ json_obj[att][key] = val.to_json()
1497
+ else:
1498
+ json_obj[att][key] = val
1499
+ else:
1500
+ json_obj[att] = cur_val
1501
+ return json_obj
1502
+
1503
+ def set_hub_content_type(self, hub_content_type: HubContentType) -> None:
1504
+ """Sets the hub content type."""
1505
+ if self._is_hub_content:
1506
+ self.hub_content_type = hub_content_type
1507
+
1508
+
1509
+ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
1510
+ """Data class of JumpStart config component."""
1511
+
1512
+ slots = ["component_name"]
1513
+
1514
+ # List of fields that is not allowed to override to JumpStartMetadataBaseFields
1515
+ OVERRIDING_DENY_LIST = [
1516
+ "model_id",
1517
+ "url",
1518
+ "version",
1519
+ "min_sdk_version",
1520
+ "deprecated",
1521
+ "deprecated_message",
1522
+ "deprecate_warn_message",
1523
+ "resource_name_base",
1524
+ "gated_bucket",
1525
+ "training_supported",
1526
+ "incremental_training_supported",
1527
+ ]
1528
+
1529
+ # Map of HubContent fields that map to custom names in MetadataBaseFields
1530
+ CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}
1531
+
1532
+ __slots__ = slots + JumpStartMetadataBaseFields.__slots__
1533
+
1534
+ def __init__(
1535
+ self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False
1536
+ ):
1537
+ """Initializes a JumpStartConfigComponent object from its json representation.
1538
+
1539
+ Args:
1540
+ component_name (str): Name of the component.
1541
+ component (Dict[str, Any]):
1542
+ Dictionary representation of the config component.
1543
+ Raises:
1544
+ ValueError: If the component field is invalid.
1545
+ """
1546
+ if is_hub_content:
1547
+ component = walk_and_apply_json(component, camel_to_snake)
1548
+ self.component_name = component_name
1549
+ super().__init__(component, is_hub_content)
1550
+ self.from_json(component)
1551
+
1552
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1553
+ """Initializes a JumpStartConfigComponent object from its json representation.
1554
+
1555
+ Args:
1556
+ json_obj (Dict[str, Any]):
1557
+ Dictionary representation of the config component.
1558
+ """
1559
+ for field in json_obj.keys():
1560
+ if field in self.__slots__:
1561
+ setattr(self, field, json_obj[field])
1562
+
1563
+ # Handle custom fields
1564
+ for custom_field, field in self.CUSTOM_FIELD_MAP.items():
1565
+ if custom_field in json_obj:
1566
+ setattr(self, field, json_obj.get(custom_field))
1567
+
1568
+
1569
+ class JumpStartMetadataConfig(JumpStartDataHolderType):
1570
+ """Data class of JumpStart metadata config."""
1571
+
1572
+ __slots__ = [
1573
+ "base_fields",
1574
+ "benchmark_metrics",
1575
+ "acceleration_configs",
1576
+ "config_components",
1577
+ "resolved_metadata_config",
1578
+ "config_name",
1579
+ "default_inference_config",
1580
+ "default_incremental_training_config",
1581
+ "supported_inference_configs",
1582
+ "supported_incremental_training_configs",
1583
+ ]
1584
+
1585
+ def __init__(
1586
+ self,
1587
+ config_name: str,
1588
+ config: Dict[str, Any],
1589
+ base_fields: Dict[str, Any],
1590
+ config_components: Dict[str, JumpStartConfigComponent],
1591
+ is_hub_content=False,
1592
+ ):
1593
+ """Initializes a JumpStartMetadataConfig object from its json representation.
1594
+
1595
+ Args:
1596
+ config_name (str): Name of the config,
1597
+ config (Dict[str, Any]):
1598
+ Dictionary representation of the config.
1599
+ base_fields (Dict[str, Any]):
1600
+ The default base fields that are used to construct the resolved config.
1601
+ config_components (Dict[str, JumpStartConfigComponent]):
1602
+ The list of components that are used to construct the resolved config.
1603
+ """
1604
+ if is_hub_content:
1605
+ config = walk_and_apply_json(config, camel_to_snake)
1606
+ base_fields = walk_and_apply_json(base_fields, camel_to_snake)
1607
+ self.base_fields = base_fields
1608
+ self.config_components: Dict[str, JumpStartConfigComponent] = config_components
1609
+ self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
1610
+ {
1611
+ stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
1612
+ for stat_name, stats in config.get("benchmark_metrics").items()
1613
+ }
1614
+ if config and config.get("benchmark_metrics")
1615
+ else None
1616
+ )
1617
+ self.acceleration_configs = config.get("acceleration_configs")
1618
+ self.resolved_metadata_config: Optional[Dict[str, Any]] = None
1619
+ self.config_name: Optional[str] = config_name
1620
+ self.default_inference_config: Optional[str] = config.get("default_inference_config")
1621
+ self.default_incremental_training_config: Optional[str] = config.get(
1622
+ "default_incremental_training_config"
1623
+ )
1624
+ self.supported_inference_configs: Optional[List[str]] = config.get(
1625
+ "supported_inference_configs"
1626
+ )
1627
+ self.supported_incremental_training_configs: Optional[List[str]] = config.get(
1628
+ "supported_incremental_training_configs"
1629
+ )
1630
+
1631
+ def to_json(self) -> Dict[str, Any]:
1632
+ """Returns json representation of JumpStartMetadataConfig object."""
1633
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1634
+ return json_obj
1635
+
1636
+ @property
1637
+ def resolved_config(self) -> Dict[str, Any]:
1638
+ """Returns the final config that is resolved from the components map.
1639
+
1640
+ Construct the final config by applying the list of configs from list index,
1641
+ and apply to the base default fields in the current model specs.
1642
+ """
1643
+ if self.resolved_metadata_config:
1644
+ return self.resolved_metadata_config
1645
+
1646
+ resolved_config = JumpStartMetadataBaseFields(self.base_fields)
1647
+ for component in self.config_components.values():
1648
+ resolved_config = deep_override_dict(
1649
+ deepcopy(resolved_config.to_json()),
1650
+ deepcopy(component.to_json()),
1651
+ component.OVERRIDING_DENY_LIST,
1652
+ )
1653
+
1654
+ # Remove environment variables from resolved config if using model packages
1655
+ hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns")
1656
+ if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}:
1657
+ resolved_config["inference_environment_variables"] = []
1658
+
1659
+ self.resolved_metadata_config = resolved_config
1660
+
1661
+ return resolved_config
1662
+
1663
+
1664
+ class JumpStartMetadataConfigs(JumpStartDataHolderType):
1665
+ """Data class to hold the set of JumpStart Metadata configs."""
1666
+
1667
+ __slots__ = ["configs", "config_rankings", "scope"]
1668
+
1669
+ def __init__(
1670
+ self,
1671
+ configs: Optional[Dict[str, JumpStartMetadataConfig]],
1672
+ config_rankings: Optional[Dict[str, JumpStartConfigRanking]],
1673
+ scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
1674
+ ):
1675
+ """Initializes a JumpStartMetadataConfigs object.
1676
+
1677
+ Args:
1678
+ configs (Dict[str, JumpStartMetadataConfig]):
1679
+ The map of JumpStartMetadataConfig object, with config name being the key.
1680
+ config_rankings (JumpStartConfigRanking):
1681
+ Config ranking class represents the ranking of the configs in the model.
1682
+ scope (JumpStartScriptScope):
1683
+ The scope of the current config (inference or training)
1684
+ """
1685
+ self.configs = configs
1686
+ self.config_rankings = config_rankings
1687
+ self.scope = scope
1688
+
1689
+ def to_json(self) -> Dict[str, Any]:
1690
+ """Returns json representation of JumpStartMetadataConfigs object."""
1691
+ json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
1692
+ return json_obj
1693
+
1694
+ def get_top_config_from_ranking(
1695
+ self,
1696
+ ranking_name: str = JumpStartConfigRankingName.DEFAULT,
1697
+ instance_type: Optional[str] = None,
1698
+ ) -> Optional[JumpStartMetadataConfig]:
1699
+ """Gets the best the config based on config ranking.
1700
+
1701
+ Fallback to use the ordering in config names if
1702
+ ranking is not available.
1703
+ Args:
1704
+ ranking_name (str):
1705
+ The ranking name that config priority is based on.
1706
+ instance_type (Optional[str]):
1707
+ The instance type which the config selection is based on.
1708
+
1709
+ Raises:
1710
+ NotImplementedError: If the scope is unrecognized.
1711
+ """
1712
+
1713
+ if self.scope == JumpStartScriptScope.INFERENCE:
1714
+ instance_type_attribute = "supported_inference_instance_types"
1715
+ elif self.scope == JumpStartScriptScope.TRAINING:
1716
+ instance_type_attribute = "supported_training_instance_types"
1717
+ else:
1718
+ raise NotImplementedError(f"Unknown script scope {self.scope}")
1719
+
1720
+ if self.configs and (
1721
+ not self.config_rankings or not self.config_rankings.get(ranking_name)
1722
+ ):
1723
+ ranked_config_names = sorted(list(self.configs.keys()))
1724
+ else:
1725
+ rankings = self.config_rankings.get(ranking_name)
1726
+ ranked_config_names = rankings.rankings
1727
+ for config_name in ranked_config_names:
1728
+ resolved_config = self.configs[config_name].resolved_config
1729
+ if instance_type and instance_type not in getattr(
1730
+ resolved_config, instance_type_attribute
1731
+ ):
1732
+ continue
1733
+ return self.configs[config_name]
1734
+
1735
+ return None
1736
+
1737
+
1738
+ class JumpStartModelSpecs(JumpStartMetadataBaseFields):
1739
+ """Data class JumpStart model specs."""
1740
+
1741
+ slots = [
1742
+ "inference_configs",
1743
+ "inference_config_components",
1744
+ "inference_config_rankings",
1745
+ "training_configs",
1746
+ "training_config_components",
1747
+ "training_config_rankings",
1748
+ ]
1749
+
1750
+ __slots__ = JumpStartMetadataBaseFields.__slots__ + slots
1751
+
1752
+ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
1753
+ """Initializes a JumpStartModelSpecs object from its json representation.
1754
+
1755
+ Args:
1756
+ spec (Dict[str, Any]): Dictionary representation of spec.
1757
+ is_hub_content (Optional[bool]): Whether the model is from a private hub.
1758
+ """
1759
+ super().__init__(spec, is_hub_content)
1760
+ self.from_json(spec)
1761
+ if self.inference_configs and self.inference_configs.get_top_config_from_ranking():
1762
+ super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config)
1763
+
1764
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
1765
+ """Sets fields in object based on json of header.
1766
+
1767
+ Args:
1768
+ json_obj (Dict[str, Any]): Dictionary representation of spec.
1769
+ """
1770
+ super().from_json(json_obj)
1771
+ if self._is_hub_content:
1772
+ json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1773
+ self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
1774
+ {
1775
+ component_name: JumpStartConfigComponent(component_name, component)
1776
+ for component_name, component in json_obj["inference_config_components"].items()
1777
+ }
1778
+ if json_obj.get("inference_config_components")
1779
+ else None
1780
+ )
1781
+ self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
1782
+ {
1783
+ alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
1784
+ for alias, ranking in json_obj["inference_config_rankings"].items()
1785
+ }
1786
+ if json_obj.get("inference_config_rankings")
1787
+ else None
1788
+ )
1789
+
1790
+ if self._is_hub_content:
1791
+ inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1792
+ {
1793
+ alias: JumpStartMetadataConfig(
1794
+ alias,
1795
+ config,
1796
+ json_obj,
1797
+ config.config_components,
1798
+ is_hub_content=self._is_hub_content,
1799
+ )
1800
+ for alias, config in json_obj["inference_configs"]["configs"].items()
1801
+ }
1802
+ if json_obj.get("inference_configs")
1803
+ else None
1804
+ )
1805
+ else:
1806
+ inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1807
+ {
1808
+ alias: JumpStartMetadataConfig(
1809
+ alias,
1810
+ config,
1811
+ json_obj,
1812
+ (
1813
+ {
1814
+ component_name: self.inference_config_components.get(component_name)
1815
+ for component_name in config.get("component_names")
1816
+ }
1817
+ if config and config.get("component_names")
1818
+ else None
1819
+ ),
1820
+ )
1821
+ for alias, config in json_obj["inference_configs"].items()
1822
+ }
1823
+ if json_obj.get("inference_configs")
1824
+ else None
1825
+ )
1826
+
1827
+ self.inference_configs: Optional[JumpStartMetadataConfigs] = (
1828
+ JumpStartMetadataConfigs(
1829
+ inference_configs_dict,
1830
+ self.inference_config_rankings,
1831
+ )
1832
+ if json_obj.get("inference_configs")
1833
+ else None
1834
+ )
1835
+
1836
+ if self.training_supported:
1837
+ self.training_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
1838
+ {
1839
+ alias: JumpStartConfigComponent(alias, component)
1840
+ for alias, component in json_obj["training_config_components"].items()
1841
+ }
1842
+ if json_obj.get("training_config_components")
1843
+ else None
1844
+ )
1845
+ self.training_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
1846
+ {
1847
+ alias: JumpStartConfigRanking(ranking)
1848
+ for alias, ranking in json_obj["training_config_rankings"].items()
1849
+ }
1850
+ if json_obj.get("training_config_rankings")
1851
+ else None
1852
+ )
1853
+
1854
+ if self._is_hub_content:
1855
+ training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1856
+ {
1857
+ alias: JumpStartMetadataConfig(
1858
+ alias,
1859
+ config,
1860
+ json_obj,
1861
+ config.config_components,
1862
+ is_hub_content=self._is_hub_content,
1863
+ )
1864
+ for alias, config in json_obj["training_configs"]["configs"].items()
1865
+ }
1866
+ if json_obj.get("training_configs")
1867
+ else None
1868
+ )
1869
+ else:
1870
+ training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1871
+ {
1872
+ alias: JumpStartMetadataConfig(
1873
+ alias,
1874
+ config,
1875
+ json_obj,
1876
+ (
1877
+ {
1878
+ component_name: self.training_config_components.get(
1879
+ component_name
1880
+ )
1881
+ for component_name in config.get("component_names")
1882
+ }
1883
+ if config and config.get("component_names")
1884
+ else None
1885
+ ),
1886
+ )
1887
+ for alias, config in json_obj["training_configs"].items()
1888
+ }
1889
+ if json_obj.get("training_configs")
1890
+ else None
1891
+ )
1892
+
1893
+ self.training_configs: Optional[JumpStartMetadataConfigs] = (
1894
+ JumpStartMetadataConfigs(
1895
+ training_configs_dict,
1896
+ self.training_config_rankings,
1897
+ JumpStartScriptScope.TRAINING,
1898
+ )
1899
+ if json_obj.get("training_configs")
1900
+ else None
1901
+ )
1902
+ self.model_subscription_link = json_obj.get("model_subscription_link")
1903
+
1904
+ def set_config(
1905
+ self, config_name: str, scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE
1906
+ ) -> None:
1907
+ """Apply the seleted config and resolve to the current model spec.
1908
+
1909
+ Args:
1910
+ config_name (str): Name of the config.
1911
+ scope (JumpStartScriptScope, optional):
1912
+ Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.
1913
+
1914
+ Raises:
1915
+ ValueError: If the scope is not supported, or cannot find config name.
1916
+ """
1917
+ if scope == JumpStartScriptScope.INFERENCE:
1918
+ metadata_configs = self.inference_configs
1919
+ elif scope == JumpStartScriptScope.TRAINING and self.training_supported:
1920
+ metadata_configs = self.training_configs
1921
+ else:
1922
+ raise ValueError(f"Unknown Jumpstart script scope {scope}.")
1923
+
1924
+ config_object = metadata_configs.configs.get(config_name)
1925
+ if not config_object:
1926
+ error_msg = f"Cannot find Jumpstart config name {config_name}. "
1927
+ config_names = list(metadata_configs.configs.keys())
1928
+ if config_names:
1929
+ error_msg += f"List of config names that is supported by the model: {config_names}"
1930
+ raise ValueError(error_msg)
1931
+
1932
+ super().from_json(config_object.resolved_config)
1933
+
1934
+ def supports_prepacked_inference(self) -> bool:
1935
+ """Returns True if the model has a prepacked inference artifact."""
1936
+ return getattr(self, "hosting_prepacked_artifact_key", None) is not None
1937
+
1938
+ def use_inference_script_uri(self) -> bool:
1939
+ """Returns True if the model should use a script uri when deploying inference model."""
1940
+ if self.supports_prepacked_inference():
1941
+ return False
1942
+ return self.hosting_use_script_uri
1943
+
1944
+ def use_training_model_artifact(self) -> bool:
1945
+ """Returns True if the model should use a model uri when kicking off training job."""
1946
+ # gated model never use training model artifact
1947
+ if self.gated_bucket:
1948
+ return False
1949
+
1950
+ # otherwise, return true is a training model package is not set
1951
+ return len(self.training_model_package_artifact_uris or {}) == 0
1952
+
1953
+ def is_gated_model(self) -> bool:
1954
+ """Returns True if the model has a EULA key or the model bucket is gated."""
1955
+ return self.gated_bucket or self.hosting_eula_key is not None
1956
+
1957
+ def supports_incremental_training(self) -> bool:
1958
+ """Returns True if the model supports incremental training."""
1959
+ return self.incremental_training_supported
1960
+
1961
+ def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]:
1962
+ """Returns data sources for speculative decoding."""
1963
+ if not self.hosting_additional_data_sources:
1964
+ return []
1965
+ return self.hosting_additional_data_sources.speculative_decoding or []
1966
+
1967
+ def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:
1968
+ """Returns a list of the additional S3 data sources for use by the model."""
1969
+ additional_data_sources = []
1970
+ if self.hosting_additional_data_sources:
1971
+ for data_source in self.hosting_additional_data_sources.to_json():
1972
+ data_sources = getattr(self.hosting_additional_data_sources, data_source) or []
1973
+ additional_data_sources.extend(data_sources)
1974
+ return additional_data_sources
1975
+
1976
+
1977
+ class JumpStartVersionedModelId(JumpStartDataHolderType):
1978
+ """Data class for versioned model IDs."""
1979
+
1980
+ __slots__ = ["model_id", "version"]
1981
+
1982
+ def __init__(
1983
+ self,
1984
+ model_id: str,
1985
+ version: str,
1986
+ ) -> None:
1987
+ """Instantiates JumpStartVersionedModelId object.
1988
+
1989
+ Args:
1990
+ model_id (str): JumpStart model ID.
1991
+ version (str): JumpStart model version.
1992
+ """
1993
+ self.model_id = model_id
1994
+ self.version = version
1995
+
1996
+
1997
+ class JumpStartCachedContentKey(JumpStartDataHolderType):
1998
+ """Data class for the cached content keys."""
1999
+
2000
+ __slots__ = ["data_type", "id_info"]
2001
+
2002
+ def __init__(
2003
+ self,
2004
+ data_type: JumpStartContentDataType,
2005
+ id_info: str,
2006
+ ) -> None:
2007
+ """Instantiates JumpStartCachedContentKey object.
2008
+
2009
+ Args:
2010
+ data_type (JumpStartContentDataType): JumpStart content data type.
2011
+ id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
2012
+ """
2013
+ self.data_type = data_type
2014
+ self.id_info = id_info
2015
+
2016
+
2017
+ class HubArnExtractedInfo(JumpStartDataHolderType):
2018
+ """Data class for info extracted from Hub arn."""
2019
+
2020
+ __slots__ = [
2021
+ "partition",
2022
+ "region",
2023
+ "account_id",
2024
+ "hub_name",
2025
+ "hub_content_type",
2026
+ "hub_content_name",
2027
+ "hub_content_version",
2028
+ ]
2029
+
2030
+ def __init__(
2031
+ self,
2032
+ partition: str,
2033
+ region: str,
2034
+ account_id: str,
2035
+ hub_name: str,
2036
+ hub_content_type: Optional[str] = None,
2037
+ hub_content_name: Optional[str] = None,
2038
+ hub_content_version: Optional[str] = None,
2039
+ ) -> None:
2040
+ """Instantiates HubArnExtractedInfo object."""
2041
+
2042
+ self.partition = partition
2043
+ self.region = region
2044
+ self.account_id = account_id
2045
+ self.hub_name = hub_name
2046
+ self.hub_content_name = hub_content_name
2047
+ self.hub_content_type = hub_content_type
2048
+ self.hub_content_version = hub_content_version
2049
+
2050
+ @staticmethod
2051
+ def extract_region_from_arn(arn: str) -> Optional[str]:
2052
+ """Extracts hub_name, content_name, and content_version from a HubContentArn"""
2053
+
2054
+ HUB_CONTENT_ARN_REGEX = (
2055
+ r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
2056
+ )
2057
+ HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
2058
+
2059
+ match = re.match(HUB_CONTENT_ARN_REGEX, arn)
2060
+ hub_region = None
2061
+ if match:
2062
+ hub_region = match.group(2)
2063
+ return hub_region
2064
+
2065
+ match = re.match(HUB_ARN_REGEX, arn)
2066
+ if match:
2067
+ hub_region = match.group(2)
2068
+ return hub_region
2069
+
2070
+ return hub_region
2071
+
2072
+
2073
+ class JumpStartCachedContentValue(JumpStartDataHolderType):
2074
+ """Data class for the s3 cached content values."""
2075
+
2076
+ __slots__ = ["formatted_content", "md5_hash"]
2077
+
2078
+ def __init__(
2079
+ self,
2080
+ formatted_content: Union[
2081
+ Dict[JumpStartVersionedModelId, JumpStartModelHeader],
2082
+ JumpStartModelSpecs,
2083
+ ],
2084
+ md5_hash: Optional[str] = None,
2085
+ ) -> None:
2086
+ """Instantiates JumpStartCachedContentValue object.
2087
+
2088
+ Args:
2089
+ formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],
2090
+ JumpStartModelSpecs]):
2091
+ Formatted content for model specs and mappings from
2092
+ versioned model IDs to specs.
2093
+ md5_hash (str): md5_hash for stored file content from s3.
2094
+ """
2095
+ self.formatted_content = formatted_content
2096
+ self.md5_hash = md5_hash
2097
+
2098
+
2099
+ class JumpStartKwargs(JumpStartDataHolderType):
2100
+ """Data class for JumpStart object kwargs."""
2101
+
2102
+ BASE_SERIALIZATION_EXCLUSION_SET: Set[str] = ["specs"]
2103
+ SERIALIZATION_EXCLUSION_SET: Set[str] = set()
2104
+
2105
+ def to_kwargs_dict(self, exclude_keys: bool = True):
2106
+ """Serializes object to dictionary to be used for kwargs for method arguments."""
2107
+ kwargs_dict = {}
2108
+ for field in self.__slots__:
2109
+ if (
2110
+ exclude_keys
2111
+ and field
2112
+ not in self.SERIALIZATION_EXCLUSION_SET.union(self.BASE_SERIALIZATION_EXCLUSION_SET)
2113
+ or not exclude_keys
2114
+ ):
2115
+ att_value = getattr(self, field, None)
2116
+ if att_value is not None:
2117
+ kwargs_dict[field] = getattr(self, field)
2118
+ return kwargs_dict
2119
+
2120
+
2121
+ class JumpStartModelInitKwargs(JumpStartKwargs):
2122
+ """Data class for the inputs to `JumpStartModel.__init__` method."""
2123
+
2124
+ __slots__ = [
2125
+ "model_id",
2126
+ "model_version",
2127
+ "hub_arn",
2128
+ "model_type",
2129
+ "instance_type",
2130
+ "tolerate_vulnerable_model",
2131
+ "tolerate_deprecated_model",
2132
+ "region",
2133
+ "image_uri",
2134
+ "model_data",
2135
+ "source_dir",
2136
+ "entry_point",
2137
+ "env",
2138
+ "predictor_cls",
2139
+ "role",
2140
+ "name",
2141
+ "vpc_config",
2142
+ "sagemaker_session",
2143
+ "enable_network_isolation",
2144
+ "model_kms_key",
2145
+ "image_config",
2146
+ "code_location",
2147
+ "container_log_level",
2148
+ "dependencies",
2149
+ "git_config",
2150
+ "model_package_arn",
2151
+ "training_instance_type",
2152
+ "resources",
2153
+ "config_name",
2154
+ "additional_model_data_sources",
2155
+ "hub_content_type",
2156
+ "model_reference_arn",
2157
+ "specs",
2158
+ ]
2159
+
2160
+ SERIALIZATION_EXCLUSION_SET = {
2161
+ "instance_type",
2162
+ "model_id",
2163
+ "model_version",
2164
+ "hub_arn",
2165
+ "model_type",
2166
+ "tolerate_vulnerable_model",
2167
+ "tolerate_deprecated_model",
2168
+ "region",
2169
+ "model_package_arn",
2170
+ "training_instance_type",
2171
+ "config_name",
2172
+ "hub_content_type",
2173
+ }
2174
+
2175
+ def __init__(
2176
+ self,
2177
+ model_id: str,
2178
+ model_version: Optional[str] = None,
2179
+ hub_arn: Optional[str] = None,
2180
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
2181
+ region: Optional[str] = None,
2182
+ instance_type: Optional[str] = None,
2183
+ image_uri: Optional[Union[str, Any]] = None,
2184
+ model_data: Optional[Union[str, Any, dict]] = None,
2185
+ role: Optional[str] = None,
2186
+ predictor_cls: Optional[Callable] = None,
2187
+ env: Optional[Dict[str, Union[str, Any]]] = None,
2188
+ name: Optional[str] = None,
2189
+ vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
2190
+ sagemaker_session: Optional[Any] = None,
2191
+ enable_network_isolation: Union[bool, Any] = None,
2192
+ model_kms_key: Optional[str] = None,
2193
+ image_config: Optional[Dict[str, Union[str, Any]]] = None,
2194
+ source_dir: Optional[str] = None,
2195
+ code_location: Optional[str] = None,
2196
+ entry_point: Optional[str] = None,
2197
+ container_log_level: Optional[Union[int, Any]] = None,
2198
+ dependencies: Optional[List[str]] = None,
2199
+ git_config: Optional[Dict[str, str]] = None,
2200
+ tolerate_vulnerable_model: Optional[bool] = None,
2201
+ tolerate_deprecated_model: Optional[bool] = None,
2202
+ model_package_arn: Optional[str] = None,
2203
+ training_instance_type: Optional[str] = None,
2204
+ resources: Optional[ResourceRequirements] = None,
2205
+ config_name: Optional[str] = None,
2206
+ additional_model_data_sources: Optional[Dict[str, Any]] = None,
2207
+ ) -> None:
2208
+ """Instantiates JumpStartModelInitKwargs object."""
2209
+
2210
+ self.model_id = model_id
2211
+ self.model_version = model_version
2212
+ self.hub_arn = hub_arn
2213
+ self.model_type = model_type
2214
+ self.instance_type = instance_type
2215
+ self.region = region
2216
+ self.image_uri = image_uri
2217
+ self.model_data = deepcopy(model_data)
2218
+ self.source_dir = source_dir
2219
+ self.entry_point = entry_point
2220
+ self.env = deepcopy(env)
2221
+ self.predictor_cls = predictor_cls
2222
+ self.role = role
2223
+ self.name = name
2224
+ self.vpc_config = vpc_config
2225
+ self.sagemaker_session = sagemaker_session
2226
+ self.enable_network_isolation = enable_network_isolation
2227
+ self.model_kms_key = model_kms_key
2228
+ self.image_config = image_config
2229
+ self.code_location = code_location
2230
+ self.container_log_level = container_log_level
2231
+ self.dependencies = dependencies
2232
+ self.git_config = git_config
2233
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2234
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2235
+ self.model_package_arn = model_package_arn
2236
+ self.training_instance_type = training_instance_type
2237
+ self.resources = resources
2238
+ self.config_name = config_name
2239
+ self.additional_model_data_sources = additional_model_data_sources
2240
+
2241
+
2242
+ class JumpStartModelDeployKwargs(JumpStartKwargs):
2243
+ """Data class for the inputs to `JumpStartModel.deploy` method."""
2244
+
2245
+ __slots__ = [
2246
+ "model_id",
2247
+ "model_version",
2248
+ "hub_arn",
2249
+ "model_type",
2250
+ "initial_instance_count",
2251
+ "instance_type",
2252
+ "region",
2253
+ "serializer",
2254
+ "deserializer",
2255
+ "accelerator_type",
2256
+ "endpoint_name",
2257
+ "inference_component_name",
2258
+ "tags",
2259
+ "kms_key",
2260
+ "wait",
2261
+ "data_capture_config",
2262
+ "async_inference_config",
2263
+ "serverless_inference_config",
2264
+ "volume_size",
2265
+ "model_data_download_timeout",
2266
+ "container_startup_health_check_timeout",
2267
+ "inference_recommendation_id",
2268
+ "explainer_config",
2269
+ "tolerate_vulnerable_model",
2270
+ "tolerate_deprecated_model",
2271
+ "sagemaker_session",
2272
+ "training_instance_type",
2273
+ "accept_eula",
2274
+ "model_reference_arn",
2275
+ "endpoint_logging",
2276
+ "resources",
2277
+ "endpoint_type",
2278
+ "config_name",
2279
+ "routing_config",
2280
+ "specs",
2281
+ "model_access_configs",
2282
+ "inference_ami_version",
2283
+ ]
2284
+
2285
+ SERIALIZATION_EXCLUSION_SET = {
2286
+ "model_id",
2287
+ "model_version",
2288
+ "model_type",
2289
+ "hub_arn",
2290
+ "region",
2291
+ "tolerate_deprecated_model",
2292
+ "tolerate_vulnerable_model",
2293
+ "sagemaker_session",
2294
+ "training_instance_type",
2295
+ "config_name",
2296
+ "model_access_configs",
2297
+ }
2298
+
2299
+ def __init__(
2300
+ self,
2301
+ model_id: str,
2302
+ model_version: Optional[str] = None,
2303
+ hub_arn: Optional[str] = None,
2304
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
2305
+ region: Optional[str] = None,
2306
+ initial_instance_count: Optional[int] = None,
2307
+ instance_type: Optional[str] = None,
2308
+ serializer: Optional[Any] = None,
2309
+ deserializer: Optional[Any] = None,
2310
+ accelerator_type: Optional[str] = None,
2311
+ endpoint_name: Optional[str] = None,
2312
+ inference_component_name: Optional[str] = None,
2313
+ tags: Optional[Tags] = None,
2314
+ kms_key: Optional[str] = None,
2315
+ wait: Optional[bool] = None,
2316
+ data_capture_config: Optional[Any] = None,
2317
+ async_inference_config: Optional[Any] = None,
2318
+ serverless_inference_config: Optional[Any] = None,
2319
+ volume_size: Optional[int] = None,
2320
+ model_data_download_timeout: Optional[int] = None,
2321
+ container_startup_health_check_timeout: Optional[int] = None,
2322
+ inference_recommendation_id: Optional[str] = None,
2323
+ explainer_config: Optional[Any] = None,
2324
+ tolerate_deprecated_model: Optional[bool] = None,
2325
+ tolerate_vulnerable_model: Optional[bool] = None,
2326
+ sagemaker_session: Optional[Session] = None,
2327
+ training_instance_type: Optional[str] = None,
2328
+ accept_eula: Optional[bool] = None,
2329
+ model_reference_arn: Optional[str] = None,
2330
+ endpoint_logging: Optional[bool] = None,
2331
+ resources: Optional[ResourceRequirements] = None,
2332
+ endpoint_type: Optional[EndpointType] = None,
2333
+ config_name: Optional[str] = None,
2334
+ routing_config: Optional[Dict[str, Any]] = None,
2335
+ model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
2336
+ inference_ami_version: Optional[str] = None,
2337
+ ) -> None:
2338
+ """Instantiates JumpStartModelDeployKwargs object."""
2339
+
2340
+ self.model_id = model_id
2341
+ self.model_version = model_version
2342
+ self.hub_arn = hub_arn
2343
+ self.model_type = model_type
2344
+ self.initial_instance_count = initial_instance_count
2345
+ self.instance_type = instance_type
2346
+ self.region = region
2347
+ self.serializer = serializer
2348
+ self.deserializer = deserializer
2349
+ self.accelerator_type = accelerator_type
2350
+ self.endpoint_name = endpoint_name
2351
+ self.inference_component_name = inference_component_name
2352
+ self.tags = format_tags(tags)
2353
+ self.kms_key = kms_key
2354
+ self.wait = wait
2355
+ self.data_capture_config = data_capture_config
2356
+ self.async_inference_config = async_inference_config
2357
+ self.serverless_inference_config = serverless_inference_config
2358
+ self.volume_size = volume_size
2359
+ self.model_data_download_timeout = model_data_download_timeout
2360
+ self.container_startup_health_check_timeout = container_startup_health_check_timeout
2361
+ self.inference_recommendation_id = inference_recommendation_id
2362
+ self.explainer_config = explainer_config
2363
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2364
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2365
+ self.sagemaker_session = sagemaker_session
2366
+ self.training_instance_type = training_instance_type
2367
+ self.accept_eula = accept_eula
2368
+ self.model_reference_arn = model_reference_arn
2369
+ self.endpoint_logging = endpoint_logging
2370
+ self.resources = resources
2371
+ self.endpoint_type = endpoint_type
2372
+ self.config_name = config_name
2373
+ self.routing_config = routing_config
2374
+ self.model_access_configs = model_access_configs
2375
+ self.inference_ami_version = inference_ami_version
2376
+
2377
+
2378
+ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
2379
+ """Data class for the inputs to `JumpStartEstimator.__init__` method."""
2380
+
2381
+ __slots__ = [
2382
+ "model_id",
2383
+ "model_version",
2384
+ "hub_arn",
2385
+ "model_type",
2386
+ "instance_type",
2387
+ "instance_count",
2388
+ "region",
2389
+ "image_uri",
2390
+ "model_uri",
2391
+ "source_dir",
2392
+ "entry_point",
2393
+ "hyperparameters",
2394
+ "metric_definitions",
2395
+ "role",
2396
+ "keep_alive_period_in_seconds",
2397
+ "volume_size",
2398
+ "volume_kms_key",
2399
+ "max_run",
2400
+ "input_mode",
2401
+ "output_path",
2402
+ "output_kms_key",
2403
+ "base_job_name",
2404
+ "sagemaker_session",
2405
+ "tags",
2406
+ "subnets",
2407
+ "security_group_ids",
2408
+ "model_channel_name",
2409
+ "encrypt_inter_container_traffic",
2410
+ "use_spot_instances",
2411
+ "max_wait",
2412
+ "checkpoint_s3_uri",
2413
+ "checkpoint_local_path",
2414
+ "enable_network_isolation",
2415
+ "rules",
2416
+ "debugger_hook_config",
2417
+ "tensorboard_output_config",
2418
+ "enable_sagemaker_metrics",
2419
+ "profiler_config",
2420
+ "disable_profiler",
2421
+ "environment",
2422
+ "max_retry_attempts",
2423
+ "git_config",
2424
+ "container_log_level",
2425
+ "code_location",
2426
+ "dependencies",
2427
+ "instance_groups",
2428
+ "training_repository_access_mode",
2429
+ "training_repository_credentials_provider_arn",
2430
+ "tolerate_deprecated_model",
2431
+ "tolerate_vulnerable_model",
2432
+ "container_entry_point",
2433
+ "container_arguments",
2434
+ "disable_output_compression",
2435
+ "enable_infra_check",
2436
+ "enable_remote_debug",
2437
+ "config_name",
2438
+ "enable_session_tag_chaining",
2439
+ "hub_content_type",
2440
+ "model_reference_arn",
2441
+ "specs",
2442
+ "training_plan",
2443
+ ]
2444
+
2445
+ SERIALIZATION_EXCLUSION_SET = {
2446
+ "region",
2447
+ "tolerate_deprecated_model",
2448
+ "tolerate_vulnerable_model",
2449
+ "model_id",
2450
+ "model_version",
2451
+ "hub_arn",
2452
+ "model_type",
2453
+ "hub_content_type",
2454
+ "config_name",
2455
+ }
2456
+
2457
+ def __init__(
2458
+ self,
2459
+ model_id: str,
2460
+ model_version: Optional[str] = None,
2461
+ hub_arn: Optional[str] = None,
2462
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
2463
+ region: Optional[str] = None,
2464
+ image_uri: Optional[Union[str, Any]] = None,
2465
+ role: Optional[str] = None,
2466
+ instance_count: Optional[Union[int, Any]] = None,
2467
+ instance_type: Optional[Union[str, Any]] = None,
2468
+ keep_alive_period_in_seconds: Optional[Union[int, Any]] = None,
2469
+ volume_size: Optional[Union[int, Any]] = None,
2470
+ volume_kms_key: Optional[Union[str, Any]] = None,
2471
+ max_run: Optional[Union[int, Any]] = None,
2472
+ input_mode: Optional[Union[str, Any]] = None,
2473
+ output_path: Optional[Union[str, Any]] = None,
2474
+ output_kms_key: Optional[Union[str, Any]] = None,
2475
+ base_job_name: Optional[str] = None,
2476
+ sagemaker_session: Optional[Any] = None,
2477
+ hyperparameters: Optional[Dict[str, Union[str, Any]]] = None,
2478
+ tags: Optional[Tags] = None,
2479
+ subnets: Optional[List[Union[str, Any]]] = None,
2480
+ security_group_ids: Optional[List[Union[str, Any]]] = None,
2481
+ model_uri: Optional[str] = None,
2482
+ model_channel_name: Optional[Union[str, Any]] = None,
2483
+ metric_definitions: Optional[List[Dict[str, Union[str, Any]]]] = None,
2484
+ encrypt_inter_container_traffic: Union[bool, Any] = None,
2485
+ use_spot_instances: Optional[Union[bool, Any]] = None,
2486
+ max_wait: Optional[Union[int, Any]] = None,
2487
+ checkpoint_s3_uri: Optional[Union[str, Any]] = None,
2488
+ checkpoint_local_path: Optional[Union[str, Any]] = None,
2489
+ enable_network_isolation: Union[bool, Any] = None,
2490
+ rules: Optional[List[Any]] = None,
2491
+ debugger_hook_config: Optional[Union[Any, bool]] = None,
2492
+ tensorboard_output_config: Optional[Any] = None,
2493
+ enable_sagemaker_metrics: Optional[Union[bool, Any]] = None,
2494
+ profiler_config: Optional[Any] = None,
2495
+ disable_profiler: Optional[bool] = None,
2496
+ environment: Optional[Dict[str, Union[str, Any]]] = None,
2497
+ max_retry_attempts: Optional[Union[int, Any]] = None,
2498
+ source_dir: Optional[Union[str, Any]] = None,
2499
+ git_config: Optional[Dict[str, str]] = None,
2500
+ container_log_level: Optional[Union[int, Any]] = None,
2501
+ code_location: Optional[str] = None,
2502
+ entry_point: Optional[Union[str, Any]] = None,
2503
+ dependencies: Optional[List[str]] = None,
2504
+ instance_groups: Optional[List[Any]] = None,
2505
+ training_repository_access_mode: Optional[Union[str, Any]] = None,
2506
+ training_repository_credentials_provider_arn: Optional[Union[str, Any]] = None,
2507
+ tolerate_vulnerable_model: Optional[bool] = None,
2508
+ tolerate_deprecated_model: Optional[bool] = None,
2509
+ container_entry_point: Optional[List[str]] = None,
2510
+ container_arguments: Optional[List[str]] = None,
2511
+ disable_output_compression: Optional[bool] = None,
2512
+ enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
2513
+ enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
2514
+ config_name: Optional[str] = None,
2515
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
2516
+ training_plan: Optional[Union[str, PipelineVariable]] = None,
2517
+ ) -> None:
2518
+ """Instantiates JumpStartEstimatorInitKwargs object."""
2519
+
2520
+ self.model_id = model_id
2521
+ self.model_version = model_version
2522
+ self.hub_arn = hub_arn
2523
+ self.model_type = model_type
2524
+ self.instance_type = instance_type
2525
+ self.instance_count = instance_count
2526
+ self.region = region
2527
+ self.image_uri = image_uri
2528
+ self.model_uri = model_uri
2529
+ self.source_dir = source_dir
2530
+ self.entry_point = entry_point
2531
+ self.hyperparameters = deepcopy(hyperparameters)
2532
+ self.metric_definitions = deepcopy(metric_definitions)
2533
+ self.role = role
2534
+ self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
2535
+ self.volume_size = volume_size
2536
+ self.volume_kms_key = volume_kms_key
2537
+ self.max_run = max_run
2538
+ self.input_mode = input_mode
2539
+ self.output_path = output_path
2540
+ self.output_kms_key = output_kms_key
2541
+ self.base_job_name = base_job_name
2542
+ self.sagemaker_session = sagemaker_session
2543
+ self.tags = format_tags(tags)
2544
+ self.subnets = subnets
2545
+ self.security_group_ids = security_group_ids
2546
+ self.model_channel_name = model_channel_name
2547
+ self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
2548
+ self.use_spot_instances = use_spot_instances
2549
+ self.max_wait = max_wait
2550
+ self.checkpoint_s3_uri = checkpoint_s3_uri
2551
+ self.checkpoint_local_path = checkpoint_local_path
2552
+ self.enable_network_isolation = enable_network_isolation
2553
+ self.rules = rules
2554
+ self.debugger_hook_config = debugger_hook_config
2555
+ self.tensorboard_output_config = tensorboard_output_config
2556
+ self.enable_sagemaker_metrics = enable_sagemaker_metrics
2557
+ self.profiler_config = profiler_config
2558
+ self.disable_profiler = disable_profiler
2559
+ self.environment = deepcopy(environment)
2560
+ self.max_retry_attempts = max_retry_attempts
2561
+ self.git_config = git_config
2562
+ self.container_log_level = container_log_level
2563
+ self.code_location = code_location
2564
+ self.dependencies = dependencies
2565
+ self.instance_groups = instance_groups
2566
+ self.training_repository_access_mode = training_repository_access_mode
2567
+ self.training_repository_credentials_provider_arn = (
2568
+ training_repository_credentials_provider_arn
2569
+ )
2570
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2571
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2572
+ self.container_entry_point = container_entry_point
2573
+ self.container_arguments = container_arguments
2574
+ self.disable_output_compression = disable_output_compression
2575
+ self.enable_infra_check = enable_infra_check
2576
+ self.enable_remote_debug = enable_remote_debug
2577
+ self.config_name = config_name
2578
+ self.enable_session_tag_chaining = enable_session_tag_chaining
2579
+ self.training_plan = training_plan
2580
+
2581
+
2582
+ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
2583
+ """Data class for the inputs to `JumpStartEstimator.fit` method."""
2584
+
2585
+ __slots__ = [
2586
+ "model_id",
2587
+ "model_version",
2588
+ "hub_arn",
2589
+ "model_type",
2590
+ "region",
2591
+ "inputs",
2592
+ "wait",
2593
+ "logs",
2594
+ "job_name",
2595
+ "experiment_config",
2596
+ "tolerate_deprecated_model",
2597
+ "tolerate_vulnerable_model",
2598
+ "sagemaker_session",
2599
+ "config_name",
2600
+ "specs",
2601
+ ]
2602
+
2603
+ SERIALIZATION_EXCLUSION_SET = {
2604
+ "model_id",
2605
+ "model_version",
2606
+ "hub_arn",
2607
+ "model_type",
2608
+ "region",
2609
+ "tolerate_deprecated_model",
2610
+ "tolerate_vulnerable_model",
2611
+ "sagemaker_session",
2612
+ "config_name",
2613
+ }
2614
+
2615
+ def __init__(
2616
+ self,
2617
+ model_id: str,
2618
+ model_version: Optional[str] = None,
2619
+ hub_arn: Optional[str] = None,
2620
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
2621
+ region: Optional[str] = None,
2622
+ inputs: Optional[Union[str, Dict, Any, Any]] = None,
2623
+ wait: Optional[bool] = None,
2624
+ logs: Optional[str] = None,
2625
+ job_name: Optional[str] = None,
2626
+ experiment_config: Optional[Dict[str, str]] = None,
2627
+ tolerate_deprecated_model: Optional[bool] = None,
2628
+ tolerate_vulnerable_model: Optional[bool] = None,
2629
+ sagemaker_session: Optional[Session] = None,
2630
+ config_name: Optional[str] = None,
2631
+ ) -> None:
2632
+ """Instantiates JumpStartEstimatorInitKwargs object."""
2633
+
2634
+ self.model_id = model_id
2635
+ self.model_version = model_version
2636
+ self.hub_arn = hub_arn
2637
+ self.model_type = model_type
2638
+ self.region = region
2639
+ self.inputs = inputs
2640
+ self.wait = wait
2641
+ self.logs = logs
2642
+ self.job_name = job_name
2643
+ self.experiment_config = experiment_config
2644
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2645
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2646
+ self.sagemaker_session = sagemaker_session
2647
+ self.config_name = config_name
2648
+
2649
+
2650
+ class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
2651
+ """Data class for the inputs to `JumpStartEstimator.deploy` method."""
2652
+
2653
+ __slots__ = [
2654
+ "model_id",
2655
+ "model_version",
2656
+ "hub_arn",
2657
+ "instance_type",
2658
+ "initial_instance_count",
2659
+ "region",
2660
+ "image_uri",
2661
+ "source_dir",
2662
+ "entry_point",
2663
+ "env",
2664
+ "predictor_cls",
2665
+ "serializer",
2666
+ "deserializer",
2667
+ "accelerator_type",
2668
+ "endpoint_name",
2669
+ "tags",
2670
+ "kms_key",
2671
+ "wait",
2672
+ "data_capture_config",
2673
+ "async_inference_config",
2674
+ "serverless_inference_config",
2675
+ "volume_size",
2676
+ "model_data_download_timeout",
2677
+ "container_startup_health_check_timeout",
2678
+ "inference_recommendation_id",
2679
+ "explainer_config",
2680
+ "role",
2681
+ "vpc_config",
2682
+ "sagemaker_session",
2683
+ "enable_network_isolation",
2684
+ "model_kms_key",
2685
+ "image_config",
2686
+ "code_location",
2687
+ "container_log_level",
2688
+ "dependencies",
2689
+ "git_config",
2690
+ "tolerate_deprecated_model",
2691
+ "tolerate_vulnerable_model",
2692
+ "model_name",
2693
+ "use_compiled_model",
2694
+ "config_name",
2695
+ "specs",
2696
+ ]
2697
+
2698
+ SERIALIZATION_EXCLUSION_SET = {
2699
+ "tolerate_vulnerable_model",
2700
+ "tolerate_deprecated_model",
2701
+ "region",
2702
+ "model_id",
2703
+ "model_version",
2704
+ "hub_arn",
2705
+ "sagemaker_session",
2706
+ "config_name",
2707
+ }
2708
+
2709
+ def __init__(
2710
+ self,
2711
+ model_id: str,
2712
+ model_version: Optional[str] = None,
2713
+ hub_arn: Optional[str] = None,
2714
+ region: Optional[str] = None,
2715
+ initial_instance_count: Optional[int] = None,
2716
+ instance_type: Optional[str] = None,
2717
+ serializer: Optional[Any] = None,
2718
+ deserializer: Optional[Any] = None,
2719
+ accelerator_type: Optional[str] = None,
2720
+ endpoint_name: Optional[str] = None,
2721
+ tags: Optional[Tags] = None,
2722
+ kms_key: Optional[str] = None,
2723
+ wait: Optional[bool] = None,
2724
+ data_capture_config: Optional[Any] = None,
2725
+ async_inference_config: Optional[Any] = None,
2726
+ serverless_inference_config: Optional[Any] = None,
2727
+ volume_size: Optional[int] = None,
2728
+ model_data_download_timeout: Optional[int] = None,
2729
+ container_startup_health_check_timeout: Optional[int] = None,
2730
+ inference_recommendation_id: Optional[str] = None,
2731
+ explainer_config: Optional[Any] = None,
2732
+ image_uri: Optional[Union[str, Any]] = None,
2733
+ role: Optional[str] = None,
2734
+ predictor_cls: Optional[Callable] = None,
2735
+ env: Optional[Dict[str, Union[str, Any]]] = None,
2736
+ model_name: Optional[str] = None,
2737
+ vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
2738
+ sagemaker_session: Optional[Any] = None,
2739
+ enable_network_isolation: Union[bool, Any] = None,
2740
+ model_kms_key: Optional[str] = None,
2741
+ image_config: Optional[Dict[str, Union[str, Any]]] = None,
2742
+ source_dir: Optional[str] = None,
2743
+ code_location: Optional[str] = None,
2744
+ entry_point: Optional[str] = None,
2745
+ container_log_level: Optional[Union[int, Any]] = None,
2746
+ dependencies: Optional[List[str]] = None,
2747
+ git_config: Optional[Dict[str, str]] = None,
2748
+ tolerate_deprecated_model: Optional[bool] = None,
2749
+ tolerate_vulnerable_model: Optional[bool] = None,
2750
+ use_compiled_model: bool = False,
2751
+ config_name: Optional[str] = None,
2752
+ ) -> None:
2753
+ """Instantiates JumpStartEstimatorInitKwargs object."""
2754
+
2755
+ self.model_id = model_id
2756
+ self.model_version = model_version
2757
+ self.hub_arn = hub_arn
2758
+ self.instance_type = instance_type
2759
+ self.initial_instance_count = initial_instance_count
2760
+ self.region = region
2761
+ self.image_uri = image_uri
2762
+ self.source_dir = source_dir
2763
+ self.entry_point = entry_point
2764
+ self.env = deepcopy(env)
2765
+ self.predictor_cls = predictor_cls
2766
+ self.serializer = serializer
2767
+ self.deserializer = deserializer
2768
+ self.accelerator_type = accelerator_type
2769
+ self.endpoint_name = endpoint_name
2770
+ self.tags = format_tags(tags)
2771
+ self.kms_key = kms_key
2772
+ self.wait = wait
2773
+ self.data_capture_config = data_capture_config
2774
+ self.async_inference_config = async_inference_config
2775
+ self.serverless_inference_config = serverless_inference_config
2776
+ self.volume_size = volume_size
2777
+ self.model_data_download_timeout = model_data_download_timeout
2778
+ self.container_startup_health_check_timeout = container_startup_health_check_timeout
2779
+ self.inference_recommendation_id = inference_recommendation_id
2780
+ self.explainer_config = explainer_config
2781
+ self.role = role
2782
+ self.model_name = model_name
2783
+ self.vpc_config = vpc_config
2784
+ self.sagemaker_session = sagemaker_session
2785
+ self.enable_network_isolation = enable_network_isolation
2786
+ self.model_kms_key = model_kms_key
2787
+ self.image_config = image_config
2788
+ self.code_location = code_location
2789
+ self.container_log_level = container_log_level
2790
+ self.dependencies = dependencies
2791
+ self.git_config = git_config
2792
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2793
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2794
+ self.use_compiled_model = use_compiled_model
2795
+ self.config_name = config_name
2796
+
2797
+
2798
+ class JumpStartModelRegisterKwargs(JumpStartKwargs):
2799
+ """Data class for the inputs to `JumpStartEstimator.deploy` method."""
2800
+
2801
+ __slots__ = [
2802
+ "tolerate_vulnerable_model",
2803
+ "tolerate_deprecated_model",
2804
+ "region",
2805
+ "model_id",
2806
+ "model_type",
2807
+ "model_version",
2808
+ "hub_arn",
2809
+ "sagemaker_session",
2810
+ "content_types",
2811
+ "response_types",
2812
+ "inference_instances",
2813
+ "transform_instances",
2814
+ "model_package_group_name",
2815
+ "image_uri",
2816
+ "model_metrics",
2817
+ "metadata_properties",
2818
+ "approval_status",
2819
+ "description",
2820
+ "drift_check_baselines",
2821
+ "customer_metadata_properties",
2822
+ "validation_specification",
2823
+ "domain",
2824
+ "task",
2825
+ "sample_payload_url",
2826
+ "framework",
2827
+ "framework_version",
2828
+ "nearest_model_name",
2829
+ "data_input_configuration",
2830
+ "skip_model_validation",
2831
+ "source_uri",
2832
+ "model_life_cycle",
2833
+ "config_name",
2834
+ "model_card",
2835
+ "accept_eula",
2836
+ "specs",
2837
+ ]
2838
+
2839
+ SERIALIZATION_EXCLUSION_SET = {
2840
+ "tolerate_vulnerable_model",
2841
+ "tolerate_deprecated_model",
2842
+ "region",
2843
+ "model_id",
2844
+ "model_version",
2845
+ "hub_arn",
2846
+ "sagemaker_session",
2847
+ "config_name",
2848
+ }
2849
+
2850
+ def __init__(
2851
+ self,
2852
+ model_id: str,
2853
+ model_version: Optional[str] = None,
2854
+ hub_arn: Optional[str] = None,
2855
+ region: Optional[str] = None,
2856
+ model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
2857
+ tolerate_deprecated_model: Optional[bool] = None,
2858
+ tolerate_vulnerable_model: Optional[bool] = None,
2859
+ sagemaker_session: Optional[Any] = None,
2860
+ content_types: List[str] = None,
2861
+ response_types: List[str] = None,
2862
+ inference_instances: Optional[List[str]] = None,
2863
+ transform_instances: Optional[List[str]] = None,
2864
+ model_package_group_name: Optional[str] = None,
2865
+ image_uri: Optional[str] = None,
2866
+ model_metrics: Optional[ModelMetrics] = None,
2867
+ metadata_properties: Optional[MetadataProperties] = None,
2868
+ approval_status: Optional[str] = None,
2869
+ description: Optional[str] = None,
2870
+ drift_check_baselines: Optional[DriftCheckBaselines] = None,
2871
+ customer_metadata_properties: Optional[Dict[str, str]] = None,
2872
+ validation_specification: Optional[str] = None,
2873
+ domain: Optional[str] = None,
2874
+ task: Optional[str] = None,
2875
+ sample_payload_url: Optional[str] = None,
2876
+ framework: Optional[str] = None,
2877
+ framework_version: Optional[str] = None,
2878
+ nearest_model_name: Optional[str] = None,
2879
+ data_input_configuration: Optional[str] = None,
2880
+ skip_model_validation: Optional[str] = None,
2881
+ source_uri: Optional[str] = None,
2882
+ model_life_cycle: Optional[ModelLifeCycle] = None,
2883
+ config_name: Optional[str] = None,
2884
+ model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
2885
+ accept_eula: Optional[bool] = None,
2886
+ ) -> None:
2887
+ """Instantiates JumpStartModelRegisterKwargs object."""
2888
+
2889
+ self.model_id = model_id
2890
+ self.model_version = model_version
2891
+ self.hub_arn = hub_arn
2892
+ self.model_type = model_type
2893
+ self.region = region
2894
+ self.image_uri = image_uri
2895
+ self.sagemaker_session = sagemaker_session
2896
+ self.tolerate_deprecated_model = tolerate_deprecated_model
2897
+ self.tolerate_vulnerable_model = tolerate_vulnerable_model
2898
+ self.content_types = content_types
2899
+ self.response_types = response_types
2900
+ self.inference_instances = inference_instances
2901
+ self.transform_instances = transform_instances
2902
+ self.model_package_group_name = model_package_group_name
2903
+ self.image_uri = image_uri
2904
+ self.model_metrics = model_metrics
2905
+ self.metadata_properties = metadata_properties
2906
+ self.approval_status = approval_status
2907
+ self.description = description
2908
+ self.drift_check_baselines = drift_check_baselines
2909
+ self.customer_metadata_properties = customer_metadata_properties
2910
+ self.validation_specification = validation_specification
2911
+ self.domain = domain
2912
+ self.task = task
2913
+ self.sample_payload_url = sample_payload_url
2914
+ self.framework = framework
2915
+ self.framework_version = framework_version
2916
+ self.nearest_model_name = nearest_model_name
2917
+ self.data_input_configuration = data_input_configuration
2918
+ self.skip_model_validation = skip_model_validation
2919
+ self.source_uri = source_uri
2920
+ self.config_name = config_name
2921
+ self.model_card = model_card
2922
+ self.accept_eula = accept_eula
2923
+
2924
+
2925
+ class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
2926
+ """Base class for Deployment Config Data."""
2927
+
2928
+ def _convert_to_pascal_case(self, attr_name: str) -> str:
2929
+ """Converts a snake_case attribute name into a camelCased string.
2930
+
2931
+ Args:
2932
+ attr_name (str): The snake_case attribute name.
2933
+ Returns:
2934
+ str: The PascalCased attribute name.
2935
+ """
2936
+ return attr_name.replace("_", " ").title().replace(" ", "")
2937
+
2938
+ def to_json(self) -> Dict[str, Any]:
2939
+ """Represents ``This`` object as JSON."""
2940
+ json_obj = {}
2941
+ for att in self.__slots__:
2942
+ if hasattr(self, att):
2943
+ cur_val = getattr(self, att)
2944
+ att = self._convert_to_pascal_case(att)
2945
+ json_obj[att] = self._val_to_json(cur_val)
2946
+ return json_obj
2947
+
2948
+ def _val_to_json(self, val: Any) -> Any:
2949
+ """Converts the given value to JSON.
2950
+
2951
+ Args:
2952
+ val (Any): The value to convert.
2953
+ Returns:
2954
+ Any: The converted json value.
2955
+ """
2956
+ if issubclass(type(val), JumpStartDataHolderType):
2957
+ if isinstance(val, JumpStartBenchmarkStat):
2958
+ val.name = val.name.replace("_", " ").title()
2959
+ return val.to_json()
2960
+ if isinstance(val, list):
2961
+ list_obj = []
2962
+ for obj in val:
2963
+ list_obj.append(self._val_to_json(obj))
2964
+ return list_obj
2965
+ if isinstance(val, dict):
2966
+ dict_obj = {}
2967
+ for k, v in val.items():
2968
+ if isinstance(v, JumpStartDataHolderType):
2969
+ dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
2970
+ else:
2971
+ dict_obj[k] = self._val_to_json(v)
2972
+ return dict_obj
2973
+ return val
2974
+
2975
+
2976
+ class DeploymentArgs(BaseDeploymentConfigDataHolder):
2977
+ """Dataclass representing a Deployment Args."""
2978
+
2979
+ __slots__ = [
2980
+ "image_uri",
2981
+ "model_data",
2982
+ "model_package_arn",
2983
+ "environment",
2984
+ "instance_type",
2985
+ "compute_resource_requirements",
2986
+ "model_data_download_timeout",
2987
+ "container_startup_health_check_timeout",
2988
+ "additional_data_sources",
2989
+ ]
2990
+
2991
+ def __init__(
2992
+ self,
2993
+ init_kwargs: Optional[JumpStartModelInitKwargs] = None,
2994
+ deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
2995
+ resolved_config: Optional[Dict[str, Any]] = None,
2996
+ ):
2997
+ """Instantiates DeploymentArgs object."""
2998
+ if init_kwargs is not None:
2999
+ self.image_uri = init_kwargs.image_uri
3000
+ self.model_data = init_kwargs.model_data
3001
+ self.model_package_arn = init_kwargs.model_package_arn
3002
+ self.instance_type = init_kwargs.instance_type
3003
+ self.environment = init_kwargs.env
3004
+ if init_kwargs.resources is not None:
3005
+ self.compute_resource_requirements = (
3006
+ init_kwargs.resources.get_compute_resource_requirements()
3007
+ )
3008
+ if deploy_kwargs is not None:
3009
+ self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
3010
+ self.container_startup_health_check_timeout = (
3011
+ deploy_kwargs.container_startup_health_check_timeout
3012
+ )
3013
+ if resolved_config is not None:
3014
+ self.default_instance_type = resolved_config.get("default_inference_instance_type")
3015
+ self.supported_instance_types = resolved_config.get(
3016
+ "supported_inference_instance_types"
3017
+ )
3018
+ self.additional_data_sources = resolved_config.get("hosting_additional_data_sources")
3019
+
3020
+
3021
+ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
3022
+ """Dataclass representing a Deployment Config Metadata"""
3023
+
3024
+ __slots__ = [
3025
+ "deployment_config_name",
3026
+ "deployment_args",
3027
+ "acceleration_configs",
3028
+ "benchmark_metrics",
3029
+ ]
3030
+
3031
+ def __init__(
3032
+ self,
3033
+ config_name: Optional[str] = None,
3034
+ metadata_config: Optional[JumpStartMetadataConfig] = None,
3035
+ init_kwargs: Optional[JumpStartModelInitKwargs] = None,
3036
+ deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
3037
+ ):
3038
+ """Instantiates DeploymentConfigMetadata object."""
3039
+ self.deployment_config_name = config_name
3040
+ self.deployment_args = DeploymentArgs(
3041
+ init_kwargs, deploy_kwargs, metadata_config.resolved_config
3042
+ )
3043
+ self.benchmark_metrics = metadata_config.benchmark_metrics
3044
+ self.acceleration_configs = metadata_config.acceleration_configs