sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Utilities related to logging."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _UTCFormatter(logging.Formatter):
|
|
21
|
+
"""Class that overrides the default local time provider in log formatter."""
|
|
22
|
+
|
|
23
|
+
converter = time.gmtime
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_logger():
|
|
27
|
+
"""Return a logger with the name 'sagemaker'"""
|
|
28
|
+
sagemaker_logger = logging.getLogger("sagemaker.remote_function")
|
|
29
|
+
if len(sagemaker_logger.handlers) == 0:
|
|
30
|
+
sagemaker_logger.setLevel(logging.INFO)
|
|
31
|
+
handler = logging.StreamHandler()
|
|
32
|
+
formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s")
|
|
33
|
+
handler.setFormatter(formatter)
|
|
34
|
+
sagemaker_logger.addHandler(handler)
|
|
35
|
+
# don't stream logs with the root logger handler
|
|
36
|
+
sagemaker_logger.propagate = 0
|
|
37
|
+
|
|
38
|
+
return sagemaker_logger
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Sagemaker modules container_drivers directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import getpass
|
|
18
|
+
import json
|
|
19
|
+
import multiprocessing
|
|
20
|
+
import os
|
|
21
|
+
import pathlib
|
|
22
|
+
import shutil
|
|
23
|
+
import subprocess
|
|
24
|
+
import sys
|
|
25
|
+
from typing import Any, Dict
|
|
26
|
+
|
|
27
|
+
if __package__ is None or __package__ == "":
|
|
28
|
+
from runtime_environment_manager import (
|
|
29
|
+
RuntimeEnvironmentManager,
|
|
30
|
+
_DependencySettings,
|
|
31
|
+
get_logger,
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
|
|
35
|
+
RuntimeEnvironmentManager,
|
|
36
|
+
_DependencySettings,
|
|
37
|
+
get_logger,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SUCCESS_EXIT_CODE = 0
|
|
41
|
+
DEFAULT_FAILURE_CODE = 1
|
|
42
|
+
|
|
43
|
+
REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
|
|
44
|
+
BASE_CHANNEL_PATH = "/opt/ml/input/data"
|
|
45
|
+
FAILURE_REASON_PATH = "/opt/ml/output/failure"
|
|
46
|
+
JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"]
|
|
47
|
+
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
|
|
48
|
+
JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
|
|
49
|
+
SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies"
|
|
50
|
+
|
|
51
|
+
SM_MODEL_DIR = "/opt/ml/model"
|
|
52
|
+
|
|
53
|
+
SM_INPUT_DIR = "/opt/ml/input"
|
|
54
|
+
SM_INPUT_DATA_DIR = "/opt/ml/input/data"
|
|
55
|
+
SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"
|
|
56
|
+
|
|
57
|
+
SM_OUTPUT_DIR = "/opt/ml/output"
|
|
58
|
+
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
|
|
59
|
+
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
|
|
60
|
+
|
|
61
|
+
SM_MASTER_ADDR = "algo-1"
|
|
62
|
+
SM_MASTER_PORT = 7777
|
|
63
|
+
|
|
64
|
+
RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
|
|
65
|
+
ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env"
|
|
66
|
+
|
|
67
|
+
SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
|
|
68
|
+
HIDDEN_VALUE = "******"
|
|
69
|
+
|
|
70
|
+
SM_EFA_NCCL_INSTANCES = [
|
|
71
|
+
"ml.g4dn.8xlarge",
|
|
72
|
+
"ml.g4dn.12xlarge",
|
|
73
|
+
"ml.g5.48xlarge",
|
|
74
|
+
"ml.p3dn.24xlarge",
|
|
75
|
+
"ml.p4d.24xlarge",
|
|
76
|
+
"ml.p4de.24xlarge",
|
|
77
|
+
"ml.p5.48xlarge",
|
|
78
|
+
"ml.trn1.32xlarge",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
SM_EFA_RDMA_INSTANCES = [
|
|
82
|
+
"ml.p4d.24xlarge",
|
|
83
|
+
"ml.p4de.24xlarge",
|
|
84
|
+
"ml.trn1.32xlarge",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
logger = get_logger()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _bootstrap_runtime_env_for_remote_function(
|
|
91
|
+
client_python_version: str,
|
|
92
|
+
conda_env: str = None,
|
|
93
|
+
dependency_settings: _DependencySettings = None,
|
|
94
|
+
):
|
|
95
|
+
"""Bootstrap runtime environment for remote function invocation.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
client_python_version (str): Python version at the client side.
|
|
99
|
+
conda_env (str): conda environment to be activated. Default is None.
|
|
100
|
+
dependency_settings (dict): Settings for installing dependencies.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
workspace_unpack_dir = _unpack_user_workspace()
|
|
104
|
+
if not workspace_unpack_dir:
|
|
105
|
+
logger.info("No workspace to unpack and setup.")
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
_handle_pre_exec_scripts(workspace_unpack_dir)
|
|
109
|
+
|
|
110
|
+
_install_dependencies(
|
|
111
|
+
workspace_unpack_dir,
|
|
112
|
+
conda_env,
|
|
113
|
+
client_python_version,
|
|
114
|
+
REMOTE_FUNCTION_WORKSPACE,
|
|
115
|
+
dependency_settings,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _bootstrap_runtime_env_for_pipeline_step(
|
|
120
|
+
client_python_version: str,
|
|
121
|
+
func_step_workspace: str,
|
|
122
|
+
conda_env: str = None,
|
|
123
|
+
dependency_settings: _DependencySettings = None,
|
|
124
|
+
):
|
|
125
|
+
"""Bootstrap runtime environment for pipeline step invocation.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
client_python_version (str): Python version at the client side.
|
|
129
|
+
func_step_workspace (str): s3 folder where workspace for FunctionStep is stored
|
|
130
|
+
conda_env (str): conda environment to be activated. Default is None.
|
|
131
|
+
dependency_settings (dict): Name of the dependency file. Default is None.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
workspace_dir = _unpack_user_workspace(func_step_workspace)
|
|
135
|
+
if not workspace_dir:
|
|
136
|
+
os.mkdir(JOB_REMOTE_FUNCTION_WORKSPACE)
|
|
137
|
+
workspace_dir = pathlib.Path(os.getcwd(), JOB_REMOTE_FUNCTION_WORKSPACE).absolute()
|
|
138
|
+
|
|
139
|
+
pre_exec_script_and_dependencies_dir = os.path.join(
|
|
140
|
+
BASE_CHANNEL_PATH, SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if not os.path.exists(pre_exec_script_and_dependencies_dir):
|
|
144
|
+
logger.info("No dependencies to bootstrap")
|
|
145
|
+
return
|
|
146
|
+
for file in os.listdir(pre_exec_script_and_dependencies_dir):
|
|
147
|
+
src_path = os.path.join(pre_exec_script_and_dependencies_dir, file)
|
|
148
|
+
dest_path = os.path.join(workspace_dir, file)
|
|
149
|
+
shutil.copy(src_path, dest_path)
|
|
150
|
+
|
|
151
|
+
_handle_pre_exec_scripts(workspace_dir)
|
|
152
|
+
|
|
153
|
+
_install_dependencies(
|
|
154
|
+
workspace_dir,
|
|
155
|
+
conda_env,
|
|
156
|
+
client_python_version,
|
|
157
|
+
SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME,
|
|
158
|
+
dependency_settings,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _handle_pre_exec_scripts(script_file_dir: str):
|
|
163
|
+
"""Run the pre execution scripts.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
script_file_dir (str): Directory in the container where pre-execution scripts exists.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
path_to_pre_exec_script = os.path.join(script_file_dir, PRE_EXECUTION_SCRIPT_NAME)
|
|
170
|
+
if os.path.isfile(path_to_pre_exec_script):
|
|
171
|
+
RuntimeEnvironmentManager().run_pre_exec_script(
|
|
172
|
+
pre_exec_script_path=path_to_pre_exec_script
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _install_dependencies(
|
|
177
|
+
dependency_file_dir: str,
|
|
178
|
+
conda_env: str,
|
|
179
|
+
client_python_version: str,
|
|
180
|
+
channel_name: str,
|
|
181
|
+
dependency_settings: _DependencySettings = None,
|
|
182
|
+
):
|
|
183
|
+
"""Install dependencies in the job container
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
dependency_file_dir (str): Directory in the container where dependency file exists.
|
|
187
|
+
conda_env (str): conda environment to be activated.
|
|
188
|
+
client_python_version (str): Python version at the client side.
|
|
189
|
+
channel_name (str): Channel where dependency file was uploaded.
|
|
190
|
+
dependency_settings (dict): Settings for installing dependencies.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
if dependency_settings is not None and dependency_settings.dependency_file is None:
|
|
194
|
+
# an empty dict is passed when no dependencies are specified
|
|
195
|
+
logger.info("No dependencies to install.")
|
|
196
|
+
elif dependency_settings is not None:
|
|
197
|
+
dependencies_file = os.path.join(dependency_file_dir, dependency_settings.dependency_file)
|
|
198
|
+
RuntimeEnvironmentManager().bootstrap(
|
|
199
|
+
local_dependencies_file=dependencies_file,
|
|
200
|
+
conda_env=conda_env,
|
|
201
|
+
client_python_version=client_python_version,
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
# no dependency file name is passed when an legacy version of the SDK is used
|
|
205
|
+
# we look for a file with .txt, .yml or .yaml extension in the workspace directory
|
|
206
|
+
dependencies_file = None
|
|
207
|
+
for file in os.listdir(dependency_file_dir):
|
|
208
|
+
if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"):
|
|
209
|
+
dependencies_file = os.path.join(dependency_file_dir, file)
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
if dependencies_file:
|
|
213
|
+
RuntimeEnvironmentManager().bootstrap(
|
|
214
|
+
local_dependencies_file=dependencies_file,
|
|
215
|
+
conda_env=conda_env,
|
|
216
|
+
client_python_version=client_python_version,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
logger.info(
|
|
220
|
+
"Did not find any dependency file in the directory at '%s'."
|
|
221
|
+
" Assuming no additional dependencies to install.",
|
|
222
|
+
os.path.join(BASE_CHANNEL_PATH, channel_name),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _unpack_user_workspace(func_step_workspace: str = None):
|
|
227
|
+
"""Unzip the user workspace"""
|
|
228
|
+
|
|
229
|
+
workspace_archive_dir_path = (
|
|
230
|
+
os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE)
|
|
231
|
+
if not func_step_workspace
|
|
232
|
+
else os.path.join(BASE_CHANNEL_PATH, func_step_workspace)
|
|
233
|
+
)
|
|
234
|
+
if not os.path.exists(workspace_archive_dir_path):
|
|
235
|
+
logger.info(
|
|
236
|
+
"Directory '%s' does not exist.",
|
|
237
|
+
workspace_archive_dir_path,
|
|
238
|
+
)
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip")
|
|
242
|
+
if not os.path.isfile(workspace_archive_path):
|
|
243
|
+
logger.info(
|
|
244
|
+
"Workspace archive '%s' does not exist.",
|
|
245
|
+
workspace_archive_dir_path,
|
|
246
|
+
)
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute()
|
|
250
|
+
shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir)
|
|
251
|
+
logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir)
|
|
252
|
+
workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
|
|
253
|
+
return workspace_unpack_dir
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _write_failure_reason_file(failure_msg):
|
|
257
|
+
"""Create a file 'failure' with failure reason written if bootstrap runtime env failed.
|
|
258
|
+
|
|
259
|
+
See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
|
|
260
|
+
Args:
|
|
261
|
+
failure_msg: The content of file to be written.
|
|
262
|
+
"""
|
|
263
|
+
if not os.path.exists(FAILURE_REASON_PATH):
|
|
264
|
+
with open(FAILURE_REASON_PATH, "w") as f:
|
|
265
|
+
f.write("RuntimeEnvironmentError: " + failure_msg)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _parse_args(sys_args):
|
|
269
|
+
"""Parses CLI arguments."""
|
|
270
|
+
parser = argparse.ArgumentParser()
|
|
271
|
+
parser.add_argument("--job_conda_env", type=str)
|
|
272
|
+
parser.add_argument("--client_python_version", type=str)
|
|
273
|
+
parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None)
|
|
274
|
+
parser.add_argument("--pipeline_execution_id", type=str)
|
|
275
|
+
parser.add_argument("--dependency_settings", type=str)
|
|
276
|
+
parser.add_argument("--func_step_s3_dir", type=str)
|
|
277
|
+
parser.add_argument("--distribution", type=str, default=None)
|
|
278
|
+
parser.add_argument("--user_nproc_per_node", type=str, default=None)
|
|
279
|
+
args, _ = parser.parse_known_args(sys_args)
|
|
280
|
+
return args
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def log_key_value(key: str, value: str):
|
|
284
|
+
"""Log a key-value pair, masking sensitive values if necessary."""
|
|
285
|
+
if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS):
|
|
286
|
+
logger.info("%s=%s", key, HIDDEN_VALUE)
|
|
287
|
+
elif isinstance(value, dict):
|
|
288
|
+
masked_value = mask_sensitive_info(value)
|
|
289
|
+
logger.info("%s=%s", key, json.dumps(masked_value))
|
|
290
|
+
else:
|
|
291
|
+
try:
|
|
292
|
+
decoded_value = json.loads(value)
|
|
293
|
+
if isinstance(decoded_value, dict):
|
|
294
|
+
masked_value = mask_sensitive_info(decoded_value)
|
|
295
|
+
logger.info("%s=%s", key, json.dumps(masked_value))
|
|
296
|
+
else:
|
|
297
|
+
logger.info("%s=%s", key, decoded_value)
|
|
298
|
+
except (json.JSONDecodeError, TypeError):
|
|
299
|
+
logger.info("%s=%s", key, value)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def log_env_variables(env_vars_dict: Dict[str, Any]):
|
|
303
|
+
"""Log Environment Variables from the environment and an env_vars_dict."""
|
|
304
|
+
for key, value in os.environ.items():
|
|
305
|
+
log_key_value(key, value)
|
|
306
|
+
|
|
307
|
+
for key, value in env_vars_dict.items():
|
|
308
|
+
log_key_value(key, value)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def mask_sensitive_info(data):
|
|
312
|
+
"""Recursively mask sensitive information in a dictionary."""
|
|
313
|
+
if isinstance(data, dict):
|
|
314
|
+
for k, v in data.items():
|
|
315
|
+
if isinstance(v, dict):
|
|
316
|
+
data[k] = mask_sensitive_info(v)
|
|
317
|
+
elif isinstance(v, str) and any(
|
|
318
|
+
keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS
|
|
319
|
+
):
|
|
320
|
+
data[k] = HIDDEN_VALUE
|
|
321
|
+
return data
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def num_cpus() -> int:
|
|
325
|
+
"""Return the number of CPUs available in the current container.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
int: Number of CPUs available in the current container.
|
|
329
|
+
"""
|
|
330
|
+
return multiprocessing.cpu_count()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def num_gpus() -> int:
|
|
334
|
+
"""Return the number of GPUs available in the current container.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
int: Number of GPUs available in the current container.
|
|
338
|
+
"""
|
|
339
|
+
try:
|
|
340
|
+
cmd = ["nvidia-smi", "--list-gpus"]
|
|
341
|
+
output = subprocess.check_output(cmd).decode("utf-8")
|
|
342
|
+
return sum(1 for line in output.splitlines() if line.startswith("GPU "))
|
|
343
|
+
except (OSError, subprocess.CalledProcessError):
|
|
344
|
+
logger.info("No GPUs detected (normal if no gpus installed)")
|
|
345
|
+
return 0
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def num_neurons() -> int:
|
|
349
|
+
"""Return the number of neuron cores available in the current container.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
int: Number of Neuron Cores available in the current container.
|
|
353
|
+
"""
|
|
354
|
+
try:
|
|
355
|
+
cmd = ["neuron-ls", "-j"]
|
|
356
|
+
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8")
|
|
357
|
+
j = json.loads(output)
|
|
358
|
+
neuron_cores = 0
|
|
359
|
+
for item in j:
|
|
360
|
+
neuron_cores += item.get("nc_count", 0)
|
|
361
|
+
logger.info("Found %s neurons on this instance", neuron_cores)
|
|
362
|
+
return neuron_cores
|
|
363
|
+
except OSError:
|
|
364
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
365
|
+
return 0
|
|
366
|
+
except subprocess.CalledProcessError as e:
|
|
367
|
+
if e.output is not None:
|
|
368
|
+
try:
|
|
369
|
+
msg = e.output.decode("utf-8").partition("error=")[2]
|
|
370
|
+
logger.info(
|
|
371
|
+
"No Neurons detected (normal if no neurons installed). \
|
|
372
|
+
If neuron installed then %s",
|
|
373
|
+
msg,
|
|
374
|
+
)
|
|
375
|
+
except AttributeError:
|
|
376
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
377
|
+
else:
|
|
378
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
379
|
+
|
|
380
|
+
return 0
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def safe_serialize(data):
|
|
384
|
+
"""Serialize the data without wrapping strings in quotes.
|
|
385
|
+
|
|
386
|
+
This function handles the following cases:
|
|
387
|
+
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
|
|
388
|
+
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
|
|
389
|
+
the JSON-encoded string using `json.dumps()`.
|
|
390
|
+
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
|
|
391
|
+
representation of the data using `str(data)`.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
data (Any): The data to serialize.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
str: The serialized JSON-compatible string or the string representation of the input.
|
|
398
|
+
"""
|
|
399
|
+
if isinstance(data, str):
|
|
400
|
+
return data
|
|
401
|
+
try:
|
|
402
|
+
return json.dumps(data)
|
|
403
|
+
except TypeError:
|
|
404
|
+
return str(data)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def set_env(
|
|
408
|
+
resource_config: Dict[str, Any],
|
|
409
|
+
distribution: str = None,
|
|
410
|
+
user_nproc_per_node: bool = None,
|
|
411
|
+
output_file: str = ENV_OUTPUT_FILE,
|
|
412
|
+
):
|
|
413
|
+
"""Set environment variables for the training job container.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
resource_config (Dict[str, Any]): Resource configuration for the training job.
|
|
417
|
+
output_file (str): Output file to write the environment variables.
|
|
418
|
+
"""
|
|
419
|
+
# Constants
|
|
420
|
+
env_vars = {
|
|
421
|
+
"SM_MODEL_DIR": SM_MODEL_DIR,
|
|
422
|
+
"SM_INPUT_DIR": SM_INPUT_DIR,
|
|
423
|
+
"SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR,
|
|
424
|
+
"SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR,
|
|
425
|
+
"SM_OUTPUT_DIR": SM_OUTPUT_DIR,
|
|
426
|
+
"SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE,
|
|
427
|
+
"SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR,
|
|
428
|
+
"SM_MASTER_ADDR": SM_MASTER_ADDR,
|
|
429
|
+
"SM_MASTER_PORT": SM_MASTER_PORT,
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
# Host Variables
|
|
433
|
+
current_host = resource_config["current_host"]
|
|
434
|
+
current_instance_type = resource_config["current_instance_type"]
|
|
435
|
+
hosts = resource_config["hosts"]
|
|
436
|
+
sorted_hosts = sorted(hosts)
|
|
437
|
+
|
|
438
|
+
env_vars["SM_CURRENT_HOST"] = current_host
|
|
439
|
+
env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type
|
|
440
|
+
env_vars["SM_HOSTS"] = sorted_hosts
|
|
441
|
+
env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"]
|
|
442
|
+
env_vars["SM_HOST_COUNT"] = len(sorted_hosts)
|
|
443
|
+
env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host)
|
|
444
|
+
|
|
445
|
+
env_vars["SM_NUM_CPUS"] = num_cpus()
|
|
446
|
+
env_vars["SM_NUM_GPUS"] = num_gpus()
|
|
447
|
+
env_vars["SM_NUM_NEURONS"] = num_neurons()
|
|
448
|
+
|
|
449
|
+
# Misc.
|
|
450
|
+
env_vars["SM_RESOURCE_CONFIG"] = resource_config
|
|
451
|
+
|
|
452
|
+
if user_nproc_per_node is not None and int(user_nproc_per_node) > 0:
|
|
453
|
+
env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node)
|
|
454
|
+
else:
|
|
455
|
+
if int(env_vars["SM_NUM_GPUS"]) > 0:
|
|
456
|
+
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"])
|
|
457
|
+
elif int(env_vars["SM_NUM_NEURONS"]) > 0:
|
|
458
|
+
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"])
|
|
459
|
+
else:
|
|
460
|
+
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"])
|
|
461
|
+
|
|
462
|
+
# All Training Environment Variables
|
|
463
|
+
env_vars["SM_TRAINING_ENV"] = {
|
|
464
|
+
"current_host": env_vars["SM_CURRENT_HOST"],
|
|
465
|
+
"current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"],
|
|
466
|
+
"hosts": env_vars["SM_HOSTS"],
|
|
467
|
+
"host_count": env_vars["SM_HOST_COUNT"],
|
|
468
|
+
"nproc_per_node": env_vars["SM_NPROC_PER_NODE"],
|
|
469
|
+
"master_addr": env_vars["SM_MASTER_ADDR"],
|
|
470
|
+
"master_port": env_vars["SM_MASTER_PORT"],
|
|
471
|
+
"input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"],
|
|
472
|
+
"input_data_dir": env_vars["SM_INPUT_DATA_DIR"],
|
|
473
|
+
"input_dir": env_vars["SM_INPUT_DIR"],
|
|
474
|
+
"job_name": os.environ["TRAINING_JOB_NAME"],
|
|
475
|
+
"model_dir": env_vars["SM_MODEL_DIR"],
|
|
476
|
+
"network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"],
|
|
477
|
+
"num_cpus": env_vars["SM_NUM_CPUS"],
|
|
478
|
+
"num_gpus": env_vars["SM_NUM_GPUS"],
|
|
479
|
+
"num_neurons": env_vars["SM_NUM_NEURONS"],
|
|
480
|
+
"output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"],
|
|
481
|
+
"resource_config": env_vars["SM_RESOURCE_CONFIG"],
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
if distribution and distribution == "torchrun":
|
|
485
|
+
logger.info("Distribution: torchrun")
|
|
486
|
+
|
|
487
|
+
instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"]
|
|
488
|
+
network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0")
|
|
489
|
+
|
|
490
|
+
if instance_type in SM_EFA_NCCL_INSTANCES:
|
|
491
|
+
# Enable EFA use
|
|
492
|
+
env_vars["FI_PROVIDER"] = "efa"
|
|
493
|
+
if instance_type in SM_EFA_RDMA_INSTANCES:
|
|
494
|
+
# Use EFA's RDMA functionality for one-sided and two-sided transfer
|
|
495
|
+
env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1"
|
|
496
|
+
env_vars["RDMAV_FORK_SAFE"] = "1"
|
|
497
|
+
env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name)
|
|
498
|
+
env_vars["NCCL_PROTO"] = "simple"
|
|
499
|
+
elif distribution and distribution == "mpirun":
|
|
500
|
+
logger.info("Distribution: mpirun")
|
|
501
|
+
|
|
502
|
+
env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"]
|
|
503
|
+
env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"])
|
|
504
|
+
|
|
505
|
+
host_list = [
|
|
506
|
+
"{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts
|
|
507
|
+
]
|
|
508
|
+
env_vars["SM_HOSTS_LIST"] = ",".join(host_list)
|
|
509
|
+
|
|
510
|
+
instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"]
|
|
511
|
+
|
|
512
|
+
if instance_type in SM_EFA_NCCL_INSTANCES:
|
|
513
|
+
env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa"
|
|
514
|
+
env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple"
|
|
515
|
+
else:
|
|
516
|
+
env_vars["SM_FI_PROVIDER"] = ""
|
|
517
|
+
env_vars["SM_NCCL_PROTO"] = ""
|
|
518
|
+
|
|
519
|
+
if instance_type in SM_EFA_RDMA_INSTANCES:
|
|
520
|
+
env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1"
|
|
521
|
+
else:
|
|
522
|
+
env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = ""
|
|
523
|
+
|
|
524
|
+
with open(output_file, "w") as f:
|
|
525
|
+
for key, value in env_vars.items():
|
|
526
|
+
f.write(f"export {key}='{safe_serialize(value)}'\n")
|
|
527
|
+
|
|
528
|
+
logger.info("Environment Variables:")
|
|
529
|
+
log_env_variables(env_vars_dict=env_vars)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def main(sys_args=None):
|
|
533
|
+
"""Entry point for bootstrap script"""
|
|
534
|
+
|
|
535
|
+
exit_code = DEFAULT_FAILURE_CODE
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
args = _parse_args(sys_args)
|
|
539
|
+
|
|
540
|
+
logger.info("Arguments:")
|
|
541
|
+
for arg in vars(args):
|
|
542
|
+
logger.info("%s=%s", arg, getattr(args, arg))
|
|
543
|
+
|
|
544
|
+
client_python_version = args.client_python_version
|
|
545
|
+
client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
|
|
546
|
+
job_conda_env = args.job_conda_env
|
|
547
|
+
pipeline_execution_id = args.pipeline_execution_id
|
|
548
|
+
dependency_settings = _DependencySettings.from_string(args.dependency_settings)
|
|
549
|
+
func_step_workspace = args.func_step_s3_dir
|
|
550
|
+
distribution = args.distribution
|
|
551
|
+
user_nproc_per_node = args.user_nproc_per_node
|
|
552
|
+
|
|
553
|
+
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
|
|
554
|
+
|
|
555
|
+
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
|
|
556
|
+
|
|
557
|
+
user = getpass.getuser()
|
|
558
|
+
if user != "root":
|
|
559
|
+
log_message = (
|
|
560
|
+
"The job is running on non-root user: %s. Adding write permissions to the "
|
|
561
|
+
"following job output directories: %s."
|
|
562
|
+
)
|
|
563
|
+
logger.info(log_message, user, JOB_OUTPUT_DIRS)
|
|
564
|
+
RuntimeEnvironmentManager().change_dir_permission(
|
|
565
|
+
dirs=JOB_OUTPUT_DIRS, new_permission="777"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
if pipeline_execution_id:
|
|
569
|
+
_bootstrap_runtime_env_for_pipeline_step(
|
|
570
|
+
client_python_version, func_step_workspace, conda_env, dependency_settings
|
|
571
|
+
)
|
|
572
|
+
else:
|
|
573
|
+
_bootstrap_runtime_env_for_remote_function(
|
|
574
|
+
client_python_version, conda_env, dependency_settings
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
|
|
578
|
+
client_sagemaker_pysdk_version
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
if os.path.exists(RESOURCE_CONFIG):
|
|
582
|
+
try:
|
|
583
|
+
logger.info("Found %s", RESOURCE_CONFIG)
|
|
584
|
+
with open(RESOURCE_CONFIG, "r") as f:
|
|
585
|
+
resource_config = json.load(f)
|
|
586
|
+
set_env(
|
|
587
|
+
resource_config=resource_config,
|
|
588
|
+
distribution=distribution,
|
|
589
|
+
user_nproc_per_node=user_nproc_per_node,
|
|
590
|
+
)
|
|
591
|
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
|
592
|
+
# Optionally, you might want to log this error
|
|
593
|
+
logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e))
|
|
594
|
+
|
|
595
|
+
exit_code = SUCCESS_EXIT_CODE
|
|
596
|
+
except Exception as e: # pylint: disable=broad-except
|
|
597
|
+
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
|
|
598
|
+
|
|
599
|
+
_write_failure_reason_file(str(e))
|
|
600
|
+
finally:
|
|
601
|
+
sys.exit(exit_code)
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
if __name__ == "__main__":
|
|
605
|
+
main(sys.argv[1:])
|