sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1176 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Utility methods used by framework classes."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import shutil
|
|
21
|
+
import tempfile
|
|
22
|
+
import time
|
|
23
|
+
from collections import namedtuple
|
|
24
|
+
from typing import Dict, List, Optional, Union
|
|
25
|
+
|
|
26
|
+
from packaging import version
|
|
27
|
+
|
|
28
|
+
import sagemaker.core.common_utils as sagemaker_utils
|
|
29
|
+
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
|
|
30
|
+
from sagemaker.core.instance_group import InstanceGroup
|
|
31
|
+
from sagemaker.core.s3 import s3_path_join
|
|
32
|
+
from sagemaker.core.session_settings import SessionSettings
|
|
33
|
+
from sagemaker.core.workflow import is_pipeline_variable
|
|
34
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
_TAR_SOURCE_FILENAME = "source.tar.gz"
|
|
39
|
+
|
|
40
|
+
UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
|
|
41
|
+
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
|
|
42
|
+
|
|
43
|
+
This is for the source code used for the entry point with an ``Estimator``. It can be
|
|
44
|
+
instantiated with positional or keyword arguments.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
PYTHON_2_DEPRECATION_WARNING = (
|
|
48
|
+
"{latest_supported_version} is the latest version of {framework} that supports "
|
|
49
|
+
"Python 2. Newer versions of {framework} will only be available for Python 3."
|
|
50
|
+
"Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
|
|
51
|
+
)
|
|
52
|
+
PARAMETER_SERVER_MULTI_GPU_WARNING = (
|
|
53
|
+
"If you have selected a multi-GPU training instance type "
|
|
54
|
+
"and also enabled parameter server for distributed training, "
|
|
55
|
+
"distributed training with the default parameter server configuration will not "
|
|
56
|
+
"fully leverage all GPU cores; the parameter server will be configured to run "
|
|
57
|
+
"only one worker per host regardless of the number of GPUs."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
DEBUGGER_UNSUPPORTED_REGIONS = (
|
|
61
|
+
"us-iso-east-1",
|
|
62
|
+
"us-isob-east-1",
|
|
63
|
+
"ap-southeast-3",
|
|
64
|
+
"ap-southeast-4",
|
|
65
|
+
"eu-south-2",
|
|
66
|
+
"me-central-1",
|
|
67
|
+
"ap-south-2",
|
|
68
|
+
"eu-central-2",
|
|
69
|
+
"us-gov-east-1",
|
|
70
|
+
)
|
|
71
|
+
PROFILER_UNSUPPORTED_REGIONS = (
|
|
72
|
+
"us-iso-east-1",
|
|
73
|
+
"us-isob-east-1",
|
|
74
|
+
"ap-southeast-3",
|
|
75
|
+
"ap-southeast-4",
|
|
76
|
+
"eu-south-2",
|
|
77
|
+
"me-central-1",
|
|
78
|
+
"ap-south-2",
|
|
79
|
+
"eu-central-2",
|
|
80
|
+
"us-gov-east-1",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
|
|
84
|
+
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
|
|
85
|
+
"ml.p3.16xlarge",
|
|
86
|
+
"ml.p3dn.24xlarge",
|
|
87
|
+
"ml.p4d.24xlarge",
|
|
88
|
+
"ml.p4de.24xlarge",
|
|
89
|
+
"local_gpu",
|
|
90
|
+
)
|
|
91
|
+
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
|
|
92
|
+
# tf 2.12 should not be supported: smdataparallel excludes support for tf>=2.12.
|
|
93
|
+
"tensorflow": [
|
|
94
|
+
"2.3",
|
|
95
|
+
"2.3.1",
|
|
96
|
+
"2.3.2",
|
|
97
|
+
"2.4",
|
|
98
|
+
"2.4.1",
|
|
99
|
+
"2.4.3",
|
|
100
|
+
"2.5",
|
|
101
|
+
"2.5.0",
|
|
102
|
+
"2.5.1",
|
|
103
|
+
"2.6",
|
|
104
|
+
"2.6.0",
|
|
105
|
+
"2.6.2",
|
|
106
|
+
"2.6.3",
|
|
107
|
+
"2.7",
|
|
108
|
+
"2.7.1",
|
|
109
|
+
"2.8",
|
|
110
|
+
"2.8.0",
|
|
111
|
+
"2.9",
|
|
112
|
+
"2.9.1",
|
|
113
|
+
"2.9.2",
|
|
114
|
+
"2.10",
|
|
115
|
+
"2.10.1",
|
|
116
|
+
"2.11",
|
|
117
|
+
"2.11.0",
|
|
118
|
+
],
|
|
119
|
+
"pytorch": [
|
|
120
|
+
"1.6",
|
|
121
|
+
"1.6.0",
|
|
122
|
+
"1.7",
|
|
123
|
+
"1.7.1",
|
|
124
|
+
"1.8",
|
|
125
|
+
"1.8.0",
|
|
126
|
+
"1.8.1",
|
|
127
|
+
"1.9",
|
|
128
|
+
"1.9.0",
|
|
129
|
+
"1.9.1",
|
|
130
|
+
"1.10",
|
|
131
|
+
"1.10.0",
|
|
132
|
+
"1.10.2",
|
|
133
|
+
"1.11",
|
|
134
|
+
"1.11.0",
|
|
135
|
+
"1.12",
|
|
136
|
+
"1.12.0",
|
|
137
|
+
"1.12.1",
|
|
138
|
+
"1.13.1",
|
|
139
|
+
"2.0.0",
|
|
140
|
+
"2.0.1",
|
|
141
|
+
"2.1.0",
|
|
142
|
+
"2.1.2",
|
|
143
|
+
"2.2.0",
|
|
144
|
+
],
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
|
|
148
|
+
"1.13.1",
|
|
149
|
+
"2.0.0",
|
|
150
|
+
"2.0.1",
|
|
151
|
+
"2.1.0",
|
|
152
|
+
"2.1.2",
|
|
153
|
+
"2.2.0",
|
|
154
|
+
"2.3.0",
|
|
155
|
+
"2.3.1",
|
|
156
|
+
"2.4.1",
|
|
157
|
+
"2.5.1",
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
|
|
161
|
+
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
|
|
162
|
+
"1.11",
|
|
163
|
+
"1.11.0",
|
|
164
|
+
"1.12",
|
|
165
|
+
"1.12.0",
|
|
166
|
+
"1.12.1",
|
|
167
|
+
"1.13.1",
|
|
168
|
+
"2.0.0",
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY = [
|
|
175
|
+
"m6g",
|
|
176
|
+
"m6gd",
|
|
177
|
+
"c6g",
|
|
178
|
+
"c6gd",
|
|
179
|
+
"c6gn",
|
|
180
|
+
"c7g",
|
|
181
|
+
"r6g",
|
|
182
|
+
"r6gd",
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
GRAVITON_ALLOWED_FRAMEWORKS = set(["tensorflow", "pytorch", "xgboost", "sklearn"])
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def validate_source_dir(script, directory):
|
|
190
|
+
"""Validate that the source directory exists and it contains the user script.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
script (str): Script filename.
|
|
194
|
+
directory (str): Directory containing the source file.
|
|
195
|
+
Raises:
|
|
196
|
+
ValueError: If ``directory`` does not exist, is not a directory, or does
|
|
197
|
+
not contain ``script``.
|
|
198
|
+
"""
|
|
199
|
+
if directory:
|
|
200
|
+
if not os.path.isfile(os.path.join(directory, script)):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
'No file named "{}" was found in directory "{}".'.format(script, directory)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def validate_source_code_input_against_pipeline_variables(
|
|
209
|
+
entry_point: Optional[Union[str, PipelineVariable]] = None,
|
|
210
|
+
source_dir: Optional[Union[str, PipelineVariable]] = None,
|
|
211
|
+
git_config: Optional[Dict[str, str]] = None,
|
|
212
|
+
enable_network_isolation: Union[bool, PipelineVariable] = False,
|
|
213
|
+
):
|
|
214
|
+
"""Validate source code input against pipeline variables.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
entry_point (str or PipelineVariable): The path to the local Python source file that
|
|
218
|
+
should be executed as the entry point to training (default: None).
|
|
219
|
+
source_dir (str or PipelineVariable): The Path to a directory with any other
|
|
220
|
+
training source code dependencies aside from the entry point file (default: None).
|
|
221
|
+
git_config (Dict[str, str]): Git configurations used for cloning files (default: None).
|
|
222
|
+
enable_network_isolation (bool or PipelineVariable): Specifies whether container will run
|
|
223
|
+
in network isolation mode (default: False).
|
|
224
|
+
"""
|
|
225
|
+
if is_pipeline_variable(enable_network_isolation) or enable_network_isolation is True:
|
|
226
|
+
if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
|
|
227
|
+
raise TypeError(
|
|
228
|
+
"entry_point, source_dir should not be pipeline variables "
|
|
229
|
+
"when enable_network_isolation is a pipeline variable or it is set to True."
|
|
230
|
+
)
|
|
231
|
+
if git_config:
|
|
232
|
+
if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir):
|
|
233
|
+
raise TypeError(
|
|
234
|
+
"entry_point, source_dir should not be pipeline variables when git_config is given."
|
|
235
|
+
)
|
|
236
|
+
if is_pipeline_variable(entry_point):
|
|
237
|
+
if not source_dir:
|
|
238
|
+
raise TypeError(
|
|
239
|
+
"The entry_point should not be a pipeline variable when source_dir is missing."
|
|
240
|
+
)
|
|
241
|
+
if not is_pipeline_variable(source_dir) and not source_dir.lower().startswith("s3://"):
|
|
242
|
+
raise TypeError(
|
|
243
|
+
"The entry_point should not be a pipeline variable when source_dir is a local path."
|
|
244
|
+
)
|
|
245
|
+
logger.warning(
|
|
246
|
+
"The entry_point is a pipeline variable: %s. During pipeline execution, "
|
|
247
|
+
"the interpreted value of entry_point has to be a local path in the container "
|
|
248
|
+
"pointing to a Python source file which is located at the root of source_dir.",
|
|
249
|
+
type(entry_point),
|
|
250
|
+
)
|
|
251
|
+
if is_pipeline_variable(source_dir):
|
|
252
|
+
logger.warning(
|
|
253
|
+
"The source_dir is a pipeline variable: %s. During pipeline execution, "
|
|
254
|
+
"the interpreted value of source_dir has to be an S3 URI and "
|
|
255
|
+
"must point to a file with name ``sourcedir.tar.gz``",
|
|
256
|
+
type(source_dir),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def parse_mp_parameters(params):
|
|
261
|
+
"""Parse the model parallelism parameters provided by the user.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
params: a string representing path to an existing config, or
|
|
265
|
+
a config dict.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
parsed: a dict of parsed config.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
ValueError: if params is not a string or a dict, or
|
|
272
|
+
the config file cannot be parsed as json.
|
|
273
|
+
"""
|
|
274
|
+
parsed = None
|
|
275
|
+
if isinstance(params, dict):
|
|
276
|
+
parsed = params
|
|
277
|
+
elif os.path.exists(params):
|
|
278
|
+
try:
|
|
279
|
+
with open(params, "r") as fp:
|
|
280
|
+
parsed = json.load(fp)
|
|
281
|
+
except json.decoder.JSONDecodeError:
|
|
282
|
+
pass
|
|
283
|
+
else:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
|
|
286
|
+
f"Received: {params}."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if parsed is None:
|
|
290
|
+
raise ValueError(f"Cannot parse {params} as a json file.")
|
|
291
|
+
|
|
292
|
+
return parsed
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def get_mp_parameters(distribution):
|
|
296
|
+
"""Get the model parallelism parameters provided by the user.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
distribution: distribution dictionary defined by the user.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
params: dictionary containing model parallelism parameters
|
|
303
|
+
used for training.
|
|
304
|
+
"""
|
|
305
|
+
try:
|
|
306
|
+
mp_dict = distribution["smdistributed"]["modelparallel"]
|
|
307
|
+
except KeyError:
|
|
308
|
+
mp_dict = {}
|
|
309
|
+
if mp_dict.get("enabled", False) is True:
|
|
310
|
+
params = mp_dict.get("parameters", {})
|
|
311
|
+
params = parse_mp_parameters(params)
|
|
312
|
+
validate_mp_config(params)
|
|
313
|
+
return params
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def validate_mp_config(config):
|
|
318
|
+
"""Validate the configuration dictionary for model parallelism.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
config (dict): Dictionary holding configuration keys and values.
|
|
322
|
+
|
|
323
|
+
Raises:
|
|
324
|
+
ValueError: If any of the keys have incorrect values.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
def validate_positive(key):
|
|
328
|
+
try:
|
|
329
|
+
if not isinstance(config[key], int) or config[key] < 1:
|
|
330
|
+
raise ValueError(f"The number of {key} must be a positive integer.")
|
|
331
|
+
except KeyError:
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
def validate_in(key, vals):
|
|
335
|
+
try:
|
|
336
|
+
if config[key] not in vals:
|
|
337
|
+
raise ValueError(f"{key} must be a value in: {vals}.")
|
|
338
|
+
except KeyError:
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
def validate_bool(keys):
|
|
342
|
+
validate_in(keys, [True, False])
|
|
343
|
+
|
|
344
|
+
validate_in("pipeline", ["simple", "interleaved", "_only_forward"])
|
|
345
|
+
validate_in("placement_strategy", ["spread", "cluster"])
|
|
346
|
+
validate_in("optimize", ["speed", "memory"])
|
|
347
|
+
|
|
348
|
+
for key in ["microbatches", "partitions", "active_microbatches"]:
|
|
349
|
+
validate_positive(key)
|
|
350
|
+
|
|
351
|
+
for key in [
|
|
352
|
+
"auto_partition",
|
|
353
|
+
"contiguous",
|
|
354
|
+
"load_partition",
|
|
355
|
+
"horovod",
|
|
356
|
+
"ddp",
|
|
357
|
+
"deterministic_server",
|
|
358
|
+
]:
|
|
359
|
+
validate_bool(key)
|
|
360
|
+
|
|
361
|
+
if "partition_file" in config and not isinstance(config.get("partition_file"), str):
|
|
362
|
+
raise ValueError("'partition_file' must be a str.")
|
|
363
|
+
|
|
364
|
+
if config.get("auto_partition") is False and "default_partition" not in config:
|
|
365
|
+
raise ValueError("default_partition must be supplied if auto_partition is set to False!")
|
|
366
|
+
|
|
367
|
+
if "default_partition" in config and config["default_partition"] >= config["partitions"]:
|
|
368
|
+
raise ValueError("default_partition must be less than the number of partitions!")
|
|
369
|
+
|
|
370
|
+
if "memory_weight" in config and (
|
|
371
|
+
config["memory_weight"] > 1.0 or config["memory_weight"] < 0.0
|
|
372
|
+
):
|
|
373
|
+
raise ValueError("memory_weight must be between 0.0 and 1.0!")
|
|
374
|
+
|
|
375
|
+
if "ddp_port" in config and "ddp" not in config:
|
|
376
|
+
raise ValueError("`ddp_port` needs `ddp` to be set as well")
|
|
377
|
+
|
|
378
|
+
if "ddp_dist_backend" in config and "ddp" not in config:
|
|
379
|
+
raise ValueError("`ddp_dist_backend` needs `ddp` to be set as well")
|
|
380
|
+
|
|
381
|
+
if "ddp_port" in config:
|
|
382
|
+
if not isinstance(config["ddp_port"], int) or config["ddp_port"] < 0:
|
|
383
|
+
value = config["ddp_port"]
|
|
384
|
+
raise ValueError(f"Invalid port number {value}.")
|
|
385
|
+
|
|
386
|
+
if config.get("horovod", False) and config.get("ddp", False):
|
|
387
|
+
raise ValueError("'ddp' and 'horovod' cannot be simultaneously enabled.")
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def tar_and_upload_dir(
|
|
391
|
+
session,
|
|
392
|
+
bucket,
|
|
393
|
+
s3_key_prefix,
|
|
394
|
+
script,
|
|
395
|
+
directory=None,
|
|
396
|
+
dependencies=None,
|
|
397
|
+
kms_key=None,
|
|
398
|
+
s3_resource=None,
|
|
399
|
+
settings: Optional[SessionSettings] = None,
|
|
400
|
+
) -> UploadedCode:
|
|
401
|
+
"""Package source files and upload a compress tar file to S3.
|
|
402
|
+
|
|
403
|
+
The S3 location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
|
|
404
|
+
If directory is an S3 URI, an UploadedCode object will be returned, but
|
|
405
|
+
nothing will be uploaded to S3 (this allow reuse of code already in S3).
|
|
406
|
+
If directory is None, the script will be added to the archive at
|
|
407
|
+
``./<basename of script>``. If directory is not None, the (recursive) contents
|
|
408
|
+
of the directory will be added to the archive. directory is treated as the base
|
|
409
|
+
path of the archive, and the script name is assumed to be a filename or relative path
|
|
410
|
+
inside the directory.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
session (boto3.Session): Boto session used to access S3.
|
|
414
|
+
bucket (str): S3 bucket to which the compressed file is uploaded.
|
|
415
|
+
s3_key_prefix (str): Prefix for the S3 key.
|
|
416
|
+
script (str): Script filename or path.
|
|
417
|
+
directory (str): Optional. Directory containing the source file. If it
|
|
418
|
+
starts with "s3://", no action is taken.
|
|
419
|
+
dependencies (List[str]): Optional. A list of paths to directories
|
|
420
|
+
(absolute or relative) containing additional libraries that will be
|
|
421
|
+
copied into /opt/ml/lib
|
|
422
|
+
kms_key (str): Optional. KMS key ID used to upload objects to the bucket
|
|
423
|
+
(default: None).
|
|
424
|
+
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
|
|
425
|
+
for S3 connections, can be used to customize the configuration,
|
|
426
|
+
e.g. set the endpoint URL (default: None).
|
|
427
|
+
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
|
|
428
|
+
of the SageMaker ``Session``, can be used to override the default encryption
|
|
429
|
+
behavior (default: None).
|
|
430
|
+
Returns:
|
|
431
|
+
sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and
|
|
432
|
+
script name.
|
|
433
|
+
"""
|
|
434
|
+
if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")):
|
|
435
|
+
return UploadedCode(s3_prefix=directory, script_name=script)
|
|
436
|
+
|
|
437
|
+
script_name = script if directory else os.path.basename(script)
|
|
438
|
+
dependencies = dependencies or []
|
|
439
|
+
key = "%s/sourcedir.tar.gz" % s3_key_prefix
|
|
440
|
+
if (
|
|
441
|
+
settings is not None
|
|
442
|
+
and settings.local_download_dir is not None
|
|
443
|
+
and not (
|
|
444
|
+
os.path.exists(settings.local_download_dir)
|
|
445
|
+
and os.path.isdir(settings.local_download_dir)
|
|
446
|
+
)
|
|
447
|
+
):
|
|
448
|
+
raise ValueError(
|
|
449
|
+
"Inputted directory for storing newly generated temporary directory does "
|
|
450
|
+
f"not exist: '{settings.local_download_dir}'"
|
|
451
|
+
)
|
|
452
|
+
local_download_dir = None if settings is None else settings.local_download_dir
|
|
453
|
+
tmp = tempfile.mkdtemp(dir=local_download_dir)
|
|
454
|
+
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
source_files = _list_files_to_compress(script, directory) + dependencies
|
|
458
|
+
tar_file = sagemaker_utils.create_tar_file(
|
|
459
|
+
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if kms_key:
|
|
463
|
+
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
|
|
464
|
+
elif encrypt_artifact:
|
|
465
|
+
# encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
|
|
466
|
+
# see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
|
|
467
|
+
extra_args = {"ServerSideEncryption": "aws:kms"}
|
|
468
|
+
else:
|
|
469
|
+
extra_args = None
|
|
470
|
+
|
|
471
|
+
if s3_resource is None:
|
|
472
|
+
s3_resource = session.resource("s3", region_name=session.region_name)
|
|
473
|
+
else:
|
|
474
|
+
logger.debug("Using provided s3_resource")
|
|
475
|
+
|
|
476
|
+
s3_resource.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
|
|
477
|
+
finally:
|
|
478
|
+
shutil.rmtree(tmp)
|
|
479
|
+
|
|
480
|
+
return UploadedCode(s3_prefix="s3://%s/%s" % (bucket, key), script_name=script_name)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _list_files_to_compress(script, directory):
|
|
484
|
+
"""Placeholder docstring."""
|
|
485
|
+
if directory is None:
|
|
486
|
+
return [script]
|
|
487
|
+
|
|
488
|
+
basedir = directory if directory else os.path.dirname(script)
|
|
489
|
+
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def framework_name_from_image(image_uri):
|
|
493
|
+
# noinspection LongLine
|
|
494
|
+
"""Extract the framework and Python version from the image name.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
image_uri (str): Image URI, which should be one of the following forms:
|
|
498
|
+
legacy:
|
|
499
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<container_version>'
|
|
500
|
+
legacy:
|
|
501
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<fw_version>-<device>-<py_ver>'
|
|
502
|
+
current:
|
|
503
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
|
|
504
|
+
current:
|
|
505
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
|
|
506
|
+
current:
|
|
507
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
|
|
508
|
+
current:
|
|
509
|
+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-xgboost:<fw_version>-<container_version>'
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
tuple: A tuple containing:
|
|
513
|
+
|
|
514
|
+
- str: The framework name
|
|
515
|
+
- str: The Python version
|
|
516
|
+
- str: The image tag
|
|
517
|
+
- str: If the TensorFlow image is script mode
|
|
518
|
+
"""
|
|
519
|
+
sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
|
|
520
|
+
sagemaker_match = sagemaker_pattern.match(image_uri)
|
|
521
|
+
if sagemaker_match is None:
|
|
522
|
+
return None, None, None, None
|
|
523
|
+
|
|
524
|
+
# extract framework, python version and image tag
|
|
525
|
+
# We must support both the legacy and current image name format.
|
|
526
|
+
name_pattern = re.compile(
|
|
527
|
+
r"""^(?:sagemaker(?:-rl)?-)?
|
|
528
|
+
(tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
|
|
529
|
+
|huggingface-tensorflow|huggingface-pytorch
|
|
530
|
+
|huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
|
|
531
|
+
(scriptmode|training)?
|
|
532
|
+
:(.*)-(.*?)-(py2|py3\d*)(?:.*)$""",
|
|
533
|
+
re.VERBOSE,
|
|
534
|
+
)
|
|
535
|
+
name_match = name_pattern.match(sagemaker_match.group(9))
|
|
536
|
+
if name_match is not None:
|
|
537
|
+
fw, scriptmode, ver, device, py = (
|
|
538
|
+
name_match.group(1),
|
|
539
|
+
name_match.group(2),
|
|
540
|
+
name_match.group(3),
|
|
541
|
+
name_match.group(4),
|
|
542
|
+
name_match.group(5),
|
|
543
|
+
)
|
|
544
|
+
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
|
|
545
|
+
|
|
546
|
+
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
|
|
547
|
+
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
|
|
548
|
+
if legacy_match is not None:
|
|
549
|
+
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
|
|
550
|
+
|
|
551
|
+
# sagemaker-xgboost images are tagged with two aliases, e.g.:
|
|
552
|
+
# 1. Long tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1-cpu-py3"
|
|
553
|
+
# 2. Short tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1"
|
|
554
|
+
# Note 1: Both tags point to the same image
|
|
555
|
+
# Note 2: Both tags have full GPU capabilities, despite "cpu" delineation in the long tag
|
|
556
|
+
short_xgboost_tag_pattern = re.compile(r"^sagemaker-(xgboost):(.*)$")
|
|
557
|
+
short_xgboost_tag_match = short_xgboost_tag_pattern.match(sagemaker_match.group(9))
|
|
558
|
+
if short_xgboost_tag_match is not None:
|
|
559
|
+
return (short_xgboost_tag_match.group(1), "py3", short_xgboost_tag_match.group(2), None)
|
|
560
|
+
return None, None, None, None
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def framework_version_from_tag(image_tag):
|
|
564
|
+
"""Extract the framework version from the image tag.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
image_tag (str): Image tag, which should take the form
|
|
568
|
+
'<framework_version>-<device>-<py_version>'
|
|
569
|
+
'<xgboost_version>-<container_version>'
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
str: The framework version.
|
|
573
|
+
"""
|
|
574
|
+
tag_pattern = re.compile(r"^(.*)-(cpu|gpu)-(py2|py3\d*)$")
|
|
575
|
+
tag_match = tag_pattern.match(image_tag)
|
|
576
|
+
if tag_match is None:
|
|
577
|
+
short_xgboost_tag_pattern = re.compile(r"^(\d\.\d+\-\d)$")
|
|
578
|
+
tag_match = short_xgboost_tag_pattern.match(image_tag)
|
|
579
|
+
return None if tag_match is None else tag_match.group(1)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def model_code_key_prefix(code_location_key_prefix, model_name, image):
|
|
583
|
+
"""Returns the s3 key prefix for uploading code during model deployment.
|
|
584
|
+
|
|
585
|
+
The location returned is a potential concatenation of 2 parts
|
|
586
|
+
1. code_location_key_prefix if it exists
|
|
587
|
+
2. model_name or a name derived from the image
|
|
588
|
+
Args:
|
|
589
|
+
code_location_key_prefix (str): the s3 key prefix from code_location
|
|
590
|
+
model_name (str): the name of the model
|
|
591
|
+
image (str): the image from which a default name can be extracted
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
str: the key prefix to be used in uploading code
|
|
595
|
+
"""
|
|
596
|
+
name_from_image = f"/model_code/{int(time.time())}"
|
|
597
|
+
if not is_pipeline_variable(image):
|
|
598
|
+
name_from_image = sagemaker_utils.name_from_image(image)
|
|
599
|
+
return s3_path_join(code_location_key_prefix, model_name or name_from_image)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution):
|
|
603
|
+
"""Warn the user about training when it doesn't leverage all the GPU cores.
|
|
604
|
+
|
|
605
|
+
Warn the user that training will not fully leverage all the GPU
|
|
606
|
+
cores if parameter server is enabled and a multi-GPU instance is selected.
|
|
607
|
+
Distributed training with the default parameter server setup doesn't
|
|
608
|
+
support multi-GPU instances.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
training_instance_type (str): A string representing the type of training instance selected.
|
|
612
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
613
|
+
(Defaults to None if distributed training is not enabled.) For example:
|
|
614
|
+
|
|
615
|
+
.. code:: python
|
|
616
|
+
|
|
617
|
+
{
|
|
618
|
+
"parameter_server": {
|
|
619
|
+
"enabled": True
|
|
620
|
+
}
|
|
621
|
+
}
|
|
622
|
+
"""
|
|
623
|
+
if training_instance_type == "local" or distribution is None:
|
|
624
|
+
return
|
|
625
|
+
if is_pipeline_variable(training_instance_type):
|
|
626
|
+
# The training_instance_type is not available in compile time.
|
|
627
|
+
# Rather, it's given in Pipeline execution time
|
|
628
|
+
return
|
|
629
|
+
|
|
630
|
+
is_multi_gpu_instance = (
|
|
631
|
+
training_instance_type == "local_gpu"
|
|
632
|
+
or training_instance_type.split(".")[1].startswith("p")
|
|
633
|
+
) and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
|
|
634
|
+
|
|
635
|
+
ps_enabled = "parameter_server" in distribution and distribution["parameter_server"].get(
|
|
636
|
+
"enabled", False
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
if is_multi_gpu_instance and ps_enabled:
|
|
640
|
+
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def profiler_config_deprecation_warning(
|
|
644
|
+
profiler_config, image_uri, framework_name, framework_version
|
|
645
|
+
):
|
|
646
|
+
"""Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0."""
|
|
647
|
+
if profiler_config is None or profiler_config.framework_profile_params is None:
|
|
648
|
+
return
|
|
649
|
+
|
|
650
|
+
if framework_name not in ("pytorch", "tensorflow"):
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
if framework_version is None:
|
|
654
|
+
framework_name, _, image_tag, _ = framework_name_from_image(image_uri)
|
|
655
|
+
|
|
656
|
+
if image_tag is not None:
|
|
657
|
+
framework_version = framework_version_from_tag(image_tag)
|
|
658
|
+
|
|
659
|
+
if framework_version is not None:
|
|
660
|
+
framework_profile_thresh = (
|
|
661
|
+
version.parse("2.0") if framework_name == "pytorch" else version.parse("2.12")
|
|
662
|
+
)
|
|
663
|
+
framework_profile = version.parse(framework_version)
|
|
664
|
+
if framework_profile >= framework_profile_thresh:
|
|
665
|
+
deprecation_warn_base(
|
|
666
|
+
f"Framework profiling is deprecated from\
|
|
667
|
+
{framework_name} version {framework_version}.\
|
|
668
|
+
No framework metrics will be collected"
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def validate_smdistributed(
|
|
673
|
+
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
|
|
674
|
+
):
|
|
675
|
+
"""Check if smdistributed strategy is correctly invoked by the user.
|
|
676
|
+
|
|
677
|
+
Currently, two strategies are supported: `dataparallel` or `modelparallel`.
|
|
678
|
+
Validate if the user requested strategy is supported.
|
|
679
|
+
|
|
680
|
+
Currently, only one strategy can be specified at a time. Validate if the user has requested
|
|
681
|
+
more than one strategy simultaneously.
|
|
682
|
+
|
|
683
|
+
Validate if the smdistributed dict arg is syntactically correct.
|
|
684
|
+
|
|
685
|
+
Additionally, perform strategy-specific validations.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
instance_type (str): A string representing the type of training instance selected.
|
|
689
|
+
framework_name (str): A string representing the name of framework selected.
|
|
690
|
+
framework_version (str): A string representing the framework version selected.
|
|
691
|
+
py_version (str): A string representing the python version selected.
|
|
692
|
+
Ex: `py38, py39, py310, py311`
|
|
693
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
694
|
+
(Defaults to None if distributed training is not enabled.) For example:
|
|
695
|
+
|
|
696
|
+
.. code:: python
|
|
697
|
+
|
|
698
|
+
{
|
|
699
|
+
"smdistributed": {
|
|
700
|
+
"dataparallel": {
|
|
701
|
+
"enabled": True
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
image_uri (str): A string representing a Docker image URI.
|
|
706
|
+
|
|
707
|
+
Raises:
|
|
708
|
+
ValueError: if distribution dictionary isn't correctly formatted or
|
|
709
|
+
multiple strategies are requested simultaneously or
|
|
710
|
+
an unsupported strategy is requested or
|
|
711
|
+
strategy-specific inputs are incorrect/unsupported
|
|
712
|
+
"""
|
|
713
|
+
if "smdistributed" not in distribution:
|
|
714
|
+
# Distribution strategy other than smdistributed is selected
|
|
715
|
+
return
|
|
716
|
+
if is_pipeline_variable(instance_type) or is_pipeline_variable(image_uri):
|
|
717
|
+
# The instance_type is not available in compile time.
|
|
718
|
+
# Rather, it's given in Pipeline execution time
|
|
719
|
+
return
|
|
720
|
+
|
|
721
|
+
# distribution contains smdistributed
|
|
722
|
+
smdistributed = distribution["smdistributed"]
|
|
723
|
+
if not isinstance(smdistributed, dict):
|
|
724
|
+
raise ValueError("smdistributed strategy requires a dictionary")
|
|
725
|
+
|
|
726
|
+
if len(smdistributed) > 1:
|
|
727
|
+
# more than 1 smdistributed strategy requested by the user
|
|
728
|
+
err_msg = (
|
|
729
|
+
"Cannot use more than 1 smdistributed strategy. \n"
|
|
730
|
+
"Choose one of the following supported strategies:"
|
|
731
|
+
f"{SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
|
|
732
|
+
)
|
|
733
|
+
raise ValueError(err_msg)
|
|
734
|
+
|
|
735
|
+
# validate if smdistributed strategy is supported
|
|
736
|
+
# currently this for loop essentially checks for only 1 key
|
|
737
|
+
for strategy in smdistributed:
|
|
738
|
+
if strategy not in SMDISTRIBUTED_SUPPORTED_STRATEGIES:
|
|
739
|
+
err_msg = (
|
|
740
|
+
f"Invalid smdistributed strategy provided: {strategy} \n"
|
|
741
|
+
f"Supported strategies: {SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
|
|
742
|
+
)
|
|
743
|
+
raise ValueError(err_msg)
|
|
744
|
+
|
|
745
|
+
# smdataparallel-specific input validation
|
|
746
|
+
if "dataparallel" in smdistributed:
|
|
747
|
+
_validate_smdataparallel_args(
|
|
748
|
+
instance_type, framework_name, framework_version, py_version, distribution, image_uri
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def _validate_smdataparallel_args(
|
|
753
|
+
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
|
|
754
|
+
):
|
|
755
|
+
"""Check if request is using unsupported arguments.
|
|
756
|
+
|
|
757
|
+
Validate if user specifies a supported instance type, framework version, and python
|
|
758
|
+
version.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
|
|
762
|
+
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
|
|
763
|
+
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
|
|
764
|
+
py_version (str): A string representing the python version selected.
|
|
765
|
+
Ex: `py38, py39, py310, py311`
|
|
766
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
767
|
+
(Defaults to None if distributed training is not enabled.) Ex:
|
|
768
|
+
|
|
769
|
+
.. code:: python
|
|
770
|
+
|
|
771
|
+
{
|
|
772
|
+
"smdistributed": {
|
|
773
|
+
"dataparallel": {
|
|
774
|
+
"enabled": True
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
image_uri (str): A string representing a Docker image URI.
|
|
779
|
+
|
|
780
|
+
Raises:
|
|
781
|
+
ValueError: if
|
|
782
|
+
`py_version` is not python3 or
|
|
783
|
+
`framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
|
|
784
|
+
"""
|
|
785
|
+
smdataparallel_enabled = (
|
|
786
|
+
distribution.get("smdistributed").get("dataparallel").get("enabled", False)
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
if not smdataparallel_enabled:
|
|
790
|
+
return
|
|
791
|
+
|
|
792
|
+
err_msg = ""
|
|
793
|
+
|
|
794
|
+
if not instance_type:
|
|
795
|
+
err_msg += "Please specify an instance_type for smdataparallel.\n"
|
|
796
|
+
|
|
797
|
+
if not image_uri:
|
|
798
|
+
# ignore framework_version & py_version if image_uri is set
|
|
799
|
+
# in case image_uri is not set, then both are mandatory
|
|
800
|
+
supported = SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS[framework_name]
|
|
801
|
+
if framework_version not in supported:
|
|
802
|
+
err_msg += (
|
|
803
|
+
f"Provided framework_version {framework_version} is not supported by"
|
|
804
|
+
" smdataparallel.\n"
|
|
805
|
+
f"Please specify one of the supported framework versions: {supported} \n"
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
if "py3" not in py_version:
|
|
809
|
+
err_msg += (
|
|
810
|
+
f"Provided py_version {py_version} is not supported by smdataparallel.\n"
|
|
811
|
+
"Please specify py_version>=py3"
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
if err_msg:
|
|
815
|
+
raise ValueError(err_msg)
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def validate_distribution(
|
|
819
|
+
distribution: Dict,
|
|
820
|
+
instance_groups: List[InstanceGroup],
|
|
821
|
+
framework_name: str,
|
|
822
|
+
framework_version: str,
|
|
823
|
+
py_version: str,
|
|
824
|
+
image_uri: str,
|
|
825
|
+
kwargs: Dict,
|
|
826
|
+
) -> Dict:
|
|
827
|
+
"""Check if distribution strategy is correctly invoked by the user.
|
|
828
|
+
|
|
829
|
+
Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
|
|
830
|
+
Validate if the user requested strategy is supported.
|
|
831
|
+
|
|
832
|
+
Args:
|
|
833
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
834
|
+
(Defaults to None if distributed training is not enabled.) For example:
|
|
835
|
+
|
|
836
|
+
.. code:: python
|
|
837
|
+
|
|
838
|
+
{
|
|
839
|
+
"smdistributed": {
|
|
840
|
+
"dataparallel": {
|
|
841
|
+
"enabled": True
|
|
842
|
+
}
|
|
843
|
+
}
|
|
844
|
+
}
|
|
845
|
+
instance_groups ([InstanceGroup]): A list contains instance groups used for training.
|
|
846
|
+
framework_name (str): A string representing the name of framework selected.
|
|
847
|
+
framework_version (str): A string representing the framework version selected.
|
|
848
|
+
py_version (str): A string representing the python version selected.
|
|
849
|
+
Ex: `py38, py39, py310, py311`
|
|
850
|
+
image_uri (str): A string representing a Docker image URI.
|
|
851
|
+
kwargs(dict): Additional kwargs passed to this function
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
distribution(dict): updated dictionary with validated information
|
|
855
|
+
to enable distributed training.
|
|
856
|
+
|
|
857
|
+
Raises:
|
|
858
|
+
ValueError: if distribution dictionary isn't correctly formatted or
|
|
859
|
+
multiple strategies are requested simultaneously or
|
|
860
|
+
an unsupported strategy is requested or
|
|
861
|
+
strategy-specific inputs are incorrect/unsupported or
|
|
862
|
+
heterogeneous cluster set up is incorrect
|
|
863
|
+
"""
|
|
864
|
+
validated_distribution = dict(distribution)
|
|
865
|
+
|
|
866
|
+
train_instance_groups = validated_distribution.get("instance_groups", [])
|
|
867
|
+
if instance_groups is None:
|
|
868
|
+
if len(train_instance_groups) >= 1:
|
|
869
|
+
# if estimator's instance_groups is not defined but
|
|
870
|
+
# train_instance_groups are specified in distribution
|
|
871
|
+
raise ValueError("Instance groups not specified in the estimator !")
|
|
872
|
+
else:
|
|
873
|
+
if len(train_instance_groups) > len(instance_groups):
|
|
874
|
+
# if train_instance_groups in distribution are more than estimator's instance_groups
|
|
875
|
+
raise ValueError("Train instance groups oversubscribed !")
|
|
876
|
+
if len(instance_groups) == 1 and len(train_instance_groups) == 0:
|
|
877
|
+
# if just one instance_group but it is not specified in distribution, we set it for user
|
|
878
|
+
train_instance_groups = instance_groups
|
|
879
|
+
elif len(instance_groups) > 1 and len(train_instance_groups) != 1:
|
|
880
|
+
# currently we just support one train instance group
|
|
881
|
+
raise ValueError("Distribution should only contain one instance group name !")
|
|
882
|
+
|
|
883
|
+
if len(train_instance_groups) != 0:
|
|
884
|
+
# in this case, we are handling a heterogeneous cluster training job
|
|
885
|
+
instance_group_names = []
|
|
886
|
+
for train_instance_group in train_instance_groups:
|
|
887
|
+
# in future version we will support multiple train_instance_groups, so use loop here
|
|
888
|
+
if train_instance_group not in instance_groups:
|
|
889
|
+
# check if train instance groups belongs to what user defined in estimator set up
|
|
890
|
+
raise ValueError(
|
|
891
|
+
f"Invalid training instance group {train_instance_group.instance_group_name} !"
|
|
892
|
+
)
|
|
893
|
+
instance_type = train_instance_group.instance_type
|
|
894
|
+
validate_distribution_for_instance_type(
|
|
895
|
+
instance_type=instance_type,
|
|
896
|
+
distribution=validated_distribution,
|
|
897
|
+
)
|
|
898
|
+
validate_smdistributed(
|
|
899
|
+
instance_type=instance_type,
|
|
900
|
+
framework_name=framework_name,
|
|
901
|
+
framework_version=framework_version,
|
|
902
|
+
py_version=py_version,
|
|
903
|
+
distribution=validated_distribution,
|
|
904
|
+
image_uri=image_uri,
|
|
905
|
+
)
|
|
906
|
+
if framework_name and framework_name == "pytorch":
|
|
907
|
+
# We need to validate only for PyTorch framework
|
|
908
|
+
validate_torch_distributed_distribution(
|
|
909
|
+
instance_type=instance_type,
|
|
910
|
+
distribution=validated_distribution,
|
|
911
|
+
framework_version=framework_version,
|
|
912
|
+
py_version=py_version,
|
|
913
|
+
image_uri=image_uri,
|
|
914
|
+
entry_point=kwargs["entry_point"],
|
|
915
|
+
)
|
|
916
|
+
warn_if_parameter_server_with_multi_gpu(
|
|
917
|
+
training_instance_type=instance_type, distribution=validated_distribution
|
|
918
|
+
)
|
|
919
|
+
# get instance group names
|
|
920
|
+
instance_group_names.append(train_instance_group.instance_group_name)
|
|
921
|
+
validated_distribution["instance_groups"] = instance_group_names
|
|
922
|
+
else:
|
|
923
|
+
# in this case, we are handling a normal training job (without heterogeneous cluster)
|
|
924
|
+
instance_type = renamed_kwargs(
|
|
925
|
+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
|
|
926
|
+
)
|
|
927
|
+
validate_distribution_for_instance_type(
|
|
928
|
+
instance_type=instance_type,
|
|
929
|
+
distribution=validated_distribution,
|
|
930
|
+
)
|
|
931
|
+
validate_smdistributed(
|
|
932
|
+
instance_type=instance_type,
|
|
933
|
+
framework_name=framework_name,
|
|
934
|
+
framework_version=framework_version,
|
|
935
|
+
py_version=py_version,
|
|
936
|
+
distribution=validated_distribution,
|
|
937
|
+
image_uri=image_uri,
|
|
938
|
+
)
|
|
939
|
+
if framework_name and framework_name == "pytorch":
|
|
940
|
+
# We need to validate only for PyTorch framework
|
|
941
|
+
validate_torch_distributed_distribution(
|
|
942
|
+
instance_type=instance_type,
|
|
943
|
+
distribution=validated_distribution,
|
|
944
|
+
framework_version=framework_version,
|
|
945
|
+
py_version=py_version,
|
|
946
|
+
image_uri=image_uri,
|
|
947
|
+
entry_point=kwargs["entry_point"],
|
|
948
|
+
)
|
|
949
|
+
warn_if_parameter_server_with_multi_gpu(
|
|
950
|
+
training_instance_type=instance_type, distribution=validated_distribution
|
|
951
|
+
)
|
|
952
|
+
return validated_distribution
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def validate_distribution_for_instance_type(instance_type, distribution):
|
|
956
|
+
"""Check if the provided distribution strategy is supported for the instance_type.
|
|
957
|
+
|
|
958
|
+
Args:
|
|
959
|
+
instance_type (str): A string representing the type of training instance selected.
|
|
960
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
961
|
+
"""
|
|
962
|
+
err_msg = ""
|
|
963
|
+
if isinstance(instance_type, str):
|
|
964
|
+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
965
|
+
if match and match[1].startswith("trn"):
|
|
966
|
+
keys = list(distribution.keys())
|
|
967
|
+
if len(keys) == 0:
|
|
968
|
+
return
|
|
969
|
+
if len(keys) == 1:
|
|
970
|
+
distribution_strategy = keys[0]
|
|
971
|
+
if distribution_strategy != "torch_distributed":
|
|
972
|
+
err_msg += (
|
|
973
|
+
f"Provided distribution strategy {distribution_strategy} is not supported"
|
|
974
|
+
" for Trainium instances.\n"
|
|
975
|
+
"Please specify one of the following supported distribution strategies:"
|
|
976
|
+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
|
|
977
|
+
)
|
|
978
|
+
elif len(keys) > 1:
|
|
979
|
+
err_msg += (
|
|
980
|
+
"Multiple distribution strategies are not supported for Trainium instances.\n"
|
|
981
|
+
"Please specify one of the following supported distribution strategies:"
|
|
982
|
+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} "
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
if err_msg:
|
|
986
|
+
raise ValueError(err_msg)
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def validate_torch_distributed_distribution(
|
|
990
|
+
instance_type,
|
|
991
|
+
distribution,
|
|
992
|
+
framework_version,
|
|
993
|
+
py_version,
|
|
994
|
+
image_uri,
|
|
995
|
+
entry_point,
|
|
996
|
+
):
|
|
997
|
+
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
|
|
998
|
+
|
|
999
|
+
Args:
|
|
1000
|
+
instance_type (str): A string representing the type of training instance selected.
|
|
1001
|
+
distribution (dict): A dictionary with information to enable distributed training.
|
|
1002
|
+
(Defaults to None if distributed training is not enabled.) For example:
|
|
1003
|
+
|
|
1004
|
+
.. code:: python
|
|
1005
|
+
|
|
1006
|
+
{
|
|
1007
|
+
"torch_distributed": {
|
|
1008
|
+
"enabled": True
|
|
1009
|
+
}
|
|
1010
|
+
}
|
|
1011
|
+
framework_version (str): A string representing the framework version selected.
|
|
1012
|
+
py_version (str): A string representing the python version selected.
|
|
1013
|
+
Ex: `py38, py39, py310, py311`
|
|
1014
|
+
image_uri (str): A string representing a Docker image URI.
|
|
1015
|
+
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
|
|
1016
|
+
source file that should be executed as the entry point to
|
|
1017
|
+
training.
|
|
1018
|
+
|
|
1019
|
+
Raises:
|
|
1020
|
+
ValueError: if
|
|
1021
|
+
`py_version` is not python3 or
|
|
1022
|
+
`framework_version` is not compatible with instance types
|
|
1023
|
+
"""
|
|
1024
|
+
torch_distributed_enabled = False
|
|
1025
|
+
if "torch_distributed" in distribution:
|
|
1026
|
+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
|
|
1027
|
+
if not torch_distributed_enabled:
|
|
1028
|
+
# Distribution strategy other than torch_distributed is selected
|
|
1029
|
+
return
|
|
1030
|
+
|
|
1031
|
+
err_msg = ""
|
|
1032
|
+
|
|
1033
|
+
if not image_uri:
|
|
1034
|
+
# ignore framework_version and py_version if image_uri is set
|
|
1035
|
+
# in case image_uri is not set, then both are mandatory
|
|
1036
|
+
if "py3" not in py_version:
|
|
1037
|
+
err_msg += (
|
|
1038
|
+
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
|
|
1039
|
+
"Please specify py_version>=py3\n"
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
# Check instance and framework_version compatibility
|
|
1043
|
+
if _is_gpu_instance(instance_type):
|
|
1044
|
+
if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS:
|
|
1045
|
+
err_msg += (
|
|
1046
|
+
f"Provided framework_version {framework_version} is not supported by"
|
|
1047
|
+
f" torch_distributed for instance {instance_type}.\n"
|
|
1048
|
+
"Please specify one of the supported framework versions:"
|
|
1049
|
+
f"{TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS} \n"
|
|
1050
|
+
)
|
|
1051
|
+
elif _is_trainium_instance(instance_type):
|
|
1052
|
+
if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS:
|
|
1053
|
+
err_msg += (
|
|
1054
|
+
f"Provided framework_version {framework_version} is not supported by"
|
|
1055
|
+
f" torch_distributed for instance {instance_type}.\n"
|
|
1056
|
+
"Please specify one of the supported framework versions:"
|
|
1057
|
+
f"{TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS} \n"
|
|
1058
|
+
)
|
|
1059
|
+
else:
|
|
1060
|
+
err_msg += (
|
|
1061
|
+
"Currently torch_distributed is supported only for GPU and Trainium instances.\n"
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
# Check entry point type
|
|
1065
|
+
if not entry_point.endswith(".py"):
|
|
1066
|
+
err_msg += (
|
|
1067
|
+
"Unsupported entry point type for the distribution torch_distributed.\n"
|
|
1068
|
+
"Only python programs (*.py) are supported."
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
if err_msg:
|
|
1072
|
+
raise ValueError(err_msg)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def _is_gpu_instance(instance_type):
|
|
1076
|
+
"""Returns bool indicating whether instance_type supports GPU.
|
|
1077
|
+
|
|
1078
|
+
Args:
|
|
1079
|
+
instance_type (str): Name of the instance_type to check against.
|
|
1080
|
+
|
|
1081
|
+
Returns:
|
|
1082
|
+
bool: Whether or not the instance_type supports GPU
|
|
1083
|
+
"""
|
|
1084
|
+
if isinstance(instance_type, str):
|
|
1085
|
+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1086
|
+
if match:
|
|
1087
|
+
if match[1].startswith("p") or match[1].startswith("g"):
|
|
1088
|
+
return True
|
|
1089
|
+
if instance_type == "local_gpu":
|
|
1090
|
+
return True
|
|
1091
|
+
return False
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
def _is_trainium_instance(instance_type):
|
|
1095
|
+
"""Returns bool indicating whether instance_type is a Trainium instance.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
instance_type (str): Name of the instance_type to check against.
|
|
1099
|
+
|
|
1100
|
+
Returns:
|
|
1101
|
+
bool: Whether or not the instance_type is a Trainium instance
|
|
1102
|
+
"""
|
|
1103
|
+
if isinstance(instance_type, str):
|
|
1104
|
+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1105
|
+
if match and match[1].startswith("trn"):
|
|
1106
|
+
return True
|
|
1107
|
+
return False
|
|
1108
|
+
|
|
1109
|
+
|
|
1110
|
+
def python_deprecation_warning(framework, latest_supported_version):
|
|
1111
|
+
"""Placeholder docstring."""
|
|
1112
|
+
return PYTHON_2_DEPRECATION_WARNING.format(
|
|
1113
|
+
framework=framework, latest_supported_version=latest_supported_version
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def _region_supports_debugger(region_name):
|
|
1118
|
+
"""Returns boolean indicating whether the region supports Amazon SageMaker Debugger.
|
|
1119
|
+
|
|
1120
|
+
Args:
|
|
1121
|
+
region_name (str): Name of the region to check against.
|
|
1122
|
+
|
|
1123
|
+
Returns:
|
|
1124
|
+
bool: Whether or not the region supports Amazon SageMaker Debugger.
|
|
1125
|
+
"""
|
|
1126
|
+
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
def _region_supports_profiler(region_name):
|
|
1130
|
+
"""Returns bool indicating whether region supports Amazon SageMaker Debugger profiling feature.
|
|
1131
|
+
|
|
1132
|
+
Args:
|
|
1133
|
+
region_name (str): Name of the region to check against.
|
|
1134
|
+
|
|
1135
|
+
Returns:
|
|
1136
|
+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
|
|
1137
|
+
"""
|
|
1138
|
+
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _instance_type_supports_profiler(instance_type):
|
|
1142
|
+
"""Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
|
|
1143
|
+
|
|
1144
|
+
Args:
|
|
1145
|
+
instance_type (str): Name of the instance_type to check against.
|
|
1146
|
+
|
|
1147
|
+
Returns:
|
|
1148
|
+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
|
|
1149
|
+
"""
|
|
1150
|
+
if isinstance(instance_type, str):
|
|
1151
|
+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1152
|
+
if match and match[1].startswith("trn"):
|
|
1153
|
+
return True
|
|
1154
|
+
return False
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
def validate_version_or_image_args(framework_version, py_version, image_uri):
|
|
1158
|
+
"""Checks if version or image arguments are specified.
|
|
1159
|
+
|
|
1160
|
+
Validates framework and model arguments to enforce version or image specification.
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
framework_version (str): The version of the framework.
|
|
1164
|
+
py_version (str): A string representing the python version selected.
|
|
1165
|
+
Ex: `py38, py39, py310, py311`
|
|
1166
|
+
image_uri (str): The URI of the image.
|
|
1167
|
+
|
|
1168
|
+
Raises:
|
|
1169
|
+
ValueError: if `image_uri` is None and either `framework_version` or `py_version` is
|
|
1170
|
+
None.
|
|
1171
|
+
"""
|
|
1172
|
+
if (framework_version is None or py_version is None) and image_uri is None:
|
|
1173
|
+
raise ValueError(
|
|
1174
|
+
"framework_version or py_version was None, yet image_uri was also None. "
|
|
1175
|
+
"Either specify both framework_version and py_version, or specify image_uri."
|
|
1176
|
+
)
|