sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,936 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores types related to SageMaker JumpStart HubAPI requests and responses."""
14
+ from __future__ import absolute_import
15
+
16
+ from enum import Enum
17
+ import re
18
+ import json
19
+ import datetime
20
+
21
+ from typing import Any, Dict, List, Union, Optional
22
+ from sagemaker.core.jumpstart.enums import JumpStartScriptScope
23
+ from sagemaker.core.jumpstart.types import (
24
+ HubContentType,
25
+ HubArnExtractedInfo,
26
+ JumpStartConfigComponent,
27
+ JumpStartConfigRanking,
28
+ JumpStartMetadataConfig,
29
+ JumpStartMetadataConfigs,
30
+ JumpStartPredictorSpecs,
31
+ JumpStartHyperparameter,
32
+ JumpStartDataHolderType,
33
+ JumpStartEnvironmentVariable,
34
+ JumpStartSerializablePayload,
35
+ JumpStartInstanceTypeVariants,
36
+ )
37
+ from sagemaker.core.jumpstart.hub.parser_utils import (
38
+ snake_to_upper_camel,
39
+ walk_and_apply_json,
40
+ )
41
+
42
+
43
+ class _ComponentType(str, Enum):
44
+ """Enum for different component types."""
45
+
46
+ INFERENCE = "Inference"
47
+ TRAINING = "Training"
48
+
49
+
50
+ class HubDataHolderType(JumpStartDataHolderType):
51
+ """Base class for many Hub API interfaces."""
52
+
53
+ def to_json(self) -> Dict[str, Any]:
54
+ """Returns json representation of object."""
55
+ json_obj = {}
56
+ for att in self.__slots__:
57
+ if att in self._non_serializable_slots:
58
+ continue
59
+ if hasattr(self, att):
60
+ cur_val = getattr(self, att)
61
+ # Do not serialize null values.
62
+ if cur_val is None:
63
+ continue
64
+ if issubclass(type(cur_val), JumpStartDataHolderType):
65
+ json_obj[att] = cur_val.to_json()
66
+ elif isinstance(cur_val, list):
67
+ json_obj[att] = []
68
+ for obj in cur_val:
69
+ if issubclass(type(obj), JumpStartDataHolderType):
70
+ json_obj[att].append(obj.to_json())
71
+ else:
72
+ json_obj[att].append(obj)
73
+ elif isinstance(cur_val, datetime.datetime):
74
+ json_obj[att] = str(cur_val)
75
+ else:
76
+ json_obj[att] = cur_val
77
+ return json_obj
78
+
79
+ def __str__(self) -> str:
80
+ """Returns string representation of object.
81
+
82
+ Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
83
+ """
84
+
85
+ att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel)
86
+ return f"{json.dumps(att_dict, default=lambda o: o.to_json())}"
87
+
88
+
89
+ class CreateHubResponse(HubDataHolderType):
90
+ """Data class for the Hub from session.create_hub()"""
91
+
92
+ __slots__ = [
93
+ "hub_arn",
94
+ ]
95
+
96
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
97
+ """Instantiates CreateHubResponse object.
98
+
99
+ Args:
100
+ json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response.
101
+ """
102
+ self.from_json(json_obj)
103
+
104
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
105
+ """Sets fields in object based on json.
106
+
107
+ Args:
108
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
109
+ """
110
+ self.hub_arn: str = json_obj["HubArn"]
111
+
112
+
113
+ class HubContentDependency(HubDataHolderType):
114
+ """Data class for any dependencies related to hub content.
115
+
116
+ Content can be scripts, model artifacts, datasets, or notebooks.
117
+ """
118
+
119
+ __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"]
120
+
121
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
122
+ """Instantiates HubContentDependency object
123
+
124
+ Args:
125
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
126
+ """
127
+ self.from_json(json_obj)
128
+
129
+ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
130
+ """Sets fields in object based on json.
131
+
132
+ Args:
133
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
134
+ """
135
+
136
+ self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "")
137
+ self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "")
138
+ self.dependency_type: Optional[str] = json_obj.get("DependencyType", "")
139
+
140
+
141
+ class DescribeHubContentResponse(HubDataHolderType):
142
+ """Data class for the Hub Content from session.describe_hub_contents()"""
143
+
144
+ __slots__ = [
145
+ "creation_time",
146
+ "document_schema_version",
147
+ "failure_reason",
148
+ "hub_arn",
149
+ "hub_content_arn",
150
+ "hub_content_dependencies",
151
+ "hub_content_description",
152
+ "hub_content_display_name",
153
+ "hub_content_document",
154
+ "hub_content_markdown",
155
+ "hub_content_name",
156
+ "hub_content_search_keywords",
157
+ "hub_content_status",
158
+ "hub_content_type",
159
+ "hub_content_version",
160
+ "reference_min_version",
161
+ "hub_name",
162
+ "_region",
163
+ ]
164
+
165
+ _non_serializable_slots = ["_region"]
166
+
167
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
168
+ """Instantiates DescribeHubContentResponse object.
169
+
170
+ Args:
171
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
172
+ """
173
+ self.from_json(json_obj)
174
+
175
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
176
+ """Sets fields in object based on json.
177
+
178
+ Args:
179
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
180
+ """
181
+ self.creation_time: datetime.datetime = json_obj["CreationTime"]
182
+ self.document_schema_version: str = json_obj["DocumentSchemaVersion"]
183
+ self.failure_reason: Optional[str] = json_obj.get("FailureReason")
184
+ self.hub_arn: str = json_obj["HubArn"]
185
+ self.hub_content_arn: str = json_obj["HubContentArn"]
186
+ self.hub_content_dependencies = []
187
+ if "Dependencies" in json_obj:
188
+ self.hub_content_dependencies: Optional[List[HubContentDependency]] = [
189
+ HubContentDependency(dep) for dep in json_obj.get(["Dependencies"])
190
+ ]
191
+ self.hub_content_description: str = json_obj.get("HubContentDescription")
192
+ self.hub_content_display_name: str = json_obj.get("HubContentDisplayName")
193
+ hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn)
194
+ self._region = hub_region
195
+ self.hub_content_type: str = json_obj.get("HubContentType")
196
+ hub_content_document = json.loads(json_obj["HubContentDocument"])
197
+ if self.hub_content_type == HubContentType.MODEL:
198
+ self.hub_content_document: HubContentDocument = HubModelDocument(
199
+ json_obj=hub_content_document,
200
+ region=self._region,
201
+ dependencies=self.hub_content_dependencies,
202
+ )
203
+ elif self.hub_content_type == HubContentType.MODEL_REFERENCE:
204
+ self.hub_content_document: HubContentDocument = HubModelDocument(
205
+ json_obj=hub_content_document,
206
+ region=self._region,
207
+ dependencies=self.hub_content_dependencies,
208
+ )
209
+ elif self.hub_content_type == HubContentType.NOTEBOOK:
210
+ self.hub_content_document: HubContentDocument = HubNotebookDocument(
211
+ json_obj=hub_content_document, region=self._region
212
+ )
213
+ else:
214
+ raise ValueError(
215
+ f"[{self.hub_content_type}] is not a valid HubContentType."
216
+ f"Should be one of: {[item.name for item in HubContentType]}."
217
+ )
218
+
219
+ self.hub_content_markdown: str = json_obj.get("HubContentMarkdown")
220
+ self.hub_content_name: str = json_obj["HubContentName"]
221
+ self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords")
222
+ self.hub_content_status: str = json_obj["HubContentStatus"]
223
+ self.hub_content_version: str = json_obj["HubContentVersion"]
224
+ self.hub_name: str = json_obj["HubName"]
225
+
226
+ def get_hub_region(self) -> Optional[str]:
227
+ """Returns the region hub is in."""
228
+ return self._region
229
+
230
+
231
+ class HubS3StorageConfig(HubDataHolderType):
232
+ """Data class for any dependencies related to hub content.
233
+
234
+ Includes scripts, model artifacts, datasets, or notebooks.
235
+ """
236
+
237
+ __slots__ = ["s3_output_path"]
238
+
239
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
240
+ """Instantiates HubS3StorageConfig object
241
+
242
+ Args:
243
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
244
+ """
245
+ self.from_json(json_obj)
246
+
247
+ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
248
+ """Sets fields in object based on json.
249
+
250
+ Args:
251
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
252
+ """
253
+
254
+ self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "")
255
+
256
+
257
+ class DescribeHubResponse(HubDataHolderType):
258
+ """Data class for the Hub from session.describe_hub()"""
259
+
260
+ __slots__ = [
261
+ "creation_time",
262
+ "failure_reason",
263
+ "hub_arn",
264
+ "hub_description",
265
+ "hub_display_name",
266
+ "hub_name",
267
+ "hub_search_keywords",
268
+ "hub_status",
269
+ "last_modified_time",
270
+ "s3_storage_config",
271
+ "_region",
272
+ ]
273
+
274
+ _non_serializable_slots = ["_region"]
275
+
276
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
277
+ """Instantiates DescribeHubResponse object.
278
+
279
+ Args:
280
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
281
+ """
282
+ self.from_json(json_obj)
283
+
284
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
285
+ """Sets fields in object based on json.
286
+
287
+ Args:
288
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
289
+ """
290
+
291
+ self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"])
292
+ self.failure_reason: str = json_obj["FailureReason"]
293
+ self.hub_arn: str = json_obj["HubArn"]
294
+ hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn)
295
+ self._region = hub_region
296
+ self.hub_description: str = json_obj["HubDescription"]
297
+ self.hub_display_name: str = json_obj["HubDisplayName"]
298
+ self.hub_name: str = json_obj["HubName"]
299
+ self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"]
300
+ self.hub_status: str = json_obj["HubStatus"]
301
+ self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"])
302
+ self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"])
303
+
304
+ def get_hub_region(self) -> Optional[str]:
305
+ """Returns the region hub is in."""
306
+ return self._region
307
+
308
+
309
+ class ImportHubResponse(HubDataHolderType):
310
+ """Data class for the Hub from session.import_hub()"""
311
+
312
+ __slots__ = [
313
+ "hub_arn",
314
+ "hub_content_arn",
315
+ ]
316
+
317
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
318
+ """Instantiates ImportHubResponse object.
319
+
320
+ Args:
321
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
322
+ """
323
+ self.from_json(json_obj)
324
+
325
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
326
+ """Sets fields in object based on json.
327
+
328
+ Args:
329
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
330
+ """
331
+ self.hub_arn: str = json_obj["HubArn"]
332
+ self.hub_content_arn: str = json_obj["HubContentArn"]
333
+
334
+
335
+ class HubSummary(HubDataHolderType):
336
+ """Data class for the HubSummary from session.list_hubs()"""
337
+
338
+ __slots__ = [
339
+ "creation_time",
340
+ "hub_arn",
341
+ "hub_description",
342
+ "hub_display_name",
343
+ "hub_name",
344
+ "hub_search_keywords",
345
+ "hub_status",
346
+ "last_modified_time",
347
+ ]
348
+
349
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
350
+ """Instantiates HubSummary object.
351
+
352
+ Args:
353
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
354
+ """
355
+ self.from_json(json_obj)
356
+
357
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
358
+ """Sets fields in object based on json.
359
+
360
+ Args:
361
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
362
+ """
363
+ self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"])
364
+ self.hub_arn: str = json_obj["HubArn"]
365
+ self.hub_description: str = json_obj["HubDescription"]
366
+ self.hub_display_name: str = json_obj["HubDisplayName"]
367
+ self.hub_name: str = json_obj["HubName"]
368
+ self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"]
369
+ self.hub_status: str = json_obj["HubStatus"]
370
+ self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"])
371
+
372
+
373
+ class ListHubsResponse(HubDataHolderType):
374
+ """Data class for the Hub from session.list_hubs()"""
375
+
376
+ __slots__ = [
377
+ "hub_summaries",
378
+ "next_token",
379
+ ]
380
+
381
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
382
+ """Instantiates ListHubsResponse object.
383
+
384
+ Args:
385
+ json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response.
386
+ """
387
+ self.from_json(json_obj)
388
+
389
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
390
+ """Sets fields in object based on json.
391
+
392
+ Args:
393
+ json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response.
394
+ """
395
+ self.hub_summaries: List[HubSummary] = [
396
+ HubSummary(item) for item in json_obj["HubSummaries"]
397
+ ]
398
+ self.next_token: str = json_obj["NextToken"]
399
+
400
+
401
+ class EcrUri(HubDataHolderType):
402
+ """Data class for ECR image uri."""
403
+
404
+ __slots__ = ["account", "region_name", "repository", "tag"]
405
+
406
+ def __init__(self, uri: str):
407
+ """Instantiates EcrUri object."""
408
+ self.from_ecr_uri(uri)
409
+
410
+ def from_ecr_uri(self, uri: str) -> None:
411
+ """Parse a given aws ecr image uri into its various components."""
412
+ uri_regex = (
413
+ r"^(?:(?P<account_id>[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P<region>[a-zA-Z0-9][\w-]*)"
414
+ r"\.(?P<tld>[a-zA-Z0-9\.-]+))\/(?P<repository_name>([a-z0-9]+"
415
+ r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P<image_tag>.*)?"
416
+ )
417
+
418
+ parsed_image_uri = re.compile(uri_regex).match(uri)
419
+
420
+ account = parsed_image_uri.group("account_id")
421
+ region = parsed_image_uri.group("region")
422
+ repository = parsed_image_uri.group("repository_name")
423
+ tag = parsed_image_uri.group("image_tag")
424
+
425
+ self.account = account
426
+ self.region_name = region
427
+ self.repository = repository
428
+ self.tag = tag
429
+
430
+
431
+ class NotebookLocationUris(HubDataHolderType):
432
+ """Data class for Notebook Location uri."""
433
+
434
+ __slots__ = ["demo_notebook", "model_fit", "model_deploy"]
435
+
436
+ def __init__(self, json_obj: Dict[str, Any]):
437
+ """Instantiates EcrUri object."""
438
+ self.from_json(json_obj)
439
+
440
+ def from_json(self, json_obj: str) -> None:
441
+ """Sets fields in object based on json.
442
+
443
+ Args:
444
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
445
+ """
446
+ self.demo_notebook = json_obj.get("demo_notebook")
447
+ self.model_fit = json_obj.get("model_fit")
448
+ self.model_deploy = json_obj.get("model_deploy")
449
+
450
+
451
+ class HubModelDocument(HubDataHolderType):
452
+ """Data class for model type HubContentDocument from session.describe_hub_content()."""
453
+
454
+ SCHEMA_VERSION = "2.3.0"
455
+
456
+ __slots__ = [
457
+ "url",
458
+ "min_sdk_version",
459
+ "training_supported",
460
+ "model_types",
461
+ "capabilities",
462
+ "incremental_training_supported",
463
+ "dynamic_container_deployment_supported",
464
+ "hosting_ecr_uri",
465
+ "hosting_artifact_s3_data_type",
466
+ "hosting_artifact_compression_type",
467
+ "hosting_artifact_uri",
468
+ "hosting_prepacked_artifact_uri",
469
+ "hosting_prepacked_artifact_version",
470
+ "hosting_script_uri",
471
+ "hosting_use_script_uri",
472
+ "hosting_eula_uri",
473
+ "hosting_model_package_arn",
474
+ "inference_ami_version",
475
+ "model_subscription_link",
476
+ "inference_configs",
477
+ "inference_config_components",
478
+ "inference_config_rankings",
479
+ "training_artifact_s3_data_type",
480
+ "training_artifact_compression_type",
481
+ "training_model_package_artifact_uri",
482
+ "hyperparameters",
483
+ "inference_environment_variables",
484
+ "training_script_uri",
485
+ "training_prepacked_script_uri",
486
+ "training_prepacked_script_version",
487
+ "training_ecr_uri",
488
+ "training_metrics",
489
+ "training_artifact_uri",
490
+ "training_configs",
491
+ "training_config_components",
492
+ "training_config_rankings",
493
+ "inference_dependencies",
494
+ "training_dependencies",
495
+ "default_inference_instance_type",
496
+ "supported_inference_instance_types",
497
+ "default_training_instance_type",
498
+ "supported_training_instance_types",
499
+ "sage_maker_sdk_predictor_specifications",
500
+ "inference_volume_size",
501
+ "training_volume_size",
502
+ "inference_enable_network_isolation",
503
+ "training_enable_network_isolation",
504
+ "fine_tuning_supported",
505
+ "validation_supported",
506
+ "default_training_dataset_uri",
507
+ "resource_name_base",
508
+ "gated_bucket",
509
+ "default_payloads",
510
+ "hosting_resource_requirements",
511
+ "hosting_instance_type_variants",
512
+ "training_instance_type_variants",
513
+ "notebook_location_uris",
514
+ "model_provider_icon_uri",
515
+ "task",
516
+ "framework",
517
+ "datatype",
518
+ "license",
519
+ "contextual_help",
520
+ "model_data_download_timeout",
521
+ "container_startup_health_check_timeout",
522
+ "encrypt_inter_container_traffic",
523
+ "max_runtime_in_seconds",
524
+ "disable_output_compression",
525
+ "model_dir",
526
+ "dependencies",
527
+ "_region",
528
+ ]
529
+
530
+ _non_serializable_slots = ["_region"]
531
+
532
+ def __init__(
533
+ self,
534
+ json_obj: Dict[str, Any],
535
+ region: str,
536
+ dependencies: List[HubContentDependency] = None,
537
+ ) -> None:
538
+ """Instantiates HubModelDocument object.
539
+
540
+ Args:
541
+ json_obj (Dict[str, Any]): Dictionary representation of hub content document.
542
+
543
+ Raises:
544
+ ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided.
545
+ """
546
+ self._region = region
547
+ self.dependencies = dependencies or []
548
+ self.from_json(json_obj)
549
+
550
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
551
+ """Sets fields in object based on json.
552
+
553
+ Args:
554
+ json_obj (Dict[str, Any]): Dictionary representation of hub model document.
555
+ """
556
+ self.url: str = json_obj.get("Url")
557
+ self.min_sdk_version: str = json_obj.get("MinSdkVersion")
558
+ self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
559
+ self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
560
+ self.hosting_script_uri = json_obj.get("HostingScriptUri")
561
+ self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
562
+ self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
563
+ JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
564
+ for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
565
+ ]
566
+ self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
567
+ self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
568
+ self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
569
+ self.incremental_training_supported: bool = bool(
570
+ json_obj.get("IncrementalTrainingSupported")
571
+ )
572
+ self.dynamic_container_deployment_supported: Optional[bool] = (
573
+ bool(json_obj.get("DynamicContainerDeploymentSupported"))
574
+ if json_obj.get("DynamicContainerDeploymentSupported")
575
+ else None
576
+ )
577
+ self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get(
578
+ "HostingArtifactS3DataType"
579
+ )
580
+ self.hosting_artifact_compression_type: Optional[str] = json_obj.get(
581
+ "HostingArtifactCompressionType"
582
+ )
583
+ self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get(
584
+ "HostingPrepackedArtifactUri"
585
+ )
586
+ self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get(
587
+ "HostingPrepackedArtifactVersion"
588
+ )
589
+ self.hosting_use_script_uri: Optional[bool] = (
590
+ bool(json_obj.get("HostingUseScriptUri"))
591
+ if json_obj.get("HostingUseScriptUri") is not None
592
+ else None
593
+ )
594
+ self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
595
+ self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
596
+
597
+ self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion")
598
+
599
+ self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
600
+
601
+ self.inference_config_rankings = self._get_config_rankings(json_obj)
602
+ self.inference_config_components = self._get_config_components(json_obj)
603
+ self.inference_configs = self._get_configs(json_obj)
604
+
605
+ self.default_inference_instance_type: Optional[str] = json_obj.get(
606
+ "DefaultInferenceInstanceType"
607
+ )
608
+ self.supported_inference_instance_types: Optional[str] = json_obj.get(
609
+ "SupportedInferenceInstanceTypes"
610
+ )
611
+ self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = (
612
+ JumpStartPredictorSpecs(
613
+ json_obj.get("SageMakerSdkPredictorSpecifications"),
614
+ is_hub_content=True,
615
+ )
616
+ if json_obj.get("SageMakerSdkPredictorSpecifications")
617
+ else None
618
+ )
619
+ self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize")
620
+ self.inference_enable_network_isolation: Optional[str] = json_obj.get(
621
+ "InferenceEnableNetworkIsolation", False
622
+ )
623
+ self.fine_tuning_supported: Optional[bool] = (
624
+ bool(json_obj.get("FineTuningSupported"))
625
+ if json_obj.get("FineTuningSupported")
626
+ else None
627
+ )
628
+ self.validation_supported: Optional[bool] = (
629
+ bool(json_obj.get("ValidationSupported"))
630
+ if json_obj.get("ValidationSupported")
631
+ else None
632
+ )
633
+ self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
634
+ self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
635
+ self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
636
+ {
637
+ alias: JumpStartSerializablePayload(payload, is_hub_content=True)
638
+ for alias, payload in json_obj.get("DefaultPayloads").items()
639
+ }
640
+ if json_obj.get("DefaultPayloads")
641
+ else None
642
+ )
643
+ self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get(
644
+ "HostingResourceRequirements", None
645
+ )
646
+ self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
647
+ JumpStartInstanceTypeVariants(
648
+ json_obj.get("HostingInstanceTypeVariants"),
649
+ is_hub_content=True,
650
+ )
651
+ if json_obj.get("HostingInstanceTypeVariants")
652
+ else None
653
+ )
654
+ self.notebook_location_uris: Optional[NotebookLocationUris] = (
655
+ NotebookLocationUris(json_obj.get("NotebookLocationUris"))
656
+ if json_obj.get("NotebookLocationUris")
657
+ else None
658
+ )
659
+ self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta
660
+ self.task: Optional[str] = json_obj.get("Task")
661
+ self.framework: Optional[str] = json_obj.get("Framework")
662
+ self.datatype: Optional[str] = json_obj.get("Datatype")
663
+ self.license: Optional[str] = json_obj.get("License")
664
+ self.contextual_help: Optional[str] = json_obj.get("ContextualHelp")
665
+ self.model_dir: Optional[str] = json_obj.get("ModelDir")
666
+ # Deploy kwargs
667
+ self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout")
668
+ self.container_startup_health_check_timeout: Optional[str] = json_obj.get(
669
+ "ContainerStartupHealthCheckTimeout"
670
+ )
671
+
672
+ if self.training_supported:
673
+ self.default_training_dataset_uri: Optional[str] = json_obj.get(
674
+ "DefaultTrainingDatasetUri"
675
+ )
676
+ self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
677
+ "TrainingModelPackageArtifactUri"
678
+ )
679
+ self.training_artifact_compression_type: Optional[str] = json_obj.get(
680
+ "TrainingArtifactCompressionType"
681
+ )
682
+ self.training_artifact_s3_data_type: Optional[str] = json_obj.get(
683
+ "TrainingArtifactS3DataType"
684
+ )
685
+ self.hyperparameters: List[JumpStartHyperparameter] = []
686
+ hyperparameters: Any = json_obj.get("Hyperparameters")
687
+ if hyperparameters is not None:
688
+ self.hyperparameters.extend(
689
+ [
690
+ JumpStartHyperparameter(hyperparameter, is_hub_content=True)
691
+ for hyperparameter in hyperparameters
692
+ ]
693
+ )
694
+
695
+ self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri")
696
+ self.training_prepacked_script_uri: Optional[str] = json_obj.get(
697
+ "TrainingPrepackedScriptUri"
698
+ )
699
+ self.training_prepacked_script_version: Optional[str] = json_obj.get(
700
+ "TrainingPrepackedScriptVersion"
701
+ )
702
+ self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri")
703
+ self._non_serializable_slots.append("training_ecr_specs")
704
+ self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get(
705
+ "TrainingMetrics", None
706
+ )
707
+ self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri")
708
+
709
+ self.training_config_rankings = self._get_config_rankings(
710
+ json_obj, _ComponentType.TRAINING
711
+ )
712
+ self.training_config_components = self._get_config_components(
713
+ json_obj, _ComponentType.TRAINING
714
+ )
715
+ self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING)
716
+
717
+ self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies")
718
+ self.default_training_instance_type: Optional[str] = json_obj.get(
719
+ "DefaultTrainingInstanceType"
720
+ )
721
+ self.supported_training_instance_types: Optional[str] = json_obj.get(
722
+ "SupportedTrainingInstanceTypes"
723
+ )
724
+ self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize")
725
+ self.training_enable_network_isolation: Optional[str] = json_obj.get(
726
+ "TrainingEnableNetworkIsolation", False
727
+ )
728
+ self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
729
+ JumpStartInstanceTypeVariants(
730
+ json_obj.get("TrainingInstanceTypeVariants"),
731
+ is_hub_content=True,
732
+ )
733
+ if json_obj.get("TrainingInstanceTypeVariants")
734
+ else None
735
+ )
736
+ # Estimator kwargs
737
+ self.encrypt_inter_container_traffic: Optional[bool] = (
738
+ bool(json_obj.get("EncryptInterContainerTraffic"))
739
+ if json_obj.get("EncryptInterContainerTraffic")
740
+ else None
741
+ )
742
+ self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds")
743
+ self.disable_output_compression: Optional[bool] = (
744
+ bool(json_obj.get("DisableOutputCompression"))
745
+ if json_obj.get("DisableOutputCompression")
746
+ else None
747
+ )
748
+
749
+ def get_schema_version(self) -> str:
750
+ """Returns schema version."""
751
+ return self.SCHEMA_VERSION
752
+
753
+ def get_region(self) -> str:
754
+ """Returns hub region."""
755
+ return self._region
756
+
757
+ def _get_config_rankings(
758
+ self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
759
+ ) -> Optional[Dict[str, JumpStartConfigRanking]]:
760
+ """Returns config rankings."""
761
+ config_rankings = json_obj.get(f"{component_type.value}ConfigRankings")
762
+ return (
763
+ {
764
+ alias: JumpStartConfigRanking(ranking, is_hub_content=True)
765
+ for alias, ranking in config_rankings.items()
766
+ }
767
+ if config_rankings
768
+ else None
769
+ )
770
+
771
+ def _get_config_components(
772
+ self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
773
+ ) -> Optional[Dict[str, JumpStartConfigComponent]]:
774
+ """Returns config components."""
775
+ config_components = json_obj.get(f"{component_type.value}ConfigComponents")
776
+ return (
777
+ {
778
+ alias: JumpStartConfigComponent(alias, config, is_hub_content=True)
779
+ for alias, config in config_components.items()
780
+ }
781
+ if config_components
782
+ else None
783
+ )
784
+
785
+ def _get_configs(
786
+ self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE
787
+ ) -> Optional[JumpStartMetadataConfigs]:
788
+ """Returns configs."""
789
+ if not (configs := json_obj.get(f"{component_type.value}Configs")):
790
+ return None
791
+
792
+ configs_dict = {}
793
+ for alias, config in configs.items():
794
+ config_components = None
795
+ if isinstance(config, dict) and (component_names := config.get("ComponentNames")):
796
+ config_components = {
797
+ name: getattr(self, f"{component_type.value.lower()}_config_components").get(
798
+ name
799
+ )
800
+ for name in component_names
801
+ }
802
+ configs_dict[alias] = JumpStartMetadataConfig(
803
+ alias, config, json_obj, config_components, is_hub_content=True
804
+ )
805
+
806
+ if component_type == _ComponentType.INFERENCE:
807
+ config_rankings = self.inference_config_rankings
808
+ scope = JumpStartScriptScope.INFERENCE
809
+ else:
810
+ config_rankings = self.training_config_rankings
811
+ scope = JumpStartScriptScope.TRAINING
812
+
813
+ return JumpStartMetadataConfigs(configs_dict, config_rankings, scope)
814
+
815
+
816
+ class HubNotebookDocument(HubDataHolderType):
817
+ """Data class for notebook type HubContentDocument from session.describe_hub_content()."""
818
+
819
+ SCHEMA_VERSION = "1.0.0"
820
+
821
+ __slots__ = ["notebook_location", "dependencies", "_region"]
822
+
823
+ _non_serializable_slots = ["_region"]
824
+
825
+ def __init__(self, json_obj: Dict[str, Any], region: str) -> None:
826
+ """Instantiates HubNotebookDocument object.
827
+
828
+ Args:
829
+ json_obj (Dict[str, Any]): Dictionary representation of hub content document.
830
+ """
831
+ self._region = region
832
+ self.from_json(json_obj)
833
+
834
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
835
+ """Sets fields in object based on json.
836
+
837
+ Args:
838
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
839
+ """
840
+ self.notebook_location = json_obj["NotebookLocation"]
841
+ self.dependencies: List[HubContentDependency] = [
842
+ HubContentDependency(dep) for dep in json_obj["Dependencies"]
843
+ ]
844
+
845
+ def get_schema_version(self) -> str:
846
+ """Returns schema version."""
847
+ return self.SCHEMA_VERSION
848
+
849
+ def get_region(self) -> str:
850
+ """Returns hub region."""
851
+ return self._region
852
+
853
+
854
+ HubContentDocument = Union[HubModelDocument, HubNotebookDocument]
855
+
856
+
857
+ class HubContentInfo(HubDataHolderType):
858
+ """Data class for the HubContentInfo from session.list_hub_contents()."""
859
+
860
+ __slots__ = [
861
+ "creation_time",
862
+ "document_schema_version",
863
+ "hub_content_arn",
864
+ "hub_content_name",
865
+ "hub_content_status",
866
+ "hub_content_type",
867
+ "hub_content_version",
868
+ "hub_content_description",
869
+ "hub_content_display_name",
870
+ "hub_content_search_keywords",
871
+ "_region",
872
+ ]
873
+
874
+ _non_serializable_slots = ["_region"]
875
+
876
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
877
+ """Instantiates HubContentInfo object.
878
+
879
+ Args:
880
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
881
+ """
882
+ self.from_json(json_obj)
883
+
884
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
885
+ """Sets fields in object based on json.
886
+
887
+ Args:
888
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
889
+ """
890
+ self.creation_time: str = json_obj["CreationTime"]
891
+ self.document_schema_version: str = json_obj["DocumentSchemaVersion"]
892
+ self.hub_content_arn: str = json_obj["HubContentArn"]
893
+ self.hub_content_name: str = json_obj["HubContentName"]
894
+ self.hub_content_status: str = json_obj["HubContentStatus"]
895
+ self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"])
896
+ self.hub_content_version: str = json_obj["HubContentVersion"]
897
+ self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription")
898
+ self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName")
899
+ self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(
900
+ self.hub_content_arn
901
+ )
902
+ self.hub_content_search_keywords: Optional[List[str]] = json_obj.get(
903
+ "HubContentSearchKeywords"
904
+ )
905
+
906
+ def get_hub_region(self) -> Optional[str]:
907
+ """Returns the region hub is in."""
908
+ return self._region
909
+
910
+
911
+ class ListHubContentsResponse(HubDataHolderType):
912
+ """Data class for the Hub from session.list_hub_contents()"""
913
+
914
+ __slots__ = [
915
+ "hub_content_summaries",
916
+ "next_token",
917
+ ]
918
+
919
+ def __init__(self, json_obj: Dict[str, Any]) -> None:
920
+ """Instantiates ImportHubResponse object.
921
+
922
+ Args:
923
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
924
+ """
925
+ self.from_json(json_obj)
926
+
927
+ def from_json(self, json_obj: Dict[str, Any]) -> None:
928
+ """Sets fields in object based on json.
929
+
930
+ Args:
931
+ json_obj (Dict[str, Any]): Dictionary representation of hub description.
932
+ """
933
+ self.hub_content_summaries: List[HubContentInfo] = [
934
+ HubContentInfo(item) for item in json_obj["HubContentSummaries"]
935
+ ]
936
+ self.next_token: str = json_obj["NextToken"]