sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart instance types."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import List, Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker.core.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
|
|
19
|
+
from sagemaker.core.jumpstart.constants import (
|
|
20
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
21
|
+
)
|
|
22
|
+
from sagemaker.core.jumpstart.enums import (
|
|
23
|
+
JumpStartScriptScope,
|
|
24
|
+
JumpStartModelType,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker.core.jumpstart.utils import (
|
|
27
|
+
get_region_fallback,
|
|
28
|
+
verify_model_region_and_return_specs,
|
|
29
|
+
)
|
|
30
|
+
from sagemaker.core.helper.session_helper import Session
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _retrieve_default_instance_type(
|
|
34
|
+
model_id: str,
|
|
35
|
+
model_version: str,
|
|
36
|
+
scope: str,
|
|
37
|
+
hub_arn: Optional[str] = None,
|
|
38
|
+
region: Optional[str] = None,
|
|
39
|
+
tolerate_vulnerable_model: bool = False,
|
|
40
|
+
tolerate_deprecated_model: bool = False,
|
|
41
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
42
|
+
training_instance_type: Optional[str] = None,
|
|
43
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
44
|
+
config_name: Optional[str] = None,
|
|
45
|
+
) -> str:
|
|
46
|
+
"""Retrieves the default instance type for the model.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
50
|
+
retrieve the default instance type.
|
|
51
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
52
|
+
default instance type.
|
|
53
|
+
scope (str): The script type, i.e. what it is used for.
|
|
54
|
+
Valid values: "training" and "inference".
|
|
55
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
56
|
+
model details from. (Default: None).
|
|
57
|
+
region (Optional[str]): Region for which to retrieve default instance type.
|
|
58
|
+
(Default: None).
|
|
59
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
60
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
61
|
+
exception if the script used by this version of the model has dependencies with known
|
|
62
|
+
security vulnerabilities. (Default: False).
|
|
63
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
64
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
65
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
66
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
67
|
+
object, used for SageMaker interactions. If not
|
|
68
|
+
specified, one is created using the default AWS configuration
|
|
69
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
70
|
+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
|
|
71
|
+
instance type used for the training job that produced the fine-tuned weights.
|
|
72
|
+
Optionally supply this to get a inference instance type conditioned
|
|
73
|
+
on the training instance, to ensure compatability of training artifact to inference
|
|
74
|
+
instance. (Default: None).
|
|
75
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
76
|
+
Returns:
|
|
77
|
+
str: the default instance type to use for the model or None.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If the model is not available in the
|
|
81
|
+
specified region due to lack of supported computing instances.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
region = region or get_region_fallback(
|
|
85
|
+
sagemaker_session=sagemaker_session,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
model_specs = verify_model_region_and_return_specs(
|
|
89
|
+
model_id=model_id,
|
|
90
|
+
version=model_version,
|
|
91
|
+
hub_arn=hub_arn,
|
|
92
|
+
scope=scope,
|
|
93
|
+
region=region,
|
|
94
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
95
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
96
|
+
model_type=model_type,
|
|
97
|
+
sagemaker_session=sagemaker_session,
|
|
98
|
+
config_name=config_name,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if scope == JumpStartScriptScope.INFERENCE:
|
|
102
|
+
instance_specific_default_instance_type = (
|
|
103
|
+
(
|
|
104
|
+
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
|
|
105
|
+
training_instance_type
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
if training_instance_type is not None
|
|
109
|
+
and getattr(model_specs, "training_instance_type_variants", None) is not None
|
|
110
|
+
else None
|
|
111
|
+
)
|
|
112
|
+
default_instance_type = (
|
|
113
|
+
instance_specific_default_instance_type
|
|
114
|
+
if instance_specific_default_instance_type is not None
|
|
115
|
+
else model_specs.default_inference_instance_type
|
|
116
|
+
)
|
|
117
|
+
elif scope == JumpStartScriptScope.TRAINING:
|
|
118
|
+
default_instance_type = model_specs.default_training_instance_type
|
|
119
|
+
else:
|
|
120
|
+
raise NotImplementedError(
|
|
121
|
+
f"Unsupported script scope for retrieving default instance type: '{scope}'"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if default_instance_type in {None, ""}:
|
|
125
|
+
raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
|
|
126
|
+
return default_instance_type
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _retrieve_instance_types(
|
|
130
|
+
model_id: str,
|
|
131
|
+
model_version: str,
|
|
132
|
+
scope: str,
|
|
133
|
+
hub_arn: Optional[str] = None,
|
|
134
|
+
region: Optional[str] = None,
|
|
135
|
+
tolerate_vulnerable_model: bool = False,
|
|
136
|
+
tolerate_deprecated_model: bool = False,
|
|
137
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
138
|
+
training_instance_type: Optional[str] = None,
|
|
139
|
+
config_name: Optional[str] = None,
|
|
140
|
+
) -> List[str]:
|
|
141
|
+
"""Retrieves the supported instance types for the model.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
145
|
+
retrieve the supported instance types.
|
|
146
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
147
|
+
supported instance types.
|
|
148
|
+
scope (str): The script type, i.e. what it is used for.
|
|
149
|
+
Valid values: "training" and "inference".
|
|
150
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
151
|
+
model details from. (Default: None).
|
|
152
|
+
region (Optional[str]): Region for which to retrieve supported instance types.
|
|
153
|
+
(Default: None).
|
|
154
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
155
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
156
|
+
exception if the script used by this version of the model has dependencies with known
|
|
157
|
+
security vulnerabilities. (Default: False).
|
|
158
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
159
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
160
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
161
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
162
|
+
object, used for SageMaker interactions. If not
|
|
163
|
+
specified, one is created using the default AWS configuration
|
|
164
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
165
|
+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
|
|
166
|
+
instance type used for the training job that produced the fine-tuned weights.
|
|
167
|
+
Optionally supply this to get a inference instance type conditioned
|
|
168
|
+
on the training instance, to ensure compatability of training artifact to inference
|
|
169
|
+
instance. (Default: None).
|
|
170
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
171
|
+
Returns:
|
|
172
|
+
list: the supported instance types to use for the model or None.
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ValueError: If the model is not available in the
|
|
176
|
+
specified region due to lack of supported computing instances.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
region = region or get_region_fallback(
|
|
180
|
+
sagemaker_session=sagemaker_session,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
model_specs = verify_model_region_and_return_specs(
|
|
184
|
+
model_id=model_id,
|
|
185
|
+
version=model_version,
|
|
186
|
+
hub_arn=hub_arn,
|
|
187
|
+
scope=scope,
|
|
188
|
+
region=region,
|
|
189
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
190
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
191
|
+
sagemaker_session=sagemaker_session,
|
|
192
|
+
config_name=config_name,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if scope == JumpStartScriptScope.INFERENCE:
|
|
196
|
+
default_instance_types = model_specs.supported_inference_instance_types or []
|
|
197
|
+
instance_specific_instance_types = (
|
|
198
|
+
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
|
|
199
|
+
training_instance_type
|
|
200
|
+
)
|
|
201
|
+
if training_instance_type is not None
|
|
202
|
+
and getattr(model_specs, "training_instance_type_variants", None) is not None
|
|
203
|
+
else []
|
|
204
|
+
)
|
|
205
|
+
instance_types = (
|
|
206
|
+
instance_specific_instance_types
|
|
207
|
+
if len(instance_specific_instance_types) > 0
|
|
208
|
+
else default_instance_types
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
elif scope == JumpStartScriptScope.TRAINING:
|
|
212
|
+
if training_instance_type is not None:
|
|
213
|
+
raise ValueError("Cannot use `training_instance_type` argument with training scope.")
|
|
214
|
+
instance_types = model_specs.supported_training_instance_types
|
|
215
|
+
else:
|
|
216
|
+
raise NotImplementedError(
|
|
217
|
+
f"Unsupported script scope for retrieving supported instance types: '{scope}'"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if instance_types is None or len(instance_types) == 0:
|
|
221
|
+
raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
|
|
222
|
+
|
|
223
|
+
return instance_types
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart kwargs."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from copy import deepcopy
|
|
16
|
+
from typing import Optional
|
|
17
|
+
from sagemaker.core.helper.session_helper import Session
|
|
18
|
+
from sagemaker.core.common_utils import volume_size_supported
|
|
19
|
+
from sagemaker.core.jumpstart.constants import (
|
|
20
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
21
|
+
)
|
|
22
|
+
from sagemaker.core.jumpstart.enums import (
|
|
23
|
+
JumpStartScriptScope,
|
|
24
|
+
JumpStartModelType,
|
|
25
|
+
)
|
|
26
|
+
from sagemaker.core.jumpstart.utils import (
|
|
27
|
+
get_region_fallback,
|
|
28
|
+
verify_model_region_and_return_specs,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _retrieve_model_init_kwargs(
|
|
33
|
+
model_id: str,
|
|
34
|
+
model_version: str,
|
|
35
|
+
hub_arn: Optional[str] = None,
|
|
36
|
+
region: Optional[str] = None,
|
|
37
|
+
tolerate_vulnerable_model: bool = False,
|
|
38
|
+
tolerate_deprecated_model: bool = False,
|
|
39
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
40
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
41
|
+
config_name: Optional[str] = None,
|
|
42
|
+
) -> dict:
|
|
43
|
+
"""Retrieves kwargs for `Model`.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
47
|
+
retrieve the kwargs.
|
|
48
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
49
|
+
kwargs.
|
|
50
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
51
|
+
model details from. (Default: None).
|
|
52
|
+
region (Optional[str]): Region for which to retrieve kwargs.
|
|
53
|
+
(Default: None).
|
|
54
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
55
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
56
|
+
exception if the script used by this version of the model has dependencies with known
|
|
57
|
+
security vulnerabilities. (Default: False).
|
|
58
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
59
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
60
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
61
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
62
|
+
object, used for SageMaker interactions. If not
|
|
63
|
+
specified, one is created using the default AWS configuration
|
|
64
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
65
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
66
|
+
Returns:
|
|
67
|
+
dict: the kwargs to use for the use case.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
region = region or get_region_fallback(
|
|
71
|
+
sagemaker_session=sagemaker_session,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
model_specs = verify_model_region_and_return_specs(
|
|
75
|
+
model_id=model_id,
|
|
76
|
+
version=model_version,
|
|
77
|
+
hub_arn=hub_arn,
|
|
78
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
79
|
+
region=region,
|
|
80
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
81
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
82
|
+
sagemaker_session=sagemaker_session,
|
|
83
|
+
model_type=model_type,
|
|
84
|
+
config_name=config_name,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
kwargs = deepcopy(model_specs.model_kwargs)
|
|
88
|
+
|
|
89
|
+
if model_specs.inference_enable_network_isolation is not None:
|
|
90
|
+
kwargs.update({"enable_network_isolation": model_specs.inference_enable_network_isolation})
|
|
91
|
+
|
|
92
|
+
return kwargs
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _retrieve_model_deploy_kwargs(
|
|
96
|
+
model_id: str,
|
|
97
|
+
model_version: str,
|
|
98
|
+
instance_type: str,
|
|
99
|
+
hub_arn: Optional[str] = None,
|
|
100
|
+
region: Optional[str] = None,
|
|
101
|
+
tolerate_vulnerable_model: bool = False,
|
|
102
|
+
tolerate_deprecated_model: bool = False,
|
|
103
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
104
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
105
|
+
config_name: Optional[str] = None,
|
|
106
|
+
) -> dict:
|
|
107
|
+
"""Retrieves kwargs for `Model.deploy`.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
111
|
+
retrieve the kwargs.
|
|
112
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
113
|
+
kwargs.
|
|
114
|
+
instance_type (str): Instance type of the hosting endpoint, to determine if volume size
|
|
115
|
+
is supported.
|
|
116
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
117
|
+
model details from. (Default: None).
|
|
118
|
+
region (Optional[str]): Region for which to retrieve kwargs.
|
|
119
|
+
(Default: None).
|
|
120
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
121
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
122
|
+
exception if the script used by this version of the model has dependencies with known
|
|
123
|
+
security vulnerabilities. (Default: False).
|
|
124
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
125
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
126
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
127
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
128
|
+
object, used for SageMaker interactions. If not
|
|
129
|
+
specified, one is created using the default AWS configuration
|
|
130
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
131
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
dict: the kwargs to use for the use case.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
region = region or get_region_fallback(
|
|
138
|
+
sagemaker_session=sagemaker_session,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
model_specs = verify_model_region_and_return_specs(
|
|
142
|
+
model_id=model_id,
|
|
143
|
+
version=model_version,
|
|
144
|
+
hub_arn=hub_arn,
|
|
145
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
146
|
+
region=region,
|
|
147
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
148
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
149
|
+
sagemaker_session=sagemaker_session,
|
|
150
|
+
model_type=model_type,
|
|
151
|
+
config_name=config_name,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:
|
|
155
|
+
return {**model_specs.deploy_kwargs, **{"volume_size": model_specs.inference_volume_size}}
|
|
156
|
+
|
|
157
|
+
return model_specs.deploy_kwargs
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _retrieve_estimator_init_kwargs(
|
|
161
|
+
model_id: str,
|
|
162
|
+
model_version: str,
|
|
163
|
+
instance_type: str,
|
|
164
|
+
hub_arn: Optional[str] = None,
|
|
165
|
+
region: Optional[str] = None,
|
|
166
|
+
tolerate_vulnerable_model: bool = False,
|
|
167
|
+
tolerate_deprecated_model: bool = False,
|
|
168
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
169
|
+
config_name: Optional[str] = None,
|
|
170
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
171
|
+
) -> dict:
|
|
172
|
+
"""Retrieves kwargs for `Estimator`.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
176
|
+
retrieve the kwargs.
|
|
177
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
178
|
+
kwargs.
|
|
179
|
+
instance_type (str): Instance type of the training job, to determine if volume size is
|
|
180
|
+
supported.
|
|
181
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
182
|
+
model details from. (Default: None).
|
|
183
|
+
region (Optional[str]): Region for which to retrieve kwargs.
|
|
184
|
+
(Default: None).
|
|
185
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
186
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
187
|
+
exception if the script used by this version of the model has dependencies with known
|
|
188
|
+
security vulnerabilities. (Default: False).
|
|
189
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
190
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
191
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
192
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
193
|
+
object, used for SageMaker interactions. If not
|
|
194
|
+
specified, one is created using the default AWS configuration
|
|
195
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
196
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
197
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
198
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
199
|
+
Returns:
|
|
200
|
+
dict: the kwargs to use for the use case.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
region = region or get_region_fallback(
|
|
204
|
+
sagemaker_session=sagemaker_session,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
model_specs = verify_model_region_and_return_specs(
|
|
208
|
+
model_id=model_id,
|
|
209
|
+
version=model_version,
|
|
210
|
+
hub_arn=hub_arn,
|
|
211
|
+
scope=JumpStartScriptScope.TRAINING,
|
|
212
|
+
region=region,
|
|
213
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
214
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
215
|
+
sagemaker_session=sagemaker_session,
|
|
216
|
+
config_name=config_name,
|
|
217
|
+
model_type=model_type,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
kwargs = deepcopy(model_specs.estimator_kwargs)
|
|
221
|
+
|
|
222
|
+
if model_specs.training_enable_network_isolation is not None:
|
|
223
|
+
kwargs.update({"enable_network_isolation": model_specs.training_enable_network_isolation})
|
|
224
|
+
|
|
225
|
+
if volume_size_supported(instance_type) and model_specs.training_volume_size is not None:
|
|
226
|
+
kwargs.update({"volume_size": model_specs.training_volume_size})
|
|
227
|
+
|
|
228
|
+
return kwargs
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _retrieve_estimator_fit_kwargs(
|
|
232
|
+
model_id: str,
|
|
233
|
+
model_version: str,
|
|
234
|
+
hub_arn: Optional[str] = None,
|
|
235
|
+
region: Optional[str] = None,
|
|
236
|
+
tolerate_vulnerable_model: bool = False,
|
|
237
|
+
tolerate_deprecated_model: bool = False,
|
|
238
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
239
|
+
config_name: Optional[str] = None,
|
|
240
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
241
|
+
) -> dict:
|
|
242
|
+
"""Retrieves kwargs for `Estimator.fit`.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
246
|
+
retrieve the kwargs.
|
|
247
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
248
|
+
kwargs.
|
|
249
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
250
|
+
model details from. (Default: None).
|
|
251
|
+
region (Optional[str]): Region for which to retrieve kwargs.
|
|
252
|
+
(Default: None).
|
|
253
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
254
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
255
|
+
exception if the script used by this version of the model has dependencies with known
|
|
256
|
+
security vulnerabilities. (Default: False).
|
|
257
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
258
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
259
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
260
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
261
|
+
object, used for SageMaker interactions. If not
|
|
262
|
+
specified, one is created using the default AWS configuration
|
|
263
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
264
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
265
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
266
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
dict: the kwargs to use for the use case.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
region = region or get_region_fallback(
|
|
273
|
+
sagemaker_session=sagemaker_session,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
model_specs = verify_model_region_and_return_specs(
|
|
277
|
+
model_id=model_id,
|
|
278
|
+
version=model_version,
|
|
279
|
+
hub_arn=hub_arn,
|
|
280
|
+
scope=JumpStartScriptScope.TRAINING,
|
|
281
|
+
region=region,
|
|
282
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
283
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
284
|
+
sagemaker_session=sagemaker_session,
|
|
285
|
+
config_name=config_name,
|
|
286
|
+
model_type=model_type,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return model_specs.fit_kwargs
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains functions for obtaining JumpStart metric definitions."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from copy import deepcopy
|
|
16
|
+
from typing import Dict, List, Optional
|
|
17
|
+
from sagemaker.core.jumpstart.constants import (
|
|
18
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
19
|
+
)
|
|
20
|
+
from sagemaker.core.jumpstart.enums import (
|
|
21
|
+
JumpStartModelType,
|
|
22
|
+
JumpStartScriptScope,
|
|
23
|
+
)
|
|
24
|
+
from sagemaker.core.jumpstart.utils import (
|
|
25
|
+
get_region_fallback,
|
|
26
|
+
verify_model_region_and_return_specs,
|
|
27
|
+
)
|
|
28
|
+
from sagemaker.core.helper.session_helper import Session
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _retrieve_default_training_metric_definitions(
|
|
32
|
+
model_id: str,
|
|
33
|
+
model_version: str,
|
|
34
|
+
region: Optional[str],
|
|
35
|
+
hub_arn: Optional[str] = None,
|
|
36
|
+
tolerate_vulnerable_model: bool = False,
|
|
37
|
+
tolerate_deprecated_model: bool = False,
|
|
38
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
39
|
+
instance_type: Optional[str] = None,
|
|
40
|
+
config_name: Optional[str] = None,
|
|
41
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
42
|
+
) -> Optional[List[Dict[str, str]]]:
|
|
43
|
+
"""Retrieves the default training metric definitions for the model.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_id (str): JumpStart model ID of the JumpStart model for which to
|
|
47
|
+
retrieve the default training metric definitions.
|
|
48
|
+
model_version (str): Version of the JumpStart model for which to retrieve the
|
|
49
|
+
default training metric definitions.
|
|
50
|
+
region (Optional[str]): Region for which to retrieve default training metric
|
|
51
|
+
definitions.
|
|
52
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
53
|
+
model details from. (Default: None).
|
|
54
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
55
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
56
|
+
exception if the script used by this version of the model has dependencies with known
|
|
57
|
+
security vulnerabilities. (Default: False).
|
|
58
|
+
tolerate_deprecated_model (bool): True if deprecated versions of model
|
|
59
|
+
specifications should be tolerated (exception not raised). If False, raises
|
|
60
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
61
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
62
|
+
object, used for SageMaker interactions. If not
|
|
63
|
+
specified, one is created using the default AWS configuration
|
|
64
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
65
|
+
instance_type (str): An instance type to optionally supply in order to get
|
|
66
|
+
metric definitions specific for the instance type.
|
|
67
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
68
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
69
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
70
|
+
Returns:
|
|
71
|
+
list: the default training metric definitions to use for the model or None.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
region = region or get_region_fallback(
|
|
75
|
+
sagemaker_session=sagemaker_session,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
model_specs = verify_model_region_and_return_specs(
|
|
79
|
+
model_id=model_id,
|
|
80
|
+
version=model_version,
|
|
81
|
+
hub_arn=hub_arn,
|
|
82
|
+
scope=JumpStartScriptScope.TRAINING,
|
|
83
|
+
region=region,
|
|
84
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
85
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
86
|
+
sagemaker_session=sagemaker_session,
|
|
87
|
+
config_name=config_name,
|
|
88
|
+
model_type=model_type,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
default_metric_definitions = (
|
|
92
|
+
deepcopy(model_specs.metrics) if getattr(model_specs, "metrics") else []
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
instance_specific_metric_definitions = (
|
|
96
|
+
model_specs.training_instance_type_variants.get_instance_specific_metric_definitions(
|
|
97
|
+
instance_type
|
|
98
|
+
)
|
|
99
|
+
if instance_type
|
|
100
|
+
and getattr(model_specs, "training_instance_type_variants", None) is not None
|
|
101
|
+
else []
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if instance_specific_metric_definitions:
|
|
105
|
+
instance_specific_metric_name: str
|
|
106
|
+
for instance_specific_metric_definition in instance_specific_metric_definitions:
|
|
107
|
+
instance_specific_metric_name = instance_specific_metric_definition["Name"]
|
|
108
|
+
default_metric_definitions = list(
|
|
109
|
+
filter(
|
|
110
|
+
lambda metric_definition: metric_definition["Name"]
|
|
111
|
+
!= instance_specific_metric_name,
|
|
112
|
+
default_metric_definitions,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
default_metric_definitions.append(instance_specific_metric_definition)
|
|
116
|
+
|
|
117
|
+
return default_metric_definitions
|