sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart predictors."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from typing import List, Optional, Set, Type
|
|
16
|
+
from sagemaker.core.deserializers import BaseDeserializer
|
|
17
|
+
from sagemaker.core.serializers import BaseSerializer
|
|
18
|
+
from sagemaker.core.jumpstart.constants import (
|
|
19
|
+
ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP,
|
|
20
|
+
CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP,
|
|
21
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
22
|
+
DESERIALIZER_TYPE_TO_CLASS_MAP,
|
|
23
|
+
SERIALIZER_TYPE_TO_CLASS_MAP,
|
|
24
|
+
)
|
|
25
|
+
from sagemaker.core.jumpstart.enums import (
|
|
26
|
+
JumpStartScriptScope,
|
|
27
|
+
MIMEType,
|
|
28
|
+
JumpStartModelType,
|
|
29
|
+
)
|
|
30
|
+
from sagemaker.core.jumpstart.utils import (
|
|
31
|
+
get_region_fallback,
|
|
32
|
+
verify_model_region_and_return_specs,
|
|
33
|
+
)
|
|
34
|
+
from sagemaker.core.helper.session_helper import Session
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _retrieve_serializer_from_content_type(
|
|
38
|
+
content_type: MIMEType,
|
|
39
|
+
) -> BaseDeserializer:
|
|
40
|
+
"""Returns serializer object to use for content type."""
|
|
41
|
+
|
|
42
|
+
serializer_type = CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP.get(content_type)
|
|
43
|
+
|
|
44
|
+
if serializer_type is None:
|
|
45
|
+
raise RuntimeError(f"Unrecognized content type: {content_type}")
|
|
46
|
+
|
|
47
|
+
serializer_handle = SERIALIZER_TYPE_TO_CLASS_MAP.get(serializer_type)
|
|
48
|
+
|
|
49
|
+
if serializer_handle is None:
|
|
50
|
+
raise RuntimeError(f"Unrecognized serializer type: {serializer_type}")
|
|
51
|
+
|
|
52
|
+
return serializer_handle.__call__()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _retrieve_deserializer_from_accept_type(
|
|
56
|
+
accept_type: MIMEType,
|
|
57
|
+
) -> BaseDeserializer:
|
|
58
|
+
"""Returns deserializer object to use for accept type."""
|
|
59
|
+
|
|
60
|
+
deserializer_type = ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP.get(accept_type)
|
|
61
|
+
|
|
62
|
+
if deserializer_type is None:
|
|
63
|
+
raise RuntimeError(f"Unrecognized accept type: {accept_type}")
|
|
64
|
+
|
|
65
|
+
deserializer_handle = DESERIALIZER_TYPE_TO_CLASS_MAP.get(deserializer_type)
|
|
66
|
+
|
|
67
|
+
if deserializer_handle is None:
|
|
68
|
+
raise RuntimeError(f"Unrecognized deserializer type: {deserializer_type}")
|
|
69
|
+
|
|
70
|
+
return deserializer_handle.__call__()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _retrieve_default_deserializer(
|
|
74
|
+
model_id: str,
|
|
75
|
+
model_version: str,
|
|
76
|
+
hub_arn: Optional[str],
|
|
77
|
+
region: Optional[str],
|
|
78
|
+
tolerate_vulnerable_model: bool = False,
|
|
79
|
+
tolerate_deprecated_model: bool = False,
|
|
80
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
81
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
82
|
+
config_name: Optional[str] = None,
|
|
83
|
+
) -> BaseDeserializer:
|
|
84
|
+
"""Retrieves the default deserializer for the model.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
88
|
+
retrieve the default deserializer.
|
|
89
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
90
|
+
default deserializer.
|
|
91
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
92
|
+
model details from. (Default: None).
|
|
93
|
+
region (Optional[str]): Region for which to retrieve default deserializer.
|
|
94
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
95
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
96
|
+
exception if the script used by this version of the model has dependencies with known
|
|
97
|
+
security vulnerabilities. (Default: False).
|
|
98
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
99
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
100
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
101
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
102
|
+
object, used for SageMaker interactions. If not
|
|
103
|
+
specified, one is created using the default AWS configuration
|
|
104
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
105
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
BaseDeserializer: the default deserializer to use for the model.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
default_accept_type = _retrieve_default_accept_type(
|
|
112
|
+
model_id=model_id,
|
|
113
|
+
model_version=model_version,
|
|
114
|
+
hub_arn=hub_arn,
|
|
115
|
+
region=region,
|
|
116
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
117
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
118
|
+
sagemaker_session=sagemaker_session,
|
|
119
|
+
model_type=model_type,
|
|
120
|
+
config_name=config_name,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _retrieve_default_serializer(
|
|
127
|
+
model_id: str,
|
|
128
|
+
model_version: str,
|
|
129
|
+
hub_arn: Optional[str],
|
|
130
|
+
region: Optional[str],
|
|
131
|
+
tolerate_vulnerable_model: bool = False,
|
|
132
|
+
tolerate_deprecated_model: bool = False,
|
|
133
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
134
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
135
|
+
config_name: Optional[str] = None,
|
|
136
|
+
) -> BaseSerializer:
|
|
137
|
+
"""Retrieves the default serializer for the model.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
141
|
+
retrieve the default serializer.
|
|
142
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
143
|
+
default serializer.
|
|
144
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
145
|
+
model details from. (Default: None).
|
|
146
|
+
region (Optional[str]): Region for which to retrieve default serializer.
|
|
147
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
148
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
149
|
+
exception if the script used by this version of the model has dependencies with known
|
|
150
|
+
security vulnerabilities. (Default: False).
|
|
151
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
152
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
153
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
154
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
155
|
+
object, used for SageMaker interactions. If not
|
|
156
|
+
specified, one is created using the default AWS configuration
|
|
157
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
158
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
159
|
+
Returns:
|
|
160
|
+
BaseSerializer: the default serializer to use for the model.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
default_content_type = _retrieve_default_content_type(
|
|
164
|
+
model_id=model_id,
|
|
165
|
+
model_version=model_version,
|
|
166
|
+
hub_arn=hub_arn,
|
|
167
|
+
region=region,
|
|
168
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
169
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
170
|
+
sagemaker_session=sagemaker_session,
|
|
171
|
+
model_type=model_type,
|
|
172
|
+
config_name=config_name,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _retrieve_deserializer_options(
|
|
179
|
+
model_id: str,
|
|
180
|
+
model_version: str,
|
|
181
|
+
hub_arn: Optional[str],
|
|
182
|
+
region: Optional[str],
|
|
183
|
+
tolerate_vulnerable_model: bool = False,
|
|
184
|
+
tolerate_deprecated_model: bool = False,
|
|
185
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
186
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
187
|
+
config_name: Optional[str] = None,
|
|
188
|
+
) -> List[BaseDeserializer]:
|
|
189
|
+
"""Retrieves the supported deserializers for the model.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
193
|
+
retrieve the supported deserializers.
|
|
194
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
195
|
+
supported deserializers.
|
|
196
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
197
|
+
model details from. (Default: None).
|
|
198
|
+
region (Optional[str]): Region for which to retrieve deserializer options.
|
|
199
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
200
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
201
|
+
exception if the script used by this version of the model has dependencies with known
|
|
202
|
+
security vulnerabilities. (Default: False).
|
|
203
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
204
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
205
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
206
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
207
|
+
object, used for SageMaker interactions. If not
|
|
208
|
+
specified, one is created using the default AWS configuration
|
|
209
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
210
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
211
|
+
Returns:
|
|
212
|
+
List[BaseDeserializer]: the supported deserializers to use for the model.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
supported_accept_types = _retrieve_supported_accept_types(
|
|
216
|
+
model_id=model_id,
|
|
217
|
+
model_version=model_version,
|
|
218
|
+
hub_arn=hub_arn,
|
|
219
|
+
region=region,
|
|
220
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
221
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
222
|
+
sagemaker_session=sagemaker_session,
|
|
223
|
+
model_type=model_type,
|
|
224
|
+
config_name=config_name,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
seen_classes: Set[Type] = set()
|
|
228
|
+
|
|
229
|
+
deserializers_with_duplicates: List[BaseDeserializer] = [
|
|
230
|
+
_retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(accept_type))
|
|
231
|
+
for accept_type in supported_accept_types
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
deserializers: List[BaseDeserializer] = []
|
|
235
|
+
|
|
236
|
+
for deserializer in deserializers_with_duplicates:
|
|
237
|
+
if type(deserializer) not in seen_classes:
|
|
238
|
+
seen_classes.add(type(deserializer))
|
|
239
|
+
deserializers.append(deserializer)
|
|
240
|
+
|
|
241
|
+
return deserializers
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _retrieve_serializer_options(
|
|
245
|
+
model_id: str,
|
|
246
|
+
model_version: str,
|
|
247
|
+
hub_arn: Optional[str],
|
|
248
|
+
region: Optional[str],
|
|
249
|
+
tolerate_vulnerable_model: bool = False,
|
|
250
|
+
tolerate_deprecated_model: bool = False,
|
|
251
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
252
|
+
config_name: Optional[str] = None,
|
|
253
|
+
) -> List[BaseSerializer]:
|
|
254
|
+
"""Retrieves the supported serializers for the model.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
258
|
+
retrieve the supported serializers.
|
|
259
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
260
|
+
supported serializers.
|
|
261
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
262
|
+
model details from. (Default: None).
|
|
263
|
+
region (Optional[str]): Region for which to retrieve serializer options.
|
|
264
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
265
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
266
|
+
exception if the script used by this version of the model has dependencies with known
|
|
267
|
+
security vulnerabilities. (Default: False).
|
|
268
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
269
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
270
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
271
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
272
|
+
object, used for SageMaker interactions. If not
|
|
273
|
+
specified, one is created using the default AWS configuration
|
|
274
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
275
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
276
|
+
Returns:
|
|
277
|
+
List[BaseSerializer]: the supported serializers to use for the model.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
supported_content_types = _retrieve_supported_content_types(
|
|
281
|
+
model_id=model_id,
|
|
282
|
+
model_version=model_version,
|
|
283
|
+
hub_arn=hub_arn,
|
|
284
|
+
region=region,
|
|
285
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
286
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
287
|
+
sagemaker_session=sagemaker_session,
|
|
288
|
+
config_name=config_name,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
seen_classes: Set[Type] = set()
|
|
292
|
+
|
|
293
|
+
serializers_with_duplicates: List[BaseSerializer] = [
|
|
294
|
+
_retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(content_type))
|
|
295
|
+
for content_type in supported_content_types
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
serializers: List[BaseSerializer] = []
|
|
299
|
+
|
|
300
|
+
for serializer in serializers_with_duplicates:
|
|
301
|
+
if type(serializer) not in seen_classes:
|
|
302
|
+
seen_classes.add(type(serializer))
|
|
303
|
+
serializers.append(serializer)
|
|
304
|
+
|
|
305
|
+
return serializers
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _retrieve_default_content_type(
|
|
309
|
+
model_id: str,
|
|
310
|
+
model_version: str,
|
|
311
|
+
hub_arn: Optional[str],
|
|
312
|
+
region: Optional[str],
|
|
313
|
+
tolerate_vulnerable_model: bool = False,
|
|
314
|
+
tolerate_deprecated_model: bool = False,
|
|
315
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
316
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
317
|
+
config_name: Optional[str] = None,
|
|
318
|
+
) -> str:
|
|
319
|
+
"""Retrieves the default content type for the model.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
323
|
+
retrieve the default content type.
|
|
324
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
325
|
+
default content type.
|
|
326
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
327
|
+
model details from. (Default: None).
|
|
328
|
+
region (Optional[str]): Region for which to retrieve default content type.
|
|
329
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
330
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
331
|
+
exception if the script used by this version of the model has dependencies with known
|
|
332
|
+
security vulnerabilities. (Default: False).
|
|
333
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
334
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
335
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
336
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
337
|
+
object, used for SageMaker interactions. If not
|
|
338
|
+
specified, one is created using the default AWS configuration
|
|
339
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
340
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
341
|
+
Returns:
|
|
342
|
+
str: the default content type to use for the model.
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
region = region or get_region_fallback(
|
|
346
|
+
sagemaker_session=sagemaker_session,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
model_specs = verify_model_region_and_return_specs(
|
|
350
|
+
model_id=model_id,
|
|
351
|
+
version=model_version,
|
|
352
|
+
hub_arn=hub_arn,
|
|
353
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
354
|
+
region=region,
|
|
355
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
356
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
357
|
+
sagemaker_session=sagemaker_session,
|
|
358
|
+
model_type=model_type,
|
|
359
|
+
config_name=config_name,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
default_content_type = model_specs.predictor_specs.default_content_type
|
|
363
|
+
return default_content_type
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _retrieve_default_accept_type(
|
|
367
|
+
model_id: str,
|
|
368
|
+
model_version: str,
|
|
369
|
+
hub_arn: Optional[str],
|
|
370
|
+
region: Optional[str],
|
|
371
|
+
tolerate_vulnerable_model: bool = False,
|
|
372
|
+
tolerate_deprecated_model: bool = False,
|
|
373
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
374
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
375
|
+
config_name: Optional[str] = None,
|
|
376
|
+
) -> str:
|
|
377
|
+
"""Retrieves the default accept type for the model.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
381
|
+
retrieve the default accept type.
|
|
382
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
383
|
+
default accept type.
|
|
384
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
385
|
+
model details from. (Default: None).
|
|
386
|
+
region (Optional[str]): Region for which to retrieve default accept type.
|
|
387
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
388
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
389
|
+
exception if the script used by this version of the model has dependencies with known
|
|
390
|
+
security vulnerabilities. (Default: False).
|
|
391
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
392
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
393
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
394
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
395
|
+
object, used for SageMaker interactions. If not
|
|
396
|
+
specified, one is created using the default AWS configuration
|
|
397
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
398
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
399
|
+
Returns:
|
|
400
|
+
str: the default accept type to use for the model.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
region = region or get_region_fallback(
|
|
404
|
+
sagemaker_session=sagemaker_session,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
model_specs = verify_model_region_and_return_specs(
|
|
408
|
+
model_id=model_id,
|
|
409
|
+
version=model_version,
|
|
410
|
+
hub_arn=hub_arn,
|
|
411
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
412
|
+
region=region,
|
|
413
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
414
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
415
|
+
sagemaker_session=sagemaker_session,
|
|
416
|
+
model_type=model_type,
|
|
417
|
+
config_name=config_name,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
default_accept_type = model_specs.predictor_specs.default_accept_type
|
|
421
|
+
|
|
422
|
+
return default_accept_type
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _retrieve_supported_accept_types(
|
|
426
|
+
model_id: str,
|
|
427
|
+
model_version: str,
|
|
428
|
+
hub_arn: Optional[str],
|
|
429
|
+
region: Optional[str],
|
|
430
|
+
tolerate_vulnerable_model: bool = False,
|
|
431
|
+
tolerate_deprecated_model: bool = False,
|
|
432
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
433
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
434
|
+
config_name: Optional[str] = None,
|
|
435
|
+
) -> List[str]:
|
|
436
|
+
"""Retrieves the supported accept types for the model.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
440
|
+
retrieve the supported accept types.
|
|
441
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
442
|
+
supported accept types.
|
|
443
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
444
|
+
model details from. (Default: None).
|
|
445
|
+
region (Optional[str]): Region for which to retrieve accept type options.
|
|
446
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
447
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
448
|
+
exception if the script used by this version of the model has dependencies with known
|
|
449
|
+
security vulnerabilities. (Default: False).
|
|
450
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
451
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
452
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
453
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
454
|
+
object, used for SageMaker interactions. If not
|
|
455
|
+
specified, one is created using the default AWS configuration
|
|
456
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
457
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
458
|
+
Returns:
|
|
459
|
+
list: the supported accept types to use for the model.
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
region = region or get_region_fallback(
|
|
463
|
+
sagemaker_session=sagemaker_session,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
model_specs = verify_model_region_and_return_specs(
|
|
467
|
+
model_id=model_id,
|
|
468
|
+
version=model_version,
|
|
469
|
+
hub_arn=hub_arn,
|
|
470
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
471
|
+
region=region,
|
|
472
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
473
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
474
|
+
sagemaker_session=sagemaker_session,
|
|
475
|
+
model_type=model_type,
|
|
476
|
+
config_name=config_name,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
supported_accept_types = model_specs.predictor_specs.supported_accept_types
|
|
480
|
+
|
|
481
|
+
return supported_accept_types
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _retrieve_supported_content_types(
|
|
485
|
+
model_id: str,
|
|
486
|
+
model_version: str,
|
|
487
|
+
hub_arn: Optional[str],
|
|
488
|
+
region: Optional[str],
|
|
489
|
+
tolerate_vulnerable_model: bool = False,
|
|
490
|
+
tolerate_deprecated_model: bool = False,
|
|
491
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
492
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
493
|
+
config_name: Optional[str] = None,
|
|
494
|
+
) -> List[str]:
|
|
495
|
+
"""Retrieves the supported content types for the model.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
499
|
+
retrieve the supported content types.
|
|
500
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
501
|
+
supported content types.
|
|
502
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
503
|
+
model details from. (Default: None).
|
|
504
|
+
region (Optional[str]): Region for which to retrieve content type options.
|
|
505
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
506
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
507
|
+
exception if the script used by this version of the model has dependencies with known
|
|
508
|
+
security vulnerabilities. (Default: False).
|
|
509
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
510
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
511
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
512
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
513
|
+
object, used for SageMaker interactions. If not
|
|
514
|
+
specified, one is created using the default AWS configuration
|
|
515
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
516
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
517
|
+
Returns:
|
|
518
|
+
list: the supported content types to use for the model.
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
region = region or get_region_fallback(
|
|
522
|
+
sagemaker_session=sagemaker_session,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
model_specs = verify_model_region_and_return_specs(
|
|
526
|
+
model_id=model_id,
|
|
527
|
+
version=model_version,
|
|
528
|
+
hub_arn=hub_arn,
|
|
529
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
530
|
+
region=region,
|
|
531
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
532
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
533
|
+
sagemaker_session=sagemaker_session,
|
|
534
|
+
model_type=model_type,
|
|
535
|
+
config_name=config_name,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
supported_content_types = model_specs.predictor_specs.supported_content_types
|
|
539
|
+
|
|
540
|
+
return supported_content_types
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart resource names."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from sagemaker.core.jumpstart.constants import (
|
|
17
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
18
|
+
)
|
|
19
|
+
from sagemaker.core.jumpstart.enums import (
|
|
20
|
+
JumpStartScriptScope,
|
|
21
|
+
JumpStartModelType,
|
|
22
|
+
)
|
|
23
|
+
from sagemaker.core.jumpstart.utils import (
|
|
24
|
+
get_region_fallback,
|
|
25
|
+
verify_model_region_and_return_specs,
|
|
26
|
+
)
|
|
27
|
+
from sagemaker.core.helper.session_helper import Session
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _retrieve_resource_name_base(
|
|
31
|
+
model_id: str,
|
|
32
|
+
model_version: str,
|
|
33
|
+
region: Optional[str],
|
|
34
|
+
hub_arn: Optional[str] = None,
|
|
35
|
+
tolerate_vulnerable_model: bool = False,
|
|
36
|
+
tolerate_deprecated_model: bool = False,
|
|
37
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
38
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
39
|
+
scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
|
|
40
|
+
config_name: Optional[str] = None,
|
|
41
|
+
) -> bool:
|
|
42
|
+
"""Returns default resource name.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
46
|
+
get default resource name.
|
|
47
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
48
|
+
default resource name.
|
|
49
|
+
region (Optional[str]): Region for which to retrieve the
|
|
50
|
+
default resource name.
|
|
51
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
52
|
+
model details from. (Default: None).
|
|
53
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
54
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
55
|
+
exception if the script used by this version of the model has dependencies with known
|
|
56
|
+
security vulnerabilities. (Default: False).
|
|
57
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
58
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
59
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
60
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
61
|
+
object, used for SageMaker interactions. If not
|
|
62
|
+
specified, one is created using the default AWS configuration
|
|
63
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
64
|
+
config_name (Optional[str]): Name of the JumpStart Model config. (Default: None).
|
|
65
|
+
Returns:
|
|
66
|
+
str: the default resource name.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
region = region or get_region_fallback(
|
|
70
|
+
sagemaker_session=sagemaker_session,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
model_specs = verify_model_region_and_return_specs(
|
|
74
|
+
model_id=model_id,
|
|
75
|
+
version=model_version,
|
|
76
|
+
hub_arn=hub_arn,
|
|
77
|
+
scope=scope,
|
|
78
|
+
region=region,
|
|
79
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
80
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
81
|
+
model_type=model_type,
|
|
82
|
+
sagemaker_session=sagemaker_session,
|
|
83
|
+
config_name=config_name,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return model_specs.resource_name_base
|