sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,239 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores inference payload utilities for JumpStart models."""
14
+ from __future__ import absolute_import
15
+ import base64
16
+ import json
17
+ from typing import Any, Dict, List, Optional, Union
18
+ import re
19
+ import boto3
20
+
21
+ from sagemaker.core.jumpstart.accessors import JumpStartS3PayloadAccessor
22
+ from sagemaker.core.jumpstart.artifacts.payloads import _retrieve_example_payloads
23
+ from sagemaker.core.jumpstart.constants import (
24
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
25
+ )
26
+ from sagemaker.core.jumpstart.enums import JumpStartModelType, MIMEType
27
+ from sagemaker.core.jumpstart.types import JumpStartSerializablePayload
28
+ from sagemaker.core.jumpstart.utils import (
29
+ get_jumpstart_content_bucket,
30
+ get_region_fallback,
31
+ )
32
+ from sagemaker.core.helper.session_helper import Session
33
+
34
+
35
+ S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
36
+ S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
37
+
38
+
39
+ def _extract_field_from_json(
40
+ json_input: dict,
41
+ keys: List[str],
42
+ ) -> Any:
43
+ """Given a dictionary, returns value at specified keys.
44
+
45
+ Raises:
46
+ KeyError: If a key cannot be found in the json input.
47
+ """
48
+ curr_json = json_input
49
+ for idx, key in enumerate(keys):
50
+ if idx < len(keys) - 1:
51
+ curr_json = curr_json[key]
52
+ continue
53
+ return curr_json[key]
54
+
55
+
56
+ def _construct_payload(
57
+ prompt: str,
58
+ model_id: str,
59
+ model_version: str,
60
+ region: Optional[str] = None,
61
+ tolerate_vulnerable_model: bool = False,
62
+ tolerate_deprecated_model: bool = False,
63
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
64
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
65
+ alias: Optional[str] = None,
66
+ ) -> Optional[JumpStartSerializablePayload]:
67
+ """Returns example payload from prompt.
68
+
69
+ Args:
70
+ prompt (str): String-valued prompt to embed in payload.
71
+ model_id (str): JumpStart model ID of the JumpStart model for which to construct
72
+ the payload.
73
+ model_version (str): Version of the JumpStart model for which to retrieve the
74
+ payload.
75
+ region (Optional[str]): Region for which to retrieve the
76
+ payload. (Default: None).
77
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
78
+ specifications should be tolerated (exception not raised). If False, raises an
79
+ exception if the script used by this version of the model has dependencies with known
80
+ security vulnerabilities. (Default: False).
81
+ tolerate_deprecated_model (bool): True if deprecated versions of model
82
+ specifications should be tolerated (exception not raised). If False, raises
83
+ an exception if the version of the model is deprecated. (Default: False).
84
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
85
+ object, used for SageMaker interactions. If not
86
+ specified, one is created using the default AWS configuration
87
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
88
+ model_type (JumpStartModelType): The type of the model, can be open weights model or
89
+ proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
90
+ Returns:
91
+ Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
92
+ this feature is unavailable for the specified model.
93
+ """
94
+ payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
95
+ model_id=model_id,
96
+ model_version=model_version,
97
+ region=region,
98
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
99
+ tolerate_deprecated_model=tolerate_deprecated_model,
100
+ sagemaker_session=sagemaker_session,
101
+ model_type=model_type,
102
+ )
103
+ if payloads is None or len(payloads) == 0:
104
+ return None
105
+
106
+ payload_to_use: JumpStartSerializablePayload = (
107
+ payloads[alias] if alias else list(payloads.values())[0]
108
+ )
109
+
110
+ prompt_key: Optional[str] = payload_to_use.prompt_key
111
+ if prompt_key is None:
112
+ return None
113
+
114
+ payload_body = payload_to_use.body
115
+ prompt_key_split = prompt_key.split(".")
116
+ for idx, prompt_key in enumerate(prompt_key_split):
117
+ if idx < len(prompt_key_split) - 1:
118
+ payload_body = payload_body[prompt_key]
119
+ else:
120
+ payload_body[prompt_key] = prompt
121
+
122
+ return payload_to_use
123
+
124
+
125
+ class PayloadSerializer:
126
+ """Utility class for serializing payloads associated with JumpStart models.
127
+
128
+ Many JumpStart models embed byte-streams into payloads corresponding to images, sounds,
129
+ and other content types which require downloading from S3.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ bucket: Optional[str] = None,
135
+ region: Optional[str] = None,
136
+ s3_client: Optional[boto3.client] = None,
137
+ ) -> None:
138
+ """Initializes PayloadSerializer object."""
139
+ self.bucket = bucket or get_jumpstart_content_bucket()
140
+ self.region = region or get_region_fallback(
141
+ s3_client=s3_client,
142
+ )
143
+ self.s3_client = s3_client
144
+
145
+ def get_bytes_payload_with_s3_references(
146
+ self,
147
+ payload_str: str,
148
+ ) -> bytes:
149
+ """Returns bytes object corresponding to referenced S3 object.
150
+
151
+ Raises:
152
+ ValueError: If the raw bytes payload is not formatted correctly.
153
+ """
154
+ s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str)
155
+ if len(s3_keys) != 1:
156
+ raise ValueError("Invalid bytes payload.")
157
+
158
+ s3_key = s3_keys[0]
159
+ serialized_s3_object = JumpStartS3PayloadAccessor.get_object_cached(
160
+ bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client
161
+ )
162
+
163
+ return serialized_s3_object
164
+
165
+ def embed_s3_references_in_str_payload(
166
+ self,
167
+ payload: str,
168
+ ) -> str:
169
+ """Inserts serialized S3 content into string payload.
170
+
171
+ If no S3 content is embedded in payload, original string is returned.
172
+ """
173
+ return self._embed_s3_b64_references_in_str_payload(payload_body=payload)
174
+
175
+ def _embed_s3_b64_references_in_str_payload(
176
+ self,
177
+ payload_body: str,
178
+ ) -> str:
179
+ """Performs base 64 encoding of payloads embedded in a payload.
180
+
181
+ This is required so that byte-valued payloads can be transmitted efficiently
182
+ as a utf-8 encoded string.
183
+ """
184
+
185
+ s3_keys = re.compile(S3_B64_STR_REGEX).findall(payload_body)
186
+ for s3_key in s3_keys:
187
+ b64_encoded_string = base64.b64encode(
188
+ bytearray(
189
+ JumpStartS3PayloadAccessor.get_object_cached(
190
+ bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client
191
+ )
192
+ )
193
+ ).decode()
194
+ payload_body = payload_body.replace(f"$s3_b64<{s3_key}>", b64_encoded_string)
195
+ return payload_body
196
+
197
+ def embed_s3_references_in_json_payload(
198
+ self, payload_body: Union[list, dict, str, int, float]
199
+ ) -> Union[list, dict, str, int, float]:
200
+ """Finds all S3 references in payload and embeds serialized S3 data.
201
+
202
+ If no S3 references are found, the payload is returned un-modified.
203
+
204
+ Raises:
205
+ ValueError: If the payload has an unrecognized type.
206
+ """
207
+ if isinstance(payload_body, str):
208
+ return self.embed_s3_references_in_str_payload(payload_body)
209
+ if isinstance(payload_body, (float, int)):
210
+ return payload_body
211
+ if isinstance(payload_body, list):
212
+ return [self.embed_s3_references_in_json_payload(item) for item in payload_body]
213
+ if isinstance(payload_body, dict):
214
+ return {
215
+ key: self.embed_s3_references_in_json_payload(value)
216
+ for key, value in payload_body.items()
217
+ }
218
+ raise ValueError(f"Payload has unrecognized type: {type(payload_body)}")
219
+
220
+ def serialize(self, payload: JumpStartSerializablePayload) -> Union[str, bytes]:
221
+ """Returns payload string or bytes that can be inputted to inference endpoint.
222
+
223
+ Raises:
224
+ ValueError: If the payload has an unrecognized type.
225
+ """
226
+ content_type = MIMEType.from_suffixed_type(payload.content_type)
227
+ body = payload.body
228
+
229
+ if content_type in {MIMEType.JSON, MIMEType.LIST_TEXT, MIMEType.X_TEXT}:
230
+ body = self.embed_s3_references_in_json_payload(body)
231
+ else:
232
+ body = self.get_bytes_payload_with_s3_references(body)
233
+
234
+ if isinstance(body, dict):
235
+ body = json.dumps(body)
236
+ elif not isinstance(body, str) and not isinstance(body, bytes):
237
+ raise ValueError(f"Default payload '{body}' has unrecognized type: {type(body)}")
238
+
239
+ return body
@@ -0,0 +1,163 @@
1
+ {
2
+ "af-south-1": {
3
+ "content_bucket": "jumpstart-cache-prod-af-south-1",
4
+ "gated_content_bucket": "jumpstart-private-cache-prod-af-south-1"
5
+ },
6
+ "ap-east-1": {
7
+ "content_bucket": "jumpstart-cache-prod-ap-east-1",
8
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-east-1"
9
+ },
10
+ "ap-northeast-1": {
11
+ "content_bucket": "jumpstart-cache-prod-ap-northeast-1",
12
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-1",
13
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-1"
14
+ },
15
+ "ap-northeast-2": {
16
+ "content_bucket": "jumpstart-cache-prod-ap-northeast-2",
17
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-2",
18
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-2"
19
+ },
20
+ "ap-northeast-3": {
21
+ "content_bucket": "jumpstart-cache-prod-ap-northeast-3",
22
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-3",
23
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-3"
24
+ },
25
+ "ap-south-1": {
26
+ "content_bucket": "jumpstart-cache-prod-ap-south-1",
27
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-1",
28
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-south-1"
29
+ },
30
+ "ap-south-2": {
31
+ "content_bucket": "jumpstart-cache-prod-ap-south-2",
32
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-2"
33
+ },
34
+ "ap-southeast-1": {
35
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-1",
36
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-1",
37
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-1"
38
+ },
39
+ "ap-southeast-2": {
40
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-2",
41
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-2",
42
+ "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-2"
43
+ },
44
+ "ap-southeast-3": {
45
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-3",
46
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-3"
47
+ },
48
+ "ap-southeast-4": {
49
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-4",
50
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-4"
51
+ },
52
+ "ap-southeast-5": {
53
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-5",
54
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-5"
55
+ },
56
+ "ap-southeast-7": {
57
+ "content_bucket": "jumpstart-cache-prod-ap-southeast-7",
58
+ "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-7"
59
+ },
60
+ "ca-central-1": {
61
+ "content_bucket": "jumpstart-cache-prod-ca-central-1",
62
+ "gated_content_bucket": "jumpstart-private-cache-prod-ca-central-1",
63
+ "neo_content_bucket": "sagemaker-sd-models-prod-ca-central-1"
64
+ },
65
+ "ca-west-1": {
66
+ "content_bucket": "jumpstart-cache-prod-ca-west-1",
67
+ "gated_content_bucket": "jumpstart-private-cache-prod-ca-west-1"
68
+ },
69
+ "cn-north-1": {
70
+ "content_bucket": "jumpstart-cache-prod-cn-north-1",
71
+ "gated_content_bucket": "jumpstart-private-cache-prod-cn-north-1"
72
+ },
73
+ "cn-northwest-1": {
74
+ "content_bucket": "jumpstart-cache-prod-cn-northwest-1",
75
+ "gated_content_bucket": "jumpstart-private-cache-prod-cn-northwest-1"
76
+ },
77
+ "eu-central-1": {
78
+ "content_bucket": "jumpstart-cache-prod-eu-central-1",
79
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-1",
80
+ "neo_content_bucket": "sagemaker-sd-models-prod-eu-central-1"
81
+ },
82
+ "eu-central-2": {
83
+ "content_bucket": "jumpstart-cache-prod-eu-central-2",
84
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-2"
85
+ },
86
+ "eu-north-1": {
87
+ "content_bucket": "jumpstart-cache-prod-eu-north-1",
88
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-north-1",
89
+ "neo_content_bucket": "sagemaker-sd-models-prod-eu-north-1"
90
+ },
91
+ "eu-south-1": {
92
+ "content_bucket": "jumpstart-cache-prod-eu-south-1",
93
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-1"
94
+ },
95
+ "eu-south-2": {
96
+ "content_bucket": "jumpstart-cache-prod-eu-south-2",
97
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-2"
98
+ },
99
+ "eu-west-1": {
100
+ "content_bucket": "jumpstart-cache-prod-eu-west-1",
101
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-1",
102
+ "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-1"
103
+ },
104
+ "eu-west-2": {
105
+ "content_bucket": "jumpstart-cache-prod-eu-west-2",
106
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-2",
107
+ "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-2"
108
+ },
109
+ "eu-west-3": {
110
+ "content_bucket": "jumpstart-cache-prod-eu-west-3",
111
+ "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-3",
112
+ "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-3"
113
+ },
114
+ "il-central-1": {
115
+ "content_bucket": "jumpstart-cache-prod-il-central-1",
116
+ "gated_content_bucket": "jumpstart-private-cache-prod-il-central-1"
117
+ },
118
+ "me-central-1": {
119
+ "content_bucket": "jumpstart-cache-prod-me-central-1",
120
+ "gated_content_bucket": "jumpstart-private-cache-prod-me-central-1"
121
+ },
122
+ "me-south-1": {
123
+ "content_bucket": "jumpstart-cache-prod-me-south-1",
124
+ "gated_content_bucket": "jumpstart-private-cache-prod-me-south-1"
125
+ },
126
+ "mx-central-1": {
127
+ "content_bucket": "jumpstart-cache-prod-mx-central-1",
128
+ "gated_content_bucket": "jumpstart-private-cache-prod-mx-central-1"
129
+ },
130
+ "sa-east-1": {
131
+ "content_bucket": "jumpstart-cache-prod-sa-east-1",
132
+ "gated_content_bucket": "jumpstart-private-cache-prod-sa-east-1",
133
+ "neo_content_bucket": "sagemaker-sd-models-prod-sa-east-1"
134
+ },
135
+ "us-east-1": {
136
+ "content_bucket": "jumpstart-cache-prod-us-east-1",
137
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-east-1",
138
+ "neo_content_bucket": "sagemaker-sd-models-prod-us-east-1"
139
+ },
140
+ "us-east-2": {
141
+ "content_bucket": "jumpstart-cache-prod-us-east-2",
142
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-east-2",
143
+ "neo_content_bucket": "sagemaker-sd-models-prod-us-east-2"
144
+ },
145
+ "us-gov-east-1": {
146
+ "content_bucket": "jumpstart-cache-prod-us-gov-east-1",
147
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-east-1"
148
+ },
149
+ "us-gov-west-1": {
150
+ "content_bucket": "jumpstart-cache-prod-us-gov-west-1",
151
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-west-1"
152
+ },
153
+ "us-west-1": {
154
+ "content_bucket": "jumpstart-cache-prod-us-west-1",
155
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-west-1",
156
+ "neo_content_bucket": "sagemaker-sd-models-prod-us-west-1"
157
+ },
158
+ "us-west-2": {
159
+ "content_bucket": "jumpstart-cache-prod-us-west-2",
160
+ "gated_content_bucket": "jumpstart-private-cache-prod-us-west-2",
161
+ "neo_content_bucket": "sagemaker-sd-models-prod-us-west-2"
162
+ }
163
+ }
@@ -0,0 +1,171 @@
1
+ import re
2
+ import logging
3
+ from typing import List, Iterator, Optional
4
+ from sagemaker.core.helper.session_helper import Session
5
+ from sagemaker.core.resources import HubContent
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class _Filter:
11
+ """
12
+ A filter that evaluates logical expressions against a list of keyword strings.
13
+
14
+ Supports logical operators (AND, OR, NOT), parentheses for grouping, and wildcard patterns
15
+ (e.g., `text-*`, `*ai`, `@task:foo`).
16
+
17
+ Example:
18
+ filt = _Filter("(@framework:huggingface OR text-*) AND NOT deprecated")
19
+ filt.match(["@framework:huggingface", "text-generation"]) # Returns True
20
+ """
21
+
22
+ def __init__(self, expression: str) -> None:
23
+ """
24
+ Initialize the filter with a string expression.
25
+
26
+ Args:
27
+ expression (str): A logical expression to evaluate against keywords.
28
+ Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
29
+ """
30
+ self.expression: str = expression
31
+
32
+ def match(self, keywords: List[str]) -> bool:
33
+ """
34
+ Evaluate the filter expression against a list of keywords.
35
+
36
+ Args:
37
+ keywords (List[str]): A list of keyword strings to test.
38
+
39
+ Returns:
40
+ bool: True if the expression evaluates to True for the given keywords, else False.
41
+ """
42
+ expr: str = self._convert_expression(self.expression)
43
+ try:
44
+ return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any})
45
+ except Exception:
46
+ return False
47
+
48
+ def _convert_expression(self, expr: str) -> str:
49
+ """
50
+ Convert the logical filter expression into a Python-evaluable string.
51
+
52
+ Args:
53
+ expr (str): The raw expression to convert.
54
+
55
+ Returns:
56
+ str: A Python expression string using 'any' and logical operators.
57
+ """
58
+ tokens: List[str] = re.findall(
59
+ r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE
60
+ )
61
+
62
+ def wildcard_condition(pattern: str) -> str:
63
+ pattern = pattern.strip('"').strip("'")
64
+ stripped = pattern.strip("*")
65
+
66
+ if pattern.startswith("*") and pattern.endswith("*"):
67
+ return f"{repr(stripped)} in k"
68
+ elif pattern.startswith("*"):
69
+ return f"k.endswith({repr(stripped)})"
70
+ elif pattern.endswith("*"):
71
+ return f"k.startswith({repr(stripped)})"
72
+ else:
73
+ return f"k == {repr(pattern)}"
74
+
75
+ def convert_token(token: str) -> str:
76
+ upper = token.upper()
77
+ if upper == "AND":
78
+ return "and"
79
+ elif upper == "OR":
80
+ return "or"
81
+ elif upper == "NOT":
82
+ return "not"
83
+ elif token in ("(", ")"):
84
+ return token
85
+ else:
86
+ return f"any({wildcard_condition(token)} for k in keywords)"
87
+
88
+ converted_tokens = [convert_token(tok) for tok in tokens]
89
+ return " ".join(converted_tokens)
90
+
91
+
92
+ def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]:
93
+ """
94
+ Retrieve all model entries from the specified hub and yield them one by one.
95
+
96
+ This function paginates through the SageMaker Hub API to retrieve all published models of type "Model"
97
+ and yields them as `HubContent` objects.
98
+
99
+ Args:
100
+ hub_name (str): The name of the hub to query.
101
+ sm_client (Session): The SageMaker session.
102
+
103
+ Yields:
104
+ HubContent: A `HubContent` object representing a single model entry from the hub.
105
+ """
106
+ next_token = None
107
+
108
+ while True:
109
+ # Prepare the request parameters
110
+ params = {"HubName": hub_name, "HubContentType": "Model", "MaxResults": 100}
111
+
112
+ # Add NextToken if it exists
113
+ if next_token:
114
+ params["NextToken"] = next_token
115
+
116
+ # Make the API call
117
+ response = sm_client.list_hub_contents(**params)
118
+
119
+ # Yield each content summary
120
+ for content in response["HubContentSummaries"]:
121
+ yield HubContent(
122
+ hub_name=hub_name,
123
+ hub_content_arn=content["HubContentArn"],
124
+ hub_content_type="Model",
125
+ hub_content_name=content["HubContentName"],
126
+ hub_content_version=content["HubContentVersion"],
127
+ hub_content_description=content.get("HubContentDescription", ""),
128
+ hub_content_search_keywords=content.get("HubContentSearchKeywords", []),
129
+ )
130
+
131
+ # Check if there are more results
132
+ next_token = response.get("NextToken", None)
133
+ if not next_token or len(response["HubContentSummaries"]) == 0:
134
+ break # Exit the loop if there are no more pages
135
+
136
+
137
+ def search_public_hub_models(
138
+ query: str,
139
+ hub_name: Optional[str] = "SageMakerPublicHub",
140
+ sagemaker_session: Optional[Session] = None,
141
+ ) -> List[HubContent]:
142
+ """
143
+ Search and filter models from hub using a keyword expression.
144
+
145
+ Args:
146
+ query (str): A logical expression used to filter models by keywords.
147
+ Example: "@task:text-generation AND NOT @framework:legacy"
148
+ hub_name (Optional[str]): The name of the hub to query. Defaults to "SageMakerPublicHub".
149
+ sagemaker_session (Optional[Session]): An optional SageMaker `Session` object. If not provided,
150
+ a default session will be created and a warning will be logged.
151
+
152
+ Returns:
153
+ List[HubContent]: A list of filtered `HubContent` model objects that match the query.
154
+ """
155
+ if sagemaker_session is None:
156
+ sagemaker_session = Session()
157
+ logger.warning("SageMaker session not provided. Using default Session.")
158
+ sm_client = sagemaker_session.sagemaker_client
159
+
160
+ models = _list_all_hub_models(hub_name, sm_client)
161
+ filt = _Filter(query)
162
+ results: List[HubContent] = []
163
+
164
+ for model in models:
165
+ keywords = model.hub_content_search_keywords
166
+ normalized_keywords = [kw.replace(" ", "-") for kw in keywords]
167
+
168
+ if filt.match(normalized_keywords):
169
+ results.append(model)
170
+
171
+ return results
@@ -0,0 +1,81 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """JumpStart serializers module - provides retrieve_default function for backward compatibility."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import Optional
17
+
18
+ from sagemaker.core.serializers import BaseSerializer
19
+ from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
20
+ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
21
+ from sagemaker.core.jumpstart.enums import JumpStartModelType
22
+ from sagemaker.core.helper.session_helper import Session
23
+
24
+
25
+ def retrieve_default(
26
+ region: Optional[str] = None,
27
+ model_id: Optional[str] = None,
28
+ model_version: Optional[str] = None,
29
+ hub_arn: Optional[str] = None,
30
+ tolerate_vulnerable_model: bool = False,
31
+ tolerate_deprecated_model: bool = False,
32
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
33
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
34
+ config_name: Optional[str] = None,
35
+ ) -> BaseSerializer:
36
+ """Retrieves the default serializer for the model matching the given arguments.
37
+
38
+ Args:
39
+ region (str): The AWS Region for which to retrieve the default serializer.
40
+ Defaults to ``None``.
41
+ model_id (str): The model ID of the model for which to
42
+ retrieve the default serializer. (Default: None).
43
+ model_version (str): The version of the model for which to retrieve the
44
+ default serializer. (Default: None).
45
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
46
+ model details from. (Default: None).
47
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
48
+ specifications should be tolerated (exception not raised). If False, raises an
49
+ exception if the script used by this version of the model has dependencies with known
50
+ security vulnerabilities. (Default: False).
51
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated
52
+ (exception not raised). False if these models should raise an exception.
53
+ (Default: False).
54
+ model_type (JumpStartModelType): The model type. (Default: JumpStartModelType.OPEN_WEIGHTS).
55
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
56
+ object, used for SageMaker interactions. If not
57
+ specified, one is created using the default AWS configuration
58
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
59
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
60
+ Returns:
61
+ BaseSerializer: The default serializer to use for the model.
62
+
63
+ Raises:
64
+ ValueError: If the combination of arguments specified is not supported.
65
+ """
66
+ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
67
+ raise ValueError(
68
+ "Must specify JumpStart `model_id` and `model_version` when retrieving serializers."
69
+ )
70
+
71
+ return artifacts._retrieve_default_serializer(
72
+ model_id=model_id,
73
+ model_version=model_version,
74
+ hub_arn=hub_arn,
75
+ region=region,
76
+ tolerate_deprecated_model=tolerate_deprecated_model,
77
+ tolerate_vulnerable_model=tolerate_vulnerable_model,
78
+ sagemaker_session=sagemaker_session,
79
+ model_type=model_type,
80
+ config_name=config_name,
81
+ )