sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,833 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores JumpStart factory utilities."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
import json
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
from sagemaker.core.shapes import ModelAccessConfig
|
|
19
|
+
from sagemaker.core import (
|
|
20
|
+
environment_variables,
|
|
21
|
+
image_uris,
|
|
22
|
+
instance_types,
|
|
23
|
+
model_uris,
|
|
24
|
+
script_uris,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker.serve.async_inference.async_inference_config import AsyncInferenceConfig
|
|
27
|
+
from sagemaker.core.deserializers.base import BaseDeserializer
|
|
28
|
+
from sagemaker.core.serializers.base import BaseSerializer
|
|
29
|
+
from sagemaker.core.explainer.explainer_config import ExplainerConfig
|
|
30
|
+
from sagemaker.core.jumpstart.artifacts import (
|
|
31
|
+
_model_supports_inference_script_uri,
|
|
32
|
+
_retrieve_model_init_kwargs,
|
|
33
|
+
_retrieve_model_deploy_kwargs,
|
|
34
|
+
_retrieve_model_package_arn,
|
|
35
|
+
)
|
|
36
|
+
from sagemaker.core.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
|
|
37
|
+
from sagemaker.core.jumpstart.constants import (
|
|
38
|
+
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
|
|
39
|
+
JUMPSTART_DEFAULT_REGION_NAME,
|
|
40
|
+
JUMPSTART_LOGGER,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
44
|
+
from sagemaker.core.jumpstart.hub.utils import (
|
|
45
|
+
construct_hub_model_arn_from_inputs,
|
|
46
|
+
construct_hub_model_reference_arn_from_inputs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
from sagemaker.core.jumpstart.enums import (
|
|
50
|
+
JumpStartScriptScope,
|
|
51
|
+
JumpStartModelType,
|
|
52
|
+
HubContentCapability,
|
|
53
|
+
)
|
|
54
|
+
from sagemaker.core.jumpstart.types import (
|
|
55
|
+
HubContentType,
|
|
56
|
+
JumpStartEstimatorDeployKwargs,
|
|
57
|
+
JumpStartEstimatorFitKwargs,
|
|
58
|
+
JumpStartEstimatorInitKwargs,
|
|
59
|
+
JumpStartModelDeployKwargs,
|
|
60
|
+
JumpStartModelInitKwargs,
|
|
61
|
+
JumpStartModelSpecs,
|
|
62
|
+
)
|
|
63
|
+
from sagemaker.core.jumpstart.utils import (
|
|
64
|
+
add_hub_content_arn_tags,
|
|
65
|
+
add_jumpstart_model_info_tags,
|
|
66
|
+
add_bedrock_store_tags,
|
|
67
|
+
get_default_jumpstart_session_with_user_agent_suffix,
|
|
68
|
+
get_top_ranked_config_name,
|
|
69
|
+
update_dict_if_key_not_present,
|
|
70
|
+
resolve_model_sagemaker_config_field,
|
|
71
|
+
verify_model_region_and_return_specs,
|
|
72
|
+
get_draft_model_content_bucket,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
from sagemaker.core.model_monitor.data_capture_config import DataCaptureConfig
|
|
76
|
+
|
|
77
|
+
from sagemaker.serve.serverless.serverless_inference_config import ServerlessInferenceConfig
|
|
78
|
+
from sagemaker.core.helper.session_helper import Session
|
|
79
|
+
from sagemaker.core.common_utils import (
|
|
80
|
+
camel_case_to_pascal_case,
|
|
81
|
+
name_from_base,
|
|
82
|
+
format_tags,
|
|
83
|
+
Tags,
|
|
84
|
+
)
|
|
85
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
86
|
+
from sagemaker.serve.compute_resource_requirements.resource_requirements import ResourceRequirements
|
|
87
|
+
from sagemaker.core import resource_requirements
|
|
88
|
+
from sagemaker.core.enums import EndpointType
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
KwargsType = Union[
|
|
92
|
+
JumpStartModelDeployKwargs,
|
|
93
|
+
JumpStartModelInitKwargs,
|
|
94
|
+
JumpStartEstimatorFitKwargs,
|
|
95
|
+
JumpStartEstimatorInitKwargs,
|
|
96
|
+
JumpStartEstimatorDeployKwargs,
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_model_info_default_kwargs(
|
|
101
|
+
kwargs: KwargsType,
|
|
102
|
+
include_config_name: bool = True,
|
|
103
|
+
include_model_version: bool = True,
|
|
104
|
+
include_tolerate_flags: bool = True,
|
|
105
|
+
) -> dict:
|
|
106
|
+
"""Returns a dictionary of model info kwargs to use with JumpStart APIs."""
|
|
107
|
+
|
|
108
|
+
kwargs_dict = {
|
|
109
|
+
"model_id": kwargs.model_id,
|
|
110
|
+
"hub_arn": kwargs.hub_arn,
|
|
111
|
+
"region": kwargs.region,
|
|
112
|
+
"sagemaker_session": kwargs.sagemaker_session,
|
|
113
|
+
"model_type": kwargs.model_type,
|
|
114
|
+
}
|
|
115
|
+
if include_config_name:
|
|
116
|
+
kwargs_dict.update({"config_name": kwargs.config_name})
|
|
117
|
+
|
|
118
|
+
if include_model_version:
|
|
119
|
+
kwargs_dict.update({"model_version": kwargs.model_version})
|
|
120
|
+
|
|
121
|
+
if include_tolerate_flags:
|
|
122
|
+
kwargs_dict.update(
|
|
123
|
+
{
|
|
124
|
+
"tolerate_deprecated_model": kwargs.tolerate_deprecated_model,
|
|
125
|
+
"tolerate_vulnerable_model": kwargs.tolerate_vulnerable_model,
|
|
126
|
+
}
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return kwargs_dict
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]:
|
|
133
|
+
"""Sets a temporary sagemaker session if one is not set, and returns original session.
|
|
134
|
+
|
|
135
|
+
We need to create a default JS session (without custom user agent)
|
|
136
|
+
in order to retrieve config name info.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
orig_session = kwargs.sagemaker_session
|
|
140
|
+
if kwargs.sagemaker_session is None:
|
|
141
|
+
kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
142
|
+
return kwargs, orig_session
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
146
|
+
"""Sets region kwargs based on default or override, returns full kwargs."""
|
|
147
|
+
|
|
148
|
+
kwargs.region = (
|
|
149
|
+
kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return kwargs
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
|
|
156
|
+
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs],
|
|
157
|
+
orig_session: Optional[Session],
|
|
158
|
+
) -> JumpStartModelInitKwargs:
|
|
159
|
+
"""Sets session in kwargs based on default or override, returns full kwargs."""
|
|
160
|
+
|
|
161
|
+
kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix(
|
|
162
|
+
model_id=kwargs.model_id,
|
|
163
|
+
model_version=kwargs.model_version,
|
|
164
|
+
config_name=kwargs.config_name,
|
|
165
|
+
is_hub_content=kwargs.hub_arn is not None,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return kwargs
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _add_role_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
172
|
+
"""Sets role based on default or override, returns full kwargs."""
|
|
173
|
+
|
|
174
|
+
kwargs.role = resolve_model_sagemaker_config_field(
|
|
175
|
+
field_name="role",
|
|
176
|
+
field_val=kwargs.role,
|
|
177
|
+
sagemaker_session=kwargs.sagemaker_session,
|
|
178
|
+
default_value=kwargs.role,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
return kwargs
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _add_model_version_to_kwargs(
|
|
185
|
+
kwargs: JumpStartModelInitKwargs,
|
|
186
|
+
) -> JumpStartModelInitKwargs:
|
|
187
|
+
"""Sets model version based on default or override, returns full kwargs."""
|
|
188
|
+
|
|
189
|
+
kwargs.model_version = kwargs.model_version or "*"
|
|
190
|
+
|
|
191
|
+
if kwargs.hub_arn:
|
|
192
|
+
hub_content_version = kwargs.specs.version
|
|
193
|
+
kwargs.model_version = hub_content_version
|
|
194
|
+
|
|
195
|
+
return kwargs
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _add_vulnerable_and_deprecated_status_to_kwargs(
|
|
199
|
+
kwargs: JumpStartModelInitKwargs,
|
|
200
|
+
) -> JumpStartModelInitKwargs:
|
|
201
|
+
"""Sets deprecated and vulnerability check status, returns full kwargs."""
|
|
202
|
+
|
|
203
|
+
kwargs.tolerate_deprecated_model = kwargs.tolerate_deprecated_model or False
|
|
204
|
+
kwargs.tolerate_vulnerable_model = kwargs.tolerate_vulnerable_model or False
|
|
205
|
+
|
|
206
|
+
return kwargs
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _add_instance_type_to_kwargs(
|
|
210
|
+
kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False
|
|
211
|
+
) -> JumpStartModelInitKwargs:
|
|
212
|
+
"""Sets instance type based on default or override, returns full kwargs."""
|
|
213
|
+
|
|
214
|
+
orig_instance_type = kwargs.instance_type
|
|
215
|
+
kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default(
|
|
216
|
+
**get_model_info_default_kwargs(kwargs),
|
|
217
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
218
|
+
training_instance_type=kwargs.training_instance_type,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if not disable_instance_type_logging and orig_instance_type is None:
|
|
222
|
+
JUMPSTART_LOGGER.info(
|
|
223
|
+
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
|
|
224
|
+
kwargs.instance_type,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
specs = kwargs.specs
|
|
228
|
+
|
|
229
|
+
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
|
|
230
|
+
return kwargs
|
|
231
|
+
|
|
232
|
+
resolved_config = (
|
|
233
|
+
specs.inference_configs.configs[kwargs.config_name].resolved_config
|
|
234
|
+
if specs.inference_configs
|
|
235
|
+
else None
|
|
236
|
+
)
|
|
237
|
+
if resolved_config is None:
|
|
238
|
+
return kwargs
|
|
239
|
+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
|
|
240
|
+
if kwargs.instance_type not in supported_instance_types:
|
|
241
|
+
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
|
|
242
|
+
|
|
243
|
+
return kwargs
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
247
|
+
"""Sets image uri based on default or override, returns full kwargs.
|
|
248
|
+
Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
|
|
252
|
+
kwargs.image_uri = None
|
|
253
|
+
return kwargs
|
|
254
|
+
|
|
255
|
+
kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
|
|
256
|
+
**get_model_info_default_kwargs(kwargs),
|
|
257
|
+
framework=None,
|
|
258
|
+
image_scope=JumpStartScriptScope.INFERENCE,
|
|
259
|
+
instance_type=kwargs.instance_type,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return kwargs
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _add_model_reference_arn_to_kwargs(
|
|
266
|
+
kwargs: JumpStartModelInitKwargs,
|
|
267
|
+
) -> JumpStartModelInitKwargs:
|
|
268
|
+
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
|
|
269
|
+
|
|
270
|
+
hub_content_type = kwargs.specs.hub_content_type
|
|
271
|
+
kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None
|
|
272
|
+
|
|
273
|
+
if hub_content_type == HubContentType.MODEL_REFERENCE:
|
|
274
|
+
kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
|
|
275
|
+
hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
kwargs.model_reference_arn = None
|
|
279
|
+
return kwargs
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
283
|
+
"""Sets model data based on default or override, returns full kwargs."""
|
|
284
|
+
|
|
285
|
+
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
|
|
286
|
+
kwargs.model_data = None
|
|
287
|
+
return kwargs
|
|
288
|
+
|
|
289
|
+
model_info_kwargs = get_model_info_default_kwargs(kwargs)
|
|
290
|
+
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
|
|
291
|
+
**model_info_kwargs,
|
|
292
|
+
model_scope=JumpStartScriptScope.INFERENCE,
|
|
293
|
+
instance_type=kwargs.instance_type,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
|
|
297
|
+
old_model_data_str = model_data
|
|
298
|
+
model_data = {
|
|
299
|
+
"S3DataSource": {
|
|
300
|
+
"S3Uri": model_data,
|
|
301
|
+
"S3DataType": "S3Prefix",
|
|
302
|
+
"CompressionType": "None",
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
if kwargs.model_data:
|
|
306
|
+
JUMPSTART_LOGGER.info(
|
|
307
|
+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
|
|
308
|
+
"Converting to S3DataSource dictionary: '%s'.",
|
|
309
|
+
old_model_data_str,
|
|
310
|
+
json.dumps(model_data),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
kwargs.model_data = model_data
|
|
314
|
+
|
|
315
|
+
return kwargs
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
319
|
+
"""Sets source dir based on default or override, returns full kwargs."""
|
|
320
|
+
|
|
321
|
+
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
|
|
322
|
+
kwargs.source_dir = None
|
|
323
|
+
return kwargs
|
|
324
|
+
|
|
325
|
+
source_dir = kwargs.source_dir
|
|
326
|
+
|
|
327
|
+
if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
|
|
328
|
+
source_dir = source_dir or script_uris.retrieve(
|
|
329
|
+
**get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
kwargs.source_dir = source_dir
|
|
333
|
+
|
|
334
|
+
return kwargs
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
338
|
+
"""Sets entry point based on default or override, returns full kwargs."""
|
|
339
|
+
|
|
340
|
+
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
|
|
341
|
+
kwargs.entry_point = None
|
|
342
|
+
return kwargs
|
|
343
|
+
|
|
344
|
+
entry_point = kwargs.entry_point
|
|
345
|
+
|
|
346
|
+
if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
|
|
347
|
+
|
|
348
|
+
entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
|
|
349
|
+
|
|
350
|
+
kwargs.entry_point = entry_point
|
|
351
|
+
|
|
352
|
+
return kwargs
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
356
|
+
"""Sets env based on default or override, returns full kwargs."""
|
|
357
|
+
|
|
358
|
+
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
|
|
359
|
+
kwargs.env = None
|
|
360
|
+
return kwargs
|
|
361
|
+
|
|
362
|
+
env = kwargs.env
|
|
363
|
+
|
|
364
|
+
if env is None:
|
|
365
|
+
env = {}
|
|
366
|
+
|
|
367
|
+
extra_env_vars = environment_variables.retrieve_default(
|
|
368
|
+
**get_model_info_default_kwargs(kwargs),
|
|
369
|
+
include_aws_sdk_env_vars=False,
|
|
370
|
+
script=JumpStartScriptScope.INFERENCE,
|
|
371
|
+
instance_type=kwargs.instance_type,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
for key, value in extra_env_vars.items():
|
|
375
|
+
update_dict_if_key_not_present(
|
|
376
|
+
env,
|
|
377
|
+
key,
|
|
378
|
+
value,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
if env == {}:
|
|
382
|
+
env = None
|
|
383
|
+
|
|
384
|
+
kwargs.env = env
|
|
385
|
+
|
|
386
|
+
return kwargs
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
390
|
+
"""Sets model package arn based on default or override, returns full kwargs."""
|
|
391
|
+
|
|
392
|
+
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
|
|
393
|
+
**get_model_info_default_kwargs(kwargs),
|
|
394
|
+
instance_type=kwargs.instance_type,
|
|
395
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
kwargs.model_package_arn = model_package_arn
|
|
399
|
+
return kwargs
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
403
|
+
"""Sets extra kwargs based on default or override, returns full kwargs."""
|
|
404
|
+
|
|
405
|
+
model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs))
|
|
406
|
+
|
|
407
|
+
for key, value in model_kwargs_to_add.items():
|
|
408
|
+
if getattr(kwargs, key) is None:
|
|
409
|
+
resolved_value = resolve_model_sagemaker_config_field(
|
|
410
|
+
field_name=key,
|
|
411
|
+
field_val=value,
|
|
412
|
+
sagemaker_session=kwargs.sagemaker_session,
|
|
413
|
+
)
|
|
414
|
+
setattr(kwargs, key, resolved_value)
|
|
415
|
+
|
|
416
|
+
return kwargs
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _add_endpoint_name_to_kwargs(
|
|
420
|
+
kwargs: Optional[JumpStartModelDeployKwargs],
|
|
421
|
+
) -> JumpStartModelDeployKwargs:
|
|
422
|
+
"""Sets resource name based on default or override, returns full kwargs."""
|
|
423
|
+
|
|
424
|
+
default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
|
|
425
|
+
|
|
426
|
+
kwargs.endpoint_name = kwargs.endpoint_name or (
|
|
427
|
+
name_from_base(default_endpoint_name) if default_endpoint_name is not None else None
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
return kwargs
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _add_model_name_to_kwargs(
|
|
434
|
+
kwargs: Optional[JumpStartModelInitKwargs],
|
|
435
|
+
) -> JumpStartModelInitKwargs:
|
|
436
|
+
"""Sets resource name based on default or override, returns full kwargs."""
|
|
437
|
+
|
|
438
|
+
default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
|
|
439
|
+
|
|
440
|
+
kwargs.name = kwargs.name or (
|
|
441
|
+
name_from_base(default_model_name) if default_model_name is not None else None
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return kwargs
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
|
|
448
|
+
"""Sets tags based on default or override, returns full kwargs."""
|
|
449
|
+
|
|
450
|
+
full_model_version = kwargs.specs.version
|
|
451
|
+
|
|
452
|
+
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
|
|
453
|
+
kwargs.tags = add_jumpstart_model_info_tags(
|
|
454
|
+
kwargs.tags,
|
|
455
|
+
kwargs.model_id,
|
|
456
|
+
full_model_version,
|
|
457
|
+
kwargs.model_type,
|
|
458
|
+
config_name=kwargs.config_name,
|
|
459
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if kwargs.hub_arn:
|
|
463
|
+
if kwargs.model_reference_arn:
|
|
464
|
+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
|
|
465
|
+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
hub_content_arn = construct_hub_model_arn_from_inputs(
|
|
469
|
+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
|
|
470
|
+
)
|
|
471
|
+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
|
|
472
|
+
|
|
473
|
+
if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
|
|
474
|
+
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
|
|
475
|
+
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
|
|
476
|
+
|
|
477
|
+
return kwargs
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]:
|
|
481
|
+
"""Sets extra kwargs based on default or override, returns full kwargs."""
|
|
482
|
+
|
|
483
|
+
deploy_kwargs_to_add = _retrieve_model_deploy_kwargs(
|
|
484
|
+
**get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
for key, value in deploy_kwargs_to_add.items():
|
|
488
|
+
if getattr(kwargs, key) is None:
|
|
489
|
+
setattr(kwargs, key, value)
|
|
490
|
+
|
|
491
|
+
return kwargs
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
495
|
+
"""Sets the resource requirements based on the default or an override. Returns full kwargs."""
|
|
496
|
+
|
|
497
|
+
kwargs.resources = kwargs.resources or resource_requirements.retrieve_default(
|
|
498
|
+
**get_model_info_default_kwargs(kwargs),
|
|
499
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
500
|
+
instance_type=kwargs.instance_type,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
return kwargs
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def _select_inference_config_from_training_config(
|
|
507
|
+
specs: JumpStartModelSpecs, training_config_name: str
|
|
508
|
+
) -> Optional[str]:
|
|
509
|
+
"""Selects the inference config from the training config.
|
|
510
|
+
Args:
|
|
511
|
+
specs (JumpStartModelSpecs): The specs for the model.
|
|
512
|
+
training_config_name (str): The name of the training config.
|
|
513
|
+
Returns:
|
|
514
|
+
str: The name of the inference config.
|
|
515
|
+
"""
|
|
516
|
+
if specs.training_configs:
|
|
517
|
+
resolved_training_config = specs.training_configs.configs.get(training_config_name)
|
|
518
|
+
if resolved_training_config:
|
|
519
|
+
return resolved_training_config.default_inference_config
|
|
520
|
+
|
|
521
|
+
return None
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
|
|
525
|
+
"""Sets default config name to the kwargs. Returns full kwargs.
|
|
526
|
+
Raises:
|
|
527
|
+
ValueError: If the instance_type is not supported with the current config.
|
|
528
|
+
"""
|
|
529
|
+
|
|
530
|
+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
|
|
531
|
+
**get_model_info_default_kwargs(kwargs, include_config_name=False),
|
|
532
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
if kwargs.config_name is None:
|
|
536
|
+
return kwargs
|
|
537
|
+
|
|
538
|
+
return kwargs
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _add_additional_model_data_sources_to_kwargs(
|
|
542
|
+
kwargs: JumpStartModelInitKwargs,
|
|
543
|
+
) -> JumpStartModelInitKwargs:
|
|
544
|
+
"""Sets default additional model data sources to init kwargs"""
|
|
545
|
+
|
|
546
|
+
specs = kwargs.specs
|
|
547
|
+
# Append speculative decoding data source from metadata
|
|
548
|
+
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
|
|
549
|
+
for data_source in speculative_decoding_data_sources:
|
|
550
|
+
data_source.s3_data_source.set_bucket(
|
|
551
|
+
get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
|
|
552
|
+
)
|
|
553
|
+
api_shape_additional_model_data_sources = (
|
|
554
|
+
[
|
|
555
|
+
camel_case_to_pascal_case(data_source.to_json())
|
|
556
|
+
for data_source in speculative_decoding_data_sources
|
|
557
|
+
]
|
|
558
|
+
if specs.get_speculative_decoding_s3_data_sources()
|
|
559
|
+
else None
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
kwargs.additional_model_data_sources = (
|
|
563
|
+
kwargs.additional_model_data_sources or api_shape_additional_model_data_sources
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
return kwargs
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _add_config_name_to_deploy_kwargs(
|
|
570
|
+
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
|
|
571
|
+
) -> JumpStartModelInitKwargs:
|
|
572
|
+
"""Sets default config name to the kwargs. Returns full kwargs.
|
|
573
|
+
If a training_config_name is passed, then choose the inference config
|
|
574
|
+
based on the supported inference configs in that training config.
|
|
575
|
+
Raises:
|
|
576
|
+
ValueError: If the instance_type is not supported with the current config.
|
|
577
|
+
"""
|
|
578
|
+
|
|
579
|
+
if training_config_name:
|
|
580
|
+
|
|
581
|
+
specs = kwargs.specs
|
|
582
|
+
default_config_name = _select_inference_config_from_training_config(
|
|
583
|
+
specs=specs, training_config_name=training_config_name
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
default_config_name = kwargs.config_name or get_top_ranked_config_name(
|
|
588
|
+
**get_model_info_default_kwargs(kwargs, include_config_name=False),
|
|
589
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
kwargs.config_name = kwargs.config_name or default_config_name
|
|
593
|
+
|
|
594
|
+
return kwargs
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def get_deploy_kwargs(
|
|
598
|
+
model_id: str,
|
|
599
|
+
model_version: Optional[str] = None,
|
|
600
|
+
hub_arn: Optional[str] = None,
|
|
601
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
602
|
+
region: Optional[str] = None,
|
|
603
|
+
initial_instance_count: Optional[int] = None,
|
|
604
|
+
instance_type: Optional[str] = None,
|
|
605
|
+
serializer: Optional[BaseSerializer] = None,
|
|
606
|
+
deserializer: Optional[BaseDeserializer] = None,
|
|
607
|
+
accelerator_type: Optional[str] = None,
|
|
608
|
+
endpoint_name: Optional[str] = None,
|
|
609
|
+
inference_component_name: Optional[str] = None,
|
|
610
|
+
tags: Optional[Tags] = None,
|
|
611
|
+
kms_key: Optional[str] = None,
|
|
612
|
+
wait: Optional[bool] = None,
|
|
613
|
+
data_capture_config: Optional[DataCaptureConfig] = None,
|
|
614
|
+
async_inference_config: Optional[AsyncInferenceConfig] = None,
|
|
615
|
+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
|
|
616
|
+
volume_size: Optional[int] = None,
|
|
617
|
+
model_data_download_timeout: Optional[int] = None,
|
|
618
|
+
container_startup_health_check_timeout: Optional[int] = None,
|
|
619
|
+
inference_recommendation_id: Optional[str] = None,
|
|
620
|
+
explainer_config: Optional[ExplainerConfig] = None,
|
|
621
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
622
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
623
|
+
sagemaker_session: Optional[Session] = None,
|
|
624
|
+
accept_eula: Optional[bool] = None,
|
|
625
|
+
model_reference_arn: Optional[str] = None,
|
|
626
|
+
endpoint_logging: Optional[bool] = None,
|
|
627
|
+
resources: Optional[ResourceRequirements] = None,
|
|
628
|
+
managed_instance_scaling: Optional[str] = None,
|
|
629
|
+
endpoint_type: Optional[EndpointType] = None,
|
|
630
|
+
training_config_name: Optional[str] = None,
|
|
631
|
+
config_name: Optional[str] = None,
|
|
632
|
+
routing_config: Optional[Dict[str, Any]] = None,
|
|
633
|
+
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
|
|
634
|
+
inference_ami_version: Optional[str] = None,
|
|
635
|
+
) -> JumpStartModelDeployKwargs:
|
|
636
|
+
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
|
|
637
|
+
|
|
638
|
+
deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs(
|
|
639
|
+
model_id=model_id,
|
|
640
|
+
model_version=model_version,
|
|
641
|
+
hub_arn=hub_arn,
|
|
642
|
+
model_type=model_type,
|
|
643
|
+
region=region,
|
|
644
|
+
initial_instance_count=initial_instance_count,
|
|
645
|
+
instance_type=instance_type,
|
|
646
|
+
serializer=serializer,
|
|
647
|
+
deserializer=deserializer,
|
|
648
|
+
accelerator_type=accelerator_type,
|
|
649
|
+
endpoint_name=endpoint_name,
|
|
650
|
+
inference_component_name=inference_component_name,
|
|
651
|
+
tags=format_tags(tags),
|
|
652
|
+
kms_key=kms_key,
|
|
653
|
+
wait=wait,
|
|
654
|
+
data_capture_config=data_capture_config,
|
|
655
|
+
async_inference_config=async_inference_config,
|
|
656
|
+
serverless_inference_config=serverless_inference_config,
|
|
657
|
+
volume_size=volume_size,
|
|
658
|
+
model_data_download_timeout=model_data_download_timeout,
|
|
659
|
+
container_startup_health_check_timeout=container_startup_health_check_timeout,
|
|
660
|
+
inference_recommendation_id=inference_recommendation_id,
|
|
661
|
+
explainer_config=explainer_config,
|
|
662
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
663
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
664
|
+
sagemaker_session=sagemaker_session,
|
|
665
|
+
accept_eula=accept_eula,
|
|
666
|
+
model_reference_arn=model_reference_arn,
|
|
667
|
+
endpoint_logging=endpoint_logging,
|
|
668
|
+
resources=resources,
|
|
669
|
+
config_name=config_name,
|
|
670
|
+
routing_config=routing_config,
|
|
671
|
+
model_access_configs=model_access_configs,
|
|
672
|
+
inference_ami_version=inference_ami_version,
|
|
673
|
+
)
|
|
674
|
+
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
|
|
675
|
+
deploy_kwargs.specs = verify_model_region_and_return_specs(
|
|
676
|
+
**get_model_info_default_kwargs(
|
|
677
|
+
deploy_kwargs, include_model_version=False, include_tolerate_flags=False
|
|
678
|
+
),
|
|
679
|
+
version=deploy_kwargs.model_version or "*",
|
|
680
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
681
|
+
# We set these flags to True to retrieve the json specs.
|
|
682
|
+
# Exceptions will be thrown later if these are not tolerated.
|
|
683
|
+
tolerate_deprecated_model=True,
|
|
684
|
+
tolerate_vulnerable_model=True,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
|
|
688
|
+
kwargs=deploy_kwargs, training_config_name=training_config_name
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
|
|
692
|
+
|
|
693
|
+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
|
|
694
|
+
kwargs=deploy_kwargs, orig_session=orig_session
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
|
|
698
|
+
|
|
699
|
+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
|
|
700
|
+
|
|
701
|
+
deploy_kwargs.initial_instance_count = initial_instance_count or 1
|
|
702
|
+
|
|
703
|
+
deploy_kwargs = _add_deploy_extra_kwargs(kwargs=deploy_kwargs)
|
|
704
|
+
|
|
705
|
+
deploy_kwargs = _add_tags_to_kwargs(kwargs=deploy_kwargs)
|
|
706
|
+
|
|
707
|
+
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
|
|
708
|
+
deploy_kwargs = _add_resources_to_kwargs(kwargs=deploy_kwargs)
|
|
709
|
+
deploy_kwargs.endpoint_type = endpoint_type
|
|
710
|
+
deploy_kwargs.managed_instance_scaling = managed_instance_scaling
|
|
711
|
+
|
|
712
|
+
return deploy_kwargs
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def get_init_kwargs(
|
|
716
|
+
model_id: str,
|
|
717
|
+
model_from_estimator: bool = False,
|
|
718
|
+
model_version: Optional[str] = None,
|
|
719
|
+
hub_arn: Optional[str] = None,
|
|
720
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
721
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
722
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
723
|
+
instance_type: Optional[str] = None,
|
|
724
|
+
region: Optional[str] = None,
|
|
725
|
+
image_uri: Optional[Union[str, PipelineVariable]] = None,
|
|
726
|
+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
|
|
727
|
+
role: Optional[str] = None,
|
|
728
|
+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
729
|
+
name: Optional[str] = None,
|
|
730
|
+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
|
|
731
|
+
sagemaker_session: Optional[Session] = None,
|
|
732
|
+
enable_network_isolation: Union[bool, PipelineVariable] = None,
|
|
733
|
+
model_kms_key: Optional[str] = None,
|
|
734
|
+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
735
|
+
source_dir: Optional[str] = None,
|
|
736
|
+
code_location: Optional[str] = None,
|
|
737
|
+
entry_point: Optional[str] = None,
|
|
738
|
+
container_log_level: Optional[Union[int, PipelineVariable]] = None,
|
|
739
|
+
dependencies: Optional[List[str]] = None,
|
|
740
|
+
git_config: Optional[Dict[str, str]] = None,
|
|
741
|
+
model_package_arn: Optional[str] = None,
|
|
742
|
+
training_instance_type: Optional[str] = None,
|
|
743
|
+
disable_instance_type_logging: bool = False,
|
|
744
|
+
resources: Optional[ResourceRequirements] = None,
|
|
745
|
+
config_name: Optional[str] = None,
|
|
746
|
+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
|
|
747
|
+
) -> JumpStartModelInitKwargs:
|
|
748
|
+
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
|
|
749
|
+
|
|
750
|
+
model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs(
|
|
751
|
+
model_id=model_id,
|
|
752
|
+
model_version=model_version,
|
|
753
|
+
hub_arn=hub_arn,
|
|
754
|
+
model_type=model_type,
|
|
755
|
+
instance_type=instance_type,
|
|
756
|
+
region=region,
|
|
757
|
+
image_uri=image_uri,
|
|
758
|
+
model_data=model_data,
|
|
759
|
+
source_dir=source_dir,
|
|
760
|
+
entry_point=entry_point,
|
|
761
|
+
env=env,
|
|
762
|
+
role=role,
|
|
763
|
+
name=name,
|
|
764
|
+
vpc_config=vpc_config,
|
|
765
|
+
sagemaker_session=sagemaker_session,
|
|
766
|
+
enable_network_isolation=enable_network_isolation,
|
|
767
|
+
model_kms_key=model_kms_key,
|
|
768
|
+
image_config=image_config,
|
|
769
|
+
code_location=code_location,
|
|
770
|
+
container_log_level=container_log_level,
|
|
771
|
+
dependencies=dependencies,
|
|
772
|
+
git_config=git_config,
|
|
773
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
774
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
775
|
+
model_package_arn=model_package_arn,
|
|
776
|
+
training_instance_type=training_instance_type,
|
|
777
|
+
resources=resources,
|
|
778
|
+
config_name=config_name,
|
|
779
|
+
additional_model_data_sources=additional_model_data_sources,
|
|
780
|
+
)
|
|
781
|
+
model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
|
|
782
|
+
kwargs=model_init_kwargs
|
|
783
|
+
)
|
|
784
|
+
model_init_kwargs.specs = verify_model_region_and_return_specs(
|
|
785
|
+
**get_model_info_default_kwargs(
|
|
786
|
+
model_init_kwargs, include_model_version=False, include_tolerate_flags=False
|
|
787
|
+
),
|
|
788
|
+
version=model_init_kwargs.model_version or "*",
|
|
789
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
790
|
+
# We set these flags to True to retrieve the json specs.
|
|
791
|
+
# Exceptions will be thrown later if these are not tolerated.
|
|
792
|
+
tolerate_deprecated_model=True,
|
|
793
|
+
tolerate_vulnerable_model=True,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
|
|
797
|
+
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
|
|
798
|
+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
|
|
799
|
+
|
|
800
|
+
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
|
|
801
|
+
kwargs=model_init_kwargs, orig_session=orig_session
|
|
802
|
+
)
|
|
803
|
+
model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
|
|
804
|
+
|
|
805
|
+
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
|
|
806
|
+
|
|
807
|
+
model_init_kwargs = _add_instance_type_to_kwargs(
|
|
808
|
+
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)
|
|
812
|
+
|
|
813
|
+
if hub_arn:
|
|
814
|
+
model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs)
|
|
815
|
+
else:
|
|
816
|
+
model_init_kwargs.model_reference_arn = None
|
|
817
|
+
model_init_kwargs.hub_content_type = None
|
|
818
|
+
|
|
819
|
+
# we use the model artifact from the training job output
|
|
820
|
+
if not model_from_estimator:
|
|
821
|
+
model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs)
|
|
822
|
+
model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs)
|
|
823
|
+
model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs)
|
|
824
|
+
model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs)
|
|
825
|
+
model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs)
|
|
826
|
+
model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs)
|
|
827
|
+
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
|
|
828
|
+
|
|
829
|
+
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
|
|
830
|
+
|
|
831
|
+
model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
|
|
832
|
+
|
|
833
|
+
return model_init_kwargs
|