sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores inference payload utilities for JumpStart models."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import base64
|
|
16
|
+
import json
|
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
|
18
|
+
import re
|
|
19
|
+
import boto3
|
|
20
|
+
|
|
21
|
+
from sagemaker.core.jumpstart.accessors import JumpStartS3PayloadAccessor
|
|
22
|
+
from sagemaker.core.jumpstart.artifacts.payloads import _retrieve_example_payloads
|
|
23
|
+
from sagemaker.core.jumpstart.constants import (
|
|
24
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType, MIMEType
|
|
27
|
+
from sagemaker.core.jumpstart.types import JumpStartSerializablePayload
|
|
28
|
+
from sagemaker.core.jumpstart.utils import (
|
|
29
|
+
get_jumpstart_content_bucket,
|
|
30
|
+
get_region_fallback,
|
|
31
|
+
)
|
|
32
|
+
from sagemaker.core.helper.session_helper import Session
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
|
|
36
|
+
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _extract_field_from_json(
|
|
40
|
+
json_input: dict,
|
|
41
|
+
keys: List[str],
|
|
42
|
+
) -> Any:
|
|
43
|
+
"""Given a dictionary, returns value at specified keys.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
KeyError: If a key cannot be found in the json input.
|
|
47
|
+
"""
|
|
48
|
+
curr_json = json_input
|
|
49
|
+
for idx, key in enumerate(keys):
|
|
50
|
+
if idx < len(keys) - 1:
|
|
51
|
+
curr_json = curr_json[key]
|
|
52
|
+
continue
|
|
53
|
+
return curr_json[key]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _construct_payload(
|
|
57
|
+
prompt: str,
|
|
58
|
+
model_id: str,
|
|
59
|
+
model_version: str,
|
|
60
|
+
region: Optional[str] = None,
|
|
61
|
+
tolerate_vulnerable_model: bool = False,
|
|
62
|
+
tolerate_deprecated_model: bool = False,
|
|
63
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
64
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
65
|
+
alias: Optional[str] = None,
|
|
66
|
+
) -> Optional[JumpStartSerializablePayload]:
|
|
67
|
+
"""Returns example payload from prompt.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
prompt (str): String-valued prompt to embed in payload.
|
|
71
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to construct
|
|
72
|
+
the payload.
|
|
73
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
74
|
+
payload.
|
|
75
|
+
region (Optional[str]): Region for which to retrieve the
|
|
76
|
+
payload. (Default: None).
|
|
77
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
78
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
79
|
+
exception if the script used by this version of the model has dependencies with known
|
|
80
|
+
security vulnerabilities. (Default: False).
|
|
81
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
82
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
83
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
84
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
85
|
+
object, used for SageMaker interactions. If not
|
|
86
|
+
specified, one is created using the default AWS configuration
|
|
87
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
88
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model or
|
|
89
|
+
proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
90
|
+
Returns:
|
|
91
|
+
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
|
|
92
|
+
this feature is unavailable for the specified model.
|
|
93
|
+
"""
|
|
94
|
+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
|
|
95
|
+
model_id=model_id,
|
|
96
|
+
model_version=model_version,
|
|
97
|
+
region=region,
|
|
98
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
99
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
100
|
+
sagemaker_session=sagemaker_session,
|
|
101
|
+
model_type=model_type,
|
|
102
|
+
)
|
|
103
|
+
if payloads is None or len(payloads) == 0:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
payload_to_use: JumpStartSerializablePayload = (
|
|
107
|
+
payloads[alias] if alias else list(payloads.values())[0]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
prompt_key: Optional[str] = payload_to_use.prompt_key
|
|
111
|
+
if prompt_key is None:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
payload_body = payload_to_use.body
|
|
115
|
+
prompt_key_split = prompt_key.split(".")
|
|
116
|
+
for idx, prompt_key in enumerate(prompt_key_split):
|
|
117
|
+
if idx < len(prompt_key_split) - 1:
|
|
118
|
+
payload_body = payload_body[prompt_key]
|
|
119
|
+
else:
|
|
120
|
+
payload_body[prompt_key] = prompt
|
|
121
|
+
|
|
122
|
+
return payload_to_use
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class PayloadSerializer:
|
|
126
|
+
"""Utility class for serializing payloads associated with JumpStart models.
|
|
127
|
+
|
|
128
|
+
Many JumpStart models embed byte-streams into payloads corresponding to images, sounds,
|
|
129
|
+
and other content types which require downloading from S3.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
bucket: Optional[str] = None,
|
|
135
|
+
region: Optional[str] = None,
|
|
136
|
+
s3_client: Optional[boto3.client] = None,
|
|
137
|
+
) -> None:
|
|
138
|
+
"""Initializes PayloadSerializer object."""
|
|
139
|
+
self.bucket = bucket or get_jumpstart_content_bucket()
|
|
140
|
+
self.region = region or get_region_fallback(
|
|
141
|
+
s3_client=s3_client,
|
|
142
|
+
)
|
|
143
|
+
self.s3_client = s3_client
|
|
144
|
+
|
|
145
|
+
def get_bytes_payload_with_s3_references(
|
|
146
|
+
self,
|
|
147
|
+
payload_str: str,
|
|
148
|
+
) -> bytes:
|
|
149
|
+
"""Returns bytes object corresponding to referenced S3 object.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: If the raw bytes payload is not formatted correctly.
|
|
153
|
+
"""
|
|
154
|
+
s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str)
|
|
155
|
+
if len(s3_keys) != 1:
|
|
156
|
+
raise ValueError("Invalid bytes payload.")
|
|
157
|
+
|
|
158
|
+
s3_key = s3_keys[0]
|
|
159
|
+
serialized_s3_object = JumpStartS3PayloadAccessor.get_object_cached(
|
|
160
|
+
bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return serialized_s3_object
|
|
164
|
+
|
|
165
|
+
def embed_s3_references_in_str_payload(
|
|
166
|
+
self,
|
|
167
|
+
payload: str,
|
|
168
|
+
) -> str:
|
|
169
|
+
"""Inserts serialized S3 content into string payload.
|
|
170
|
+
|
|
171
|
+
If no S3 content is embedded in payload, original string is returned.
|
|
172
|
+
"""
|
|
173
|
+
return self._embed_s3_b64_references_in_str_payload(payload_body=payload)
|
|
174
|
+
|
|
175
|
+
def _embed_s3_b64_references_in_str_payload(
|
|
176
|
+
self,
|
|
177
|
+
payload_body: str,
|
|
178
|
+
) -> str:
|
|
179
|
+
"""Performs base 64 encoding of payloads embedded in a payload.
|
|
180
|
+
|
|
181
|
+
This is required so that byte-valued payloads can be transmitted efficiently
|
|
182
|
+
as a utf-8 encoded string.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
s3_keys = re.compile(S3_B64_STR_REGEX).findall(payload_body)
|
|
186
|
+
for s3_key in s3_keys:
|
|
187
|
+
b64_encoded_string = base64.b64encode(
|
|
188
|
+
bytearray(
|
|
189
|
+
JumpStartS3PayloadAccessor.get_object_cached(
|
|
190
|
+
bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
).decode()
|
|
194
|
+
payload_body = payload_body.replace(f"$s3_b64<{s3_key}>", b64_encoded_string)
|
|
195
|
+
return payload_body
|
|
196
|
+
|
|
197
|
+
def embed_s3_references_in_json_payload(
|
|
198
|
+
self, payload_body: Union[list, dict, str, int, float]
|
|
199
|
+
) -> Union[list, dict, str, int, float]:
|
|
200
|
+
"""Finds all S3 references in payload and embeds serialized S3 data.
|
|
201
|
+
|
|
202
|
+
If no S3 references are found, the payload is returned un-modified.
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
ValueError: If the payload has an unrecognized type.
|
|
206
|
+
"""
|
|
207
|
+
if isinstance(payload_body, str):
|
|
208
|
+
return self.embed_s3_references_in_str_payload(payload_body)
|
|
209
|
+
if isinstance(payload_body, (float, int)):
|
|
210
|
+
return payload_body
|
|
211
|
+
if isinstance(payload_body, list):
|
|
212
|
+
return [self.embed_s3_references_in_json_payload(item) for item in payload_body]
|
|
213
|
+
if isinstance(payload_body, dict):
|
|
214
|
+
return {
|
|
215
|
+
key: self.embed_s3_references_in_json_payload(value)
|
|
216
|
+
for key, value in payload_body.items()
|
|
217
|
+
}
|
|
218
|
+
raise ValueError(f"Payload has unrecognized type: {type(payload_body)}")
|
|
219
|
+
|
|
220
|
+
def serialize(self, payload: JumpStartSerializablePayload) -> Union[str, bytes]:
|
|
221
|
+
"""Returns payload string or bytes that can be inputted to inference endpoint.
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
ValueError: If the payload has an unrecognized type.
|
|
225
|
+
"""
|
|
226
|
+
content_type = MIMEType.from_suffixed_type(payload.content_type)
|
|
227
|
+
body = payload.body
|
|
228
|
+
|
|
229
|
+
if content_type in {MIMEType.JSON, MIMEType.LIST_TEXT, MIMEType.X_TEXT}:
|
|
230
|
+
body = self.embed_s3_references_in_json_payload(body)
|
|
231
|
+
else:
|
|
232
|
+
body = self.get_bytes_payload_with_s3_references(body)
|
|
233
|
+
|
|
234
|
+
if isinstance(body, dict):
|
|
235
|
+
body = json.dumps(body)
|
|
236
|
+
elif not isinstance(body, str) and not isinstance(body, bytes):
|
|
237
|
+
raise ValueError(f"Default payload '{body}' has unrecognized type: {type(body)}")
|
|
238
|
+
|
|
239
|
+
return body
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
{
|
|
2
|
+
"af-south-1": {
|
|
3
|
+
"content_bucket": "jumpstart-cache-prod-af-south-1",
|
|
4
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-af-south-1"
|
|
5
|
+
},
|
|
6
|
+
"ap-east-1": {
|
|
7
|
+
"content_bucket": "jumpstart-cache-prod-ap-east-1",
|
|
8
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-east-1"
|
|
9
|
+
},
|
|
10
|
+
"ap-east-2": {
|
|
11
|
+
"content_bucket": "jumpstart-cache-prod-ap-east-2",
|
|
12
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-east-2"
|
|
13
|
+
},
|
|
14
|
+
"ap-northeast-1": {
|
|
15
|
+
"content_bucket": "jumpstart-cache-prod-ap-northeast-1",
|
|
16
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-1",
|
|
17
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-1"
|
|
18
|
+
},
|
|
19
|
+
"ap-northeast-2": {
|
|
20
|
+
"content_bucket": "jumpstart-cache-prod-ap-northeast-2",
|
|
21
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-2",
|
|
22
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-2"
|
|
23
|
+
},
|
|
24
|
+
"ap-northeast-3": {
|
|
25
|
+
"content_bucket": "jumpstart-cache-prod-ap-northeast-3",
|
|
26
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-3",
|
|
27
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-3"
|
|
28
|
+
},
|
|
29
|
+
"ap-south-1": {
|
|
30
|
+
"content_bucket": "jumpstart-cache-prod-ap-south-1",
|
|
31
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-south-1",
|
|
32
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-south-1"
|
|
33
|
+
},
|
|
34
|
+
"ap-south-2": {
|
|
35
|
+
"content_bucket": "jumpstart-cache-prod-ap-south-2",
|
|
36
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-south-2"
|
|
37
|
+
},
|
|
38
|
+
"ap-southeast-1": {
|
|
39
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-1",
|
|
40
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-1",
|
|
41
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-1"
|
|
42
|
+
},
|
|
43
|
+
"ap-southeast-2": {
|
|
44
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-2",
|
|
45
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-2",
|
|
46
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-2"
|
|
47
|
+
},
|
|
48
|
+
"ap-southeast-3": {
|
|
49
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-3",
|
|
50
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-3"
|
|
51
|
+
},
|
|
52
|
+
"ap-southeast-4": {
|
|
53
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-4",
|
|
54
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-4"
|
|
55
|
+
},
|
|
56
|
+
"ap-southeast-5": {
|
|
57
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-5",
|
|
58
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-5"
|
|
59
|
+
},
|
|
60
|
+
"ap-southeast-6": {
|
|
61
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-6",
|
|
62
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-6"
|
|
63
|
+
},
|
|
64
|
+
"ap-southeast-7": {
|
|
65
|
+
"content_bucket": "jumpstart-cache-prod-ap-southeast-7",
|
|
66
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-7"
|
|
67
|
+
},
|
|
68
|
+
"ca-central-1": {
|
|
69
|
+
"content_bucket": "jumpstart-cache-prod-ca-central-1",
|
|
70
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ca-central-1",
|
|
71
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-ca-central-1"
|
|
72
|
+
},
|
|
73
|
+
"ca-west-1": {
|
|
74
|
+
"content_bucket": "jumpstart-cache-prod-ca-west-1",
|
|
75
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-ca-west-1"
|
|
76
|
+
},
|
|
77
|
+
"cn-north-1": {
|
|
78
|
+
"content_bucket": "jumpstart-cache-prod-cn-north-1",
|
|
79
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-cn-north-1"
|
|
80
|
+
},
|
|
81
|
+
"cn-northwest-1": {
|
|
82
|
+
"content_bucket": "jumpstart-cache-prod-cn-northwest-1",
|
|
83
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-cn-northwest-1"
|
|
84
|
+
},
|
|
85
|
+
"eu-central-1": {
|
|
86
|
+
"content_bucket": "jumpstart-cache-prod-eu-central-1",
|
|
87
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-central-1",
|
|
88
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-eu-central-1"
|
|
89
|
+
},
|
|
90
|
+
"eu-central-2": {
|
|
91
|
+
"content_bucket": "jumpstart-cache-prod-eu-central-2",
|
|
92
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-central-2"
|
|
93
|
+
},
|
|
94
|
+
"eu-north-1": {
|
|
95
|
+
"content_bucket": "jumpstart-cache-prod-eu-north-1",
|
|
96
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-north-1",
|
|
97
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-eu-north-1"
|
|
98
|
+
},
|
|
99
|
+
"eu-south-1": {
|
|
100
|
+
"content_bucket": "jumpstart-cache-prod-eu-south-1",
|
|
101
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-south-1"
|
|
102
|
+
},
|
|
103
|
+
"eu-south-2": {
|
|
104
|
+
"content_bucket": "jumpstart-cache-prod-eu-south-2",
|
|
105
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-south-2"
|
|
106
|
+
},
|
|
107
|
+
"eu-west-1": {
|
|
108
|
+
"content_bucket": "jumpstart-cache-prod-eu-west-1",
|
|
109
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-west-1",
|
|
110
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-eu-west-1"
|
|
111
|
+
},
|
|
112
|
+
"eu-west-2": {
|
|
113
|
+
"content_bucket": "jumpstart-cache-prod-eu-west-2",
|
|
114
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-west-2",
|
|
115
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-eu-west-2"
|
|
116
|
+
},
|
|
117
|
+
"eu-west-3": {
|
|
118
|
+
"content_bucket": "jumpstart-cache-prod-eu-west-3",
|
|
119
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-eu-west-3",
|
|
120
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-eu-west-3"
|
|
121
|
+
},
|
|
122
|
+
"il-central-1": {
|
|
123
|
+
"content_bucket": "jumpstart-cache-prod-il-central-1",
|
|
124
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-il-central-1"
|
|
125
|
+
},
|
|
126
|
+
"me-central-1": {
|
|
127
|
+
"content_bucket": "jumpstart-cache-prod-me-central-1",
|
|
128
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-me-central-1"
|
|
129
|
+
},
|
|
130
|
+
"me-south-1": {
|
|
131
|
+
"content_bucket": "jumpstart-cache-prod-me-south-1",
|
|
132
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-me-south-1"
|
|
133
|
+
},
|
|
134
|
+
"mx-central-1": {
|
|
135
|
+
"content_bucket": "jumpstart-cache-prod-mx-central-1",
|
|
136
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-mx-central-1"
|
|
137
|
+
},
|
|
138
|
+
"sa-east-1": {
|
|
139
|
+
"content_bucket": "jumpstart-cache-prod-sa-east-1",
|
|
140
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-sa-east-1",
|
|
141
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-sa-east-1"
|
|
142
|
+
},
|
|
143
|
+
"us-east-1": {
|
|
144
|
+
"content_bucket": "jumpstart-cache-prod-us-east-1",
|
|
145
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-east-1",
|
|
146
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-us-east-1"
|
|
147
|
+
},
|
|
148
|
+
"us-east-2": {
|
|
149
|
+
"content_bucket": "jumpstart-cache-prod-us-east-2",
|
|
150
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-east-2",
|
|
151
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-us-east-2"
|
|
152
|
+
},
|
|
153
|
+
"us-gov-east-1": {
|
|
154
|
+
"content_bucket": "jumpstart-cache-prod-us-gov-east-1",
|
|
155
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-gov-east-1"
|
|
156
|
+
},
|
|
157
|
+
"us-gov-west-1": {
|
|
158
|
+
"content_bucket": "jumpstart-cache-prod-us-gov-west-1",
|
|
159
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-gov-west-1"
|
|
160
|
+
},
|
|
161
|
+
"us-west-1": {
|
|
162
|
+
"content_bucket": "jumpstart-cache-prod-us-west-1",
|
|
163
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-west-1",
|
|
164
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-us-west-1"
|
|
165
|
+
},
|
|
166
|
+
"us-west-2": {
|
|
167
|
+
"content_bucket": "jumpstart-cache-prod-us-west-2",
|
|
168
|
+
"gated_content_bucket": "jumpstart-private-cache-prod-us-west-2",
|
|
169
|
+
"neo_content_bucket": "sagemaker-sd-models-prod-us-west-2"
|
|
170
|
+
}
|
|
171
|
+
}
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Iterator, Optional
|
|
4
|
+
from sagemaker.core.helper.session_helper import Session
|
|
5
|
+
from sagemaker.core.resources import HubContent
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _Filter:
|
|
11
|
+
"""
|
|
12
|
+
A filter that evaluates logical expressions against a list of keyword strings.
|
|
13
|
+
|
|
14
|
+
Supports logical operators (AND, OR, NOT), parentheses for grouping, and wildcard patterns
|
|
15
|
+
(e.g., `text-*`, `*ai`, `@task:foo`).
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
filt = _Filter("(@framework:huggingface OR text-*) AND NOT deprecated")
|
|
19
|
+
filt.match(["@framework:huggingface", "text-generation"]) # Returns True
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, expression: str) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Initialize the filter with a string expression.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
expression (str): A logical expression to evaluate against keywords.
|
|
28
|
+
Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
|
|
29
|
+
"""
|
|
30
|
+
self.expression: str = expression
|
|
31
|
+
|
|
32
|
+
def match(self, keywords: List[str]) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
Evaluate the filter expression against a list of keywords.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
keywords (List[str]): A list of keyword strings to test.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
bool: True if the expression evaluates to True for the given keywords, else False.
|
|
41
|
+
"""
|
|
42
|
+
expr: str = self._convert_expression(self.expression)
|
|
43
|
+
try:
|
|
44
|
+
return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any})
|
|
45
|
+
except Exception:
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
def _convert_expression(self, expr: str) -> str:
|
|
49
|
+
"""
|
|
50
|
+
Convert the logical filter expression into a Python-evaluable string.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
expr (str): The raw expression to convert.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
str: A Python expression string using 'any' and logical operators.
|
|
57
|
+
"""
|
|
58
|
+
tokens: List[str] = re.findall(
|
|
59
|
+
r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def wildcard_condition(pattern: str) -> str:
|
|
63
|
+
pattern = pattern.strip('"').strip("'")
|
|
64
|
+
stripped = pattern.strip("*")
|
|
65
|
+
|
|
66
|
+
if pattern.startswith("*") and pattern.endswith("*"):
|
|
67
|
+
return f"{repr(stripped)} in k"
|
|
68
|
+
elif pattern.startswith("*"):
|
|
69
|
+
return f"k.endswith({repr(stripped)})"
|
|
70
|
+
elif pattern.endswith("*"):
|
|
71
|
+
return f"k.startswith({repr(stripped)})"
|
|
72
|
+
else:
|
|
73
|
+
return f"k == {repr(pattern)}"
|
|
74
|
+
|
|
75
|
+
def convert_token(token: str) -> str:
|
|
76
|
+
upper = token.upper()
|
|
77
|
+
if upper == "AND":
|
|
78
|
+
return "and"
|
|
79
|
+
elif upper == "OR":
|
|
80
|
+
return "or"
|
|
81
|
+
elif upper == "NOT":
|
|
82
|
+
return "not"
|
|
83
|
+
elif token in ("(", ")"):
|
|
84
|
+
return token
|
|
85
|
+
else:
|
|
86
|
+
return f"any({wildcard_condition(token)} for k in keywords)"
|
|
87
|
+
|
|
88
|
+
converted_tokens = [convert_token(tok) for tok in tokens]
|
|
89
|
+
return " ".join(converted_tokens)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]:
|
|
93
|
+
"""
|
|
94
|
+
Retrieve all model entries from the specified hub and yield them one by one.
|
|
95
|
+
|
|
96
|
+
This function paginates through the SageMaker Hub API to retrieve all published models of type "Model"
|
|
97
|
+
and yields them as `HubContent` objects.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
hub_name (str): The name of the hub to query.
|
|
101
|
+
sm_client (Session): The SageMaker session.
|
|
102
|
+
|
|
103
|
+
Yields:
|
|
104
|
+
HubContent: A `HubContent` object representing a single model entry from the hub.
|
|
105
|
+
"""
|
|
106
|
+
next_token = None
|
|
107
|
+
|
|
108
|
+
while True:
|
|
109
|
+
# Prepare the request parameters
|
|
110
|
+
params = {"HubName": hub_name, "HubContentType": "Model", "MaxResults": 100}
|
|
111
|
+
|
|
112
|
+
# Add NextToken if it exists
|
|
113
|
+
if next_token:
|
|
114
|
+
params["NextToken"] = next_token
|
|
115
|
+
|
|
116
|
+
# Make the API call
|
|
117
|
+
response = sm_client.list_hub_contents(**params)
|
|
118
|
+
|
|
119
|
+
# Yield each content summary
|
|
120
|
+
for content in response["HubContentSummaries"]:
|
|
121
|
+
yield HubContent(
|
|
122
|
+
hub_name=hub_name,
|
|
123
|
+
hub_content_arn=content["HubContentArn"],
|
|
124
|
+
hub_content_type="Model",
|
|
125
|
+
hub_content_name=content["HubContentName"],
|
|
126
|
+
hub_content_version=content["HubContentVersion"],
|
|
127
|
+
hub_content_description=content.get("HubContentDescription", ""),
|
|
128
|
+
hub_content_search_keywords=content.get("HubContentSearchKeywords", []),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Check if there are more results
|
|
132
|
+
next_token = response.get("NextToken", None)
|
|
133
|
+
if not next_token or len(response["HubContentSummaries"]) == 0:
|
|
134
|
+
break # Exit the loop if there are no more pages
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def search_public_hub_models(
|
|
138
|
+
query: str,
|
|
139
|
+
hub_name: Optional[str] = "SageMakerPublicHub",
|
|
140
|
+
sagemaker_session: Optional[Session] = None,
|
|
141
|
+
) -> List[HubContent]:
|
|
142
|
+
"""
|
|
143
|
+
Search and filter models from hub using a keyword expression.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
query (str): A logical expression used to filter models by keywords.
|
|
147
|
+
Example: "@task:text-generation AND NOT @framework:legacy"
|
|
148
|
+
hub_name (Optional[str]): The name of the hub to query. Defaults to "SageMakerPublicHub".
|
|
149
|
+
sagemaker_session (Optional[Session]): An optional SageMaker `Session` object. If not provided,
|
|
150
|
+
a default session will be created and a warning will be logged.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
List[HubContent]: A list of filtered `HubContent` model objects that match the query.
|
|
154
|
+
"""
|
|
155
|
+
if sagemaker_session is None:
|
|
156
|
+
sagemaker_session = Session()
|
|
157
|
+
logger.warning("SageMaker session not provided. Using default Session.")
|
|
158
|
+
sm_client = sagemaker_session.sagemaker_client
|
|
159
|
+
|
|
160
|
+
models = _list_all_hub_models(hub_name, sm_client)
|
|
161
|
+
filt = _Filter(query)
|
|
162
|
+
results: List[HubContent] = []
|
|
163
|
+
|
|
164
|
+
for model in models:
|
|
165
|
+
keywords = model.hub_content_search_keywords
|
|
166
|
+
normalized_keywords = [kw.replace(" ", "-") for kw in keywords]
|
|
167
|
+
|
|
168
|
+
if filt.match(normalized_keywords):
|
|
169
|
+
results.append(model)
|
|
170
|
+
|
|
171
|
+
return results
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""JumpStart serializers module - provides retrieve_default function for backward compatibility."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker.core.serializers import BaseSerializer
|
|
19
|
+
from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
|
|
20
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
21
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
22
|
+
from sagemaker.core.helper.session_helper import Session
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def retrieve_default(
|
|
26
|
+
region: Optional[str] = None,
|
|
27
|
+
model_id: Optional[str] = None,
|
|
28
|
+
model_version: Optional[str] = None,
|
|
29
|
+
hub_arn: Optional[str] = None,
|
|
30
|
+
tolerate_vulnerable_model: bool = False,
|
|
31
|
+
tolerate_deprecated_model: bool = False,
|
|
32
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
33
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
34
|
+
config_name: Optional[str] = None,
|
|
35
|
+
) -> BaseSerializer:
|
|
36
|
+
"""Retrieves the default serializer for the model matching the given arguments.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
region (str): The AWS Region for which to retrieve the default serializer.
|
|
40
|
+
Defaults to ``None``.
|
|
41
|
+
model_id (str): The model ID of the model for which to
|
|
42
|
+
retrieve the default serializer. (Default: None).
|
|
43
|
+
model_version (str): The version of the model for which to retrieve the
|
|
44
|
+
default serializer. (Default: None).
|
|
45
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
46
|
+
model details from. (Default: None).
|
|
47
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
48
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
49
|
+
exception if the script used by this version of the model has dependencies with known
|
|
50
|
+
security vulnerabilities. (Default: False).
|
|
51
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
52
|
+
(exception not raised). False if these models should raise an exception.
|
|
53
|
+
(Default: False).
|
|
54
|
+
model_type (JumpStartModelType): The model type. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
55
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
56
|
+
object, used for SageMaker interactions. If not
|
|
57
|
+
specified, one is created using the default AWS configuration
|
|
58
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
59
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
60
|
+
Returns:
|
|
61
|
+
BaseSerializer: The default serializer to use for the model.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
65
|
+
"""
|
|
66
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving serializers."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return artifacts._retrieve_default_serializer(
|
|
72
|
+
model_id=model_id,
|
|
73
|
+
model_version=model_version,
|
|
74
|
+
hub_arn=hub_arn,
|
|
75
|
+
region=region,
|
|
76
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
77
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
78
|
+
sagemaker_session=sagemaker_session,
|
|
79
|
+
model_type=model_type,
|
|
80
|
+
config_name=config_name,
|
|
81
|
+
)
|