sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules train directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules container drivers directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules container drivers - common directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module provides utility functions for the container drivers."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import logging
|
|
18
|
+
import sys
|
|
19
|
+
import subprocess
|
|
20
|
+
import traceback
|
|
21
|
+
import json
|
|
22
|
+
|
|
23
|
+
from typing import List, Dict, Any, Tuple, IO, Optional
|
|
24
|
+
|
|
25
|
+
# Initialize logger
|
|
26
|
+
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
29
|
+
logger.addHandler(console_handler)
|
|
30
|
+
logger.setLevel(int(SM_LOG_LEVEL))
|
|
31
|
+
|
|
32
|
+
FAILURE_FILE = "/opt/ml/output/failure"
|
|
33
|
+
DEFAULT_FAILURE_MESSAGE = """
|
|
34
|
+
Training Execution failed.
|
|
35
|
+
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
|
|
36
|
+
TrainingJob - {training_job_name}
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
USER_CODE_PATH = "/opt/ml/input/data/code"
|
|
40
|
+
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
|
|
41
|
+
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
|
|
42
|
+
|
|
43
|
+
HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
|
|
44
|
+
|
|
45
|
+
SM_EFA_NCCL_INSTANCES = [
|
|
46
|
+
"ml.g4dn.8xlarge",
|
|
47
|
+
"ml.g4dn.12xlarge",
|
|
48
|
+
"ml.g5.48xlarge",
|
|
49
|
+
"ml.p3dn.24xlarge",
|
|
50
|
+
"ml.p4d.24xlarge",
|
|
51
|
+
"ml.p4de.24xlarge",
|
|
52
|
+
"ml.p5.48xlarge",
|
|
53
|
+
"ml.trn1.32xlarge",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
SM_EFA_RDMA_INSTANCES = [
|
|
57
|
+
"ml.p4d.24xlarge",
|
|
58
|
+
"ml.p4de.24xlarge",
|
|
59
|
+
"ml.trn1.32xlarge",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def write_failure_file(message: Optional[str] = None):
|
|
64
|
+
"""Write a failure file with the message."""
|
|
65
|
+
if message is None:
|
|
66
|
+
message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"])
|
|
67
|
+
if not os.path.exists(FAILURE_FILE):
|
|
68
|
+
with open(FAILURE_FILE, "w") as f:
|
|
69
|
+
f.write(message)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
|
|
73
|
+
"""Read the source code config json file."""
|
|
74
|
+
try:
|
|
75
|
+
with open(source_code_json, "r") as f:
|
|
76
|
+
source_code_dict = json.load(f) or {}
|
|
77
|
+
except FileNotFoundError:
|
|
78
|
+
source_code_dict = {}
|
|
79
|
+
return source_code_dict
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
|
|
83
|
+
"""Read the distribution config json file."""
|
|
84
|
+
try:
|
|
85
|
+
with open(distributed_json, "r") as f:
|
|
86
|
+
distributed_dict = json.load(f) or {}
|
|
87
|
+
except FileNotFoundError:
|
|
88
|
+
distributed_dict = {}
|
|
89
|
+
return distributed_dict
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
|
|
93
|
+
"""Read the hyperparameters config json file."""
|
|
94
|
+
try:
|
|
95
|
+
with open(hyperparameters_json, "r") as f:
|
|
96
|
+
hyperparameters_dict = json.load(f) or {}
|
|
97
|
+
except FileNotFoundError:
|
|
98
|
+
hyperparameters_dict = {}
|
|
99
|
+
return hyperparameters_dict
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_process_count(process_count: Optional[int] = None) -> int:
|
|
103
|
+
"""Get the number of processes to run on each node in the training job."""
|
|
104
|
+
return (
|
|
105
|
+
process_count
|
|
106
|
+
or int(os.environ.get("SM_NUM_GPUS", 0))
|
|
107
|
+
or int(os.environ.get("SM_NUM_NEURONS", 0))
|
|
108
|
+
or 1
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]:
|
|
113
|
+
"""Convert the hyperparameters to CLI arguments."""
|
|
114
|
+
cli_args = []
|
|
115
|
+
for key, value in hyperparameters.items():
|
|
116
|
+
value = safe_deserialize(value)
|
|
117
|
+
cli_args.extend([f"--{key}", safe_serialize(value)])
|
|
118
|
+
|
|
119
|
+
return cli_args
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def safe_deserialize(data: Any) -> Any:
|
|
123
|
+
"""Safely deserialize data from a JSON string.
|
|
124
|
+
|
|
125
|
+
This function handles the following cases:
|
|
126
|
+
1. If `data` is not a string, it returns the input as-is.
|
|
127
|
+
2. If `data` is a string and matches common boolean values ("true" or "false"),
|
|
128
|
+
it returns the corresponding boolean value (True or False).
|
|
129
|
+
3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
|
|
130
|
+
4. If `data` is a string but cannot be decoded as JSON, it returns the original string.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Any: The deserialized data, or the original input if it cannot be JSON-decoded.
|
|
134
|
+
"""
|
|
135
|
+
if not isinstance(data, str):
|
|
136
|
+
return data
|
|
137
|
+
|
|
138
|
+
lower_data = data.lower()
|
|
139
|
+
if lower_data in ["true"]:
|
|
140
|
+
return True
|
|
141
|
+
if lower_data in ["false"]:
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
return json.loads(data)
|
|
146
|
+
except json.JSONDecodeError:
|
|
147
|
+
return data
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def safe_serialize(data):
|
|
151
|
+
"""Serialize the data without wrapping strings in quotes.
|
|
152
|
+
|
|
153
|
+
This function handles the following cases:
|
|
154
|
+
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
|
|
155
|
+
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
|
|
156
|
+
the JSON-encoded string using `json.dumps()`.
|
|
157
|
+
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
|
|
158
|
+
representation of the data using `str(data)`.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
data (Any): The data to serialize.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: The serialized JSON-compatible string or the string representation of the input.
|
|
165
|
+
"""
|
|
166
|
+
if isinstance(data, str):
|
|
167
|
+
return data
|
|
168
|
+
try:
|
|
169
|
+
return json.dumps(data)
|
|
170
|
+
except TypeError:
|
|
171
|
+
return str(data)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_python_executable() -> str:
|
|
175
|
+
"""Get the python executable path."""
|
|
176
|
+
return sys.executable
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def log_subprocess_output(pipe: IO[bytes]):
|
|
180
|
+
"""Log the output from the subprocess."""
|
|
181
|
+
for line in iter(pipe.readline, b""):
|
|
182
|
+
logger.info(line.decode("utf-8").strip())
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def execute_commands(commands: List[str]) -> Tuple[int, str]:
|
|
186
|
+
"""Execute the provided commands and return exit code with failure traceback if any."""
|
|
187
|
+
try:
|
|
188
|
+
process = subprocess.Popen(
|
|
189
|
+
commands,
|
|
190
|
+
stdout=subprocess.PIPE,
|
|
191
|
+
stderr=subprocess.STDOUT,
|
|
192
|
+
)
|
|
193
|
+
with process.stdout:
|
|
194
|
+
log_subprocess_output(process.stdout)
|
|
195
|
+
exitcode = process.wait()
|
|
196
|
+
if exitcode != 0:
|
|
197
|
+
raise subprocess.CalledProcessError(exitcode, commands)
|
|
198
|
+
return exitcode, ""
|
|
199
|
+
except subprocess.CalledProcessError as e:
|
|
200
|
+
# Capture the traceback in case of failure
|
|
201
|
+
error_traceback = traceback.format_exc()
|
|
202
|
+
print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}")
|
|
203
|
+
return e.returncode, error_traceback
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def is_worker_node() -> bool:
|
|
207
|
+
"""Check if the current node is a worker node."""
|
|
208
|
+
return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def is_master_node() -> bool:
|
|
212
|
+
"""Check if the current node is the master node."""
|
|
213
|
+
return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules container drivers - drivers directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module is the entry point for the Basic Script Driver."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
import shlex
|
|
20
|
+
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import List
|
|
23
|
+
|
|
24
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
25
|
+
|
|
26
|
+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
|
|
27
|
+
logger,
|
|
28
|
+
get_python_executable,
|
|
29
|
+
execute_commands,
|
|
30
|
+
write_failure_file,
|
|
31
|
+
hyperparameters_to_cli_args,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_commands() -> List[str]:
|
|
36
|
+
"""Create the commands to execute."""
|
|
37
|
+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
|
|
38
|
+
hyperparameters = json.loads(os.environ["SM_HPS"])
|
|
39
|
+
python_executable = get_python_executable()
|
|
40
|
+
|
|
41
|
+
args = hyperparameters_to_cli_args(hyperparameters)
|
|
42
|
+
if entry_script.endswith(".py"):
|
|
43
|
+
commands = [python_executable, entry_script]
|
|
44
|
+
commands += args
|
|
45
|
+
elif entry_script.endswith(".sh"):
|
|
46
|
+
args_str = " ".join(shlex.quote(arg) for arg in args)
|
|
47
|
+
commands = [
|
|
48
|
+
"/bin/sh",
|
|
49
|
+
"-c",
|
|
50
|
+
f"chmod +x {entry_script} && ./{entry_script} {args_str}",
|
|
51
|
+
]
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported."
|
|
55
|
+
)
|
|
56
|
+
return commands
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def main():
|
|
60
|
+
"""Main function for the Basic Script Driver.
|
|
61
|
+
|
|
62
|
+
This function is the entry point for the Basic Script Driver.
|
|
63
|
+
|
|
64
|
+
Execution Lifecycle:
|
|
65
|
+
1. Read the source code and hyperparameters JSON files.
|
|
66
|
+
2. Set hyperparameters as command line arguments.
|
|
67
|
+
3. Create the commands to execute.
|
|
68
|
+
4. Execute the commands.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
cmd = create_commands()
|
|
72
|
+
|
|
73
|
+
logger.info(f"Executing command: {' '.join(cmd)}")
|
|
74
|
+
exit_code, traceback = execute_commands(cmd)
|
|
75
|
+
if exit_code != 0:
|
|
76
|
+
write_failure_file(traceback)
|
|
77
|
+
sys.exit(exit_code)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
main()
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module is the entry point for the MPI driver script."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from mpi_utils import (
|
|
23
|
+
start_sshd_daemon,
|
|
24
|
+
bootstrap_master_node,
|
|
25
|
+
bootstrap_worker_node,
|
|
26
|
+
get_mpirun_command,
|
|
27
|
+
write_status_file_to_workers,
|
|
28
|
+
write_env_vars_to_file,
|
|
29
|
+
)
|
|
30
|
+
except ImportError:
|
|
31
|
+
# mpi_utils is an optional external dependency for MPI distributed training
|
|
32
|
+
# If not available, provide stub functions that raise helpful errors
|
|
33
|
+
def _mpi_not_available(*args, **kwargs):
|
|
34
|
+
raise ImportError(
|
|
35
|
+
"MPI distributed training requires the 'mpi_utils' package. "
|
|
36
|
+
"Please install it to use MPI-based distributed training."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
start_sshd_daemon = _mpi_not_available
|
|
40
|
+
bootstrap_master_node = _mpi_not_available
|
|
41
|
+
bootstrap_worker_node = _mpi_not_available
|
|
42
|
+
get_mpirun_command = _mpi_not_available
|
|
43
|
+
write_status_file_to_workers = _mpi_not_available
|
|
44
|
+
write_env_vars_to_file = _mpi_not_available
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
48
|
+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
|
|
49
|
+
logger,
|
|
50
|
+
hyperparameters_to_cli_args,
|
|
51
|
+
get_process_count,
|
|
52
|
+
execute_commands,
|
|
53
|
+
write_failure_file,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def main():
|
|
58
|
+
"""Main function for the MPI driver script.
|
|
59
|
+
|
|
60
|
+
The MPI Dirver is responsible for setting up the MPI environment,
|
|
61
|
+
generating the correct mpi commands, and launching the MPI job.
|
|
62
|
+
|
|
63
|
+
Execution Lifecycle:
|
|
64
|
+
1. Setup General Environment Variables at /etc/environment
|
|
65
|
+
2. Start SSHD Daemon
|
|
66
|
+
3. Bootstrap Worker Nodes
|
|
67
|
+
a. Wait to establish connection with Master Node
|
|
68
|
+
b. Wait for Master Node to write status file
|
|
69
|
+
4. Bootstrap Master Node
|
|
70
|
+
a. Wait to establish connection with Worker Nodes
|
|
71
|
+
b. Generate MPI Command
|
|
72
|
+
c. Execute MPI Command with user script provided in `entry_script`
|
|
73
|
+
d. Write status file to Worker Nodes
|
|
74
|
+
5. Exit
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
|
|
78
|
+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
|
|
79
|
+
hyperparameters = json.loads(os.environ["SM_HPS"])
|
|
80
|
+
|
|
81
|
+
sm_current_host = os.environ["SM_CURRENT_HOST"]
|
|
82
|
+
sm_hosts = json.loads(os.environ["SM_HOSTS"])
|
|
83
|
+
sm_master_addr = os.environ["SM_MASTER_ADDR"]
|
|
84
|
+
|
|
85
|
+
write_env_vars_to_file()
|
|
86
|
+
start_sshd_daemon()
|
|
87
|
+
|
|
88
|
+
if sm_current_host != sm_master_addr:
|
|
89
|
+
bootstrap_worker_node(sm_master_addr)
|
|
90
|
+
else:
|
|
91
|
+
worker_hosts = [host for host in sm_hosts if host != sm_master_addr]
|
|
92
|
+
bootstrap_master_node(worker_hosts)
|
|
93
|
+
|
|
94
|
+
host_list = json.loads(os.environ["SM_HOSTS"])
|
|
95
|
+
host_count = int(os.environ["SM_HOST_COUNT"])
|
|
96
|
+
process_count = int(distributed_config["process_count_per_node"] or 0)
|
|
97
|
+
process_count = get_process_count(process_count)
|
|
98
|
+
|
|
99
|
+
if process_count > 1:
|
|
100
|
+
host_list = ["{}:{}".format(host, process_count) for host in host_list]
|
|
101
|
+
|
|
102
|
+
mpi_command = get_mpirun_command(
|
|
103
|
+
host_count=host_count,
|
|
104
|
+
host_list=host_list,
|
|
105
|
+
num_processes=process_count,
|
|
106
|
+
additional_options=distributed_config["mpi_additional_options"] or [],
|
|
107
|
+
entry_script_path=entry_script,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
args = hyperparameters_to_cli_args(hyperparameters)
|
|
111
|
+
mpi_command += args
|
|
112
|
+
|
|
113
|
+
logger.info(f"Executing command: {' '.join(mpi_command)}")
|
|
114
|
+
exit_code, error_traceback = execute_commands(mpi_command)
|
|
115
|
+
write_status_file_to_workers(worker_hosts)
|
|
116
|
+
|
|
117
|
+
if exit_code != 0:
|
|
118
|
+
write_failure_file(error_traceback)
|
|
119
|
+
sys.exit(exit_code)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if __name__ == "__main__":
|
|
123
|
+
main()
|