sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart model packages."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from sagemaker.core.jumpstart.constants import (
|
|
17
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
18
|
+
)
|
|
19
|
+
from sagemaker.core.jumpstart.utils import (
|
|
20
|
+
get_region_fallback,
|
|
21
|
+
verify_model_region_and_return_specs,
|
|
22
|
+
)
|
|
23
|
+
from sagemaker.core.jumpstart.enums import (
|
|
24
|
+
JumpStartScriptScope,
|
|
25
|
+
JumpStartModelType,
|
|
26
|
+
)
|
|
27
|
+
from sagemaker.core.helper.session_helper import Session
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _retrieve_model_package_arn(
|
|
31
|
+
model_id: str,
|
|
32
|
+
model_version: str,
|
|
33
|
+
instance_type: Optional[str],
|
|
34
|
+
region: Optional[str],
|
|
35
|
+
hub_arn: Optional[str] = None,
|
|
36
|
+
scope: Optional[str] = None,
|
|
37
|
+
tolerate_vulnerable_model: bool = False,
|
|
38
|
+
tolerate_deprecated_model: bool = False,
|
|
39
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
40
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
41
|
+
config_name: Optional[str] = None,
|
|
42
|
+
) -> Optional[str]:
|
|
43
|
+
"""Retrieves associated model pacakge arn for the model.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
47
|
+
retrieve the model package arn.
|
|
48
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
49
|
+
model package arn.
|
|
50
|
+
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
|
|
51
|
+
specific for the instance type.
|
|
52
|
+
region (Optional[str]): Region for which to retrieve the model package arn.
|
|
53
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
54
|
+
model details from. (Default: None).
|
|
55
|
+
scope (Optional[str]): Scope for which to retrieve the model package arn.
|
|
56
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
57
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
58
|
+
exception if the script used by this version of the model has dependencies with known
|
|
59
|
+
security vulnerabilities. (Default: False).
|
|
60
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
61
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
62
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
63
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
64
|
+
object, used for SageMaker interactions. If not
|
|
65
|
+
specified, one is created using the default AWS configuration
|
|
66
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
67
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
str: the model package arn to use for the model or None.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
region = region or get_region_fallback(
|
|
74
|
+
sagemaker_session=sagemaker_session,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
model_specs = verify_model_region_and_return_specs(
|
|
78
|
+
model_id=model_id,
|
|
79
|
+
version=model_version,
|
|
80
|
+
hub_arn=hub_arn,
|
|
81
|
+
scope=scope,
|
|
82
|
+
region=region,
|
|
83
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
84
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
85
|
+
sagemaker_session=sagemaker_session,
|
|
86
|
+
model_type=model_type,
|
|
87
|
+
config_name=config_name,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if scope == JumpStartScriptScope.INFERENCE:
|
|
91
|
+
|
|
92
|
+
instance_specific_arn: Optional[str] = (
|
|
93
|
+
model_specs.hosting_instance_type_variants.get_model_package_arn(
|
|
94
|
+
region=region, instance_type=instance_type
|
|
95
|
+
)
|
|
96
|
+
if getattr(model_specs, "hosting_instance_type_variants", None) is not None
|
|
97
|
+
else None
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if instance_specific_arn is not None:
|
|
101
|
+
return instance_specific_arn
|
|
102
|
+
|
|
103
|
+
if (
|
|
104
|
+
model_specs.hosting_model_package_arns is None
|
|
105
|
+
or model_specs.hosting_model_package_arns == {}
|
|
106
|
+
):
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
regional_arn = model_specs.hosting_model_package_arns.get(region)
|
|
110
|
+
|
|
111
|
+
if regional_arn is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Model package arn for '{model_id}' not supported in {region}. "
|
|
114
|
+
"Please try one of the following regions: "
|
|
115
|
+
f"{', '.join(model_specs.hosting_model_package_arns.keys())}."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return regional_arn
|
|
119
|
+
|
|
120
|
+
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _retrieve_model_package_model_artifact_s3_uri(
|
|
124
|
+
model_id: str,
|
|
125
|
+
model_version: str,
|
|
126
|
+
region: Optional[str],
|
|
127
|
+
hub_arn: Optional[str] = None,
|
|
128
|
+
scope: Optional[str] = None,
|
|
129
|
+
tolerate_vulnerable_model: bool = False,
|
|
130
|
+
tolerate_deprecated_model: bool = False,
|
|
131
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
132
|
+
config_name: Optional[str] = None,
|
|
133
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
134
|
+
) -> Optional[str]:
|
|
135
|
+
"""Retrieves s3 artifact uri associated with model package.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
139
|
+
retrieve the model package artifact.
|
|
140
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
141
|
+
model package artifact.
|
|
142
|
+
region (Optional[str]): Region for which to retrieve the model package artifact.
|
|
143
|
+
(Default: None).
|
|
144
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
145
|
+
model details from. (Default: None).
|
|
146
|
+
scope (Optional[str]): Scope for which to retrieve the model package artifact.
|
|
147
|
+
(Default: None).
|
|
148
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
149
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
150
|
+
exception if the script used by this version of the model has dependencies with known
|
|
151
|
+
security vulnerabilities. (Default: False).
|
|
152
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
153
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
154
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
155
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
156
|
+
object, used for SageMaker interactions. If not
|
|
157
|
+
specified, one is created using the default AWS configuration
|
|
158
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
159
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
160
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
161
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
162
|
+
Returns:
|
|
163
|
+
str: the model package artifact uri to use for the model or None.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
NotImplementedError: If an unsupported script is used.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if scope == JumpStartScriptScope.TRAINING:
|
|
170
|
+
|
|
171
|
+
region = region or get_region_fallback(
|
|
172
|
+
sagemaker_session=sagemaker_session,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
model_specs = verify_model_region_and_return_specs(
|
|
176
|
+
model_id=model_id,
|
|
177
|
+
version=model_version,
|
|
178
|
+
hub_arn=hub_arn,
|
|
179
|
+
scope=scope,
|
|
180
|
+
region=region,
|
|
181
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
182
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
183
|
+
sagemaker_session=sagemaker_session,
|
|
184
|
+
config_name=config_name,
|
|
185
|
+
model_type=model_type,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if model_specs.training_model_package_artifact_uris is None:
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)
|
|
192
|
+
|
|
193
|
+
if model_s3_uri is None:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Model package artifact s3 uri for '{model_id}' not supported in {region}. "
|
|
196
|
+
"Please try one of the following regions: "
|
|
197
|
+
f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
return model_s3_uri
|
|
201
|
+
|
|
202
|
+
raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart model uris."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import os
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker.core.jumpstart.constants import (
|
|
19
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
20
|
+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
|
|
21
|
+
)
|
|
22
|
+
from sagemaker.core.jumpstart.enums import (
|
|
23
|
+
JumpStartModelType,
|
|
24
|
+
JumpStartScriptScope,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker.core.jumpstart.utils import (
|
|
27
|
+
get_jumpstart_content_bucket,
|
|
28
|
+
get_jumpstart_gated_content_bucket,
|
|
29
|
+
get_region_fallback,
|
|
30
|
+
verify_model_region_and_return_specs,
|
|
31
|
+
)
|
|
32
|
+
from sagemaker.core.s3.utils import is_s3_url
|
|
33
|
+
from sagemaker.core.helper.session_helper import Session
|
|
34
|
+
from sagemaker.core.jumpstart.types import JumpStartModelSpecs
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _retrieve_hosting_prepacked_artifact_key(
|
|
38
|
+
model_specs: JumpStartModelSpecs, instance_type: str
|
|
39
|
+
) -> str:
|
|
40
|
+
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
|
|
41
|
+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
|
|
42
|
+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
|
|
43
|
+
instance_type=instance_type
|
|
44
|
+
)
|
|
45
|
+
if instance_type
|
|
46
|
+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
|
|
47
|
+
else None
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
|
|
51
|
+
model_specs, "hosting_prepacked_artifact_key"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return (
|
|
55
|
+
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
|
|
60
|
+
"""Returns instance specific hosting artifact key or default one as fallback."""
|
|
61
|
+
instance_specific_hosting_artifact_key: Optional[str] = (
|
|
62
|
+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
|
|
63
|
+
instance_type=instance_type
|
|
64
|
+
)
|
|
65
|
+
if instance_type
|
|
66
|
+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
|
|
67
|
+
else None
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
|
|
71
|
+
|
|
72
|
+
return instance_specific_hosting_artifact_key or default_hosting_artifact_key
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
|
|
76
|
+
"""Returns instance specific training artifact key or default one as fallback."""
|
|
77
|
+
instance_specific_training_artifact_key: Optional[str] = (
|
|
78
|
+
model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key(
|
|
79
|
+
instance_type=instance_type
|
|
80
|
+
)
|
|
81
|
+
if instance_type
|
|
82
|
+
and getattr(model_specs, "training_instance_type_variants", None) is not None
|
|
83
|
+
else None
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
default_training_artifact_key: str = model_specs.training_artifact_key
|
|
87
|
+
|
|
88
|
+
return instance_specific_training_artifact_key or default_training_artifact_key
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _retrieve_model_uri(
|
|
92
|
+
model_id: str,
|
|
93
|
+
model_version: str,
|
|
94
|
+
hub_arn: Optional[str] = None,
|
|
95
|
+
model_scope: Optional[str] = None,
|
|
96
|
+
instance_type: Optional[str] = None,
|
|
97
|
+
region: Optional[str] = None,
|
|
98
|
+
tolerate_vulnerable_model: bool = False,
|
|
99
|
+
tolerate_deprecated_model: bool = False,
|
|
100
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
101
|
+
config_name: Optional[str] = None,
|
|
102
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
103
|
+
):
|
|
104
|
+
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
|
|
105
|
+
|
|
106
|
+
Optionally uses a bucket override specified by environment variable.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
|
|
110
|
+
the model artifact S3 URI.
|
|
111
|
+
model_version (str): Version of the JumpStart model for which to retrieve the model
|
|
112
|
+
artifact S3 URI.
|
|
113
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
114
|
+
model details from. (Default: None).
|
|
115
|
+
model_scope (str): The model type, i.e. what it is used for.
|
|
116
|
+
Valid values: "training" and "inference".
|
|
117
|
+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
|
|
118
|
+
region (str): Region for which to retrieve model S3 URI. (Default: None).
|
|
119
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
120
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
121
|
+
exception if the script used by this version of the model has dependencies with known
|
|
122
|
+
security vulnerabilities. (Default: False).
|
|
123
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
124
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
125
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
126
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
127
|
+
object, used for SageMaker interactions. If not
|
|
128
|
+
specified, one is created using the default AWS configuration
|
|
129
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
130
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
131
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
132
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
str: the model artifact S3 URI for the corresponding model.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
139
|
+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
|
|
140
|
+
known security vulnerabilities.
|
|
141
|
+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
|
|
142
|
+
"""
|
|
143
|
+
region = region or get_region_fallback(
|
|
144
|
+
sagemaker_session=sagemaker_session,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
model_specs = verify_model_region_and_return_specs(
|
|
148
|
+
model_id=model_id,
|
|
149
|
+
version=model_version,
|
|
150
|
+
hub_arn=hub_arn,
|
|
151
|
+
scope=model_scope,
|
|
152
|
+
region=region,
|
|
153
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
154
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
155
|
+
sagemaker_session=sagemaker_session,
|
|
156
|
+
config_name=config_name,
|
|
157
|
+
model_type=model_type,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
model_artifact_key: str
|
|
161
|
+
|
|
162
|
+
if model_scope == JumpStartScriptScope.INFERENCE:
|
|
163
|
+
|
|
164
|
+
is_prepacked = not model_specs.use_inference_script_uri()
|
|
165
|
+
|
|
166
|
+
if hub_arn:
|
|
167
|
+
model_artifact_uri = model_specs.hosting_artifact_uri
|
|
168
|
+
return model_artifact_uri
|
|
169
|
+
model_artifact_key = (
|
|
170
|
+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
|
|
171
|
+
if is_prepacked
|
|
172
|
+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
elif model_scope == JumpStartScriptScope.TRAINING:
|
|
176
|
+
|
|
177
|
+
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
|
|
178
|
+
|
|
179
|
+
default_jumpstart_bucket: str = (
|
|
180
|
+
get_jumpstart_gated_content_bucket(region)
|
|
181
|
+
if model_specs.gated_bucket
|
|
182
|
+
else get_jumpstart_content_bucket(region)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
bucket = (
|
|
186
|
+
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
|
|
187
|
+
or default_jumpstart_bucket
|
|
188
|
+
)
|
|
189
|
+
if not is_s3_url(model_artifact_key):
|
|
190
|
+
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
|
|
191
|
+
|
|
192
|
+
return model_s3_uri
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _model_supports_training_model_uri(
|
|
196
|
+
model_id: str,
|
|
197
|
+
model_version: str,
|
|
198
|
+
region: Optional[str],
|
|
199
|
+
hub_arn: Optional[str] = None,
|
|
200
|
+
tolerate_vulnerable_model: bool = False,
|
|
201
|
+
tolerate_deprecated_model: bool = False,
|
|
202
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
203
|
+
config_name: Optional[str] = None,
|
|
204
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
205
|
+
) -> bool:
|
|
206
|
+
"""Returns True if the model supports training with model uri field.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
210
|
+
retrieve the support status for model uri with training.
|
|
211
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
212
|
+
support status for model uri with training.
|
|
213
|
+
region (Optional[str]): Region for which to retrieve the
|
|
214
|
+
support status for model uri with training.
|
|
215
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
216
|
+
model details from. (Default: None).
|
|
217
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
218
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
219
|
+
exception if the script used by this version of the model has dependencies with known
|
|
220
|
+
security vulnerabilities. (Default: False).
|
|
221
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
222
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
223
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
224
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
225
|
+
object, used for SageMaker interactions. If not
|
|
226
|
+
specified, one is created using the default AWS configuration
|
|
227
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
228
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
229
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
230
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
231
|
+
Returns:
|
|
232
|
+
bool: the support status for model uri with training.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
region = region or get_region_fallback(
|
|
236
|
+
sagemaker_session=sagemaker_session,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
model_specs = verify_model_region_and_return_specs(
|
|
240
|
+
model_id=model_id,
|
|
241
|
+
version=model_version,
|
|
242
|
+
hub_arn=hub_arn,
|
|
243
|
+
scope=JumpStartScriptScope.TRAINING,
|
|
244
|
+
region=region,
|
|
245
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
246
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
247
|
+
sagemaker_session=sagemaker_session,
|
|
248
|
+
config_name=config_name,
|
|
249
|
+
model_type=model_type,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return model_specs.use_training_model_artifact()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions to obtain JumpStart model payloads."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from copy import deepcopy
|
|
16
|
+
from typing import Dict, Optional
|
|
17
|
+
from sagemaker.core.jumpstart.constants import (
|
|
18
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
19
|
+
)
|
|
20
|
+
from sagemaker.core.jumpstart.enums import (
|
|
21
|
+
JumpStartScriptScope,
|
|
22
|
+
JumpStartModelType,
|
|
23
|
+
)
|
|
24
|
+
from sagemaker.core.jumpstart.types import JumpStartSerializablePayload
|
|
25
|
+
from sagemaker.core.jumpstart.utils import (
|
|
26
|
+
get_region_fallback,
|
|
27
|
+
verify_model_region_and_return_specs,
|
|
28
|
+
)
|
|
29
|
+
from sagemaker.core.helper.session_helper import Session
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _retrieve_example_payloads(
|
|
33
|
+
model_id: str,
|
|
34
|
+
model_version: str,
|
|
35
|
+
region: Optional[str],
|
|
36
|
+
hub_arn: Optional[str] = None,
|
|
37
|
+
tolerate_vulnerable_model: bool = False,
|
|
38
|
+
tolerate_deprecated_model: bool = False,
|
|
39
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
40
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
41
|
+
config_name: Optional[str] = None,
|
|
42
|
+
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
|
|
43
|
+
"""Returns example payloads.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
47
|
+
get example payloads.
|
|
48
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
49
|
+
example payloads.
|
|
50
|
+
region (Optional[str]): Region for which to retrieve the
|
|
51
|
+
example payloads.
|
|
52
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
53
|
+
model details from. (Default: None).
|
|
54
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
55
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
56
|
+
exception if the script used by this version of the model has dependencies with known
|
|
57
|
+
security vulnerabilities. (Default: False).
|
|
58
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
59
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
60
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
61
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
62
|
+
object, used for SageMaker interactions. If not
|
|
63
|
+
specified, one is created using the default AWS configuration
|
|
64
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
65
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
66
|
+
Returns:
|
|
67
|
+
Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases
|
|
68
|
+
to the serializable payload object.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
region = region or get_region_fallback(
|
|
72
|
+
sagemaker_session=sagemaker_session,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
model_specs = verify_model_region_and_return_specs(
|
|
76
|
+
model_id=model_id,
|
|
77
|
+
version=model_version,
|
|
78
|
+
hub_arn=hub_arn,
|
|
79
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
80
|
+
region=region,
|
|
81
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
82
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
83
|
+
sagemaker_session=sagemaker_session,
|
|
84
|
+
model_type=model_type,
|
|
85
|
+
config_name=config_name,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
default_payloads = model_specs.default_payloads
|
|
89
|
+
|
|
90
|
+
if default_payloads:
|
|
91
|
+
for payload in default_payloads.values():
|
|
92
|
+
payload.accept = getattr(
|
|
93
|
+
payload, "accept", model_specs.predictor_specs.default_accept_type
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return deepcopy(default_payloads) if default_payloads else None
|