sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,575 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores notebook utils related to SageMaker JumpStart."""
14
+ from __future__ import absolute_import
15
+ import copy
16
+
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+
19
+ from functools import cmp_to_key
20
+ import json
21
+ import os
22
+ from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
23
+ from packaging.version import Version
24
+ from sagemaker.core.jumpstart import accessors
25
+ from sagemaker.core.jumpstart.constants import (
26
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
27
+ PROPRIETARY_MODEL_SPEC_PREFIX,
28
+ )
29
+ from sagemaker.core.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
30
+ from sagemaker.core.jumpstart.filters import (
31
+ SPECIAL_SUPPORTED_FILTER_KEYS,
32
+ ProprietaryModelFilterIdentifiers,
33
+ BooleanValues,
34
+ Identity,
35
+ SpecialSupportedFilterKeys,
36
+ )
37
+ from sagemaker.core.jumpstart.filters import (
38
+ Constant,
39
+ ModelFilter,
40
+ Operator,
41
+ evaluate_filter_expression,
42
+ )
43
+ from sagemaker.core.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
44
+ from sagemaker.core.jumpstart.utils import (
45
+ get_jumpstart_content_bucket,
46
+ get_region_fallback,
47
+ get_sagemaker_version,
48
+ verify_model_region_and_return_specs,
49
+ validate_model_id_and_get_type,
50
+ )
51
+ from sagemaker.core.helper.session_helper import Session
52
+
53
+ MAX_SEARCH_WORKERS = min(32, (os.cpu_count() or 4) * 2)
54
+
55
+
56
+ def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
57
+ model_version_1: Optional[Tuple[str, str]] = None,
58
+ model_version_2: Optional[Tuple[str, str]] = None,
59
+ ) -> int:
60
+ """Performs comparison of sdk specs paths, in order to sort them.
61
+
62
+ Args:
63
+ model_version_1 (Tuple[str, str]): The first model ID and version tuple to compare.
64
+ model_version_2 (Tuple[str, str]): The second model ID and version tuple to compare.
65
+ """
66
+ if model_version_1 is None or model_version_2 is None:
67
+ if model_version_2 is not None:
68
+ return -1
69
+ if model_version_1 is not None:
70
+ return 1
71
+ return 0
72
+
73
+ model_id_1, version_1 = model_version_1
74
+
75
+ model_id_2, version_2 = model_version_2
76
+
77
+ if model_id_1 < model_id_2:
78
+ return -1
79
+
80
+ if model_id_2 < model_id_1:
81
+ return 1
82
+
83
+ if Version(version_1) < Version(version_2):
84
+ return 1
85
+
86
+ if Version(version_2) < Version(version_1):
87
+ return -1
88
+
89
+ return 0
90
+
91
+
92
+ def _model_filter_in_operator_generator(filter_operator: Operator) -> Generator:
93
+ """Generator for model filters in an operator."""
94
+ for operator in filter_operator:
95
+ if isinstance(operator.unresolved_value, ModelFilter):
96
+ yield operator
97
+
98
+
99
+ def _put_resolved_booleans_into_filter(
100
+ filter_operator: Operator, model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues]
101
+ ) -> None:
102
+ """Iterate over the operators in the filter, assign resolved value if found in second arg.
103
+
104
+ If not found, assigns ``UNKNOWN``.
105
+ """
106
+ for operator in _model_filter_in_operator_generator(filter_operator):
107
+ model_filter = operator.unresolved_value
108
+ operator.resolved_value = model_filters_to_resolved_values.get(
109
+ model_filter, BooleanValues.UNKNOWN
110
+ )
111
+
112
+
113
+ def _populate_model_filters_to_resolved_values(
114
+ manifest_specs_cached_values: Dict[str, Any],
115
+ model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues],
116
+ model_filters: Operator,
117
+ ) -> None:
118
+ """Iterate over the model filters, if the filter key has a cached value, evaluate the filter.
119
+
120
+ The resolved filter values are placed in ``model_filters_to_resolved_values``.
121
+ """
122
+ for model_filter in model_filters:
123
+ if model_filter.key in manifest_specs_cached_values:
124
+ cached_model_value = manifest_specs_cached_values[model_filter.key]
125
+ evaluated_expression: BooleanValues = evaluate_filter_expression(
126
+ model_filter, cached_model_value
127
+ )
128
+ model_filters_to_resolved_values[model_filter] = evaluated_expression
129
+
130
+
131
+ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
132
+ """Parse the model ID, return a tuple framework, task, rest-of-id.
133
+
134
+ Args:
135
+ model_id (str): The model ID for which to extract the framework/task/model.
136
+ """
137
+ _id_parts = model_id.split("-")
138
+
139
+ if len(_id_parts) < 3:
140
+ return "", "", ""
141
+
142
+ framework = _id_parts[0]
143
+ task = _id_parts[1]
144
+ name = "-".join(_id_parts[2:])
145
+
146
+ return framework, task, name
147
+
148
+
149
+ def extract_model_type_filter_representation(spec_key: str) -> str:
150
+ """Parses model spec key, determine if the model is proprietary or open weight.
151
+
152
+ Args:
153
+ spek_key (str): The model spec key for which to extract the model type.
154
+ """
155
+ model_spec_prefix = spec_key.split("/")[0]
156
+
157
+ if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX:
158
+ return JumpStartModelType.PROPRIETARY.value
159
+
160
+ return JumpStartModelType.OPEN_WEIGHTS.value
161
+
162
+
163
+ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
164
+ filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
165
+ region: Optional[str] = None,
166
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
167
+ ) -> List[str]:
168
+ """List tasks for JumpStart, and optionally apply filters to result.
169
+
170
+ Args:
171
+ filter (Union[Operator, str]): Optional. The filter to apply to list tasks. This can be
172
+ either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
173
+ or simply a string filter which will get serialized into an Identity filter.
174
+ (e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed.
175
+ (Default: Constant(BooleanValues.TRUE)).
176
+ region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
177
+ models. (Default: None).
178
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
179
+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
180
+ """
181
+
182
+ region = region or get_region_fallback(
183
+ sagemaker_session=sagemaker_session,
184
+ )
185
+ tasks: Set[str] = set()
186
+ for model_id, _ in _generate_jumpstart_model_versions(
187
+ filter=filter,
188
+ region=region,
189
+ sagemaker_session=sagemaker_session,
190
+ model_type=JumpStartModelType.OPEN_WEIGHTS,
191
+ ):
192
+ _, task, _ = extract_framework_task_model(model_id)
193
+ tasks.add(task)
194
+ return sorted(list(tasks))
195
+
196
+
197
+ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
198
+ filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
199
+ region: Optional[str] = None,
200
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
201
+ ) -> List[str]:
202
+ """List frameworks for JumpStart, and optionally apply filters to result.
203
+
204
+ Args:
205
+ filter (Union[Operator, str]): Optional. The filter to apply to list frameworks. This can be
206
+ either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
207
+ or simply a string filter which will get serialized into an Identity filter.
208
+ (eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed.
209
+ (Default: Constant(BooleanValues.TRUE)).
210
+ region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
211
+ models. (Default: None).
212
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
213
+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
214
+ """
215
+
216
+ region = region or get_region_fallback(
217
+ sagemaker_session=sagemaker_session,
218
+ )
219
+ frameworks: Set[str] = set()
220
+ for model_id, _ in _generate_jumpstart_model_versions(
221
+ filter=filter,
222
+ region=region,
223
+ sagemaker_session=sagemaker_session,
224
+ model_type=JumpStartModelType.OPEN_WEIGHTS,
225
+ ):
226
+ framework, _, _ = extract_framework_task_model(model_id)
227
+ frameworks.add(framework)
228
+ return sorted(list(frameworks))
229
+
230
+
231
+ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
232
+ filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
233
+ region: Optional[str] = None,
234
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
235
+ ) -> List[str]:
236
+ """List scripts for JumpStart, and optionally apply filters to result.
237
+
238
+ Args:
239
+ filter (Union[Operator, str]): Optional. The filter to apply to list scripts. This can be
240
+ either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
241
+ or simply a string filter which will get serialized into an Identity filter.
242
+ (e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed.
243
+ (Default: Constant(BooleanValues.TRUE)).
244
+ region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
245
+ models. (Default: None).
246
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
247
+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
248
+ """
249
+ region = region or get_region_fallback(
250
+ sagemaker_session=sagemaker_session,
251
+ )
252
+ if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
253
+ isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
254
+ ):
255
+ return sorted([e.value for e in JumpStartScriptScope])
256
+
257
+ scripts: Set[str] = set()
258
+ for model_id, version in _generate_jumpstart_model_versions(
259
+ filter=filter,
260
+ region=region,
261
+ sagemaker_session=sagemaker_session,
262
+ model_type=JumpStartModelType.OPEN_WEIGHTS,
263
+ ):
264
+ scripts.add(JumpStartScriptScope.INFERENCE)
265
+ model_specs = verify_model_region_and_return_specs(
266
+ region=region,
267
+ model_id=model_id,
268
+ version=version,
269
+ sagemaker_session=sagemaker_session,
270
+ scope=JumpStartScriptScope.INFERENCE,
271
+ )
272
+ if model_specs.training_supported:
273
+ scripts.add(JumpStartScriptScope.TRAINING)
274
+
275
+ if scripts == {e.value for e in JumpStartScriptScope}:
276
+ break
277
+ return sorted(list(scripts))
278
+
279
+
280
+ def _is_valid_version(version: str) -> bool:
281
+ """Checks if the version is convertable to Version class."""
282
+ try:
283
+ Version(version)
284
+ return True
285
+ except Exception: # pylint: disable=broad-except
286
+ return False
287
+
288
+
289
+ def list_jumpstart_models( # pylint: disable=redefined-builtin
290
+ filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
291
+ region: Optional[str] = None,
292
+ list_incomplete_models: bool = False,
293
+ list_old_models: bool = False,
294
+ list_versions: bool = False,
295
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
296
+ ) -> List[Union[Tuple[str], Tuple[str, str]]]:
297
+ """List models for JumpStart, and optionally apply filters to result.
298
+
299
+ Args:
300
+ filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be
301
+ either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
302
+ or simply a string filter which will get serialized into an Identity filter.
303
+ (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
304
+ (Default: Constant(BooleanValues.TRUE)).
305
+ region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
306
+ models. (Default: None).
307
+ list_incomplete_models (bool): Optional. If a model does not contain metadata fields
308
+ requested by the filter, and the filter cannot be resolved to a include/not include,
309
+ whether the model should be included. By default, these models are omitted from results.
310
+ (Default: False).
311
+ list_old_models (bool): Optional. If there are older versions of a model, whether the older
312
+ versions should be included in the returned result. (Default: False).
313
+ list_versions (bool): Optional. True if versions for models should be returned in addition
314
+ to the id of the model. (Default: False).
315
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
316
+ to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
317
+ """
318
+
319
+ region = region or get_region_fallback(
320
+ sagemaker_session=sagemaker_session,
321
+ )
322
+ model_id_version_dict: Dict[str, List[str]] = dict()
323
+ for model_id, version in _generate_jumpstart_model_versions(
324
+ filter=filter,
325
+ region=region,
326
+ list_incomplete_models=list_incomplete_models,
327
+ sagemaker_session=sagemaker_session,
328
+ ):
329
+ if model_id not in model_id_version_dict:
330
+ model_id_version_dict[model_id] = list()
331
+ model_version = Version(version) if _is_valid_version(version) else version
332
+ model_id_version_dict[model_id].append(model_version)
333
+
334
+ if not list_versions:
335
+ return sorted(list(model_id_version_dict.keys()))
336
+
337
+ if not list_old_models:
338
+ for model_id, versions in model_id_version_dict.items():
339
+ try:
340
+ model_id_version_dict.update({model_id: set([max(versions)])})
341
+ except TypeError:
342
+ versions = [str(v) for v in versions]
343
+ model_id_version_dict.update({model_id: set([max(versions)])})
344
+
345
+ model_id_version_set: Set[Tuple[str, str]] = set()
346
+ for model_id in model_id_version_dict:
347
+ for version in model_id_version_dict[model_id]:
348
+ model_id_version_set.add((model_id, str(version)))
349
+
350
+ return sorted(list(model_id_version_set), key=cmp_to_key(_compare_model_version_tuples))
351
+
352
+
353
+ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
354
+ filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
355
+ region: Optional[str] = None,
356
+ list_incomplete_models: bool = False,
357
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
358
+ model_type: Optional[JumpStartModelType] = None,
359
+ ) -> Generator:
360
+ """Generate models for JumpStart, and optionally apply filters to result.
361
+
362
+ Args:
363
+ filter (Union[Operator, str]): Optional. The filter to apply to generate models. This can be
364
+ either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
365
+ or simply a string filter which will get serialized into an Identity filter.
366
+ (e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated.
367
+ (Default: Constant(BooleanValues.TRUE)).
368
+ region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
369
+ models. (Default: None).
370
+ list_incomplete_models (bool): Optional. If a model does not contain metadata fields
371
+ requested by the filter, and the filter cannot be resolved to a include/not include,
372
+ whether the model should be included. By default, these models are omitted from
373
+ results. (Default: False).
374
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
375
+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
376
+ """
377
+
378
+ region = region or get_region_fallback(
379
+ sagemaker_session=sagemaker_session,
380
+ )
381
+
382
+ prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
383
+ region=region,
384
+ s3_client=sagemaker_session.s3_client,
385
+ model_type=JumpStartModelType.PROPRIETARY,
386
+ )
387
+ open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
388
+ region=region,
389
+ s3_client=sagemaker_session.s3_client,
390
+ model_type=JumpStartModelType.OPEN_WEIGHTS,
391
+ )
392
+ models_manifest_list = (
393
+ open_weight_manifest_list
394
+ if model_type == JumpStartModelType.OPEN_WEIGHTS
395
+ else (
396
+ prop_models_manifest_list
397
+ if model_type == JumpStartModelType.PROPRIETARY
398
+ else open_weight_manifest_list + prop_models_manifest_list
399
+ )
400
+ )
401
+
402
+ if isinstance(filter, str):
403
+ filter = Identity(filter)
404
+
405
+ manifest_keys = set(
406
+ open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__
407
+ )
408
+
409
+ all_keys: Set[str] = set()
410
+
411
+ model_filters: Set[ModelFilter] = set()
412
+
413
+ for operator in _model_filter_in_operator_generator(filter):
414
+ model_filter = operator.unresolved_value
415
+ key = model_filter.key
416
+ all_keys.add(key)
417
+ if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in {
418
+ identifier.value for identifier in ProprietaryModelFilterIdentifiers
419
+ }:
420
+ model_filter.set_value(JumpStartModelType.PROPRIETARY.value)
421
+ model_filters.add(model_filter)
422
+
423
+ for key in all_keys:
424
+ if "." in key:
425
+ raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').")
426
+
427
+ metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
428
+
429
+ required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
430
+ possible_spec_keys = metadata_filter_keys - manifest_keys
431
+
432
+ is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
433
+ is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
434
+ is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys
435
+
436
+ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]:
437
+
438
+ copied_filter = copy.deepcopy(filter)
439
+
440
+ manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
441
+
442
+ model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
443
+
444
+ for val in required_manifest_keys:
445
+ manifest_specs_cached_values[val] = getattr(model_manifest, val)
446
+
447
+ if is_task_filter:
448
+ manifest_specs_cached_values[SpecialSupportedFilterKeys.TASK] = (
449
+ extract_framework_task_model(model_manifest.model_id)[1]
450
+ )
451
+
452
+ if is_framework_filter:
453
+ manifest_specs_cached_values[SpecialSupportedFilterKeys.FRAMEWORK] = (
454
+ extract_framework_task_model(model_manifest.model_id)[0]
455
+ )
456
+
457
+ if is_model_type_filter:
458
+ manifest_specs_cached_values[SpecialSupportedFilterKeys.MODEL_TYPE] = (
459
+ extract_model_type_filter_representation(model_manifest.spec_key)
460
+ )
461
+
462
+ if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
463
+ return None
464
+
465
+ _populate_model_filters_to_resolved_values(
466
+ manifest_specs_cached_values,
467
+ model_filters_to_resolved_values,
468
+ model_filters,
469
+ )
470
+
471
+ _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
472
+
473
+ copied_filter.eval()
474
+
475
+ if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
476
+ if copied_filter.resolved_value == BooleanValues.TRUE:
477
+ return (model_manifest.model_id, model_manifest.version)
478
+ return None
479
+
480
+ if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
481
+ raise RuntimeError(
482
+ "Filter expression in unevaluated state after using "
483
+ "values from model manifest. Model ID and version that "
484
+ f"is failing: {(model_manifest.model_id, model_manifest.version)}."
485
+ )
486
+ copied_filter_2 = copy.deepcopy(filter)
487
+
488
+ # spec is downloaded to thread's memory. since each thread
489
+ # accesses a unique s3 spec, there is no need to use the JS caching utils.
490
+ # spec only stays in memory for lifecycle of thread.
491
+ model_specs = JumpStartModelSpecs(
492
+ json.loads(
493
+ sagemaker_session.read_s3_file(
494
+ get_jumpstart_content_bucket(region), model_manifest.spec_key
495
+ )
496
+ )
497
+ )
498
+
499
+ for val in possible_spec_keys:
500
+ if hasattr(model_specs, val):
501
+ manifest_specs_cached_values[val] = getattr(model_specs, val)
502
+
503
+ _populate_model_filters_to_resolved_values(
504
+ manifest_specs_cached_values,
505
+ model_filters_to_resolved_values,
506
+ model_filters,
507
+ )
508
+ _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
509
+
510
+ copied_filter_2.eval()
511
+
512
+ if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
513
+ if copied_filter_2.resolved_value == BooleanValues.TRUE or (
514
+ BooleanValues.UNKNOWN and list_incomplete_models
515
+ ):
516
+ return (model_manifest.model_id, model_manifest.version)
517
+ return None
518
+
519
+ raise RuntimeError(
520
+ "Filter expression in unevaluated state after using values from model specs. "
521
+ "Model ID and version that is failing: "
522
+ f"{(model_manifest.model_id, model_manifest.version)}."
523
+ )
524
+
525
+ with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor:
526
+ futures = []
527
+ for header in models_manifest_list:
528
+ futures.append(executor.submit(evaluate_model, header))
529
+
530
+ for future in as_completed(futures):
531
+ error = future.exception()
532
+ if error:
533
+ raise error
534
+ result = future.result()
535
+ if result:
536
+ yield result
537
+
538
+
539
+ def get_model_url(
540
+ model_id: str,
541
+ model_version: str,
542
+ region: Optional[str] = None,
543
+ sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
544
+ config_name: Optional[str] = None,
545
+ ) -> str:
546
+ """Retrieve web url describing pretrained model.
547
+
548
+ Args:
549
+ model_id (str): The model ID for which to retrieve the url.
550
+ model_version (str): The model version for which to retrieve the url.
551
+ region (str): Optional. The region from which to retrieve metadata.
552
+ (Default: None)
553
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
554
+ to retrieve the model url.
555
+ """
556
+ model_type = validate_model_id_and_get_type(
557
+ model_id=model_id,
558
+ model_version=model_version,
559
+ region=region,
560
+ sagemaker_session=sagemaker_session,
561
+ )
562
+
563
+ region = region or get_region_fallback(
564
+ sagemaker_session=sagemaker_session,
565
+ )
566
+ model_specs = verify_model_region_and_return_specs(
567
+ region=region,
568
+ model_id=model_id,
569
+ version=model_version,
570
+ sagemaker_session=sagemaker_session,
571
+ scope=JumpStartScriptScope.INFERENCE,
572
+ model_type=model_type,
573
+ config_name=config_name,
574
+ )
575
+ return model_specs.url
@@ -0,0 +1,20 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module stores parameters related to SageMaker JumpStart."""
14
+ from __future__ import absolute_import
15
+ import datetime
16
+
17
+ JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20
18
+ JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
19
+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
20
+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)