sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module provides mpi related utility functions for the container drivers."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import subprocess
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import List
|
|
23
|
+
|
|
24
|
+
import paramiko
|
|
25
|
+
|
|
26
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
27
|
+
|
|
28
|
+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
|
|
29
|
+
SM_EFA_NCCL_INSTANCES,
|
|
30
|
+
SM_EFA_RDMA_INSTANCES,
|
|
31
|
+
get_python_executable,
|
|
32
|
+
logger,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
|
|
36
|
+
READY_FILE = "/tmp/ready.%s"
|
|
37
|
+
DEFAULT_SSH_PORT = 22
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _write_file_to_host(host: str, status_file: str) -> bool:
|
|
41
|
+
"""Write the a file to the provided host."""
|
|
42
|
+
try:
|
|
43
|
+
logger.info(f"Writing {status_file} to {host}")
|
|
44
|
+
subprocess.run(
|
|
45
|
+
["ssh", host, "touch", f"{status_file}"],
|
|
46
|
+
capture_output=True,
|
|
47
|
+
text=True,
|
|
48
|
+
check=True,
|
|
49
|
+
)
|
|
50
|
+
logger.info("Finished writing status file")
|
|
51
|
+
return True
|
|
52
|
+
except subprocess.CalledProcessError:
|
|
53
|
+
logger.info(f"Cannot connect to {host}")
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE):
|
|
58
|
+
"""Write the status file to all worker nodes."""
|
|
59
|
+
for worker in worker_hosts:
|
|
60
|
+
retry = 0
|
|
61
|
+
while not _write_file_to_host(worker, status_file):
|
|
62
|
+
time.sleep(5)
|
|
63
|
+
retry += 1
|
|
64
|
+
if retry > 5:
|
|
65
|
+
raise TimeoutError(f"Timed out waiting for {worker} to be reachable.")
|
|
66
|
+
logger.info(f"Retrying to write status file to {worker}")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _wait_for_status_file(status_file: str):
|
|
70
|
+
"""Wait for the status file to be created."""
|
|
71
|
+
logger.info(f"Waiting for status file {status_file}")
|
|
72
|
+
while not os.path.exists(status_file):
|
|
73
|
+
time.sleep(30)
|
|
74
|
+
logger.info(f"Found status file {status_file}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def start_sshd_daemon():
|
|
78
|
+
"""Start the SSH daemon on the current node."""
|
|
79
|
+
sshd_executable = "/usr/sbin/sshd"
|
|
80
|
+
|
|
81
|
+
if not os.path.exists(sshd_executable):
|
|
82
|
+
raise RuntimeError("SSH daemon not found.")
|
|
83
|
+
|
|
84
|
+
# Start the sshd in daemon mode (-D)
|
|
85
|
+
subprocess.Popen([sshd_executable, "-D"])
|
|
86
|
+
logger.info("Started SSH daemon.")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
|
|
90
|
+
"""Class to handle host key policy for SageMaker distributed training SSH connections.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> client = paramiko.SSHClient()
|
|
94
|
+
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
|
|
95
|
+
>>> # Will succeed for SageMaker algorithm containers
|
|
96
|
+
>>> client.connect('algo-1234.internal')
|
|
97
|
+
>>> # Will raise SSHException for other unknown hosts
|
|
98
|
+
>>> client.connect('unknown-host') # raises SSHException
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def missing_host_key(self, client, hostname, key):
|
|
102
|
+
"""Accept host keys for algo-* hostnames, reject others.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
client: The SSHClient instance
|
|
106
|
+
hostname: The hostname attempting to connect
|
|
107
|
+
key: The host key
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
paramiko.SSHException: If hostname doesn't match algo-* pattern
|
|
111
|
+
"""
|
|
112
|
+
if hostname.startswith("algo-"):
|
|
113
|
+
client.get_host_keys().add(hostname, key.get_name(), key)
|
|
114
|
+
return
|
|
115
|
+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
|
|
119
|
+
"""Check if the connection to the provided host and port is possible."""
|
|
120
|
+
try:
|
|
121
|
+
logger.debug("Testing connection to host %s", host)
|
|
122
|
+
with paramiko.SSHClient() as client:
|
|
123
|
+
client.load_system_host_keys()
|
|
124
|
+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
|
|
125
|
+
client.connect(host, port=port)
|
|
126
|
+
logger.info("Can connect to host %s", host)
|
|
127
|
+
return True
|
|
128
|
+
except Exception as e: # pylint: disable=W0703
|
|
129
|
+
logger.info("Cannot connect to host %s", host)
|
|
130
|
+
logger.debug(f"Connection failed with exception: {e}")
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300):
|
|
135
|
+
"""Master node waits until it can connect to all worker nodes."""
|
|
136
|
+
start_time = time.time()
|
|
137
|
+
if not worker_hosts:
|
|
138
|
+
logger.info("No worker nodes to connect to.")
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
while True:
|
|
142
|
+
logger.info("Master is attempting to connect to all workers...")
|
|
143
|
+
all_workers_connected = all(
|
|
144
|
+
_can_connect(worker, port) and os.path.exists(READY_FILE % worker)
|
|
145
|
+
for worker in worker_hosts
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if all_workers_connected:
|
|
149
|
+
logger.info("Master can connect to all worker nodes.")
|
|
150
|
+
break
|
|
151
|
+
if time.time() - start_time > timeout:
|
|
152
|
+
raise TimeoutError("Timed out waiting for workers to be reachable.")
|
|
153
|
+
|
|
154
|
+
time.sleep(5) # Wait for 5 seconds before trying again
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300):
|
|
158
|
+
"""Worker nodes wait until they can connect to the master node."""
|
|
159
|
+
start_time = time.time()
|
|
160
|
+
while True:
|
|
161
|
+
logger.info(f"Worker is attempting to connect to the master node {master_host}...")
|
|
162
|
+
if _can_connect(master_host, port):
|
|
163
|
+
logger.info(f"Worker can connect to master node {master_host}.")
|
|
164
|
+
break
|
|
165
|
+
if time.time() - start_time > timeout:
|
|
166
|
+
raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.")
|
|
167
|
+
|
|
168
|
+
time.sleep(5) # Wait for 5 seconds before trying again
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE):
|
|
172
|
+
"""Bootstrap the worker nodes."""
|
|
173
|
+
logger.info("Bootstrapping worker node...")
|
|
174
|
+
_wait_for_master(master_host)
|
|
175
|
+
_write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"])
|
|
176
|
+
_wait_for_status_file(status_file)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def bootstrap_master_node(worker_hosts: List[str]):
|
|
180
|
+
"""Bootstrap the master node."""
|
|
181
|
+
logger.info("Bootstrapping master node...")
|
|
182
|
+
_wait_for_workers(worker_hosts)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def validate_smddprun() -> bool:
|
|
186
|
+
"""Whether smddprun is installed.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
bool: True if installed
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
output = subprocess.run(
|
|
193
|
+
["which", "smddprun"],
|
|
194
|
+
capture_output=True,
|
|
195
|
+
text=True,
|
|
196
|
+
check=True,
|
|
197
|
+
)
|
|
198
|
+
return output.stdout != ""
|
|
199
|
+
except subprocess.CalledProcessError:
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def validate_smddpmprun() -> bool:
|
|
204
|
+
"""Whether smddpmprun is installed.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
bool: True if both are installed
|
|
208
|
+
"""
|
|
209
|
+
try:
|
|
210
|
+
output = subprocess.run(
|
|
211
|
+
["which", "smddpmprun"],
|
|
212
|
+
capture_output=True,
|
|
213
|
+
text=True,
|
|
214
|
+
check=True,
|
|
215
|
+
)
|
|
216
|
+
return output.stdout != ""
|
|
217
|
+
except subprocess.CalledProcessError:
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def write_env_vars_to_file():
|
|
222
|
+
"""Write environment variables to /etc/environment file."""
|
|
223
|
+
with open("/etc/environment", "a", encoding="utf-8") as f:
|
|
224
|
+
for name in os.environ:
|
|
225
|
+
f.write(f"{name}={os.environ.get(name)}\n")
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_mpirun_command(
|
|
229
|
+
host_count: int,
|
|
230
|
+
host_list: List[str],
|
|
231
|
+
num_processes: int,
|
|
232
|
+
additional_options: List[str],
|
|
233
|
+
entry_script_path: str,
|
|
234
|
+
):
|
|
235
|
+
"""Fetch mpi command"""
|
|
236
|
+
network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0")
|
|
237
|
+
|
|
238
|
+
mpirun_command = [
|
|
239
|
+
"mpirun",
|
|
240
|
+
"--host",
|
|
241
|
+
",".join(host_list),
|
|
242
|
+
"-np",
|
|
243
|
+
str(num_processes),
|
|
244
|
+
"--allow-run-as-root",
|
|
245
|
+
"--tag-output",
|
|
246
|
+
"-mca",
|
|
247
|
+
"btl_tcp_if_include",
|
|
248
|
+
network_interface_name,
|
|
249
|
+
"-mca",
|
|
250
|
+
"oob_tcp_if_include",
|
|
251
|
+
network_interface_name,
|
|
252
|
+
"-mca",
|
|
253
|
+
"plm_rsh_no_tree_spawn",
|
|
254
|
+
"1",
|
|
255
|
+
"-mca",
|
|
256
|
+
"pml",
|
|
257
|
+
"ob1",
|
|
258
|
+
"-mca",
|
|
259
|
+
"btl",
|
|
260
|
+
"^openib",
|
|
261
|
+
"-mca",
|
|
262
|
+
"orte_abort_on_non_zero_status",
|
|
263
|
+
"1",
|
|
264
|
+
"-mca",
|
|
265
|
+
"btl_vader_single_copy_mechanism",
|
|
266
|
+
"none",
|
|
267
|
+
"-mca",
|
|
268
|
+
"plm_rsh_num_concurrent",
|
|
269
|
+
str(host_count),
|
|
270
|
+
"-x",
|
|
271
|
+
"NCCL_SOCKET_IFNAME=%s" % network_interface_name,
|
|
272
|
+
"-x",
|
|
273
|
+
"LD_LIBRARY_PATH",
|
|
274
|
+
"-x",
|
|
275
|
+
"PATH",
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
if additional_options:
|
|
279
|
+
mpirun_command.extend(additional_options)
|
|
280
|
+
|
|
281
|
+
instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"]
|
|
282
|
+
# EFA settings
|
|
283
|
+
if instance_type in SM_EFA_NCCL_INSTANCES:
|
|
284
|
+
mpirun_command.extend(["-x", "FI_PROVIDER=efa"])
|
|
285
|
+
# Use simple protocol to handle the out-of-order data delivery from EFA
|
|
286
|
+
mpirun_command.extend(["-x", "NCCL_PROTO=simple"])
|
|
287
|
+
|
|
288
|
+
if instance_type in SM_EFA_RDMA_INSTANCES:
|
|
289
|
+
# Use EFA's RDMA functionality for one-sided and two-sided transfer
|
|
290
|
+
mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"])
|
|
291
|
+
|
|
292
|
+
for credential in [
|
|
293
|
+
"AWS_ACCESS_KEY_ID",
|
|
294
|
+
"AWS_SECRET_ACCESS_KEY",
|
|
295
|
+
"AWS_SESSION_TOKEN",
|
|
296
|
+
]:
|
|
297
|
+
if credential in os.environ:
|
|
298
|
+
mpirun_command.extend(["-x", credential])
|
|
299
|
+
|
|
300
|
+
mpirun_command.extend([get_python_executable()])
|
|
301
|
+
mpirun_command.extend(["-m", "mpi4py", entry_script_path])
|
|
302
|
+
return mpirun_command
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module is the entry point for the Torchrun driver script."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import List, Tuple
|
|
22
|
+
|
|
23
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
24
|
+
|
|
25
|
+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
|
|
26
|
+
logger,
|
|
27
|
+
hyperparameters_to_cli_args,
|
|
28
|
+
get_process_count,
|
|
29
|
+
get_python_executable,
|
|
30
|
+
execute_commands,
|
|
31
|
+
write_failure_file,
|
|
32
|
+
SM_EFA_NCCL_INSTANCES,
|
|
33
|
+
SM_EFA_RDMA_INSTANCES,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def pytorch_version() -> Tuple[int, int]:
|
|
38
|
+
"""Get the PyTorch version as a tuple of integers."""
|
|
39
|
+
import torch
|
|
40
|
+
|
|
41
|
+
return tuple(map(int, torch.__version__.split(".")[:2]))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_base_pytorch_command() -> List[str]:
|
|
45
|
+
"""Get the base Torch Distributed launcher to execute"""
|
|
46
|
+
if pytorch_version() >= (1, 9):
|
|
47
|
+
return ["torchrun"]
|
|
48
|
+
return [f"{get_python_executable()}", "-m", "torch.distributed.launch"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def setup_env():
|
|
52
|
+
"""Setup the environment variables for PyTorch distributed training"""
|
|
53
|
+
instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"]
|
|
54
|
+
network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0")
|
|
55
|
+
if instance_type in SM_EFA_NCCL_INSTANCES:
|
|
56
|
+
# Enable EFA use
|
|
57
|
+
os.environ["FI_PROVIDER"] = "efa"
|
|
58
|
+
if instance_type in SM_EFA_RDMA_INSTANCES:
|
|
59
|
+
# Use EFA's RDMA functionality for one-sided and two-sided transfer
|
|
60
|
+
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
|
|
61
|
+
os.environ["RDMAV_FORK_SAFE"] = "1"
|
|
62
|
+
os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name)
|
|
63
|
+
os.environ["NCCL_PROTO"] = "simple"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_commands():
|
|
67
|
+
"""Create the Torch Distributed command to execute"""
|
|
68
|
+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
|
|
69
|
+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
|
|
70
|
+
hyperparameters = json.loads(os.environ["SM_HPS"])
|
|
71
|
+
|
|
72
|
+
process_count = int(distributed_config["process_count_per_node"] or 0)
|
|
73
|
+
process_count = get_process_count(process_count)
|
|
74
|
+
host_count = int(os.environ["SM_HOST_COUNT"])
|
|
75
|
+
|
|
76
|
+
torch_cmd = []
|
|
77
|
+
if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1":
|
|
78
|
+
torch_cmd.append("neuron_parallel_compile")
|
|
79
|
+
|
|
80
|
+
torch_cmd.extend(get_base_pytorch_command())
|
|
81
|
+
torch_cmd.extend(
|
|
82
|
+
[
|
|
83
|
+
f"--nnodes={host_count}",
|
|
84
|
+
f"--nproc_per_node={process_count}",
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# If more than one node is used, add node rank information
|
|
89
|
+
if int(host_count) > 1:
|
|
90
|
+
torch_cmd.extend(
|
|
91
|
+
[
|
|
92
|
+
f"--master_addr={os.environ['SM_MASTER_ADDR']}",
|
|
93
|
+
f"--master_port={os.environ['SM_MASTER_PORT']}",
|
|
94
|
+
f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}",
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
torch_cmd.extend([entry_script])
|
|
99
|
+
|
|
100
|
+
args = hyperparameters_to_cli_args(hyperparameters)
|
|
101
|
+
torch_cmd += args
|
|
102
|
+
|
|
103
|
+
return torch_cmd
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def main():
|
|
107
|
+
"""Main function to execute the PyTorch distributed training script.
|
|
108
|
+
|
|
109
|
+
This function sets some environment variables and executes the PyTorch
|
|
110
|
+
distributed training script.
|
|
111
|
+
|
|
112
|
+
Execution Lifecycle:
|
|
113
|
+
1. Setup Environment Variables for PyTorch Distributed Training
|
|
114
|
+
2. Create Torch Distributed Command
|
|
115
|
+
3. Execute Torch Distributed Command with user script provided in `entry_script`
|
|
116
|
+
4. Exit
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
setup_env()
|
|
120
|
+
torch_cmd = create_commands()
|
|
121
|
+
logger.info(f"Executing command: {' '.join(torch_cmd)}")
|
|
122
|
+
exit_code, traceback = execute_commands(torch_cmd)
|
|
123
|
+
if exit_code != 0:
|
|
124
|
+
write_failure_file(traceback)
|
|
125
|
+
sys.exit(exit_code)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
if __name__ == "__main__":
|
|
129
|
+
main()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules container drivers - scripts directory."""
|
|
14
|
+
from __future__ import absolute_import
|