sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,223 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart instance types."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import List, Optional
17
+
18
+ from sagemaker.core.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
19
+ from sagemaker.core.jumpstart.constants import (
20
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
21
+ )
22
+ from sagemaker.core.jumpstart.enums import (
23
+ JumpStartScriptScope,
24
+ JumpStartModelType,
25
+ )
26
+ from sagemaker.core.jumpstart.utils import (
27
+ get_region_fallback,
28
+ verify_model_region_and_return_specs,
29
+ )
30
+ from sagemaker.core.helper.session_helper import Session
31
+
32
+
33
+ def _retrieve_default_instance_type(
34
+ model_id: str,
35
+ model_version: str,
36
+ scope: str,
37
+ hub_arn: Optional[str] = None,
38
+ region: Optional[str] = None,
39
+ tolerate_vulnerable_model: bool = False,
40
+ tolerate_deprecated_model: bool = False,
41
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
42
+ training_instance_type: Optional[str] = None,
43
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
44
+ config_name: Optional[str] = None,
45
+ ) -> str:
46
+ """Retrieves the default instance type for the model.
47
+
48
+ Args:
49
+ model_id (str): JumpStart model ID of the JumpStart model for which to
50
+ retrieve the default instance type.
51
+ model_version (str): Version of the JumpStart model for which to retrieve the
52
+ default instance type.
53
+ scope (str): The script type, i.e. what it is used for.
54
+ Valid values: "training" and "inference".
55
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
56
+ model details from. (Default: None).
57
+ region (Optional[str]): Region for which to retrieve default instance type.
58
+ (Default: None).
59
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
60
+ specifications should be tolerated (exception not raised). If False, raises an
61
+ exception if the script used by this version of the model has dependencies with known
62
+ security vulnerabilities. (Default: False).
63
+ tolerate_deprecated_model (bool): True if deprecated versions of model
64
+ specifications should be tolerated (exception not raised). If False, raises
65
+ an exception if the version of the model is deprecated. (Default: False).
66
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
67
+ object, used for SageMaker interactions. If not
68
+ specified, one is created using the default AWS configuration
69
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
70
+ training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
71
+ instance type used for the training job that produced the fine-tuned weights.
72
+ Optionally supply this to get a inference instance type conditioned
73
+ on the training instance, to ensure compatability of training artifact to inference
74
+ instance. (Default: None).
75
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
76
+ Returns:
77
+ str: the default instance type to use for the model or None.
78
+
79
+ Raises:
80
+ ValueError: If the model is not available in the
81
+ specified region due to lack of supported computing instances.
82
+ """
83
+
84
+ region = region or get_region_fallback(
85
+ sagemaker_session=sagemaker_session,
86
+ )
87
+
88
+ model_specs = verify_model_region_and_return_specs(
89
+ model_id=model_id,
90
+ version=model_version,
91
+ hub_arn=hub_arn,
92
+ scope=scope,
93
+ region=region,
94
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
95
+ tolerate_deprecated_model=tolerate_deprecated_model,
96
+ model_type=model_type,
97
+ sagemaker_session=sagemaker_session,
98
+ config_name=config_name,
99
+ )
100
+
101
+ if scope == JumpStartScriptScope.INFERENCE:
102
+ instance_specific_default_instance_type = (
103
+ (
104
+ model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
105
+ training_instance_type
106
+ )
107
+ )
108
+ if training_instance_type is not None
109
+ and getattr(model_specs, "training_instance_type_variants", None) is not None
110
+ else None
111
+ )
112
+ default_instance_type = (
113
+ instance_specific_default_instance_type
114
+ if instance_specific_default_instance_type is not None
115
+ else model_specs.default_inference_instance_type
116
+ )
117
+ elif scope == JumpStartScriptScope.TRAINING:
118
+ default_instance_type = model_specs.default_training_instance_type
119
+ else:
120
+ raise NotImplementedError(
121
+ f"Unsupported script scope for retrieving default instance type: '{scope}'"
122
+ )
123
+
124
+ if default_instance_type in {None, ""}:
125
+ raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
126
+ return default_instance_type
127
+
128
+
129
+ def _retrieve_instance_types(
130
+ model_id: str,
131
+ model_version: str,
132
+ scope: str,
133
+ hub_arn: Optional[str] = None,
134
+ region: Optional[str] = None,
135
+ tolerate_vulnerable_model: bool = False,
136
+ tolerate_deprecated_model: bool = False,
137
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
138
+ training_instance_type: Optional[str] = None,
139
+ config_name: Optional[str] = None,
140
+ ) -> List[str]:
141
+ """Retrieves the supported instance types for the model.
142
+
143
+ Args:
144
+ model_id (str): JumpStart model ID of the JumpStart model for which to
145
+ retrieve the supported instance types.
146
+ model_version (str): Version of the JumpStart model for which to retrieve the
147
+ supported instance types.
148
+ scope (str): The script type, i.e. what it is used for.
149
+ Valid values: "training" and "inference".
150
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
151
+ model details from. (Default: None).
152
+ region (Optional[str]): Region for which to retrieve supported instance types.
153
+ (Default: None).
154
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
155
+ specifications should be tolerated (exception not raised). If False, raises an
156
+ exception if the script used by this version of the model has dependencies with known
157
+ security vulnerabilities. (Default: False).
158
+ tolerate_deprecated_model (bool): True if deprecated versions of model
159
+ specifications should be tolerated (exception not raised). If False, raises
160
+ an exception if the version of the model is deprecated. (Default: False).
161
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
162
+ object, used for SageMaker interactions. If not
163
+ specified, one is created using the default AWS configuration
164
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
165
+ training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
166
+ instance type used for the training job that produced the fine-tuned weights.
167
+ Optionally supply this to get a inference instance type conditioned
168
+ on the training instance, to ensure compatability of training artifact to inference
169
+ instance. (Default: None).
170
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
171
+ Returns:
172
+ list: the supported instance types to use for the model or None.
173
+
174
+ Raises:
175
+ ValueError: If the model is not available in the
176
+ specified region due to lack of supported computing instances.
177
+ """
178
+
179
+ region = region or get_region_fallback(
180
+ sagemaker_session=sagemaker_session,
181
+ )
182
+
183
+ model_specs = verify_model_region_and_return_specs(
184
+ model_id=model_id,
185
+ version=model_version,
186
+ hub_arn=hub_arn,
187
+ scope=scope,
188
+ region=region,
189
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
190
+ tolerate_deprecated_model=tolerate_deprecated_model,
191
+ sagemaker_session=sagemaker_session,
192
+ config_name=config_name,
193
+ )
194
+
195
+ if scope == JumpStartScriptScope.INFERENCE:
196
+ default_instance_types = model_specs.supported_inference_instance_types or []
197
+ instance_specific_instance_types = (
198
+ model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
199
+ training_instance_type
200
+ )
201
+ if training_instance_type is not None
202
+ and getattr(model_specs, "training_instance_type_variants", None) is not None
203
+ else []
204
+ )
205
+ instance_types = (
206
+ instance_specific_instance_types
207
+ if len(instance_specific_instance_types) > 0
208
+ else default_instance_types
209
+ )
210
+
211
+ elif scope == JumpStartScriptScope.TRAINING:
212
+ if training_instance_type is not None:
213
+ raise ValueError("Cannot use `training_instance_type` argument with training scope.")
214
+ instance_types = model_specs.supported_training_instance_types
215
+ else:
216
+ raise NotImplementedError(
217
+ f"Unsupported script scope for retrieving supported instance types: '{scope}'"
218
+ )
219
+
220
+ if instance_types is None or len(instance_types) == 0:
221
+ raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
222
+
223
+ return instance_types
@@ -0,0 +1,289 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart kwargs."""
14
+ from __future__ import absolute_import
15
+ from copy import deepcopy
16
+ from typing import Optional
17
+ from sagemaker.core.helper.session_helper import Session
18
+ from sagemaker.core.common_utils import volume_size_supported
19
+ from sagemaker.core.jumpstart.constants import (
20
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
21
+ )
22
+ from sagemaker.core.jumpstart.enums import (
23
+ JumpStartScriptScope,
24
+ JumpStartModelType,
25
+ )
26
+ from sagemaker.core.jumpstart.utils import (
27
+ get_region_fallback,
28
+ verify_model_region_and_return_specs,
29
+ )
30
+
31
+
32
+ def _retrieve_model_init_kwargs(
33
+ model_id: str,
34
+ model_version: str,
35
+ hub_arn: Optional[str] = None,
36
+ region: Optional[str] = None,
37
+ tolerate_vulnerable_model: bool = False,
38
+ tolerate_deprecated_model: bool = False,
39
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
41
+ config_name: Optional[str] = None,
42
+ ) -> dict:
43
+ """Retrieves kwargs for `Model`.
44
+
45
+ Args:
46
+ model_id (str): JumpStart model ID of the JumpStart model for which to
47
+ retrieve the kwargs.
48
+ model_version (str): Version of the JumpStart model for which to retrieve the
49
+ kwargs.
50
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51
+ model details from. (Default: None).
52
+ region (Optional[str]): Region for which to retrieve kwargs.
53
+ (Default: None).
54
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
55
+ specifications should be tolerated (exception not raised). If False, raises an
56
+ exception if the script used by this version of the model has dependencies with known
57
+ security vulnerabilities. (Default: False).
58
+ tolerate_deprecated_model (bool): True if deprecated versions of model
59
+ specifications should be tolerated (exception not raised). If False, raises
60
+ an exception if the version of the model is deprecated. (Default: False).
61
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
62
+ object, used for SageMaker interactions. If not
63
+ specified, one is created using the default AWS configuration
64
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
65
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
66
+ Returns:
67
+ dict: the kwargs to use for the use case.
68
+ """
69
+
70
+ region = region or get_region_fallback(
71
+ sagemaker_session=sagemaker_session,
72
+ )
73
+
74
+ model_specs = verify_model_region_and_return_specs(
75
+ model_id=model_id,
76
+ version=model_version,
77
+ hub_arn=hub_arn,
78
+ scope=JumpStartScriptScope.INFERENCE,
79
+ region=region,
80
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
81
+ tolerate_deprecated_model=tolerate_deprecated_model,
82
+ sagemaker_session=sagemaker_session,
83
+ model_type=model_type,
84
+ config_name=config_name,
85
+ )
86
+
87
+ kwargs = deepcopy(model_specs.model_kwargs)
88
+
89
+ if model_specs.inference_enable_network_isolation is not None:
90
+ kwargs.update({"enable_network_isolation": model_specs.inference_enable_network_isolation})
91
+
92
+ return kwargs
93
+
94
+
95
+ def _retrieve_model_deploy_kwargs(
96
+ model_id: str,
97
+ model_version: str,
98
+ instance_type: str,
99
+ hub_arn: Optional[str] = None,
100
+ region: Optional[str] = None,
101
+ tolerate_vulnerable_model: bool = False,
102
+ tolerate_deprecated_model: bool = False,
103
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
104
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
105
+ config_name: Optional[str] = None,
106
+ ) -> dict:
107
+ """Retrieves kwargs for `Model.deploy`.
108
+
109
+ Args:
110
+ model_id (str): JumpStart model ID of the JumpStart model for which to
111
+ retrieve the kwargs.
112
+ model_version (str): Version of the JumpStart model for which to retrieve the
113
+ kwargs.
114
+ instance_type (str): Instance type of the hosting endpoint, to determine if volume size
115
+ is supported.
116
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
117
+ model details from. (Default: None).
118
+ region (Optional[str]): Region for which to retrieve kwargs.
119
+ (Default: None).
120
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
121
+ specifications should be tolerated (exception not raised). If False, raises an
122
+ exception if the script used by this version of the model has dependencies with known
123
+ security vulnerabilities. (Default: False).
124
+ tolerate_deprecated_model (bool): True if deprecated versions of model
125
+ specifications should be tolerated (exception not raised). If False, raises
126
+ an exception if the version of the model is deprecated. (Default: False).
127
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
128
+ object, used for SageMaker interactions. If not
129
+ specified, one is created using the default AWS configuration
130
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
131
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
132
+
133
+ Returns:
134
+ dict: the kwargs to use for the use case.
135
+ """
136
+
137
+ region = region or get_region_fallback(
138
+ sagemaker_session=sagemaker_session,
139
+ )
140
+
141
+ model_specs = verify_model_region_and_return_specs(
142
+ model_id=model_id,
143
+ version=model_version,
144
+ hub_arn=hub_arn,
145
+ scope=JumpStartScriptScope.INFERENCE,
146
+ region=region,
147
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
148
+ tolerate_deprecated_model=tolerate_deprecated_model,
149
+ sagemaker_session=sagemaker_session,
150
+ model_type=model_type,
151
+ config_name=config_name,
152
+ )
153
+
154
+ if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:
155
+ return {**model_specs.deploy_kwargs, **{"volume_size": model_specs.inference_volume_size}}
156
+
157
+ return model_specs.deploy_kwargs
158
+
159
+
160
+ def _retrieve_estimator_init_kwargs(
161
+ model_id: str,
162
+ model_version: str,
163
+ instance_type: str,
164
+ hub_arn: Optional[str] = None,
165
+ region: Optional[str] = None,
166
+ tolerate_vulnerable_model: bool = False,
167
+ tolerate_deprecated_model: bool = False,
168
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
169
+ config_name: Optional[str] = None,
170
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
171
+ ) -> dict:
172
+ """Retrieves kwargs for `Estimator`.
173
+
174
+ Args:
175
+ model_id (str): JumpStart model ID of the JumpStart model for which to
176
+ retrieve the kwargs.
177
+ model_version (str): Version of the JumpStart model for which to retrieve the
178
+ kwargs.
179
+ instance_type (str): Instance type of the training job, to determine if volume size is
180
+ supported.
181
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
182
+ model details from. (Default: None).
183
+ region (Optional[str]): Region for which to retrieve kwargs.
184
+ (Default: None).
185
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
186
+ specifications should be tolerated (exception not raised). If False, raises an
187
+ exception if the script used by this version of the model has dependencies with known
188
+ security vulnerabilities. (Default: False).
189
+ tolerate_deprecated_model (bool): True if deprecated versions of model
190
+ specifications should be tolerated (exception not raised). If False, raises
191
+ an exception if the version of the model is deprecated. (Default: False).
192
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
193
+ object, used for SageMaker interactions. If not
194
+ specified, one is created using the default AWS configuration
195
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
196
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
197
+ model_type (JumpStartModelType): The type of the model, can be open weights model
198
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
199
+ Returns:
200
+ dict: the kwargs to use for the use case.
201
+ """
202
+
203
+ region = region or get_region_fallback(
204
+ sagemaker_session=sagemaker_session,
205
+ )
206
+
207
+ model_specs = verify_model_region_and_return_specs(
208
+ model_id=model_id,
209
+ version=model_version,
210
+ hub_arn=hub_arn,
211
+ scope=JumpStartScriptScope.TRAINING,
212
+ region=region,
213
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
214
+ tolerate_deprecated_model=tolerate_deprecated_model,
215
+ sagemaker_session=sagemaker_session,
216
+ config_name=config_name,
217
+ model_type=model_type,
218
+ )
219
+
220
+ kwargs = deepcopy(model_specs.estimator_kwargs)
221
+
222
+ if model_specs.training_enable_network_isolation is not None:
223
+ kwargs.update({"enable_network_isolation": model_specs.training_enable_network_isolation})
224
+
225
+ if volume_size_supported(instance_type) and model_specs.training_volume_size is not None:
226
+ kwargs.update({"volume_size": model_specs.training_volume_size})
227
+
228
+ return kwargs
229
+
230
+
231
+ def _retrieve_estimator_fit_kwargs(
232
+ model_id: str,
233
+ model_version: str,
234
+ hub_arn: Optional[str] = None,
235
+ region: Optional[str] = None,
236
+ tolerate_vulnerable_model: bool = False,
237
+ tolerate_deprecated_model: bool = False,
238
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
239
+ config_name: Optional[str] = None,
240
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
241
+ ) -> dict:
242
+ """Retrieves kwargs for `Estimator.fit`.
243
+
244
+ Args:
245
+ model_id (str): JumpStart model ID of the JumpStart model for which to
246
+ retrieve the kwargs.
247
+ model_version (str): Version of the JumpStart model for which to retrieve the
248
+ kwargs.
249
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
250
+ model details from. (Default: None).
251
+ region (Optional[str]): Region for which to retrieve kwargs.
252
+ (Default: None).
253
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
254
+ specifications should be tolerated (exception not raised). If False, raises an
255
+ exception if the script used by this version of the model has dependencies with known
256
+ security vulnerabilities. (Default: False).
257
+ tolerate_deprecated_model (bool): True if deprecated versions of model
258
+ specifications should be tolerated (exception not raised). If False, raises
259
+ an exception if the version of the model is deprecated. (Default: False).
260
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
261
+ object, used for SageMaker interactions. If not
262
+ specified, one is created using the default AWS configuration
263
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
264
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
265
+ model_type (JumpStartModelType): The type of the model, can be open weights model
266
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
267
+
268
+ Returns:
269
+ dict: the kwargs to use for the use case.
270
+ """
271
+
272
+ region = region or get_region_fallback(
273
+ sagemaker_session=sagemaker_session,
274
+ )
275
+
276
+ model_specs = verify_model_region_and_return_specs(
277
+ model_id=model_id,
278
+ version=model_version,
279
+ hub_arn=hub_arn,
280
+ scope=JumpStartScriptScope.TRAINING,
281
+ region=region,
282
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
283
+ tolerate_deprecated_model=tolerate_deprecated_model,
284
+ sagemaker_session=sagemaker_session,
285
+ config_name=config_name,
286
+ model_type=model_type,
287
+ )
288
+
289
+ return model_specs.fit_kwargs
@@ -0,0 +1,117 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains functions for obtaining JumpStart metric definitions."""
14
+ from __future__ import absolute_import
15
+ from copy import deepcopy
16
+ from typing import Dict, List, Optional
17
+ from sagemaker.core.jumpstart.constants import (
18
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19
+ )
20
+ from sagemaker.core.jumpstart.enums import (
21
+ JumpStartModelType,
22
+ JumpStartScriptScope,
23
+ )
24
+ from sagemaker.core.jumpstart.utils import (
25
+ get_region_fallback,
26
+ verify_model_region_and_return_specs,
27
+ )
28
+ from sagemaker.core.helper.session_helper import Session
29
+
30
+
31
+ def _retrieve_default_training_metric_definitions(
32
+ model_id: str,
33
+ model_version: str,
34
+ region: Optional[str],
35
+ hub_arn: Optional[str] = None,
36
+ tolerate_vulnerable_model: bool = False,
37
+ tolerate_deprecated_model: bool = False,
38
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39
+ instance_type: Optional[str] = None,
40
+ config_name: Optional[str] = None,
41
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
42
+ ) -> Optional[List[Dict[str, str]]]:
43
+ """Retrieves the default training metric definitions for the model.
44
+
45
+ Args:
46
+ model_id (str): JumpStart model ID of the JumpStart model for which to
47
+ retrieve the default training metric definitions.
48
+ model_version (str): Version of the JumpStart model for which to retrieve the
49
+ default training metric definitions.
50
+ region (Optional[str]): Region for which to retrieve default training metric
51
+ definitions.
52
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53
+ model details from. (Default: None).
54
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
55
+ specifications should be tolerated (exception not raised). If False, raises an
56
+ exception if the script used by this version of the model has dependencies with known
57
+ security vulnerabilities. (Default: False).
58
+ tolerate_deprecated_model (bool): True if deprecated versions of model
59
+ specifications should be tolerated (exception not raised). If False, raises
60
+ an exception if the version of the model is deprecated. (Default: False).
61
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
62
+ object, used for SageMaker interactions. If not
63
+ specified, one is created using the default AWS configuration
64
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
65
+ instance_type (str): An instance type to optionally supply in order to get
66
+ metric definitions specific for the instance type.
67
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
68
+ model_type (JumpStartModelType): The type of the model, can be open weights model
69
+ or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
70
+ Returns:
71
+ list: the default training metric definitions to use for the model or None.
72
+ """
73
+
74
+ region = region or get_region_fallback(
75
+ sagemaker_session=sagemaker_session,
76
+ )
77
+
78
+ model_specs = verify_model_region_and_return_specs(
79
+ model_id=model_id,
80
+ version=model_version,
81
+ hub_arn=hub_arn,
82
+ scope=JumpStartScriptScope.TRAINING,
83
+ region=region,
84
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
85
+ tolerate_deprecated_model=tolerate_deprecated_model,
86
+ sagemaker_session=sagemaker_session,
87
+ config_name=config_name,
88
+ model_type=model_type,
89
+ )
90
+
91
+ default_metric_definitions = (
92
+ deepcopy(model_specs.metrics) if getattr(model_specs, "metrics") else []
93
+ )
94
+
95
+ instance_specific_metric_definitions = (
96
+ model_specs.training_instance_type_variants.get_instance_specific_metric_definitions(
97
+ instance_type
98
+ )
99
+ if instance_type
100
+ and getattr(model_specs, "training_instance_type_variants", None) is not None
101
+ else []
102
+ )
103
+
104
+ if instance_specific_metric_definitions:
105
+ instance_specific_metric_name: str
106
+ for instance_specific_metric_definition in instance_specific_metric_definitions:
107
+ instance_specific_metric_name = instance_specific_metric_definition["Name"]
108
+ default_metric_definitions = list(
109
+ filter(
110
+ lambda metric_definition: metric_definition["Name"]
111
+ != instance_specific_metric_name,
112
+ default_metric_definitions,
113
+ )
114
+ )
115
+ default_metric_definitions.append(instance_specific_metric_definition)
116
+
117
+ return default_metric_definitions